modular_blt.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015
  1. # coding=utf-8
  2. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Blt modular model, inheriting from Mllama where appropriate."""
  16. from typing import Callable, Optional, Union
  17. import torch
  18. import torch.distributions
  19. import torch.nn as nn
  20. import torch.nn.functional as F
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...masking_utils import create_causal_mask
  23. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, logging
  27. from ...utils.generic import OutputRecorder, check_model_inputs
  28. from ..cohere2.modeling_cohere2 import (
  29. Cohere2RotaryEmbedding,
  30. rotate_half, # noqa: F401
  31. )
  32. from ..mllama.modeling_mllama import (
  33. MllamaForCausalLM,
  34. MllamaPreTrainedModel,
  35. MllamaSelfAttentionDecoderLayer,
  36. MllamaTextCrossAttention,
  37. MllamaTextMLP,
  38. MllamaTextRMSNorm,
  39. MllamaTextSelfAttention,
  40. eager_attention_forward,
  41. )
  42. from .configuration_blt import (
  43. BltConfig,
  44. BltGlobalTransformerConfig,
  45. BltLocalDecoderConfig,
  46. BltLocalEncoderConfig,
  47. BltPatcherConfig,
  48. )
  49. logger = logging.get_logger(__name__)
  50. def rolling_polynomial_hash(token_tensor, prime: int = 1000000007):
  51. """
  52. A polynomial rolling hash algorithm that converts sequences
  53. of tokens into hash values. The hash is computed as:
  54. hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n)
  55. The rolling hash allows the model to efficiently
  56. identify and encode recurring byte-level patterns in the input text.
  57. Args:
  58. token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash
  59. prime (int): Prime number used as the base for the polynomial hash.
  60. Returns:
  61. torch.Tensor: Hash values of shape [batch_size, seq_len] where each value
  62. represents the hash of the corresponding token group
  63. Example:
  64. >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]])
  65. >>> hashes = rolling_polynomial_hash(tokens, prime=31)
  66. >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2
  67. >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2
  68. """
  69. prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device)
  70. powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
  71. prime_powers = prime_tensor**powers
  72. return torch.sum(token_tensor * prime_powers, dim=-1)
  73. def byte_group_hash_function(
  74. token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000
  75. ):
  76. """Hash token groups and map to range [0, max_hash]."""
  77. with torch.no_grad():
  78. batch_size, seq_len = token_ids.shape
  79. # Add padding for sliding window
  80. padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
  81. padded_tokens = torch.cat([padding, token_ids], dim=1)
  82. # Create sliding windows and compute hashes
  83. windows = padded_tokens.unfold(1, group_size, 1)
  84. hashes = rolling_polynomial_hash(windows, prime)
  85. hash_values = hashes % max_hash
  86. return hash_values
  87. def compute_hash_embeddings(
  88. local_encoder_tokens: torch.Tensor,
  89. local_encoder,
  90. encoder_hash_tok_embedding: nn.Embedding,
  91. encoder_hash_byte_group_nb_functions: int,
  92. encoder_hash_byte_group_size: list,
  93. encoder_hash_byte_group_vocab: int,
  94. ) -> torch.Tensor:
  95. """Compute token embeddings enhanced with hash-based embeddings."""
  96. # Available primes for hash functions
  97. primes = [
  98. 1000000007,
  99. 5915587277,
  100. 1500450271,
  101. 3267000013,
  102. 5754853343,
  103. 4093082899,
  104. 9576890767,
  105. 3628273133,
  106. 2860486313,
  107. 5463458053,
  108. 3367900313,
  109. ]
  110. embeddings = local_encoder.embed_tokens(local_encoder_tokens)
  111. embedding_idx = 0
  112. for func_nb in range(encoder_hash_byte_group_nb_functions):
  113. prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes
  114. for group_size in encoder_hash_byte_group_size:
  115. hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
  116. # Apply offset to get the correct slice of the fused embedding
  117. offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
  118. embeddings += encoder_hash_tok_embedding(offset_hash_ids)
  119. embedding_idx += 1
  120. return embeddings
  121. def _prepare_patch_cross_attention_mask(
  122. patch_ids: torch.Tensor,
  123. num_patches: int,
  124. sequence_length: int,
  125. patches_as_queries: bool = False,
  126. cross_attn_k: int = 1,
  127. dtype: torch.dtype = torch.float32,
  128. ) -> tuple[torch.Tensor, torch.Tensor]:
  129. """
  130. Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
  131. This function creates masks that control which patches can attend to which other patches,
  132. with support for query/key role swapping and cross-attention multipliers.
  133. Args:
  134. patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
  135. num_patches (int): Total number of patches.
  136. sequence_length (int): Length of the sequence.
  137. patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
  138. cross_attn_k (int): Cross-attention multiplier for repeating patches.
  139. dtype (torch.dtype): Data type for the output mask.
  140. Returns:
  141. Tuple[torch.Tensor, torch.Tensor]:
  142. - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
  143. """
  144. batch_size, seq_len = patch_ids.shape
  145. device = patch_ids.device
  146. # Determine query and key lengths based on configuration
  147. if patches_as_queries:
  148. q_len = num_patches * cross_attn_k
  149. kv_len = sequence_length
  150. # Create patch-to-sequence mapping
  151. q_patch_ids = (
  152. torch.arange(num_patches, device=device)
  153. .unsqueeze(0)
  154. .unsqueeze(-1)
  155. .expand(batch_size, num_patches, seq_len)
  156. )
  157. kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
  158. else:
  159. q_len = sequence_length
  160. kv_len = num_patches * cross_attn_k
  161. # Create sequence-to-patch mapping
  162. q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
  163. kv_patch_ids = (
  164. torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches)
  165. )
  166. # Create base attention mask - boolean mask where True means "should attend"
  167. # Exact patch matching
  168. cross_attention_mask = q_patch_ids == kv_patch_ids
  169. # Handle cross_attn_k multiplier by repeating along appropriate dimension
  170. repeat_dim = 1 if patches_as_queries else -1
  171. cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
  172. # Validate dimensions
  173. expected_shape = (batch_size, q_len, kv_len)
  174. if cross_attention_mask.shape != expected_shape:
  175. raise ValueError(
  176. f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}"
  177. )
  178. # Reshape so it can be used by attn module - add head dimension
  179. cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
  180. # Invert the mask (following mllama pattern exactly)
  181. # True -> 0.0 (attend), False -> 1.0 (will become -inf)
  182. inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype)
  183. cross_attention_mask = inverted_cross_attn_mask.masked_fill(
  184. inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
  185. )
  186. return cross_attention_mask
  187. def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: Optional[int]) -> torch.Tensor:
  188. """
  189. Splits patch lengths into smaller segments if they exceed `max_patch_length`.
  190. Pads the result to uniform length across the batch.
  191. Args:
  192. patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
  193. max_patch_length (int, optional): Maximum allowed length per patch.
  194. Returns:
  195. torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
  196. """
  197. if max_patch_length is None:
  198. return patch_lengths
  199. batch_size = patch_lengths.size(0)
  200. processed = []
  201. for seq in patch_lengths:
  202. splits = []
  203. for length in seq[seq > 0]:
  204. length = length.item()
  205. full_chunks, remainder = divmod(length, max_patch_length)
  206. splits.extend([max_patch_length] * full_chunks)
  207. if remainder:
  208. splits.append(remainder)
  209. processed.append(splits)
  210. # Find max length to pad to
  211. max_len = max(len(splits) for splits in processed)
  212. padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
  213. for i, splits in enumerate(processed):
  214. if splits:
  215. padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
  216. # Trim zero columns
  217. if (padded != 0).any(dim=0).sum() < padded.shape[1]:
  218. last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
  219. padded = padded[:, :last_nonzero]
  220. return padded
  221. class BltMLP(MllamaTextMLP):
  222. pass
  223. class BltRMSNorm(MllamaTextRMSNorm):
  224. pass
  225. class BltRotaryEmbedding(Cohere2RotaryEmbedding):
  226. pass
  227. class BltTransformerLayer(MllamaSelfAttentionDecoderLayer):
  228. def __init__(self, config, layer_idx: int):
  229. super().__init__()
  230. self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx)
  231. self.mlp = BltMLP(config)
  232. self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  233. self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  234. class BltSelfAttention(MllamaTextSelfAttention):
  235. def __init__(self, config: BltConfig, layer_idx: int):
  236. super().__init__(config, layer_idx)
  237. self.is_causal = True
  238. def forward(
  239. self,
  240. hidden_states: torch.Tensor,
  241. attention_mask: torch.Tensor,
  242. position_embeddings: torch.Tensor,
  243. use_cache: bool = False,
  244. past_key_values=None,
  245. cache_position=None,
  246. **kwargs,
  247. ):
  248. return super().forward(
  249. hidden_states=hidden_states,
  250. attention_mask=attention_mask,
  251. position_embeddings=position_embeddings,
  252. use_cache=use_cache,
  253. past_key_values=past_key_values,
  254. cache_position=cache_position,
  255. **kwargs,
  256. )
  257. class BltCrossAttention(MllamaTextCrossAttention):
  258. """Cross-attention module for Blt, following transformers style"""
  259. def __init__(self, config: BltConfig, layer_idx: int, hidden_size: Optional[int] = None):
  260. super().__init__()
  261. self.is_causal = False
  262. self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  263. self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  264. def forward(
  265. self,
  266. hidden_states: torch.Tensor,
  267. cross_attention_states: Optional[torch.Tensor] = None,
  268. past_key_values: Optional[Cache] = None,
  269. attention_mask: Optional[torch.Tensor] = None,
  270. cache_position: Optional[torch.LongTensor] = None,
  271. **kwargs: Unpack[TransformersKwargs],
  272. ):
  273. bsz, q_len, _ = hidden_states.size()
  274. query_states = self.q_norm(hidden_states)
  275. query_states = self.q_proj(query_states)
  276. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  277. if cross_attention_states is not None:
  278. cross_attention_states = self.k_norm(cross_attention_states)
  279. key_states = self.k_proj(cross_attention_states)
  280. value_states = self.v_proj(cross_attention_states)
  281. key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  282. value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  283. if past_key_values is not None:
  284. key_states, value_states = past_key_values.update(
  285. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  286. )
  287. elif cache_position[0] != 0:
  288. key_states, value_states = (
  289. past_key_values.layers[self.layer_idx].keys,
  290. past_key_values.layers[self.layer_idx].values,
  291. )
  292. else:
  293. raise ValueError(
  294. "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
  295. )
  296. attention_interface: Callable = eager_attention_forward
  297. if self.config._attn_implementation != "eager":
  298. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  299. attn_output, attn_weights = attention_interface(
  300. self,
  301. query_states,
  302. key_states,
  303. value_states,
  304. attention_mask,
  305. dropout=0.0 if not self.training else self.dropout,
  306. scaling=self.scaling,
  307. **kwargs,
  308. )
  309. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  310. attn_output = self.o_proj(attn_output)
  311. attn_output = attn_output + hidden_states
  312. return attn_output, attn_weights
  313. @auto_docstring
  314. class BltPreTrainedModel(MllamaPreTrainedModel):
  315. config: BltConfig
  316. _supports_attention_backend = False
  317. _supports_flash_attn = False
  318. _supports_flex_attn = False
  319. _no_split_modules = ["BltTransformerLayer"]
  320. _can_record_outputs = {
  321. "hidden_states": OutputRecorder(BltTransformerLayer, index=0, layer_name="local_decoder"),
  322. "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
  323. }
  324. def _init_weights(self, module):
  325. raise AttributeError("No need to inherit it!")
  326. def _update_causal_mask(self, module):
  327. raise AttributeError("No need to inherit it!")
  328. def _prepare_4d_causal_attention_mask_with_cache_position(self, module):
  329. raise AttributeError("No need to inherit it!")
  330. class BltLocalEncoder(BltPreTrainedModel):
  331. config: BltLocalEncoderConfig
  332. _can_record_outputs = {
  333. "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"),
  334. }
  335. def __init__(self, config: BltLocalEncoderConfig):
  336. super().__init__(config)
  337. self.gradient_checkpointing = False
  338. self.config = config
  339. self.layers = nn.ModuleList(
  340. [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  341. )
  342. self.rotary_emb = BltRotaryEmbedding(config=config)
  343. self.patch_embedding_projection = nn.Linear(
  344. in_features=config.hidden_size,
  345. out_features=config.hidden_size * config.cross_attn_k,
  346. bias=False,
  347. )
  348. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  349. self.cross_attn_layers = nn.ModuleList()
  350. layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
  351. for layer_idx in range(layers_to_add):
  352. self.cross_attn_layers.append(
  353. BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
  354. )
  355. self.post_init()
  356. def forward(
  357. self,
  358. input_ids: Optional[torch.LongTensor] = None,
  359. inputs_embeds: Optional[torch.Tensor] = None,
  360. patch_embeds: Optional[torch.Tensor] = None,
  361. attention_mask: Optional[torch.Tensor] = None,
  362. position_ids: Optional[torch.LongTensor] = None,
  363. past_key_values: Optional[Cache] = None,
  364. cache_position: Optional[torch.LongTensor] = None,
  365. encoder_attention_mask: Optional[torch.Tensor] = None,
  366. num_patches: Optional[int] = None,
  367. patch_ids: Optional[torch.Tensor] = None,
  368. **kwargs: Unpack[TransformersKwargs],
  369. ):
  370. if inputs_embeds is None:
  371. inputs_embeds = self.embed_tokens(input_ids)
  372. batch_size = inputs_embeds.shape[0]
  373. hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training)
  374. if position_ids is None:
  375. position_ids = (
  376. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  377. )
  378. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  379. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  380. for idx, layer in enumerate(self.layers):
  381. hidden_states = layer(
  382. hidden_states,
  383. position_embeddings=position_embeddings,
  384. attention_mask=attention_mask,
  385. past_key_values=past_key_values,
  386. cache_position=cache_position,
  387. **kwargs,
  388. )
  389. if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
  390. patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids)
  391. patch_embeds = self.patch_embedding_projection(patch_embeds)
  392. patch_embeds = patch_embeds.reshape(
  393. batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
  394. )
  395. layer_idx = idx if self.config.cross_attn_all_layers else 0
  396. cross_attention_output, _ = self.cross_attn_layers[layer_idx](
  397. hidden_states=patch_embeds,
  398. cross_attention_states=hidden_states,
  399. attention_mask=encoder_attention_mask,
  400. **kwargs,
  401. )
  402. patch_embeds = patch_embeds + cross_attention_output
  403. encoder_cross_states = patch_embeds
  404. return hidden_states, encoder_cross_states
  405. def patch_reduce(self, hidden_states, max_num_patches, patch_ids):
  406. """
  407. Reduce variable length patches to single embedding per patch
  408. Note: this works with variable number of patches for different sequences in the batch
  409. It handles variable length patches by assuming that patch_lengths will be 0 for any
  410. extra patches on the *right*. Since there can be a variable number of patches
  411. this function also return the number of patches for each sequence in the batch.
  412. Any embeddings on the right that are not allocated to a patch
  413. (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
  414. will be sent to a dummy patch, which is trimmed before returning.
  415. """
  416. batch_size = hidden_states.shape[0]
  417. embedding_dim = hidden_states.shape[-1]
  418. patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
  419. reduced_embeddings = torch.zeros(
  420. (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device
  421. )
  422. reduced_embeddings = reduced_embeddings.scatter_reduce(
  423. src=hidden_states,
  424. dim=1,
  425. index=patch_ids,
  426. reduce="amax",
  427. include_self=False,
  428. )
  429. reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
  430. return reduced_embeddings
  431. class BltLocalDecoder(BltPreTrainedModel):
  432. config: BltLocalDecoderConfig
  433. def __init__(self, config: BltLocalDecoderConfig):
  434. super().__init__(config)
  435. self.gradient_checkpointing = False
  436. self.config = config
  437. self.cross_attn_decoder = True
  438. self.layers = nn.ModuleList(
  439. [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  440. )
  441. self.rotary_emb = BltRotaryEmbedding(config=config)
  442. self.patch_embedding_projection = nn.Linear(
  443. in_features=config.hidden_size_global,
  444. out_features=config.hidden_size * config.cross_attn_k,
  445. bias=False,
  446. )
  447. self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  448. self.cross_attn_layers = nn.ModuleList()
  449. layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
  450. for layer_idx in range(layers_to_add):
  451. self.cross_attn_layers.append(
  452. BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
  453. )
  454. self.post_init()
  455. @check_model_inputs()
  456. def forward(
  457. self,
  458. input_ids: Optional[torch.LongTensor] = None,
  459. inputs_embeds: Optional[torch.Tensor] = None,
  460. patch_embeds: Optional[torch.Tensor] = None,
  461. attention_mask: Optional[torch.Tensor] = None,
  462. position_ids: Optional[torch.LongTensor] = None,
  463. past_key_values: Optional[Cache] = None,
  464. cache_position: Optional[torch.LongTensor] = None,
  465. encoder_attention_mask: Optional[torch.Tensor] = None,
  466. **kwargs: Unpack[TransformersKwargs],
  467. ):
  468. batch_size = inputs_embeds.shape[0]
  469. hidden_states = inputs_embeds
  470. patch_embeds = self.patch_embedding_projection(patch_embeds)
  471. patch_embeds = patch_embeds.reshape(
  472. batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
  473. )
  474. if patch_embeds is not None and not self.cross_attn_decoder:
  475. hidden_states = hidden_states + patch_embeds
  476. if position_ids is None:
  477. position_ids = (
  478. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  479. )
  480. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  481. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  482. for i, layer in enumerate(self.layers):
  483. if i == 0 or self.config.cross_attn_all_layers:
  484. cross_attention_output, _ = self.cross_attn_layers[i](
  485. hidden_states=hidden_states,
  486. cross_attention_states=patch_embeds,
  487. attention_mask=encoder_attention_mask,
  488. **kwargs,
  489. )
  490. hidden_states = hidden_states + cross_attention_output
  491. hidden_states = layer(
  492. hidden_states,
  493. position_embeddings=position_embeddings,
  494. attention_mask=attention_mask,
  495. past_key_values=past_key_values,
  496. cache_position=cache_position,
  497. **kwargs,
  498. )
  499. logits = self.norm(hidden_states)
  500. return logits
  501. class BltGlobalTransformer(BltPreTrainedModel):
  502. config: BltGlobalTransformerConfig
  503. _can_record_outputs = {
  504. "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"),
  505. }
  506. def __init__(self, config: BltGlobalTransformerConfig):
  507. super().__init__(config)
  508. self.config = config
  509. self.layers = nn.ModuleList()
  510. for layer_idx in range(config.num_hidden_layers):
  511. self.layers.append(BltTransformerLayer(config, layer_idx))
  512. self.rotary_emb = BltRotaryEmbedding(config=config)
  513. # Create token embedding projection (use nn.Identity() when no projection needed)
  514. if getattr(config, "encoder_cross_output_size", None) is not None:
  515. self.token_embedding_projection = nn.Linear(
  516. config.encoder_cross_output_size, config.hidden_size, bias=False
  517. )
  518. else:
  519. self.token_embedding_projection = nn.Identity()
  520. self.post_init()
  521. def forward(
  522. self,
  523. input_embeds: torch.Tensor,
  524. attention_mask: Optional[torch.Tensor] = None,
  525. position_ids: Optional[torch.LongTensor] = None,
  526. past_key_values: Optional[Cache] = None,
  527. cache_position: Optional[torch.LongTensor] = None,
  528. **kwargs: Unpack[TransformersKwargs],
  529. ):
  530. batch_size, seq_len, _ = input_embeds.shape
  531. hidden_states = self.token_embedding_projection(input_embeds)
  532. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  533. if position_ids is None:
  534. position_ids = (
  535. torch.arange(input_embeds.shape[1], device=input_embeds.device).unsqueeze(0).expand(batch_size, -1)
  536. )
  537. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  538. for i, layer in enumerate(self.layers):
  539. hidden_states = layer(
  540. hidden_states,
  541. position_embeddings=position_embeddings,
  542. attention_mask=attention_mask,
  543. past_key_values=past_key_values,
  544. cache_position=cache_position,
  545. **kwargs,
  546. )
  547. return hidden_states
  548. class BltPatcher(BltPreTrainedModel):
  549. config: BltPatcherConfig
  550. def __init__(self, config: BltPatcherConfig):
  551. super().__init__(config)
  552. self.rotary_emb = BltRotaryEmbedding(config=self.config)
  553. self.layers = nn.ModuleList()
  554. for layer_idx in range(self.config.num_hidden_layers):
  555. self.layers.append(BltTransformerLayer(self.config, layer_idx))
  556. self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
  557. self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  558. self.lm_head = nn.Linear(
  559. self.config.hidden_size,
  560. self.config.vocab_size,
  561. bias=False,
  562. )
  563. def forward(
  564. self,
  565. input_ids: Optional[torch.LongTensor] = None,
  566. attention_mask: Optional[torch.Tensor] = None,
  567. position_ids: Optional[torch.LongTensor] = None,
  568. past_key_values: Optional[Cache] = None,
  569. inputs_embeds: Optional[torch.FloatTensor] = None,
  570. use_cache: Optional[bool] = None,
  571. cache_position: Optional[torch.LongTensor] = None,
  572. patch_size: Optional[int] = None,
  573. threshold: Optional[float] = None,
  574. max_patch_length: Optional[int] = None,
  575. **kwargs: Unpack[TransformersKwargs],
  576. ):
  577. if (input_ids is None) ^ (inputs_embeds is not None):
  578. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  579. if inputs_embeds is None:
  580. inputs_embeds = self.embed_tokens(input_ids)
  581. if use_cache and past_key_values is None:
  582. past_key_values = DynamicCache()
  583. if cache_position is None:
  584. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  585. cache_position = torch.arange(
  586. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  587. )
  588. if position_ids is None:
  589. position_ids = cache_position.unsqueeze(0)
  590. causal_mask = create_causal_mask(
  591. config=self.config,
  592. input_embeds=inputs_embeds,
  593. attention_mask=attention_mask,
  594. cache_position=cache_position,
  595. past_key_values=past_key_values,
  596. position_ids=position_ids,
  597. )
  598. hidden_states = inputs_embeds
  599. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  600. for layer in self.layers:
  601. hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)
  602. logits = self.lm_head(self.norm(hidden_states))
  603. prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
  604. batch_size, sequence_length = inputs_embeds.shape[:2]
  605. if patch_size is not None:
  606. patch_lengths = self.patch_lengths_from_entropies(
  607. entropies=prediction_entropies,
  608. sequence_length=sequence_length,
  609. patch_size=patch_size,
  610. threshold=threshold,
  611. )
  612. else:
  613. patch_lengths = torch.ones(
  614. (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
  615. )
  616. patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
  617. return prediction_entropies, patch_lengths, logits
  618. @staticmethod
  619. def patch_lengths_from_entropies(
  620. entropies,
  621. sequence_length,
  622. patch_size=None,
  623. threshold=None,
  624. ):
  625. """
  626. Computes patch lengths from token entropies.
  627. Depending on whether a threshold is provided, the function uses either:
  628. - Thresholding the entropy values (when `threshold` is set).
  629. """
  630. batch_size = entropies.shape[0]
  631. # Always include token 0 and 1 as starting tokens
  632. init_tokens = (
  633. torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
  634. )
  635. offset = init_tokens.shape[1]
  636. # Ignore first token entropy (BOS)
  637. entropies = entropies[:, 1:]
  638. # Threshold the entropy values to define patch start points
  639. patch_mask = entropies > threshold
  640. seq_len = patch_mask.shape[1]
  641. # Create patch IDs (token indices), and add a sentinel to ensure alignment
  642. token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
  643. sentinel = torch.full_like(token_indices, seq_len)
  644. padded_indices = torch.cat([token_indices, sentinel], dim=1)
  645. # Pad mask with inverse to align sentinel correctly
  646. padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
  647. # Select indices where mask is True
  648. patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
  649. max_valid_patches = patch_mask.sum(dim=1).max()
  650. patch_starts = patch_starts[:, :max_valid_patches]
  651. # Offset patch starts to account for the two initial tokens
  652. patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
  653. # Compute patch end positions by shifting start positions
  654. last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
  655. patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
  656. patch_lengths = patch_ends - patch_start_ids + 1
  657. return patch_lengths
  658. class BltModel(BltPreTrainedModel):
  659. def __init__(self, config: BltConfig):
  660. super().__init__(config)
  661. self.gradient_checkpointing = False
  662. self.config = config
  663. self.local_encoder = BltLocalEncoder(config.encoder_config)
  664. self.global_transformer = BltGlobalTransformer(config.global_config)
  665. self.local_decoder = BltLocalDecoder(config.decoder_config)
  666. num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size)
  667. total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings
  668. self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size)
  669. if self.config.patch_in_forward:
  670. self.patcher = BltPatcher(config.patcher_config)
  671. self.patcher.eval()
  672. for param in self.patcher.parameters():
  673. param.requires_grad = False
  674. else:
  675. self.patcher = None
  676. self.post_init()
  677. @check_model_inputs()
  678. def forward(
  679. self,
  680. input_ids: Optional[torch.LongTensor] = None,
  681. patch_lengths: Optional[torch.Tensor] = None,
  682. attention_mask: Optional[torch.Tensor] = None,
  683. position_ids: Optional[torch.LongTensor] = None,
  684. past_key_values: Optional[Cache] = None,
  685. inputs_embeds: Optional[torch.FloatTensor] = None,
  686. use_cache: Optional[bool] = None,
  687. cache_position: Optional[torch.LongTensor] = None,
  688. **kwargs: Unpack[TransformersKwargs],
  689. ) -> BaseModelOutputWithPast:
  690. if (input_ids is None) ^ (inputs_embeds is not None):
  691. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  692. # Extract input embeddings as early as possible
  693. if inputs_embeds is not None:
  694. encoder_embeds = inputs_embeds
  695. batch_size, sequence_length, _ = inputs_embeds.shape
  696. else:
  697. batch_size, sequence_length = input_ids.shape
  698. encoder_embeds = compute_hash_embeddings(
  699. input_ids,
  700. self.local_encoder,
  701. self.encoder_hash_tok_embedding,
  702. self.config.encoder_hash_byte_group_nb_functions,
  703. self.config.encoder_hash_byte_group_size,
  704. self.config.encoder_hash_byte_group_vocab,
  705. )
  706. if patch_lengths is None:
  707. if self.config.patching_mode == "entropy" and self.patcher is not None:
  708. if input_ids is None:
  709. raise ValueError("input_ids is required for entropy-based patching")
  710. _, patch_lengths, _ = self.patcher(
  711. input_ids,
  712. patch_size=self.config.patch_size,
  713. threshold=self.config.patching_threshold,
  714. max_patch_length=self.config.max_patch_length,
  715. patching_batch_size=self.config.patching_batch_size,
  716. device=input_ids.device,
  717. )
  718. else:
  719. device = input_ids.device if input_ids is not None else inputs_embeds.device
  720. dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype
  721. patch_lengths = process_patch_lengths(
  722. torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device),
  723. self.config.max_patch_length,
  724. )
  725. patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
  726. if cache_position is None:
  727. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  728. cache_position = torch.arange(
  729. past_seen_tokens, past_seen_tokens + encoder_embeds.shape[1], device=encoder_embeds.device
  730. )
  731. if position_ids is None:
  732. position_ids = cache_position.unsqueeze(0)
  733. causal_mask = create_causal_mask(
  734. config=self.config,
  735. input_embeds=encoder_embeds,
  736. attention_mask=attention_mask,
  737. cache_position=cache_position,
  738. past_key_values=past_key_values,
  739. position_ids=position_ids,
  740. )
  741. cross_attn_mask_enc = _prepare_patch_cross_attention_mask(
  742. patch_ids=patch_ids,
  743. num_patches=patch_lengths.shape[1],
  744. sequence_length=sequence_length,
  745. patches_as_queries=True,
  746. cross_attn_k=self.config.cross_attn_k,
  747. dtype=encoder_embeds.dtype,
  748. )
  749. encoder_hidden_states, encoder_cross_states = self.local_encoder(
  750. input_ids=input_ids,
  751. inputs_embeds=encoder_embeds,
  752. attention_mask=causal_mask,
  753. position_ids=position_ids,
  754. encoder_attention_mask=cross_attn_mask_enc,
  755. num_patches=patch_lengths.shape[1],
  756. patch_ids=patch_ids,
  757. **kwargs,
  758. )
  759. encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
  760. global_cache_position = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device)
  761. global_position_ids = global_cache_position.unsqueeze(0)
  762. global_causal_mask = create_causal_mask(
  763. config=self.config,
  764. input_embeds=encoder_cross_states,
  765. attention_mask=None,
  766. cache_position=global_cache_position,
  767. past_key_values=None,
  768. position_ids=None,
  769. )
  770. global_hidden_states = self.global_transformer(
  771. input_embeds=encoder_cross_states,
  772. attention_mask=global_causal_mask,
  773. position_ids=global_position_ids,
  774. **kwargs,
  775. )
  776. decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
  777. cross_attn_mask_dec = _prepare_patch_cross_attention_mask(
  778. patch_ids=decoder_patch_ids,
  779. num_patches=patch_lengths.shape[1],
  780. sequence_length=sequence_length,
  781. patches_as_queries=False,
  782. cross_attn_k=self.config.cross_attn_k,
  783. dtype=encoder_embeds.dtype,
  784. )
  785. output = self.local_decoder(
  786. input_ids=input_ids,
  787. inputs_embeds=encoder_hidden_states,
  788. patch_embeds=global_hidden_states,
  789. attention_mask=causal_mask,
  790. position_ids=position_ids,
  791. past_key_values=past_key_values,
  792. cache_position=cache_position,
  793. encoder_attention_mask=cross_attn_mask_dec,
  794. **kwargs,
  795. )
  796. return BaseModelOutputWithPast(
  797. last_hidden_state=output,
  798. past_key_values=past_key_values,
  799. )
  800. def get_input_embeddings(self):
  801. return self.local_encoder.embed_tokens
  802. def set_input_embeddings(self, value):
  803. self.local_encoder.embed_tokens = value
  804. def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
  805. batch_size = patch_lengths.shape[0]
  806. patch_starts = torch.cat(
  807. [
  808. torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
  809. patch_lengths.cumsum(dim=-1)[:, :-1],
  810. ],
  811. dim=-1,
  812. )
  813. token_positions = torch.arange(seq_len, device=patch_lengths.device)
  814. return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
  815. class BltForCausalLM(MllamaForCausalLM):
  816. config: BltConfig
  817. _can_compile_fullgraph = False
  818. base_model_prefix = "model"
  819. _tied_weights_keys = ["lm_head.weight"]
  820. def __init__(self, config: BltConfig):
  821. super().__init__(config)
  822. self.vocab_size = config.vocab_size
  823. self.model = BltModel(config)
  824. self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
  825. self.post_init()
  826. def forward(
  827. self,
  828. input_ids: Optional[torch.LongTensor] = None,
  829. attention_mask: Optional[torch.Tensor] = None,
  830. position_ids: Optional[torch.LongTensor] = None,
  831. cross_attention_states: Optional[torch.LongTensor] = None, # Keep for compatibility
  832. cross_attention_mask: Optional[torch.LongTensor] = None,
  833. full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  834. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  835. inputs_embeds: Optional[torch.FloatTensor] = None,
  836. labels: Optional[torch.LongTensor] = None,
  837. use_cache: Optional[bool] = None,
  838. cache_position: Optional[torch.LongTensor] = None,
  839. logits_to_keep: Union[int, torch.Tensor] = 0,
  840. **kwargs: Unpack[TransformersKwargs],
  841. ) -> Union[tuple, CausalLMOutputWithPast]:
  842. # Call parent forward but exclude cross_attention_states from model call
  843. outputs = self.model(
  844. input_ids=input_ids,
  845. attention_mask=attention_mask,
  846. position_ids=position_ids,
  847. cross_attention_mask=cross_attention_mask,
  848. full_text_row_masked_out_mask=full_text_row_masked_out_mask,
  849. past_key_values=past_key_values,
  850. inputs_embeds=inputs_embeds,
  851. use_cache=use_cache,
  852. cache_position=cache_position,
  853. **kwargs,
  854. )
  855. hidden_states = outputs.last_hidden_state
  856. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  857. logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
  858. loss = None
  859. if labels is not None:
  860. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  861. return CausalLMOutputWithPast(
  862. loss=loss,
  863. logits=logits,
  864. past_key_values=outputs.past_key_values,
  865. hidden_states=outputs.hidden_states,
  866. attentions=outputs.attentions,
  867. )
  868. __all__ = [
  869. "BltPreTrainedModel",
  870. "BltModel",
  871. "BltPatcher",
  872. "BltForCausalLM",
  873. ]