regularizer.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle import _C_ops, pir
  15. from paddle.base import framework
  16. from paddle.base.framework import in_dynamic_or_pir_mode
  17. __all__ = ['L1Decay', 'L2Decay']
  18. class WeightDecayRegularizer:
  19. """Base class for weight decay regularizers
  20. Defines the common interface of weight-decay regularizers.
  21. Weight-decay regularizers are added only during the backward
  22. pass for faster regularization. They add operations to the network
  23. that correspond to gradient of the regularization function.
  24. Users should not use this class directly, but need to use one
  25. of its implementations
  26. """
  27. def __init__(self):
  28. pass
  29. def __call__(self, param, grad, block):
  30. """Add corresponding weight decay operations to the network"""
  31. raise NotImplementedError()
  32. def __str__(self):
  33. """Debug string"""
  34. raise NotImplementedError()
  35. class L1Decay(WeightDecayRegularizer):
  36. r"""
  37. Implement the L1 Weight Decay Regularization, which encourages the weights to be sparse.
  38. It can be set in :ref:`api_paddle_ParamAttr` or ``optimizer`` (such as :ref:`api_paddle_optimizer_Momentum` ).
  39. When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in
  40. ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has
  41. higher priority than ``optimizer`` , which means that for a trainable parameter, if regularizer is defined
  42. in its ParamAttr, then the regularizer in Optimizer will be ignored. Otherwise the regularizer
  43. in Optimizer will be used.
  44. In the implementation, the loss function of L1 Weight Decay Regularization is as follows:
  45. .. math::
  46. loss = coeff * reduce\_sum(abs(x))
  47. Args:
  48. coeff(float, optional): regularization coeff. Default:0.0.
  49. Examples:
  50. .. code-block:: python
  51. :name: code-example1
  52. >>> # Example1: set Regularizer in optimizer
  53. >>> import paddle
  54. >>> from paddle.regularizer import L1Decay
  55. >>> linear = paddle.nn.Linear(10, 10)
  56. >>> inp = paddle.rand(shape=[10, 10], dtype="float32")
  57. >>> out = linear(inp)
  58. >>> loss = paddle.mean(out)
  59. >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
  60. >>> beta2 = paddle.to_tensor([0.99], dtype="float32")
  61. >>> momentum = paddle.optimizer.Momentum(
  62. ... learning_rate=0.1,
  63. ... parameters=linear.parameters(),
  64. ... weight_decay=L1Decay(0.0001))
  65. >>> back = out.backward()
  66. >>> momentum.step()
  67. >>> momentum.clear_grad()
  68. .. code-block:: python
  69. :name: code-example2
  70. >>> # Example2: set Regularizer in parameters
  71. >>> # Set L1 regularization in parameters.
  72. >>> # Global regularizer does not take effect on my_conv2d for this case.
  73. >>> from paddle.nn import Conv2D
  74. >>> from paddle import ParamAttr
  75. >>> from paddle.regularizer import L1Decay
  76. >>> my_conv2d = Conv2D(
  77. ... in_channels=10,
  78. ... out_channels=10,
  79. ... kernel_size=1,
  80. ... stride=1,
  81. ... padding=0,
  82. ... weight_attr=ParamAttr(regularizer=L1Decay(coeff=0.01)),
  83. ... bias_attr=False)
  84. """
  85. def __init__(self, coeff=0.0):
  86. assert coeff is not None
  87. super().__init__()
  88. self._coeff = coeff
  89. def __call__(self, param, grad, block):
  90. """Add L1 weight decay ops to network
  91. Adds L1 weight decay ops.
  92. L1WeightDecay = reg_coeff * sign(parameter)
  93. Args:
  94. param: parameter variable for which regularization is applied
  95. block: block in which variable is to be created
  96. Returns:
  97. new variable for weight decay
  98. """
  99. assert isinstance(
  100. param, (framework.Variable, pir.Value, pir.core.ParameterMeta)
  101. )
  102. assert isinstance(block, (framework.Block, pir.Block))
  103. if in_dynamic_or_pir_mode():
  104. sign = _C_ops.sign(param)
  105. return _C_ops.scale(sign, self._coeff, 0.0, True)
  106. else:
  107. sign = block.create_var(
  108. dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
  109. )
  110. decay = block.create_var(
  111. dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
  112. )
  113. # Append sign op
  114. block.append_op(
  115. type='sign', inputs={"X": param}, outputs={"Out": sign}
  116. )
  117. # Append scale op to the output of sign op
  118. block.append_op(
  119. type='scale',
  120. inputs={"X": sign},
  121. outputs={"Out": decay},
  122. attrs={"scale": self._coeff},
  123. )
  124. return decay
  125. def __str__(self):
  126. return "L1Decay, coeff=%f" % self._coeff
  127. class L2Decay(WeightDecayRegularizer):
  128. r"""
  129. Implement the L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
  130. It can be set in :ref:`api_paddle_ParamAttr` or ``optimizer`` (such as :ref:`api_paddle_optimizer_Momentum` ).
  131. When set in ``ParamAttr`` , it only takes effect for trainable parameters in this layer. When set in
  132. ``optimizer`` , it takes effect for all trainable parameters. When set together, ``ParamAttr`` has
  133. higher priority than ``optimizer`` , which means that for a trainable parameter, if regularizer is defined
  134. in its ParamAttr, then the regularizer in Optimizer will be ignored. Otherwise the regularizer
  135. in Optimizer will be used.
  136. In the implementation, the loss function of L2 Weight Decay Regularization is as follows:
  137. .. math::
  138. loss = 0.5 * coeff * reduce\_sum(square(x))
  139. Args:
  140. coeff(float, optional): regularization coeff. Default:0.0
  141. Examples:
  142. .. code-block:: python
  143. :name: code-example1
  144. >>> # Example1: set Regularizer in optimizer
  145. >>> import paddle
  146. >>> from paddle.regularizer import L2Decay
  147. >>> linear = paddle.nn.Linear(10, 10)
  148. >>> inp = paddle.rand(shape=[10, 10], dtype="float32")
  149. >>> out = linear(inp)
  150. >>> loss = paddle.mean(out)
  151. >>> beta1 = paddle.to_tensor([0.9], dtype="float32")
  152. >>> beta2 = paddle.to_tensor([0.99], dtype="float32")
  153. >>> momentum = paddle.optimizer.Momentum(
  154. ... learning_rate=0.1,
  155. ... parameters=linear.parameters(),
  156. ... weight_decay=L2Decay(0.0001))
  157. >>> back = out.backward()
  158. >>> momentum.step()
  159. >>> momentum.clear_grad()
  160. .. code-block:: python
  161. :name: code-example2
  162. >>> # Example2: set Regularizer in parameters
  163. >>> # Set L2 regularization in parameters.
  164. >>> # Global regularizer does not take effect on my_conv2d for this case.
  165. >>> from paddle.nn import Conv2D
  166. >>> from paddle import ParamAttr
  167. >>> from paddle.regularizer import L2Decay
  168. >>> my_conv2d = Conv2D(
  169. ... in_channels=10,
  170. ... out_channels=10,
  171. ... kernel_size=1,
  172. ... stride=1,
  173. ... padding=0,
  174. ... weight_attr=ParamAttr(regularizer=L2Decay(coeff=0.01)),
  175. ... bias_attr=False)
  176. """
  177. def __init__(self, coeff=0.0):
  178. assert coeff is not None
  179. super().__init__()
  180. self._coeff = coeff
  181. def __call__(self, param, grad, block):
  182. """Add L2 weight decay ops to network
  183. Adds L2 weight decay ops.
  184. L2WeightDecay = reg_coeff * parameter
  185. Args:
  186. param: parameter variable for which regularization is applied
  187. block: block in which variable is to be created
  188. Returns:
  189. new variable for weight decay
  190. """
  191. assert isinstance(
  192. param, (framework.Variable, pir.Value, pir.core.ParameterMeta)
  193. )
  194. assert isinstance(block, (framework.Block, pir.Block))
  195. if in_dynamic_or_pir_mode():
  196. return _C_ops.scale(param, self._coeff, 0.0, True)
  197. else:
  198. decay = block.create_var(
  199. dtype=param.dtype, shape=param.shape, lod_level=param.lod_level
  200. )
  201. # Append Op to calculate decay
  202. block.append_op(
  203. type='scale',
  204. inputs={"X": param},
  205. outputs={"Out": decay},
  206. attrs={"scale": self._coeff},
  207. )
  208. return decay
  209. def __str__(self):
  210. return "L2Decay, coeff=%f" % self._coeff