initializer.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from .data_feeder import check_type
  16. __all__ = []
  17. _global_weight_initializer_ = None
  18. _global_bias_initializer_ = None
  19. def _global_weight_initializer():
  20. """
  21. Return the global weight initializer, The user doesn't need to use it.
  22. """
  23. return _global_weight_initializer_
  24. def _global_bias_initializer():
  25. """
  26. Return the global weight initializer, The user doesn't need to use it.
  27. """
  28. return _global_bias_initializer_
  29. def set_global_initializer(weight_init, bias_init=None):
  30. """
  31. This API is used to set up global model parameter initializer in framework.
  32. After this API is invoked, the global initializer will takes effect in subsequent code.
  33. The model parameters include ``weight`` and ``bias`` . In the framework, they correspond
  34. to ``paddle.ParamAttr`` , which is inherited from ``paddle.Tensor`` , and is a persistable Variable.
  35. This API only takes effect for model parameters, not for variables created through apis such as
  36. :ref:`api_paddle_static_create_global_var` , :ref:`api_paddle_Tensor_create_tensor`.
  37. If the initializer is also set up by ``param_attr`` or ``bias_attr`` when creating a network layer,
  38. the global initializer setting here will not take effect because it has a lower priority.
  39. If you want to cancel the global initializer in framework, please set global initializer to ``None`` .
  40. Args:
  41. weight_init (Initializer): set the global initializer for ``weight`` of model parameters.
  42. bias_init (Initializer, optional): set the global initializer for ``bias`` of model parameters.
  43. Default: None.
  44. Returns:
  45. None
  46. Examples:
  47. .. code-block:: python
  48. >>> import paddle
  49. >>> import paddle.nn as nn
  50. >>> nn.initializer.set_global_initializer(nn.initializer.Uniform(), nn.initializer.Constant())
  51. >>> x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
  52. >>> # The weight of conv1 is initialized by Uniform
  53. >>> # The bias of conv1 is initialized by Constant
  54. >>> conv1 = nn.Conv2D(4, 6, (3, 3))
  55. >>> y_var1 = conv1(x_var)
  56. >>> # If set param_attr/bias_attr too, global initializer will not take effect
  57. >>> # The weight of conv2 is initialized by Xavier
  58. >>> # The bias of conv2 is initialized by Normal
  59. >>> conv2 = nn.Conv2D(4, 6, (3, 3),
  60. ... weight_attr=nn.initializer.XavierUniform(),
  61. ... bias_attr=nn.initializer.Normal())
  62. >>> y_var2 = conv2(x_var)
  63. >>> # Cancel the global initializer in framework, it will takes effect in subsequent code
  64. >>> nn.initializer.set_global_initializer(None)
  65. """
  66. check_type(
  67. weight_init,
  68. 'weight_init',
  69. (paddle.nn.initializer.Initializer, type(None)),
  70. 'set_global_initializer',
  71. )
  72. global _global_weight_initializer_
  73. _global_weight_initializer_ = weight_init
  74. check_type(
  75. bias_init,
  76. 'bias_init',
  77. (paddle.nn.initializer.Initializer, type(None)),
  78. 'set_global_initializer',
  79. )
  80. global _global_bias_initializer_
  81. _global_bias_initializer_ = bias_init