modeling_nemotron.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. # coding=utf-8
  2. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  3. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch Nemotron model."""
  17. import math
  18. from typing import Optional, Union
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import Size, Tensor, nn
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, StaticCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import AttentionMaskConverter
  26. from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  27. from ...modeling_layers import (
  28. GenericForQuestionAnswering,
  29. GenericForSequenceClassification,
  30. GenericForTokenClassification,
  31. GradientCheckpointingLayer,
  32. )
  33. from ...modeling_outputs import (
  34. BaseModelOutputWithPast,
  35. CausalLMOutputWithPast,
  36. )
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import PreTrainedModel
  39. from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
  40. from ...utils.deprecation import deprecate_kwarg
  41. from .configuration_nemotron import NemotronConfig
  42. if is_torch_flex_attn_available():
  43. from torch.nn.attention.flex_attention import BlockMask
  44. from ...integrations.flex_attention import make_flex_block_causal_mask
  45. logger = logging.get_logger(__name__)
  46. def _cast_if_autocast_enabled(device_type, *args):
  47. if not torch.is_autocast_enabled():
  48. return args
  49. else:
  50. # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
  51. target_dtype = (
  52. torch.get_autocast_dtype(device_type)
  53. if hasattr(torch, "get_autocast_dtype")
  54. else torch.get_autocast_gpu_dtype()
  55. )
  56. return torch.amp.autocast_mode._cast(args, device_type, target_dtype)
  57. class NemotronLayerNorm1P(nn.LayerNorm):
  58. def __init__(
  59. self,
  60. normalized_shape: Union[int, list[int], Size],
  61. eps: float = 1e-5,
  62. elementwise_affine: bool = True,
  63. bias: bool = True,
  64. device=None,
  65. dtype=None,
  66. ):
  67. super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
  68. def forward(self, input: Tensor) -> Tensor:
  69. device_type = input.device.type if input.device.type != "mps" else "cpu"
  70. args = _cast_if_autocast_enabled(
  71. device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
  72. )
  73. with torch.autocast(device_type=input.device.type, enabled=False):
  74. return F.layer_norm(*args)
  75. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  76. class NemotronRotaryEmbedding(nn.Module):
  77. inv_freq: torch.Tensor # fix linting for `register_buffer`
  78. # Ignore copy
  79. def __init__(
  80. self,
  81. config: NemotronConfig,
  82. device=None,
  83. ):
  84. super().__init__()
  85. self.rope_type = "default"
  86. self.max_seq_len_cached = config.max_position_embeddings
  87. self.original_max_seq_len = config.max_position_embeddings
  88. self.config = config
  89. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  90. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  91. self.register_buffer("inv_freq", inv_freq, persistent=False)
  92. self.original_inv_freq = self.inv_freq
  93. @torch.no_grad()
  94. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  95. def forward(self, x, position_ids):
  96. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  97. position_ids_expanded = position_ids[:, None, :].float()
  98. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  99. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  100. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  101. emb = torch.cat((freqs, freqs), dim=-1)
  102. cos = emb.cos() * self.attention_scaling
  103. sin = emb.sin() * self.attention_scaling
  104. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  105. # Copied from transformers.models.llama.modeling_llama.rotate_half
  106. def rotate_half(x):
  107. """Rotates half the hidden dims of the input."""
  108. x1 = x[..., : x.shape[-1] // 2]
  109. x2 = x[..., x.shape[-1] // 2 :]
  110. return torch.cat((-x2, x1), dim=-1)
  111. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  112. """Applies Rotary Position Embedding to the query and key tensors.
  113. Args:
  114. q (`torch.Tensor`): The query tensor.
  115. k (`torch.Tensor`): The key tensor.
  116. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  117. sin (`torch.Tensor`): The sine part of the rotary embedding.
  118. position_ids (`torch.Tensor`, *optional*):
  119. Deprecated and unused.
  120. unsqueeze_dim (`int`, *optional*, defaults to 1):
  121. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  122. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  123. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  124. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  125. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  126. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  127. Returns:
  128. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  129. """
  130. cos = cos.unsqueeze(unsqueeze_dim)
  131. sin = sin.unsqueeze(unsqueeze_dim)
  132. rot_dim = cos.shape[-1]
  133. # If q_pass/k_pass is empty, rotary pos embedding is applied to all tensor q/k
  134. q, q_pass = q[..., :rot_dim], q[..., rot_dim:]
  135. k, k_pass = k[..., :rot_dim], k[..., rot_dim:]
  136. q_embed = (q * cos) + (rotate_half(q) * sin)
  137. k_embed = (k * cos) + (rotate_half(k) * sin)
  138. return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1)
  139. class NemotronMLP(nn.Module):
  140. def __init__(self, config):
  141. super().__init__()
  142. self.config = config
  143. self.hidden_size = config.hidden_size
  144. self.intermediate_size = config.intermediate_size
  145. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  146. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  147. self.act_fn = ACT2FN[config.hidden_act]
  148. def forward(self, x):
  149. return self.down_proj(self.act_fn(self.up_proj(x)))
  150. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  151. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  152. """
  153. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  154. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  155. """
  156. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  157. if n_rep == 1:
  158. return hidden_states
  159. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  160. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  161. class NemotronAttention(nn.Module):
  162. """Multi-headed attention from 'Attention Is All You Need' paper"""
  163. def __init__(self, config: NemotronConfig, layer_idx: Optional[int] = None):
  164. super().__init__()
  165. self.config = config
  166. self.layer_idx = layer_idx
  167. if layer_idx is None:
  168. logger.warning_once(
  169. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  170. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  171. "when creating this class."
  172. )
  173. self.attention_dropout = config.attention_dropout
  174. self.hidden_size = config.hidden_size
  175. self.num_heads = config.num_attention_heads
  176. self.head_dim = config.head_dim
  177. self.num_key_value_heads = config.num_key_value_heads
  178. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  179. self.max_position_embeddings = config.max_position_embeddings
  180. self.rope_theta = config.rope_theta
  181. self.partial_rotary_factor = config.partial_rotary_factor
  182. self.is_causal = True
  183. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  184. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  185. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  186. self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)
  187. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  188. def forward(
  189. self,
  190. hidden_states: torch.Tensor,
  191. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  192. attention_mask: Optional[torch.Tensor] = None,
  193. position_ids: Optional[torch.LongTensor] = None,
  194. past_key_values: Optional[Cache] = None,
  195. output_attentions: bool = False,
  196. use_cache: bool = False,
  197. cache_position: Optional[torch.LongTensor] = None,
  198. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  199. bsz, q_len, _ = hidden_states.size()
  200. query_states = self.q_proj(hidden_states)
  201. key_states = self.k_proj(hidden_states)
  202. value_states = self.v_proj(hidden_states)
  203. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  204. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  205. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  206. if position_embeddings is not None:
  207. cos, sin = position_embeddings
  208. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  209. if past_key_values is not None:
  210. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  211. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  212. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  213. key_states = repeat_kv(key_states, self.num_key_value_groups)
  214. value_states = repeat_kv(value_states, self.num_key_value_groups)
  215. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  216. if attention_mask is not None: # no matter the length, we just slice it
  217. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  218. attn_weights = attn_weights + causal_mask
  219. # upcast attention to fp32
  220. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  221. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  222. attn_output = torch.matmul(attn_weights, value_states)
  223. attn_output = attn_output.transpose(1, 2).contiguous()
  224. attn_output = attn_output.reshape(bsz, q_len, -1)
  225. attn_output = self.o_proj(attn_output)
  226. if not output_attentions:
  227. attn_weights = None
  228. return attn_output, attn_weights
  229. # NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  230. # TODO cyril: modular
  231. class NemotronFlashAttention2(NemotronAttention):
  232. """
  233. Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays
  234. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  235. flash attention and deal with padding tokens in case the input contains any of them.
  236. """
  237. def __init__(self, *args, **kwargs):
  238. super().__init__(*args, **kwargs)
  239. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  240. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  241. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  242. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  243. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  244. def forward(
  245. self,
  246. hidden_states: torch.Tensor,
  247. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  248. attention_mask: Optional[torch.LongTensor] = None,
  249. position_ids: Optional[torch.LongTensor] = None,
  250. past_key_values: Optional[Cache] = None,
  251. output_attentions: bool = False,
  252. use_cache: bool = False,
  253. cache_position: Optional[torch.LongTensor] = None,
  254. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  255. if isinstance(past_key_values, StaticCache):
  256. raise ValueError(
  257. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  258. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  259. )
  260. output_attentions = False
  261. bsz, q_len, _ = hidden_states.size()
  262. query_states = self.q_proj(hidden_states)
  263. key_states = self.k_proj(hidden_states)
  264. value_states = self.v_proj(hidden_states)
  265. # Flash attention requires the input to have the shape
  266. # batch_size x seq_length x head_dim x hidden_dim
  267. # therefore we just need to keep the original shape
  268. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  269. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  270. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  271. if position_embeddings is not None:
  272. cos, sin = position_embeddings
  273. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  274. if past_key_values is not None:
  275. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  276. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  277. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  278. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  279. # to be able to avoid many of these transpose/reshape/view.
  280. query_states = query_states.transpose(1, 2)
  281. key_states = key_states.transpose(1, 2)
  282. value_states = value_states.transpose(1, 2)
  283. dropout_rate = self.attention_dropout if self.training else 0.0
  284. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  285. # therefore the input hidden states gets silently casted in float32. Hence, we need
  286. # cast them back in the correct dtype just to be sure everything works as expected.
  287. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  288. # in fp32. (NemotronRMSNorm handles it correctly)
  289. input_dtype = query_states.dtype
  290. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  291. if input_dtype == torch.float32:
  292. if torch.is_autocast_enabled():
  293. # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
  294. target_dtype = (
  295. torch.get_autocast_dtype(device_type)
  296. if hasattr(torch, "get_autocast_dtype")
  297. else torch.get_autocast_gpu_dtype()
  298. )
  299. # Handle the case where the model is quantized
  300. elif hasattr(self.config, "_pre_quantization_dtype"):
  301. target_dtype = self.config._pre_quantization_dtype
  302. else:
  303. target_dtype = self.q_proj.weight.dtype
  304. logger.warning_once(
  305. f"The input hidden states seems to be silently casted in float32, this might be related to"
  306. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  307. f" {target_dtype}."
  308. )
  309. query_states = query_states.to(target_dtype)
  310. key_states = key_states.to(target_dtype)
  311. value_states = value_states.to(target_dtype)
  312. attn_output = _flash_attention_forward(
  313. query_states,
  314. key_states,
  315. value_states,
  316. attention_mask,
  317. q_len,
  318. position_ids=position_ids,
  319. dropout=dropout_rate,
  320. sliding_window=getattr(self, "sliding_window", None),
  321. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  322. is_causal=self.is_causal,
  323. )
  324. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  325. attn_output = self.o_proj(attn_output)
  326. if not output_attentions:
  327. attn_weights = None
  328. return attn_output, attn_weights
  329. # NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  330. # TODO cyril: modular
  331. class NemotronSdpaAttention(NemotronAttention):
  332. """
  333. Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  334. `NemotronAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  335. SDPA API.
  336. """
  337. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  338. def forward(
  339. self,
  340. hidden_states: torch.Tensor,
  341. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  342. attention_mask: Optional[torch.Tensor] = None,
  343. position_ids: Optional[torch.LongTensor] = None,
  344. past_key_values: Optional[Cache] = None,
  345. output_attentions: bool = False,
  346. use_cache: bool = False,
  347. cache_position: Optional[torch.LongTensor] = None,
  348. **kwargs,
  349. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  350. if output_attentions:
  351. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  352. logger.warning_once(
  353. "NemotronModel is using NemotronSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  354. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  355. )
  356. return super().forward(
  357. hidden_states=hidden_states,
  358. attention_mask=attention_mask,
  359. position_ids=position_ids,
  360. past_key_values=past_key_values,
  361. output_attentions=output_attentions,
  362. use_cache=use_cache,
  363. cache_position=cache_position,
  364. position_embeddings=position_embeddings,
  365. )
  366. bsz, q_len, _ = hidden_states.size()
  367. query_states = self.q_proj(hidden_states)
  368. key_states = self.k_proj(hidden_states)
  369. value_states = self.v_proj(hidden_states)
  370. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  371. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  372. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  373. if position_embeddings is not None:
  374. cos, sin = position_embeddings
  375. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  376. if past_key_values is not None:
  377. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  378. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  379. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  380. key_states = repeat_kv(key_states, self.num_key_value_groups)
  381. value_states = repeat_kv(value_states, self.num_key_value_groups)
  382. causal_mask = attention_mask
  383. if attention_mask is not None:
  384. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  385. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  386. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  387. if query_states.device.type == "cuda" and causal_mask is not None:
  388. query_states = query_states.contiguous()
  389. key_states = key_states.contiguous()
  390. value_states = value_states.contiguous()
  391. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  392. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  393. is_causal = causal_mask is None and q_len > 1
  394. attn_output = torch.nn.functional.scaled_dot_product_attention(
  395. query_states,
  396. key_states,
  397. value_states,
  398. attn_mask=causal_mask,
  399. dropout_p=self.attention_dropout if self.training else 0.0,
  400. is_causal=is_causal,
  401. )
  402. attn_output = attn_output.transpose(1, 2).contiguous()
  403. attn_output = attn_output.view(bsz, q_len, -1)
  404. attn_output = self.o_proj(attn_output)
  405. return attn_output, None
  406. NEMOTRON_ATTENTION_CLASSES = {
  407. "eager": NemotronAttention,
  408. "flash_attention_2": NemotronFlashAttention2,
  409. "sdpa": NemotronSdpaAttention,
  410. }
  411. # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  412. # no longer copied after attention refactors
  413. class NemotronDecoderLayer(GradientCheckpointingLayer):
  414. # Ignore copy
  415. def __init__(self, config: NemotronConfig, layer_idx: int):
  416. super().__init__()
  417. self.hidden_size = config.hidden_size
  418. self.self_attn = NEMOTRON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  419. self.mlp = NemotronMLP(config)
  420. self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
  421. self.post_attention_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
  422. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  423. def forward(
  424. self,
  425. hidden_states: torch.Tensor,
  426. attention_mask: Optional[torch.Tensor] = None,
  427. position_ids: Optional[torch.LongTensor] = None,
  428. past_key_values: Optional[Cache] = None,
  429. output_attentions: Optional[bool] = False,
  430. use_cache: Optional[bool] = False,
  431. cache_position: Optional[torch.LongTensor] = None,
  432. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
  433. **kwargs,
  434. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  435. """
  436. Args:
  437. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  438. attention_mask (`torch.FloatTensor`, *optional*):
  439. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  440. query_sequence_length, key_sequence_length)` if default attention is used.
  441. output_attentions (`bool`, *optional*):
  442. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  443. returned tensors for more detail.
  444. use_cache (`bool`, *optional*):
  445. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  446. (see `past_key_values`).
  447. past_key_values (`Cache`, *optional*): cached past key and value projection states
  448. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  449. Indices depicting the position of the input sequence tokens in the sequence
  450. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  451. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  452. with `head_dim` being the embedding dimension of each attention head.
  453. kwargs (`dict`, *optional*):
  454. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  455. into the model
  456. """
  457. residual = hidden_states
  458. hidden_states = self.input_layernorm(hidden_states)
  459. # Self Attention
  460. hidden_states, self_attn_weights = self.self_attn(
  461. hidden_states=hidden_states,
  462. attention_mask=attention_mask,
  463. position_ids=position_ids,
  464. past_key_values=past_key_values,
  465. output_attentions=output_attentions,
  466. use_cache=use_cache,
  467. cache_position=cache_position,
  468. position_embeddings=position_embeddings,
  469. **kwargs,
  470. )
  471. hidden_states = residual + hidden_states
  472. # Fully Connected
  473. residual = hidden_states
  474. hidden_states = self.post_attention_layernorm(hidden_states)
  475. hidden_states = self.mlp(hidden_states)
  476. hidden_states = residual + hidden_states
  477. outputs = (hidden_states,)
  478. if output_attentions:
  479. outputs += (self_attn_weights,)
  480. return outputs
  481. @auto_docstring
  482. class NemotronPreTrainedModel(PreTrainedModel):
  483. config: NemotronConfig
  484. base_model_prefix = "model"
  485. supports_gradient_checkpointing = True
  486. _no_split_modules = ["NemotronDecoderLayer"]
  487. _skip_keys_device_placement = ["past_key_values"]
  488. _supports_flash_attn = True
  489. _supports_sdpa = True
  490. _can_compile_fullgraph = True
  491. def _init_weights(self, module):
  492. std = self.config.initializer_range
  493. if isinstance(module, nn.Linear):
  494. module.weight.data.normal_(mean=0.0, std=std)
  495. if module.bias is not None:
  496. module.bias.data.zero_()
  497. elif isinstance(module, nn.Embedding):
  498. module.weight.data.normal_(mean=0.0, std=std)
  499. if module.padding_idx is not None:
  500. module.weight.data[module.padding_idx].zero_()
  501. elif isinstance(module, NemotronLayerNorm1P):
  502. module.weight.data.fill_(1.0)
  503. module.bias.data.zero_()
  504. @auto_docstring
  505. class NemotronModel(NemotronPreTrainedModel):
  506. """
  507. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NemotronDecoderLayer`]
  508. Args:
  509. config: NemotronConfig
  510. """
  511. def __init__(self, config: NemotronConfig):
  512. super().__init__(config)
  513. self.padding_idx = config.pad_token_id
  514. self.vocab_size = config.vocab_size
  515. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  516. self.layers = nn.ModuleList(
  517. [NemotronDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  518. )
  519. self.norm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
  520. self.rotary_emb = NemotronRotaryEmbedding(config=config)
  521. self.gradient_checkpointing = False
  522. # Initialize weights and apply final processing
  523. self.post_init()
  524. @can_return_tuple
  525. @auto_docstring
  526. def forward(
  527. self,
  528. input_ids: Optional[torch.LongTensor] = None,
  529. attention_mask: Optional[torch.Tensor] = None,
  530. position_ids: Optional[torch.LongTensor] = None,
  531. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  532. inputs_embeds: Optional[torch.FloatTensor] = None,
  533. use_cache: Optional[bool] = None,
  534. output_attentions: Optional[bool] = None,
  535. output_hidden_states: Optional[bool] = None,
  536. cache_position: Optional[torch.LongTensor] = None,
  537. ) -> BaseModelOutputWithPast:
  538. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  539. output_hidden_states = (
  540. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  541. )
  542. use_cache = use_cache if use_cache is not None else self.config.use_cache
  543. if (input_ids is None) ^ (inputs_embeds is not None):
  544. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  545. if self.gradient_checkpointing and self.training and use_cache:
  546. logger.warning_once(
  547. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  548. )
  549. use_cache = False
  550. if use_cache and past_key_values is None:
  551. past_key_values = DynamicCache(config=self.config)
  552. if inputs_embeds is None:
  553. inputs_embeds = self.embed_tokens(input_ids)
  554. if use_cache and past_key_values is None:
  555. past_key_values = DynamicCache(config=self.config)
  556. if cache_position is None:
  557. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  558. cache_position = torch.arange(
  559. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  560. )
  561. if position_ids is None:
  562. position_ids = cache_position.unsqueeze(0)
  563. causal_mask = self._update_causal_mask(
  564. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  565. )
  566. # embed positions
  567. hidden_states = inputs_embeds
  568. # create position embeddings to be shared across the decoder layers
  569. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  570. # decoder layers
  571. all_hidden_states = () if output_hidden_states else None
  572. all_self_attns = () if output_attentions else None
  573. for decoder_layer in self.layers:
  574. if output_hidden_states:
  575. all_hidden_states += (hidden_states,)
  576. layer_outputs = decoder_layer(
  577. hidden_states,
  578. attention_mask=causal_mask,
  579. position_ids=position_ids,
  580. past_key_values=past_key_values,
  581. output_attentions=output_attentions,
  582. use_cache=use_cache,
  583. cache_position=cache_position,
  584. position_embeddings=position_embeddings,
  585. )
  586. hidden_states = layer_outputs[0]
  587. if output_attentions:
  588. all_self_attns += (layer_outputs[1],)
  589. hidden_states = self.norm(hidden_states)
  590. # add hidden states from the last decoder layer
  591. if output_hidden_states:
  592. all_hidden_states += (hidden_states,)
  593. return BaseModelOutputWithPast(
  594. last_hidden_state=hidden_states,
  595. past_key_values=past_key_values,
  596. hidden_states=all_hidden_states,
  597. attentions=all_self_attns,
  598. )
  599. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  600. def _update_causal_mask(
  601. self,
  602. attention_mask: Union[torch.Tensor, "BlockMask"],
  603. input_tensor: torch.Tensor,
  604. cache_position: torch.Tensor,
  605. past_key_values: Cache,
  606. output_attentions: bool = False,
  607. ):
  608. if self.config._attn_implementation == "flash_attention_2":
  609. if attention_mask is not None and (attention_mask == 0.0).any():
  610. return attention_mask
  611. return None
  612. if self.config._attn_implementation == "flex_attention":
  613. if isinstance(attention_mask, torch.Tensor):
  614. attention_mask = make_flex_block_causal_mask(attention_mask)
  615. return attention_mask
  616. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  617. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  618. # to infer the attention mask.
  619. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  620. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  621. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  622. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  623. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  624. attention_mask,
  625. inputs_embeds=input_tensor,
  626. past_key_values_length=past_seen_tokens,
  627. is_training=self.training,
  628. ):
  629. return None
  630. dtype = input_tensor.dtype
  631. sequence_length = input_tensor.shape[1]
  632. if using_compilable_cache:
  633. target_length = past_key_values.get_max_cache_shape()
  634. else:
  635. target_length = (
  636. attention_mask.shape[-1]
  637. if isinstance(attention_mask, torch.Tensor)
  638. else past_seen_tokens + sequence_length + 1
  639. )
  640. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  641. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  642. attention_mask,
  643. sequence_length=sequence_length,
  644. target_length=target_length,
  645. dtype=dtype,
  646. cache_position=cache_position,
  647. batch_size=input_tensor.shape[0],
  648. )
  649. if (
  650. self.config._attn_implementation == "sdpa"
  651. and attention_mask is not None
  652. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  653. and not output_attentions
  654. ):
  655. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  656. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  657. # Details: https://github.com/pytorch/pytorch/issues/110213
  658. min_dtype = torch.finfo(dtype).min
  659. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  660. return causal_mask
  661. @staticmethod
  662. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  663. def _prepare_4d_causal_attention_mask_with_cache_position(
  664. attention_mask: torch.Tensor,
  665. sequence_length: int,
  666. target_length: int,
  667. dtype: torch.dtype,
  668. cache_position: torch.Tensor,
  669. batch_size: int,
  670. **kwargs,
  671. ):
  672. """
  673. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  674. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  675. Args:
  676. attention_mask (`torch.Tensor`):
  677. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  678. `(batch_size, 1, query_length, key_value_length)`.
  679. sequence_length (`int`):
  680. The sequence length being processed.
  681. target_length (`int`):
  682. The target length: when generating with static cache, the mask should be as long as the static cache,
  683. to account for the 0 padding, the part of the cache that is not filled yet.
  684. dtype (`torch.dtype`):
  685. The dtype to use for the 4D attention mask.
  686. cache_position (`torch.Tensor`):
  687. Indices depicting the position of the input sequence tokens in the sequence.
  688. batch_size (`torch.Tensor`):
  689. Batch size.
  690. """
  691. if attention_mask is not None and attention_mask.dim() == 4:
  692. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  693. causal_mask = attention_mask
  694. else:
  695. min_dtype = torch.finfo(dtype).min
  696. causal_mask = torch.full(
  697. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  698. )
  699. if sequence_length != 1:
  700. causal_mask = torch.triu(causal_mask, diagonal=1)
  701. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  702. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  703. if attention_mask is not None:
  704. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  705. mask_length = attention_mask.shape[-1]
  706. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  707. causal_mask.device
  708. )
  709. padding_mask = padding_mask == 0
  710. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  711. padding_mask, min_dtype
  712. )
  713. return causal_mask
  714. # TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  715. class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
  716. _tied_weights_keys = ["lm_head.weight"]
  717. def __init__(self, config):
  718. super().__init__(config)
  719. self.model = NemotronModel(config)
  720. self.vocab_size = config.vocab_size
  721. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  722. # Initialize weights and apply final processing
  723. self.post_init()
  724. @can_return_tuple
  725. @auto_docstring
  726. def forward(
  727. self,
  728. input_ids: Optional[torch.LongTensor] = None,
  729. attention_mask: Optional[torch.Tensor] = None,
  730. position_ids: Optional[torch.LongTensor] = None,
  731. past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
  732. inputs_embeds: Optional[torch.FloatTensor] = None,
  733. labels: Optional[torch.LongTensor] = None,
  734. use_cache: Optional[bool] = None,
  735. output_attentions: Optional[bool] = None,
  736. output_hidden_states: Optional[bool] = None,
  737. cache_position: Optional[torch.LongTensor] = None,
  738. logits_to_keep: Union[int, torch.Tensor] = 0,
  739. **kwargs,
  740. ) -> CausalLMOutputWithPast:
  741. r"""
  742. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  743. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  744. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  745. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  746. Example:
  747. ```python
  748. >>> from transformers import AutoTokenizer, NemotronForCausalLM
  749. >>> model = NemotronForCausalLM.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf")
  750. >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf")
  751. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  752. >>> inputs = tokenizer(prompt, return_tensors="pt")
  753. >>> # Generate
  754. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  755. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  756. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  757. ```"""
  758. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  759. output_hidden_states = (
  760. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  761. )
  762. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  763. outputs: BaseModelOutputWithPast = self.model(
  764. input_ids=input_ids,
  765. attention_mask=attention_mask,
  766. position_ids=position_ids,
  767. past_key_values=past_key_values,
  768. inputs_embeds=inputs_embeds,
  769. use_cache=use_cache,
  770. output_attentions=output_attentions,
  771. output_hidden_states=output_hidden_states,
  772. cache_position=cache_position,
  773. )
  774. hidden_states = outputs.last_hidden_state
  775. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  776. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  777. logits = self.lm_head(hidden_states[:, slice_indices, :])
  778. loss = None
  779. if labels is not None:
  780. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  781. return CausalLMOutputWithPast(
  782. loss=loss,
  783. logits=logits,
  784. past_key_values=outputs.past_key_values,
  785. hidden_states=outputs.hidden_states,
  786. attentions=outputs.attentions,
  787. )
  788. class NemotronForSequenceClassification(GenericForSequenceClassification, NemotronPreTrainedModel): ...
  789. class NemotronForQuestionAnswering(GenericForQuestionAnswering, NemotronPreTrainedModel):
  790. base_model_prefix = "transformer"
  791. class NemotronForTokenClassification(GenericForTokenClassification, NemotronPreTrainedModel): ...
  792. __all__ = [
  793. "NemotronForQuestionAnswering",
  794. "NemotronForCausalLM",
  795. "NemotronModel",
  796. "NemotronPreTrainedModel",
  797. "NemotronForSequenceClassification",
  798. "NemotronForTokenClassification",
  799. ]