modeling_modernbert.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_modernbert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import copy
  22. import math
  23. from contextlib import nullcontext
  24. from typing import Optional, Union
  25. import torch
  26. import torch.nn.functional as F
  27. from torch import nn
  28. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  29. from ...activations import ACT2FN
  30. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. MaskedLMOutput,
  35. MultipleChoiceModelOutput,
  36. QuestionAnsweringModelOutput,
  37. SequenceClassifierOutput,
  38. TokenClassifierOutput,
  39. )
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import PreTrainedModel
  42. from ...utils import auto_docstring, is_flash_attn_2_available, logging
  43. from ...utils.import_utils import is_triton_available
  44. from .configuration_modernbert import ModernBertConfig
  45. if is_flash_attn_2_available():
  46. from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
  47. from flash_attn.layers.rotary import RotaryEmbedding
  48. from flash_attn.ops.triton.rotary import apply_rotary
  49. else:
  50. RotaryEmbedding = object
  51. logger = logging.get_logger(__name__)
  52. class ApplyRotaryEmbUnpad(torch.autograd.Function):
  53. @staticmethod
  54. def forward(
  55. ctx,
  56. qkv,
  57. cos,
  58. sin,
  59. cu_seqlens: Optional[torch.Tensor] = None,
  60. max_seqlen: Optional[int] = None,
  61. ):
  62. # (total_nnz, 3, nheads, headdim)
  63. qkv = qkv.contiguous()
  64. total_nnz, _three, _nheads, headdim = qkv.shape
  65. # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
  66. # we get the same tensor
  67. # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
  68. qk = qkv[:, :2].view(total_nnz, -1, headdim)
  69. apply_rotary(
  70. qk,
  71. cos,
  72. sin,
  73. seqlen_offsets=0,
  74. cu_seqlens=cu_seqlens,
  75. max_seqlen=max_seqlen,
  76. interleaved=False,
  77. inplace=True,
  78. )
  79. ctx.save_for_backward(cos, sin, cu_seqlens)
  80. ctx.max_seqlen = max_seqlen
  81. return qkv
  82. @staticmethod
  83. def backward(ctx, do):
  84. cos, sin, cu_seqlens = ctx.saved_tensors
  85. do = do.contiguous()
  86. total_nnz, _three, _nheads, headdim = do.shape
  87. # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
  88. # we get the same tensor
  89. dqk = do[:, :2].view(total_nnz, -1, headdim)
  90. apply_rotary(
  91. dqk,
  92. cos,
  93. sin,
  94. seqlen_offsets=0,
  95. cu_seqlens=cu_seqlens,
  96. max_seqlen=ctx.max_seqlen,
  97. interleaved=False,
  98. inplace=True,
  99. conjugate=True,
  100. )
  101. return do, None, None, None, None, None, None
  102. def apply_rotary_unpadded(
  103. qkv,
  104. cos,
  105. sin,
  106. cu_seqlens: Optional[torch.Tensor] = None,
  107. max_seqlen: Optional[int] = None,
  108. ):
  109. """
  110. Arguments:
  111. qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
  112. cos, sin: (seqlen_rotary, rotary_dim / 2)
  113. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  114. of 1st half and 2nd half (GPT-NeoX style).
  115. inplace: if True, apply rotary embedding in-place.
  116. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
  117. Most commonly used in inference when we have KV cache.
  118. cu_seqlens: (batch + 1,) or None
  119. max_seqlen: int
  120. Return:
  121. out: (total_nnz, dim)
  122. rotary_dim must be <= headdim
  123. Apply rotary embedding to the first rotary_dim of x.
  124. """
  125. return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
  126. class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
  127. """
  128. The rotary position embeddings applied directly to unpadded sequences.
  129. """
  130. def __init__(
  131. self,
  132. dim: int,
  133. base: float = 10000.0,
  134. max_seqlen: Optional[int] = None,
  135. device: Optional[torch.device] = None,
  136. dtype: Optional[torch.dtype] = None,
  137. ):
  138. """
  139. max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
  140. up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
  141. the cos_sin_cache will be recomputed during the forward pass.
  142. """
  143. super().__init__(dim=dim, base=base, device=device, interleaved=False)
  144. self.max_seqlen = max_seqlen
  145. if max_seqlen is not None and device is not None and dtype is not None:
  146. self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
  147. def forward(
  148. self,
  149. qkv: torch.Tensor,
  150. cu_seqlens: torch.Tensor,
  151. max_seqlen: Optional[int] = None,
  152. ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  153. """
  154. Apply rotary embedding *inplace* to qkv.
  155. qkv: (total_nnz, 3, nheads, headdim)
  156. cu_seqlens: (batch + 1,) cumulative sequence lengths
  157. max_seqlen: int max seq length in the batch
  158. """
  159. if max_seqlen is not None:
  160. self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
  161. qkv = apply_rotary_unpadded(
  162. qkv,
  163. self._cos_cached,
  164. self._sin_cached,
  165. cu_seqlens=cu_seqlens,
  166. max_seqlen=max_seqlen,
  167. )
  168. return qkv
  169. def extra_repr(self) -> str:
  170. return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
  171. class ModernBertEmbeddings(nn.Module):
  172. """
  173. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  174. """
  175. def __init__(self, config: ModernBertConfig):
  176. super().__init__()
  177. self.config = config
  178. self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  179. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  180. self.drop = nn.Dropout(config.embedding_dropout)
  181. @torch.compile(dynamic=True)
  182. def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
  183. return self.drop(self.norm(self.tok_embeddings(input_ids)))
  184. def forward(
  185. self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
  186. ) -> torch.Tensor:
  187. if inputs_embeds is not None:
  188. hidden_states = self.drop(self.norm(inputs_embeds))
  189. else:
  190. hidden_states = (
  191. self.compiled_embeddings(input_ids)
  192. if self.config.reference_compile
  193. else self.drop(self.norm(self.tok_embeddings(input_ids)))
  194. )
  195. return hidden_states
  196. class ModernBertMLP(nn.Module):
  197. """Applies the GLU at the end of each ModernBERT layer.
  198. Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
  199. and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
  200. """
  201. def __init__(self, config: ModernBertConfig):
  202. super().__init__()
  203. self.config = config
  204. self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
  205. self.act = ACT2FN[config.hidden_activation]
  206. self.drop = nn.Dropout(config.mlp_dropout)
  207. self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
  208. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  209. input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
  210. return self.Wo(self.drop(self.act(input) * gate))
  211. class ModernBertRotaryEmbedding(nn.Module):
  212. inv_freq: torch.Tensor # fix linting for `register_buffer`
  213. def __init__(self, config: ModernBertConfig, device=None):
  214. super().__init__()
  215. # BC: "rope_type" was originally "type"
  216. if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
  217. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  218. else:
  219. self.rope_type = "default"
  220. self.max_seq_len_cached = config.max_position_embeddings
  221. self.original_max_seq_len = config.max_position_embeddings
  222. self.config = config
  223. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  224. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
  225. self.register_buffer("inv_freq", inv_freq, persistent=False)
  226. self.original_inv_freq = self.inv_freq
  227. @torch.no_grad()
  228. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  229. def forward(self, x, position_ids):
  230. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  231. position_ids_expanded = position_ids[:, None, :].float()
  232. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  233. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  234. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  235. emb = torch.cat((freqs, freqs), dim=-1)
  236. cos = emb.cos() * self.attention_scaling
  237. sin = emb.sin() * self.attention_scaling
  238. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  239. def rotate_half(x):
  240. """Rotates half the hidden dims of the input."""
  241. x1 = x[..., : x.shape[-1] // 2]
  242. x2 = x[..., x.shape[-1] // 2 :]
  243. return torch.cat((-x2, x1), dim=-1)
  244. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  245. """Applies Rotary Position Embedding to the query and key tensors.
  246. Args:
  247. q (`torch.Tensor`): The query tensor.
  248. k (`torch.Tensor`): The key tensor.
  249. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  250. sin (`torch.Tensor`): The sine part of the rotary embedding.
  251. position_ids (`torch.Tensor`, *optional*):
  252. Deprecated and unused.
  253. unsqueeze_dim (`int`, *optional*, defaults to 1):
  254. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  255. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  256. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  257. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  258. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  259. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  260. Returns:
  261. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  262. """
  263. cos = cos.unsqueeze(unsqueeze_dim)
  264. sin = sin.unsqueeze(unsqueeze_dim)
  265. q_embed = (q * cos) + (rotate_half(q) * sin)
  266. k_embed = (k * cos) + (rotate_half(k) * sin)
  267. return q_embed, k_embed
  268. def eager_attention_forward(
  269. module: "ModernBertAttention",
  270. qkv: torch.Tensor,
  271. attention_mask: torch.Tensor,
  272. sliding_window_mask: torch.Tensor,
  273. position_ids: Optional[torch.LongTensor],
  274. local_attention: tuple[int, int],
  275. bs: int,
  276. dim: int,
  277. output_attentions: Optional[bool] = False,
  278. **_kwargs,
  279. ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
  280. # qkv: [batch_size, seqlen, 3, nheads, headdim]
  281. cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
  282. query, key, value = qkv.transpose(3, 1).unbind(dim=2)
  283. # query, key, value: [batch_size, heads, seq_len, head_dim]
  284. query, key = apply_rotary_pos_emb(query, key, cos, sin)
  285. scale = module.head_dim**-0.5
  286. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
  287. if local_attention != (-1, -1):
  288. attention_mask = sliding_window_mask
  289. attn_weights = attn_weights + attention_mask
  290. # upcast attention to fp32
  291. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  292. attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
  293. attn_output = torch.matmul(attn_weights, value)
  294. attn_output = attn_output.transpose(1, 2).contiguous()
  295. attn_output = attn_output.view(bs, -1, dim)
  296. if output_attentions:
  297. return (attn_output, attn_weights)
  298. return (attn_output,)
  299. def flash_attention_forward(
  300. module: "ModernBertAttention",
  301. qkv: torch.Tensor,
  302. rotary_emb: ModernBertUnpaddedRotaryEmbedding,
  303. cu_seqlens: torch.Tensor,
  304. max_seqlen: int,
  305. local_attention: tuple[int, int],
  306. bs: int,
  307. dim: int,
  308. target_dtype: torch.dtype = torch.bfloat16,
  309. **_kwargs,
  310. ) -> tuple[torch.Tensor]:
  311. # (total_seqlen, 3, nheads, headdim)
  312. qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
  313. convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
  314. if convert_dtype:
  315. # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
  316. # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
  317. orig_dtype = qkv.dtype
  318. qkv = qkv.to(target_dtype)
  319. attn = flash_attn_varlen_qkvpacked_func(
  320. qkv,
  321. cu_seqlens=cu_seqlens,
  322. max_seqlen=max_seqlen,
  323. dropout_p=module.attention_dropout if module.training else 0.0,
  324. deterministic=module.deterministic_flash_attn,
  325. window_size=local_attention,
  326. )
  327. attn = attn.to(orig_dtype) # type: ignore
  328. else:
  329. attn = flash_attn_varlen_qkvpacked_func(
  330. qkv,
  331. cu_seqlens=cu_seqlens,
  332. max_seqlen=max_seqlen,
  333. dropout_p=module.attention_dropout if module.training else 0.0,
  334. deterministic=module.deterministic_flash_attn,
  335. window_size=local_attention,
  336. )
  337. return (attn.view(bs, dim),)
  338. def sdpa_attention_forward(
  339. module: "ModernBertAttention",
  340. qkv: torch.Tensor,
  341. attention_mask: torch.Tensor,
  342. sliding_window_mask: torch.Tensor,
  343. position_ids: Optional[torch.LongTensor],
  344. local_attention: tuple[int, int],
  345. bs: int,
  346. dim: int,
  347. **_kwargs,
  348. ) -> tuple[torch.Tensor]:
  349. # qkv: [batch_size, seqlen, 3, nheads, headdim]
  350. cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
  351. query, key, value = qkv.transpose(3, 1).unbind(dim=2)
  352. # query, key, value: [batch_size, heads, seq_len, head_dim]
  353. query, key = apply_rotary_pos_emb(query, key, cos, sin)
  354. if local_attention != (-1, -1):
  355. attention_mask = sliding_window_mask
  356. attn_output = (
  357. F.scaled_dot_product_attention(
  358. query,
  359. key,
  360. value,
  361. dropout_p=module.attention_dropout if module.training else 0.0,
  362. attn_mask=attention_mask,
  363. )
  364. .transpose(1, 2)
  365. .contiguous()
  366. )
  367. attn_output = attn_output.view(bs, -1, dim)
  368. return (attn_output,)
  369. MODERNBERT_ATTENTION_FUNCTION = {
  370. "flash_attention_2": flash_attention_forward,
  371. "eager": eager_attention_forward,
  372. "sdpa": sdpa_attention_forward,
  373. }
  374. class ModernBertAttention(nn.Module):
  375. """Performs multi-headed self attention on a batch of unpadded sequences.
  376. If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
  377. If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
  378. which requires padding and unpadding inputs, adding some overhead.
  379. See `forward` method for additional details.
  380. """
  381. def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
  382. super().__init__()
  383. self.config = config
  384. self.layer_id = layer_id
  385. if config.hidden_size % config.num_attention_heads != 0:
  386. raise ValueError(
  387. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  388. )
  389. self.attention_dropout = config.attention_dropout
  390. self.deterministic_flash_attn = config.deterministic_flash_attn
  391. self.num_heads = config.num_attention_heads
  392. self.head_dim = config.hidden_size // config.num_attention_heads
  393. self.all_head_size = self.head_dim * self.num_heads
  394. self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
  395. if layer_id % config.global_attn_every_n_layers != 0:
  396. self.local_attention = (config.local_attention // 2, config.local_attention // 2)
  397. rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
  398. max_position_embeddings = config.local_attention
  399. else:
  400. self.local_attention = (-1, -1)
  401. max_position_embeddings = config.max_position_embeddings
  402. rope_theta = config.global_rope_theta
  403. if config._attn_implementation == "flash_attention_2":
  404. self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
  405. dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
  406. )
  407. else:
  408. config_copy = copy.deepcopy(config)
  409. config_copy.rope_theta = rope_theta
  410. self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
  411. self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
  412. self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
  413. self.pruned_heads = set()
  414. def forward(
  415. self,
  416. hidden_states: torch.Tensor,
  417. output_attentions: Optional[bool] = False,
  418. **kwargs,
  419. ) -> torch.Tensor:
  420. qkv = self.Wqkv(hidden_states)
  421. bs = hidden_states.shape[0]
  422. if self.config._attn_implementation == "flash_attention_2":
  423. qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
  424. else:
  425. qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
  426. attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
  427. self,
  428. qkv=qkv,
  429. rotary_emb=self.rotary_emb,
  430. local_attention=self.local_attention,
  431. bs=bs,
  432. dim=self.all_head_size,
  433. output_attentions=output_attentions,
  434. **kwargs,
  435. )
  436. hidden_states = attn_outputs[0]
  437. hidden_states = self.out_drop(self.Wo(hidden_states))
  438. return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
  439. class ModernBertEncoderLayer(GradientCheckpointingLayer):
  440. def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
  441. super().__init__()
  442. self.config = config
  443. if layer_id == 0:
  444. self.attn_norm = nn.Identity()
  445. else:
  446. self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  447. self.attn = ModernBertAttention(config=config, layer_id=layer_id)
  448. self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  449. self.mlp = ModernBertMLP(config)
  450. @torch.compile(dynamic=True)
  451. def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
  452. return self.mlp(self.mlp_norm(hidden_states))
  453. def forward(
  454. self,
  455. hidden_states: torch.Tensor,
  456. attention_mask: Optional[torch.Tensor] = None,
  457. sliding_window_mask: Optional[torch.Tensor] = None,
  458. position_ids: Optional[torch.LongTensor] = None,
  459. cu_seqlens: Optional[torch.Tensor] = None,
  460. max_seqlen: Optional[int] = None,
  461. output_attentions: Optional[bool] = False,
  462. ) -> torch.Tensor:
  463. attn_outputs = self.attn(
  464. self.attn_norm(hidden_states),
  465. attention_mask=attention_mask,
  466. sliding_window_mask=sliding_window_mask,
  467. position_ids=position_ids,
  468. cu_seqlens=cu_seqlens,
  469. max_seqlen=max_seqlen,
  470. output_attentions=output_attentions,
  471. )
  472. hidden_states = hidden_states + attn_outputs[0]
  473. mlp_output = (
  474. self.compiled_mlp(hidden_states)
  475. if self.config.reference_compile
  476. else self.mlp(self.mlp_norm(hidden_states))
  477. )
  478. hidden_states = hidden_states + mlp_output
  479. return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
  480. @auto_docstring
  481. class ModernBertPreTrainedModel(PreTrainedModel):
  482. config: ModernBertConfig
  483. base_model_prefix = "model"
  484. supports_gradient_checkpointing = True
  485. _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
  486. _supports_flash_attn = True
  487. _supports_sdpa = True
  488. _supports_flex_attn = False
  489. def _init_weights(self, module: nn.Module):
  490. cutoff_factor = self.config.initializer_cutoff_factor
  491. if cutoff_factor is None:
  492. cutoff_factor = 3
  493. def init_weight(module: nn.Module, std: float):
  494. nn.init.trunc_normal_(
  495. module.weight,
  496. mean=0.0,
  497. std=std,
  498. a=-cutoff_factor * std,
  499. b=cutoff_factor * std,
  500. )
  501. if isinstance(module, nn.Linear):
  502. if module.bias is not None:
  503. nn.init.zeros_(module.bias)
  504. stds = {
  505. "in": self.config.initializer_range,
  506. "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
  507. "embedding": self.config.initializer_range,
  508. "final_out": self.config.hidden_size**-0.5,
  509. }
  510. if isinstance(module, ModernBertEmbeddings):
  511. init_weight(module.tok_embeddings, stds["embedding"])
  512. elif isinstance(module, ModernBertMLP):
  513. init_weight(module.Wi, stds["in"])
  514. init_weight(module.Wo, stds["out"])
  515. elif isinstance(module, ModernBertAttention):
  516. init_weight(module.Wqkv, stds["in"])
  517. init_weight(module.Wo, stds["out"])
  518. elif isinstance(module, ModernBertPredictionHead):
  519. init_weight(module.dense, stds["out"])
  520. elif isinstance(module, ModernBertForMaskedLM):
  521. init_weight(module.decoder, stds["out"])
  522. elif isinstance(
  523. module,
  524. (
  525. ModernBertForSequenceClassification,
  526. ModernBertForMultipleChoice,
  527. ModernBertForTokenClassification,
  528. ModernBertForQuestionAnswering,
  529. ),
  530. ):
  531. init_weight(module.classifier, stds["final_out"])
  532. elif isinstance(module, nn.LayerNorm):
  533. module.weight.data.fill_(1.0)
  534. if module.bias is not None:
  535. module.bias.data.zero_()
  536. def _check_and_adjust_attn_implementation(
  537. self, attn_implementation: Optional[str], is_init_check: bool = False
  538. ) -> str:
  539. """
  540. Checks and dispatches to hhe requested attention implementation.
  541. """
  542. # If the user didn't specify anything, try to use flash_attention_2 if available.
  543. # Otherwise we fall back to the default SDPA -> Eager from the super() method.
  544. # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
  545. # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
  546. try:
  547. attn_implementation = (
  548. "flash_attention_2"
  549. if attn_implementation is None and self._flash_attn_2_can_dispatch()
  550. else attn_implementation
  551. )
  552. except (ValueError, ImportError):
  553. pass
  554. return super()._check_and_adjust_attn_implementation(
  555. attn_implementation=attn_implementation, is_init_check=is_init_check
  556. )
  557. def _maybe_set_compile(self):
  558. if self.config.reference_compile is False:
  559. return
  560. if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
  561. if self.config.reference_compile:
  562. logger.warning_once(
  563. "If `accelerate` split the model across devices, `torch.compile` will not work. "
  564. "Falling back to non-compiled mode."
  565. )
  566. self.config.reference_compile = False
  567. if self.device.type == "mps":
  568. if self.config.reference_compile:
  569. logger.warning_once(
  570. "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
  571. "Falling back to non-compiled mode."
  572. )
  573. self.config.reference_compile = False
  574. if self.device.type == "cpu":
  575. if self.config.reference_compile:
  576. logger.warning_once(
  577. "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
  578. "Falling back to non-compiled mode."
  579. )
  580. self.config.reference_compile = False
  581. if self.config.reference_compile is None:
  582. self.config.reference_compile = is_triton_available()
  583. def resize_token_embeddings(self, *args, **kwargs):
  584. model_embeds = super().resize_token_embeddings(*args, **kwargs)
  585. if self.config.reference_compile in {True, None}:
  586. if self.config.reference_compile:
  587. logger.warning_once(
  588. "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
  589. )
  590. self.config.reference_compile = False
  591. return model_embeds
  592. def _unpad_modernbert_input(
  593. inputs: torch.Tensor,
  594. attention_mask: torch.Tensor,
  595. position_ids: Optional[torch.Tensor] = None,
  596. labels: Optional[torch.Tensor] = None,
  597. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
  598. """
  599. Remove padding from input sequences.
  600. Args:
  601. inputs: (batch, seqlen, ...) or (batch, seqlen)
  602. attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
  603. position_ids: (batch, seqlen), int, position ids
  604. labels: (batch, seqlen), int, labels
  605. Returns:
  606. unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
  607. indices: (total_nnz)
  608. cu_seqlens: (batch + 1), the cumulative sequence lengths
  609. max_seqlen_in_batch: int
  610. unpadded_position_ids: (total_nnz) or None
  611. unpadded_labels: (total_nnz) or None
  612. """
  613. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  614. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  615. max_seqlen_in_batch = int(seqlens_in_batch.max().item())
  616. cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  617. if inputs.dim() == 2:
  618. unpadded_inputs = inputs.flatten()[indices]
  619. else:
  620. batch, seqlen, *rest = inputs.shape
  621. shape = batch * seqlen
  622. unpadded_inputs = inputs.view(shape, *rest)[indices]
  623. unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
  624. unpadded_labels = labels.flatten()[indices] if labels is not None else None
  625. return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
  626. def _pad_modernbert_output(
  627. inputs: torch.Tensor,
  628. indices: torch.Tensor,
  629. batch: int,
  630. seqlen: int,
  631. ) -> torch.Tensor:
  632. """
  633. Add padding to sequences.
  634. Args:
  635. inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
  636. indices: (total_nnz)
  637. batch: int, batch size
  638. seqlen: int, max sequence length
  639. Returns:
  640. padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
  641. """
  642. if inputs.dim() == 1:
  643. output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
  644. output[indices] = inputs
  645. padded_inputs = output.view(batch, seqlen)
  646. else:
  647. _, *rest = inputs.shape
  648. output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
  649. output[indices] = inputs
  650. padded_inputs = output.view(batch, seqlen, *rest)
  651. return padded_inputs
  652. @auto_docstring
  653. class ModernBertModel(ModernBertPreTrainedModel):
  654. def __init__(self, config: ModernBertConfig):
  655. super().__init__(config)
  656. self.config = config
  657. self.embeddings = ModernBertEmbeddings(config)
  658. self.layers = nn.ModuleList(
  659. [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
  660. )
  661. self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  662. self.gradient_checkpointing = False
  663. self.post_init()
  664. def get_input_embeddings(self):
  665. return self.embeddings.tok_embeddings
  666. def set_input_embeddings(self, value):
  667. self.embeddings.tok_embeddings = value
  668. @auto_docstring
  669. def forward(
  670. self,
  671. input_ids: Optional[torch.LongTensor] = None,
  672. attention_mask: Optional[torch.Tensor] = None,
  673. sliding_window_mask: Optional[torch.Tensor] = None,
  674. position_ids: Optional[torch.LongTensor] = None,
  675. inputs_embeds: Optional[torch.Tensor] = None,
  676. indices: Optional[torch.Tensor] = None,
  677. cu_seqlens: Optional[torch.Tensor] = None,
  678. max_seqlen: Optional[int] = None,
  679. batch_size: Optional[int] = None,
  680. seq_len: Optional[int] = None,
  681. output_attentions: Optional[bool] = None,
  682. output_hidden_states: Optional[bool] = None,
  683. return_dict: Optional[bool] = None,
  684. ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
  685. r"""
  686. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  687. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  688. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  689. far-away tokens in the local attention layers when not using Flash Attention.
  690. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  691. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  692. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  693. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  694. max_seqlen (`int`, *optional*):
  695. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  696. batch_size (`int`, *optional*):
  697. Batch size of the input sequences. Used to pad the output tensors.
  698. seq_len (`int`, *optional*):
  699. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  700. """
  701. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  702. output_hidden_states = (
  703. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  704. )
  705. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  706. if (input_ids is None) ^ (inputs_embeds is not None):
  707. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  708. all_hidden_states = () if output_hidden_states else None
  709. all_self_attentions = () if output_attentions else None
  710. self._maybe_set_compile()
  711. if input_ids is not None:
  712. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  713. if batch_size is None and seq_len is None:
  714. if inputs_embeds is not None:
  715. batch_size, seq_len = inputs_embeds.shape[:2]
  716. else:
  717. batch_size, seq_len = input_ids.shape[:2]
  718. device = input_ids.device if input_ids is not None else inputs_embeds.device
  719. if attention_mask is None:
  720. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  721. repad = False
  722. if self.config._attn_implementation == "flash_attention_2":
  723. if indices is None and cu_seqlens is None and max_seqlen is None:
  724. repad = True
  725. if inputs_embeds is None:
  726. with torch.no_grad():
  727. input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
  728. inputs=input_ids, attention_mask=attention_mask
  729. )
  730. else:
  731. inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
  732. inputs=inputs_embeds, attention_mask=attention_mask
  733. )
  734. else:
  735. if position_ids is None:
  736. position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
  737. attention_mask, sliding_window_mask = self._update_attention_mask(
  738. attention_mask, output_attentions=output_attentions
  739. )
  740. hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
  741. for encoder_layer in self.layers:
  742. if output_hidden_states:
  743. all_hidden_states = all_hidden_states + (hidden_states,)
  744. layer_outputs = encoder_layer(
  745. hidden_states,
  746. attention_mask=attention_mask,
  747. sliding_window_mask=sliding_window_mask,
  748. position_ids=position_ids,
  749. cu_seqlens=cu_seqlens,
  750. max_seqlen=max_seqlen,
  751. output_attentions=output_attentions,
  752. )
  753. hidden_states = layer_outputs[0]
  754. if output_attentions and len(layer_outputs) > 1:
  755. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  756. if output_hidden_states:
  757. all_hidden_states = all_hidden_states + (hidden_states,)
  758. hidden_states = self.final_norm(hidden_states)
  759. if repad:
  760. hidden_states = _pad_modernbert_output(
  761. inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
  762. )
  763. if all_hidden_states is not None:
  764. all_hidden_states = tuple(
  765. _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
  766. for hs in all_hidden_states
  767. )
  768. # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
  769. # dimension missing
  770. elif (
  771. self.config._attn_implementation == "flash_attention_2"
  772. and all_hidden_states is not None
  773. and all_hidden_states[-1].dim() == 2
  774. ):
  775. hidden_states = hidden_states.unsqueeze(0)
  776. all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
  777. if not return_dict:
  778. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  779. return BaseModelOutput(
  780. last_hidden_state=hidden_states,
  781. hidden_states=all_hidden_states,
  782. attentions=all_self_attentions,
  783. )
  784. def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
  785. if output_attentions:
  786. if self.config._attn_implementation == "sdpa":
  787. logger.warning_once(
  788. "Outputting attentions is only supported with the 'eager' attention implementation, "
  789. 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
  790. )
  791. self.config._attn_implementation = "eager"
  792. elif self.config._attn_implementation != "eager":
  793. logger.warning_once(
  794. "Outputting attentions is only supported with the eager attention implementation, "
  795. f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
  796. " Setting `output_attentions=False`."
  797. )
  798. global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
  799. # Create position indices
  800. rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
  801. # Calculate distance between positions
  802. distance = torch.abs(rows - rows.T)
  803. # Create sliding window mask (1 for positions within window, 0 outside)
  804. window_mask = (
  805. (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
  806. )
  807. # Combine with existing mask
  808. sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
  809. return global_attention_mask, sliding_window_mask
  810. class ModernBertPredictionHead(nn.Module):
  811. def __init__(self, config: ModernBertConfig):
  812. super().__init__()
  813. self.config = config
  814. self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
  815. self.act = ACT2FN[config.classifier_activation]
  816. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
  817. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  818. return self.norm(self.act(self.dense(hidden_states)))
  819. @auto_docstring(
  820. custom_intro="""
  821. The ModernBert Model with a decoder head on top that is used for masked language modeling.
  822. """
  823. )
  824. class ModernBertForMaskedLM(ModernBertPreTrainedModel):
  825. _tied_weights_keys = ["decoder.weight"]
  826. def __init__(self, config: ModernBertConfig):
  827. super().__init__(config)
  828. self.config = config
  829. self.model = ModernBertModel(config)
  830. self.head = ModernBertPredictionHead(config)
  831. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
  832. self.sparse_prediction = self.config.sparse_prediction
  833. self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
  834. # Initialize weights and apply final processing
  835. self.post_init()
  836. def get_output_embeddings(self):
  837. return self.decoder
  838. def set_output_embeddings(self, new_embeddings: nn.Linear):
  839. self.decoder = new_embeddings
  840. @torch.compile(dynamic=True)
  841. def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
  842. return self.decoder(self.head(output))
  843. @auto_docstring
  844. def forward(
  845. self,
  846. input_ids: Optional[torch.LongTensor] = None,
  847. attention_mask: Optional[torch.Tensor] = None,
  848. sliding_window_mask: Optional[torch.Tensor] = None,
  849. position_ids: Optional[torch.Tensor] = None,
  850. inputs_embeds: Optional[torch.Tensor] = None,
  851. labels: Optional[torch.Tensor] = None,
  852. indices: Optional[torch.Tensor] = None,
  853. cu_seqlens: Optional[torch.Tensor] = None,
  854. max_seqlen: Optional[int] = None,
  855. batch_size: Optional[int] = None,
  856. seq_len: Optional[int] = None,
  857. output_attentions: Optional[bool] = None,
  858. output_hidden_states: Optional[bool] = None,
  859. return_dict: Optional[bool] = None,
  860. **kwargs,
  861. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  862. r"""
  863. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  864. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  865. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  866. far-away tokens in the local attention layers when not using Flash Attention.
  867. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  868. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  869. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  870. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  871. max_seqlen (`int`, *optional*):
  872. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  873. batch_size (`int`, *optional*):
  874. Batch size of the input sequences. Used to pad the output tensors.
  875. seq_len (`int`, *optional*):
  876. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  877. """
  878. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  879. self._maybe_set_compile()
  880. if self.config._attn_implementation == "flash_attention_2":
  881. if indices is None and cu_seqlens is None and max_seqlen is None:
  882. if batch_size is None and seq_len is None:
  883. if inputs_embeds is not None:
  884. batch_size, seq_len = inputs_embeds.shape[:2]
  885. else:
  886. batch_size, seq_len = input_ids.shape[:2]
  887. device = input_ids.device if input_ids is not None else inputs_embeds.device
  888. if attention_mask is None:
  889. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  890. if inputs_embeds is None:
  891. with torch.no_grad():
  892. input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
  893. inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
  894. )
  895. else:
  896. inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
  897. inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
  898. )
  899. outputs = self.model(
  900. input_ids=input_ids,
  901. attention_mask=attention_mask,
  902. sliding_window_mask=sliding_window_mask,
  903. position_ids=position_ids,
  904. inputs_embeds=inputs_embeds,
  905. indices=indices,
  906. cu_seqlens=cu_seqlens,
  907. max_seqlen=max_seqlen,
  908. batch_size=batch_size,
  909. seq_len=seq_len,
  910. output_attentions=output_attentions,
  911. output_hidden_states=output_hidden_states,
  912. return_dict=return_dict,
  913. )
  914. last_hidden_state = outputs[0]
  915. if self.sparse_prediction and labels is not None:
  916. # flatten labels and output first
  917. labels = labels.view(-1)
  918. last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
  919. # then filter out the non-masked tokens
  920. mask_tokens = labels != self.sparse_pred_ignore_index
  921. last_hidden_state = last_hidden_state[mask_tokens]
  922. labels = labels[mask_tokens]
  923. logits = (
  924. self.compiled_head(last_hidden_state)
  925. if self.config.reference_compile
  926. else self.decoder(self.head(last_hidden_state))
  927. )
  928. loss = None
  929. if labels is not None:
  930. loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
  931. if self.config._attn_implementation == "flash_attention_2":
  932. # Logits padding
  933. with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
  934. logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
  935. # Hidden states padding
  936. if getattr(outputs, "hidden_states", None) is not None:
  937. padded_hidden_states = []
  938. for hs in outputs.hidden_states:
  939. if hs.dim() == 3 and hs.shape[0] == 1:
  940. hs = hs.squeeze(0)
  941. padded_hidden_states.append(
  942. _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
  943. )
  944. outputs.hidden_states = tuple(padded_hidden_states)
  945. if not return_dict:
  946. output = (logits,)
  947. return ((loss,) + output) if loss is not None else output
  948. return MaskedLMOutput(
  949. loss=loss,
  950. logits=logits,
  951. hidden_states=outputs.hidden_states,
  952. attentions=outputs.attentions,
  953. )
  954. @auto_docstring(
  955. custom_intro="""
  956. The ModernBert Model with a sequence classification head on top that performs pooling.
  957. """
  958. )
  959. class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
  960. def __init__(self, config: ModernBertConfig):
  961. super().__init__(config)
  962. self.num_labels = config.num_labels
  963. self.config = config
  964. self.model = ModernBertModel(config)
  965. self.head = ModernBertPredictionHead(config)
  966. self.drop = torch.nn.Dropout(config.classifier_dropout)
  967. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  968. # Initialize weights and apply final processing
  969. self.post_init()
  970. @auto_docstring
  971. def forward(
  972. self,
  973. input_ids: Optional[torch.LongTensor] = None,
  974. attention_mask: Optional[torch.Tensor] = None,
  975. sliding_window_mask: Optional[torch.Tensor] = None,
  976. position_ids: Optional[torch.Tensor] = None,
  977. inputs_embeds: Optional[torch.Tensor] = None,
  978. labels: Optional[torch.Tensor] = None,
  979. indices: Optional[torch.Tensor] = None,
  980. cu_seqlens: Optional[torch.Tensor] = None,
  981. max_seqlen: Optional[int] = None,
  982. batch_size: Optional[int] = None,
  983. seq_len: Optional[int] = None,
  984. output_attentions: Optional[bool] = None,
  985. output_hidden_states: Optional[bool] = None,
  986. return_dict: Optional[bool] = None,
  987. **kwargs,
  988. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  989. r"""
  990. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  991. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  992. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  993. far-away tokens in the local attention layers when not using Flash Attention.
  994. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  995. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  996. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  997. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  998. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  999. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1000. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1001. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1002. max_seqlen (`int`, *optional*):
  1003. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1004. batch_size (`int`, *optional*):
  1005. Batch size of the input sequences. Used to pad the output tensors.
  1006. seq_len (`int`, *optional*):
  1007. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1008. """
  1009. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1010. self._maybe_set_compile()
  1011. if input_ids is not None:
  1012. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  1013. if batch_size is None and seq_len is None:
  1014. if inputs_embeds is not None:
  1015. batch_size, seq_len = inputs_embeds.shape[:2]
  1016. else:
  1017. batch_size, seq_len = input_ids.shape[:2]
  1018. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1019. if attention_mask is None:
  1020. attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
  1021. outputs = self.model(
  1022. input_ids=input_ids,
  1023. attention_mask=attention_mask,
  1024. sliding_window_mask=sliding_window_mask,
  1025. position_ids=position_ids,
  1026. inputs_embeds=inputs_embeds,
  1027. indices=indices,
  1028. cu_seqlens=cu_seqlens,
  1029. max_seqlen=max_seqlen,
  1030. batch_size=batch_size,
  1031. seq_len=seq_len,
  1032. output_attentions=output_attentions,
  1033. output_hidden_states=output_hidden_states,
  1034. return_dict=return_dict,
  1035. )
  1036. last_hidden_state = outputs[0]
  1037. if self.config.classifier_pooling == "cls":
  1038. last_hidden_state = last_hidden_state[:, 0]
  1039. elif self.config.classifier_pooling == "mean":
  1040. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
  1041. dim=1, keepdim=True
  1042. )
  1043. pooled_output = self.head(last_hidden_state)
  1044. pooled_output = self.drop(pooled_output)
  1045. logits = self.classifier(pooled_output)
  1046. loss = None
  1047. if labels is not None:
  1048. if self.config.problem_type is None:
  1049. if self.num_labels == 1:
  1050. self.config.problem_type = "regression"
  1051. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1052. self.config.problem_type = "single_label_classification"
  1053. else:
  1054. self.config.problem_type = "multi_label_classification"
  1055. if self.config.problem_type == "regression":
  1056. loss_fct = MSELoss()
  1057. if self.num_labels == 1:
  1058. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1059. else:
  1060. loss = loss_fct(logits, labels)
  1061. elif self.config.problem_type == "single_label_classification":
  1062. loss_fct = CrossEntropyLoss()
  1063. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1064. elif self.config.problem_type == "multi_label_classification":
  1065. loss_fct = BCEWithLogitsLoss()
  1066. loss = loss_fct(logits, labels)
  1067. if not return_dict:
  1068. output = (logits,)
  1069. return ((loss,) + output) if loss is not None else output
  1070. return SequenceClassifierOutput(
  1071. loss=loss,
  1072. logits=logits,
  1073. hidden_states=outputs.hidden_states,
  1074. attentions=outputs.attentions,
  1075. )
  1076. @auto_docstring(
  1077. custom_intro="""
  1078. The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
  1079. """
  1080. )
  1081. class ModernBertForTokenClassification(ModernBertPreTrainedModel):
  1082. def __init__(self, config: ModernBertConfig):
  1083. super().__init__(config)
  1084. self.num_labels = config.num_labels
  1085. self.model = ModernBertModel(config)
  1086. self.head = ModernBertPredictionHead(config)
  1087. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1088. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1089. # Initialize weights and apply final processing
  1090. self.post_init()
  1091. @auto_docstring
  1092. def forward(
  1093. self,
  1094. input_ids: Optional[torch.LongTensor] = None,
  1095. attention_mask: Optional[torch.Tensor] = None,
  1096. sliding_window_mask: Optional[torch.Tensor] = None,
  1097. position_ids: Optional[torch.Tensor] = None,
  1098. inputs_embeds: Optional[torch.Tensor] = None,
  1099. labels: Optional[torch.Tensor] = None,
  1100. indices: Optional[torch.Tensor] = None,
  1101. cu_seqlens: Optional[torch.Tensor] = None,
  1102. max_seqlen: Optional[int] = None,
  1103. batch_size: Optional[int] = None,
  1104. seq_len: Optional[int] = None,
  1105. output_attentions: Optional[bool] = None,
  1106. output_hidden_states: Optional[bool] = None,
  1107. return_dict: Optional[bool] = None,
  1108. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1109. r"""
  1110. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1111. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1112. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1113. far-away tokens in the local attention layers when not using Flash Attention.
  1114. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1115. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1116. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1117. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1118. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1119. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1120. max_seqlen (`int`, *optional*):
  1121. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1122. batch_size (`int`, *optional*):
  1123. Batch size of the input sequences. Used to pad the output tensors.
  1124. seq_len (`int`, *optional*):
  1125. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1126. """
  1127. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1128. self._maybe_set_compile()
  1129. outputs = self.model(
  1130. input_ids=input_ids,
  1131. attention_mask=attention_mask,
  1132. sliding_window_mask=sliding_window_mask,
  1133. position_ids=position_ids,
  1134. inputs_embeds=inputs_embeds,
  1135. indices=indices,
  1136. cu_seqlens=cu_seqlens,
  1137. max_seqlen=max_seqlen,
  1138. batch_size=batch_size,
  1139. seq_len=seq_len,
  1140. output_attentions=output_attentions,
  1141. output_hidden_states=output_hidden_states,
  1142. return_dict=return_dict,
  1143. )
  1144. last_hidden_state = outputs[0]
  1145. last_hidden_state = self.head(last_hidden_state)
  1146. last_hidden_state = self.drop(last_hidden_state)
  1147. logits = self.classifier(last_hidden_state)
  1148. loss = None
  1149. if labels is not None:
  1150. loss_fct = CrossEntropyLoss()
  1151. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1152. if not return_dict:
  1153. output = (logits,) + outputs[1:]
  1154. return ((loss,) + output) if loss is not None else output
  1155. return TokenClassifierOutput(
  1156. loss=loss,
  1157. logits=logits,
  1158. hidden_states=outputs.hidden_states,
  1159. attentions=outputs.attentions,
  1160. )
  1161. @auto_docstring
  1162. class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
  1163. def __init__(self, config: ModernBertConfig):
  1164. super().__init__(config)
  1165. self.num_labels = config.num_labels
  1166. self.model = ModernBertModel(config)
  1167. self.head = ModernBertPredictionHead(config)
  1168. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1169. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1170. self.post_init()
  1171. @auto_docstring
  1172. def forward(
  1173. self,
  1174. input_ids: Optional[torch.Tensor],
  1175. attention_mask: Optional[torch.Tensor] = None,
  1176. sliding_window_mask: Optional[torch.Tensor] = None,
  1177. position_ids: Optional[torch.Tensor] = None,
  1178. start_positions: Optional[torch.Tensor] = None,
  1179. end_positions: Optional[torch.Tensor] = None,
  1180. indices: Optional[torch.Tensor] = None,
  1181. cu_seqlens: Optional[torch.Tensor] = None,
  1182. max_seqlen: Optional[int] = None,
  1183. batch_size: Optional[int] = None,
  1184. seq_len: Optional[int] = None,
  1185. output_attentions: Optional[bool] = None,
  1186. output_hidden_states: Optional[bool] = None,
  1187. return_dict: Optional[bool] = None,
  1188. **kwargs,
  1189. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1190. r"""
  1191. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1192. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1193. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1194. far-away tokens in the local attention layers when not using Flash Attention.
  1195. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1196. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1197. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1198. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1199. max_seqlen (`int`, *optional*):
  1200. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1201. batch_size (`int`, *optional*):
  1202. Batch size of the input sequences. Used to pad the output tensors.
  1203. seq_len (`int`, *optional*):
  1204. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1205. """
  1206. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1207. self._maybe_set_compile()
  1208. outputs = self.model(
  1209. input_ids,
  1210. attention_mask=attention_mask,
  1211. sliding_window_mask=sliding_window_mask,
  1212. position_ids=position_ids,
  1213. indices=indices,
  1214. cu_seqlens=cu_seqlens,
  1215. max_seqlen=max_seqlen,
  1216. batch_size=batch_size,
  1217. seq_len=seq_len,
  1218. output_attentions=output_attentions,
  1219. output_hidden_states=output_hidden_states,
  1220. return_dict=return_dict,
  1221. )
  1222. last_hidden_state = outputs[0]
  1223. last_hidden_state = self.head(last_hidden_state)
  1224. last_hidden_state = self.drop(last_hidden_state)
  1225. logits = self.classifier(last_hidden_state)
  1226. start_logits, end_logits = logits.split(1, dim=-1)
  1227. start_logits = start_logits.squeeze(-1).contiguous()
  1228. end_logits = end_logits.squeeze(-1).contiguous()
  1229. loss = None
  1230. if start_positions is not None and end_positions is not None:
  1231. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  1232. if not return_dict:
  1233. output = (start_logits, end_logits) + outputs[1:]
  1234. return ((loss,) + output) if loss is not None else output
  1235. return QuestionAnsweringModelOutput(
  1236. loss=loss,
  1237. start_logits=start_logits,
  1238. end_logits=end_logits,
  1239. hidden_states=outputs.hidden_states,
  1240. attentions=outputs.attentions,
  1241. )
  1242. @auto_docstring(
  1243. custom_intro="""
  1244. The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
  1245. """
  1246. )
  1247. class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
  1248. def __init__(self, config: ModernBertConfig):
  1249. super().__init__(config)
  1250. self.config = config
  1251. self.model = ModernBertModel(config)
  1252. self.head = ModernBertPredictionHead(config)
  1253. self.drop = torch.nn.Dropout(config.classifier_dropout)
  1254. self.classifier = nn.Linear(config.hidden_size, 1)
  1255. # Initialize weights and apply final processing
  1256. self.post_init()
  1257. @auto_docstring
  1258. def forward(
  1259. self,
  1260. input_ids: Optional[torch.LongTensor] = None,
  1261. attention_mask: Optional[torch.Tensor] = None,
  1262. sliding_window_mask: Optional[torch.Tensor] = None,
  1263. position_ids: Optional[torch.Tensor] = None,
  1264. inputs_embeds: Optional[torch.Tensor] = None,
  1265. labels: Optional[torch.Tensor] = None,
  1266. indices: Optional[torch.Tensor] = None,
  1267. cu_seqlens: Optional[torch.Tensor] = None,
  1268. max_seqlen: Optional[int] = None,
  1269. batch_size: Optional[int] = None,
  1270. seq_len: Optional[int] = None,
  1271. output_attentions: Optional[bool] = None,
  1272. output_hidden_states: Optional[bool] = None,
  1273. return_dict: Optional[bool] = None,
  1274. **kwargs,
  1275. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1276. r"""
  1277. sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1278. Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
  1279. perform global attention, while the rest perform local attention. This mask is used to avoid attending to
  1280. far-away tokens in the local attention layers when not using Flash Attention.
  1281. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1282. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1283. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
  1284. indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
  1285. Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
  1286. cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
  1287. Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
  1288. max_seqlen (`int`, *optional*):
  1289. Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
  1290. batch_size (`int`, *optional*):
  1291. Batch size of the input sequences. Used to pad the output tensors.
  1292. seq_len (`int`, *optional*):
  1293. Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
  1294. """
  1295. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1296. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1297. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1298. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1299. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1300. inputs_embeds = (
  1301. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1302. if inputs_embeds is not None
  1303. else None
  1304. )
  1305. self._maybe_set_compile()
  1306. outputs = self.model(
  1307. input_ids=input_ids,
  1308. attention_mask=attention_mask,
  1309. sliding_window_mask=sliding_window_mask,
  1310. position_ids=position_ids,
  1311. inputs_embeds=inputs_embeds,
  1312. indices=indices,
  1313. cu_seqlens=cu_seqlens,
  1314. max_seqlen=max_seqlen,
  1315. batch_size=batch_size,
  1316. seq_len=seq_len,
  1317. output_attentions=output_attentions,
  1318. output_hidden_states=output_hidden_states,
  1319. return_dict=return_dict,
  1320. )
  1321. last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
  1322. # If classifier_pooling is "cls", isolate the <cls> token
  1323. if self.config.classifier_pooling == "cls":
  1324. indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
  1325. # for left or right padding, <cls> is the first non-pad token
  1326. if attention_mask is not None:
  1327. cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
  1328. # if no pad, <cls> is the first token
  1329. else:
  1330. cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
  1331. # extract the <cls> token for the logits
  1332. last_hidden_state = last_hidden_state[indices_0, cls_mask]
  1333. # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
  1334. elif self.config.classifier_pooling == "mean":
  1335. num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
  1336. last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
  1337. pooled_output = self.head(last_hidden_state)
  1338. pooled_output = self.drop(pooled_output)
  1339. logits = self.classifier(pooled_output)
  1340. reshaped_logits = logits.view(-1, num_choices)
  1341. loss = None
  1342. if labels is not None:
  1343. loss_fct = nn.CrossEntropyLoss()
  1344. loss = loss_fct(reshaped_logits, labels)
  1345. if not return_dict:
  1346. output = (reshaped_logits,) + outputs[1:]
  1347. return ((loss,) + output) if loss is not None else output
  1348. return MultipleChoiceModelOutput(
  1349. loss=loss,
  1350. logits=reshaped_logits,
  1351. hidden_states=outputs.hidden_states,
  1352. attentions=outputs.attentions,
  1353. )
  1354. __all__ = [
  1355. "ModernBertModel",
  1356. "ModernBertPreTrainedModel",
  1357. "ModernBertForMaskedLM",
  1358. "ModernBertForSequenceClassification",
  1359. "ModernBertForTokenClassification",
  1360. "ModernBertForQuestionAnswering",
  1361. "ModernBertForMultipleChoice",
  1362. ]