modeling_attn_mask_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
  16. `masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
  17. and will be removed in the future.
  18. """
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. from .utils.import_utils import is_torchdynamo_compiling
  23. @dataclass
  24. class AttentionMaskConverter:
  25. """
  26. A utility attention mask class that allows one to:
  27. - Create a causal 4d mask
  28. - Create a causal 4d mask with slided window
  29. - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
  30. key_value_length) that can be multiplied with attention scores
  31. Examples:
  32. ```python
  33. >>> import torch
  34. >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
  35. >>> converter = AttentionMaskConverter(True)
  36. >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
  37. tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
  38. [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
  39. [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
  40. [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
  41. [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
  42. ```
  43. Parameters:
  44. is_causal (`bool`):
  45. Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
  46. sliding_window (`int`, *optional*):
  47. Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
  48. """
  49. is_causal: bool
  50. sliding_window: int
  51. def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
  52. self.is_causal = is_causal
  53. self.sliding_window = sliding_window
  54. if self.sliding_window is not None and self.sliding_window <= 0:
  55. raise ValueError(
  56. f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
  57. )
  58. def to_causal_4d(
  59. self,
  60. batch_size: int,
  61. query_length: int,
  62. key_value_length: int,
  63. dtype: torch.dtype,
  64. device: Union[torch.device, "str"] = "cpu",
  65. ) -> Optional[torch.Tensor]:
  66. """
  67. Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
  68. bias to upper right hand triangular matrix (causal mask).
  69. """
  70. if not self.is_causal:
  71. raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
  72. # If shape is not cached, create a new causal mask and cache it
  73. input_shape = (batch_size, query_length)
  74. past_key_values_length = key_value_length - query_length
  75. # create causal mask
  76. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  77. causal_4d_mask = None
  78. if input_shape[-1] > 1 or self.sliding_window is not None:
  79. causal_4d_mask = self._make_causal_mask(
  80. input_shape,
  81. dtype,
  82. device=device,
  83. past_key_values_length=past_key_values_length,
  84. sliding_window=self.sliding_window,
  85. )
  86. return causal_4d_mask
  87. def to_4d(
  88. self,
  89. attention_mask_2d: torch.Tensor,
  90. query_length: int,
  91. dtype: torch.dtype,
  92. key_value_length: Optional[int] = None,
  93. ) -> torch.Tensor:
  94. """
  95. Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
  96. key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
  97. causal, a causal mask will be added.
  98. """
  99. input_shape = (attention_mask_2d.shape[0], query_length)
  100. # create causal mask
  101. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  102. causal_4d_mask = None
  103. if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  104. if key_value_length is None:
  105. raise ValueError(
  106. "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
  107. )
  108. past_key_values_length = key_value_length - query_length
  109. causal_4d_mask = self._make_causal_mask(
  110. input_shape,
  111. dtype,
  112. device=attention_mask_2d.device,
  113. past_key_values_length=past_key_values_length,
  114. sliding_window=self.sliding_window,
  115. )
  116. elif self.sliding_window is not None:
  117. raise NotImplementedError("Sliding window is currently only implemented for causal masking")
  118. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  119. expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
  120. attention_mask_2d.device
  121. )
  122. if causal_4d_mask is not None:
  123. expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
  124. # expanded_attn_mask + causal_4d_mask can cause some overflow
  125. expanded_4d_mask = expanded_attn_mask
  126. return expanded_4d_mask
  127. @staticmethod
  128. def _make_causal_mask(
  129. input_ids_shape: torch.Size,
  130. dtype: torch.dtype,
  131. device: torch.device,
  132. past_key_values_length: int = 0,
  133. sliding_window: Optional[int] = None,
  134. ):
  135. """
  136. Make causal mask used for bi-directional self-attention.
  137. """
  138. bsz, tgt_len = input_ids_shape
  139. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  140. mask_cond = torch.arange(mask.size(-1), device=device)
  141. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  142. mask = mask.to(dtype)
  143. if past_key_values_length > 0:
  144. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  145. # add lower triangular sliding window mask if necessary
  146. if sliding_window is not None:
  147. diagonal = past_key_values_length - sliding_window - 1
  148. context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
  149. # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
  150. # See https://github.com/pytorch/pytorch/issues/127571
  151. if is_torchdynamo_compiling():
  152. mask = mask.clone()
  153. mask.masked_fill_(context_mask, torch.finfo(dtype).min)
  154. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  155. @staticmethod
  156. def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
  157. """
  158. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  159. """
  160. bsz, src_len = mask.size()
  161. tgt_len = tgt_len if tgt_len is not None else src_len
  162. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  163. inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
  164. return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
  165. @staticmethod
  166. def _unmask_unattended(
  167. expanded_mask: torch.FloatTensor,
  168. min_dtype: float,
  169. ):
  170. # fmt: off
  171. """
  172. Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
  173. using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  174. Details: https://github.com/pytorch/pytorch/issues/110213
  175. `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
  176. `attention_mask` is [bsz, src_seq_len].
  177. The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
  178. For example, if `expanded_mask` is (e.g. here left-padding case)
  179. ```
  180. [[[[0, 0, 0],
  181. [0, 0, 0],
  182. [0, 0, 1]]],
  183. [[[1, 0, 0],
  184. [1, 1, 0],
  185. [1, 1, 1]]],
  186. [[[0, 0, 0],
  187. [0, 1, 0],
  188. [0, 1, 1]]]]
  189. ```
  190. then the modified `expanded_mask` will be
  191. ```
  192. [[[[1, 1, 1], <-- modified
  193. [1, 1, 1], <-- modified
  194. [0, 0, 1]]],
  195. [[[1, 0, 0],
  196. [1, 1, 0],
  197. [1, 1, 1]]],
  198. [[[1, 1, 1], <-- modified
  199. [0, 1, 0],
  200. [0, 1, 1]]]]
  201. ```
  202. """
  203. # fmt: on
  204. if expanded_mask.dtype == torch.bool:
  205. raise ValueError(
  206. "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
  207. )
  208. return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
  209. @staticmethod
  210. def _ignore_causal_mask_sdpa(
  211. attention_mask: Optional[torch.Tensor],
  212. inputs_embeds: torch.Tensor,
  213. past_key_values_length: int,
  214. sliding_window: Optional[int] = None,
  215. is_training: bool = False,
  216. ) -> bool:
  217. """
  218. Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
  219. ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
  220. In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
  221. `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
  222. allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
  223. passed).
  224. """
  225. _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
  226. key_value_length = query_length + past_key_values_length
  227. is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
  228. ignore_causal_mask = False
  229. if attention_mask is None:
  230. # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
  231. # shape, thus SDPA's `is_causal` argument is rightfully updated
  232. # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
  233. # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
  234. # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
  235. # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
  236. # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
  237. #
  238. # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
  239. # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
  240. if (
  241. (is_training or not is_tracing)
  242. and (query_length == 1 or key_value_length == query_length)
  243. and (sliding_window is None or key_value_length < sliding_window)
  244. ):
  245. ignore_causal_mask = True
  246. elif sliding_window is None or key_value_length < sliding_window:
  247. if len(attention_mask.shape) == 4:
  248. return False
  249. elif not is_tracing and torch.all(attention_mask == 1):
  250. if query_length == 1 or key_value_length == query_length:
  251. # For query_length == 1, causal attention and bi-directional attention are the same.
  252. ignore_causal_mask = True
  253. # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
  254. # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
  255. # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
  256. # Reference: https://github.com/pytorch/pytorch/issues/108108
  257. # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
  258. return ignore_causal_mask
  259. def _prepare_4d_causal_attention_mask(
  260. attention_mask: Optional[torch.Tensor],
  261. input_shape: Union[torch.Size, tuple, list],
  262. inputs_embeds: torch.Tensor,
  263. past_key_values_length: int,
  264. sliding_window: Optional[int] = None,
  265. ):
  266. """
  267. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  268. `(batch_size, key_value_length)`
  269. Args:
  270. attention_mask (`torch.Tensor` or `None`):
  271. A 2D attention mask of shape `(batch_size, key_value_length)`
  272. input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
  273. The input shape should be a tuple that defines `(batch_size, query_length)`.
  274. inputs_embeds (`torch.Tensor`):
  275. The embedded inputs as a torch Tensor.
  276. past_key_values_length (`int`):
  277. The length of the key value cache.
  278. sliding_window (`int`, *optional*):
  279. If the model uses windowed attention, a sliding window should be passed.
  280. """
  281. attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
  282. key_value_length = input_shape[-1] + past_key_values_length
  283. # 4d mask is passed through the layers
  284. if attention_mask is not None and len(attention_mask.shape) == 2:
  285. attention_mask = attn_mask_converter.to_4d(
  286. attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
  287. )
  288. elif attention_mask is not None and len(attention_mask.shape) == 4:
  289. expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
  290. if tuple(attention_mask.shape) != expected_shape:
  291. raise ValueError(
  292. f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
  293. )
  294. else:
  295. # if the 4D mask has correct shape - invert it and fill with negative infinity
  296. inverted_mask = 1.0 - attention_mask
  297. attention_mask = inverted_mask.masked_fill(
  298. inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
  299. )
  300. else:
  301. attention_mask = attn_mask_converter.to_causal_4d(
  302. input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
  303. )
  304. return attention_mask
  305. # Adapted from _prepare_4d_causal_attention_mask
  306. def _prepare_4d_causal_attention_mask_for_sdpa(
  307. attention_mask: Optional[torch.Tensor],
  308. input_shape: Union[torch.Size, tuple, list],
  309. inputs_embeds: torch.Tensor,
  310. past_key_values_length: int,
  311. sliding_window: Optional[int] = None,
  312. ):
  313. """
  314. Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
  315. In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
  316. `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
  317. allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
  318. """
  319. attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
  320. key_value_length = input_shape[-1] + past_key_values_length
  321. # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
  322. # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
  323. # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
  324. is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
  325. ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
  326. attention_mask=attention_mask,
  327. inputs_embeds=inputs_embeds,
  328. past_key_values_length=past_key_values_length,
  329. sliding_window=sliding_window,
  330. )
  331. if ignore_causal_mask:
  332. expanded_4d_mask = None
  333. elif attention_mask is None:
  334. expanded_4d_mask = attn_mask_converter.to_causal_4d(
  335. input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
  336. )
  337. else:
  338. if attention_mask.dim() == 4:
  339. expanded_4d_mask = attention_mask
  340. else:
  341. expanded_4d_mask = attn_mask_converter.to_4d(
  342. attention_mask,
  343. input_shape[-1],
  344. dtype=inputs_embeds.dtype,
  345. key_value_length=key_value_length,
  346. )
  347. # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
  348. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  349. # Details: https://github.com/pytorch/pytorch/issues/110213
  350. if not is_tracing and expanded_4d_mask.device.type == "cuda":
  351. expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
  352. expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
  353. )
  354. return expanded_4d_mask
  355. def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
  356. """
  357. Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  358. `(batch_size, key_value_length)`
  359. Args:
  360. mask (`torch.Tensor`):
  361. A 2D attention mask of shape `(batch_size, key_value_length)`
  362. dtype (`torch.dtype`):
  363. The torch dtype the created mask shall have.
  364. tgt_len (`int`):
  365. The target length or query length the created mask shall have.
  366. """
  367. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  368. def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
  369. """
  370. Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  371. `(batch_size, key_value_length)`
  372. Args:
  373. mask (`torch.Tensor`):
  374. A 2D attention mask of shape `(batch_size, key_value_length)`
  375. dtype (`torch.dtype`):
  376. The torch dtype the created mask shall have.
  377. tgt_len (`int`):
  378. The target length or query length the created mask shall have.
  379. """
  380. _, key_value_length = mask.shape
  381. tgt_len = tgt_len if tgt_len is not None else key_value_length
  382. is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
  383. # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
  384. if not is_tracing and torch.all(mask == 1):
  385. return None
  386. else:
  387. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  388. def _create_4d_causal_attention_mask(
  389. input_shape: Union[torch.Size, tuple, list],
  390. dtype: torch.dtype,
  391. device: torch.device,
  392. past_key_values_length: int = 0,
  393. sliding_window: Optional[int] = None,
  394. ) -> Optional[torch.Tensor]:
  395. """
  396. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
  397. Args:
  398. input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
  399. The input shape should be a tuple that defines `(batch_size, query_length)`.
  400. dtype (`torch.dtype`):
  401. The torch dtype the created mask shall have.
  402. device (`int`):
  403. The torch device the created mask shall have.
  404. sliding_window (`int`, *optional*):
  405. If the model uses windowed attention, a sliding window should be passed.
  406. """
  407. attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
  408. key_value_length = past_key_values_length + input_shape[-1]
  409. attention_mask = attn_mask_converter.to_causal_4d(
  410. input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
  411. )
  412. return attention_mask