| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/modernbert/modular_modernbert.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_modernbert.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import math
- from contextlib import nullcontext
- from typing import Optional, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ...activations import ACT2FN
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutput,
- MaskedLMOutput,
- MultipleChoiceModelOutput,
- QuestionAnsweringModelOutput,
- SequenceClassifierOutput,
- TokenClassifierOutput,
- )
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import PreTrainedModel
- from ...utils import auto_docstring, is_flash_attn_2_available, logging
- from ...utils.import_utils import is_triton_available
- from .configuration_modernbert import ModernBertConfig
- if is_flash_attn_2_available():
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
- from flash_attn.layers.rotary import RotaryEmbedding
- from flash_attn.ops.triton.rotary import apply_rotary
- else:
- RotaryEmbedding = object
- logger = logging.get_logger(__name__)
- class ApplyRotaryEmbUnpad(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- qkv,
- cos,
- sin,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- ):
- # (total_nnz, 3, nheads, headdim)
- qkv = qkv.contiguous()
- total_nnz, _three, _nheads, headdim = qkv.shape
- # We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
- # we get the same tensor
- # qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
- qk = qkv[:, :2].view(total_nnz, -1, headdim)
- apply_rotary(
- qk,
- cos,
- sin,
- seqlen_offsets=0,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- interleaved=False,
- inplace=True,
- )
- ctx.save_for_backward(cos, sin, cu_seqlens)
- ctx.max_seqlen = max_seqlen
- return qkv
- @staticmethod
- def backward(ctx, do):
- cos, sin, cu_seqlens = ctx.saved_tensors
- do = do.contiguous()
- total_nnz, _three, _nheads, headdim = do.shape
- # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
- # we get the same tensor
- dqk = do[:, :2].view(total_nnz, -1, headdim)
- apply_rotary(
- dqk,
- cos,
- sin,
- seqlen_offsets=0,
- cu_seqlens=cu_seqlens,
- max_seqlen=ctx.max_seqlen,
- interleaved=False,
- inplace=True,
- conjugate=True,
- )
- return do, None, None, None, None, None, None
- def apply_rotary_unpadded(
- qkv,
- cos,
- sin,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- ):
- """
- Arguments:
- qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
- cos, sin: (seqlen_rotary, rotary_dim / 2)
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
- of 1st half and 2nd half (GPT-NeoX style).
- inplace: if True, apply rotary embedding in-place.
- seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
- Most commonly used in inference when we have KV cache.
- cu_seqlens: (batch + 1,) or None
- max_seqlen: int
- Return:
- out: (total_nnz, dim)
- rotary_dim must be <= headdim
- Apply rotary embedding to the first rotary_dim of x.
- """
- return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
- class ModernBertUnpaddedRotaryEmbedding(RotaryEmbedding):
- """
- The rotary position embeddings applied directly to unpadded sequences.
- """
- def __init__(
- self,
- dim: int,
- base: float = 10000.0,
- max_seqlen: Optional[int] = None,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- ):
- """
- max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
- up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
- the cos_sin_cache will be recomputed during the forward pass.
- """
- super().__init__(dim=dim, base=base, device=device, interleaved=False)
- self.max_seqlen = max_seqlen
- if max_seqlen is not None and device is not None and dtype is not None:
- self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
- def forward(
- self,
- qkv: torch.Tensor,
- cu_seqlens: torch.Tensor,
- max_seqlen: Optional[int] = None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
- """
- Apply rotary embedding *inplace* to qkv.
- qkv: (total_nnz, 3, nheads, headdim)
- cu_seqlens: (batch + 1,) cumulative sequence lengths
- max_seqlen: int max seq length in the batch
- """
- if max_seqlen is not None:
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
- qkv = apply_rotary_unpadded(
- qkv,
- self._cos_cached,
- self._sin_cached,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- )
- return qkv
- def extra_repr(self) -> str:
- return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
- class ModernBertEmbeddings(nn.Module):
- """
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
- """
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.drop = nn.Dropout(config.embedding_dropout)
- @torch.compile(dynamic=True)
- def compiled_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
- return self.drop(self.norm(self.tok_embeddings(input_ids)))
- def forward(
- self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
- ) -> torch.Tensor:
- if inputs_embeds is not None:
- hidden_states = self.drop(self.norm(inputs_embeds))
- else:
- hidden_states = (
- self.compiled_embeddings(input_ids)
- if self.config.reference_compile
- else self.drop(self.norm(self.tok_embeddings(input_ids)))
- )
- return hidden_states
- class ModernBertMLP(nn.Module):
- """Applies the GLU at the end of each ModernBERT layer.
- Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
- and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
- """
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_bias)
- self.act = ACT2FN[config.hidden_activation]
- self.drop = nn.Dropout(config.mlp_dropout)
- self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
- return self.Wo(self.drop(self.act(input) * gate))
- class ModernBertRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: ModernBertConfig, device=None):
- super().__init__()
- # BC: "rope_type" was originally "type"
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def eager_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- attention_mask: torch.Tensor,
- sliding_window_mask: torch.Tensor,
- position_ids: Optional[torch.LongTensor],
- local_attention: tuple[int, int],
- bs: int,
- dim: int,
- output_attentions: Optional[bool] = False,
- **_kwargs,
- ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
- cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
- # query, key, value: [batch_size, heads, seq_len, head_dim]
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
- scale = module.head_dim**-0.5
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
- if local_attention != (-1, -1):
- attention_mask = sliding_window_mask
- attn_weights = attn_weights + attention_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(bs, -1, dim)
- if output_attentions:
- return (attn_output, attn_weights)
- return (attn_output,)
- def flash_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- rotary_emb: ModernBertUnpaddedRotaryEmbedding,
- cu_seqlens: torch.Tensor,
- max_seqlen: int,
- local_attention: tuple[int, int],
- bs: int,
- dim: int,
- target_dtype: torch.dtype = torch.bfloat16,
- **_kwargs,
- ) -> tuple[torch.Tensor]:
- # (total_seqlen, 3, nheads, headdim)
- qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
- convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
- if convert_dtype:
- # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
- # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
- orig_dtype = qkv.dtype
- qkv = qkv.to(target_dtype)
- attn = flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- dropout_p=module.attention_dropout if module.training else 0.0,
- deterministic=module.deterministic_flash_attn,
- window_size=local_attention,
- )
- attn = attn.to(orig_dtype) # type: ignore
- else:
- attn = flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- dropout_p=module.attention_dropout if module.training else 0.0,
- deterministic=module.deterministic_flash_attn,
- window_size=local_attention,
- )
- return (attn.view(bs, dim),)
- def sdpa_attention_forward(
- module: "ModernBertAttention",
- qkv: torch.Tensor,
- attention_mask: torch.Tensor,
- sliding_window_mask: torch.Tensor,
- position_ids: Optional[torch.LongTensor],
- local_attention: tuple[int, int],
- bs: int,
- dim: int,
- **_kwargs,
- ) -> tuple[torch.Tensor]:
- # qkv: [batch_size, seqlen, 3, nheads, headdim]
- cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
- query, key, value = qkv.transpose(3, 1).unbind(dim=2)
- # query, key, value: [batch_size, heads, seq_len, head_dim]
- query, key = apply_rotary_pos_emb(query, key, cos, sin)
- if local_attention != (-1, -1):
- attention_mask = sliding_window_mask
- attn_output = (
- F.scaled_dot_product_attention(
- query,
- key,
- value,
- dropout_p=module.attention_dropout if module.training else 0.0,
- attn_mask=attention_mask,
- )
- .transpose(1, 2)
- .contiguous()
- )
- attn_output = attn_output.view(bs, -1, dim)
- return (attn_output,)
- MODERNBERT_ATTENTION_FUNCTION = {
- "flash_attention_2": flash_attention_forward,
- "eager": eager_attention_forward,
- "sdpa": sdpa_attention_forward,
- }
- class ModernBertAttention(nn.Module):
- """Performs multi-headed self attention on a batch of unpadded sequences.
- If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
- If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
- which requires padding and unpadding inputs, adding some overhead.
- See `forward` method for additional details.
- """
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_id = layer_id
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
- )
- self.attention_dropout = config.attention_dropout
- self.deterministic_flash_attn = config.deterministic_flash_attn
- self.num_heads = config.num_attention_heads
- self.head_dim = config.hidden_size // config.num_attention_heads
- self.all_head_size = self.head_dim * self.num_heads
- self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
- if layer_id % config.global_attn_every_n_layers != 0:
- self.local_attention = (config.local_attention // 2, config.local_attention // 2)
- rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta
- max_position_embeddings = config.local_attention
- else:
- self.local_attention = (-1, -1)
- max_position_embeddings = config.max_position_embeddings
- rope_theta = config.global_rope_theta
- if config._attn_implementation == "flash_attention_2":
- self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
- dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
- )
- else:
- config_copy = copy.deepcopy(config)
- config_copy.rope_theta = rope_theta
- self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy)
- self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
- self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
- self.pruned_heads = set()
- def forward(
- self,
- hidden_states: torch.Tensor,
- output_attentions: Optional[bool] = False,
- **kwargs,
- ) -> torch.Tensor:
- qkv = self.Wqkv(hidden_states)
- bs = hidden_states.shape[0]
- if self.config._attn_implementation == "flash_attention_2":
- qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
- else:
- qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
- attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
- self,
- qkv=qkv,
- rotary_emb=self.rotary_emb,
- local_attention=self.local_attention,
- bs=bs,
- dim=self.all_head_size,
- output_attentions=output_attentions,
- **kwargs,
- )
- hidden_states = attn_outputs[0]
- hidden_states = self.out_drop(self.Wo(hidden_states))
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
- class ModernBertEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
- super().__init__()
- self.config = config
- if layer_id == 0:
- self.attn_norm = nn.Identity()
- else:
- self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.attn = ModernBertAttention(config=config, layer_id=layer_id)
- self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.mlp = ModernBertMLP(config)
- @torch.compile(dynamic=True)
- def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.mlp(self.mlp_norm(hidden_states))
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- output_attentions: Optional[bool] = False,
- ) -> torch.Tensor:
- attn_outputs = self.attn(
- self.attn_norm(hidden_states),
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- output_attentions=output_attentions,
- )
- hidden_states = hidden_states + attn_outputs[0]
- mlp_output = (
- self.compiled_mlp(hidden_states)
- if self.config.reference_compile
- else self.mlp(self.mlp_norm(hidden_states))
- )
- hidden_states = hidden_states + mlp_output
- return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
- @auto_docstring
- class ModernBertPreTrainedModel(PreTrainedModel):
- config: ModernBertConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = False
- def _init_weights(self, module: nn.Module):
- cutoff_factor = self.config.initializer_cutoff_factor
- if cutoff_factor is None:
- cutoff_factor = 3
- def init_weight(module: nn.Module, std: float):
- nn.init.trunc_normal_(
- module.weight,
- mean=0.0,
- std=std,
- a=-cutoff_factor * std,
- b=cutoff_factor * std,
- )
- if isinstance(module, nn.Linear):
- if module.bias is not None:
- nn.init.zeros_(module.bias)
- stds = {
- "in": self.config.initializer_range,
- "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
- "embedding": self.config.initializer_range,
- "final_out": self.config.hidden_size**-0.5,
- }
- if isinstance(module, ModernBertEmbeddings):
- init_weight(module.tok_embeddings, stds["embedding"])
- elif isinstance(module, ModernBertMLP):
- init_weight(module.Wi, stds["in"])
- init_weight(module.Wo, stds["out"])
- elif isinstance(module, ModernBertAttention):
- init_weight(module.Wqkv, stds["in"])
- init_weight(module.Wo, stds["out"])
- elif isinstance(module, ModernBertPredictionHead):
- init_weight(module.dense, stds["out"])
- elif isinstance(module, ModernBertForMaskedLM):
- init_weight(module.decoder, stds["out"])
- elif isinstance(
- module,
- (
- ModernBertForSequenceClassification,
- ModernBertForMultipleChoice,
- ModernBertForTokenClassification,
- ModernBertForQuestionAnswering,
- ),
- ):
- init_weight(module.classifier, stds["final_out"])
- elif isinstance(module, nn.LayerNorm):
- module.weight.data.fill_(1.0)
- if module.bias is not None:
- module.bias.data.zero_()
- def _check_and_adjust_attn_implementation(
- self, attn_implementation: Optional[str], is_init_check: bool = False
- ) -> str:
- """
- Checks and dispatches to hhe requested attention implementation.
- """
- # If the user didn't specify anything, try to use flash_attention_2 if available.
- # Otherwise we fall back to the default SDPA -> Eager from the super() method.
- # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
- # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
- try:
- attn_implementation = (
- "flash_attention_2"
- if attn_implementation is None and self._flash_attn_2_can_dispatch()
- else attn_implementation
- )
- except (ValueError, ImportError):
- pass
- return super()._check_and_adjust_attn_implementation(
- attn_implementation=attn_implementation, is_init_check=is_init_check
- )
- def _maybe_set_compile(self):
- if self.config.reference_compile is False:
- return
- if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
- if self.config.reference_compile:
- logger.warning_once(
- "If `accelerate` split the model across devices, `torch.compile` will not work. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
- if self.device.type == "mps":
- if self.config.reference_compile:
- logger.warning_once(
- "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
- if self.device.type == "cpu":
- if self.config.reference_compile:
- logger.warning_once(
- "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
- "Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
- if self.config.reference_compile is None:
- self.config.reference_compile = is_triton_available()
- def resize_token_embeddings(self, *args, **kwargs):
- model_embeds = super().resize_token_embeddings(*args, **kwargs)
- if self.config.reference_compile in {True, None}:
- if self.config.reference_compile:
- logger.warning_once(
- "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
- )
- self.config.reference_compile = False
- return model_embeds
- def _unpad_modernbert_input(
- inputs: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
- """
- Remove padding from input sequences.
- Args:
- inputs: (batch, seqlen, ...) or (batch, seqlen)
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
- position_ids: (batch, seqlen), int, position ids
- labels: (batch, seqlen), int, labels
- Returns:
- unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
- indices: (total_nnz)
- cu_seqlens: (batch + 1), the cumulative sequence lengths
- max_seqlen_in_batch: int
- unpadded_position_ids: (total_nnz) or None
- unpadded_labels: (total_nnz) or None
- """
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = int(seqlens_in_batch.max().item())
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
- if inputs.dim() == 2:
- unpadded_inputs = inputs.flatten()[indices]
- else:
- batch, seqlen, *rest = inputs.shape
- shape = batch * seqlen
- unpadded_inputs = inputs.view(shape, *rest)[indices]
- unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
- unpadded_labels = labels.flatten()[indices] if labels is not None else None
- return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
- def _pad_modernbert_output(
- inputs: torch.Tensor,
- indices: torch.Tensor,
- batch: int,
- seqlen: int,
- ) -> torch.Tensor:
- """
- Add padding to sequences.
- Args:
- inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
- indices: (total_nnz)
- batch: int, batch size
- seqlen: int, max sequence length
- Returns:
- padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
- """
- if inputs.dim() == 1:
- output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
- output[indices] = inputs
- padded_inputs = output.view(batch, seqlen)
- else:
- _, *rest = inputs.shape
- output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
- output[indices] = inputs
- padded_inputs = output.view(batch, seqlen, *rest)
- return padded_inputs
- @auto_docstring
- class ModernBertModel(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = ModernBertEmbeddings(config)
- self.layers = nn.ModuleList(
- [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
- )
- self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- self.gradient_checkpointing = False
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.tok_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.tok_embeddings = value
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]:
- r"""
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers when not using Flash Attention.
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- self._maybe_set_compile()
- if input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- if batch_size is None and seq_len is None:
- if inputs_embeds is not None:
- batch_size, seq_len = inputs_embeds.shape[:2]
- else:
- batch_size, seq_len = input_ids.shape[:2]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
- repad = False
- if self.config._attn_implementation == "flash_attention_2":
- if indices is None and cu_seqlens is None and max_seqlen is None:
- repad = True
- if inputs_embeds is None:
- with torch.no_grad():
- input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
- inputs=input_ids, attention_mask=attention_mask
- )
- else:
- inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
- inputs=inputs_embeds, attention_mask=attention_mask
- )
- else:
- if position_ids is None:
- position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
- attention_mask, sliding_window_mask = self._update_attention_mask(
- attention_mask, output_attentions=output_attentions
- )
- hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
- for encoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions and len(layer_outputs) > 1:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- hidden_states = self.final_norm(hidden_states)
- if repad:
- hidden_states = _pad_modernbert_output(
- inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
- )
- if all_hidden_states is not None:
- all_hidden_states = tuple(
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
- for hs in all_hidden_states
- )
- # If the attention implementation is FA2 and there is no need for repadding, there might still be the batch
- # dimension missing
- elif (
- self.config._attn_implementation == "flash_attention_2"
- and all_hidden_states is not None
- and all_hidden_states[-1].dim() == 2
- ):
- hidden_states = hidden_states.unsqueeze(0)
- all_hidden_states = tuple(hs.unsqueeze(0) for hs in all_hidden_states)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
- if output_attentions:
- if self.config._attn_implementation == "sdpa":
- logger.warning_once(
- "Outputting attentions is only supported with the 'eager' attention implementation, "
- 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
- )
- self.config._attn_implementation = "eager"
- elif self.config._attn_implementation != "eager":
- logger.warning_once(
- "Outputting attentions is only supported with the eager attention implementation, "
- f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
- " Setting `output_attentions=False`."
- )
- global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
- # Create position indices
- rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
- # Calculate distance between positions
- distance = torch.abs(rows - rows.T)
- # Create sliding window mask (1 for positions within window, 0 outside)
- window_mask = (
- (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
- )
- # Combine with existing mask
- sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
- return global_attention_mask, sliding_window_mask
- class ModernBertPredictionHead(nn.Module):
- def __init__(self, config: ModernBertConfig):
- super().__init__()
- self.config = config
- self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
- self.act = ACT2FN[config.classifier_activation]
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return self.norm(self.act(self.dense(hidden_states)))
- @auto_docstring(
- custom_intro="""
- The ModernBert Model with a decoder head on top that is used for masked language modeling.
- """
- )
- class ModernBertForMaskedLM(ModernBertPreTrainedModel):
- _tied_weights_keys = ["decoder.weight"]
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
- self.sparse_prediction = self.config.sparse_prediction
- self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.decoder
- def set_output_embeddings(self, new_embeddings: nn.Linear):
- self.decoder = new_embeddings
- @torch.compile(dynamic=True)
- def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
- return self.decoder(self.head(output))
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
- r"""
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers when not using Flash Attention.
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
- if self.config._attn_implementation == "flash_attention_2":
- if indices is None and cu_seqlens is None and max_seqlen is None:
- if batch_size is None and seq_len is None:
- if inputs_embeds is not None:
- batch_size, seq_len = inputs_embeds.shape[:2]
- else:
- batch_size, seq_len = input_ids.shape[:2]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
- if inputs_embeds is None:
- with torch.no_grad():
- input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
- inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
- )
- else:
- inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
- inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
- )
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
- if self.sparse_prediction and labels is not None:
- # flatten labels and output first
- labels = labels.view(-1)
- last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
- # then filter out the non-masked tokens
- mask_tokens = labels != self.sparse_pred_ignore_index
- last_hidden_state = last_hidden_state[mask_tokens]
- labels = labels[mask_tokens]
- logits = (
- self.compiled_head(last_hidden_state)
- if self.config.reference_compile
- else self.decoder(self.head(last_hidden_state))
- )
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
- if self.config._attn_implementation == "flash_attention_2":
- # Logits padding
- with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
- logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
- # Hidden states padding
- if getattr(outputs, "hidden_states", None) is not None:
- padded_hidden_states = []
- for hs in outputs.hidden_states:
- if hs.dim() == 3 and hs.shape[0] == 1:
- hs = hs.squeeze(0)
- padded_hidden_states.append(
- _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
- )
- outputs.hidden_states = tuple(padded_hidden_states)
- if not return_dict:
- output = (logits,)
- return ((loss,) + output) if loss is not None else output
- return MaskedLMOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The ModernBert Model with a sequence classification head on top that performs pooling.
- """
- )
- class ModernBertForSequenceClassification(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
- r"""
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers when not using Flash Attention.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
- if input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- if batch_size is None and seq_len is None:
- if inputs_embeds is not None:
- batch_size, seq_len = inputs_embeds.shape[:2]
- else:
- batch_size, seq_len = input_ids.shape[:2]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
- if self.config.classifier_pooling == "cls":
- last_hidden_state = last_hidden_state[:, 0]
- elif self.config.classifier_pooling == "mean":
- last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
- dim=1, keepdim=True
- )
- pooled_output = self.head(last_hidden_state)
- pooled_output = self.drop(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,)
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
- """
- )
- class ModernBertForTokenClassification(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
- r"""
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers when not using Flash Attention.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
- last_hidden_state = self.head(last_hidden_state)
- last_hidden_state = self.drop(last_hidden_state)
- logits = self.classifier(last_hidden_state)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring
- class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- start_positions: Optional[torch.Tensor] = None,
- end_positions: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
- r"""
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers when not using Flash Attention.
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- self._maybe_set_compile()
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0]
- last_hidden_state = self.head(last_hidden_state)
- last_hidden_state = self.drop(last_hidden_state)
- logits = self.classifier(last_hidden_state)
- start_logits, end_logits = logits.split(1, dim=-1)
- start_logits = start_logits.squeeze(-1).contiguous()
- end_logits = end_logits.squeeze(-1).contiguous()
- loss = None
- if start_positions is not None and end_positions is not None:
- loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
- if not return_dict:
- output = (start_logits, end_logits) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return QuestionAnsweringModelOutput(
- loss=loss,
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @auto_docstring(
- custom_intro="""
- The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
- """
- )
- class ModernBertForMultipleChoice(ModernBertPreTrainedModel):
- def __init__(self, config: ModernBertConfig):
- super().__init__(config)
- self.config = config
- self.model = ModernBertModel(config)
- self.head = ModernBertPredictionHead(config)
- self.drop = torch.nn.Dropout(config.classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, 1)
- # Initialize weights and apply final processing
- self.post_init()
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- sliding_window_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- indices: Optional[torch.Tensor] = None,
- cu_seqlens: Optional[torch.Tensor] = None,
- max_seqlen: Optional[int] = None,
- batch_size: Optional[int] = None,
- seq_len: Optional[int] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
- r"""
- sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
- perform global attention, while the rest perform local attention. This mask is used to avoid attending to
- far-away tokens in the local attention layers when not using Flash Attention.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
- num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
- indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
- Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
- cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
- Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
- max_seqlen (`int`, *optional*):
- Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
- batch_size (`int`, *optional*):
- Batch size of the input sequences. Used to pad the output tensors.
- seq_len (`int`, *optional*):
- Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
- input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
- attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
- position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
- inputs_embeds = (
- inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
- if inputs_embeds is not None
- else None
- )
- self._maybe_set_compile()
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- sliding_window_mask=sliding_window_mask,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- indices=indices,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- batch_size=batch_size,
- seq_len=seq_len,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = outputs[0] # shape (num_choices, seq_len, hidden_size)
- # If classifier_pooling is "cls", isolate the <cls> token
- if self.config.classifier_pooling == "cls":
- indices_0 = torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device)
- # for left or right padding, <cls> is the first non-pad token
- if attention_mask is not None:
- cls_mask = attention_mask.argmax(dim=-1).to(last_hidden_state.device)
- # if no pad, <cls> is the first token
- else:
- cls_mask = torch.tensor(0, dtype=torch.long, device=last_hidden_state.device)
- # extract the <cls> token for the logits
- last_hidden_state = last_hidden_state[indices_0, cls_mask]
- # If classifier_pooling is "mean", pool the hidden states by averaging over the sequence length
- elif self.config.classifier_pooling == "mean":
- num_non_pad_tokens = attention_mask.sum(dim=1, keepdim=True)
- last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / num_non_pad_tokens
- pooled_output = self.head(last_hidden_state)
- pooled_output = self.drop(pooled_output)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.view(-1, num_choices)
- loss = None
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
- if not return_dict:
- output = (reshaped_logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return MultipleChoiceModelOutput(
- loss=loss,
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- __all__ = [
- "ModernBertModel",
- "ModernBertPreTrainedModel",
- "ModernBertForMaskedLM",
- "ModernBertForSequenceClassification",
- "ModernBertForTokenClassification",
- "ModernBertForQuestionAnswering",
- "ModernBertForMultipleChoice",
- ]
|