dirac.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. from paddle import _C_ops, in_dynamic_mode, pir
  16. from paddle.utils import unique_name
  17. from ... import base
  18. from ...base import core, framework
  19. from ...base.core import VarDesc
  20. from ...base.data_feeder import check_variable_and_dtype
  21. from ...base.framework import _current_expected_place
  22. from .initializer import Initializer
  23. __all__ = []
  24. class Dirac(Initializer):
  25. r"""Initialize the 3D/4D/5D Tensor with Dirac delta function.
  26. It can reserve the feature of convolution layer input, which means that
  27. as many channels are reserved as possible.
  28. In this initialize method, elements in the middle of convolution kernels will
  29. be set to 1 . The formula can be described as follow.
  30. .. math::
  31. X[d, d, shape[2]//2, shape[3]//2, ...]=1, \ d=0,1...N
  32. where, ``N`` is the minimum value of ``in_channels`` and ``out_channels``
  33. Args:
  34. groups(int, optional): 0-dimension of the Tensor will be divided by groups,
  35. each group has the same value. Default: 1.
  36. name(str, optional): The default value is None. Normally there is no need for user to set this
  37. property. For more information, please refer to :ref:`api_guide_Name`.
  38. Returns:
  39. Dirac initializer instance objects.
  40. Examples:
  41. .. code-block:: python
  42. >>> import paddle
  43. >>> # 1. For kernel_size is uneven number:
  44. >>> attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Dirac())
  45. >>> conv = paddle.nn.Conv1D(3, 2, 3, weight_attr=attr)
  46. >>> print(conv.weight)
  47. Parameter containing:
  48. Tensor(shape=[2, 3, 3], dtype=float32, place=CPUPlace, stop_gradient=False,
  49. [[[0., 1., 0.],
  50. [0., 0., 0.],
  51. [0., 0., 0.]],
  52. [[0., 0., 0.],
  53. [0., 1., 0.],
  54. [0., 0., 0.]]])
  55. >>> input = paddle.rand([8, 3, 10])
  56. >>> output = conv(input)
  57. >>> output == input[:, 0:2, 1:9]
  58. >>> print(output.shape)
  59. [8, 2, 8]
  60. >>> # It means output is almost the same with input, 2 channels are reserved
  61. >>> # 2. For kernel_size is even number:
  62. >>> attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Dirac())
  63. >>> conv = paddle.nn.Conv1D(3, 2, 4, weight_attr=attr)
  64. >>> print(conv.weight)
  65. Parameter containing:
  66. Tensor(shape=[2, 3, 4], dtype=float32, place=CPUPlace, stop_gradient=False,
  67. [[[0., 0., 1., 0.],
  68. [0., 0., 0., 0.],
  69. [0., 0., 0., 0.]],
  70. [[0., 0., 0., 0.],
  71. [0., 0., 1., 0.],
  72. [0., 0., 0., 0.]]])
  73. """
  74. def __init__(self, groups=1, name=None):
  75. assert groups > 0 and isinstance(
  76. groups, int
  77. ), " 'groups' must be a positive integer. "
  78. super().__init__()
  79. self._groups = groups
  80. def __call__(self, var, block=None):
  81. """Initialize the input tensor with dirac initializer.
  82. Args:
  83. var(Tensor): Tensor that needs to be initialized.
  84. block(Block, optional): The block in which initialization ops
  85. should be added. Used in static graph only, default None.
  86. Returns:
  87. The most critical OP(scatter) in this initializer, which contains 7~8 ops in total.
  88. """
  89. assert not (
  90. isinstance(var, framework.EagerParamBase) and var.is_dist()
  91. ), "Currently, dirac initializer not support lazy init for dist param."
  92. block = self._check_block(block)
  93. assert isinstance(var, (framework.Variable, pir.core.ParameterMeta))
  94. assert isinstance(block, (framework.Block, pir.Block))
  95. check_variable_and_dtype(
  96. var, "Out", ['float16', 'bfloat16', 'float32', 'float64'], 'Dirac'
  97. )
  98. assert len(var.shape) in [
  99. 3,
  100. 4,
  101. 5,
  102. ], "Only Tensor with 3/4/5 dimensions can be initialized by Dirac"
  103. assert (
  104. var.shape[0] % self._groups
  105. ) == 0, "Tensor 0-dimension must be divisible by groups"
  106. if framework.in_pir_mode():
  107. if var.dtype != core.DataType.FLOAT32:
  108. out_dtype = core.DataType.FLOAT32
  109. out_var = var
  110. else:
  111. out_dtype = var.dtype
  112. out_var = var
  113. else:
  114. if var.dtype != VarDesc.VarType.FP32:
  115. out_dtype = VarDesc.VarType.FP32
  116. out_var = block.create_var(
  117. name=unique_name.generate(
  118. ".".join(['dirac', var.name, 'tmp'])
  119. ),
  120. shape=var.shape,
  121. dtype=out_dtype,
  122. type=VarDesc.VarType.LOD_TENSOR,
  123. persistable=False,
  124. )
  125. else:
  126. out_dtype = var.dtype
  127. out_var = var
  128. op = None
  129. if framework.in_dygraph_mode():
  130. with base.dygraph.no_grad():
  131. place = _current_expected_place()
  132. _C_ops.full_(
  133. out_var, out_var.shape, str(float(0)), out_dtype, place
  134. )
  135. elif framework.in_pir_mode():
  136. place = _current_expected_place()
  137. out_var = _C_ops.full(out_var.shape, float(0), out_dtype, place)
  138. else:
  139. block.append_op(
  140. type='fill_constant',
  141. inputs={},
  142. outputs={'Out': out_var},
  143. attrs={
  144. 'value': float(0),
  145. 'dtype': out_var.dtype,
  146. 'shape': out_var.shape,
  147. },
  148. stop_gradient=True,
  149. )
  150. origin_shape = var.shape
  151. num_per_group = origin_shape[0] // self._groups
  152. min_shape = min(num_per_group, origin_shape[1])
  153. idx_list = []
  154. value_list = []
  155. strides = []
  156. prod = 1
  157. for dim in reversed(origin_shape):
  158. strides.insert(0, prod)
  159. prod *= dim
  160. for i in range(self._groups):
  161. for j in range(min_shape):
  162. value_list.append(1.0)
  163. offset = 0
  164. for k, stride in enumerate(strides):
  165. if k == 0:
  166. offset += (j + i * num_per_group) * stride
  167. elif k == 1:
  168. offset += j * stride
  169. else:
  170. offset += origin_shape[k] // 2 * stride
  171. idx_list.append(offset)
  172. if framework.in_dygraph_mode():
  173. with base.dygraph.no_grad():
  174. tmp_out = _C_ops.reshape(out_var, [-1])
  175. tmp_out._share_underline_tensor_to(out_var)
  176. elif framework.in_pir_mode():
  177. out_var = _C_ops.reshape(out_var, [-1])
  178. else:
  179. x_shape = block.create_var(
  180. name=unique_name.generate(".".join([out_var.name, "XShape"])),
  181. dtype=out_dtype,
  182. shape=out_var.shape,
  183. type=VarDesc.VarType.LOD_TENSOR,
  184. persistable=False,
  185. stop_gradient=True,
  186. )
  187. block.append_op(
  188. type="reshape2",
  189. inputs={"X": out_var},
  190. attrs={'shape': [-1]},
  191. outputs={"Out": out_var, "XShape": x_shape},
  192. stop_gradient=True,
  193. )
  194. if framework.in_pir_mode():
  195. index_tensor = paddle.zeros(
  196. [len(idx_list)], dtype=core.DataType.INT64
  197. )
  198. index_tensor.stop_gradient = True
  199. else:
  200. index_tensor = block.create_var(
  201. name=unique_name.generate('scatter_index'),
  202. persistable=False,
  203. stop_gradient=True,
  204. )
  205. if framework.in_dygraph_mode():
  206. with base.dygraph.no_grad():
  207. tmp_tensor = framework._create_tensor()
  208. _C_ops.assign_value_(
  209. tmp_tensor,
  210. [len(idx_list)],
  211. VarDesc.VarType.INT64,
  212. idx_list,
  213. _current_expected_place(),
  214. )
  215. tmp_tensor._share_underline_tensor_to(index_tensor)
  216. elif framework.in_pir_mode():
  217. _C_ops.assign_value_(
  218. index_tensor,
  219. [len(idx_list)],
  220. core.DataType.INT64,
  221. idx_list,
  222. _current_expected_place(),
  223. )
  224. else:
  225. block.append_op(
  226. type='assign_value',
  227. outputs={'Out': index_tensor},
  228. attrs={
  229. 'dtype': VarDesc.VarType.INT64,
  230. 'shape': [len(idx_list)],
  231. 'values': idx_list,
  232. },
  233. stop_gradient=True,
  234. )
  235. if framework.in_pir_mode():
  236. value_tensor = paddle.zeros(
  237. [len(value_list)], dtype=core.DataType.FLOAT32
  238. )
  239. value_tensor.stop_gradient = True
  240. else:
  241. value_tensor = block.create_var(
  242. name=unique_name.generate('scatter_value'),
  243. persistable=False,
  244. stop_gradient=True,
  245. )
  246. if framework.in_dygraph_mode():
  247. with base.dygraph.no_grad():
  248. tmp_tensor = framework._create_tensor()
  249. _C_ops.assign_value_(
  250. tmp_tensor,
  251. [len(value_list)],
  252. VarDesc.VarType.FP32,
  253. value_list,
  254. _current_expected_place(),
  255. )
  256. tmp_tensor._share_underline_tensor_to(value_tensor)
  257. elif framework.in_pir_mode():
  258. _C_ops.assign_value_(
  259. value_tensor,
  260. [len(value_list)],
  261. core.DataType.FLOAT32,
  262. value_list,
  263. _current_expected_place(),
  264. )
  265. else:
  266. block.append_op(
  267. type='assign_value',
  268. outputs={'Out': value_tensor},
  269. attrs={
  270. 'dtype': VarDesc.VarType.FP32,
  271. 'shape': [len(value_list)],
  272. 'values': value_list,
  273. },
  274. stop_gradient=True,
  275. )
  276. if framework.in_dygraph_mode():
  277. with base.dygraph.no_grad():
  278. tmp_out = _C_ops.scatter(
  279. out_var, index_tensor, value_tensor, True
  280. )
  281. tmp_out._share_underline_tensor_to(out_var)
  282. tmp_reshape_out = _C_ops.reshape(out_var, origin_shape)
  283. tmp_reshape_out._share_underline_tensor_to(out_var)
  284. if var.dtype != VarDesc.VarType.FP32:
  285. tmp_cast_out = _C_ops.cast(out_var, var.dtype)
  286. tmp_cast_out._share_underline_tensor_to(var)
  287. elif framework.in_pir_mode():
  288. out_var = _C_ops.scatter(out_var, index_tensor, value_tensor, True)
  289. out_var = _C_ops.reshape(out_var, origin_shape)
  290. if var.dtype != core.DataType.FLOAT32:
  291. return _C_ops.cast(out_var, var.dtype)
  292. return out_var
  293. else:
  294. op = block.append_op(
  295. type="scatter",
  296. inputs={
  297. "X": out_var,
  298. "Ids": index_tensor,
  299. "Updates": value_tensor,
  300. },
  301. attrs={'overwrite': True},
  302. outputs={"Out": out_var},
  303. stop_gradient=True,
  304. )
  305. x_shape = block.create_var(
  306. name=unique_name.generate(".".join([out_var.name, "XShape"])),
  307. dtype=out_dtype,
  308. shape=out_var.shape,
  309. type=VarDesc.VarType.LOD_TENSOR,
  310. persistable=False,
  311. stop_gradient=True,
  312. )
  313. block.append_op(
  314. type="reshape2",
  315. inputs={"X": out_var},
  316. attrs={'shape': origin_shape},
  317. outputs={"Out": out_var, "XShape": x_shape},
  318. stop_gradient=True,
  319. )
  320. if var.dtype != VarDesc.VarType.FP32:
  321. block.append_op(
  322. type="cast",
  323. inputs={"X": out_var},
  324. outputs={"Out": var},
  325. attrs={"in_dtype": out_var.dtype, "out_dtype": var.dtype},
  326. stop_gradient=True,
  327. )
  328. if not in_dynamic_mode():
  329. var.op = op
  330. return op