transformer.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import warnings
  4. from typing import Any, Callable, Optional, Union
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from torch.nn.init import xavier_uniform_
  9. from .activation import MultiheadAttention
  10. from .container import ModuleList
  11. from .dropout import Dropout
  12. from .linear import Linear
  13. from .module import Module
  14. from .normalization import LayerNorm
  15. __all__ = [
  16. "Transformer",
  17. "TransformerEncoder",
  18. "TransformerDecoder",
  19. "TransformerEncoderLayer",
  20. "TransformerDecoderLayer",
  21. ]
  22. def _generate_square_subsequent_mask(
  23. sz: int,
  24. device: Optional[torch.device] = None,
  25. dtype: Optional[torch.dtype] = None,
  26. ) -> Tensor:
  27. r"""Generate a square causal mask for the sequence.
  28. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
  29. """
  30. return torch.triu(
  31. torch.full((sz, sz), float("-inf"), dtype=dtype, device=device),
  32. diagonal=1,
  33. )
  34. def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]:
  35. if src.is_nested:
  36. return None
  37. else:
  38. src_size = src.size()
  39. if len(src_size) == 2:
  40. # unbatched: S, E
  41. return src_size[0]
  42. else:
  43. # batched: B, S, E if batch_first else S, B, E
  44. seq_len_pos = 1 if batch_first else 0
  45. return src_size[seq_len_pos]
  46. class Transformer(Module):
  47. r"""A basic transformer layer.
  48. This Transformer layer implements the original Transformer architecture described
  49. in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
  50. intent of this layer is as a reference implementation for foundational understanding
  51. and thus it contains only limited features relative to newer Transformer architectures.
  52. Given the fast pace of innovation in transformer-like architectures, we recommend
  53. exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
  54. to build an efficient transformer layer from building blocks in core or using higher
  55. level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
  56. Args:
  57. d_model: the number of expected features in the encoder/decoder inputs (default=512).
  58. nhead: the number of heads in the multiheadattention models (default=8).
  59. num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
  60. num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
  61. dim_feedforward: the dimension of the feedforward network model (default=2048).
  62. dropout: the dropout value (default=0.1).
  63. activation: the activation function of encoder/decoder intermediate layer, can be a string
  64. ("relu" or "gelu") or a unary callable. Default: relu
  65. custom_encoder: custom encoder (default=None).
  66. custom_decoder: custom decoder (default=None).
  67. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  68. batch_first: If ``True``, then the input and output tensors are provided
  69. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  70. norm_first: if ``True``, encoder and decoder layers will perform LayerNorms before
  71. other attention and feedforward operations, otherwise after. Default: ``False`` (after).
  72. bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
  73. bias. Default: ``True``.
  74. Examples:
  75. >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
  76. >>> src = torch.rand((10, 32, 512))
  77. >>> tgt = torch.rand((20, 32, 512))
  78. >>> out = transformer_model(src, tgt)
  79. Note: A full example to apply nn.Transformer module for the word language model is available in
  80. https://github.com/pytorch/examples/tree/master/word_language_model
  81. """
  82. def __init__(
  83. self,
  84. d_model: int = 512,
  85. nhead: int = 8,
  86. num_encoder_layers: int = 6,
  87. num_decoder_layers: int = 6,
  88. dim_feedforward: int = 2048,
  89. dropout: float = 0.1,
  90. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  91. custom_encoder: Optional[Any] = None,
  92. custom_decoder: Optional[Any] = None,
  93. layer_norm_eps: float = 1e-5,
  94. batch_first: bool = False,
  95. norm_first: bool = False,
  96. bias: bool = True,
  97. device=None,
  98. dtype=None,
  99. ) -> None:
  100. factory_kwargs = {"device": device, "dtype": dtype}
  101. super().__init__()
  102. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  103. if custom_encoder is not None:
  104. self.encoder = custom_encoder
  105. else:
  106. encoder_layer = TransformerEncoderLayer(
  107. d_model,
  108. nhead,
  109. dim_feedforward,
  110. dropout,
  111. activation,
  112. layer_norm_eps,
  113. batch_first,
  114. norm_first,
  115. bias,
  116. **factory_kwargs,
  117. )
  118. encoder_norm = LayerNorm(
  119. d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
  120. )
  121. self.encoder = TransformerEncoder(
  122. encoder_layer, num_encoder_layers, encoder_norm
  123. )
  124. if custom_decoder is not None:
  125. self.decoder = custom_decoder
  126. else:
  127. decoder_layer = TransformerDecoderLayer(
  128. d_model,
  129. nhead,
  130. dim_feedforward,
  131. dropout,
  132. activation,
  133. layer_norm_eps,
  134. batch_first,
  135. norm_first,
  136. bias,
  137. **factory_kwargs,
  138. )
  139. decoder_norm = LayerNorm(
  140. d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs
  141. )
  142. self.decoder = TransformerDecoder(
  143. decoder_layer, num_decoder_layers, decoder_norm
  144. )
  145. self._reset_parameters()
  146. self.d_model = d_model
  147. self.nhead = nhead
  148. self.batch_first = batch_first
  149. def forward(
  150. self,
  151. src: Tensor,
  152. tgt: Tensor,
  153. src_mask: Optional[Tensor] = None,
  154. tgt_mask: Optional[Tensor] = None,
  155. memory_mask: Optional[Tensor] = None,
  156. src_key_padding_mask: Optional[Tensor] = None,
  157. tgt_key_padding_mask: Optional[Tensor] = None,
  158. memory_key_padding_mask: Optional[Tensor] = None,
  159. src_is_causal: Optional[bool] = None,
  160. tgt_is_causal: Optional[bool] = None,
  161. memory_is_causal: bool = False,
  162. ) -> Tensor:
  163. r"""Take in and process masked source/target sequences.
  164. .. note::
  165. If a boolean tensor is provided for any of the [src/tgt/memory]_mask arguments, positions with a ``True`` value are
  166. not allowed to participate in the attention,
  167. which is the opposite of the definition for :attr:`attn_mask`
  168. in :func:`torch.nn.functional.scaled_dot_product_attention`.
  169. Args:
  170. src: the sequence to the encoder (required).
  171. tgt: the sequence to the decoder (required).
  172. src_mask: the additive mask for the src sequence (optional).
  173. tgt_mask: the additive mask for the tgt sequence (optional).
  174. memory_mask: the additive mask for the encoder output (optional).
  175. src_key_padding_mask: the Tensor mask for src keys per batch (optional).
  176. tgt_key_padding_mask: the Tensor mask for tgt keys per batch (optional).
  177. memory_key_padding_mask: the Tensor mask for memory keys per batch (optional).
  178. src_is_causal: If specified, applies a causal mask as ``src_mask``.
  179. Default: ``None``; try to detect a causal mask.
  180. Warning:
  181. ``src_is_causal`` provides a hint that ``src_mask`` is
  182. the causal mask. Providing incorrect hints can result in
  183. incorrect execution, including forward and backward
  184. compatibility.
  185. tgt_is_causal: If specified, applies a causal mask as ``tgt_mask``.
  186. Default: ``None``; try to detect a causal mask.
  187. Warning:
  188. ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
  189. the causal mask. Providing incorrect hints can result in
  190. incorrect execution, including forward and backward
  191. compatibility.
  192. memory_is_causal: If specified, applies a causal mask as
  193. ``memory_mask``.
  194. Default: ``False``.
  195. Warning:
  196. ``memory_is_causal`` provides a hint that
  197. ``memory_mask`` is the causal mask. Providing incorrect
  198. hints can result in incorrect execution, including
  199. forward and backward compatibility.
  200. Shape:
  201. - src: :math:`(S, E)` for unbatched input, :math:`(S, N, E)` if `batch_first=False` or
  202. `(N, S, E)` if `batch_first=True`.
  203. - tgt: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
  204. `(N, T, E)` if `batch_first=True`.
  205. - src_mask: :math:`(S, S)` or :math:`(N\cdot\text{num\_heads}, S, S)`.
  206. - tgt_mask: :math:`(T, T)` or :math:`(N\cdot\text{num\_heads}, T, T)`.
  207. - memory_mask: :math:`(T, S)`.
  208. - src_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
  209. - tgt_key_padding_mask: :math:`(T)` for unbatched input otherwise :math:`(N, T)`.
  210. - memory_key_padding_mask: :math:`(S)` for unbatched input otherwise :math:`(N, S)`.
  211. Note: [src/tgt/memory]_mask ensures that position :math:`i` is allowed to attend the unmasked
  212. positions. If a BoolTensor is provided, positions with ``True``
  213. are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
  214. is provided, it will be added to the attention weight.
  215. [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
  216. the attention. If a BoolTensor is provided, the positions with the
  217. value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
  218. - output: :math:`(T, E)` for unbatched input, :math:`(T, N, E)` if `batch_first=False` or
  219. `(N, T, E)` if `batch_first=True`.
  220. Note: Due to the multi-head attention architecture in the transformer model,
  221. the output sequence length of a transformer is same as the input sequence
  222. (i.e. target) length of the decoder.
  223. where :math:`S` is the source sequence length, :math:`T` is the target sequence length, :math:`N` is the
  224. batch size, :math:`E` is the feature number
  225. Examples:
  226. >>> # xdoctest: +SKIP
  227. >>> output = transformer_model(
  228. ... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask
  229. ... )
  230. """
  231. is_batched = src.dim() == 3
  232. if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
  233. raise RuntimeError("the batch number of src and tgt must be equal")
  234. elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
  235. raise RuntimeError("the batch number of src and tgt must be equal")
  236. if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
  237. raise RuntimeError(
  238. "the feature number of src and tgt must be equal to d_model"
  239. )
  240. memory = self.encoder(
  241. src,
  242. mask=src_mask,
  243. src_key_padding_mask=src_key_padding_mask,
  244. is_causal=src_is_causal,
  245. )
  246. output = self.decoder(
  247. tgt,
  248. memory,
  249. tgt_mask=tgt_mask,
  250. memory_mask=memory_mask,
  251. tgt_key_padding_mask=tgt_key_padding_mask,
  252. memory_key_padding_mask=memory_key_padding_mask,
  253. tgt_is_causal=tgt_is_causal,
  254. memory_is_causal=memory_is_causal,
  255. )
  256. return output
  257. @staticmethod
  258. def generate_square_subsequent_mask(
  259. sz: int,
  260. device: Optional[torch.device] = None,
  261. dtype: Optional[torch.dtype] = None,
  262. ) -> Tensor:
  263. r"""Generate a square causal mask for the sequence.
  264. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0).
  265. """
  266. return _generate_square_subsequent_mask(sz, dtype=dtype, device=device)
  267. def _reset_parameters(self) -> None:
  268. r"""Initiate parameters in the transformer model."""
  269. for p in self.parameters():
  270. if p.dim() > 1:
  271. xavier_uniform_(p)
  272. class TransformerEncoder(Module):
  273. r"""TransformerEncoder is a stack of N encoder layers.
  274. This TransformerEncoder layer implements the original architecture described
  275. in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
  276. intent of this layer is as a reference implementation for foundational understanding
  277. and thus it contains only limited features relative to newer Transformer architectures.
  278. Given the fast pace of innovation in transformer-like architectures, we recommend
  279. exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
  280. to build efficient layers from building blocks in core or using higher
  281. level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
  282. .. warning::
  283. All layers in the TransformerEncoder are initialized with the same parameters.
  284. It is recommended to manually initialize the layers after creating the TransformerEncoder instance.
  285. Args:
  286. encoder_layer: an instance of the TransformerEncoderLayer() class (required).
  287. num_layers: the number of sub-encoder-layers in the encoder (required).
  288. norm: the layer normalization component (optional).
  289. enable_nested_tensor: if True, input will automatically convert to nested tensor
  290. (and convert back on output). This will improve the overall performance of
  291. TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
  292. Examples:
  293. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  294. >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
  295. >>> src = torch.rand(10, 32, 512)
  296. >>> out = transformer_encoder(src)
  297. """
  298. __constants__ = ["norm"]
  299. def __init__(
  300. self,
  301. encoder_layer: "TransformerEncoderLayer",
  302. num_layers: int,
  303. norm: Optional[Module] = None,
  304. enable_nested_tensor: bool = True,
  305. mask_check: bool = True,
  306. ) -> None:
  307. super().__init__()
  308. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  309. self.layers = _get_clones(encoder_layer, num_layers)
  310. self.num_layers = num_layers
  311. self.norm = norm
  312. # this attribute saves the value providedat object construction
  313. self.enable_nested_tensor = enable_nested_tensor
  314. # this attribute controls whether nested tensors are used
  315. self.use_nested_tensor = enable_nested_tensor
  316. self.mask_check = mask_check
  317. enc_layer = "encoder_layer"
  318. why_not_sparsity_fast_path = ""
  319. if not isinstance(encoder_layer, torch.nn.TransformerEncoderLayer):
  320. why_not_sparsity_fast_path = f"{enc_layer} was not TransformerEncoderLayer"
  321. elif encoder_layer.norm_first:
  322. why_not_sparsity_fast_path = f"{enc_layer}.norm_first was True"
  323. elif not encoder_layer.self_attn.batch_first:
  324. why_not_sparsity_fast_path = (
  325. f"{enc_layer}.self_attn.batch_first was not True"
  326. + "(use batch_first for better inference performance)"
  327. )
  328. elif not encoder_layer.self_attn._qkv_same_embed_dim:
  329. why_not_sparsity_fast_path = (
  330. f"{enc_layer}.self_attn._qkv_same_embed_dim was not True"
  331. )
  332. elif encoder_layer.self_attn.in_proj_bias is None:
  333. why_not_sparsity_fast_path = f"{enc_layer}.self_attn was passed bias=False"
  334. elif not encoder_layer.activation_relu_or_gelu:
  335. why_not_sparsity_fast_path = (
  336. f"{enc_layer}.activation_relu_or_gelu was not True"
  337. )
  338. elif not (encoder_layer.norm1.eps == encoder_layer.norm2.eps):
  339. why_not_sparsity_fast_path = (
  340. f"{enc_layer}.norm1.eps was not equal to {enc_layer}.norm2.eps"
  341. )
  342. elif encoder_layer.self_attn.num_heads % 2 == 1:
  343. why_not_sparsity_fast_path = f"{enc_layer}.self_attn.num_heads is odd"
  344. if enable_nested_tensor and why_not_sparsity_fast_path:
  345. warnings.warn(
  346. f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}"
  347. )
  348. self.use_nested_tensor = False
  349. def forward(
  350. self,
  351. src: Tensor,
  352. mask: Optional[Tensor] = None,
  353. src_key_padding_mask: Optional[Tensor] = None,
  354. is_causal: Optional[bool] = None,
  355. ) -> Tensor:
  356. r"""Pass the input through the encoder layers in turn.
  357. Args:
  358. src: the sequence to the encoder (required).
  359. mask: the mask for the src sequence (optional).
  360. src_key_padding_mask: the mask for the src keys per batch (optional).
  361. is_causal: If specified, applies a causal mask as ``mask``.
  362. Default: ``None``; try to detect a causal mask.
  363. Warning:
  364. ``is_causal`` provides a hint that ``mask`` is the
  365. causal mask. Providing incorrect hints can result in
  366. incorrect execution, including forward and backward
  367. compatibility.
  368. Shape:
  369. see the docs in :class:`~torch.nn.Transformer`.
  370. """
  371. src_key_padding_mask = F._canonical_mask(
  372. mask=src_key_padding_mask,
  373. mask_name="src_key_padding_mask",
  374. other_type=F._none_or_dtype(mask),
  375. other_name="mask",
  376. target_type=src.dtype,
  377. )
  378. mask = F._canonical_mask(
  379. mask=mask,
  380. mask_name="mask",
  381. other_type=None,
  382. other_name="",
  383. target_type=src.dtype,
  384. check_other=False,
  385. )
  386. output = src
  387. convert_to_nested = False
  388. first_layer = self.layers[0]
  389. src_key_padding_mask_for_layers = src_key_padding_mask
  390. why_not_sparsity_fast_path = ""
  391. str_first_layer = "self.layers[0]"
  392. batch_first = first_layer.self_attn.batch_first
  393. is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
  394. if not is_fastpath_enabled:
  395. why_not_sparsity_fast_path = (
  396. "torch.backends.mha.get_fastpath_enabled() was not True"
  397. )
  398. elif not hasattr(self, "use_nested_tensor"):
  399. why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
  400. elif not self.use_nested_tensor:
  401. why_not_sparsity_fast_path = (
  402. "self.use_nested_tensor (set in init) was not True"
  403. )
  404. elif first_layer.training:
  405. why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
  406. elif not src.dim() == 3:
  407. why_not_sparsity_fast_path = (
  408. f"input not batched; expected src.dim() of 3 but got {src.dim()}"
  409. )
  410. elif src_key_padding_mask is None:
  411. why_not_sparsity_fast_path = "src_key_padding_mask was None"
  412. elif (
  413. (not hasattr(self, "mask_check")) or self.mask_check
  414. ) and not torch._nested_tensor_from_mask_left_aligned(
  415. src, src_key_padding_mask.logical_not()
  416. ):
  417. why_not_sparsity_fast_path = "mask_check enabled, and src and src_key_padding_mask was not left aligned"
  418. elif output.is_nested:
  419. why_not_sparsity_fast_path = "NestedTensor input is not supported"
  420. elif mask is not None:
  421. why_not_sparsity_fast_path = (
  422. "src_key_padding_mask and mask were both supplied"
  423. )
  424. elif torch.is_autocast_enabled():
  425. why_not_sparsity_fast_path = "autocast is enabled"
  426. if not why_not_sparsity_fast_path:
  427. tensor_args = (
  428. src,
  429. first_layer.self_attn.in_proj_weight,
  430. first_layer.self_attn.in_proj_bias,
  431. first_layer.self_attn.out_proj.weight,
  432. first_layer.self_attn.out_proj.bias,
  433. first_layer.norm1.weight,
  434. first_layer.norm1.bias,
  435. first_layer.norm2.weight,
  436. first_layer.norm2.bias,
  437. first_layer.linear1.weight,
  438. first_layer.linear1.bias,
  439. first_layer.linear2.weight,
  440. first_layer.linear2.bias,
  441. )
  442. _supported_device_type = [
  443. "cpu",
  444. "cuda",
  445. torch.utils.backend_registration._privateuse1_backend_name,
  446. ]
  447. if torch.overrides.has_torch_function(tensor_args):
  448. why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
  449. elif src.device.type not in _supported_device_type:
  450. why_not_sparsity_fast_path = (
  451. f"src device is neither one of {_supported_device_type}"
  452. )
  453. elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
  454. why_not_sparsity_fast_path = (
  455. "grad is enabled and at least one of query or the "
  456. "input/output projection weights or biases requires_grad"
  457. )
  458. if (not why_not_sparsity_fast_path) and (src_key_padding_mask is not None):
  459. convert_to_nested = True
  460. output = torch._nested_tensor_from_mask(
  461. output, src_key_padding_mask.logical_not(), mask_check=False
  462. )
  463. src_key_padding_mask_for_layers = None
  464. seq_len = _get_seq_len(src, batch_first)
  465. is_causal = _detect_is_causal_mask(mask, is_causal, seq_len)
  466. for mod in self.layers:
  467. output = mod(
  468. output,
  469. src_mask=mask,
  470. is_causal=is_causal,
  471. src_key_padding_mask=src_key_padding_mask_for_layers,
  472. )
  473. if convert_to_nested:
  474. output = output.to_padded_tensor(0.0, src.size())
  475. if self.norm is not None:
  476. output = self.norm(output)
  477. return output
  478. class TransformerDecoder(Module):
  479. r"""TransformerDecoder is a stack of N decoder layers.
  480. This TransformerDecoder layer implements the original architecture described
  481. in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
  482. intent of this layer is as a reference implementation for foundational understanding
  483. and thus it contains only limited features relative to newer Transformer architectures.
  484. Given the fast pace of innovation in transformer-like architectures, we recommend
  485. exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
  486. to build efficient layers from building blocks in core or using higher
  487. level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
  488. .. warning::
  489. All layers in the TransformerDecoder are initialized with the same parameters.
  490. It is recommended to manually initialize the layers after creating the TransformerDecoder instance.
  491. Args:
  492. decoder_layer: an instance of the TransformerDecoderLayer() class (required).
  493. num_layers: the number of sub-decoder-layers in the decoder (required).
  494. norm: the layer normalization component (optional).
  495. Examples:
  496. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  497. >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
  498. >>> memory = torch.rand(10, 32, 512)
  499. >>> tgt = torch.rand(20, 32, 512)
  500. >>> out = transformer_decoder(tgt, memory)
  501. """
  502. __constants__ = ["norm"]
  503. def __init__(
  504. self,
  505. decoder_layer: "TransformerDecoderLayer",
  506. num_layers: int,
  507. norm: Optional[Module] = None,
  508. ) -> None:
  509. super().__init__()
  510. torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
  511. self.layers = _get_clones(decoder_layer, num_layers)
  512. self.num_layers = num_layers
  513. self.norm = norm
  514. def forward(
  515. self,
  516. tgt: Tensor,
  517. memory: Tensor,
  518. tgt_mask: Optional[Tensor] = None,
  519. memory_mask: Optional[Tensor] = None,
  520. tgt_key_padding_mask: Optional[Tensor] = None,
  521. memory_key_padding_mask: Optional[Tensor] = None,
  522. tgt_is_causal: Optional[bool] = None,
  523. memory_is_causal: bool = False,
  524. ) -> Tensor:
  525. r"""Pass the inputs (and mask) through the decoder layer in turn.
  526. Args:
  527. tgt: the sequence to the decoder (required).
  528. memory: the sequence from the last layer of the encoder (required).
  529. tgt_mask: the mask for the tgt sequence (optional).
  530. memory_mask: the mask for the memory sequence (optional).
  531. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  532. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  533. tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
  534. Default: ``None``; try to detect a causal mask.
  535. Warning:
  536. ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
  537. the causal mask. Providing incorrect hints can result in
  538. incorrect execution, including forward and backward
  539. compatibility.
  540. memory_is_causal: If specified, applies a causal mask as
  541. ``memory mask``.
  542. Default: ``False``.
  543. Warning:
  544. ``memory_is_causal`` provides a hint that
  545. ``memory_mask`` is the causal mask. Providing incorrect
  546. hints can result in incorrect execution, including
  547. forward and backward compatibility.
  548. Shape:
  549. see the docs in :class:`~torch.nn.Transformer`.
  550. """
  551. output = tgt
  552. seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
  553. tgt_is_causal = _detect_is_causal_mask(tgt_mask, tgt_is_causal, seq_len)
  554. for mod in self.layers:
  555. output = mod(
  556. output,
  557. memory,
  558. tgt_mask=tgt_mask,
  559. memory_mask=memory_mask,
  560. tgt_key_padding_mask=tgt_key_padding_mask,
  561. memory_key_padding_mask=memory_key_padding_mask,
  562. tgt_is_causal=tgt_is_causal,
  563. memory_is_causal=memory_is_causal,
  564. )
  565. if self.norm is not None:
  566. output = self.norm(output)
  567. return output
  568. class TransformerEncoderLayer(Module):
  569. r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
  570. This TransformerEncoderLayer implements the original architecture described
  571. in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
  572. intent of this layer is as a reference implementation for foundational understanding
  573. and thus it contains only limited features relative to newer Transformer architectures.
  574. Given the fast pace of innovation in transformer-like architectures, we recommend
  575. exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
  576. to build efficient layers from building blocks in core or using higher
  577. level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
  578. TransformerEncoderLayer can handle either traditional torch.tensor inputs,
  579. or Nested Tensor inputs. Derived classes are expected to similarly accept
  580. both input formats. (Not all combinations of inputs are currently
  581. supported by TransformerEncoderLayer while Nested Tensor is in prototype
  582. state.)
  583. If you are implementing a custom layer, you may derive it either from
  584. the Module or TransformerEncoderLayer class. If your custom layer
  585. supports both torch.Tensors and Nested Tensors inputs, make its
  586. implementation a derived class of TransformerEncoderLayer. If your custom
  587. Layer supports only torch.Tensor inputs, derive its implementation from
  588. Module.
  589. Args:
  590. d_model: the number of expected features in the input (required).
  591. nhead: the number of heads in the multiheadattention models (required).
  592. dim_feedforward: the dimension of the feedforward network model (default=2048).
  593. dropout: the dropout value (default=0.1).
  594. activation: the activation function of the intermediate layer, can be a string
  595. ("relu" or "gelu") or a unary callable. Default: relu
  596. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  597. batch_first: If ``True``, then the input and output tensors are provided
  598. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  599. norm_first: if ``True``, layer norm is done prior to attention and feedforward
  600. operations, respectively. Otherwise it's done after. Default: ``False`` (after).
  601. bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
  602. bias. Default: ``True``.
  603. Examples:
  604. >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
  605. >>> src = torch.rand(10, 32, 512)
  606. >>> out = encoder_layer(src)
  607. Alternatively, when ``batch_first`` is ``True``:
  608. >>> encoder_layer = nn.TransformerEncoderLayer(
  609. ... d_model=512, nhead=8, batch_first=True
  610. ... )
  611. >>> src = torch.rand(32, 10, 512)
  612. >>> out = encoder_layer(src)
  613. Fast path:
  614. forward() will use a special optimized implementation described in
  615. `FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`_ if all of the following
  616. conditions are met:
  617. - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
  618. argument ``requires_grad``
  619. - training is disabled (using ``.eval()``)
  620. - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
  621. - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
  622. - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
  623. - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
  624. nor ``src_key_padding_mask`` is passed
  625. - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
  626. unless the caller has manually modified one without modifying the other)
  627. If the optimized implementation is in use, a
  628. `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
  629. passed for ``src`` to represent padding more efficiently than using a padding
  630. mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
  631. returned, and an additional speedup proportional to the fraction of the input that
  632. is padding can be expected.
  633. .. _`FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness`:
  634. https://arxiv.org/abs/2205.14135
  635. """
  636. __constants__ = ["norm_first"]
  637. def __init__(
  638. self,
  639. d_model: int,
  640. nhead: int,
  641. dim_feedforward: int = 2048,
  642. dropout: float = 0.1,
  643. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  644. layer_norm_eps: float = 1e-5,
  645. batch_first: bool = False,
  646. norm_first: bool = False,
  647. bias: bool = True,
  648. device=None,
  649. dtype=None,
  650. ) -> None:
  651. factory_kwargs = {"device": device, "dtype": dtype}
  652. super().__init__()
  653. self.self_attn = MultiheadAttention(
  654. d_model,
  655. nhead,
  656. dropout=dropout,
  657. bias=bias,
  658. batch_first=batch_first,
  659. **factory_kwargs,
  660. )
  661. # Implementation of Feedforward model
  662. self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
  663. self.dropout = Dropout(dropout)
  664. self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
  665. self.norm_first = norm_first
  666. self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  667. self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  668. self.dropout1 = Dropout(dropout)
  669. self.dropout2 = Dropout(dropout)
  670. # Legacy string support for activation function.
  671. if isinstance(activation, str):
  672. activation = _get_activation_fn(activation)
  673. # We can't test self.activation in forward() in TorchScript,
  674. # so stash some information about it instead.
  675. if activation is F.relu or isinstance(activation, torch.nn.ReLU):
  676. self.activation_relu_or_gelu = 1
  677. elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
  678. self.activation_relu_or_gelu = 2
  679. else:
  680. self.activation_relu_or_gelu = 0
  681. self.activation = activation
  682. def __setstate__(self, state):
  683. super().__setstate__(state)
  684. if not hasattr(self, "activation"):
  685. self.activation = F.relu
  686. def forward(
  687. self,
  688. src: Tensor,
  689. src_mask: Optional[Tensor] = None,
  690. src_key_padding_mask: Optional[Tensor] = None,
  691. is_causal: bool = False,
  692. ) -> Tensor:
  693. r"""Pass the input through the encoder layer.
  694. Args:
  695. src: the sequence to the encoder layer (required).
  696. src_mask: the mask for the src sequence (optional).
  697. src_key_padding_mask: the mask for the src keys per batch (optional).
  698. is_causal: If specified, applies a causal mask as ``src mask``.
  699. Default: ``False``.
  700. Warning:
  701. ``is_causal`` provides a hint that ``src_mask`` is the
  702. causal mask. Providing incorrect hints can result in
  703. incorrect execution, including forward and backward
  704. compatibility.
  705. Shape:
  706. see the docs in :class:`~torch.nn.Transformer`.
  707. """
  708. src_key_padding_mask = F._canonical_mask(
  709. mask=src_key_padding_mask,
  710. mask_name="src_key_padding_mask",
  711. other_type=F._none_or_dtype(src_mask),
  712. other_name="src_mask",
  713. target_type=src.dtype,
  714. )
  715. src_mask = F._canonical_mask(
  716. mask=src_mask,
  717. mask_name="src_mask",
  718. other_type=None,
  719. other_name="",
  720. target_type=src.dtype,
  721. check_other=False,
  722. )
  723. is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
  724. why_not_sparsity_fast_path = ""
  725. if not is_fastpath_enabled:
  726. why_not_sparsity_fast_path = (
  727. "torch.backends.mha.get_fastpath_enabled() was not True"
  728. )
  729. elif not src.dim() == 3:
  730. why_not_sparsity_fast_path = (
  731. f"input not batched; expected src.dim() of 3 but got {src.dim()}"
  732. )
  733. elif self.training:
  734. why_not_sparsity_fast_path = "training is enabled"
  735. elif not self.self_attn.batch_first:
  736. why_not_sparsity_fast_path = "self_attn.batch_first was not True"
  737. elif self.self_attn.in_proj_bias is None:
  738. why_not_sparsity_fast_path = "self_attn was passed bias=False"
  739. elif not self.self_attn._qkv_same_embed_dim:
  740. why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
  741. elif not self.activation_relu_or_gelu:
  742. why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
  743. elif not (self.norm1.eps == self.norm2.eps):
  744. why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
  745. elif src.is_nested and (
  746. src_key_padding_mask is not None or src_mask is not None
  747. ):
  748. why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"
  749. elif self.self_attn.num_heads % 2 == 1:
  750. why_not_sparsity_fast_path = "num_head is odd"
  751. elif torch.is_autocast_enabled():
  752. why_not_sparsity_fast_path = "autocast is enabled"
  753. elif any(
  754. len(getattr(m, "_forward_hooks", {}))
  755. + len(getattr(m, "_forward_pre_hooks", {}))
  756. for m in self.modules()
  757. ):
  758. why_not_sparsity_fast_path = "forward pre-/hooks are attached to the module"
  759. if not why_not_sparsity_fast_path:
  760. tensor_args = (
  761. src,
  762. self.self_attn.in_proj_weight,
  763. self.self_attn.in_proj_bias,
  764. self.self_attn.out_proj.weight,
  765. self.self_attn.out_proj.bias,
  766. self.norm1.weight,
  767. self.norm1.bias,
  768. self.norm2.weight,
  769. self.norm2.bias,
  770. self.linear1.weight,
  771. self.linear1.bias,
  772. self.linear2.weight,
  773. self.linear2.bias,
  774. )
  775. # We have to use list comprehensions below because TorchScript does not support
  776. # generator expressions.
  777. _supported_device_type = [
  778. "cpu",
  779. "cuda",
  780. torch.utils.backend_registration._privateuse1_backend_name,
  781. ]
  782. if torch.overrides.has_torch_function(tensor_args):
  783. why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
  784. elif not all(
  785. (x.device.type in _supported_device_type) for x in tensor_args
  786. ):
  787. why_not_sparsity_fast_path = (
  788. "some Tensor argument's device is neither one of "
  789. f"{_supported_device_type}"
  790. )
  791. elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
  792. why_not_sparsity_fast_path = (
  793. "grad is enabled and at least one of query or the "
  794. "input/output projection weights or biases requires_grad"
  795. )
  796. if not why_not_sparsity_fast_path:
  797. merged_mask, mask_type = self.self_attn.merge_masks(
  798. src_mask, src_key_padding_mask, src
  799. )
  800. return torch._transformer_encoder_layer_fwd(
  801. src,
  802. self.self_attn.embed_dim,
  803. self.self_attn.num_heads,
  804. self.self_attn.in_proj_weight,
  805. self.self_attn.in_proj_bias,
  806. self.self_attn.out_proj.weight,
  807. self.self_attn.out_proj.bias,
  808. self.activation_relu_or_gelu == 2,
  809. self.norm_first,
  810. self.norm1.eps,
  811. self.norm1.weight,
  812. self.norm1.bias,
  813. self.norm2.weight,
  814. self.norm2.bias,
  815. self.linear1.weight,
  816. self.linear1.bias,
  817. self.linear2.weight,
  818. self.linear2.bias,
  819. merged_mask,
  820. mask_type,
  821. )
  822. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  823. x = src
  824. if self.norm_first:
  825. x = x + self._sa_block(
  826. self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal
  827. )
  828. x = x + self._ff_block(self.norm2(x))
  829. else:
  830. x = self.norm1(
  831. x
  832. + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)
  833. )
  834. x = self.norm2(x + self._ff_block(x))
  835. return x
  836. # self-attention block
  837. def _sa_block(
  838. self,
  839. x: Tensor,
  840. attn_mask: Optional[Tensor],
  841. key_padding_mask: Optional[Tensor],
  842. is_causal: bool = False,
  843. ) -> Tensor:
  844. x = self.self_attn(
  845. x,
  846. x,
  847. x,
  848. attn_mask=attn_mask,
  849. key_padding_mask=key_padding_mask,
  850. need_weights=False,
  851. is_causal=is_causal,
  852. )[0]
  853. return self.dropout1(x)
  854. # feed forward block
  855. def _ff_block(self, x: Tensor) -> Tensor:
  856. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  857. return self.dropout2(x)
  858. class TransformerDecoderLayer(Module):
  859. r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
  860. This TransformerDecoderLayer implements the original architecture described
  861. in the `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_ paper. The
  862. intent of this layer is as a reference implementation for foundational understanding
  863. and thus it contains only limited features relative to newer Transformer architectures.
  864. Given the fast pace of innovation in transformer-like architectures, we recommend
  865. exploring this `tutorial <https://pytorch.org/tutorials/intermediate/transformer_building_blocks.html>`_
  866. to build efficient layers from building blocks in core or using higher
  867. level libraries from the `PyTorch Ecosystem <https://landscape.pytorch.org/>`_.
  868. Args:
  869. d_model: the number of expected features in the input (required).
  870. nhead: the number of heads in the multiheadattention models (required).
  871. dim_feedforward: the dimension of the feedforward network model (default=2048).
  872. dropout: the dropout value (default=0.1).
  873. activation: the activation function of the intermediate layer, can be a string
  874. ("relu" or "gelu") or a unary callable. Default: relu
  875. layer_norm_eps: the eps value in layer normalization components (default=1e-5).
  876. batch_first: If ``True``, then the input and output tensors are provided
  877. as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
  878. norm_first: if ``True``, layer norm is done prior to self attention, multihead
  879. attention and feedforward operations, respectively. Otherwise it's done after.
  880. Default: ``False`` (after).
  881. bias: If set to ``False``, ``Linear`` and ``LayerNorm`` layers will not learn an additive
  882. bias. Default: ``True``.
  883. Examples:
  884. >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
  885. >>> memory = torch.rand(10, 32, 512)
  886. >>> tgt = torch.rand(20, 32, 512)
  887. >>> out = decoder_layer(tgt, memory)
  888. Alternatively, when ``batch_first`` is ``True``:
  889. >>> decoder_layer = nn.TransformerDecoderLayer(
  890. ... d_model=512, nhead=8, batch_first=True
  891. ... )
  892. >>> memory = torch.rand(32, 10, 512)
  893. >>> tgt = torch.rand(32, 20, 512)
  894. >>> out = decoder_layer(tgt, memory)
  895. """
  896. __constants__ = ["norm_first"]
  897. def __init__(
  898. self,
  899. d_model: int,
  900. nhead: int,
  901. dim_feedforward: int = 2048,
  902. dropout: float = 0.1,
  903. activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
  904. layer_norm_eps: float = 1e-5,
  905. batch_first: bool = False,
  906. norm_first: bool = False,
  907. bias: bool = True,
  908. device=None,
  909. dtype=None,
  910. ) -> None:
  911. factory_kwargs = {"device": device, "dtype": dtype}
  912. super().__init__()
  913. self.self_attn = MultiheadAttention(
  914. d_model,
  915. nhead,
  916. dropout=dropout,
  917. batch_first=batch_first,
  918. bias=bias,
  919. **factory_kwargs,
  920. )
  921. self.multihead_attn = MultiheadAttention(
  922. d_model,
  923. nhead,
  924. dropout=dropout,
  925. batch_first=batch_first,
  926. bias=bias,
  927. **factory_kwargs,
  928. )
  929. # Implementation of Feedforward model
  930. self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
  931. self.dropout = Dropout(dropout)
  932. self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
  933. self.norm_first = norm_first
  934. self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  935. self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  936. self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
  937. self.dropout1 = Dropout(dropout)
  938. self.dropout2 = Dropout(dropout)
  939. self.dropout3 = Dropout(dropout)
  940. # Legacy string support for activation function.
  941. if isinstance(activation, str):
  942. self.activation = _get_activation_fn(activation)
  943. else:
  944. self.activation = activation
  945. def __setstate__(self, state):
  946. if "activation" not in state:
  947. state["activation"] = F.relu
  948. super().__setstate__(state)
  949. def forward(
  950. self,
  951. tgt: Tensor,
  952. memory: Tensor,
  953. tgt_mask: Optional[Tensor] = None,
  954. memory_mask: Optional[Tensor] = None,
  955. tgt_key_padding_mask: Optional[Tensor] = None,
  956. memory_key_padding_mask: Optional[Tensor] = None,
  957. tgt_is_causal: bool = False,
  958. memory_is_causal: bool = False,
  959. ) -> Tensor:
  960. r"""Pass the inputs (and mask) through the decoder layer.
  961. Args:
  962. tgt: the sequence to the decoder layer (required).
  963. memory: the sequence from the last layer of the encoder (required).
  964. tgt_mask: the mask for the tgt sequence (optional).
  965. memory_mask: the mask for the memory sequence (optional).
  966. tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
  967. memory_key_padding_mask: the mask for the memory keys per batch (optional).
  968. tgt_is_causal: If specified, applies a causal mask as ``tgt mask``.
  969. Default: ``False``.
  970. Warning:
  971. ``tgt_is_causal`` provides a hint that ``tgt_mask`` is
  972. the causal mask. Providing incorrect hints can result in
  973. incorrect execution, including forward and backward
  974. compatibility.
  975. memory_is_causal: If specified, applies a causal mask as
  976. ``memory mask``.
  977. Default: ``False``.
  978. Warning:
  979. ``memory_is_causal`` provides a hint that
  980. ``memory_mask`` is the causal mask. Providing incorrect
  981. hints can result in incorrect execution, including
  982. forward and backward compatibility.
  983. Shape:
  984. see the docs in :class:`~torch.nn.Transformer`.
  985. """
  986. # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
  987. x = tgt
  988. if self.norm_first:
  989. x = x + self._sa_block(
  990. self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal
  991. )
  992. x = x + self._mha_block(
  993. self.norm2(x),
  994. memory,
  995. memory_mask,
  996. memory_key_padding_mask,
  997. memory_is_causal,
  998. )
  999. x = x + self._ff_block(self.norm3(x))
  1000. else:
  1001. x = self.norm1(
  1002. x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)
  1003. )
  1004. x = self.norm2(
  1005. x
  1006. + self._mha_block(
  1007. x, memory, memory_mask, memory_key_padding_mask, memory_is_causal
  1008. )
  1009. )
  1010. x = self.norm3(x + self._ff_block(x))
  1011. return x
  1012. # self-attention block
  1013. def _sa_block(
  1014. self,
  1015. x: Tensor,
  1016. attn_mask: Optional[Tensor],
  1017. key_padding_mask: Optional[Tensor],
  1018. is_causal: bool = False,
  1019. ) -> Tensor:
  1020. x = self.self_attn(
  1021. x,
  1022. x,
  1023. x,
  1024. attn_mask=attn_mask,
  1025. key_padding_mask=key_padding_mask,
  1026. is_causal=is_causal,
  1027. need_weights=False,
  1028. )[0]
  1029. return self.dropout1(x)
  1030. # multihead attention block
  1031. def _mha_block(
  1032. self,
  1033. x: Tensor,
  1034. mem: Tensor,
  1035. attn_mask: Optional[Tensor],
  1036. key_padding_mask: Optional[Tensor],
  1037. is_causal: bool = False,
  1038. ) -> Tensor:
  1039. x = self.multihead_attn(
  1040. x,
  1041. mem,
  1042. mem,
  1043. attn_mask=attn_mask,
  1044. key_padding_mask=key_padding_mask,
  1045. is_causal=is_causal,
  1046. need_weights=False,
  1047. )[0]
  1048. return self.dropout2(x)
  1049. # feed forward block
  1050. def _ff_block(self, x: Tensor) -> Tensor:
  1051. x = self.linear2(self.dropout(self.activation(self.linear1(x))))
  1052. return self.dropout3(x)
  1053. def _get_clones(module, N):
  1054. # FIXME: copy.deepcopy() is not defined on nn.module
  1055. return ModuleList([copy.deepcopy(module) for i in range(N)])
  1056. def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
  1057. if activation == "relu":
  1058. return F.relu
  1059. elif activation == "gelu":
  1060. return F.gelu
  1061. raise RuntimeError(f"activation should be relu/gelu, not {activation}")
  1062. def _detect_is_causal_mask(
  1063. mask: Optional[Tensor],
  1064. is_causal: Optional[bool] = None,
  1065. size: Optional[int] = None,
  1066. ) -> bool:
  1067. """Return whether the given attention mask is causal.
  1068. Warning:
  1069. If ``is_causal`` is not ``None``, its value will be returned as is. If a
  1070. user supplies an incorrect ``is_causal`` hint,
  1071. ``is_causal=False`` when the mask is in fact a causal attention.mask
  1072. may lead to reduced performance relative to what would be achievable
  1073. with ``is_causal=True``;
  1074. ``is_causal=True`` when the mask is in fact not a causal attention.mask
  1075. may lead to incorrect and unpredictable execution - in some scenarios,
  1076. a causal mask may be applied based on the hint, in other execution
  1077. scenarios the specified mask may be used. The choice may not appear
  1078. to be deterministic, in that a number of factors like alignment,
  1079. hardware SKU, etc influence the decision whether to use a mask or
  1080. rely on the hint.
  1081. ``size`` if not None, check whether the mask is a causal mask of the provided size
  1082. Otherwise, checks for any causal mask.
  1083. """
  1084. # Prevent type refinement
  1085. make_causal = is_causal is True
  1086. if is_causal is None and mask is not None:
  1087. sz = size if size is not None else mask.size(-2)
  1088. causal_comparison = _generate_square_subsequent_mask(
  1089. sz, device=mask.device, dtype=mask.dtype
  1090. )
  1091. # Do not use `torch.equal` so we handle batched masks by
  1092. # broadcasting the comparison.
  1093. if mask.size() == causal_comparison.size():
  1094. make_causal = bool((mask == causal_comparison).all())
  1095. else:
  1096. make_causal = False
  1097. return make_causal