modeling_olmo2.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/olmo2/modular_olmo2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_olmo2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. from typing import Callable, Optional, Union
  8. import torch
  9. import torch.nn as nn
  10. from transformers.utils.generic import TransformersKwargs
  11. from ...activations import ACT2FN
  12. from ...cache_utils import Cache, DynamicCache
  13. from ...generation import GenerationMixin
  14. from ...integrations import use_kernel_forward_from_hub
  15. from ...masking_utils import create_causal_mask
  16. from ...modeling_layers import GradientCheckpointingLayer
  17. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  18. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  19. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  20. from ...processing_utils import Unpack
  21. from ...utils import auto_docstring, can_return_tuple
  22. from ...utils.deprecation import deprecate_kwarg
  23. from ...utils.generic import check_model_inputs
  24. from .configuration_olmo2 import Olmo2Config
  25. @use_kernel_forward_from_hub("RMSNorm")
  26. class Olmo2RMSNorm(nn.Module):
  27. def __init__(self, hidden_size, eps=1e-6):
  28. """
  29. Olmo2RMSNorm is equivalent to T5LayerNorm
  30. """
  31. super().__init__()
  32. self.weight = nn.Parameter(torch.ones(hidden_size))
  33. self.variance_epsilon = eps
  34. def forward(self, hidden_states):
  35. input_dtype = hidden_states.dtype
  36. hidden_states = hidden_states.to(torch.float32)
  37. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  38. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  39. return (self.weight * hidden_states).to(input_dtype)
  40. def extra_repr(self):
  41. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  42. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  43. """
  44. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  45. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  46. """
  47. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  48. if n_rep == 1:
  49. return hidden_states
  50. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  51. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  52. def eager_attention_forward(
  53. module: nn.Module,
  54. query: torch.Tensor,
  55. key: torch.Tensor,
  56. value: torch.Tensor,
  57. attention_mask: Optional[torch.Tensor],
  58. scaling: float,
  59. dropout: float = 0.0,
  60. **kwargs: Unpack[TransformersKwargs],
  61. ):
  62. key_states = repeat_kv(key, module.num_key_value_groups)
  63. value_states = repeat_kv(value, module.num_key_value_groups)
  64. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  65. if attention_mask is not None:
  66. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  67. attn_weights = attn_weights + causal_mask
  68. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  69. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  70. attn_output = torch.matmul(attn_weights, value_states)
  71. attn_output = attn_output.transpose(1, 2).contiguous()
  72. return attn_output, attn_weights
  73. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  74. """Applies Rotary Position Embedding to the query and key tensors.
  75. Args:
  76. q (`torch.Tensor`): The query tensor.
  77. k (`torch.Tensor`): The key tensor.
  78. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  79. sin (`torch.Tensor`): The sine part of the rotary embedding.
  80. position_ids (`torch.Tensor`, *optional*):
  81. Deprecated and unused.
  82. unsqueeze_dim (`int`, *optional*, defaults to 1):
  83. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  84. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  85. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  86. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  87. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  88. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  89. Returns:
  90. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  91. """
  92. q_type, k_type = q.dtype, k.dtype
  93. cos = cos.unsqueeze(unsqueeze_dim)
  94. sin = sin.unsqueeze(unsqueeze_dim)
  95. q_embed = (q * cos) + (rotate_half(q) * sin)
  96. k_embed = (k * cos) + (rotate_half(k) * sin)
  97. return q_embed.to(q_type), k_embed.to(k_type)
  98. def rotate_half(x):
  99. """Rotates half the hidden dims of the input."""
  100. x1 = x[..., : x.shape[-1] // 2]
  101. x2 = x[..., x.shape[-1] // 2 :]
  102. return torch.cat((-x2, x1), dim=-1)
  103. class Olmo2Attention(nn.Module):
  104. """Multi-headed attention from 'Attention Is All You Need' paper"""
  105. def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None):
  106. super().__init__()
  107. self.config = config
  108. self.layer_idx = layer_idx
  109. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  110. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  111. self.scaling = self.head_dim**-0.5
  112. self.attention_dropout = config.attention_dropout
  113. self.is_causal = True
  114. self.q_proj = nn.Linear(
  115. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  116. )
  117. self.k_proj = nn.Linear(
  118. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  119. )
  120. self.v_proj = nn.Linear(
  121. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  122. )
  123. self.o_proj = nn.Linear(
  124. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  125. )
  126. self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
  127. self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
  128. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  129. def forward(
  130. self,
  131. hidden_states: torch.Tensor,
  132. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  133. attention_mask: Optional[torch.Tensor],
  134. past_key_values: Optional[Cache] = None,
  135. cache_position: Optional[torch.LongTensor] = None,
  136. **kwargs: Unpack[TransformersKwargs],
  137. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  138. input_shape = hidden_states.shape[:-1]
  139. hidden_shape = (*input_shape, -1, self.head_dim)
  140. query_states = self.q_norm(self.q_proj(hidden_states))
  141. key_states = self.k_norm(self.k_proj(hidden_states))
  142. value_states = self.v_proj(hidden_states)
  143. query_states = query_states.view(hidden_shape).transpose(1, 2)
  144. key_states = key_states.view(hidden_shape).transpose(1, 2)
  145. value_states = value_states.view(hidden_shape).transpose(1, 2)
  146. cos, sin = position_embeddings
  147. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  148. if past_key_values is not None:
  149. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  150. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  151. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  152. attention_interface: Callable = eager_attention_forward
  153. if self.config._attn_implementation != "eager":
  154. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  155. attn_output, attn_weights = attention_interface(
  156. self,
  157. query_states,
  158. key_states,
  159. value_states,
  160. attention_mask,
  161. dropout=0.0 if not self.training else self.attention_dropout,
  162. scaling=self.scaling,
  163. **kwargs,
  164. )
  165. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  166. attn_output = self.o_proj(attn_output)
  167. return attn_output, attn_weights
  168. class Olmo2MLP(nn.Module):
  169. def __init__(self, config):
  170. super().__init__()
  171. self.config = config
  172. self.hidden_size = config.hidden_size
  173. self.intermediate_size = config.intermediate_size
  174. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  175. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  176. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  177. self.act_fn = ACT2FN[config.hidden_act]
  178. def forward(self, x):
  179. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  180. return down_proj
  181. class Olmo2DecoderLayer(GradientCheckpointingLayer):
  182. def __init__(self, config: Olmo2Config, layer_idx: int):
  183. super().__init__()
  184. self.hidden_size = config.hidden_size
  185. self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx)
  186. self.mlp = Olmo2MLP(config)
  187. self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  188. self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  189. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  190. def forward(
  191. self,
  192. hidden_states: torch.Tensor,
  193. attention_mask: Optional[torch.Tensor] = None,
  194. position_ids: Optional[torch.LongTensor] = None,
  195. past_key_values: Optional[Cache] = None,
  196. use_cache: Optional[bool] = False,
  197. cache_position: Optional[torch.LongTensor] = None,
  198. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  199. **kwargs: Unpack[TransformersKwargs],
  200. ) -> torch.Tensor:
  201. residual = hidden_states
  202. hidden_states, _ = self.self_attn(
  203. hidden_states=hidden_states,
  204. attention_mask=attention_mask,
  205. position_ids=position_ids,
  206. past_key_values=past_key_values,
  207. use_cache=use_cache,
  208. cache_position=cache_position,
  209. position_embeddings=position_embeddings,
  210. **kwargs,
  211. )
  212. hidden_states = self.post_attention_layernorm(hidden_states)
  213. hidden_states = residual + hidden_states
  214. # Fully Connected
  215. residual = hidden_states
  216. hidden_states = self.mlp(hidden_states)
  217. hidden_states = self.post_feedforward_layernorm(hidden_states)
  218. hidden_states = residual + hidden_states
  219. return hidden_states
  220. class Olmo2RotaryEmbedding(nn.Module):
  221. inv_freq: torch.Tensor # fix linting for `register_buffer`
  222. def __init__(self, config: Olmo2Config, device=None):
  223. super().__init__()
  224. # BC: "rope_type" was originally "type"
  225. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  226. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  227. else:
  228. self.rope_type = "default"
  229. self.max_seq_len_cached = config.max_position_embeddings
  230. self.original_max_seq_len = config.max_position_embeddings
  231. self.config = config
  232. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  233. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  234. self.register_buffer("inv_freq", inv_freq, persistent=False)
  235. self.original_inv_freq = self.inv_freq
  236. @torch.no_grad()
  237. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  238. def forward(self, x, position_ids):
  239. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  240. position_ids_expanded = position_ids[:, None, :].float()
  241. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  242. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  243. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  244. emb = torch.cat((freqs, freqs), dim=-1)
  245. cos = emb.cos() * self.attention_scaling
  246. sin = emb.sin() * self.attention_scaling
  247. return cos, sin
  248. @auto_docstring
  249. class Olmo2PreTrainedModel(PreTrainedModel):
  250. config: Olmo2Config
  251. base_model_prefix = "model"
  252. supports_gradient_checkpointing = True
  253. _no_split_modules = ["Olmo2DecoderLayer"]
  254. _skip_keys_device_placement = ["past_key_values"]
  255. _supports_flash_attn = True
  256. _supports_sdpa = True
  257. _supports_flex_attn = True
  258. _can_compile_fullgraph = True
  259. _supports_attention_backend = True
  260. _can_record_outputs = {
  261. "hidden_states": Olmo2DecoderLayer,
  262. "attentions": Olmo2Attention,
  263. }
  264. @auto_docstring
  265. class Olmo2Model(Olmo2PreTrainedModel):
  266. def __init__(self, config: Olmo2Config):
  267. super().__init__(config)
  268. self.padding_idx = config.pad_token_id
  269. self.vocab_size = config.vocab_size
  270. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  271. self.layers = nn.ModuleList(
  272. [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  273. )
  274. self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  275. self.rotary_emb = Olmo2RotaryEmbedding(config=config)
  276. self.gradient_checkpointing = False
  277. # Initialize weights and apply final processing
  278. self.post_init()
  279. @check_model_inputs()
  280. @auto_docstring
  281. def forward(
  282. self,
  283. input_ids: Optional[torch.LongTensor] = None,
  284. attention_mask: Optional[torch.Tensor] = None,
  285. position_ids: Optional[torch.LongTensor] = None,
  286. past_key_values: Optional[Cache] = None,
  287. inputs_embeds: Optional[torch.FloatTensor] = None,
  288. cache_position: Optional[torch.LongTensor] = None,
  289. use_cache: Optional[bool] = None,
  290. **kwargs: Unpack[TransformersKwargs],
  291. ) -> BaseModelOutputWithPast:
  292. if (input_ids is None) ^ (inputs_embeds is not None):
  293. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  294. if inputs_embeds is None:
  295. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  296. if use_cache and past_key_values is None:
  297. past_key_values = DynamicCache(config=self.config)
  298. if cache_position is None:
  299. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  300. cache_position: torch.Tensor = torch.arange(
  301. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  302. )
  303. if position_ids is None:
  304. position_ids = cache_position.unsqueeze(0)
  305. causal_mask = create_causal_mask(
  306. config=self.config,
  307. input_embeds=inputs_embeds,
  308. attention_mask=attention_mask,
  309. cache_position=cache_position,
  310. past_key_values=past_key_values,
  311. position_ids=position_ids,
  312. )
  313. hidden_states = inputs_embeds
  314. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  315. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  316. hidden_states = decoder_layer(
  317. hidden_states,
  318. attention_mask=causal_mask,
  319. position_ids=position_ids,
  320. past_key_values=past_key_values,
  321. cache_position=cache_position,
  322. position_embeddings=position_embeddings,
  323. **kwargs,
  324. )
  325. hidden_states = self.norm(hidden_states)
  326. return BaseModelOutputWithPast(
  327. last_hidden_state=hidden_states,
  328. past_key_values=past_key_values,
  329. )
  330. @auto_docstring
  331. class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin):
  332. _tied_weights_keys = ["lm_head.weight"]
  333. _tp_plan = {"lm_head": "colwise_rep"}
  334. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  335. def __init__(self, config):
  336. super().__init__(config)
  337. self.model = Olmo2Model(config)
  338. self.vocab_size = config.vocab_size
  339. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  340. # Initialize weights and apply final processing
  341. self.post_init()
  342. @can_return_tuple
  343. @auto_docstring
  344. def forward(
  345. self,
  346. input_ids: Optional[torch.LongTensor] = None,
  347. attention_mask: Optional[torch.Tensor] = None,
  348. position_ids: Optional[torch.LongTensor] = None,
  349. past_key_values: Optional[Cache] = None,
  350. inputs_embeds: Optional[torch.FloatTensor] = None,
  351. labels: Optional[torch.LongTensor] = None,
  352. use_cache: Optional[bool] = None,
  353. cache_position: Optional[torch.LongTensor] = None,
  354. logits_to_keep: Union[int, torch.Tensor] = 0,
  355. **kwargs: Unpack[TransformersKwargs],
  356. ) -> CausalLMOutputWithPast:
  357. r"""
  358. Example:
  359. ```python
  360. >>> from transformers import AutoTokenizer, Olmo2ForCausalLM
  361. >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
  362. >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf")
  363. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  364. >>> inputs = tokenizer(prompt, return_tensors="pt")
  365. >>> # Generate
  366. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  367. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  368. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  369. ```"""
  370. outputs: BaseModelOutputWithPast = self.model(
  371. input_ids=input_ids,
  372. attention_mask=attention_mask,
  373. position_ids=position_ids,
  374. past_key_values=past_key_values,
  375. inputs_embeds=inputs_embeds,
  376. use_cache=use_cache,
  377. cache_position=cache_position,
  378. **kwargs,
  379. )
  380. hidden_states = outputs.last_hidden_state
  381. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  382. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  383. logits = self.lm_head(hidden_states[:, slice_indices, :])
  384. loss = None
  385. if labels is not None:
  386. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  387. return CausalLMOutputWithPast(
  388. loss=loss,
  389. logits=logits,
  390. past_key_values=outputs.past_key_values,
  391. hidden_states=outputs.hidden_states,
  392. attentions=outputs.attentions,
  393. )
  394. __all__ = ["Olmo2ForCausalLM", "Olmo2Model", "Olmo2PreTrainedModel"]