modular_modernbert.py 72 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698
  1. # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import copy
  16. import math
  17. from contextlib import nullcontext
  18. from typing import Literal, Optional, Union
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...configuration_utils import PretrainedConfig
  25. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. MaskedLMOutput,
  30. MultipleChoiceModelOutput,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import auto_docstring, is_flash_attn_2_available, logging
  37. from ...utils.import_utils import is_triton_available
  38. from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
  39. if is_flash_attn_2_available():
  40. from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
  41. from flash_attn.layers.rotary import RotaryEmbedding
  42. from flash_attn.ops.triton.rotary import apply_rotary
  43. else:
  44. RotaryEmbedding = object
  45. logger = logging.get_logger(__name__)
  46. class ModernBertConfig(PretrainedConfig):
  47. r"""
  48. This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
  49. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  50. defaults will yield a similar configuration to that of the ModernBERT-base.
  51. e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
  52. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  53. documentation from [`PretrainedConfig`] for more information.
  54. Args:
  55. vocab_size (`int`, *optional*, defaults to 50368):
  56. Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
  57. `inputs_ids` passed when calling [`ModernBertModel`]
  58. hidden_size (`int`, *optional*, defaults to 768):
  59. Dimension of the hidden representations.
  60. intermediate_size (`int`, *optional*, defaults to 1152):
  61. Dimension of the MLP representations.
  62. num_hidden_layers (`int`, *optional*, defaults to 22):
  63. Number of hidden layers in the Transformer decoder.
  64. num_attention_heads (`int`, *optional*, defaults to 12):
  65. Number of attention heads for each attention layer in the Transformer decoder.
  66. hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
  67. The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
  68. if not specified.
  69. max_position_embeddings (`int`, *optional*, defaults to 8192):
  70. The maximum sequence length that this model might ever be used with.
  71. initializer_range (`float`, *optional*, defaults to 0.02):
  72. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  73. initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
  74. The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
  75. norm_eps (`float`, *optional*, defaults to 1e-05):
  76. The epsilon used by the rms normalization layers.
  77. norm_bias (`bool`, *optional*, defaults to `False`):
  78. Whether to use bias in the normalization layers.
  79. pad_token_id (`int`, *optional*, defaults to 50283):
  80. Padding token id.
  81. eos_token_id (`int`, *optional*, defaults to 50282):
  82. End of stream token id.
  83. bos_token_id (`int`, *optional*, defaults to 50281):
  84. Beginning of stream token id.
  85. cls_token_id (`int`, *optional*, defaults to 50281):
  86. Classification token id.
  87. sep_token_id (`int`, *optional*, defaults to 50282):
  88. Separation token id.
  89. global_rope_theta (`float`, *optional*, defaults to 160000.0):
  90. The base period of the global RoPE embeddings.
  91. attention_bias (`bool`, *optional*, defaults to `False`):
  92. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  93. attention_dropout (`float`, *optional*, defaults to 0.0):
  94. The dropout ratio for the attention probabilities.
  95. global_attn_every_n_layers (`int`, *optional*, defaults to 3):
  96. The number of layers between global attention layers.
  97. local_attention (`int`, *optional*, defaults to 128):
  98. The window size for local attention.
  99. local_rope_theta (`float`, *optional*, defaults to 10000.0):
  100. The base period of the local RoPE embeddings.
  101. embedding_dropout (`float`, *optional*, defaults to 0.0):
  102. The dropout ratio for the embeddings.
  103. mlp_bias (`bool`, *optional*, defaults to `False`):
  104. Whether to use bias in the MLP layers.
  105. mlp_dropout (`float`, *optional*, defaults to 0.0):
  106. The dropout ratio for the MLP layers.
  107. decoder_bias (`bool`, *optional*, defaults to `True`):
  108. Whether to use bias in the decoder layers.
  109. classifier_pooling (`str`, *optional*, defaults to `"cls"`):
  110. The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
  111. CLS token doesn't attend to all tokens on long sequences.
  112. classifier_dropout (`float`, *optional*, defaults to 0.0):
  113. The dropout ratio for the classifier.
  114. classifier_bias (`bool`, *optional*, defaults to `False`):
  115. Whether to use bias in the classifier.
  116. classifier_activation (`str`, *optional*, defaults to `"gelu"`):
  117. The activation function for the classifier.
  118. deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
  119. Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
  120. sparse_prediction (`bool`, *optional*, defaults to `False`):
  121. Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
  122. sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
  123. The index to ignore for the sparse prediction.
  124. reference_compile (`bool`, *optional*):
  125. Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
  126. the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
  127. shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
  128. be faster in some scenarios.
  129. repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
  130. When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
  131. applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
  132. Examples:
  133. ```python
  134. >>> from transformers import ModernBertModel, ModernBertConfig
  135. >>> # Initializing a ModernBert style configuration
  136. >>> configuration = ModernBertConfig()
  137. >>> # Initializing a model from the modernbert-base style configuration
  138. >>> model = ModernBertModel(configuration)
  139. >>> # Accessing the model configuration
  140. >>> configuration = model.config
  141. ```"""
  142. model_type = "modernbert"
  143. attribute_map = {"rope_theta": "global_rope_theta"}
  144. keys_to_ignore_at_inference = ["past_key_values"]
  145. def __init__(
  146. self,
  147. vocab_size=50368,
  148. hidden_size=768,
  149. intermediate_size=1152,
  150. num_hidden_layers=22,
  151. num_attention_heads=12,
  152. hidden_activation="gelu",
  153. max_position_embeddings=8192,
  154. initializer_range=0.02,
  155. initializer_cutoff_factor=2.0,
  156. norm_eps=1e-5,
  157. norm_bias=False,
  158. pad_token_id=50283,
  159. eos_token_id=50282,
  160. bos_token_id=50281,
  161. cls_token_id=50281,
  162. sep_token_id=50282,
  163. global_rope_theta=160000.0,
  164. attention_bias=False,
  165. attention_dropout=0.0,
  166. global_attn_every_n_layers=3,
  167. local_attention=128,
  168. local_rope_theta=10000.0,
  169. embedding_dropout=0.0,
  170. mlp_bias=False,
  171. mlp_dropout=0.0,
  172. decoder_bias=True,
  173. classifier_pooling: Literal["cls", "mean"] = "cls",
  174. classifier_dropout=0.0,
  175. classifier_bias=False,
  176. classifier_activation="gelu",
  177. deterministic_flash_attn=False,
  178. sparse_prediction=False,
  179. sparse_pred_ignore_index=-100,
  180. reference_compile=None,
  181. repad_logits_with_grad=False,
  182. **kwargs,
  183. ):
  184. super().__init__(
  185. pad_token_id=pad_token_id,
  186. bos_token_id=bos_token_id,
  187. eos_token_id=eos_token_id,
  188. cls_token_id=cls_token_id,
  189. sep_token_id=sep_token_id,
  190. **kwargs,
  191. )
  192. self.vocab_size = vocab_size
  193. self.max_position_embeddings = max_position_embeddings
  194. self.hidden_size = hidden_size
  195. self.intermediate_size = intermediate_size
  196. self.num_hidden_layers = num_hidden_layers
  197. self.num_attention_heads = num_attention_heads
  198. self.initializer_range = initializer_range
  199. self.initializer_cutoff_factor = initializer_cutoff_factor
  200. self.norm_eps = norm_eps
  201. self.norm_bias = norm_bias
  202. self.global_rope_theta = global_rope_theta
  203. self.attention_bias = attention_bias
  204. self.attention_dropout = attention_dropout
  205. self.hidden_activation = hidden_activation
  206. self.global_attn_every_n_layers = global_attn_every_n_layers
  207. self.local_attention = local_attention
  208. self.local_rope_theta = local_rope_theta
  209. self.embedding_dropout = embedding_dropout
  210. self.mlp_bias = mlp_bias
  211. self.mlp_dropout = mlp_dropout
  212. self.decoder_bias = decoder_bias
  213. self.classifier_pooling = classifier_pooling
  214. self.classifier_dropout = classifier_dropout
  215. self.classifier_bias = classifier_bias
  216. self.classifier_activation = classifier_activation
  217. self.deterministic_flash_attn = deterministic_flash_attn
  218. self.sparse_prediction = sparse_prediction
  219. self.sparse_pred_ignore_index = sparse_pred_ignore_index
  220. self.reference_compile = reference_compile
  221. self.repad_logits_with_grad = repad_logits_with_grad
  222. if self.classifier_pooling not in ["cls", "mean"]:
  223. raise ValueError(
  224. f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
  225. )
  226. def to_dict(self):
  227. output = super().to_dict()
  228. output.pop("reference_compile", None)
  229. return output
  230. def _unpad_modernbert_input(
  231. inputs: torch.Tensor,
  232. attention_mask: torch.Tensor,
  233. position_ids: Optional[torch.Tensor] = None,
  234. labels: Optional[torch.Tensor] = None,
  235. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
  236. """
  237. Remove padding from input sequences.
  238. Args:
  239. inputs: (batch, seqlen, ...) or (batch, seqlen)
  240. attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
  241. position_ids: (batch, seqlen), int, position ids
  242. labels: (batch, seqlen), int, labels
  243. Returns:
  244. unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
  245. indices: (total_nnz)
  246. cu_seqlens: (batch + 1), the cumulative sequence lengths
  247. max_seqlen_in_batch: int
  248. unpadded_position_ids: (total_nnz) or None
  249. unpadded_labels: (total_nnz) or None
  250. """
  251. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  252. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  253. max_seqlen_in_batch = int(seqlens_in_batch.max().item())
  254. cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  255. if inputs.dim() == 2:
  256. unpadded_inputs = inputs.flatten()[indices]
  257. else:
  258. batch, seqlen, *rest = inputs.shape
  259. shape = batch * seqlen
  260. unpadded_inputs = inputs.view(shape, *rest)[indices]
  261. unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
  262. unpadded_labels = labels.flatten()[indices] if labels is not None else None
  263. return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
  264. def _pad_modernbert_output(
  265. inputs: torch.Tensor,
  266. indices: torch.Tensor,
  267. batch: int,
  268. seqlen: int,
  269. ) -> torch.Tensor:
  270. """
  271. Add padding to sequences.
  272. Args:
  273. inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
  274. indices: (total_nnz)
  275. batch: int, batch size
  276. seqlen: int, max sequence length
  277. Returns:
  278. padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
  279. """
  280. if inputs.dim() == 1:
  281. output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
  282. output[indices] = inputs
  283. padded_inputs = output.view(batch, seqlen)
  284. else:
  285. _, *rest = inputs.shape
  286. output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
  287. output[indices] = inputs
  288. padded_inputs = output.view(batch, seqlen, *rest)
  289. return padded_inputs
  290. class ApplyRotaryEmbUnpad(torch.autograd.Function):
  291. @staticmethod
  292. def forward(
  293. ctx,
  294. qkv,
  295. cos,
  296. sin,
  297. cu_seqlens: Optional[torch.Tensor] = None,
  298. max_seqlen: Optional[int] = None,
  299. ):
  300. # (total_nnz, 3, nheads, headdim)
  301. qkv = qkv.contiguous()
  302. total_nnz, _three, _nheads, headdim = qkv.shape
  303. # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
  304. # we get the same tensor
  305. # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
  306. qk = qkv[:, :2].view(total_nnz, -1, headdim)
  307. apply_rotary(
  308. qk,
  309. cos,
  310. sin,
  311. seqlen_offsets=0,
  312. cu_seqlens=cu_seqlens,
  313. max_seqlen=max_seqlen,
  314. interleaved=False,
  315. inplace=True,
  316. )
  317. ctx.save_for_backward(cos, sin, cu_seqlens)
  318. ctx.max_seqlen = max_seqlen
  319. return qkv
  320. @staticmethod
  321. def backward(ctx, do):
  322. cos, sin, cu_seqlens = ctx.saved_tensors
  323. do = do.contiguous()
  324. total_nnz, _three, _nheads, headdim = do.shape
  325. # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
  326. # we get the same tensor
  327. dqk = do[:, :2].view(total_nnz, -1, headdim)
  328. apply_rotary(
  329. dqk,
  330. cos,
  331. sin,
  332. seqlen_offsets=0,
  333. cu_seqlens=cu_seqlens,
  334. max_seqlen=ctx.max_seqlen,
  335. interleaved=False,
  336. inplace=True,
  337. conjugate=True,
  338. )
  339. return do, None, None, None, None, None, None
  340. def apply_rotary_unpadded(
  341. qkv,
  342. cos,
  343. sin,
  344. cu_seqlens: Optional[torch.Tensor] = None,
  345. max_seqlen: Optional[int] = None,
  346. ):
  347. """
  348. Arguments:
  349. qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
  350. cos, sin: (seqlen_rotary, rotary_dim / 2)
  351. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  352. of 1st half and 2nd half (GPT-NeoX style).
  353. inplace: if True, apply rotary embedding in-place.
  354. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
  355. Most commonly used in inference when we have KV cache.
  356. cu_seqlens: (batch + 1,) or None
  357. max_seqlen: int
  358. Return:
  359. out: (total_nnz, dim)
  360. rotary_dim must be <= headdim
  361. Apply rotary embedding to the first rotary_dim of x.
  362. """
  363. return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
  364. class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
  365. """
  366. The rotary position embeddings applied directly to unpadded sequences.
  367. """
  368. def __init__(
  369. self,
  370. dim: int,
  371. base: float = 10000.0,
  372. max_seqlen: Optional[int] = None,
  373. device: Optional[torch.device] = None,
  374. dtype: Optional[torch.dtype] = None,
  375. ):
  376. """
  377. max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
  378. up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
  379. the cos_sin_cache will be recomputed during the forward pass.
  380. """
  381. super().__init__(dim=dim, base=base, device=device, interleaved=False)
  382. self.max_seqlen = max_seqlen
  383. if max_seqlen is not None and device is not None and dtype is not None:
  384. self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
  385. def forward(
  386. self,
  387. qkv: torch.Tensor,
  388. cu_seqlens: torch.Tensor,
  389. max_seqlen: Optional[int] = None,
  390. ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  391. """
  392. Apply rotary embedding *inplace* to qkv.
  393. qkv: (total_nnz, 3, nheads, headdim)
  394. cu_seqlens: (batch + 1,) cumulative sequence lengths
  395. max_seqlen: int max seq length in the batch
  396. """
  397. if max_seqlen is not None:
  398. self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
  399. qkv = apply_rotary_unpadded(
  400. qkv,
  401. self._cos_cached,
  402. self._sin_cached,
  403. cu_seqlens=cu_seqlens,
  404. max_seqlen=max_seqlen,
  405. )
  406. return qkv
  407. def extra_repr(self) -> str:
  408. return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
  409. class ModernBertEmbeddings(nn.Module):
  410. """
  411. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  412. """
  413. def __init__(self, config: ModernBertConfig):
  414. super().__init__()
  415. self.config = config
  416. self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  417. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  418. self.drop = nn.Dropout(config.embedding_dropout)
  419. @torch.compile(dynamic=True)
  420. def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
  421. return self.drop(self.norm(self.tok_embeddings(input_ids)))
  422. def forward(
  423. self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
  424. ) -> torch.Tensor:
  425. if inputs_embeds is not None:
  426. hidden_states = self.drop(self.norm(inputs_embeds))
  427. else:
  428. hidden_states = (
  429. self.compiled_embeddings(input_ids)
  430. if self.config.reference_compile
  431. else self.drop(self.norm(self.tok_embeddings(input_ids)))
  432. )
  433. return hidden_states
  434. class ModernBertMLP(nn.Module):
  435. """Applies the GLU at the end of each ModernBERT layer.
  436. Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
  437. and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
  438. """
  439. def __init__(self, config: ModernBertConfig):
  440. super().__init__()
  441. self.config = config
  442. self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
  443. self.act = ACT2FN[config.hidden_activation]
  444. self.drop = nn.Dropout(config.mlp_dropout)
  445. self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
  446. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  447. input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
  448. return self.Wo(self.drop(self.act(input) * gate))
  449. class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
  450. pass
  451. def eager_attention_forward(
  452. module: "ModernBertAttention",
  453. qkv: torch.Tensor,
  454. attention_mask: torch.Tensor,
  455. sliding_window_mask: torch.Tensor,
  456. position_ids: Optional[torch.LongTensor],
  457. local_attention: tuple[int, int],
  458. bs: int,
  459. dim: int,
  460. output_attentions: Optional[bool] = False,
  461. **_kwargs,
  462. ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
  463. # qkv: [batch_size, seqlen, 3, nheads, headdim]
  464. cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
  465. query, key, value = qkv.transpose(3, 1).unbind(dim=2)
  466. # query, key, value: [batch_size, heads, seq_len, head_dim]
  467. query, key = apply_rotary_pos_emb(query, key, cos, sin)
  468. scale = module.head_dim**-0.5
  469. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
  470. if local_attention != (-1, -1):
  471. attention_mask = sliding_window_mask
  472. attn_weights = attn_weights + attention_mask
  473. # upcast attention to fp32
  474. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  475. attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
  476. attn_output = torch.matmul(attn_weights, value)
  477. attn_output = attn_output.transpose(1, 2).contiguous()
  478. attn_output = attn_output.view(bs, -1, dim)
  479. if output_attentions:
  480. return (attn_output, attn_weights)
  481. return (attn_output,)
  482. def flash_attention_forward(
  483. module: "ModernBertAttention",
  484. qkv: torch.Tensor,
  485. rotary_emb: ModernBertUnpaddedRotaryEmbedding,
  486. cu_seqlens: torch.Tensor,
  487. max_seqlen: int,
  488. local_attention: tuple[int, int],
  489. bs: int,
  490. dim: int,
  491. target_dtype: torch.dtype = torch.bfloat16,
  492. **_kwargs,
  493. ) -> tuple[torch.Tensor]:
  494. # (total_seqlen, 3, nheads, headdim)
  495. qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
  496. convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
  497. if convert_dtype:
  498. # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
  499. # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
  500. orig_dtype = qkv.dtype
  501. qkv = qkv.to(target_dtype)
  502. attn = flash_attn_varlen_qkvpacked_func(
  503. qkv,
  504. cu_seqlens=cu_seqlens,
  505. max_seqlen=max_seqlen,
  506. dropout_p=module.attention_dropout if module.training else 0.0,
  507. deterministic=module.deterministic_flash_attn,
  508. window_size=local_attention,
  509. )
  510. attn = attn.to(orig_dtype) # type: ignore
  511. else:
  512. attn = flash_attn_varlen_qkvpacked_func(
  513. qkv,
  514. cu_seqlens=cu_seqlens,
  515. max_seqlen=max_seqlen,
  516. dropout_p=module.attention_dropout if module.training else 0.0,
  517. deterministic=module.deterministic_flash_attn,
  518. window_size=local_attention,
  519. )
  520. return (attn.view(bs, dim),)
  521. def sdpa_attention_forward(
  522. module: "ModernBertAttention",
  523. qkv: torch.Tensor,
  524. attention_mask: torch.Tensor,
  525. sliding_window_mask: torch.Tensor,
  526. position_ids: Optional[torch.LongTensor],
  527. local_attention: tuple[int, int],
  528. bs: int,
  529. dim: int,
  530. **_kwargs,
  531. ) -> tuple[torch.Tensor]:
  532. # qkv: [batch_size, seqlen, 3, nheads, headdim]
  533. cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
  534. query, key, value = qkv.transpose(3, 1).unbind(dim=2)
  535. # query, key, value: [batch_size, heads, seq_len, head_dim]
  536. query, key = apply_rotary_pos_emb(query, key, cos, sin)
  537. if local_attention != (-1, -1):
  538. attention_mask = sliding_window_mask
  539. attn_output = (
  540. F.scaled_dot_product_attention(
  541. query,
  542. key,
  543. value,
  544. dropout_p=module.attention_dropout if module.training else 0.0,
  545. attn_mask=attention_mask,
  546. )
  547. .transpose(1, 2)
  548. .contiguous()
  549. )
  550. attn_output = attn_output.view(bs, -1, dim)
  551. return (attn_output,)
  552. MODERNBERT_ATTENTION_FUNCTION = {
  553. "flash_attention_2": flash_attention_forward,
  554. "eager": eager_attention_forward,
  555. "sdpa": sdpa_attention_forward,
  556. }
  557. class ModernBertAttention(nn.Module):
  558. """Performs multi-headed self attention on a batch of unpadded sequences.
  559. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
  560. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
  561. which requires padding and unpadding inputs, adding some overhead.
  562. See `forward` method for additional details.
  563. """
  564. def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
  565. super().__init__()
  566. self.config = config
  567. self.layer_id = layer_id
  568. if config.hidden_size % config.num_attention_heads != 0:
  569. raise ValueError(
  570. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  571. )
  572. self.attention_dropout = config.attention_dropout
  573. self.deterministic_flash_attn = config.deterministic_flash_attn
  574. self.num_heads = config.num_attention_heads
  575. self.head_dim = config.hidden_size // config.num_attention_heads
  576. self.all_head_size = self.head_dim * self.num_heads
  577. self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
  578. if layer_id % config.global_attn_every_n_layers != 0:
  579. self.local_attention = (config.local_attention // 2, config.local_attention // 2)
  580. rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
  581. max_position_embeddings = config.local_attention
  582. else:
  583. self.local_attention = (-1, -1)
  584. max_position_embeddings = config.max_position_embeddings
  585. rope_theta = config.global_rope_theta
  586. if config._attn_implementation == "flash_attention_2":
  587. self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
  588. dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
  589. )
  590. else:
  591. config_copy = copy.deepcopy(config)
  592. config_copy.rope_theta = rope_theta
  593. self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
  594. self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
  595. self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
  596. self.pruned_heads = set()
  597. def forward(
  598. self,
  599. hidden_states: torch.Tensor,
  600. output_attentions: Optional[bool] = False,
  601. **kwargs,
  602. ) -> torch.Tensor:
  603. qkv = self.Wqkv(hidden_states)
  604. bs = hidden_states.shape[0]
  605. if self.config._attn_implementation == "flash_attention_2":
  606. qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
  607. else:
  608. qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
  609. attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
  610. self,
  611. qkv=qkv,
  612. rotary_emb=self.rotary_emb,
  613. local_attention=self.local_attention,
  614. bs=bs,
  615. dim=self.all_head_size,
  616. output_attentions=output_attentions,
  617. **kwargs,
  618. )
  619. hidden_states = attn_outputs[0]
  620. hidden_states = self.out_drop(self.Wo(hidden_states))
  621. return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
  622. class ModernBertEncoderLayer(GradientCheckpointingLayer):
  623. def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
  624. super().__init__()
  625. self.config = config
  626. if layer_id == 0:
  627. self.attn_norm = nn.Identity()
  628. else:
  629. self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  630. self.attn = ModernBertAttention(config=config, layer_id=layer_id)
  631. self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  632. self.mlp = ModernBertMLP(config)
  633. @torch.compile(dynamic=True)
  634. def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
  635. return self.mlp(self.mlp_norm(hidden_states))
  636. def forward(
  637. self,
  638. hidden_states: torch.Tensor,
  639. attention_mask: Optional[torch.Tensor] = None,
  640. sliding_window_mask: Optional[torch.Tensor] = None,
  641. position_ids: Optional[torch.LongTensor] = None,
  642. cu_seqlens: Optional[torch.Tensor] = None,
  643. max_seqlen: Optional[int] = None,
  644. output_attentions: Optional[bool] = False,
  645. ) -> torch.Tensor:
  646. attn_outputs = self.attn(
  647. self.attn_norm(hidden_states),
  648. attention_mask=attention_mask,
  649. sliding_window_mask=sliding_window_mask,
  650. position_ids=position_ids,
  651. cu_seqlens=cu_seqlens,
  652. max_seqlen=max_seqlen,
  653. output_attentions=output_attentions,
  654. )
  655. hidden_states = hidden_states + attn_outputs[0]
  656. mlp_output = (
  657. self.compiled_mlp(hidden_states)
  658. if self.config.reference_compile
  659. else self.mlp(self.mlp_norm(hidden_states))
  660. )
  661. hidden_states = hidden_states + mlp_output
  662. return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
  663. @auto_docstring
  664. class ModernBertPreTrainedModel(PreTrainedModel):
  665. config: ModernBertConfig
  666. base_model_prefix = "model"
  667. supports_gradient_checkpointing = True
  668. _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
  669. _supports_flash_attn = True
  670. _supports_sdpa = True
  671. _supports_flex_attn = False
  672. def _init_weights(self, module: nn.Module):
  673. cutoff_factor = self.config.initializer_cutoff_factor
  674. if cutoff_factor is None:
  675. cutoff_factor = 3
  676. def init_weight(module: nn.Module, std: float):
  677. nn.init.trunc_normal_(
  678. module.weight,
  679. mean=0.0,
  680. std=std,
  681. a=-cutoff_factor * std,
  682. b=cutoff_factor * std,
  683. )
  684. if isinstance(module, nn.Linear):
  685. if module.bias is not None:
  686. nn.init.zeros_(module.bias)
  687. stds = {
  688. "in": self.config.initializer_range,
  689. "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
  690. "embedding": self.config.initializer_range,
  691. "final_out": self.config.hidden_size**-0.5,
  692. }
  693. if isinstance(module, ModernBertEmbeddings):
  694. init_weight(module.tok_embeddings, stds["embedding"])
  695. elif isinstance(module, ModernBertMLP):
  696. init_weight(module.Wi, stds["in"])
  697. init_weight(module.Wo, stds["out"])
  698. elif isinstance(module, ModernBertAttention):
  699. init_weight(module.Wqkv, stds["in"])
  700. init_weight(module.Wo, stds["out"])
  701. elif isinstance(module, ModernBertPredictionHead):
  702. init_weight(module.dense, stds["out"])
  703. elif isinstance(module, ModernBertForMaskedLM):
  704. init_weight(module.decoder, stds["out"])
  705. elif isinstance(
  706. module,
  707. (
  708. ModernBertForSequenceClassification,
  709. ModernBertForMultipleChoice,
  710. ModernBertForTokenClassification,
  711. ModernBertForQuestionAnswering,
  712. ),
  713. ):
  714. init_weight(module.classifier, stds["final_out"])
  715. elif isinstance(module, nn.LayerNorm):
  716. module.weight.data.fill_(1.0)
  717. if module.bias is not None:
  718. module.bias.data.zero_()
  719. def _check_and_adjust_attn_implementation(
  720. self, attn_implementation: Optional[str], is_init_check: bool = False
  721. ) -> str:
  722. """
  723. Checks and dispatches to hhe requested attention implementation.
  724. """
  725. # If the user didn't specify anything, try to use flash_attention_2 if available.
  726. # Otherwise we fall back to the default SDPA -> Eager from the super() method.
  727. # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
  728. # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
  729. try:
  730. attn_implementation = (
  731. "flash_attention_2"
  732. if attn_implementation is None and self._flash_attn_2_can_dispatch()
  733. else attn_implementation
  734. )
  735. except (ValueError, ImportError):
  736. pass
  737. return super()._check_and_adjust_attn_implementation(
  738. attn_implementation=attn_implementation, is_init_check=is_init_check
  739. )
  740. def _maybe_set_compile(self):
  741. if self.config.reference_compile is False:
  742. return
  743. if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
  744. if self.config.reference_compile:
  745. logger.warning_once(
  746. "If `accelerate` split the model across devices, `torch.compile` will not work. "
  747. "Falling back to non-compiled mode."
  748. )
  749. self.config.reference_compile = False
  750. if self.device.type == "mps":
  751. if self.config.reference_compile:
  752. logger.warning_once(
  753. "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
  754. "Falling back to non-compiled mode."
  755. )
  756. self.config.reference_compile = False
  757. if self.device.type == "cpu":
  758. if self.config.reference_compile:
  759. logger.warning_once(
  760. "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
  761. "Falling back to non-compiled mode."
  762. )
  763. self.config.reference_compile = False
  764. if self.config.reference_compile is None:
  765. self.config.reference_compile = is_triton_available()
  766. def resize_token_embeddings(self, *args, **kwargs):
  767. model_embeds = super().resize_token_embeddings(*args, **kwargs)
  768. if self.config.reference_compile in {True, None}:
  769. if self.config.reference_compile:
  770. logger.warning_once(
  771. "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
  772. )
  773. self.config.reference_compile = False
  774. return model_embeds
  775. @auto_docstring
  776. class ModernBertModel(ModernBertPreTrainedModel):
  777. def __init__(self, config: ModernBertConfig):
  778. super().__init__(config)
  779. self.config = config
  780. self.embeddings = ModernBertEmbeddings(config)
  781. self.layers = nn.ModuleList(
  782. [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
  783. )
  784. self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  785. self.gradient_checkpointing = False
  786. self.post_init()
  787. def get_input_embeddings(self):
  788. return self.embeddings.tok_embeddings
  789. def set_input_embeddings(self, value):
  790. self.embeddings.tok_embeddings = value
  791. @auto_docstring
  792. def forward(
  793. self,
  794. input_ids: Optional[torch.LongTensor] = None,
  795. attention_mask: Optional[torch.Tensor] = None,
  796. sliding_window_mask: Optional[torch.Tensor] = None,
  797. position_ids: Optional[torch.LongTensor] = None,
  798. inputs_embeds: Optional[torch.Tensor] = None,
  799. indices: Optional[torch.Tensor] = None,
  800. cu_seqlens: Optional[torch.Tensor] = None,
  801. max_seqlen: Optional[int] = None,
  802. batch_size: Optional[int] = None,
  803. seq_len: Optional[int] = None,
  804. output_attentions: Optional[bool] = None,
  805. output_hidden_states: Optional[bool] = None,
  806. return_dict: Optional[bool] = None,
  807. ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
  808. r"""
  809. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  810. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  811. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  812. far-away tokens in the local attention layers when not using Flash Attention.
  813. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  814. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  815. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  816. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  817. max_seqlen (`int`, *optional*):
  818. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  819. batch_size (`int`, *optional*):
  820. Batch size of the input sequences. Used to pad the output tensors.
  821. seq_len (`int`, *optional*):
  822. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  823. """
  824. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  825. output_hidden_states = (
  826. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  827. )
  828. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  829. if (input_ids is None) ^ (inputs_embeds is not None):
  830. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  831. all_hidden_states = () if output_hidden_states else None
  832. all_self_attentions = () if output_attentions else None
  833. self._maybe_set_compile()
  834. if input_ids is not None:
  835. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  836. if batch_size is None and seq_len is None:
  837. if inputs_embeds is not None:
  838. batch_size, seq_len = inputs_embeds.shape[:2]
  839. else:
  840. batch_size, seq_len = input_ids.shape[:2]
  841. device = input_ids.device if input_ids is not None else inputs_embeds.device
  842. if attention_mask is None:
  843. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  844. repad = False
  845. if self.config._attn_implementation == "flash_attention_2":
  846. if indices is None and cu_seqlens is None and max_seqlen is None:
  847. repad = True
  848. if inputs_embeds is None:
  849. with torch.no_grad():
  850. input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
  851. inputs=input_ids, attention_mask=attention_mask
  852. )
  853. else:
  854. inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
  855. inputs=inputs_embeds, attention_mask=attention_mask
  856. )
  857. else:
  858. if position_ids is None:
  859. position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
  860. attention_mask, sliding_window_mask = self._update_attention_mask(
  861. attention_mask, output_attentions=output_attentions
  862. )
  863. hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
  864. for encoder_layer in self.layers:
  865. if output_hidden_states:
  866. all_hidden_states = all_hidden_states + (hidden_states,)
  867. layer_outputs = encoder_layer(
  868. hidden_states,
  869. attention_mask=attention_mask,
  870. sliding_window_mask=sliding_window_mask,
  871. position_ids=position_ids,
  872. cu_seqlens=cu_seqlens,
  873. max_seqlen=max_seqlen,
  874. output_attentions=output_attentions,
  875. )
  876. hidden_states = layer_outputs[0]
  877. if output_attentions and len(layer_outputs) > 1:
  878. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  879. if output_hidden_states:
  880. all_hidden_states = all_hidden_states + (hidden_states,)
  881. hidden_states = self.final_norm(hidden_states)
  882. if repad:
  883. hidden_states = _pad_modernbert_output(
  884. inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
  885. )
  886. if all_hidden_states is not None:
  887. all_hidden_states = tuple(
  888. _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
  889. for hs in all_hidden_states
  890. )
  891. # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
  892. # dimension missing
  893. elif (
  894. self.config._attn_implementation == "flash_attention_2"
  895. and all_hidden_states is not None
  896. and all_hidden_states[-1].dim() == 2
  897. ):
  898. hidden_states = hidden_states.unsqueeze(0)
  899. all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
  900. if not return_dict:
  901. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  902. return BaseModelOutput(
  903. last_hidden_state=hidden_states,
  904. hidden_states=all_hidden_states,
  905. attentions=all_self_attentions,
  906. )
  907. def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
  908. if output_attentions:
  909. if self.config._attn_implementation == "sdpa":
  910. logger.warning_once(
  911. "Outputting attentions is only supported with the 'eager' attention implementation, "
  912. 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
  913. )
  914. self.config._attn_implementation = "eager"
  915. elif self.config._attn_implementation != "eager":
  916. logger.warning_once(
  917. "Outputting attentions is only supported with the eager attention implementation, "
  918. f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
  919. " Setting `output_attentions=False`."
  920. )
  921. global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
  922. # Create position indices
  923. rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
  924. # Calculate distance between positions
  925. distance = torch.abs(rows - rows.T)
  926. # Create sliding window mask (1 for positions within window, 0 outside)
  927. window_mask = (
  928. (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
  929. )
  930. # Combine with existing mask
  931. sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
  932. return global_attention_mask, sliding_window_mask
  933. class ModernBertPredictionHead(nn.Module):
  934. def __init__(self, config: ModernBertConfig):
  935. super().__init__()
  936. self.config = config
  937. self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
  938. self.act = ACT2FN[config.classifier_activation]
  939. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  940. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  941. return self.norm(self.act(self.dense(hidden_states)))
  942. @auto_docstring(
  943. custom_intro="""
  944. The ModernBert Model with a decoder head on top that is used for masked language modeling.
  945. """
  946. )
  947. class ModernBertForMaskedLM(ModernBertPreTrainedModel):
  948. _tied_weights_keys = ["decoder.weight"]
  949. def __init__(self, config: ModernBertConfig):
  950. super().__init__(config)
  951. self.config = config
  952. self.model = ModernBertModel(config)
  953. self.head = ModernBertPredictionHead(config)
  954. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
  955. self.sparse_prediction = self.config.sparse_prediction
  956. self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
  957. # Initialize weights and apply final processing
  958. self.post_init()
  959. def get_output_embeddings(self):
  960. return self.decoder
  961. def set_output_embeddings(self, new_embeddings: nn.Linear):
  962. self.decoder = new_embeddings
  963. @torch.compile(dynamic=True)
  964. def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
  965. return self.decoder(self.head(output))
  966. @auto_docstring
  967. def forward(
  968. self,
  969. input_ids: Optional[torch.LongTensor] = None,
  970. attention_mask: Optional[torch.Tensor] = None,
  971. sliding_window_mask: Optional[torch.Tensor] = None,
  972. position_ids: Optional[torch.Tensor] = None,
  973. inputs_embeds: Optional[torch.Tensor] = None,
  974. labels: Optional[torch.Tensor] = None,
  975. indices: Optional[torch.Tensor] = None,
  976. cu_seqlens: Optional[torch.Tensor] = None,
  977. max_seqlen: Optional[int] = None,
  978. batch_size: Optional[int] = None,
  979. seq_len: Optional[int] = None,
  980. output_attentions: Optional[bool] = None,
  981. output_hidden_states: Optional[bool] = None,
  982. return_dict: Optional[bool] = None,
  983. **kwargs,
  984. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  985. r"""
  986. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  987. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  988. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  989. far-away tokens in the local attention layers when not using Flash Attention.
  990. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  991. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  992. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  993. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  994. max_seqlen (`int`, *optional*):
  995. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  996. batch_size (`int`, *optional*):
  997. Batch size of the input sequences. Used to pad the output tensors.
  998. seq_len (`int`, *optional*):
  999. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1000. """
  1001. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1002. self._maybe_set_compile()
  1003. if self.config._attn_implementation == "flash_attention_2":
  1004. if indices is None and cu_seqlens is None and max_seqlen is None:
  1005. if batch_size is None and seq_len is None:
  1006. if inputs_embeds is not None:
  1007. batch_size, seq_len = inputs_embeds.shape[:2]
  1008. else:
  1009. batch_size, seq_len = input_ids.shape[:2]
  1010. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1011. if attention_mask is None:
  1012. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  1013. if inputs_embeds is None:
  1014. with torch.no_grad():
  1015. input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
  1016. inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
  1017. )
  1018. else:
  1019. inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
  1020. inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
  1021. )
  1022. outputs = self.model(
  1023. input_ids=input_ids,
  1024. attention_mask=attention_mask,
  1025. sliding_window_mask=sliding_window_mask,
  1026. position_ids=position_ids,
  1027. inputs_embeds=inputs_embeds,
  1028. indices=indices,
  1029. cu_seqlens=cu_seqlens,
  1030. max_seqlen=max_seqlen,
  1031. batch_size=batch_size,
  1032. seq_len=seq_len,
  1033. output_attentions=output_attentions,
  1034. output_hidden_states=output_hidden_states,
  1035. return_dict=return_dict,
  1036. )
  1037. last_hidden_state = outputs[0]
  1038. if self.sparse_prediction and labels is not None:
  1039. # flatten labels and output first
  1040. labels = labels.view(-1)
  1041. last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
  1042. # then filter out the non-masked tokens
  1043. mask_tokens = labels != self.sparse_pred_ignore_index
  1044. last_hidden_state = last_hidden_state[mask_tokens]
  1045. labels = labels[mask_tokens]
  1046. logits = (
  1047. self.compiled_head(last_hidden_state)
  1048. if self.config.reference_compile
  1049. else self.decoder(self.head(last_hidden_state))
  1050. )
  1051. loss = None
  1052. if labels is not None:
  1053. loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
  1054. if self.config._attn_implementation == "flash_attention_2":
  1055. # Logits padding
  1056. with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
  1057. logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
  1058. # Hidden states padding
  1059. if getattr(outputs, "hidden_states", None) is not None:
  1060. padded_hidden_states = []
  1061. for hs in outputs.hidden_states:
  1062. if hs.dim() == 3 and hs.shape[0] == 1:
  1063. hs = hs.squeeze(0)
  1064. padded_hidden_states.append(
  1065. _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
  1066. )
  1067. outputs.hidden_states = tuple(padded_hidden_states)
  1068. if not return_dict:
  1069. output = (logits,)
  1070. return ((loss,) + output) if loss is not None else output
  1071. return MaskedLMOutput(
  1072. loss=loss,
  1073. logits=logits,
  1074. hidden_states=outputs.hidden_states,
  1075. attentions=outputs.attentions,
  1076. )
  1077. @auto_docstring(
  1078. custom_intro="""
  1079. The ModernBert Model with a sequence classification head on top that performs pooling.
  1080. """
  1081. )
  1082. class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
  1083. def __init__(self, config: ModernBertConfig):
  1084. super().__init__(config)
  1085. self.num_labels = config.num_labels
  1086. self.config = config
  1087. self.model = ModernBertModel(config)
  1088. self.head = ModernBertPredictionHead(config)
  1089. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1090. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1091. # Initialize weights and apply final processing
  1092. self.post_init()
  1093. @auto_docstring
  1094. def forward(
  1095. self,
  1096. input_ids: Optional[torch.LongTensor] = None,
  1097. attention_mask: Optional[torch.Tensor] = None,
  1098. sliding_window_mask: Optional[torch.Tensor] = None,
  1099. position_ids: Optional[torch.Tensor] = None,
  1100. inputs_embeds: Optional[torch.Tensor] = None,
  1101. labels: Optional[torch.Tensor] = None,
  1102. indices: Optional[torch.Tensor] = None,
  1103. cu_seqlens: Optional[torch.Tensor] = None,
  1104. max_seqlen: Optional[int] = None,
  1105. batch_size: Optional[int] = None,
  1106. seq_len: Optional[int] = None,
  1107. output_attentions: Optional[bool] = None,
  1108. output_hidden_states: Optional[bool] = None,
  1109. return_dict: Optional[bool] = None,
  1110. **kwargs,
  1111. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  1112. r"""
  1113. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1114. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1115. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1116. far-away tokens in the local attention layers when not using Flash Attention.
  1117. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1118. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1119. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1120. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1121. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1122. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1123. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1124. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1125. max_seqlen (`int`, *optional*):
  1126. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1127. batch_size (`int`, *optional*):
  1128. Batch size of the input sequences. Used to pad the output tensors.
  1129. seq_len (`int`, *optional*):
  1130. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1131. """
  1132. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1133. self._maybe_set_compile()
  1134. if input_ids is not None:
  1135. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  1136. if batch_size is None and seq_len is None:
  1137. if inputs_embeds is not None:
  1138. batch_size, seq_len = inputs_embeds.shape[:2]
  1139. else:
  1140. batch_size, seq_len = input_ids.shape[:2]
  1141. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1142. if attention_mask is None:
  1143. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  1144. outputs = self.model(
  1145. input_ids=input_ids,
  1146. attention_mask=attention_mask,
  1147. sliding_window_mask=sliding_window_mask,
  1148. position_ids=position_ids,
  1149. inputs_embeds=inputs_embeds,
  1150. indices=indices,
  1151. cu_seqlens=cu_seqlens,
  1152. max_seqlen=max_seqlen,
  1153. batch_size=batch_size,
  1154. seq_len=seq_len,
  1155. output_attentions=output_attentions,
  1156. output_hidden_states=output_hidden_states,
  1157. return_dict=return_dict,
  1158. )
  1159. last_hidden_state = outputs[0]
  1160. if self.config.classifier_pooling == "cls":
  1161. last_hidden_state = last_hidden_state[:, 0]
  1162. elif self.config.classifier_pooling == "mean":
  1163. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
  1164. dim=1, keepdim=True
  1165. )
  1166. pooled_output = self.head(last_hidden_state)
  1167. pooled_output = self.drop(pooled_output)
  1168. logits = self.classifier(pooled_output)
  1169. loss = None
  1170. if labels is not None:
  1171. if self.config.problem_type is None:
  1172. if self.num_labels == 1:
  1173. self.config.problem_type = "regression"
  1174. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1175. self.config.problem_type = "single_label_classification"
  1176. else:
  1177. self.config.problem_type = "multi_label_classification"
  1178. if self.config.problem_type == "regression":
  1179. loss_fct = MSELoss()
  1180. if self.num_labels == 1:
  1181. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1182. else:
  1183. loss = loss_fct(logits, labels)
  1184. elif self.config.problem_type == "single_label_classification":
  1185. loss_fct = CrossEntropyLoss()
  1186. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1187. elif self.config.problem_type == "multi_label_classification":
  1188. loss_fct = BCEWithLogitsLoss()
  1189. loss = loss_fct(logits, labels)
  1190. if not return_dict:
  1191. output = (logits,)
  1192. return ((loss,) + output) if loss is not None else output
  1193. return SequenceClassifierOutput(
  1194. loss=loss,
  1195. logits=logits,
  1196. hidden_states=outputs.hidden_states,
  1197. attentions=outputs.attentions,
  1198. )
  1199. @auto_docstring(
  1200. custom_intro="""
  1201. The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
  1202. """
  1203. )
  1204. class ModernBertForTokenClassification(ModernBertPreTrainedModel):
  1205. def __init__(self, config: ModernBertConfig):
  1206. super().__init__(config)
  1207. self.num_labels = config.num_labels
  1208. self.model = ModernBertModel(config)
  1209. self.head = ModernBertPredictionHead(config)
  1210. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1211. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1212. # Initialize weights and apply final processing
  1213. self.post_init()
  1214. @auto_docstring
  1215. def forward(
  1216. self,
  1217. input_ids: Optional[torch.LongTensor] = None,
  1218. attention_mask: Optional[torch.Tensor] = None,
  1219. sliding_window_mask: Optional[torch.Tensor] = None,
  1220. position_ids: Optional[torch.Tensor] = None,
  1221. inputs_embeds: Optional[torch.Tensor] = None,
  1222. labels: Optional[torch.Tensor] = None,
  1223. indices: Optional[torch.Tensor] = None,
  1224. cu_seqlens: Optional[torch.Tensor] = None,
  1225. max_seqlen: Optional[int] = None,
  1226. batch_size: Optional[int] = None,
  1227. seq_len: Optional[int] = None,
  1228. output_attentions: Optional[bool] = None,
  1229. output_hidden_states: Optional[bool] = None,
  1230. return_dict: Optional[bool] = None,
  1231. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1232. r"""
  1233. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1234. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1235. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1236. far-away tokens in the local attention layers when not using Flash Attention.
  1237. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1238. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1239. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1240. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1241. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1242. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1243. max_seqlen (`int`, *optional*):
  1244. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1245. batch_size (`int`, *optional*):
  1246. Batch size of the input sequences. Used to pad the output tensors.
  1247. seq_len (`int`, *optional*):
  1248. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1249. """
  1250. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1251. self._maybe_set_compile()
  1252. outputs = self.model(
  1253. input_ids=input_ids,
  1254. attention_mask=attention_mask,
  1255. sliding_window_mask=sliding_window_mask,
  1256. position_ids=position_ids,
  1257. inputs_embeds=inputs_embeds,
  1258. indices=indices,
  1259. cu_seqlens=cu_seqlens,
  1260. max_seqlen=max_seqlen,
  1261. batch_size=batch_size,
  1262. seq_len=seq_len,
  1263. output_attentions=output_attentions,
  1264. output_hidden_states=output_hidden_states,
  1265. return_dict=return_dict,
  1266. )
  1267. last_hidden_state = outputs[0]
  1268. last_hidden_state = self.head(last_hidden_state)
  1269. last_hidden_state = self.drop(last_hidden_state)
  1270. logits = self.classifier(last_hidden_state)
  1271. loss = None
  1272. if labels is not None:
  1273. loss_fct = CrossEntropyLoss()
  1274. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1275. if not return_dict:
  1276. output = (logits,) + outputs[1:]
  1277. return ((loss,) + output) if loss is not None else output
  1278. return TokenClassifierOutput(
  1279. loss=loss,
  1280. logits=logits,
  1281. hidden_states=outputs.hidden_states,
  1282. attentions=outputs.attentions,
  1283. )
  1284. @auto_docstring
  1285. class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
  1286. def __init__(self, config: ModernBertConfig):
  1287. super().__init__(config)
  1288. self.num_labels = config.num_labels
  1289. self.model = ModernBertModel(config)
  1290. self.head = ModernBertPredictionHead(config)
  1291. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1292. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1293. self.post_init()
  1294. @auto_docstring
  1295. def forward(
  1296. self,
  1297. input_ids: Optional[torch.Tensor],
  1298. attention_mask: Optional[torch.Tensor] = None,
  1299. sliding_window_mask: Optional[torch.Tensor] = None,
  1300. position_ids: Optional[torch.Tensor] = None,
  1301. start_positions: Optional[torch.Tensor] = None,
  1302. end_positions: Optional[torch.Tensor] = None,
  1303. indices: Optional[torch.Tensor] = None,
  1304. cu_seqlens: Optional[torch.Tensor] = None,
  1305. max_seqlen: Optional[int] = None,
  1306. batch_size: Optional[int] = None,
  1307. seq_len: Optional[int] = None,
  1308. output_attentions: Optional[bool] = None,
  1309. output_hidden_states: Optional[bool] = None,
  1310. return_dict: Optional[bool] = None,
  1311. **kwargs,
  1312. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1313. r"""
  1314. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1315. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1316. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1317. far-away tokens in the local attention layers when not using Flash Attention.
  1318. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1319. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1320. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1321. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1322. max_seqlen (`int`, *optional*):
  1323. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1324. batch_size (`int`, *optional*):
  1325. Batch size of the input sequences. Used to pad the output tensors.
  1326. seq_len (`int`, *optional*):
  1327. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1328. """
  1329. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1330. self._maybe_set_compile()
  1331. outputs = self.model(
  1332. input_ids,
  1333. attention_mask=attention_mask,
  1334. sliding_window_mask=sliding_window_mask,
  1335. position_ids=position_ids,
  1336. indices=indices,
  1337. cu_seqlens=cu_seqlens,
  1338. max_seqlen=max_seqlen,
  1339. batch_size=batch_size,
  1340. seq_len=seq_len,
  1341. output_attentions=output_attentions,
  1342. output_hidden_states=output_hidden_states,
  1343. return_dict=return_dict,
  1344. )
  1345. last_hidden_state = outputs[0]
  1346. last_hidden_state = self.head(last_hidden_state)
  1347. last_hidden_state = self.drop(last_hidden_state)
  1348. logits = self.classifier(last_hidden_state)
  1349. start_logits, end_logits = logits.split(1, dim=-1)
  1350. start_logits = start_logits.squeeze(-1).contiguous()
  1351. end_logits = end_logits.squeeze(-1).contiguous()
  1352. loss = None
  1353. if start_positions is not None and end_positions is not None:
  1354. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  1355. if not return_dict:
  1356. output = (start_logits, end_logits) + outputs[1:]
  1357. return ((loss,) + output) if loss is not None else output
  1358. return QuestionAnsweringModelOutput(
  1359. loss=loss,
  1360. start_logits=start_logits,
  1361. end_logits=end_logits,
  1362. hidden_states=outputs.hidden_states,
  1363. attentions=outputs.attentions,
  1364. )
  1365. @auto_docstring(
  1366. custom_intro="""
  1367. The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
  1368. """
  1369. )
  1370. class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
  1371. def __init__(self, config: ModernBertConfig):
  1372. super().__init__(config)
  1373. self.config = config
  1374. self.model = ModernBertModel(config)
  1375. self.head = ModernBertPredictionHead(config)
  1376. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1377. self.classifier = nn.Linear(config.hidden_size, 1)
  1378. # Initialize weights and apply final processing
  1379. self.post_init()
  1380. @auto_docstring
  1381. def forward(
  1382. self,
  1383. input_ids: Optional[torch.LongTensor] = None,
  1384. attention_mask: Optional[torch.Tensor] = None,
  1385. sliding_window_mask: Optional[torch.Tensor] = None,
  1386. position_ids: Optional[torch.Tensor] = None,
  1387. inputs_embeds: Optional[torch.Tensor] = None,
  1388. labels: Optional[torch.Tensor] = None,
  1389. indices: Optional[torch.Tensor] = None,
  1390. cu_seqlens: Optional[torch.Tensor] = None,
  1391. max_seqlen: Optional[int] = None,
  1392. batch_size: Optional[int] = None,
  1393. seq_len: Optional[int] = None,
  1394. output_attentions: Optional[bool] = None,
  1395. output_hidden_states: Optional[bool] = None,
  1396. return_dict: Optional[bool] = None,
  1397. **kwargs,
  1398. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1399. r"""
  1400. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1401. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1402. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1403. far-away tokens in the local attention layers when not using Flash Attention.
  1404. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1405. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1406. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
  1407. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1408. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1409. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1410. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1411. max_seqlen (`int`, *optional*):
  1412. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1413. batch_size (`int`, *optional*):
  1414. Batch size of the input sequences. Used to pad the output tensors.
  1415. seq_len (`int`, *optional*):
  1416. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1417. """
  1418. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1419. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1420. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1421. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1422. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1423. inputs_embeds = (
  1424. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1425. if inputs_embeds is not None
  1426. else None
  1427. )
  1428. self._maybe_set_compile()
  1429. outputs = self.model(
  1430. input_ids=input_ids,
  1431. attention_mask=attention_mask,
  1432. sliding_window_mask=sliding_window_mask,
  1433. position_ids=position_ids,
  1434. inputs_embeds=inputs_embeds,
  1435. indices=indices,
  1436. cu_seqlens=cu_seqlens,
  1437. max_seqlen=max_seqlen,
  1438. batch_size=batch_size,
  1439. seq_len=seq_len,
  1440. output_attentions=output_attentions,
  1441. output_hidden_states=output_hidden_states,
  1442. return_dict=return_dict,
  1443. )
  1444. last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
  1445. # If classifier_pooling is "cls", isolate the <cls> token
  1446. if self.config.classifier_pooling == "cls":
  1447. indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
  1448. # for left or right padding, <cls> is the first non-pad token
  1449. if attention_mask is not None:
  1450. cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
  1451. # if no pad, <cls> is the first token
  1452. else:
  1453. cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
  1454. # extract the <cls> token for the logits
  1455. last_hidden_state = last_hidden_state[indices_0, cls_mask]
  1456. # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
  1457. elif self.config.classifier_pooling == "mean":
  1458. num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
  1459. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
  1460. pooled_output = self.head(last_hidden_state)
  1461. pooled_output = self.drop(pooled_output)
  1462. logits = self.classifier(pooled_output)
  1463. reshaped_logits = logits.view(-1, num_choices)
  1464. loss = None
  1465. if labels is not None:
  1466. loss_fct = nn.CrossEntropyLoss()
  1467. loss = loss_fct(reshaped_logits, labels)
  1468. if not return_dict:
  1469. output = (reshaped_logits,) + outputs[1:]
  1470. return ((loss,) + output) if loss is not None else output
  1471. return MultipleChoiceModelOutput(
  1472. loss=loss,
  1473. logits=reshaped_logits,
  1474. hidden_states=outputs.hidden_states,
  1475. attentions=outputs.attentions,
  1476. )
  1477. __all__ = [
  1478. "ModernBertConfig",
  1479. "ModernBertModel",
  1480. "ModernBertPreTrainedModel",
  1481. "ModernBertForMaskedLM",
  1482. "ModernBertForSequenceClassification",
  1483. "ModernBertForTokenClassification",
  1484. "ModernBertForQuestionAnswering",
  1485. "ModernBertForMultipleChoice",
  1486. ]