modeling_mbart.py 87 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910
  1. # coding=utf-8
  2. # Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MBART model."""
  16. import math
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import (
  25. AttentionMaskConverter,
  26. _prepare_4d_attention_mask,
  27. _prepare_4d_attention_mask_for_sdpa,
  28. )
  29. from ...modeling_flash_attention_utils import (
  30. FlashAttentionKwargs,
  31. )
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutput,
  35. BaseModelOutputWithPastAndCrossAttentions,
  36. CausalLMOutputWithCrossAttentions,
  37. Seq2SeqLMOutput,
  38. Seq2SeqModelOutput,
  39. Seq2SeqQuestionAnsweringModelOutput,
  40. Seq2SeqSequenceClassifierOutput,
  41. )
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import (
  45. auto_docstring,
  46. is_torch_flex_attn_available,
  47. is_torchdynamo_compiling,
  48. logging,
  49. )
  50. from ...utils.deprecation import deprecate_kwarg
  51. from .configuration_mbart import MBartConfig
  52. if is_torch_flex_attn_available():
  53. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  54. logger = logging.get_logger(__name__)
  55. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
  56. """
  57. Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
  58. have a single `decoder_start_token_id` in contrast to other Bart-like models.
  59. """
  60. prev_output_tokens = input_ids.clone()
  61. if pad_token_id is None:
  62. raise ValueError("self.model.config.pad_token_id has to be defined.")
  63. # replace possible -100 values in labels by `pad_token_id`
  64. prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
  65. index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  66. decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
  67. prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
  68. prev_output_tokens[:, 0] = decoder_start_tokens
  69. return prev_output_tokens
  70. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
  71. class MBartLearnedPositionalEmbedding(nn.Embedding):
  72. """
  73. This module learns positional embeddings up to a fixed maximum size.
  74. """
  75. def __init__(self, num_embeddings: int, embedding_dim: int):
  76. # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
  77. # and adjust num_embeddings appropriately. Other models don't have this hack
  78. self.offset = 2
  79. super().__init__(num_embeddings + self.offset, embedding_dim)
  80. def forward(
  81. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  82. ):
  83. """`input_ids' shape is expected to be [bsz x seqlen]."""
  84. if position_ids is None:
  85. bsz, seq_len = input_ids.shape[:2]
  86. position_ids = torch.arange(
  87. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  88. ).expand(bsz, -1)
  89. else:
  90. position_ids = position_ids.unsqueeze(0)
  91. return super().forward(position_ids + self.offset)
  92. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
  93. class MBartScaledWordEmbedding(nn.Embedding):
  94. """
  95. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  96. """
  97. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  98. super().__init__(num_embeddings, embedding_dim, padding_idx)
  99. self.embed_scale = embed_scale
  100. def forward(self, input_ids: torch.Tensor):
  101. return super().forward(input_ids) * self.embed_scale
  102. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  103. def eager_attention_forward(
  104. module: nn.Module,
  105. query: torch.Tensor,
  106. key: torch.Tensor,
  107. value: torch.Tensor,
  108. attention_mask: Optional[torch.Tensor],
  109. scaling: Optional[float] = None,
  110. dropout: float = 0.0,
  111. head_mask: Optional[torch.Tensor] = None,
  112. **kwargs,
  113. ):
  114. if scaling is None:
  115. scaling = query.size(-1) ** -0.5
  116. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  117. if attention_mask is not None:
  118. attn_weights = attn_weights + attention_mask
  119. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  120. if head_mask is not None:
  121. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  122. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  123. attn_output = torch.matmul(attn_weights, value)
  124. attn_output = attn_output.transpose(1, 2).contiguous()
  125. return attn_output, attn_weights
  126. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
  127. class MBartAttention(nn.Module):
  128. """Multi-headed attention from 'Attention Is All You Need' paper"""
  129. def __init__(
  130. self,
  131. embed_dim: int,
  132. num_heads: int,
  133. dropout: float = 0.0,
  134. is_decoder: bool = False,
  135. bias: bool = True,
  136. is_causal: bool = False,
  137. config: Optional[MBartConfig] = None,
  138. layer_idx: Optional[int] = None,
  139. ):
  140. super().__init__()
  141. self.embed_dim = embed_dim
  142. self.num_heads = num_heads
  143. self.dropout = dropout
  144. self.head_dim = embed_dim // num_heads
  145. self.config = config
  146. if (self.head_dim * num_heads) != self.embed_dim:
  147. raise ValueError(
  148. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  149. f" and `num_heads`: {num_heads})."
  150. )
  151. self.scaling = self.head_dim**-0.5
  152. self.is_decoder = is_decoder
  153. self.is_causal = is_causal
  154. self.layer_idx = layer_idx
  155. if layer_idx is None and self.is_decoder:
  156. logger.warning_once(
  157. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  158. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  159. "when creating this class."
  160. )
  161. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  162. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  163. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  164. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  165. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  166. def forward(
  167. self,
  168. hidden_states: torch.Tensor,
  169. key_value_states: Optional[torch.Tensor] = None,
  170. past_key_values: Optional[Cache] = None,
  171. attention_mask: Optional[torch.Tensor] = None,
  172. layer_head_mask: Optional[torch.Tensor] = None,
  173. output_attentions: bool = False,
  174. cache_position: Optional[torch.Tensor] = None,
  175. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  176. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  177. **kwargs: Unpack[FlashAttentionKwargs],
  178. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  179. """Input shape: Batch x Time x Channel"""
  180. # if key_value_states are provided this layer is used as a cross-attention layer
  181. # for the decoder
  182. is_cross_attention = key_value_states is not None
  183. # determine input shapes
  184. bsz, tgt_len = hidden_states.shape[:-1]
  185. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  186. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  187. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  188. # get query proj
  189. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  190. is_updated = False
  191. if past_key_values is not None:
  192. if isinstance(past_key_values, EncoderDecoderCache):
  193. is_updated = past_key_values.is_updated.get(self.layer_idx)
  194. if is_cross_attention:
  195. # after the first generated id, we can subsequently re-use all key/value_states from cache
  196. curr_past_key_value = past_key_values.cross_attention_cache
  197. else:
  198. curr_past_key_value = past_key_values.self_attention_cache
  199. else:
  200. curr_past_key_value = past_key_values
  201. current_states = key_value_states if is_cross_attention else hidden_states
  202. if is_cross_attention and past_key_values is not None and is_updated:
  203. # reuse k,v, cross_attentions
  204. key_states = curr_past_key_value.layers[self.layer_idx].keys
  205. value_states = curr_past_key_value.layers[self.layer_idx].values
  206. else:
  207. key_states = self.k_proj(current_states)
  208. value_states = self.v_proj(current_states)
  209. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  210. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  211. if past_key_values is not None:
  212. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  213. cache_position = cache_position if not is_cross_attention else None
  214. key_states, value_states = curr_past_key_value.update(
  215. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  216. )
  217. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  218. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  219. past_key_values.is_updated[self.layer_idx] = True
  220. attention_interface: Callable = eager_attention_forward
  221. if self.config._attn_implementation != "eager":
  222. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  223. attn_output, attn_weights = attention_interface(
  224. self,
  225. query_states,
  226. key_states,
  227. value_states,
  228. attention_mask,
  229. dropout=0.0 if not self.training else self.dropout,
  230. scaling=self.scaling,
  231. output_attentions=output_attentions,
  232. head_mask=layer_head_mask,
  233. **kwargs,
  234. )
  235. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  236. attn_output = self.out_proj(attn_output)
  237. return attn_output, attn_weights
  238. class MBartEncoderLayer(GradientCheckpointingLayer):
  239. def __init__(self, config: MBartConfig):
  240. super().__init__()
  241. self.embed_dim = config.d_model
  242. self.self_attn = MBartAttention(
  243. embed_dim=self.embed_dim,
  244. num_heads=config.encoder_attention_heads,
  245. dropout=config.attention_dropout,
  246. config=config,
  247. )
  248. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  249. self.dropout = config.dropout
  250. self.activation_fn = ACT2FN[config.activation_function]
  251. self.activation_dropout = config.activation_dropout
  252. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  253. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  254. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  255. def forward(
  256. self,
  257. hidden_states: torch.Tensor,
  258. attention_mask: torch.Tensor,
  259. layer_head_mask: torch.Tensor,
  260. output_attentions: bool = False,
  261. ) -> torch.Tensor:
  262. """
  263. Args:
  264. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  265. attention_mask (`torch.FloatTensor`): attention mask of size
  266. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  267. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  268. `(encoder_attention_heads,)`.
  269. output_attentions (`bool`, *optional*):
  270. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  271. returned tensors for more detail.
  272. """
  273. residual = hidden_states
  274. hidden_states = self.self_attn_layer_norm(hidden_states)
  275. hidden_states, attn_weights = self.self_attn(
  276. hidden_states=hidden_states,
  277. attention_mask=attention_mask,
  278. layer_head_mask=layer_head_mask,
  279. output_attentions=output_attentions,
  280. )
  281. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  282. hidden_states = residual + hidden_states
  283. residual = hidden_states
  284. hidden_states = self.final_layer_norm(hidden_states)
  285. hidden_states = self.activation_fn(self.fc1(hidden_states))
  286. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  287. hidden_states = self.fc2(hidden_states)
  288. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  289. hidden_states = residual + hidden_states
  290. if hidden_states.dtype == torch.float16:
  291. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  292. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  293. return hidden_states, attn_weights
  294. class MBartDecoderLayer(GradientCheckpointingLayer):
  295. def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None):
  296. super().__init__()
  297. self.embed_dim = config.d_model
  298. self.self_attn = MBartAttention(
  299. embed_dim=self.embed_dim,
  300. num_heads=config.decoder_attention_heads,
  301. dropout=config.attention_dropout,
  302. is_decoder=True,
  303. is_causal=True,
  304. config=config,
  305. layer_idx=layer_idx,
  306. )
  307. self.dropout = config.dropout
  308. self.activation_fn = ACT2FN[config.activation_function]
  309. self.activation_dropout = config.activation_dropout
  310. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  311. self.encoder_attn = MBartAttention(
  312. self.embed_dim,
  313. config.decoder_attention_heads,
  314. dropout=config.attention_dropout,
  315. is_decoder=True,
  316. config=config,
  317. layer_idx=layer_idx,
  318. )
  319. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  320. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  321. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  322. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  323. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  324. def forward(
  325. self,
  326. hidden_states: torch.Tensor,
  327. attention_mask: Optional[torch.Tensor] = None,
  328. encoder_hidden_states: Optional[torch.Tensor] = None,
  329. encoder_attention_mask: Optional[torch.Tensor] = None,
  330. layer_head_mask: Optional[torch.Tensor] = None,
  331. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  332. past_key_values: Optional[Cache] = None,
  333. output_attentions: Optional[bool] = False,
  334. use_cache: Optional[bool] = True,
  335. cache_position: Optional[torch.Tensor] = None,
  336. ) -> torch.Tensor:
  337. """
  338. Args:
  339. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  340. attention_mask (`torch.FloatTensor`): attention mask of size
  341. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  342. encoder_hidden_states (`torch.FloatTensor`):
  343. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  344. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  345. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  346. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  347. `(encoder_attention_heads,)`.
  348. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  349. size `(decoder_attention_heads,)`.
  350. past_key_values (`Cache`): cached past key and value projection states
  351. output_attentions (`bool`, *optional*):
  352. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  353. returned tensors for more detail.
  354. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  355. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  356. cache in the correct position and to infer the complete sequence length.
  357. """
  358. residual = hidden_states
  359. hidden_states = self.self_attn_layer_norm(hidden_states)
  360. # Self Attention
  361. hidden_states, self_attn_weights = self.self_attn(
  362. hidden_states=hidden_states,
  363. past_key_values=past_key_values,
  364. attention_mask=attention_mask,
  365. layer_head_mask=layer_head_mask,
  366. output_attentions=output_attentions,
  367. cache_position=cache_position,
  368. )
  369. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  370. hidden_states = residual + hidden_states
  371. # Cross-Attention Block
  372. cross_attn_weights = None
  373. if encoder_hidden_states is not None:
  374. residual = hidden_states
  375. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  376. hidden_states, cross_attn_weights = self.encoder_attn(
  377. hidden_states=hidden_states,
  378. key_value_states=encoder_hidden_states,
  379. attention_mask=encoder_attention_mask,
  380. layer_head_mask=cross_attn_layer_head_mask,
  381. past_key_values=past_key_values,
  382. output_attentions=output_attentions,
  383. )
  384. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  385. hidden_states = residual + hidden_states
  386. # Fully Connected
  387. residual = hidden_states
  388. hidden_states = self.final_layer_norm(hidden_states)
  389. hidden_states = self.activation_fn(self.fc1(hidden_states))
  390. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  391. hidden_states = self.fc2(hidden_states)
  392. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  393. hidden_states = residual + hidden_states
  394. outputs = (hidden_states,)
  395. if output_attentions:
  396. outputs += (self_attn_weights, cross_attn_weights)
  397. return outputs
  398. # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MBart
  399. class MBartClassificationHead(nn.Module):
  400. """Head for sentence-level classification tasks."""
  401. def __init__(
  402. self,
  403. input_dim: int,
  404. inner_dim: int,
  405. num_classes: int,
  406. pooler_dropout: float,
  407. ):
  408. super().__init__()
  409. self.dense = nn.Linear(input_dim, inner_dim)
  410. self.dropout = nn.Dropout(p=pooler_dropout)
  411. self.out_proj = nn.Linear(inner_dim, num_classes)
  412. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  413. hidden_states = self.dropout(hidden_states)
  414. hidden_states = self.dense(hidden_states)
  415. hidden_states = torch.tanh(hidden_states)
  416. hidden_states = self.dropout(hidden_states)
  417. hidden_states = self.out_proj(hidden_states)
  418. return hidden_states
  419. @auto_docstring
  420. class MBartPreTrainedModel(PreTrainedModel):
  421. config: MBartConfig
  422. base_model_prefix = "model"
  423. supports_gradient_checkpointing = True
  424. _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
  425. _supports_flash_attn = True
  426. _supports_sdpa = True
  427. _supports_flex_attn = True
  428. _can_compile_fullgraph = True
  429. def _init_weights(self, module):
  430. std = self.config.init_std
  431. if isinstance(module, nn.Linear):
  432. module.weight.data.normal_(mean=0.0, std=std)
  433. if module.bias is not None:
  434. module.bias.data.zero_()
  435. elif isinstance(module, nn.Embedding):
  436. module.weight.data.normal_(mean=0.0, std=std)
  437. if module.padding_idx is not None:
  438. module.weight.data[module.padding_idx].zero_()
  439. elif isinstance(module, nn.LayerNorm):
  440. module.weight.data.fill_(1.0)
  441. module.bias.data.zero_()
  442. @property
  443. def dummy_inputs(self):
  444. pad_token = self.config.pad_token_id
  445. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  446. dummy_inputs = {
  447. "attention_mask": input_ids.ne(pad_token),
  448. "input_ids": input_ids,
  449. }
  450. return dummy_inputs
  451. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  452. def _update_full_mask(
  453. self,
  454. attention_mask: Union[torch.Tensor, None],
  455. inputs_embeds: torch.Tensor,
  456. ):
  457. if attention_mask is not None:
  458. if self.config._attn_implementation == "flash_attention_2":
  459. attention_mask = attention_mask if 0 in attention_mask else None
  460. elif self.config._attn_implementation == "sdpa":
  461. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  462. # the manual implementation that requires a 4D causal mask in all cases.
  463. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  464. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  465. elif self.config._attn_implementation == "flex_attention":
  466. if isinstance(attention_mask, torch.Tensor):
  467. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  468. else:
  469. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  470. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  471. return attention_mask
  472. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  473. def _update_causal_mask(
  474. self,
  475. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  476. input_tensor: torch.Tensor,
  477. cache_position: torch.Tensor,
  478. past_key_values: Cache,
  479. ):
  480. if self.config._attn_implementation == "flex_attention":
  481. if isinstance(attention_mask, torch.Tensor):
  482. attention_mask = make_flex_block_causal_mask(attention_mask)
  483. # Other attention flavors support in-built causal (when `mask is None`)
  484. # while we need to create our specific block mask regardless
  485. elif attention_mask is None:
  486. attention_mask = make_flex_block_causal_mask(
  487. torch.ones(
  488. size=(input_tensor.shape[0], input_tensor.shape[1]),
  489. device=attention_mask.device,
  490. )
  491. )
  492. return attention_mask
  493. if self.config._attn_implementation == "flash_attention_2":
  494. if attention_mask is not None and (attention_mask == 0.0).any():
  495. return attention_mask
  496. return None
  497. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  498. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  499. # to infer the attention mask.
  500. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  501. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  502. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  503. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  504. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  505. attention_mask,
  506. inputs_embeds=input_tensor,
  507. past_key_values_length=past_seen_tokens,
  508. is_training=self.training,
  509. ):
  510. return None
  511. dtype = input_tensor.dtype
  512. sequence_length = input_tensor.shape[1]
  513. if using_compilable_cache:
  514. target_length = past_key_values.get_max_cache_shape()
  515. else:
  516. target_length = (
  517. attention_mask.shape[-1]
  518. if isinstance(attention_mask, torch.Tensor)
  519. else past_seen_tokens + sequence_length + 1
  520. )
  521. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  522. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  523. attention_mask,
  524. sequence_length=sequence_length,
  525. target_length=target_length,
  526. dtype=dtype,
  527. cache_position=cache_position,
  528. batch_size=input_tensor.shape[0],
  529. )
  530. if (
  531. self.config._attn_implementation == "sdpa"
  532. and attention_mask is not None
  533. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  534. ):
  535. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  536. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  537. # Details: https://github.com/pytorch/pytorch/issues/110213
  538. min_dtype = torch.finfo(dtype).min
  539. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  540. return causal_mask
  541. @staticmethod
  542. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  543. def _prepare_4d_causal_attention_mask_with_cache_position(
  544. attention_mask: torch.Tensor,
  545. sequence_length: int,
  546. target_length: int,
  547. dtype: torch.dtype,
  548. cache_position: torch.Tensor,
  549. batch_size: int,
  550. **kwargs,
  551. ):
  552. """
  553. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  554. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  555. Args:
  556. attention_mask (`torch.Tensor`):
  557. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  558. `(batch_size, 1, query_length, key_value_length)`.
  559. sequence_length (`int`):
  560. The sequence length being processed.
  561. target_length (`int`):
  562. The target length: when generating with static cache, the mask should be as long as the static cache,
  563. to account for the 0 padding, the part of the cache that is not filled yet.
  564. dtype (`torch.dtype`):
  565. The dtype to use for the 4D attention mask.
  566. cache_position (`torch.Tensor`):
  567. Indices depicting the position of the input sequence tokens in the sequence.
  568. batch_size (`torch.Tensor`):
  569. Batch size.
  570. """
  571. if attention_mask is not None and attention_mask.dim() == 4:
  572. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  573. causal_mask = attention_mask
  574. else:
  575. min_dtype = torch.finfo(dtype).min
  576. causal_mask = torch.full(
  577. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  578. )
  579. if sequence_length != 1:
  580. causal_mask = torch.triu(causal_mask, diagonal=1)
  581. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  582. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  583. if attention_mask is not None:
  584. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  585. mask_length = attention_mask.shape[-1]
  586. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  587. causal_mask.device
  588. )
  589. padding_mask = padding_mask == 0
  590. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  591. padding_mask, min_dtype
  592. )
  593. return causal_mask
  594. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  595. def _update_cross_attn_mask(
  596. self,
  597. encoder_hidden_states: Union[torch.Tensor, None],
  598. encoder_attention_mask: Union[torch.Tensor, None],
  599. input_shape: torch.Size,
  600. inputs_embeds: torch.Tensor,
  601. ):
  602. # expand encoder attention mask
  603. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  604. if self.config._attn_implementation == "flash_attention_2":
  605. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  606. elif self.config._attn_implementation == "sdpa":
  607. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  608. # the manual implementation that requires a 4D causal mask in all cases.
  609. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  610. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  611. encoder_attention_mask,
  612. inputs_embeds.dtype,
  613. tgt_len=input_shape[-1],
  614. )
  615. elif self.config._attn_implementation == "flex_attention":
  616. if isinstance(encoder_attention_mask, torch.Tensor):
  617. encoder_attention_mask = make_flex_block_causal_mask(
  618. encoder_attention_mask,
  619. query_length=input_shape[-1],
  620. is_causal=False,
  621. )
  622. else:
  623. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  624. encoder_attention_mask = _prepare_4d_attention_mask(
  625. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  626. )
  627. return encoder_attention_mask
  628. class MBartEncoder(MBartPreTrainedModel):
  629. """
  630. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  631. [`MBartEncoderLayer`].
  632. Args:
  633. config: MBartConfig
  634. embed_tokens (nn.Embedding): output embedding
  635. """
  636. def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
  637. super().__init__(config)
  638. self.dropout = config.dropout
  639. self.layerdrop = config.encoder_layerdrop
  640. embed_dim = config.d_model
  641. self.padding_idx = config.pad_token_id
  642. self.max_source_positions = config.max_position_embeddings
  643. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  644. self.embed_tokens = MBartScaledWordEmbedding(
  645. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  646. )
  647. if embed_tokens is not None:
  648. self.embed_tokens.weight = embed_tokens.weight
  649. self.embed_positions = MBartLearnedPositionalEmbedding(
  650. config.max_position_embeddings,
  651. embed_dim,
  652. )
  653. self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
  654. self.config = config
  655. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  656. self.layer_norm = nn.LayerNorm(config.d_model)
  657. self.gradient_checkpointing = False
  658. # Initialize weights and apply final processing
  659. self.post_init()
  660. def _backward_compatibility_gradient_checkpointing(self):
  661. # Override to not delete the attribute from the config
  662. if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
  663. self.gradient_checkpointing_enable()
  664. def forward(
  665. self,
  666. input_ids: Optional[torch.LongTensor] = None,
  667. attention_mask: Optional[torch.Tensor] = None,
  668. head_mask: Optional[torch.Tensor] = None,
  669. inputs_embeds: Optional[torch.FloatTensor] = None,
  670. output_attentions: Optional[bool] = None,
  671. output_hidden_states: Optional[bool] = None,
  672. return_dict: Optional[bool] = None,
  673. ) -> Union[tuple, BaseModelOutput]:
  674. r"""
  675. Args:
  676. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  677. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  678. provide it.
  679. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  680. [`PreTrainedTokenizer.__call__`] for details.
  681. [What are input IDs?](../glossary#input-ids)
  682. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  683. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  684. - 1 for tokens that are **not masked**,
  685. - 0 for tokens that are **masked**.
  686. [What are attention masks?](../glossary#attention-mask)
  687. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  688. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  689. - 1 indicates the head is **not masked**,
  690. - 0 indicates the head is **masked**.
  691. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  692. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  693. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  694. than the model's internal embedding lookup matrix.
  695. output_attentions (`bool`, *optional*):
  696. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  697. returned tensors for more detail.
  698. output_hidden_states (`bool`, *optional*):
  699. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  700. for more detail.
  701. return_dict (`bool`, *optional*):
  702. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  703. """
  704. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  705. output_hidden_states = (
  706. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  707. )
  708. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  709. # retrieve input_ids and inputs_embeds
  710. if input_ids is not None and inputs_embeds is not None:
  711. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  712. elif input_ids is not None:
  713. input = input_ids
  714. input_shape = input.shape
  715. input_ids = input_ids.view(-1, input_shape[-1])
  716. elif inputs_embeds is not None:
  717. input = inputs_embeds[:, :, -1]
  718. else:
  719. raise ValueError("You have to specify either input_ids or inputs_embeds")
  720. if inputs_embeds is None:
  721. inputs_embeds = self.embed_tokens(input_ids)
  722. embed_pos = self.embed_positions(input)
  723. hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device)
  724. hidden_states = self.layernorm_embedding(hidden_states)
  725. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  726. attention_mask = self._update_full_mask(
  727. attention_mask,
  728. inputs_embeds,
  729. )
  730. encoder_states = () if output_hidden_states else None
  731. all_attentions = () if output_attentions else None
  732. # check if head_mask has a correct number of layers specified if desired
  733. if head_mask is not None:
  734. if head_mask.size()[0] != len(self.layers):
  735. raise ValueError(
  736. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  737. f" {head_mask.size()[0]}."
  738. )
  739. for idx, encoder_layer in enumerate(self.layers):
  740. if output_hidden_states:
  741. encoder_states = encoder_states + (hidden_states,)
  742. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  743. to_drop = False
  744. if self.training:
  745. dropout_probability = torch.rand([])
  746. if dropout_probability < self.layerdrop: # skip the layer
  747. to_drop = True
  748. if to_drop:
  749. layer_outputs = (None, None)
  750. else:
  751. layer_outputs = encoder_layer(
  752. hidden_states,
  753. attention_mask,
  754. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  755. output_attentions=output_attentions,
  756. )
  757. hidden_states = layer_outputs[0]
  758. if output_attentions:
  759. all_attentions = all_attentions + (layer_outputs[1],)
  760. hidden_states = self.layer_norm(hidden_states)
  761. if output_hidden_states:
  762. encoder_states = encoder_states + (hidden_states,)
  763. if not return_dict:
  764. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  765. return BaseModelOutput(
  766. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  767. )
  768. class MBartDecoder(MBartPreTrainedModel):
  769. """
  770. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
  771. Args:
  772. config: MBartConfig
  773. embed_tokens (nn.Embedding): output embedding
  774. """
  775. def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
  776. super().__init__(config)
  777. self.dropout = config.dropout
  778. self.layerdrop = config.decoder_layerdrop
  779. self.padding_idx = config.pad_token_id
  780. self.max_target_positions = config.max_position_embeddings
  781. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  782. self.embed_tokens = MBartScaledWordEmbedding(
  783. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  784. )
  785. if embed_tokens is not None:
  786. self.embed_tokens.weight = embed_tokens.weight
  787. self.embed_positions = MBartLearnedPositionalEmbedding(
  788. config.max_position_embeddings,
  789. config.d_model,
  790. )
  791. self.layers = nn.ModuleList([MBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  792. self.config = config
  793. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  794. self.layer_norm = nn.LayerNorm(config.d_model)
  795. self.gradient_checkpointing = False
  796. # Initialize weights and apply final processing
  797. self.post_init()
  798. def forward(
  799. self,
  800. input_ids: Optional[torch.LongTensor] = None,
  801. attention_mask: Optional[torch.Tensor] = None,
  802. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  803. encoder_attention_mask: Optional[torch.LongTensor] = None,
  804. head_mask: Optional[torch.Tensor] = None,
  805. cross_attn_head_mask: Optional[torch.Tensor] = None,
  806. past_key_values: Optional[Cache] = None,
  807. inputs_embeds: Optional[torch.FloatTensor] = None,
  808. use_cache: Optional[bool] = None,
  809. output_attentions: Optional[bool] = None,
  810. output_hidden_states: Optional[bool] = None,
  811. return_dict: Optional[bool] = None,
  812. cache_position: Optional[torch.Tensor] = None,
  813. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  814. r"""
  815. Args:
  816. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  817. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  818. provide it.
  819. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  820. [`PreTrainedTokenizer.__call__`] for details.
  821. [What are input IDs?](../glossary#input-ids)
  822. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  823. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  824. - 1 for tokens that are **not masked**,
  825. - 0 for tokens that are **masked**.
  826. [What are attention masks?](../glossary#attention-mask)
  827. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  828. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  829. of the decoder.
  830. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  831. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  832. selected in `[0, 1]`:
  833. - 1 for tokens that are **not masked**,
  834. - 0 for tokens that are **masked**.
  835. [What are attention masks?](../glossary#attention-mask)
  836. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  837. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  838. - 1 indicates the head is **not masked**,
  839. - 0 indicates the head is **masked**.
  840. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  841. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  842. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  843. - 1 indicates the head is **not masked**,
  844. - 0 indicates the head is **masked**.
  845. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  846. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  847. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  848. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  849. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  850. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  851. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  852. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  853. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  854. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  855. than the model's internal embedding lookup matrix.
  856. output_attentions (`bool`, *optional*):
  857. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  858. returned tensors for more detail.
  859. output_hidden_states (`bool`, *optional*):
  860. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  861. for more detail.
  862. return_dict (`bool`, *optional*):
  863. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  864. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  865. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  866. cache in the correct position and to infer the complete sequence length.
  867. """
  868. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  869. output_hidden_states = (
  870. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  871. )
  872. use_cache = use_cache if use_cache is not None else self.config.use_cache
  873. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  874. # retrieve input_ids and inputs_embeds
  875. if (input_ids is None) ^ (inputs_embeds is not None):
  876. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  877. elif input_ids is not None:
  878. input = input_ids
  879. input_shape = input.shape
  880. input_ids = input_ids.view(-1, input_shape[-1])
  881. elif inputs_embeds is not None:
  882. input_shape = inputs_embeds.size()[:-1]
  883. input = inputs_embeds[:, :, -1]
  884. else:
  885. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  886. if inputs_embeds is None:
  887. inputs_embeds = self.embed_tokens(input)
  888. if self.gradient_checkpointing and self.training:
  889. if use_cache:
  890. logger.warning_once(
  891. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  892. )
  893. use_cache = False
  894. # initialize `past_key_values`
  895. if use_cache and past_key_values is None:
  896. past_key_values = (
  897. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  898. if encoder_hidden_states is not None
  899. else DynamicCache(config=self.config)
  900. )
  901. if use_cache and isinstance(past_key_values, tuple):
  902. logger.warning_once(
  903. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  904. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  905. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  906. )
  907. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  908. batch_size, seq_length = inputs_embeds.size()[:-1]
  909. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  910. if cache_position is None:
  911. cache_position = torch.arange(
  912. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  913. )
  914. if attention_mask is None and not is_torchdynamo_compiling():
  915. # required mask seq length can be calculated via length of past cache
  916. mask_seq_length = past_key_values_length + seq_length
  917. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  918. self_attn_cache = (
  919. past_key_values.self_attention_cache
  920. if isinstance(past_key_values, EncoderDecoderCache)
  921. else past_key_values
  922. )
  923. causal_mask = self._update_causal_mask(
  924. attention_mask,
  925. inputs_embeds,
  926. cache_position,
  927. self_attn_cache,
  928. )
  929. encoder_attention_mask = self._update_cross_attn_mask(
  930. encoder_hidden_states,
  931. encoder_attention_mask,
  932. input_shape,
  933. inputs_embeds,
  934. )
  935. # embed positions
  936. position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
  937. hidden_states = inputs_embeds + position_ids.to(inputs_embeds.device)
  938. hidden_states = self.layernorm_embedding(hidden_states)
  939. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  940. # decoder layers
  941. all_hidden_states = () if output_hidden_states else None
  942. all_self_attns = () if output_attentions else None
  943. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  944. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  945. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  946. if attn_mask is not None:
  947. if attn_mask.size()[0] != len(self.layers):
  948. raise ValueError(
  949. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  950. f" {attn_mask.size()[0]}."
  951. )
  952. for idx, decoder_layer in enumerate(self.layers):
  953. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  954. if output_hidden_states:
  955. all_hidden_states += (hidden_states,)
  956. if self.training:
  957. dropout_probability = torch.rand([])
  958. if dropout_probability < self.layerdrop:
  959. continue
  960. layer_outputs = decoder_layer(
  961. hidden_states,
  962. causal_mask,
  963. encoder_hidden_states, # as a positional argument for gradient checkpointing
  964. encoder_attention_mask=encoder_attention_mask,
  965. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  966. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  967. past_key_values=past_key_values,
  968. output_attentions=output_attentions,
  969. use_cache=use_cache,
  970. cache_position=cache_position,
  971. )
  972. hidden_states = layer_outputs[0]
  973. if output_attentions:
  974. all_self_attns += (layer_outputs[1],)
  975. if encoder_hidden_states is not None:
  976. all_cross_attentions += (layer_outputs[2],)
  977. hidden_states = self.layer_norm(hidden_states)
  978. # add hidden states from the last decoder layer
  979. if output_hidden_states:
  980. all_hidden_states += (hidden_states,)
  981. if not return_dict:
  982. return tuple(
  983. v
  984. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  985. if v is not None
  986. )
  987. return BaseModelOutputWithPastAndCrossAttentions(
  988. last_hidden_state=hidden_states,
  989. past_key_values=past_key_values,
  990. hidden_states=all_hidden_states,
  991. attentions=all_self_attns,
  992. cross_attentions=all_cross_attentions,
  993. )
  994. @auto_docstring
  995. class MBartModel(MBartPreTrainedModel):
  996. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  997. def __init__(self, config: MBartConfig):
  998. super().__init__(config)
  999. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  1000. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  1001. self.shared = MBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  1002. self.encoder = MBartEncoder(config, self.shared)
  1003. self.decoder = MBartDecoder(config, self.shared)
  1004. # Initialize weights and apply final processing
  1005. self.post_init()
  1006. def get_input_embeddings(self):
  1007. return self.shared
  1008. def set_input_embeddings(self, value):
  1009. self.shared = value
  1010. self.encoder.embed_tokens = self.shared
  1011. self.decoder.embed_tokens = self.shared
  1012. def get_encoder(self):
  1013. return self.encoder
  1014. def _tie_weights(self):
  1015. if self.config.tie_word_embeddings:
  1016. self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
  1017. self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
  1018. @auto_docstring
  1019. def forward(
  1020. self,
  1021. input_ids: Optional[torch.LongTensor] = None,
  1022. attention_mask: Optional[torch.Tensor] = None,
  1023. decoder_input_ids: Optional[torch.LongTensor] = None,
  1024. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1025. head_mask: Optional[torch.Tensor] = None,
  1026. decoder_head_mask: Optional[torch.Tensor] = None,
  1027. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1028. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1029. past_key_values: Optional[Cache] = None,
  1030. inputs_embeds: Optional[torch.FloatTensor] = None,
  1031. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1032. use_cache: Optional[bool] = None,
  1033. output_attentions: Optional[bool] = None,
  1034. output_hidden_states: Optional[bool] = None,
  1035. return_dict: Optional[bool] = None,
  1036. cache_position: Optional[torch.Tensor] = None,
  1037. ) -> Union[Seq2SeqModelOutput, tuple[torch.FloatTensor]]:
  1038. r"""
  1039. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1040. Indices of decoder input sequence tokens in the vocabulary.
  1041. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1042. [`PreTrainedTokenizer.__call__`] for details.
  1043. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1044. MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  1045. varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
  1046. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1047. `past_key_values`).
  1048. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1049. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1050. for denoising pre-training following the paper.
  1051. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1052. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1053. be used by default.
  1054. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1055. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1056. 1]`:
  1057. - 1 indicates the head is **not masked**,
  1058. - 0 indicates the head is **masked**.
  1059. """
  1060. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1061. output_hidden_states = (
  1062. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1063. )
  1064. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1065. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1066. # different to other models, MBart automatically creates decoder_input_ids from
  1067. # input_ids if no decoder_input_ids are provided
  1068. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1069. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  1070. if encoder_outputs is None:
  1071. encoder_outputs = self.encoder(
  1072. input_ids=input_ids,
  1073. attention_mask=attention_mask,
  1074. head_mask=head_mask,
  1075. inputs_embeds=inputs_embeds,
  1076. output_attentions=output_attentions,
  1077. output_hidden_states=output_hidden_states,
  1078. return_dict=return_dict,
  1079. )
  1080. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1081. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1082. encoder_outputs = BaseModelOutput(
  1083. last_hidden_state=encoder_outputs[0],
  1084. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1085. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1086. )
  1087. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1088. decoder_outputs = self.decoder(
  1089. input_ids=decoder_input_ids,
  1090. attention_mask=decoder_attention_mask,
  1091. encoder_hidden_states=encoder_outputs[0],
  1092. encoder_attention_mask=attention_mask,
  1093. head_mask=decoder_head_mask,
  1094. cross_attn_head_mask=cross_attn_head_mask,
  1095. past_key_values=past_key_values,
  1096. inputs_embeds=decoder_inputs_embeds,
  1097. use_cache=use_cache,
  1098. output_attentions=output_attentions,
  1099. output_hidden_states=output_hidden_states,
  1100. return_dict=return_dict,
  1101. cache_position=cache_position,
  1102. )
  1103. if not return_dict:
  1104. return decoder_outputs + encoder_outputs
  1105. return Seq2SeqModelOutput(
  1106. last_hidden_state=decoder_outputs.last_hidden_state,
  1107. past_key_values=decoder_outputs.past_key_values,
  1108. decoder_hidden_states=decoder_outputs.hidden_states,
  1109. decoder_attentions=decoder_outputs.attentions,
  1110. cross_attentions=decoder_outputs.cross_attentions,
  1111. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1112. encoder_hidden_states=encoder_outputs.hidden_states,
  1113. encoder_attentions=encoder_outputs.attentions,
  1114. )
  1115. @auto_docstring(
  1116. custom_intro="""
  1117. The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.
  1118. """
  1119. )
  1120. class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
  1121. base_model_prefix = "model"
  1122. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1123. _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
  1124. def __init__(self, config: MBartConfig):
  1125. super().__init__(config)
  1126. self.model = MBartModel(config)
  1127. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1128. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1129. # Initialize weights and apply final processing
  1130. self.post_init()
  1131. def get_encoder(self):
  1132. return self.model.get_encoder()
  1133. def get_decoder(self):
  1134. return self.model.get_decoder()
  1135. def resize_token_embeddings(
  1136. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  1137. ) -> nn.Embedding:
  1138. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1139. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1140. return new_embeddings
  1141. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1142. old_num_tokens = self.final_logits_bias.shape[-1]
  1143. if new_num_tokens <= old_num_tokens:
  1144. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1145. else:
  1146. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1147. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1148. self.register_buffer("final_logits_bias", new_bias)
  1149. @auto_docstring
  1150. def forward(
  1151. self,
  1152. input_ids: Optional[torch.LongTensor] = None,
  1153. attention_mask: Optional[torch.Tensor] = None,
  1154. decoder_input_ids: Optional[torch.LongTensor] = None,
  1155. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1156. head_mask: Optional[torch.Tensor] = None,
  1157. decoder_head_mask: Optional[torch.Tensor] = None,
  1158. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1159. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1160. past_key_values: Optional[Cache] = None,
  1161. inputs_embeds: Optional[torch.FloatTensor] = None,
  1162. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1163. labels: Optional[torch.LongTensor] = None,
  1164. use_cache: Optional[bool] = None,
  1165. output_attentions: Optional[bool] = None,
  1166. output_hidden_states: Optional[bool] = None,
  1167. return_dict: Optional[bool] = None,
  1168. cache_position: Optional[torch.Tensor] = None,
  1169. ) -> Union[Seq2SeqLMOutput, tuple[torch.FloatTensor]]:
  1170. r"""
  1171. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1172. Indices of decoder input sequence tokens in the vocabulary.
  1173. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1174. [`PreTrainedTokenizer.__call__`] for details.
  1175. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1176. MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  1177. varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
  1178. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1179. `past_key_values`).
  1180. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1181. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1182. for denoising pre-training following the paper.
  1183. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1184. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1185. be used by default.
  1186. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1187. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1188. 1]`:
  1189. - 1 indicates the head is **not masked**,
  1190. - 0 indicates the head is **masked**.
  1191. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1192. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1193. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1194. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1195. Example Translation:
  1196. ```python
  1197. >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
  1198. >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
  1199. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")
  1200. >>> example_english_phrase = "42 is the answer"
  1201. >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
  1202. >>> # Translate
  1203. >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
  1204. >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1205. '42 este răspuns'
  1206. ```
  1207. Mask filling example:
  1208. ```python
  1209. >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
  1210. >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
  1211. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
  1212. >>> # de_DE is the language symbol id <LID> for German
  1213. >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
  1214. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
  1215. >>> logits = model(input_ids).logits
  1216. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  1217. >>> probs = logits[0, masked_index].softmax(dim=0)
  1218. >>> values, predictions = probs.topk(5)
  1219. >>> tokenizer.decode(predictions).split()
  1220. ['nett', 'sehr', 'ganz', 'nicht', 'so']
  1221. ```
  1222. """
  1223. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1224. if labels is not None:
  1225. if use_cache:
  1226. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1227. use_cache = False
  1228. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1229. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  1230. outputs = self.model(
  1231. input_ids,
  1232. attention_mask=attention_mask,
  1233. decoder_input_ids=decoder_input_ids,
  1234. encoder_outputs=encoder_outputs,
  1235. decoder_attention_mask=decoder_attention_mask,
  1236. head_mask=head_mask,
  1237. decoder_head_mask=decoder_head_mask,
  1238. cross_attn_head_mask=cross_attn_head_mask,
  1239. past_key_values=past_key_values,
  1240. inputs_embeds=inputs_embeds,
  1241. decoder_inputs_embeds=decoder_inputs_embeds,
  1242. use_cache=use_cache,
  1243. output_attentions=output_attentions,
  1244. output_hidden_states=output_hidden_states,
  1245. return_dict=return_dict,
  1246. cache_position=cache_position,
  1247. )
  1248. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1249. masked_lm_loss = None
  1250. if labels is not None:
  1251. loss_fct = CrossEntropyLoss()
  1252. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1253. if not return_dict:
  1254. output = (lm_logits,) + outputs[1:]
  1255. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1256. return Seq2SeqLMOutput(
  1257. loss=masked_lm_loss,
  1258. logits=lm_logits,
  1259. past_key_values=outputs.past_key_values,
  1260. decoder_hidden_states=outputs.decoder_hidden_states,
  1261. decoder_attentions=outputs.decoder_attentions,
  1262. cross_attentions=outputs.cross_attentions,
  1263. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1264. encoder_hidden_states=outputs.encoder_hidden_states,
  1265. encoder_attentions=outputs.encoder_attentions,
  1266. )
  1267. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1268. return shift_tokens_right(labels, self.config.pad_token_id)
  1269. @auto_docstring(
  1270. custom_intro="""
  1271. MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1272. tasks.
  1273. """
  1274. )
  1275. class MBartForSequenceClassification(MBartPreTrainedModel):
  1276. _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
  1277. def __init__(self, config: MBartConfig, **kwargs):
  1278. super().__init__(config, **kwargs)
  1279. self.model = MBartModel(config)
  1280. self.classification_head = MBartClassificationHead(
  1281. config.d_model,
  1282. config.d_model,
  1283. config.num_labels,
  1284. config.classifier_dropout,
  1285. )
  1286. # Initialize weights and apply final processing
  1287. self.post_init()
  1288. @auto_docstring
  1289. # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
  1290. def forward(
  1291. self,
  1292. input_ids: Optional[torch.LongTensor] = None,
  1293. attention_mask: Optional[torch.Tensor] = None,
  1294. decoder_input_ids: Optional[torch.LongTensor] = None,
  1295. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1296. head_mask: Optional[torch.Tensor] = None,
  1297. decoder_head_mask: Optional[torch.Tensor] = None,
  1298. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1299. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1300. inputs_embeds: Optional[torch.FloatTensor] = None,
  1301. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1302. labels: Optional[torch.LongTensor] = None,
  1303. use_cache: Optional[bool] = None,
  1304. output_attentions: Optional[bool] = None,
  1305. output_hidden_states: Optional[bool] = None,
  1306. return_dict: Optional[bool] = None,
  1307. cache_position: Optional[torch.LongTensor] = None,
  1308. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1309. r"""
  1310. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1311. Indices of decoder input sequence tokens in the vocabulary.
  1312. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1313. [`PreTrainedTokenizer.__call__`] for details.
  1314. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1315. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1316. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1317. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1318. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1319. for denoising pre-training following the paper.
  1320. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1321. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1322. be used by default.
  1323. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1324. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1325. information on the default strategy.
  1326. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1327. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1328. 1]`:
  1329. - 1 indicates the head is **not masked**,
  1330. - 0 indicates the head is **masked**.
  1331. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1332. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1333. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1334. """
  1335. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1336. if labels is not None:
  1337. use_cache = False
  1338. if input_ids is None and inputs_embeds is not None:
  1339. raise NotImplementedError(
  1340. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1341. )
  1342. outputs = self.model(
  1343. input_ids,
  1344. attention_mask=attention_mask,
  1345. decoder_input_ids=decoder_input_ids,
  1346. decoder_attention_mask=decoder_attention_mask,
  1347. head_mask=head_mask,
  1348. decoder_head_mask=decoder_head_mask,
  1349. cross_attn_head_mask=cross_attn_head_mask,
  1350. encoder_outputs=encoder_outputs,
  1351. inputs_embeds=inputs_embeds,
  1352. decoder_inputs_embeds=decoder_inputs_embeds,
  1353. use_cache=use_cache,
  1354. output_attentions=output_attentions,
  1355. output_hidden_states=output_hidden_states,
  1356. return_dict=return_dict,
  1357. cache_position=cache_position,
  1358. )
  1359. hidden_states = outputs[0] # last hidden state
  1360. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1361. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1362. raise ValueError("All examples must have the same number of <eos> tokens.")
  1363. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1364. :, -1, :
  1365. ]
  1366. logits = self.classification_head(sentence_representation)
  1367. loss = None
  1368. if labels is not None:
  1369. labels = labels.to(logits.device)
  1370. if self.config.problem_type is None:
  1371. if self.config.num_labels == 1:
  1372. self.config.problem_type = "regression"
  1373. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1374. self.config.problem_type = "single_label_classification"
  1375. else:
  1376. self.config.problem_type = "multi_label_classification"
  1377. if self.config.problem_type == "regression":
  1378. loss_fct = MSELoss()
  1379. if self.config.num_labels == 1:
  1380. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1381. else:
  1382. loss = loss_fct(logits, labels)
  1383. elif self.config.problem_type == "single_label_classification":
  1384. loss_fct = CrossEntropyLoss()
  1385. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1386. elif self.config.problem_type == "multi_label_classification":
  1387. loss_fct = BCEWithLogitsLoss()
  1388. loss = loss_fct(logits, labels)
  1389. if not return_dict:
  1390. output = (logits,) + outputs[1:]
  1391. return ((loss,) + output) if loss is not None else output
  1392. return Seq2SeqSequenceClassifierOutput(
  1393. loss=loss,
  1394. logits=logits,
  1395. past_key_values=outputs.past_key_values,
  1396. decoder_hidden_states=outputs.decoder_hidden_states,
  1397. decoder_attentions=outputs.decoder_attentions,
  1398. cross_attentions=outputs.cross_attentions,
  1399. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1400. encoder_hidden_states=outputs.encoder_hidden_states,
  1401. encoder_attentions=outputs.encoder_attentions,
  1402. )
  1403. @auto_docstring
  1404. class MBartForQuestionAnswering(MBartPreTrainedModel):
  1405. _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
  1406. def __init__(self, config):
  1407. super().__init__(config)
  1408. config.num_labels = 2
  1409. self.num_labels = config.num_labels
  1410. self.model = MBartModel(config)
  1411. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1412. # Initialize weights and apply final processing
  1413. self.post_init()
  1414. @auto_docstring
  1415. # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
  1416. def forward(
  1417. self,
  1418. input_ids: Optional[torch.Tensor] = None,
  1419. attention_mask: Optional[torch.Tensor] = None,
  1420. decoder_input_ids: Optional[torch.LongTensor] = None,
  1421. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1422. head_mask: Optional[torch.Tensor] = None,
  1423. decoder_head_mask: Optional[torch.Tensor] = None,
  1424. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1425. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1426. start_positions: Optional[torch.LongTensor] = None,
  1427. end_positions: Optional[torch.LongTensor] = None,
  1428. inputs_embeds: Optional[torch.FloatTensor] = None,
  1429. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1430. use_cache: Optional[bool] = None,
  1431. output_attentions: Optional[bool] = None,
  1432. output_hidden_states: Optional[bool] = None,
  1433. return_dict: Optional[bool] = None,
  1434. cache_position: Optional[torch.LongTensor] = None,
  1435. ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
  1436. r"""
  1437. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1438. Indices of decoder input sequence tokens in the vocabulary.
  1439. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1440. [`PreTrainedTokenizer.__call__`] for details.
  1441. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1442. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1443. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1444. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1445. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1446. for denoising pre-training following the paper.
  1447. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1448. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1449. be used by default.
  1450. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1451. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1452. information on the default strategy.
  1453. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1454. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1455. 1]`:
  1456. - 1 indicates the head is **not masked**,
  1457. - 0 indicates the head is **masked**.
  1458. """
  1459. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1460. if start_positions is not None and end_positions is not None:
  1461. use_cache = False
  1462. outputs = self.model(
  1463. input_ids,
  1464. attention_mask=attention_mask,
  1465. decoder_input_ids=decoder_input_ids,
  1466. decoder_attention_mask=decoder_attention_mask,
  1467. head_mask=head_mask,
  1468. decoder_head_mask=decoder_head_mask,
  1469. cross_attn_head_mask=cross_attn_head_mask,
  1470. encoder_outputs=encoder_outputs,
  1471. inputs_embeds=inputs_embeds,
  1472. decoder_inputs_embeds=decoder_inputs_embeds,
  1473. use_cache=use_cache,
  1474. output_attentions=output_attentions,
  1475. output_hidden_states=output_hidden_states,
  1476. return_dict=return_dict,
  1477. cache_position=cache_position,
  1478. )
  1479. sequence_output = outputs[0]
  1480. logits = self.qa_outputs(sequence_output)
  1481. start_logits, end_logits = logits.split(1, dim=-1)
  1482. start_logits = start_logits.squeeze(-1).contiguous()
  1483. end_logits = end_logits.squeeze(-1).contiguous()
  1484. total_loss = None
  1485. if start_positions is not None and end_positions is not None:
  1486. # If we are on multi-GPU, split add a dimension
  1487. if len(start_positions.size()) > 1:
  1488. start_positions = start_positions.squeeze(-1)
  1489. if len(end_positions.size()) > 1:
  1490. end_positions = end_positions.squeeze(-1)
  1491. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1492. ignored_index = start_logits.size(1)
  1493. start_positions = start_positions.clamp(0, ignored_index)
  1494. end_positions = end_positions.clamp(0, ignored_index)
  1495. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1496. start_loss = loss_fct(start_logits, start_positions)
  1497. end_loss = loss_fct(end_logits, end_positions)
  1498. total_loss = (start_loss + end_loss) / 2
  1499. if not return_dict:
  1500. output = (
  1501. start_logits,
  1502. end_logits,
  1503. ) + outputs[1:]
  1504. return ((total_loss,) + output) if total_loss is not None else output
  1505. return Seq2SeqQuestionAnsweringModelOutput(
  1506. loss=total_loss,
  1507. start_logits=start_logits,
  1508. end_logits=end_logits,
  1509. past_key_values=outputs.past_key_values,
  1510. decoder_hidden_states=outputs.decoder_hidden_states,
  1511. decoder_attentions=outputs.decoder_attentions,
  1512. cross_attentions=outputs.cross_attentions,
  1513. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1514. encoder_hidden_states=outputs.encoder_hidden_states,
  1515. encoder_attentions=outputs.encoder_attentions,
  1516. )
  1517. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
  1518. class MBartDecoderWrapper(MBartPreTrainedModel):
  1519. """
  1520. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1521. used in combination with the [`EncoderDecoderModel`] framework.
  1522. """
  1523. def __init__(self, config):
  1524. super().__init__(config)
  1525. self.decoder = MBartDecoder(config)
  1526. def forward(self, *args, **kwargs):
  1527. return self.decoder(*args, **kwargs)
  1528. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
  1529. class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin):
  1530. _tied_weights_keys = ["lm_head.weight"]
  1531. def __init__(self, config):
  1532. config.is_decoder = True
  1533. config.is_encoder_decoder = False
  1534. super().__init__(config)
  1535. self.model = MBartDecoderWrapper(config)
  1536. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1537. # Initialize weights and apply final processing
  1538. self.post_init()
  1539. def get_input_embeddings(self):
  1540. return self.model.decoder.embed_tokens
  1541. def set_input_embeddings(self, value):
  1542. self.model.decoder.embed_tokens = value
  1543. def set_decoder(self, decoder):
  1544. self.model.decoder = decoder
  1545. def get_decoder(self):
  1546. return self.model.decoder
  1547. @auto_docstring
  1548. def forward(
  1549. self,
  1550. input_ids: Optional[torch.LongTensor] = None,
  1551. attention_mask: Optional[torch.Tensor] = None,
  1552. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1553. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1554. head_mask: Optional[torch.Tensor] = None,
  1555. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1556. past_key_values: Optional[Cache] = None,
  1557. inputs_embeds: Optional[torch.FloatTensor] = None,
  1558. labels: Optional[torch.LongTensor] = None,
  1559. use_cache: Optional[bool] = None,
  1560. output_attentions: Optional[bool] = None,
  1561. output_hidden_states: Optional[bool] = None,
  1562. return_dict: Optional[bool] = None,
  1563. cache_position: Optional[torch.LongTensor] = None,
  1564. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1565. r"""
  1566. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1567. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1568. - 1 indicates the head is **not masked**,
  1569. - 0 indicates the head is **masked**.
  1570. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1571. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1572. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1573. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1574. Example:
  1575. ```python
  1576. >>> from transformers import AutoTokenizer, MBartForCausalLM
  1577. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
  1578. >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25", add_cross_attention=False)
  1579. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1580. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1581. >>> outputs = model(**inputs)
  1582. >>> logits = outputs.logits
  1583. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1584. >>> list(logits.shape) == expected_shape
  1585. True
  1586. ```"""
  1587. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1588. output_hidden_states = (
  1589. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1590. )
  1591. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1592. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1593. outputs = self.model.decoder(
  1594. input_ids=input_ids,
  1595. attention_mask=attention_mask,
  1596. encoder_hidden_states=encoder_hidden_states,
  1597. encoder_attention_mask=encoder_attention_mask,
  1598. head_mask=head_mask,
  1599. cross_attn_head_mask=cross_attn_head_mask,
  1600. past_key_values=past_key_values,
  1601. inputs_embeds=inputs_embeds,
  1602. use_cache=use_cache,
  1603. output_attentions=output_attentions,
  1604. output_hidden_states=output_hidden_states,
  1605. return_dict=return_dict,
  1606. cache_position=cache_position,
  1607. )
  1608. logits = self.lm_head(outputs[0])
  1609. loss = None
  1610. if labels is not None:
  1611. labels = labels.to(logits.device)
  1612. loss_fct = CrossEntropyLoss()
  1613. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1614. if not return_dict:
  1615. output = (logits,) + outputs[1:]
  1616. return (loss,) + output if loss is not None else output
  1617. return CausalLMOutputWithCrossAttentions(
  1618. loss=loss,
  1619. logits=logits,
  1620. past_key_values=outputs.past_key_values,
  1621. hidden_states=outputs.hidden_states,
  1622. attentions=outputs.attentions,
  1623. cross_attentions=outputs.cross_attentions,
  1624. )
  1625. __all__ = [
  1626. "MBartForCausalLM",
  1627. "MBartForConditionalGeneration",
  1628. "MBartForQuestionAnswering",
  1629. "MBartForSequenceClassification",
  1630. "MBartModel",
  1631. "MBartPreTrainedModel",
  1632. ]