extension.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # TODO: define the extension functions
  15. from paddle import _C_ops, tensor
  16. from paddle.utils import deprecated
  17. from ...base.data_feeder import check_type, check_variable_and_dtype
  18. from ...base.layer_helper import LayerHelper
  19. from ...common_ops_import import Variable
  20. from ...framework import (
  21. convert_np_dtype_to_dtype_,
  22. core,
  23. in_dynamic_or_pir_mode,
  24. )
  25. __all__ = []
  26. @deprecated(
  27. since="2.5.2",
  28. update_to="paddle.diag_embed",
  29. level=1,
  30. reason="diag_embed in paddle.nn.functional will be removed in future",
  31. )
  32. def diag_embed(input, offset=0, dim1=-2, dim2=-1):
  33. return tensor.diag_embed(input, offset, dim1, dim2)
  34. def sequence_mask(x, maxlen=None, dtype='int64', name=None):
  35. r"""
  36. **SequenceMask Layer**
  37. This layer outputs a mask according to the input :code:`x` and
  38. :code:`maxlen` with data type of :code:`dtype`.
  39. Supposing :code:`x` is a Tensor with shape [d_1, d_2, ..., d_n], the
  40. :code:`y` is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
  41. .. math::
  42. y(i_1, i_2,..., i_n, j) = (j < x(i_1, i_2,..., i_n))
  43. .. code-block:: text
  44. Case:
  45. Consider input:
  46. x = [3, 1, 1, 0] max_len = 4
  47. then we get out:
  48. mask = [[1, 1, 1, 0],
  49. [1, 0, 0, 0],
  50. [1, 0, 0, 0],
  51. [0, 0, 0, 0]]
  52. Args:
  53. x (Variable): Input tensor of sequence_mask layer, \
  54. whose elements are integers less than :code:`maxlen`. \
  55. Tensor or LodTensor with shape [d_1, d_2, ..., d_n].
  56. maxlen (int, optional): Maximum length of the sequence. If :code:`maxlen` \
  57. is None, it would be replace with :math:`max(x)`.
  58. dtype (np.dtype|paddle.dtype|str, optional): Data type of the output, \
  59. ``int64`` by default.
  60. name(str, optional): For detailed information, please refer \
  61. to :ref:`api_guide_Name`. Usually name is no need to set and \
  62. None by default.
  63. Returns:
  64. Tensor, The output sequence mask. Tensor with shape [d_1, d_2, ..., d_n, maxlen] \
  65. and data type of :code:`dtype`. The data type should be bool, float32, float64, int8, \
  66. int32 or int64.
  67. Examples:
  68. .. code-block:: python
  69. >>> import paddle
  70. >>> lengths = paddle.to_tensor([10, 9, 8])
  71. >>> mask = paddle.nn.functional.sequence_mask(lengths)
  72. >>> print(mask)
  73. Tensor(shape=[3, 10], dtype=int64, place=Place(cpu), stop_gradient=True,
  74. [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  75. [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
  76. [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
  77. """
  78. if in_dynamic_or_pir_mode():
  79. if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
  80. dtype = convert_np_dtype_to_dtype_(dtype)
  81. if maxlen is None:
  82. maxlen = -1
  83. out = _C_ops.sequence_mask(x, maxlen, dtype)
  84. out.stop_gradient = True
  85. return out
  86. helper = LayerHelper('sequence_mask', **locals())
  87. out = helper.create_variable_for_type_inference(dtype=dtype)
  88. inputs = {'X': [x]}
  89. attrs = {'out_dtype': out.dtype}
  90. if maxlen is not None:
  91. if isinstance(maxlen, Variable):
  92. inputs['MaxLenTensor'] = maxlen
  93. else:
  94. attrs['maxlen'] = maxlen
  95. helper.append_op(
  96. type='sequence_mask', inputs=inputs, outputs={'Y': out}, attrs=attrs
  97. )
  98. out.stop_gradient = True
  99. return out
  100. def gather_tree(ids, parents):
  101. r"""
  102. To be used after beam search. After beam search, we get selected ids at
  103. each time step and the corresponding parents in the search tree. Both ids
  104. and parents have the layout :attr:`[max_time, batch_size, beam_size]`. Then
  105. :attr:`gather_tree` is used to backtrace from the last time step and
  106. generate the full sequences by collecting selected ids.
  107. Here is an example:
  108. .. code-block:: text
  109. Given:
  110. ids = [[[2 2]
  111. [6 1]]
  112. [[3 9]
  113. [6 1]]
  114. [[0 1]
  115. [9 0]]]
  116. parents = [[[0 0]
  117. [1 1]]
  118. [[1 0]
  119. [1 0]]
  120. [[0 0]
  121. [0 1]]]
  122. Then:
  123. gather_tree(ids, parents)
  124. = [[[2 2]
  125. [1 6]]
  126. [[3 3]
  127. [6 1]]
  128. [[0 1]
  129. [9 0]]]
  130. Args:
  131. ids(Tensor): A Tensor with shape :attr:`[length, batch_size, beam_size]`
  132. and data type :attr:`int32` or :attr:`int64`. It contains the selected
  133. ids of all time steps.
  134. parents(Tensor): A Tensor with the same shape and data type as :attr:`ids`,
  135. It contains the parents corresponding to selected ids when searching
  136. among beams.
  137. Returns:
  138. A Tensor with the same shape and data type as :attr:`ids`. \
  139. It contains the full sequences. The sequences are collected from \
  140. :attr:`ids` by backtracing according to :attr:`parents`.
  141. Examples:
  142. .. code-block:: python
  143. >>> import paddle
  144. >>> ids = paddle.to_tensor([[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]])
  145. >>> parents = paddle.to_tensor([[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]])
  146. >>> final_sequences = paddle.nn.functional.gather_tree(ids, parents)
  147. >>> [[[2, 2], [1, 6]], [[3, 3], [6, 1]], [[0, 1], [9, 0]]]
  148. >>> final_sequences = paddle.nn.functional.gather_tree(ids, parents)
  149. >>> print(final_sequences)
  150. Tensor(shape=[3, 2, 2], dtype=int64, place=Place(cpu), stop_gradient=True,
  151. [[[2, 2],
  152. [1, 6]],
  153. [[3, 3],
  154. [6, 1]],
  155. [[0, 1],
  156. [9, 0]]])
  157. """
  158. if ids.ndim != 3:
  159. raise ValueError(
  160. "The input ids must be a 3D tensor with shape [length, batch_size, beam_size]"
  161. )
  162. if ids.ndim != parents.ndim:
  163. raise ValueError("The ids's shape must be the same as parents' shape. ")
  164. if in_dynamic_or_pir_mode():
  165. return _C_ops.gather_tree(ids, parents)
  166. else:
  167. helper = LayerHelper('gather_tree', **locals())
  168. check_variable_and_dtype(ids, 'ids', ['int32', 'int64'], 'gather_tree')
  169. check_variable_and_dtype(
  170. parents, 'parents', ['int32', 'int64'], 'gather_tree'
  171. )
  172. out = helper.create_variable_for_type_inference(dtype=ids.dtype)
  173. helper.append_op(
  174. type="gather_tree",
  175. inputs={"Ids": ids, "Parents": parents},
  176. outputs={"Out": out},
  177. )
  178. return out
  179. def temporal_shift(x, seg_num, shift_ratio=0.25, name=None, data_format="NCHW"):
  180. """
  181. **Temporal Shift Operator**
  182. Calculate the temporal shifting features for Input(X).
  183. Input(X) should be in shape of [N*T, C, H, W] or [N*T, H, W, C], while
  184. N is the batch size, T is the temporal segment number specified by
  185. :attr:`seg_num`, C is the channel number, H and W is the height and
  186. width of features.
  187. Temporal Shifting is calculated as follows when data format is NCHW:
  188. Step 1: Reshape Input(X) to [N, T, C, H, W].
  189. Step 2: Pad 0 to reshaping result in the 2nd(T) dimension with
  190. padding width as 1 on each side, padding result will be in shape
  191. of [N, T+2, C, H, W].
  192. Step 3: Assume :attr:`shift_ratio` is :math:`1/4`, slice padding
  193. result as follows:
  194. $$
  195. slice1 = x[:, :T, :C/4, :, :]
  196. $$
  197. $$
  198. slice2 = x[:, 2:T+2, C/4:C/2, :, :]
  199. $$
  200. $$
  201. slice3 = x[:, 1:T+1, C/2:, :, :]
  202. $$
  203. Step 4: Concatenate three slices along the 3rd(C) dimension and
  204. reshape result to [N*T, C, H, W].
  205. For details of temporal shifting, please refer to paper:
  206. `Temporal Shift Module <http://arxiv.org/abs/1811.08383>`_ .
  207. Args:
  208. x(Tensor): ${x_comment}
  209. seg_num(int): ${seg_num_comment}
  210. shift_ratio(float): ${shift_ratio_comment}
  211. name(str, optional): For detailed information, please refer
  212. to :ref:`api_guide_Name`. Usually name is no need to set and
  213. None by default.
  214. data_format(str, optional): Data format that specifies the layout of input.
  215. It can be "NCHW" or "NHWC". Default: "NCHW".
  216. Returns:
  217. out(Tensor): The temporal shifting result is a tensor with the
  218. same shape and same data type as the input.
  219. Examples:
  220. .. code-block:: python
  221. >>> import paddle
  222. >>> import paddle.nn.functional as F
  223. >>> input = paddle.randn([6, 4, 2, 2])
  224. >>> out = F.temporal_shift(x=input, seg_num=2, shift_ratio=0.2)
  225. """
  226. if data_format not in ["NCHW", "NHWC"]:
  227. raise ValueError(
  228. "Attr(data_format) should be 'NCHW' or 'NHWC'. "
  229. f"Received Attr(data_format): {data_format}."
  230. )
  231. if in_dynamic_or_pir_mode():
  232. return _C_ops.temporal_shift(x, seg_num, shift_ratio, data_format)
  233. else:
  234. helper = LayerHelper("temporal_shift", **locals())
  235. check_variable_and_dtype(
  236. x,
  237. 'x',
  238. ['float16', 'uint16', 'float32', 'float64'],
  239. 'temporal_shift',
  240. )
  241. check_type(seg_num, 'seg_num', int, 'temporal_shift')
  242. check_type(shift_ratio, 'shift_ratio', float, 'temporal_shift')
  243. out = helper.create_variable_for_type_inference(dtype=x.dtype)
  244. if not isinstance(seg_num, int):
  245. raise TypeError("seg_num must be int type.")
  246. helper.append_op(
  247. type="temporal_shift",
  248. inputs={"X": x},
  249. outputs={"Out": out},
  250. attrs={
  251. "seg_num": seg_num,
  252. "shift_ratio": shift_ratio,
  253. "data_format": data_format,
  254. },
  255. )
  256. return out