masking_utils.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256
  1. # coding=utf-8
  2. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import itertools
  16. from typing import Callable, Optional, Union
  17. import torch
  18. import torch.nn.functional as F
  19. from .cache_utils import Cache
  20. from .configuration_utils import PretrainedConfig
  21. from .utils import is_torch_xpu_available, logging
  22. from .utils.generic import GeneralInterface
  23. from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_torchdynamo_compiling
  24. if is_torch_flex_attn_available():
  25. from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
  26. from torch.nn.attention.flex_attention import BlockMask, create_block_mask
  27. else:
  28. # Register a fake type to avoid crashing for annotations and `isinstance` checks
  29. BlockMask = torch.Tensor
  30. _is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
  31. _is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
  32. _is_torch_xpu_available = is_torch_xpu_available()
  33. if _is_torch_greater_or_equal_than_2_6:
  34. from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
  35. logger = logging.get_logger(__name__)
  36. def and_masks(*mask_functions: Callable) -> Callable:
  37. """Returns a mask function that is the intersection of provided mask functions"""
  38. if not all(callable(arg) for arg in mask_functions):
  39. raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
  40. def and_mask(batch_idx, head_idx, q_idx, kv_idx):
  41. result = q_idx.new_ones((), dtype=torch.bool)
  42. for mask in mask_functions:
  43. result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
  44. return result
  45. return and_mask
  46. def or_masks(*mask_functions: Callable) -> Callable:
  47. """Returns a mask function that is the union of provided mask functions"""
  48. if not all(callable(arg) for arg in mask_functions):
  49. raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")
  50. def or_mask(batch_idx, head_idx, q_idx, kv_idx):
  51. result = q_idx.new_zeros((), dtype=torch.bool)
  52. for mask in mask_functions:
  53. result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
  54. return result
  55. return or_mask
  56. def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  57. """
  58. This creates a basic lower-diagonal causal mask.
  59. """
  60. return kv_idx <= q_idx
  61. def sliding_window_overlay(sliding_window: int) -> Callable:
  62. """
  63. This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding
  64. window mask.
  65. """
  66. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  67. return kv_idx > q_idx - sliding_window
  68. return inner_mask
  69. def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
  70. """
  71. This is an overlay depicting a chunked attention pattern. Add it on top of a causal mask for a proper chunked
  72. attention mask.
  73. """
  74. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  75. return (kv_idx - left_padding[batch_idx]) // chunk_size == (q_idx - left_padding[batch_idx]) // chunk_size
  76. return inner_mask
  77. def _legacy_chunked_overlay(chunk_size: int) -> Callable:
  78. """
  79. Same as the above function, but do not correctly account for left padding tokens.
  80. Only kept for compatibility with older torch versions (< 2.6).
  81. """
  82. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  83. return kv_idx // chunk_size == q_idx // chunk_size
  84. return inner_mask
  85. def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
  86. """
  87. This return the mask_function function to create a sliding window mask.
  88. """
  89. return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
  90. def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> Callable:
  91. """
  92. This return the mask_function function to create a chunked attention mask.
  93. """
  94. if not _is_torch_greater_or_equal_than_2_6:
  95. return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function)
  96. return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
  97. def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
  98. """
  99. This return the mask_function function corresponding to a 2D padding mask.
  100. """
  101. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  102. # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because
  103. # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not
  104. # vectorizable on accelerator devices
  105. return padding_mask[batch_idx, kv_idx]
  106. return inner_mask
  107. def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable:
  108. """
  109. This return the mask_function function corresponding to a 2D packed sequence mask.
  110. """
  111. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  112. return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx]
  113. return inner_mask
  114. def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable:
  115. """
  116. This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
  117. not start and end indices.
  118. """
  119. def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
  120. return mask_function(batch_idx, head_idx, q_idx + q_offset, kv_idx + kv_offset)
  121. return inner_mask
  122. def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
  123. """
  124. Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
  125. the batch and head indices as well if `bh_indices=True`.
  126. Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
  127. functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
  128. Args:
  129. mask_function (`Callable`):
  130. The mask_function to vmap.
  131. bh_indices (`bool`, optional):
  132. Whether to vmap over the batch and head indices as well, or only q and kv indices.
  133. Returns:
  134. Callable: The vmapped function.
  135. """
  136. # We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
  137. dimensions = [(None, None, None, 0), (None, None, 0, None)]
  138. if bh_indices:
  139. # We extend broadcasting over the [batch_idx, head_idx] dimensions
  140. dimensions.extend([(None, 0, None, None), (0, None, None, None)])
  141. for dims in dimensions:
  142. mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
  143. return mask_function
  144. def prepare_padding_mask(
  145. attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
  146. ) -> Optional[torch.Tensor]:
  147. """
  148. From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing
  149. according to the `kv_offset` if `_slice` is `True`.
  150. """
  151. local_padding_mask = attention_mask
  152. if attention_mask is not None:
  153. # Pad it if necessary
  154. if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
  155. local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
  156. # For flex, we should not slice them, only use an offset
  157. if _slice:
  158. # Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`,
  159. # but without data-dependent slicing (i.e. torch.compile friendly)
  160. mask_indices = torch.arange(kv_length, device=local_padding_mask.device)
  161. mask_indices += kv_offset
  162. local_padding_mask = local_padding_mask[:, mask_indices]
  163. return local_padding_mask
  164. def _ignore_causal_mask_sdpa(
  165. padding_mask: Optional[torch.Tensor],
  166. query_length: int,
  167. kv_length: int,
  168. kv_offset: int,
  169. local_attention_size: Optional[int] = None,
  170. ) -> bool:
  171. """
  172. Detects whether the causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
  173. In case no token is masked in the 2D `padding_mask` argument, if `query_length == 1` or
  174. `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
  175. allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
  176. passed).
  177. """
  178. is_tracing = torch.jit.is_tracing() or isinstance(padding_mask, torch.fx.Proxy) or is_torchdynamo_compiling()
  179. if padding_mask is not None and padding_mask.shape[-1] > kv_length:
  180. mask_indices = torch.arange(kv_length, device=padding_mask.device)
  181. mask_indices += kv_offset
  182. padding_mask = padding_mask[:, mask_indices]
  183. # When using `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
  184. # hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
  185. # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
  186. # `ignore_causal_mask = True` if we are not tracing
  187. if (
  188. not is_tracing
  189. # only cases when lower and upper diags are the same, see https://github.com/pytorch/pytorch/issues/108108
  190. and (query_length == 1 or (kv_length == query_length or _is_torch_xpu_available))
  191. # in this case we need to add special patterns to the mask so cannot be skipped otherwise
  192. and (local_attention_size is None or kv_length < local_attention_size)
  193. # In this case, we need to add padding to the mask, so cannot be skipped otherwise
  194. and (
  195. padding_mask is None
  196. or (
  197. padding_mask.all()
  198. if not _is_torch_xpu_available or query_length == 1
  199. else padding_mask[:, :query_length].all()
  200. )
  201. )
  202. ):
  203. return True
  204. return False
  205. def sdpa_mask_recent_torch(
  206. batch_size: int,
  207. cache_position: torch.Tensor,
  208. kv_length: int,
  209. kv_offset: int = 0,
  210. mask_function: Callable = causal_mask_function,
  211. attention_mask: Optional[torch.Tensor] = None,
  212. local_size: Optional[int] = None,
  213. allow_is_causal_skip: bool = True,
  214. **kwargs,
  215. ) -> Optional[torch.Tensor]:
  216. """
  217. Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
  218. the element should take part in the attention computation, and False that it should not.
  219. This function can only be used with torch>=2.5, as the context manager is otherwise not available.
  220. Args:
  221. batch_size (`int`):
  222. The batch size of the input sequence.
  223. cache_position (`torch.Tensor`):
  224. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  225. kv_length (`int`):
  226. The size that the key and value states will have during the attention computation.
  227. kv_offset (`int`, optional):
  228. An optional offset to indicate at which first position the key and values states will refer to.
  229. mask_function (`Callable`):
  230. The mask factory function describing the mask pattern.
  231. attention_mask (`torch.Tensor`, optional):
  232. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
  233. local_size (`int`, optional):
  234. The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
  235. to try to skip mask creation if possible.
  236. allow_is_causal_skip (`bool`, optional):
  237. Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
  238. `torch.sdpa` instead. Default to `True`.
  239. allow_torch_fix (`bool`, optional):
  240. Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
  241. versions. We need an arg to skip it when using eager. By default `True`.
  242. ## Creating a simple causal mask:
  243. To create the following causal mask:
  244. 0 ■ ⬚ ⬚ ⬚ ⬚
  245. 1 ■ ■ ⬚ ⬚ ⬚
  246. 2 ■ ■ ■ ⬚ ⬚
  247. 3 ■ ■ ■ ■ ⬚
  248. 4 ■ ■ ■ ■ ■
  249. You can do
  250. ```python
  251. >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5)
  252. >>> tensor([[[[ True, False, False, False, False],
  253. [ True, True, False, False, False],
  254. [ True, True, True, False, False],
  255. [ True, True, True, True, False],
  256. [ True, True, True, True, True]]]])
  257. ```
  258. ## Creating a sliding window mask:
  259. To create the following sliding window mask (`sliding_window=3`):
  260. 0 ■ ⬚ ⬚ ⬚ ⬚
  261. 1 ■ ■ ⬚ ⬚ ⬚
  262. 2 ■ ■ ■ ⬚ ⬚
  263. 3 ⬚ ■ ■ ■ ⬚
  264. 4 ⬚ ⬚ ■ ■ ■
  265. You can do
  266. ```python
  267. >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3))
  268. >>> tensor([[[[ True, False, False, False, False],
  269. [ True, True, False, False, False],
  270. [ True, True, True, False, False],
  271. [False, True, True, True, False],
  272. [False, False, True, True, True]]]])
  273. ```
  274. ## Creating a chunked attention mask
  275. To create the following chunked attention mask (`chunk_size=3`):
  276. 0 ■ ⬚ ⬚ ⬚ ⬚
  277. 1 ■ ■ ⬚ ⬚ ⬚
  278. 2 ■ ■ ■ ⬚ ⬚
  279. 3 ⬚ ⬚ ⬚ ■ ⬚
  280. 4 ⬚ ⬚ ⬚ ■ ■
  281. You can do
  282. ```python
  283. >>> sdpa_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3, torch.zeros(1, dtype=int)))
  284. >>> tensor([[[[ True, False, False, False, False],
  285. [ True, True, False, False, False],
  286. [ True, True, True, False, False],
  287. [False, False, False, True, False],
  288. [False, False, False, True, True]]]])
  289. ```
  290. """
  291. q_length = cache_position.shape[0]
  292. # Potentially pad the 2D mask, and slice it correctly
  293. padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
  294. # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
  295. if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
  296. return None
  297. # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
  298. # but without data-dependent slicing (i.e. torch.compile friendly)
  299. kv_arange = torch.arange(kv_length, device=cache_position.device)
  300. kv_arange += kv_offset
  301. # Potentially add the padding 2D mask
  302. if padding_mask is not None:
  303. mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
  304. batch_arange = torch.arange(batch_size, device=cache_position.device)
  305. head_arange = torch.arange(1, device=cache_position.device)
  306. # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
  307. # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
  308. # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
  309. with TransformGetItemToIndex():
  310. causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
  311. return causal_mask
  312. def sdpa_mask_older_torch(
  313. batch_size: int,
  314. cache_position: torch.Tensor,
  315. kv_length: int,
  316. kv_offset: int = 0,
  317. mask_function: Callable = causal_mask_function,
  318. attention_mask: Optional[torch.Tensor] = None,
  319. local_size: Optional[int] = None,
  320. allow_is_causal_skip: bool = True,
  321. allow_torch_fix: bool = True,
  322. **kwargs,
  323. ) -> Optional[torch.Tensor]:
  324. """
  325. NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise.
  326. Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
  327. the element should take part in the attention computation, and False that it should not.
  328. If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend
  329. to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does
  330. not change the final result).
  331. Args:
  332. batch_size (`int`):
  333. The batch size of the input sequence.
  334. cache_position (`torch.Tensor`):
  335. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  336. kv_length (`int`):
  337. The size that the key and value states will have during the attention computation.
  338. kv_offset (`int`, optional):
  339. An optional offset to indicate at which first position the key and values states will refer to.
  340. mask_function (`Callable`):
  341. The mask factory function describing the mask pattern.
  342. attention_mask (`torch.Tensor`, optional):
  343. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
  344. local_size (`int`, optional):
  345. The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
  346. to try to skip mask creation if possible.
  347. allow_is_causal_skip (`bool`, optional):
  348. Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
  349. `torch.sdpa` instead. Default to `True`.
  350. allow_torch_fix (`bool`, optional):
  351. Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
  352. versions. We need an arg to skip it when using eager. By default `True`.
  353. """
  354. q_length = cache_position.shape[0]
  355. # Potentially pad the 2D mask, and slice it correctly
  356. padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
  357. # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
  358. if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
  359. return None
  360. # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
  361. # but without data-dependent slicing (i.e. torch.compile friendly)
  362. kv_arange = torch.arange(kv_length, device=cache_position.device)
  363. kv_arange += kv_offset
  364. # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well,
  365. # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
  366. # However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
  367. # `sdpa_mask_recent_torch`, as it allows more general `mask_function`
  368. causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
  369. causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
  370. if padding_mask is not None:
  371. causal_mask = causal_mask * padding_mask[:, None, None, :]
  372. # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
  373. # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
  374. if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
  375. causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
  376. return causal_mask
  377. # We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
  378. # (especially mask_function indexing a tensor, such as the padding mask function)
  379. sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch
  380. def eager_mask(
  381. batch_size: int,
  382. cache_position: torch.Tensor,
  383. kv_length: int,
  384. kv_offset: int = 0,
  385. mask_function: Callable = causal_mask_function,
  386. attention_mask: Optional[torch.Tensor] = None,
  387. dtype: torch.dtype = torch.float32,
  388. **kwargs,
  389. ) -> torch.Tensor:
  390. """
  391. Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
  392. the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
  393. it should not.
  394. Args:
  395. batch_size (`int`):
  396. The batch size of the input sequence.
  397. cache_position (`torch.Tensor`):
  398. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  399. kv_length (`int`):
  400. The size that the key and value states will have during the attention computation.
  401. kv_offset (`int`, optional):
  402. An optional offset to indicate at which first position the key and values states will refer to.
  403. mask_function (`Callable`):
  404. The mask factory function describing the mask pattern.
  405. attention_mask (`torch.Tensor`, optional):
  406. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
  407. dtype (`torch.dtype`, optional):
  408. The dtype to use for the mask. By default, `torch.float32`.
  409. """
  410. # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
  411. _ = kwargs.pop("allow_is_causal_skip", None)
  412. mask = sdpa_mask(
  413. batch_size=batch_size,
  414. cache_position=cache_position,
  415. kv_length=kv_length,
  416. kv_offset=kv_offset,
  417. mask_function=mask_function,
  418. attention_mask=attention_mask,
  419. allow_is_causal_skip=False,
  420. allow_torch_fix=False,
  421. **kwargs,
  422. )
  423. min_dtype = torch.finfo(dtype).min
  424. # we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
  425. mask = torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
  426. return mask
  427. def flash_attention_mask(
  428. batch_size: int,
  429. cache_position: torch.Tensor,
  430. kv_length: int,
  431. kv_offset: int = 0,
  432. mask_function: Callable = causal_mask_function,
  433. attention_mask: Optional[torch.Tensor] = None,
  434. **kwargs,
  435. ):
  436. """
  437. Create the attention mask necessary to use FA2. Since FA2 is un-padded by definition, here we simply return
  438. `None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens.
  439. We just slice it in case of sliding window.
  440. Args:
  441. batch_size (`int`):
  442. The batch size of the input sequence.
  443. cache_position (`torch.Tensor`):
  444. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  445. kv_length (`int`):
  446. The size that the key and value states will have during the attention computation.
  447. kv_offset (`int`, optional):
  448. An optional offset to indicate at which first position the key and values states will refer to.
  449. mask_function (`Callable`):
  450. The mask factory function describing the mask pattern.
  451. attention_mask (`torch.Tensor`, optional):
  452. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
  453. """
  454. if attention_mask is not None:
  455. # Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
  456. attention_mask = attention_mask[:, -kv_length:]
  457. # We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
  458. # (note that the attention_mask is a boolean dtype here)
  459. if attention_mask.all():
  460. attention_mask = None
  461. return attention_mask
  462. def flex_attention_mask(
  463. batch_size: int,
  464. cache_position: torch.Tensor,
  465. kv_length: int,
  466. kv_offset: int = 0,
  467. mask_function: Callable = causal_mask_function,
  468. attention_mask: Optional[torch.Tensor] = None,
  469. **kwargs,
  470. ) -> BlockMask:
  471. """
  472. Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential
  473. for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/
  474. Args:
  475. batch_size (`int`):
  476. The batch size of the input sequence.
  477. cache_position (`torch.Tensor`):
  478. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  479. kv_length (`int`):
  480. The size that the key and value states will have during the attention computation.
  481. kv_offset (`int`, optional):
  482. An optional offset to indicate at which first position the key and values states will refer to.
  483. mask_function (`Callable`):
  484. The mask factory function describing the mask pattern.
  485. attention_mask (`torch.Tensor`, optional):
  486. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
  487. """
  488. q_length, q_offset = cache_position.shape[0], cache_position[0]
  489. # Potentially add the padding 2D mask
  490. if attention_mask is not None:
  491. # Older torch (2.5.x) cannot handle sequences not in multiples of 128 (default block size)
  492. # Hence we pad to multiples of this as a minimum to ensure this
  493. pad_len = ((attention_mask.shape[1] // flex_default_block_size) + 1) * flex_default_block_size
  494. pad_len = pad_len - attention_mask.shape[1]
  495. if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
  496. attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
  497. padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
  498. mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
  499. # Add the offsets on top (because flex interface only allows length, not start and end indices)
  500. mask_function = add_offsets_to_mask_function(mask_function, q_offset, kv_offset)
  501. # Finally create the block mask
  502. block_mask = create_block_mask(
  503. mask_mod=mask_function,
  504. B=batch_size,
  505. H=None,
  506. Q_LEN=q_length,
  507. KV_LEN=kv_length,
  508. device=cache_position.device,
  509. _compile=_is_torch_greater_or_equal_than_2_6,
  510. )
  511. return block_mask
  512. class AttentionMaskInterface(GeneralInterface):
  513. # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
  514. # a new instance is created (in order to locally override a given function)
  515. _global_mapping = {
  516. "sdpa": sdpa_mask,
  517. "eager": eager_mask,
  518. "flash_attention_2": flash_attention_mask,
  519. "flash_attention_3": flash_attention_mask,
  520. "flex_attention": flex_attention_mask,
  521. }
  522. # Global AttentionMaskInterface shared by all models which do not need to overwrite any of the existing ones
  523. ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
  524. def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
  525. """
  526. Find the indices of the sequence to which each new query token in the sequence belongs when using packed
  527. tensor format (i.e. several sequences packed in the same batch dimension).
  528. Args:
  529. position_ids (`torch.Tensor`)
  530. A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
  531. Returns:
  532. A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
  533. pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
  534. """
  535. # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
  536. # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
  537. # gives exactly the sequence indices
  538. # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
  539. # cannot be part of the end of the first batch dim and the start of the 2nd one for example
  540. first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
  541. position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
  542. packed_sequence_mask = (position_diff != 1).cumsum(-1)
  543. # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
  544. # but it causes issues with export
  545. return packed_sequence_mask
  546. def _preprocess_mask_arguments(
  547. config: PretrainedConfig,
  548. input_embeds: torch.Tensor,
  549. attention_mask: Optional[Union[torch.Tensor, BlockMask]],
  550. cache_position: torch.Tensor,
  551. past_key_values: Optional[Cache],
  552. position_ids: Optional[torch.Tensor],
  553. layer_idx: Optional[int],
  554. ) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]:
  555. """
  556. Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the
  557. key-value length and offsets, and if we should early exit or not.
  558. Args:
  559. config (`PretrainedConfig`):
  560. The model config.
  561. input_embeds (`torch.Tensor`):
  562. The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
  563. batch size, query length and dtype.
  564. attention_mask (`torch.Tensor`, optional):
  565. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
  566. It can also be an already prepared 4D mask, in which case it is returned as-is.
  567. cache_position (`torch.Tensor`):
  568. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  569. past_key_values (`Cache`, optional):
  570. The past key values, if we use a cache.
  571. position_ids (`torch.Tensor`, optional)
  572. A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
  573. layer_idx (`int`, optional):
  574. If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
  575. length and offset. Indeed, for hybrid caches, different layers may return different lengths.
  576. Returns:
  577. early_exit (`bool`):
  578. Whether we should early exit mask creation, and return the mask as-is.
  579. attention_mask (`torch.Tensor` or `BlockMask` or `None`):
  580. The attention mask to either return immediately, or to use in downstream mask creation.
  581. packed_sequence_mask (`torch.Tensor`, optional):
  582. In case we detected packed sequence format, this is a tensor where each similar integer indicates that
  583. the tokens belong to the same sequence.
  584. kv_length (`int`):
  585. The size that the key and value states will have during the attention computation.
  586. kv_offset (`int`):
  587. An offset to indicate at which first position the key and values states will refer to.
  588. """
  589. # If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
  590. if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
  591. return True, attention_mask, None, None, None
  592. # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
  593. # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
  594. # full graph dynamo tracing (i.e. torch.export or compile with `fullgraph=True`) will fail on Python<3.11
  595. # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
  596. # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
  597. if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping:
  598. return True, None, None, None, None
  599. # Move the mask to correct device, and potentially switch dtype for efficiency
  600. if attention_mask is not None and attention_mask.ndim == 2:
  601. attention_mask = attention_mask.to(device=cache_position.device, dtype=torch.bool)
  602. # If using a cache, it can give all information about mask sizes based on seen tokens
  603. if past_key_values is not None:
  604. kv_length, kv_offset = past_key_values.get_mask_sizes(cache_position, layer_idx)
  605. # Otherwise, the sizes are simply the input sizes
  606. else:
  607. kv_length, kv_offset = input_embeds.shape[1], 0
  608. # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None,
  609. # and we don't have past_key_values, i.e. generally a training setup)
  610. packed_sequence_mask = None
  611. if position_ids is not None and attention_mask is None and past_key_values is None:
  612. batch_size = input_embeds.shape[0]
  613. # The position ids are sometimes just unsqueezed, without being expanded
  614. if batch_size != position_ids.shape[0]:
  615. position_ids = position_ids.expand(batch_size, -1)
  616. packed_sequence_mask = find_packed_sequence_indices(position_ids)
  617. return False, attention_mask, packed_sequence_mask, kv_length, kv_offset
  618. def create_causal_mask(
  619. config: PretrainedConfig,
  620. input_embeds: torch.Tensor,
  621. attention_mask: Optional[torch.Tensor],
  622. cache_position: torch.Tensor,
  623. past_key_values: Optional[Cache],
  624. position_ids: Optional[torch.Tensor] = None,
  625. or_mask_function: Optional[Callable] = None,
  626. and_mask_function: Optional[Callable] = None,
  627. ) -> Optional[Union[torch.Tensor, BlockMask]]:
  628. """
  629. Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
  630. has an hybrid cache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
  631. to what is needed in the `modeling_xxx.py` files).
  632. Args:
  633. config (`PretrainedConfig`):
  634. The model config.
  635. input_embeds (`torch.Tensor`):
  636. The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
  637. batch size, query length and dtype.
  638. attention_mask (`torch.Tensor`, optional):
  639. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
  640. It can also be an already prepared 4D mask, in which case it is returned as-is.
  641. cache_position (`torch.Tensor`):
  642. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  643. past_key_values (`Cache`, optional):
  644. The past key values, if we use a cache.
  645. position_ids (`torch.Tensor`, optional)
  646. A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
  647. or_mask_function (`Callable`, optional):
  648. An optional mask function to combine with the causal mask function (by doing the union of both). This is
  649. useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
  650. and_mask_function (`Callable`, optional):
  651. An optional mask function to combine with the causal mask function (by doing the intersection of both). This is
  652. useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
  653. """
  654. # If we have an hybrid cache structure, here we want to create the mask for the full layers
  655. if hasattr(past_key_values, "is_sliding") and False in past_key_values.is_sliding:
  656. layer_idx = past_key_values.is_sliding.index(False)
  657. else:
  658. layer_idx = 0
  659. early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
  660. config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
  661. )
  662. if early_exit:
  663. return attention_mask
  664. batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
  665. mask_factory_function = causal_mask_function
  666. mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
  667. # Do not allow skip if we are compiling (this is to match BC)
  668. # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
  669. if _is_torch_xpu_available:
  670. allow_is_causal_skip = True
  671. else:
  672. allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
  673. # Allow slight deviations from causal mask
  674. # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
  675. # padding mask, etc) as the resulting mask may otherwise not be correct!
  676. if or_mask_function is not None:
  677. if not _is_torch_greater_or_equal_than_2_6:
  678. raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
  679. mask_factory_function = or_masks(mask_factory_function, or_mask_function)
  680. allow_is_causal_skip = False
  681. if and_mask_function is not None:
  682. if not _is_torch_greater_or_equal_than_2_6:
  683. raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
  684. mask_factory_function = and_masks(mask_factory_function, and_mask_function)
  685. allow_is_causal_skip = False
  686. # If we detected packing format
  687. if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
  688. mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
  689. allow_is_causal_skip = False
  690. # We now create the mask
  691. causal_mask = mask_interface(
  692. batch_size=batch_size,
  693. cache_position=cache_position,
  694. kv_length=kv_length,
  695. kv_offset=kv_offset,
  696. mask_function=mask_factory_function,
  697. attention_mask=attention_mask,
  698. allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
  699. dtype=dtype, # Additional kwarg for eager
  700. config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
  701. )
  702. return causal_mask
  703. def create_sliding_window_causal_mask(
  704. config: PretrainedConfig,
  705. input_embeds: torch.Tensor,
  706. attention_mask: Optional[torch.Tensor],
  707. cache_position: torch.Tensor,
  708. past_key_values: Optional[Cache],
  709. position_ids: Optional[torch.Tensor] = None,
  710. or_mask_function: Optional[Callable] = None,
  711. and_mask_function: Optional[Callable] = None,
  712. ) -> Optional[Union[torch.Tensor, BlockMask]]:
  713. """
  714. Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
  715. of attention pattern was mostly democratized by Mistral. If `past_key_values` has an hybrid cache structure, this
  716. function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
  717. `modeling_xxx.py` files).
  718. Args:
  719. config (`PretrainedConfig`):
  720. The model config.
  721. input_embeds (`torch.Tensor`):
  722. The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
  723. batch size, query length and dtype.
  724. attention_mask (`torch.Tensor`, optional):
  725. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
  726. It can also be an already prepared 4D mask, in which case it is returned as-is.
  727. cache_position (`torch.Tensor`):
  728. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  729. past_key_values (`Cache`, optional):
  730. The past key values, if we use a cache.
  731. position_ids (`torch.Tensor`, optional)
  732. A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
  733. or_mask_function (`Callable`, optional):
  734. An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
  735. useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
  736. and_mask_function (`Callable`, optional):
  737. An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
  738. useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
  739. """
  740. # If we have an hybrid cache structure, here we want to create the mask for the sliding layers
  741. if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
  742. layer_idx = past_key_values.is_sliding.index(True)
  743. else:
  744. layer_idx = 0
  745. early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
  746. config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
  747. )
  748. if early_exit:
  749. return attention_mask
  750. sliding_window = getattr(config, "sliding_window", None)
  751. if sliding_window is None:
  752. raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
  753. batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
  754. mask_factory_function = sliding_window_causal_mask_function(sliding_window)
  755. mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
  756. # Do not allow skip if we are compiling (this is to match BC)
  757. # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
  758. allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
  759. # Allow slight deviations from causal mask
  760. # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
  761. # padding mask, etc) as the resulting mask may otherwise not be correct!
  762. if or_mask_function is not None:
  763. if not _is_torch_greater_or_equal_than_2_6:
  764. raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
  765. mask_factory_function = or_masks(mask_factory_function, or_mask_function)
  766. allow_is_causal_skip = False
  767. if and_mask_function is not None:
  768. if not _is_torch_greater_or_equal_than_2_6:
  769. raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
  770. mask_factory_function = and_masks(mask_factory_function, and_mask_function)
  771. allow_is_causal_skip = False
  772. # If we detected packing format
  773. if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
  774. mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
  775. allow_is_causal_skip = False
  776. # We now create the mask
  777. causal_mask = mask_interface(
  778. batch_size=batch_size,
  779. cache_position=cache_position,
  780. kv_length=kv_length,
  781. kv_offset=kv_offset,
  782. mask_function=mask_factory_function,
  783. attention_mask=attention_mask,
  784. allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
  785. local_size=sliding_window, # Additional kwarg for sdpa
  786. dtype=dtype, # Additional kwarg for eager
  787. config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
  788. )
  789. return causal_mask
  790. def create_chunked_causal_mask(
  791. config: PretrainedConfig,
  792. input_embeds: torch.Tensor,
  793. attention_mask: Optional[torch.Tensor],
  794. cache_position: torch.Tensor,
  795. past_key_values: Optional[Cache],
  796. position_ids: Optional[torch.Tensor] = None,
  797. or_mask_function: Optional[Callable] = None,
  798. and_mask_function: Optional[Callable] = None,
  799. ) -> Optional[Union[torch.Tensor, BlockMask]]:
  800. """
  801. Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
  802. of attention pattern was mostly democratized by Llama4. If `past_key_values` has an hybrid cache structure, this
  803. function will return the mask corresponding to one of the "chunked_attention" layers (to align to what is needed in the
  804. `modeling_xxx.py` files).
  805. Args:
  806. config (`PretrainedConfig`):
  807. The model config.
  808. input_embeds (`torch.Tensor`):
  809. The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
  810. batch size, query length and dtype.
  811. attention_mask (`torch.Tensor`, optional):
  812. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
  813. It can also be an already prepared 4D mask, in which case it is returned as-is.
  814. cache_position (`torch.Tensor`):
  815. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  816. past_key_values (`Cache`, optional):
  817. The past key values, if we use a cache.
  818. position_ids (`torch.Tensor`, optional)
  819. A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
  820. or_mask_function (`Callable`, optional):
  821. An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
  822. useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
  823. and_mask_function (`Callable`, optional):
  824. An optional mask function to combine with the chunked causal mask function (by doing the intersection of both). This is
  825. useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling.
  826. """
  827. # If we have an hybrid cache structure, here we want to create the mask for the sliding layers
  828. if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
  829. layer_idx = past_key_values.is_sliding.index(True)
  830. else:
  831. layer_idx = 0
  832. early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments(
  833. config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx
  834. )
  835. if early_exit:
  836. return attention_mask
  837. chunk_size = getattr(config, "attention_chunk_size", None)
  838. if chunk_size is None:
  839. raise ValueError("Could not find an `attention_chunk_size` argument in the config, or it is not set")
  840. # Raise if using chunked attention on context too large with FA2
  841. if config._attn_implementation == "flash_attention_2" and kv_length + kv_offset > chunk_size:
  842. raise ValueError(
  843. "Flash attention 2 cannot handle chunked attention, and the key-value length is larger than the chunk size so the "
  844. "chunked pattern cannot be respected. You should use another `attn_implementation` when instantiating the model"
  845. )
  846. batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
  847. # For chunked attention and batched inputs, we need to take the number of left padding tokens into account
  848. # to start the chunk from the actual start of the sequence for the padded sequence
  849. if attention_mask is not None:
  850. # Only count the left padding tokens, not all of them
  851. left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
  852. else:
  853. left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int)
  854. # Raise a warning for older versions if the problematic left-padding situation arises
  855. if (
  856. not _is_torch_greater_or_equal_than_2_6
  857. and kv_length + kv_offset > chunk_size
  858. and (left_padding_tokens > 0).any()
  859. ):
  860. logger.warning_once(
  861. "Due to limitations of your current torch version, we cannot correctly account for the left-padding "
  862. "when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded "
  863. "sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue."
  864. )
  865. mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
  866. mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
  867. # Do not allow skip if we are compiling (this is to match BC)
  868. # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
  869. allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
  870. # Allow slight deviations from causal mask
  871. # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
  872. # padding mask, etc) as the resulting mask may otherwise not be correct!
  873. if or_mask_function is not None:
  874. if not _is_torch_greater_or_equal_than_2_6:
  875. raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
  876. mask_factory_function = or_masks(mask_factory_function, or_mask_function)
  877. allow_is_causal_skip = False
  878. if and_mask_function is not None:
  879. if not _is_torch_greater_or_equal_than_2_6:
  880. raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
  881. mask_factory_function = and_masks(mask_factory_function, and_mask_function)
  882. allow_is_causal_skip = False
  883. # If we detected packing format
  884. if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
  885. mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
  886. allow_is_causal_skip = False
  887. # We now create the mask
  888. causal_mask = mask_interface(
  889. batch_size=batch_size,
  890. cache_position=cache_position,
  891. kv_length=kv_length,
  892. kv_offset=kv_offset,
  893. mask_function=mask_factory_function,
  894. attention_mask=attention_mask,
  895. allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
  896. local_size=chunk_size, # Additional kwarg for sdpa
  897. dtype=dtype, # Additional kwarg for eager
  898. config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
  899. )
  900. return causal_mask
  901. LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
  902. "full_attention": create_causal_mask,
  903. "sliding_attention": create_sliding_window_causal_mask,
  904. "chunked_attention": create_chunked_causal_mask,
  905. }
  906. def create_masks_for_generate(
  907. config: PretrainedConfig,
  908. input_embeds: torch.Tensor,
  909. attention_mask: Optional[torch.Tensor],
  910. cache_position: torch.Tensor,
  911. past_key_values: Optional[Cache],
  912. position_ids: Optional[torch.Tensor] = None,
  913. or_mask_function: Optional[Callable] = None,
  914. and_mask_function: Optional[Callable] = None,
  915. **kwargs,
  916. ):
  917. """
  918. This function mimics how we create the masks in the `modeling_xxx.py` files, and is used in `generate` in order
  919. to easily create the masks in advance, when we compile the forwards with Static caches.
  920. Args:
  921. config (`PretrainedConfig`):
  922. The model config.
  923. input_embeds (`torch.Tensor`):
  924. The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
  925. batch size, query length and dtype.
  926. attention_mask (`torch.Tensor`, optional):
  927. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
  928. It can also be an already prepared 4D mask, in which case it is returned as-is.
  929. cache_position (`torch.Tensor`):
  930. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  931. past_key_values (`Cache`, optional):
  932. The past key values, if we use a cache.
  933. position_ids (`torch.Tensor`, optional)
  934. A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
  935. or_mask_function (`Callable`, optional):
  936. An optional mask function to combine with the other mask function (by doing the union of both). This is
  937. useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
  938. and_mask_function (`Callable`, optional):
  939. An optional mask function to combine with the other mask function (by doing the intersection of both). This is
  940. useful to easily overlay another mask on top of the causal one, for example for image tokens handling.
  941. """
  942. # The attribute reside in the text config for composite models
  943. effective_config = config.get_text_config()
  944. # Prepare the mask args
  945. mask_kwargs = {
  946. "config": effective_config,
  947. "input_embeds": input_embeds,
  948. "attention_mask": attention_mask,
  949. "cache_position": cache_position,
  950. "past_key_values": past_key_values,
  951. "position_ids": position_ids,
  952. "or_mask_function": or_mask_function,
  953. "and_mask_function": and_mask_function,
  954. }
  955. # If the attribute exist, we need several masks
  956. if hasattr(effective_config, "layer_types"):
  957. causal_masks = {}
  958. for layer_pattern in set(effective_config.layer_types):
  959. causal_masks[layer_pattern] = LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING[layer_pattern](**mask_kwargs)
  960. return causal_masks
  961. # In this case, all layers are sliding
  962. elif getattr(effective_config, "sliding_window", None) is not None:
  963. return create_sliding_window_causal_mask(**mask_kwargs)
  964. # In this case, all layers are chunked
  965. elif getattr(effective_config, "attention_chunk_size", None) is not None:
  966. return create_chunked_causal_mask(**mask_kwargs)
  967. # All layers use standard causal attention
  968. return create_causal_mask(**mask_kwargs)
  969. # Below are utilities to pretty-print the different masks
  970. # Print the matrix with words as row labels
  971. GREEN = "\033[92m"
  972. YELLOW = "\033[93m"
  973. RESET = "\033[0m"
  974. BLACK_SQUARE = "■"
  975. WHITE_SQUARE = "⬚"
  976. GREY_SQUARE = "∙"
  977. LOW_TRIANGLE = "⬕"
  978. UPPER_TRIANGLE = "⬔"
  979. def get_style(style):
  980. if style == "majong":
  981. BLACK_SQUARE = "🀞" # Full block (represents "on" or active)
  982. BLACK_SQUARE = "🀙" # Full block (represents "on" or active)
  983. WHITE_SQUARE = "🀆" # "▒" # Light shade (represents "off" or inactive)
  984. LOW_TRIANGLE = "🀛" # Lower left triangle (stylized indication)
  985. UPPER_TRIANGLE = "🀛" # Upper left triangle (stylized indication)
  986. else:
  987. BLACK_SQUARE = "█" # Full block (represents "on" or active)
  988. WHITE_SQUARE = "░" # "▒" # Light shade (represents "off" or inactive)
  989. LOW_TRIANGLE = "▙" # Lower left triangle (stylized indication))
  990. UPPER_TRIANGLE = "▜" # Upper left triangle (stylized indication)
  991. return BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE
  992. # LOW_TRIANGLE = UPPER_TRIANGLE = "⟍" # Upper right triangle (stylized indication)
  993. YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}"
  994. GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}"
  995. def tensor_to_mask_visual(original_tensor: torch.Tensor, grid_size=(20, 40), style="majong") -> str:
  996. BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style)
  997. h, w = original_tensor.shape
  998. max_h, max_w = grid_size
  999. if not (h < max_h and w < max_w):
  1000. # Preserve aspect ratio within max grid size
  1001. aspect_ratio = 2 * w / h
  1002. if aspect_ratio > 1:
  1003. w = max_w
  1004. h = min(max_h, max(1, round(max_w / aspect_ratio)))
  1005. else:
  1006. h = max_h
  1007. w = max(1, round(max_h * aspect_ratio))
  1008. # Step 1: Rescale tensor by average pooling
  1009. tensor = original_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
  1010. tensor = F.adaptive_avg_pool2d(tensor, output_size=(h, w))[0, 0] # Remove extra dims
  1011. else:
  1012. tensor = original_tensor
  1013. # Step 3: Build the string representation
  1014. result = []
  1015. for i in range(h):
  1016. row = ""
  1017. for j in range(w):
  1018. if tensor[i, j] == 1:
  1019. row += BLACK_SQUARE
  1020. elif tensor[i, j] == 0:
  1021. row += WHITE_SQUARE
  1022. else:
  1023. if j > 0:
  1024. if tensor[i, j - 1] == 1:
  1025. row += LOW_TRIANGLE
  1026. elif tensor[i, j - 1] == 0:
  1027. row += UPPER_TRIANGLE
  1028. else:
  1029. row += BLACK_SQUARE if tensor[i, j] == 1 else WHITE_SQUARE
  1030. else:
  1031. row += (
  1032. BLACK_SQUARE
  1033. if tensor[i, j] == 1
  1034. else (
  1035. WHITE_SQUARE
  1036. if tensor[i, j] == 0
  1037. else (UPPER_TRIANGLE if tensor[i, j + 1] == 1 else LOW_TRIANGLE)
  1038. )
  1039. )
  1040. result.append(row)
  1041. return "\n".join(result)
  1042. class AttentionMask(torch.Tensor):
  1043. def __new__(cls, data, style=None):
  1044. # Create a new instance of AttentionMask as a Tensor
  1045. cls.style = style
  1046. return torch.Tensor._make_subclass(cls, data, require_grad=False)
  1047. def __init__(self, data):
  1048. # You can initialize any additional metadata here if needed
  1049. pass
  1050. def to_string(self, grid_size=(20, 40), limit=4):
  1051. """Returns a string representation of the block mask."""
  1052. dense_mask = self
  1053. *batch_dims, num_rows, num_cols = dense_mask.shape
  1054. total_vis = []
  1055. for idx, batch_idx in enumerate(itertools.product(*[range(i) for i in batch_dims])):
  1056. if idx == limit:
  1057. total_vis.append("...")
  1058. total_vis.append("To print out more, set AttentionMask.to_string(limit=N)")
  1059. total_vis.append("You can also index (AttentionMask[batch, head]) to choose a specific batch or head")
  1060. break
  1061. block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style)
  1062. total_vis.append(block_vis)
  1063. total_vis.append(f"torch.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})")
  1064. return "\n".join(total_vis)
  1065. def __repr__(self):
  1066. return self.to_string()
  1067. def __str__(self):
  1068. return self.to_string()
  1069. @classmethod
  1070. def from_tensor(cls, tensor: torch.Tensor, style: Optional[str] = None) -> "AttentionMask":
  1071. res = cls(tensor)
  1072. res.style = style
  1073. return res