executorch.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  5. # the License. You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  10. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  11. # specific language governing permissions and limitations under the License.
  12. import logging
  13. from typing import Callable, Optional
  14. import torch
  15. from ..cache_utils import (
  16. DynamicCache,
  17. DynamicLayer,
  18. DynamicSlidingWindowLayer,
  19. EncoderDecoderCache,
  20. StaticCache,
  21. )
  22. from ..generation.configuration_utils import GenerationConfig
  23. from ..masking_utils import (
  24. ALL_MASK_ATTENTION_FUNCTIONS,
  25. _ignore_causal_mask_sdpa,
  26. _is_torch_greater_or_equal_than_2_5,
  27. prepare_padding_mask,
  28. )
  29. from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ..pytorch_utils import (
  31. is_torch_greater_or_equal,
  32. is_torch_greater_or_equal_than_2_3,
  33. is_torch_greater_or_equal_than_2_6,
  34. )
  35. class TorchExportableModuleForVLM:
  36. """
  37. A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
  38. This class handles the export of three main components:
  39. 1. Vision encoder (processes images to visual features)
  40. 2. Connector/projector (maps visual features to text embedding space)
  41. 3. Text decoder (generates text from combined visual and text tokens)
  42. """
  43. def __init__(self, model, max_batch_size: int = 1, max_cache_len: int = 1024):
  44. """
  45. Initialize the exportable VLM module.
  46. Args:
  47. model: The VLM (e.g. SmolVLM) model instance
  48. max_batch_size: Maximum batch size. Always 1 for ExecuTorch
  49. max_cache_len: Maximum cache length for text generation
  50. """
  51. self.model = model
  52. self.max_batch_size = max_batch_size
  53. self.max_cache_len = max_cache_len
  54. self.config = model.config
  55. # Extract individual components
  56. self.vision_encoder = model.model.vision_model
  57. self.connector = model.model.connector
  58. self.text_decoder = model.model.text_model
  59. # Store exported programs
  60. self.exported_vision_encoder = None
  61. self.exported_connector = None
  62. self.exported_text_decoder = None
  63. def export_vision_encoder(self):
  64. """Export the vision encoder component."""
  65. self.vision_encoder.eval()
  66. # Create example input
  67. pixel_values = torch.randn(1, 3, 384, 384, dtype=torch.float32)
  68. # Define dynamic shapes
  69. dynamic_shapes = {
  70. "pixel_values": {
  71. 2: torch.export.Dim.AUTO,
  72. 3: torch.export.Dim.AUTO,
  73. }
  74. }
  75. self.exported_vision_encoder = torch.export.export(
  76. self.vision_encoder,
  77. args=(pixel_values,),
  78. dynamic_shapes=dynamic_shapes,
  79. strict=False,
  80. )
  81. return self.exported_vision_encoder
  82. def export_connector(self):
  83. """Export the connector component."""
  84. self.connector.eval()
  85. # Vision encoder output shape: [batch_size, num_patches, vision_hidden_size]
  86. vision_hidden_size = self.config.vision_config.hidden_size
  87. image_size = self.config.vision_config.image_size
  88. patch_size = self.config.vision_config.patch_size
  89. patches_per_dim = image_size // patch_size
  90. num_patches = patches_per_dim * patches_per_dim
  91. image_hidden_states = torch.randn(1, num_patches, vision_hidden_size, dtype=torch.float32)
  92. # Define dynamic shapes - static batch_size=1, dynamic num_patches
  93. dynamic_shapes = {"image_hidden_states": {1: torch.export.Dim.AUTO}}
  94. # Export the connector using torch.export
  95. self.exported_connector = torch.export.export(
  96. self.connector,
  97. args=(image_hidden_states,),
  98. dynamic_shapes=dynamic_shapes,
  99. strict=False,
  100. )
  101. return self.exported_connector
  102. def export_text_decoder(self):
  103. """Export the text decoder component."""
  104. # Create text decoder exportable wrapper
  105. self.exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM(model=self.text_decoder)
  106. # Use the existing text decoder exportable wrapper
  107. seq_length = 3
  108. input_ids = torch.zeros((1, seq_length), dtype=torch.long)
  109. cache_position = torch.arange(seq_length, dtype=torch.long)
  110. max_seq_length = min(self.max_cache_len, self.config.text_config.max_position_embeddings)
  111. seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_length - 1)
  112. dynamic_shapes = {
  113. "input_ids": {1: seq_len_dim},
  114. "cache_position": {0: seq_len_dim},
  115. }
  116. self.exported_text_decoder = self.exportable_text_decoder.export(
  117. input_ids=input_ids,
  118. cache_position=cache_position,
  119. dynamic_shapes=dynamic_shapes,
  120. strict=False,
  121. )
  122. return self.exported_text_decoder
  123. def export(self, **kwargs):
  124. """Export all components of the VLM model."""
  125. self.export_vision_encoder(**kwargs)
  126. self.export_connector(**kwargs)
  127. self.export_text_decoder(**kwargs)
  128. return {
  129. "vision_encoder": self.exported_vision_encoder,
  130. "connector": self.exported_connector,
  131. "text_decoder": self.exported_text_decoder,
  132. }
  133. def forward(self, pixel_values, input_ids, cache_position):
  134. """
  135. Simplified forward pass for inference with guaranteed non-null input_ids and cache_position.
  136. Args:
  137. pixel_values: Input images [1, channels, height, width] (optional)
  138. input_ids: Text token IDs [1, seq_len] (required - won't be None)
  139. cache_position: Cache positions [seq_len] (required - won't be None)
  140. Returns:
  141. Output with logits for text generation
  142. """
  143. pass
  144. def generate(
  145. self, pixel_values=None, input_ids=None, max_new_tokens=50, do_sample=False, temperature=1.0, **kwargs
  146. ):
  147. """
  148. Simplified generate method with guaranteed non-null input_ids.
  149. Args:
  150. pixel_values: Input images [1, channels, height, width] (optional)
  151. input_ids: Initial text tokens [1, seq_len] (required - won't be None)
  152. max_new_tokens: Maximum number of tokens to generate
  153. do_sample: Whether to use sampling or greedy decoding
  154. temperature: Temperature for sampling
  155. Returns:
  156. Generated sequences
  157. """
  158. pass
  159. class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
  160. """
  161. A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
  162. specifically for decoder-only LM with cache. This module ensures that the
  163. exported model is compatible with further lowering and execution in `ExecuTorch`.
  164. """
  165. def __init__(
  166. self,
  167. model: PreTrainedModel,
  168. batch_size: Optional[int] = None,
  169. max_cache_len: Optional[int] = None,
  170. device: Optional[torch.device] = None,
  171. ) -> None:
  172. """
  173. Initializes the exportable module.
  174. Args:
  175. model (`PreTrainedModel`): The pretrained model to wrap.
  176. Raises:
  177. ValueError: If the model is configured with a unsupported cache implementation.
  178. """
  179. super().__init__()
  180. config = model.config.get_text_config()
  181. if not hasattr(config, "use_cache") or config.use_cache is False:
  182. raise ValueError("The model must have caching enabled to be performant.")
  183. if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
  184. self.model = TorchExportableModuleWithHybridCache(model, batch_size, max_cache_len, device)
  185. else:
  186. # If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
  187. # there is only 1 type of layers, so export will use `StaticCache` by default.
  188. logging.info(
  189. "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
  190. )
  191. self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
  192. # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
  193. ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
  194. ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
  195. self.model.model.config._attn_implementation = "sdpa_without_vmap"
  196. def forward(
  197. self,
  198. input_ids: Optional[torch.Tensor] = None,
  199. inputs_embeds: Optional[torch.Tensor] = None,
  200. cache_position: Optional[torch.Tensor] = None,
  201. ) -> torch.Tensor:
  202. """
  203. Forward pass of the module, which is compatible with the ExecuTorch llm runner.
  204. Args:
  205. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
  206. inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
  207. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
  208. Returns:
  209. torch.Tensor: Logits output from the model.
  210. """
  211. return self.model.forward(
  212. input_ids=input_ids,
  213. inputs_embeds=inputs_embeds,
  214. cache_position=cache_position,
  215. )
  216. def export(
  217. self,
  218. input_ids: Optional[torch.Tensor] = None,
  219. inputs_embeds: Optional[torch.Tensor] = None,
  220. cache_position: Optional[torch.Tensor] = None,
  221. dynamic_shapes: Optional[dict] = None,
  222. strict: Optional[bool] = None,
  223. ) -> torch.export.ExportedProgram:
  224. """
  225. Export the wrapped module using `torch.export`.
  226. Args:
  227. input_ids (`Optional[torch.Tensor]`):
  228. Tensor representing current input token id to the module. Must specify either this or inputs_embeds.
  229. inputs_embeds (`Optional[torch.Tensor]`):
  230. Tensor representing current input embeddings to the module. Must specify either this or input_ids.
  231. cache_position (`Optional[torch.Tensor]`):
  232. Tensor representing current input position in the cache. If not provided, a default tensor will be used.
  233. dynamic_shapes (`Optional[dict]`):
  234. Dynamic shapes to use for export if specified.
  235. strict(`Optional[bool]`):
  236. Flag to instruct `torch.export` to use `torchdynamo`.
  237. Returns:
  238. torch.export.ExportedProgram: The exported program that can be used for inference.
  239. Examples:
  240. Export with input_ids:
  241. ```python
  242. # Prepare inputs
  243. input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
  244. cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
  245. # Export
  246. exported = exportable_module.export(
  247. input_ids=input_ids,
  248. cache_position=cache_position
  249. )
  250. ```
  251. Export with inputs_embeds:
  252. ```python
  253. # Prepare embeddings
  254. inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
  255. cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
  256. # Export
  257. exported = exportable_module.export(
  258. inputs_embeds=inputs_embeds,
  259. cache_position=cache_position
  260. )
  261. ```
  262. """
  263. if not (input_ids is None) ^ (inputs_embeds is None):
  264. raise ValueError("Need to specify either input_ids or inputs_embeds.")
  265. if hasattr(self.model, "base_model_prefix"):
  266. base = getattr(self.model, self.model.base_model_prefix, self.model)
  267. model_device = base.device
  268. elif hasattr(self.model, "model"):
  269. model_device = self.model.model.device
  270. else:
  271. model_device = "cpu"
  272. logging.warning(
  273. "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
  274. )
  275. if input_ids is not None:
  276. input_kwargs = {
  277. "input_ids": input_ids,
  278. "cache_position": cache_position
  279. if cache_position is not None
  280. else torch.arange(input_ids.shape[-1], dtype=torch.long, device=model_device),
  281. }
  282. else: # inputs_embeds
  283. input_kwargs = {
  284. "inputs_embeds": inputs_embeds,
  285. "cache_position": cache_position
  286. if cache_position is not None
  287. else torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model_device),
  288. }
  289. exported_program = torch.export.export(
  290. self.model,
  291. args=(),
  292. kwargs=input_kwargs,
  293. dynamic_shapes=dynamic_shapes,
  294. strict=strict if strict is not None else True,
  295. )
  296. return exported_program
  297. @staticmethod
  298. def generate(
  299. exported_program: torch.export.ExportedProgram,
  300. tokenizer,
  301. prompt: str,
  302. max_new_tokens: int = 20,
  303. do_sample: bool = False,
  304. temperature: float = 1.0,
  305. top_k: int = 50,
  306. top_p: float = 1.0,
  307. device: str = "cpu",
  308. ) -> str:
  309. """
  310. Generate a sequence of tokens using an exported program.
  311. Args:
  312. exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
  313. tokenizer: The tokenizer to use.
  314. prompt (str): The input prompt.
  315. max_new_tokens (int): Maximum number of new tokens to generate.
  316. do_sample (bool): Whether to use sampling or greedy decoding.
  317. temperature (float): The temperature for sampling.
  318. top_k (int): The number of highest probability tokens to keep for top-k sampling.
  319. top_p (float): The cumulative probability for nucleus sampling.
  320. device (str): The device to use.
  321. Returns:
  322. str: The generated text.
  323. """
  324. # Get the module from the exported program
  325. exported_module = exported_program.module()
  326. # Tokenize the prompt
  327. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  328. # Initialize with the prompt
  329. generated_ids = input_ids.clone()
  330. # Process the prompt tokens first
  331. curr_position = 0
  332. for i in range(input_ids.shape[1]):
  333. # Process one token at a time
  334. curr_input_ids = input_ids[:, i : i + 1]
  335. curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
  336. # Forward pass
  337. _ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
  338. curr_position += 1
  339. # Generate new tokens
  340. for _ in range(max_new_tokens):
  341. # Get the last token as input
  342. curr_input_ids = generated_ids[:, -1:]
  343. curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
  344. # Forward pass to get next token logits
  345. outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
  346. # Get the next token ID
  347. if do_sample:
  348. # Apply temperature
  349. if temperature > 0:
  350. logits = outputs / temperature
  351. else:
  352. logits = outputs
  353. # Apply top-k filtering
  354. if top_k > 0:
  355. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
  356. logits[indices_to_remove] = float("-inf")
  357. # Apply top-p (nucleus) filtering
  358. if top_p < 1.0:
  359. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  360. cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
  361. # Remove tokens with cumulative probability above the threshold
  362. sorted_indices_to_remove = cumulative_probs > top_p
  363. # Shift the indices to the right to keep also the first token above the threshold
  364. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  365. sorted_indices_to_remove[..., 0] = 0
  366. # Scatter sorted tensors to original indexing
  367. indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
  368. logits[indices_to_remove] = float("-inf")
  369. # Sample from the filtered distribution
  370. probs = torch.softmax(logits, dim=-1)
  371. next_token_id = torch.multinomial(probs, num_samples=1)
  372. else:
  373. # Greedy decoding
  374. next_token_id = outputs.argmax(dim=-1, keepdim=True)
  375. # Ensure next_token_id has the right shape before concatenation
  376. if next_token_id.dim() > 2:
  377. next_token_id = next_token_id.squeeze(-1)
  378. # Append to the generated sequence
  379. generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
  380. curr_position += 1
  381. # Stop if we generate an EOS token
  382. if next_token_id.item() == tokenizer.eos_token_id:
  383. break
  384. # Decode the generated text
  385. return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  386. class TorchExportableModuleWithStaticCache(torch.nn.Module):
  387. """
  388. A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
  389. specifically for decoder-only LM to `StaticCache`. This module ensures that the
  390. exported model is compatible with further lowering and execution in `ExecuTorch`.
  391. Note:
  392. This class is specifically designed to support export process using `torch.export`
  393. in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
  394. """
  395. def __init__(
  396. self,
  397. model: PreTrainedModel,
  398. batch_size: Optional[int] = None,
  399. max_cache_len: Optional[int] = None,
  400. device: Optional[torch.device] = None,
  401. ) -> None:
  402. """
  403. Initializes the wrapper module with the pretrained model.
  404. Args:
  405. model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
  406. enabled and use a 'static' caching implementation.
  407. batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found
  408. in `generation_config.cache_config` and otherwise we raise a ValueError.
  409. max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if
  410. not provided.
  411. device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found
  412. in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised).
  413. Raises:
  414. AssertionError: If the pretrained model does not have caching enabled or if it does
  415. not use a 'static' caching implementation in `model.generation_config`.
  416. ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`.
  417. """
  418. super().__init__()
  419. config = model.config.get_text_config()
  420. generation_config = model.generation_config
  421. # Sanity checks
  422. if generation_config is None:
  423. raise AssertionError(
  424. "The model must have a generation config to be exported with static caching. "
  425. "Please set `generation_config` in `model`."
  426. )
  427. if not generation_config.use_cache:
  428. raise AssertionError(
  429. "The model must have caching enabled to be exported with static caching. "
  430. "Please set `generation_config.use_cache=True`."
  431. )
  432. if generation_config.cache_implementation != "static":
  433. raise AssertionError(
  434. "The model must use a 'static' caching implementation to be exported with static caching. "
  435. "Please set `generation_config.cache_implementation='static'`."
  436. )
  437. cache_config = {} if generation_config.cache_config is None else generation_config.cache_config
  438. # Ensure batch_size and max_cache_len are set
  439. if batch_size is None:
  440. batch_size = cache_config.get("batch_size", None)
  441. if batch_size is None:
  442. raise ValueError("batch_size must be provided, either as an argument or in cache_config.")
  443. if max_cache_len is None:
  444. max_cache_len = cache_config.get("max_cache_len", None)
  445. if max_cache_len is None:
  446. raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.")
  447. # Infer device if not provided
  448. if device is None:
  449. device = cache_config.get("device", model.device)
  450. # Initialize the static cache
  451. self.model = model
  452. self.static_cache = StaticCache(max_cache_len=max_cache_len, config=config)
  453. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  454. num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
  455. dtype = self.model.dtype
  456. # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
  457. self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
  458. for i in range(len(self.static_cache)):
  459. self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
  460. self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
  461. def forward(
  462. self,
  463. input_ids: Optional[torch.LongTensor] = None,
  464. inputs_embeds: Optional[torch.Tensor] = None,
  465. cache_position: Optional[torch.Tensor] = None,
  466. ):
  467. """
  468. Forward pass of the module, which is compatible with the ExecuTorch runtime.
  469. Args:
  470. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
  471. inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
  472. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
  473. Returns:
  474. torch.Tensor: Logits output from the model.
  475. This forward adapter serves two primary purposes:
  476. 1. **Making the Model `torch.export`-Compatible**:
  477. The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
  478. enabling the model to be exportable using `torch.export` without encountering issues.
  479. 2. **Ensuring Compatibility with `ExecuTorch` runtime**:
  480. The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
  481. ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
  482. """
  483. past_key_values = self.static_cache
  484. outs = self.model(
  485. input_ids=input_ids,
  486. inputs_embeds=inputs_embeds,
  487. cache_position=cache_position,
  488. attention_mask=None,
  489. past_key_values=past_key_values,
  490. use_cache=True,
  491. )
  492. if hasattr(outs, "logits"):
  493. # Returned outputs is `CausalLMOutputWithPast`
  494. return outs.logits
  495. else:
  496. # Returned the `last_hidden_state` from `BaseModelOutputWithPast`
  497. return outs.last_hidden_state
  498. @staticmethod
  499. def generate(
  500. exported_program: torch.export.ExportedProgram,
  501. prompt_token_ids: torch.Tensor,
  502. max_new_tokens: int,
  503. ) -> torch.Tensor:
  504. """
  505. Generate a sequence of tokens using an exported program.
  506. This util function is designed to test exported models by simulating the generation process.
  507. It processes the input prompt tokens sequentially (no parallel prefill).
  508. This generate function is not intended to replace the original `generate` method, and the support
  509. for leveraging the original `generate` is potentially planned!
  510. Args:
  511. exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
  512. prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs.
  513. max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation
  514. length is limited by both `max_new_tokens` and the model's cache size.
  515. Returns:
  516. torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
  517. """
  518. device = prompt_token_ids.device
  519. prompt_token_len = prompt_token_ids.shape[-1]
  520. max_generation_length = prompt_token_len + max_new_tokens
  521. for buffer_name, buffer in exported_program.named_buffers():
  522. if buffer_name.startswith("key_cache"):
  523. max_cache_len = buffer.shape[2]
  524. max_generation_length = min(max_generation_length, max_cache_len)
  525. break
  526. response_tokens = []
  527. for input_pos in range(min(max_generation_length, prompt_token_len)):
  528. result = exported_program.module().forward(
  529. input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
  530. cache_position=torch.tensor([input_pos], dtype=torch.long, device=device),
  531. )
  532. response_tokens.append(prompt_token_ids[0][input_pos].item())
  533. current_token = torch.argmax(result[:, -1, :], dim=-1).item()
  534. response_tokens.append(current_token)
  535. while len(response_tokens) < max_generation_length:
  536. result = exported_program.module().forward(
  537. input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device),
  538. cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device),
  539. )
  540. current_token = torch.argmax(result[:, -1, :], dim=-1).item()
  541. response_tokens.append(current_token)
  542. return torch.tensor([response_tokens], dtype=torch.long, device=device)
  543. class TorchExportableModuleWithHybridCache(torch.nn.Module):
  544. """
  545. A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
  546. specifically for decoder-only LM to hybrid `StaticCache`. This module ensures that the
  547. exported model is compatible with further lowering and execution in `ExecuTorch`.
  548. """
  549. def __init__(
  550. self,
  551. model: PreTrainedModel,
  552. batch_size: Optional[int] = None,
  553. max_cache_len: Optional[int] = None,
  554. device: Optional[torch.device] = None,
  555. ) -> None:
  556. """
  557. Initializes the exportable module.
  558. Args:
  559. model (`PreTrainedModel`): The pretrained model to wrap.
  560. batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found
  561. in `generation_config.cache_config` and otherwise we raise a ValueError.
  562. max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if
  563. not provided.
  564. device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found
  565. in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised).
  566. Raises:
  567. AssertionError: If the model doesn't have the expected configuration for hybrid StaticCache.
  568. ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`.
  569. """
  570. super().__init__()
  571. self.model = model
  572. config = model.config.get_text_config()
  573. generation_config = model.generation_config
  574. # Sanity checks
  575. if generation_config is None:
  576. raise AssertionError(
  577. "The model must have a generation config to be exported with static caching. "
  578. "Please set `generation_config` in `model`."
  579. )
  580. if not config.use_cache:
  581. raise AssertionError("Model must have caching enabled.")
  582. cache_config = {} if generation_config.cache_config is None else generation_config.cache_config
  583. # Ensure batch_size and max_cache_len are set
  584. if batch_size is None:
  585. batch_size = cache_config.get("batch_size", None)
  586. if batch_size is None:
  587. raise ValueError("batch_size must be provided, either as an argument or in cache_config.")
  588. if max_cache_len is None:
  589. max_cache_len = cache_config.get("max_cache_len", None)
  590. if max_cache_len is None:
  591. raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.")
  592. # Infer device if not provided
  593. if device is None:
  594. device = cache_config.get("device", model.device)
  595. # Initialize the cache
  596. self.cache = StaticCache(config=config, max_cache_len=max_cache_len)
  597. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  598. num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
  599. dtype = self.model.dtype
  600. # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
  601. self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
  602. # Register all key and value cache tensors as buffers
  603. for i in range(len(self.cache)):
  604. self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False)
  605. self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False)
  606. def forward(
  607. self,
  608. input_ids: Optional[torch.LongTensor] = None,
  609. inputs_embeds: Optional[torch.Tensor] = None,
  610. cache_position: Optional[torch.Tensor] = None,
  611. ) -> torch.Tensor:
  612. """
  613. Forward pass of the module, which is compatible with the ExecuTorch llm runner.
  614. Args:
  615. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
  616. inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module.
  617. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
  618. Returns:
  619. torch.Tensor: Logits output from the model.
  620. """
  621. # Forward pass with the model
  622. outputs = self.model(
  623. input_ids=input_ids,
  624. inputs_embeds=inputs_embeds,
  625. cache_position=cache_position,
  626. attention_mask=None,
  627. past_key_values=self.cache,
  628. use_cache=True,
  629. )
  630. # Return only the logits to simplify the export
  631. return outputs.logits
  632. def convert_and_export_with_cache(
  633. model: PreTrainedModel,
  634. example_input_ids: Optional[torch.Tensor] = None,
  635. example_cache_position: Optional[torch.Tensor] = None,
  636. dynamic_shapes: Optional[dict] = None,
  637. strict: Optional[bool] = None,
  638. ):
  639. """
  640. Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
  641. ensuring the exported model is compatible with `ExecuTorch`.
  642. Args:
  643. model (`PreTrainedModel`): The pretrained model to be exported.
  644. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
  645. example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
  646. dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
  647. strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.
  648. Returns:
  649. Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
  650. """
  651. if not is_torch_greater_or_equal_than_2_3:
  652. raise ImportError("torch >= 2.3 is required.")
  653. import torch.export._trace
  654. # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
  655. ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
  656. ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
  657. model.config._attn_implementation = "sdpa_without_vmap"
  658. with torch.no_grad():
  659. # TODO: The default inputs only work for text models. We need to add support for vision/audio models.
  660. example_input_ids = (
  661. example_input_ids
  662. if example_input_ids is not None
  663. else torch.tensor([[1]], dtype=torch.long, device=model.device)
  664. )
  665. example_cache_position = (
  666. example_cache_position
  667. if example_cache_position is not None
  668. else torch.tensor([0], dtype=torch.long, device=model.device)
  669. )
  670. if is_torch_greater_or_equal("2.6.0"):
  671. exported_program = torch.export.export(
  672. TorchExportableModuleWithStaticCache(model),
  673. args=(),
  674. kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
  675. dynamic_shapes=dynamic_shapes,
  676. strict=strict if strict is not None else True,
  677. )
  678. else:
  679. if dynamic_shapes is not None:
  680. logging.warning(
  681. "Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0."
  682. )
  683. if strict is not None:
  684. logging.warning("The strict flag will be ignored by convert_and_export_with_cache for torch < 2.6.0.")
  685. # We have to keep this path for BC.
  686. #
  687. # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
  688. # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
  689. exported_program = torch.export._trace._export(
  690. TorchExportableModuleWithStaticCache(model),
  691. args=(),
  692. kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
  693. pre_dispatch=False,
  694. strict=True,
  695. )
  696. return exported_program
  697. class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
  698. """
  699. A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
  700. This module ensures that the exported encoder model is compatible with ExecuTorch.
  701. """
  702. def __init__(self, encoder_model):
  703. super().__init__()
  704. self.encoder = encoder_model
  705. def forward(self, input_ids):
  706. return self.encoder(input_ids=input_ids).last_hidden_state
  707. class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
  708. """
  709. A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
  710. specifically for use with static caching. This module ensures the exported decoder
  711. is compatible with ExecuTorch.
  712. """
  713. def __init__(self, model, max_static_cache_length, batch_size):
  714. super().__init__()
  715. # Get the decoder component
  716. self.decoder = model.get_decoder()
  717. self.lm_head = model.lm_head
  718. self.config = model.config
  719. # Detect the device of the exported models by checking a parameter
  720. # We'll use the model's device as the target device
  721. model_device = next(model.parameters()).device
  722. # Initialize static cache for decoder and DynamicCache for encoder
  723. self.static_cache = StaticCache(config=self.config, max_cache_len=max_static_cache_length)
  724. head_dim = getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads)
  725. num_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
  726. self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
  727. self.cache = EncoderDecoderCache(self.static_cache, DynamicCache(config=self.config))
  728. register_dynamic_cache_export_support()
  729. # Register cache buffers to make them exportable
  730. for i in range(len(self.static_cache)):
  731. self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
  732. self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
  733. def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
  734. # Get outputs from decoder
  735. outputs = self.decoder(
  736. input_ids=decoder_input_ids,
  737. encoder_hidden_states=encoder_hidden_states,
  738. past_key_values=self.cache,
  739. use_cache=True,
  740. cache_position=cache_position,
  741. )
  742. # Apply language model head
  743. lm_logits = self.lm_head(outputs[0])
  744. return lm_logits
  745. class Seq2SeqLMExportableModule(torch.nn.Module):
  746. def __init__(
  747. self, model, batch_size=1, max_hidden_seq_length=4096, cache_implementation="static", max_cache_length=1024
  748. ):
  749. super().__init__()
  750. self.full_model = model
  751. self.encoder = model.get_encoder()
  752. self.config = model.config
  753. self.max_hidden_seq_length = max_hidden_seq_length
  754. self.generation_config = GenerationConfig(
  755. use_cache=True,
  756. max_length=max_cache_length,
  757. cache_implementation=cache_implementation,
  758. cache_config={
  759. "batch_size": batch_size,
  760. "max_cache_len": max_cache_length,
  761. },
  762. )
  763. self.exported_encoder = None
  764. self.exported_decoder = None
  765. def _export_encoder(self, encoder_input_ids):
  766. wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.full_model.device).eval()
  767. # Define dynamic sequence length for encoder
  768. seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
  769. # Export the encoder
  770. with torch.no_grad():
  771. exported_encoder = torch.export.export(
  772. wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
  773. )
  774. return exported_encoder
  775. def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
  776. target_device = self.full_model.device
  777. wrapped_decoder = (
  778. Seq2SeqLMDecoderExportableModuleWithStaticCache(
  779. model=self.full_model,
  780. max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
  781. batch_size=self.generation_config.cache_config.get("batch_size"),
  782. )
  783. .to(target_device)
  784. .eval()
  785. )
  786. # Move input tensors to the same device as the wrapped decoder
  787. decoder_input_ids = decoder_input_ids.to(target_device)
  788. encoder_hidden_states = encoder_hidden_states.to(target_device)
  789. cache_position = cache_position.to(target_device)
  790. # Define dynamic dimension for encoder output sequence length
  791. encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
  792. # Export the decoder
  793. with torch.no_grad():
  794. exported_decoder = torch.export.export(
  795. wrapped_decoder,
  796. (decoder_input_ids, encoder_hidden_states, cache_position),
  797. dynamic_shapes={
  798. "decoder_input_ids": None,
  799. "encoder_hidden_states": {1: encoder_seq_len_dim},
  800. "cache_position": None,
  801. },
  802. strict=True,
  803. )
  804. return exported_decoder
  805. def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
  806. device = self.full_model.device
  807. example_encoder_input_ids = (
  808. encoder_input_ids
  809. if encoder_input_ids is not None
  810. else torch.ones((1, 10), dtype=torch.long, device=device)
  811. )
  812. example_decoder_input_ids = (
  813. decoder_input_ids
  814. if decoder_input_ids is not None
  815. else torch.tensor([[0]], dtype=torch.long, device=device)
  816. ) # Start token
  817. example_cache_position = (
  818. cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=device)
  819. )
  820. example_encoder_hidden_states = (
  821. encoder_hidden_states
  822. if encoder_hidden_states is not None
  823. else torch.zeros(
  824. (self.generation_config.cache_config.get("batch_size"), 10, self.config.d_model),
  825. dtype=torch.float32,
  826. device=device,
  827. )
  828. )
  829. self.exported_encoder = self._export_encoder(example_encoder_input_ids)
  830. self.exported_decoder = self._export_decoder(
  831. example_decoder_input_ids, example_encoder_hidden_states, example_cache_position
  832. )
  833. # Return self to allow chaining
  834. return self
  835. def generate(self, prompt_token_ids, max_new_tokens):
  836. with torch.no_grad():
  837. model_device = self.full_model.device
  838. # Move input to the model's device if it's on a different device
  839. if prompt_token_ids.device != model_device:
  840. prompt_token_ids = prompt_token_ids.to(model_device)
  841. # Run encoder
  842. encoder_output = self.exported_encoder.module()(prompt_token_ids)
  843. # Initialize with start token (0 for T5) on the correct device
  844. decoder_input_ids = torch.tensor([[0]], dtype=torch.long, device=model_device)
  845. generated_ids = [0]
  846. # Generate tokens one by one
  847. for i in range(max_new_tokens - 1):
  848. # Run decoder for next token prediction
  849. logits = self.exported_decoder.module()(
  850. decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long, device=model_device)
  851. )
  852. # Get next token
  853. next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
  854. generated_ids.append(next_token)
  855. # Update input for next iteration on the correct device
  856. decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long, device=model_device)
  857. # Check if EOS token
  858. if next_token == self.config.eos_token_id:
  859. break
  860. return generated_ids
  861. def export_with_dynamic_cache(
  862. model: PreTrainedModel,
  863. example_input_ids: Optional[torch.Tensor] = None,
  864. example_attention_mask: Optional[torch.Tensor] = None,
  865. ):
  866. """
  867. Export a model with DynamicCache using `torch.export`, ensuring the exported model is compatible with `ExecuTorch`.
  868. Args:
  869. model (`PreTrainedModel`): The pretrained model to be exported.
  870. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
  871. example_attention_mask (`Optional[torch.Tensor]`): Example attention mask used by `torch.export`.
  872. Returns:
  873. Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
  874. """
  875. if not is_torch_greater_or_equal_than_2_3:
  876. raise ImportError("torch >= 2.3 is required.")
  877. # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
  878. ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
  879. ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
  880. model.config._attn_implementation = "sdpa_without_vmap"
  881. register_dynamic_cache_export_support()
  882. with torch.no_grad():
  883. exported_program = torch.export.export(
  884. model,
  885. (),
  886. {
  887. "input_ids": example_input_ids,
  888. "attention_mask": example_attention_mask,
  889. "past_key_values": DynamicCache(config=model.config),
  890. "use_cache": True,
  891. },
  892. strict=False,
  893. )
  894. return exported_program
  895. def register_dynamic_cache_export_support():
  896. """
  897. Utilities for `DynamicCache` <> torch.export support
  898. """
  899. try:
  900. torch.utils._pytree.register_pytree_node(
  901. DynamicCache,
  902. lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
  903. _unflatten_dynamic_cache,
  904. serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
  905. flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
  906. _get_cache_dict(dynamic_cache)
  907. ),
  908. )
  909. # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
  910. torch.fx._pytree.register_pytree_flatten_spec(
  911. DynamicCache,
  912. lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
  913. )
  914. # Catching this in case there are multiple runs for some test runs
  915. except ValueError as e:
  916. if "already registered as pytree node" not in str(e):
  917. raise
  918. def _get_cache_dict(cache: DynamicCache):
  919. """Convert cache to dictionary format for pytree operations."""
  920. if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
  921. raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
  922. if not is_torch_greater_or_equal_than_2_6:
  923. logging.warning("DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions.")
  924. return {
  925. "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
  926. "value_cache": [layer.values for layer in cache.layers if layer.values is not None],
  927. }
  928. def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
  929. dictionary = torch.utils._pytree._dict_unflatten(values, context)
  930. cache = DynamicCache()
  931. # Reconstruct layers from keys and values lists
  932. key_list = dictionary.get("key_cache", [])
  933. value_list = dictionary.get("value_cache", [])
  934. for idx in range(max(len(key_list), len(value_list))):
  935. key = key_list[idx] if idx < len(key_list) else None
  936. value = value_list[idx] if idx < len(value_list) else None
  937. cache.update(key, value, idx)
  938. return cache
  939. def sdpa_mask_without_vmap(
  940. batch_size: int,
  941. cache_position: torch.Tensor,
  942. kv_length: int,
  943. kv_offset: int = 0,
  944. mask_function: Optional[Callable] = None,
  945. attention_mask: Optional[torch.Tensor] = None,
  946. local_size: Optional[int] = None,
  947. allow_is_causal_skip: bool = True,
  948. allow_torch_fix: bool = True,
  949. **kwargs,
  950. ) -> Optional[torch.Tensor]:
  951. """
  952. Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
  953. the element should take part in the attention computation, and False that it should not.
  954. This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
  955. Args:
  956. batch_size (`int`):
  957. The batch size of the input sequence.
  958. cache_position (`torch.Tensor`):
  959. A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
  960. kv_length (`int`):
  961. The size that the key and value states will have during the attention computation.
  962. kv_offset (`int`, optional):
  963. An optional offset to indicate at which first position the key and values states will refer to.
  964. mask_function (`Callable`):
  965. The mask factory function describing the mask pattern.
  966. attention_mask (`torch.Tensor`, optional):
  967. The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
  968. local_size (`int`, optional):
  969. The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
  970. to try to skip mask creation if possible.
  971. allow_is_causal_skip (`bool`, optional):
  972. Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
  973. `torch.sdpa` instead. Default to `True`.
  974. allow_torch_fix (`bool`, optional):
  975. Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
  976. versions. We need an arg to skip it when using eager. By default `True`.
  977. """
  978. q_length = cache_position.shape[0]
  979. # Potentially pad the 2D mask, and slice it correctly
  980. padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
  981. # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
  982. if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
  983. return None
  984. # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
  985. # but without data-dependent slicing (i.e. torch.compile friendly)
  986. kv_arange = torch.arange(kv_length, device=cache_position.device)
  987. kv_arange += kv_offset
  988. reshaped_cache_position = cache_position.view(-1, 1)
  989. # This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
  990. # the config through kwargs anyway, so it allows to rely on it
  991. # Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
  992. # but this is more efficient
  993. sliding_window = getattr(kwargs["config"], "sliding_window", None)
  994. chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
  995. if sliding_window is not None and chunk_size is not None:
  996. raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
  997. # Simplest and most efficient way to obtain a causal mask
  998. causal_mask = kv_arange <= reshaped_cache_position
  999. # If using sliding window, add the sliding mask
  1000. if sliding_window is not None:
  1001. sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
  1002. causal_mask *= sliding_mask_overlay
  1003. # If using chunk attention, add the chunked mask
  1004. elif chunk_size is not None:
  1005. chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
  1006. causal_mask *= chunked_mask_overlay
  1007. causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
  1008. if padding_mask is not None:
  1009. causal_mask = causal_mask * padding_mask[:, None, None, :]
  1010. # Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
  1011. # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
  1012. if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
  1013. causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
  1014. return causal_mask