modeling_marian.py 77 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700
  1. # coding=utf-8
  2. # Copyright 2021 The Marian Team Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MarianMTModel model, ported from the Marian C++ repo."""
  16. import copy
  17. import math
  18. from typing import Callable, Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import (
  27. AttentionMaskConverter,
  28. _prepare_4d_attention_mask,
  29. _prepare_4d_attention_mask_for_sdpa,
  30. )
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutput,
  35. BaseModelOutputWithPastAndCrossAttentions,
  36. CausalLMOutputWithCrossAttentions,
  37. Seq2SeqLMOutput,
  38. Seq2SeqModelOutput,
  39. )
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import (
  43. auto_docstring,
  44. is_torch_flex_attn_available,
  45. is_torchdynamo_compiling,
  46. logging,
  47. )
  48. from ...utils.deprecation import deprecate_kwarg
  49. from .configuration_marian import MarianConfig
  50. if is_torch_flex_attn_available():
  51. from torch.nn.attention.flex_attention import BlockMask
  52. from ...integrations.flex_attention import make_flex_block_causal_mask
  53. logger = logging.get_logger(__name__)
  54. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  55. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  56. """
  57. Shift input ids one token to the right.
  58. """
  59. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  60. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  61. shifted_input_ids[:, 0] = decoder_start_token_id
  62. if pad_token_id is None:
  63. raise ValueError("self.model.config.pad_token_id has to be defined.")
  64. # replace possible -100 values in labels by `pad_token_id`
  65. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  66. return shifted_input_ids
  67. class MarianSinusoidalPositionalEmbedding(nn.Embedding):
  68. """This module produces sinusoidal positional embeddings of any length."""
  69. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
  70. super().__init__(num_positions, embedding_dim)
  71. def _init_weight(self):
  72. """
  73. Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
  74. the 2nd half of the vector. [dim // 2:]
  75. """
  76. n_pos, dim = self.weight.shape
  77. position_enc = np.array(
  78. [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
  79. )
  80. out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
  81. sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
  82. out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  83. out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  84. self.weight = nn.Parameter(out, requires_grad=False)
  85. @torch.no_grad()
  86. def forward(
  87. self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  88. ) -> torch.Tensor:
  89. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  90. if position_ids is None:
  91. bsz, seq_len = input_ids_shape[:2]
  92. position_ids = torch.arange(
  93. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  94. )
  95. return super().forward(position_ids)
  96. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  97. def eager_attention_forward(
  98. module: nn.Module,
  99. query: torch.Tensor,
  100. key: torch.Tensor,
  101. value: torch.Tensor,
  102. attention_mask: Optional[torch.Tensor],
  103. scaling: Optional[float] = None,
  104. dropout: float = 0.0,
  105. head_mask: Optional[torch.Tensor] = None,
  106. **kwargs,
  107. ):
  108. if scaling is None:
  109. scaling = query.size(-1) ** -0.5
  110. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  111. if attention_mask is not None:
  112. attn_weights = attn_weights + attention_mask
  113. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  114. if head_mask is not None:
  115. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  116. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  117. attn_output = torch.matmul(attn_weights, value)
  118. attn_output = attn_output.transpose(1, 2).contiguous()
  119. return attn_output, attn_weights
  120. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian
  121. class MarianAttention(nn.Module):
  122. """Multi-headed attention from 'Attention Is All You Need' paper"""
  123. def __init__(
  124. self,
  125. embed_dim: int,
  126. num_heads: int,
  127. dropout: float = 0.0,
  128. is_decoder: bool = False,
  129. bias: bool = True,
  130. is_causal: bool = False,
  131. config: Optional[MarianConfig] = None,
  132. layer_idx: Optional[int] = None,
  133. ):
  134. super().__init__()
  135. self.embed_dim = embed_dim
  136. self.num_heads = num_heads
  137. self.dropout = dropout
  138. self.head_dim = embed_dim // num_heads
  139. self.config = config
  140. if (self.head_dim * num_heads) != self.embed_dim:
  141. raise ValueError(
  142. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  143. f" and `num_heads`: {num_heads})."
  144. )
  145. self.scaling = self.head_dim**-0.5
  146. self.is_decoder = is_decoder
  147. self.is_causal = is_causal
  148. self.layer_idx = layer_idx
  149. if layer_idx is None and self.is_decoder:
  150. logger.warning_once(
  151. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  152. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  153. "when creating this class."
  154. )
  155. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  156. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  157. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  158. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  159. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  160. def forward(
  161. self,
  162. hidden_states: torch.Tensor,
  163. key_value_states: Optional[torch.Tensor] = None,
  164. past_key_values: Optional[Cache] = None,
  165. attention_mask: Optional[torch.Tensor] = None,
  166. layer_head_mask: Optional[torch.Tensor] = None,
  167. output_attentions: bool = False,
  168. cache_position: Optional[torch.Tensor] = None,
  169. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  170. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  171. **kwargs: Unpack[FlashAttentionKwargs],
  172. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  173. """Input shape: Batch x Time x Channel"""
  174. # if key_value_states are provided this layer is used as a cross-attention layer
  175. # for the decoder
  176. is_cross_attention = key_value_states is not None
  177. # determine input shapes
  178. bsz, tgt_len = hidden_states.shape[:-1]
  179. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  180. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  181. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  182. # get query proj
  183. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  184. is_updated = False
  185. if past_key_values is not None:
  186. if isinstance(past_key_values, EncoderDecoderCache):
  187. is_updated = past_key_values.is_updated.get(self.layer_idx)
  188. if is_cross_attention:
  189. # after the first generated id, we can subsequently re-use all key/value_states from cache
  190. curr_past_key_value = past_key_values.cross_attention_cache
  191. else:
  192. curr_past_key_value = past_key_values.self_attention_cache
  193. else:
  194. curr_past_key_value = past_key_values
  195. current_states = key_value_states if is_cross_attention else hidden_states
  196. if is_cross_attention and past_key_values is not None and is_updated:
  197. # reuse k,v, cross_attentions
  198. key_states = curr_past_key_value.layers[self.layer_idx].keys
  199. value_states = curr_past_key_value.layers[self.layer_idx].values
  200. else:
  201. key_states = self.k_proj(current_states)
  202. value_states = self.v_proj(current_states)
  203. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  204. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  205. if past_key_values is not None:
  206. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  207. cache_position = cache_position if not is_cross_attention else None
  208. key_states, value_states = curr_past_key_value.update(
  209. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  210. )
  211. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  212. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  213. past_key_values.is_updated[self.layer_idx] = True
  214. attention_interface: Callable = eager_attention_forward
  215. if self.config._attn_implementation != "eager":
  216. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  217. attn_output, attn_weights = attention_interface(
  218. self,
  219. query_states,
  220. key_states,
  221. value_states,
  222. attention_mask,
  223. dropout=0.0 if not self.training else self.dropout,
  224. scaling=self.scaling,
  225. output_attentions=output_attentions,
  226. head_mask=layer_head_mask,
  227. **kwargs,
  228. )
  229. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  230. attn_output = self.out_proj(attn_output)
  231. return attn_output, attn_weights
  232. # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN
  233. class MarianEncoderLayer(GradientCheckpointingLayer):
  234. def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None):
  235. super().__init__()
  236. self.embed_dim = config.d_model
  237. self.self_attn = MarianAttention(
  238. embed_dim=self.embed_dim,
  239. num_heads=config.encoder_attention_heads,
  240. dropout=config.attention_dropout,
  241. config=config,
  242. layer_idx=layer_idx,
  243. )
  244. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  245. self.dropout = config.dropout
  246. self.activation_fn = ACT2FN[config.activation_function]
  247. self.activation_dropout = config.activation_dropout
  248. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  249. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  250. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  251. def forward(
  252. self,
  253. hidden_states: torch.FloatTensor,
  254. attention_mask: torch.FloatTensor,
  255. layer_head_mask: torch.FloatTensor,
  256. output_attentions: Optional[bool] = False,
  257. ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
  258. """
  259. Args:
  260. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  261. attention_mask (`torch.FloatTensor`): attention mask of size
  262. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  263. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  264. `(encoder_attention_heads,)`.
  265. output_attentions (`bool`, *optional*):
  266. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  267. returned tensors for more detail.
  268. """
  269. residual = hidden_states
  270. hidden_states, attn_weights = self.self_attn(
  271. hidden_states=hidden_states,
  272. attention_mask=attention_mask,
  273. layer_head_mask=layer_head_mask,
  274. output_attentions=output_attentions,
  275. )
  276. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  277. hidden_states = residual + hidden_states
  278. hidden_states = self.self_attn_layer_norm(hidden_states)
  279. residual = hidden_states
  280. hidden_states = self.activation_fn(self.fc1(hidden_states))
  281. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  282. hidden_states = self.fc2(hidden_states)
  283. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  284. hidden_states = residual + hidden_states
  285. hidden_states = self.final_layer_norm(hidden_states)
  286. if hidden_states.dtype == torch.float16 and (
  287. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  288. ):
  289. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  290. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  291. outputs = (hidden_states,)
  292. if output_attentions:
  293. outputs += (attn_weights,)
  294. return outputs
  295. # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
  296. class MarianDecoderLayer(GradientCheckpointingLayer):
  297. def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None):
  298. super().__init__()
  299. self.embed_dim = config.d_model
  300. self.self_attn = MarianAttention(
  301. embed_dim=self.embed_dim,
  302. num_heads=config.decoder_attention_heads,
  303. dropout=config.attention_dropout,
  304. is_decoder=True,
  305. is_causal=True,
  306. config=config,
  307. layer_idx=layer_idx,
  308. )
  309. self.dropout = config.dropout
  310. self.activation_fn = ACT2FN[config.activation_function]
  311. self.activation_dropout = config.activation_dropout
  312. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  313. self.encoder_attn = MarianAttention(
  314. self.embed_dim,
  315. config.decoder_attention_heads,
  316. dropout=config.attention_dropout,
  317. is_decoder=True,
  318. config=config,
  319. layer_idx=layer_idx,
  320. )
  321. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  322. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  323. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  324. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  325. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  326. def forward(
  327. self,
  328. hidden_states: torch.Tensor,
  329. attention_mask: Optional[torch.Tensor] = None,
  330. encoder_hidden_states: Optional[torch.Tensor] = None,
  331. encoder_attention_mask: Optional[torch.Tensor] = None,
  332. layer_head_mask: Optional[torch.Tensor] = None,
  333. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  334. past_key_values: Optional[Cache] = None,
  335. output_attentions: Optional[bool] = False,
  336. use_cache: Optional[bool] = True,
  337. cache_position: Optional[torch.Tensor] = None,
  338. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  339. """
  340. Args:
  341. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  342. attention_mask (`torch.FloatTensor`): attention mask of size
  343. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  344. encoder_hidden_states (`torch.FloatTensor`):
  345. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  346. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  347. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  348. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  349. `(encoder_attention_heads,)`.
  350. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  351. size `(decoder_attention_heads,)`.
  352. past_key_values (`Cache`): cached past key and value projection states
  353. output_attentions (`bool`, *optional*):
  354. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  355. returned tensors for more detail.
  356. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  357. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  358. cache in the correct position and to infer the complete sequence length.
  359. """
  360. residual = hidden_states
  361. # Self Attention
  362. hidden_states, self_attn_weights = self.self_attn(
  363. hidden_states=hidden_states,
  364. past_key_values=past_key_values,
  365. attention_mask=attention_mask,
  366. layer_head_mask=layer_head_mask,
  367. output_attentions=output_attentions,
  368. cache_position=cache_position,
  369. )
  370. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  371. hidden_states = residual + hidden_states
  372. hidden_states = self.self_attn_layer_norm(hidden_states)
  373. # Cross-Attention Block
  374. cross_attn_weights = None
  375. if encoder_hidden_states is not None:
  376. residual = hidden_states
  377. hidden_states, cross_attn_weights = self.encoder_attn(
  378. hidden_states=hidden_states,
  379. key_value_states=encoder_hidden_states,
  380. attention_mask=encoder_attention_mask,
  381. layer_head_mask=cross_attn_layer_head_mask,
  382. past_key_values=past_key_values,
  383. output_attentions=output_attentions,
  384. cache_position=cache_position,
  385. )
  386. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  387. hidden_states = residual + hidden_states
  388. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  389. # Fully Connected
  390. residual = hidden_states
  391. hidden_states = self.activation_fn(self.fc1(hidden_states))
  392. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  393. hidden_states = self.fc2(hidden_states)
  394. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  395. hidden_states = residual + hidden_states
  396. hidden_states = self.final_layer_norm(hidden_states)
  397. outputs = (hidden_states,)
  398. if output_attentions:
  399. outputs += (self_attn_weights, cross_attn_weights)
  400. return outputs
  401. @auto_docstring
  402. class MarianPreTrainedModel(PreTrainedModel):
  403. config: MarianConfig
  404. base_model_prefix = "model"
  405. supports_gradient_checkpointing = True
  406. _supports_flash_attn = True
  407. _supports_sdpa = True
  408. _supports_flex_attn = True
  409. _can_compile_fullgraph = True
  410. def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]):
  411. std = self.config.init_std
  412. if isinstance(module, nn.Linear):
  413. module.weight.data.normal_(mean=0.0, std=std)
  414. if module.bias is not None:
  415. module.bias.data.zero_()
  416. elif isinstance(module, MarianSinusoidalPositionalEmbedding):
  417. module._init_weight()
  418. elif isinstance(module, nn.Embedding):
  419. module.weight.data.normal_(mean=0.0, std=std)
  420. if module.padding_idx is not None:
  421. module.weight.data[module.padding_idx].zero_()
  422. elif isinstance(module, nn.LayerNorm):
  423. module.weight.data.fill_(1.0)
  424. module.bias.data.zero_()
  425. @property
  426. def dummy_inputs(self):
  427. pad_token = self.config.pad_token_id
  428. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  429. dummy_inputs = {
  430. "attention_mask": input_ids.ne(pad_token),
  431. "input_ids": input_ids,
  432. "decoder_input_ids": input_ids,
  433. }
  434. return dummy_inputs
  435. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  436. def _update_full_mask(
  437. self,
  438. attention_mask: Union[torch.Tensor, None],
  439. inputs_embeds: torch.Tensor,
  440. ):
  441. if attention_mask is not None:
  442. if self.config._attn_implementation == "flash_attention_2":
  443. attention_mask = attention_mask if 0 in attention_mask else None
  444. elif self.config._attn_implementation == "sdpa":
  445. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  446. # the manual implementation that requires a 4D causal mask in all cases.
  447. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  448. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  449. elif self.config._attn_implementation == "flex_attention":
  450. if isinstance(attention_mask, torch.Tensor):
  451. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  452. else:
  453. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  454. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  455. return attention_mask
  456. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  457. def _update_causal_mask(
  458. self,
  459. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  460. input_tensor: torch.Tensor,
  461. cache_position: torch.Tensor,
  462. past_key_values: Cache,
  463. ):
  464. if self.config._attn_implementation == "flex_attention":
  465. if isinstance(attention_mask, torch.Tensor):
  466. attention_mask = make_flex_block_causal_mask(attention_mask)
  467. # Other attention flavors support in-built causal (when `mask is None`)
  468. # while we need to create our specific block mask regardless
  469. elif attention_mask is None:
  470. attention_mask = make_flex_block_causal_mask(
  471. torch.ones(
  472. size=(input_tensor.shape[0], input_tensor.shape[1]),
  473. device=attention_mask.device,
  474. )
  475. )
  476. return attention_mask
  477. if self.config._attn_implementation == "flash_attention_2":
  478. if attention_mask is not None and (attention_mask == 0.0).any():
  479. return attention_mask
  480. return None
  481. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  482. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  483. # to infer the attention mask.
  484. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  485. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  486. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  487. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  488. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  489. attention_mask,
  490. inputs_embeds=input_tensor,
  491. past_key_values_length=past_seen_tokens,
  492. is_training=self.training,
  493. ):
  494. return None
  495. dtype = input_tensor.dtype
  496. sequence_length = input_tensor.shape[1]
  497. if using_compilable_cache:
  498. target_length = past_key_values.get_max_cache_shape()
  499. else:
  500. target_length = (
  501. attention_mask.shape[-1]
  502. if isinstance(attention_mask, torch.Tensor)
  503. else past_seen_tokens + sequence_length + 1
  504. )
  505. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  506. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  507. attention_mask,
  508. sequence_length=sequence_length,
  509. target_length=target_length,
  510. dtype=dtype,
  511. cache_position=cache_position,
  512. batch_size=input_tensor.shape[0],
  513. )
  514. if (
  515. self.config._attn_implementation == "sdpa"
  516. and attention_mask is not None
  517. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  518. ):
  519. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  520. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  521. # Details: https://github.com/pytorch/pytorch/issues/110213
  522. min_dtype = torch.finfo(dtype).min
  523. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  524. return causal_mask
  525. @staticmethod
  526. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  527. def _prepare_4d_causal_attention_mask_with_cache_position(
  528. attention_mask: torch.Tensor,
  529. sequence_length: int,
  530. target_length: int,
  531. dtype: torch.dtype,
  532. cache_position: torch.Tensor,
  533. batch_size: int,
  534. **kwargs,
  535. ):
  536. """
  537. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  538. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  539. Args:
  540. attention_mask (`torch.Tensor`):
  541. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  542. `(batch_size, 1, query_length, key_value_length)`.
  543. sequence_length (`int`):
  544. The sequence length being processed.
  545. target_length (`int`):
  546. The target length: when generating with static cache, the mask should be as long as the static cache,
  547. to account for the 0 padding, the part of the cache that is not filled yet.
  548. dtype (`torch.dtype`):
  549. The dtype to use for the 4D attention mask.
  550. cache_position (`torch.Tensor`):
  551. Indices depicting the position of the input sequence tokens in the sequence.
  552. batch_size (`torch.Tensor`):
  553. Batch size.
  554. """
  555. if attention_mask is not None and attention_mask.dim() == 4:
  556. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  557. causal_mask = attention_mask
  558. else:
  559. min_dtype = torch.finfo(dtype).min
  560. causal_mask = torch.full(
  561. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  562. )
  563. if sequence_length != 1:
  564. causal_mask = torch.triu(causal_mask, diagonal=1)
  565. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  566. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  567. if attention_mask is not None:
  568. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  569. mask_length = attention_mask.shape[-1]
  570. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  571. causal_mask.device
  572. )
  573. padding_mask = padding_mask == 0
  574. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  575. padding_mask, min_dtype
  576. )
  577. return causal_mask
  578. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  579. def _update_cross_attn_mask(
  580. self,
  581. encoder_hidden_states: Union[torch.Tensor, None],
  582. encoder_attention_mask: Union[torch.Tensor, None],
  583. input_shape: torch.Size,
  584. inputs_embeds: torch.Tensor,
  585. ):
  586. # expand encoder attention mask
  587. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  588. if self.config._attn_implementation == "flash_attention_2":
  589. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  590. elif self.config._attn_implementation == "sdpa":
  591. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  592. # the manual implementation that requires a 4D causal mask in all cases.
  593. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  594. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  595. encoder_attention_mask,
  596. inputs_embeds.dtype,
  597. tgt_len=input_shape[-1],
  598. )
  599. elif self.config._attn_implementation == "flex_attention":
  600. if isinstance(encoder_attention_mask, torch.Tensor):
  601. encoder_attention_mask = make_flex_block_causal_mask(
  602. encoder_attention_mask,
  603. query_length=input_shape[-1],
  604. is_causal=False,
  605. )
  606. else:
  607. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  608. encoder_attention_mask = _prepare_4d_attention_mask(
  609. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  610. )
  611. return encoder_attention_mask
  612. class MarianEncoder(MarianPreTrainedModel):
  613. """
  614. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  615. [`MarianEncoderLayer`].
  616. Args:
  617. config: MarianConfig
  618. embed_tokens (nn.Embedding): output embedding
  619. """
  620. def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
  621. super().__init__(config)
  622. self.dropout = config.dropout
  623. self.layerdrop = config.encoder_layerdrop
  624. embed_dim = config.d_model
  625. self.padding_idx = config.pad_token_id
  626. self.max_source_positions = config.max_position_embeddings
  627. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  628. if embed_tokens is not None:
  629. self.embed_tokens = embed_tokens
  630. else:
  631. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  632. self.embed_positions = MarianSinusoidalPositionalEmbedding(
  633. config.max_position_embeddings, embed_dim, self.padding_idx
  634. )
  635. self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
  636. self.gradient_checkpointing = False
  637. # Initialize weights and apply final processing
  638. self.post_init()
  639. def forward(
  640. self,
  641. input_ids: Optional[torch.LongTensor] = None,
  642. attention_mask: Optional[torch.LongTensor] = None,
  643. head_mask: Optional[torch.Tensor] = None,
  644. inputs_embeds: Optional[torch.FloatTensor] = None,
  645. output_attentions: Optional[bool] = None,
  646. output_hidden_states: Optional[bool] = None,
  647. return_dict: Optional[bool] = None,
  648. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  649. r"""
  650. Args:
  651. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  652. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  653. provide it.
  654. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  655. [`PreTrainedTokenizer.__call__`] for details.
  656. [What are input IDs?](../glossary#input-ids)
  657. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  658. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  659. - 1 for tokens that are **not masked**,
  660. - 0 for tokens that are **masked**.
  661. [What are attention masks?](../glossary#attention-mask)
  662. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  663. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  664. - 1 indicates the head is **not masked**,
  665. - 0 indicates the head is **masked**.
  666. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  667. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  668. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  669. than the model's internal embedding lookup matrix.
  670. output_attentions (`bool`, *optional*):
  671. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  672. returned tensors for more detail.
  673. output_hidden_states (`bool`, *optional*):
  674. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  675. for more detail.
  676. return_dict (`bool`, *optional*):
  677. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  678. """
  679. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  680. output_hidden_states = (
  681. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  682. )
  683. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  684. # retrieve input_ids and inputs_embeds
  685. if input_ids is not None and inputs_embeds is not None:
  686. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  687. elif input_ids is not None:
  688. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  689. input_shape = input_ids.size()
  690. input_ids = input_ids.view(-1, input_shape[-1])
  691. elif inputs_embeds is not None:
  692. input_shape = inputs_embeds.size()[:-1]
  693. else:
  694. raise ValueError("You have to specify either input_ids or inputs_embeds")
  695. if inputs_embeds is None:
  696. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  697. embed_pos = self.embed_positions(input_shape)
  698. hidden_states = inputs_embeds + embed_pos
  699. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  700. attention_mask = self._update_full_mask(
  701. attention_mask,
  702. inputs_embeds,
  703. )
  704. encoder_states = () if output_hidden_states else None
  705. all_attentions = () if output_attentions else None
  706. # check if head_mask has a correct number of layers specified if desired
  707. if head_mask is not None:
  708. assert head_mask.size()[0] == (len(self.layers)), (
  709. f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
  710. )
  711. for idx, encoder_layer in enumerate(self.layers):
  712. if output_hidden_states:
  713. encoder_states = encoder_states + (hidden_states,)
  714. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  715. to_drop = False
  716. if self.training:
  717. dropout_probability = torch.rand([])
  718. if dropout_probability < self.layerdrop: # skip the layer
  719. to_drop = True
  720. if to_drop:
  721. layer_outputs = (None, None)
  722. else:
  723. layer_outputs = encoder_layer(
  724. hidden_states,
  725. attention_mask,
  726. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  727. output_attentions=output_attentions,
  728. )
  729. hidden_states = layer_outputs[0]
  730. if output_attentions:
  731. all_attentions = all_attentions + (layer_outputs[1],)
  732. if output_hidden_states:
  733. encoder_states = encoder_states + (hidden_states,)
  734. if not return_dict:
  735. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  736. return BaseModelOutput(
  737. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  738. )
  739. class MarianDecoder(MarianPreTrainedModel):
  740. """
  741. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MarianDecoderLayer`]
  742. Args:
  743. config: MarianConfig
  744. embed_tokens (nn.Embedding): output embedding
  745. """
  746. def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None):
  747. super().__init__(config)
  748. self.dropout = config.dropout
  749. self.layerdrop = config.decoder_layerdrop
  750. self.padding_idx = config.pad_token_id
  751. self.max_target_positions = config.max_position_embeddings
  752. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  753. if embed_tokens is not None:
  754. self.embed_tokens = embed_tokens
  755. else:
  756. self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx)
  757. self.embed_positions = MarianSinusoidalPositionalEmbedding(
  758. config.max_position_embeddings, config.d_model, self.padding_idx
  759. )
  760. self.layers = nn.ModuleList([MarianDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  761. self.gradient_checkpointing = False
  762. # Initialize weights and apply final processing
  763. self.post_init()
  764. def forward(
  765. self,
  766. input_ids: Optional[torch.LongTensor] = None,
  767. attention_mask: Optional[torch.Tensor] = None,
  768. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  769. encoder_attention_mask: Optional[torch.LongTensor] = None,
  770. head_mask: Optional[torch.Tensor] = None,
  771. cross_attn_head_mask: Optional[torch.Tensor] = None,
  772. past_key_values: Optional[Cache] = None,
  773. inputs_embeds: Optional[torch.FloatTensor] = None,
  774. use_cache: Optional[bool] = None,
  775. output_attentions: Optional[bool] = None,
  776. output_hidden_states: Optional[bool] = None,
  777. return_dict: Optional[bool] = None,
  778. cache_position: Optional[torch.Tensor] = None,
  779. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  780. r"""
  781. Args:
  782. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  783. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  784. provide it.
  785. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  786. [`PreTrainedTokenizer.__call__`] for details.
  787. [What are input IDs?](../glossary#input-ids)
  788. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  789. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  790. - 1 for tokens that are **not masked**,
  791. - 0 for tokens that are **masked**.
  792. [What are attention masks?](../glossary#attention-mask)
  793. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  794. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  795. of the decoder.
  796. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  797. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  798. selected in `[0, 1]`:
  799. - 1 for tokens that are **not masked**,
  800. - 0 for tokens that are **masked**.
  801. [What are attention masks?](../glossary#attention-mask)
  802. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  803. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  804. - 1 indicates the head is **not masked**,
  805. - 0 indicates the head is **masked**.
  806. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  807. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  808. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  809. - 1 indicates the head is **not masked**,
  810. - 0 indicates the head is **masked**.
  811. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  812. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  813. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  814. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  815. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  816. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  817. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  818. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  819. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  820. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  821. than the model's internal embedding lookup matrix.
  822. output_attentions (`bool`, *optional*):
  823. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  824. returned tensors for more detail.
  825. output_hidden_states (`bool`, *optional*):
  826. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  827. for more detail.
  828. return_dict (`bool`, *optional*):
  829. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  830. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  831. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  832. cache in the correct position and to infer the complete sequence length.
  833. """
  834. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  835. output_hidden_states = (
  836. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  837. )
  838. use_cache = use_cache if use_cache is not None else self.config.use_cache
  839. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  840. if self.gradient_checkpointing and self.training:
  841. if use_cache:
  842. logger.warning_once(
  843. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  844. )
  845. use_cache = False
  846. # retrieve input_ids and inputs_embeds
  847. if (input_ids is None) ^ (inputs_embeds is not None):
  848. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  849. elif input_ids is not None:
  850. input = input_ids
  851. input_shape = input.shape
  852. input_ids = input_ids.view(-1, input_shape[-1])
  853. elif inputs_embeds is not None:
  854. input_shape = inputs_embeds.size()[:-1]
  855. input = inputs_embeds[:, :, -1]
  856. if inputs_embeds is None:
  857. inputs_embeds = self.embed_tokens(input)
  858. # Important to apply outside of the above `if`, in case user passes `embeds`
  859. inputs_embeds = inputs_embeds * self.embed_scale
  860. # initialize `past_key_values`
  861. if use_cache and past_key_values is None:
  862. past_key_values = (
  863. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  864. if encoder_hidden_states is not None
  865. else DynamicCache(config=self.config)
  866. )
  867. if use_cache and isinstance(past_key_values, tuple):
  868. logger.warning_once(
  869. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  870. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  871. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  872. )
  873. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  874. batch_size, seq_length = inputs_embeds.size()[:-1]
  875. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  876. if cache_position is None:
  877. cache_position = torch.arange(
  878. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  879. )
  880. if attention_mask is None and not is_torchdynamo_compiling():
  881. # required mask seq length can be calculated via length of past cache
  882. mask_seq_length = past_key_values_length + seq_length
  883. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  884. self_attn_cache = (
  885. past_key_values.self_attention_cache
  886. if isinstance(past_key_values, EncoderDecoderCache)
  887. else past_key_values
  888. )
  889. causal_mask = self._update_causal_mask(
  890. attention_mask,
  891. inputs_embeds,
  892. cache_position,
  893. self_attn_cache,
  894. )
  895. encoder_attention_mask = self._update_cross_attn_mask(
  896. encoder_hidden_states,
  897. encoder_attention_mask,
  898. input_shape,
  899. inputs_embeds,
  900. )
  901. # embed positions
  902. position_ids = self.embed_positions(
  903. (batch_size, seq_length), past_key_values_length, position_ids=cache_position
  904. )
  905. hidden_states = inputs_embeds + position_ids
  906. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  907. # decoder layers
  908. all_hidden_states = () if output_hidden_states else None
  909. all_self_attns = () if output_attentions else None
  910. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  911. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  912. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  913. if attn_mask is not None:
  914. assert attn_mask.size()[0] == (len(self.layers)), (
  915. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  916. f" {head_mask.size()[0]}."
  917. )
  918. for idx, decoder_layer in enumerate(self.layers):
  919. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  920. if output_hidden_states:
  921. all_hidden_states += (hidden_states,)
  922. if self.training:
  923. dropout_probability = torch.rand([])
  924. if dropout_probability < self.layerdrop:
  925. continue
  926. layer_outputs = decoder_layer(
  927. hidden_states,
  928. causal_mask,
  929. encoder_hidden_states, # as a positional argument for gradient checkpointing
  930. encoder_attention_mask=encoder_attention_mask,
  931. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  932. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  933. past_key_values=past_key_values,
  934. output_attentions=output_attentions,
  935. use_cache=use_cache,
  936. cache_position=cache_position,
  937. )
  938. hidden_states = layer_outputs[0]
  939. if output_attentions:
  940. all_self_attns += (layer_outputs[1],)
  941. if encoder_hidden_states is not None:
  942. all_cross_attentions += (layer_outputs[2],)
  943. # add hidden states from the last decoder layer
  944. if output_hidden_states:
  945. all_hidden_states += (hidden_states,)
  946. if not return_dict:
  947. return tuple(
  948. v
  949. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  950. if v is not None
  951. )
  952. return BaseModelOutputWithPastAndCrossAttentions(
  953. last_hidden_state=hidden_states,
  954. past_key_values=past_key_values,
  955. hidden_states=all_hidden_states,
  956. attentions=all_self_attns,
  957. cross_attentions=all_cross_attentions,
  958. )
  959. @auto_docstring
  960. class MarianModel(MarianPreTrainedModel):
  961. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  962. def __init__(self, config: MarianConfig):
  963. super().__init__(config)
  964. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  965. # We always use self.shared for token embeddings to ensure compatibility with all marian models
  966. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  967. if self.config.share_encoder_decoder_embeddings:
  968. encoder_embed_tokens = decoder_embed_tokens = self.shared
  969. else:
  970. # Since the embeddings are not shared, deepcopy the embeddings here for encoder
  971. # and decoder to make sure they are not tied.
  972. encoder_embed_tokens = copy.deepcopy(self.shared)
  973. decoder_embed_tokens = copy.deepcopy(self.shared)
  974. self.shared = None
  975. self.encoder = MarianEncoder(config, encoder_embed_tokens)
  976. self.decoder = MarianDecoder(config, decoder_embed_tokens)
  977. # Initialize weights and apply final processing
  978. self.post_init()
  979. def get_input_embeddings(self):
  980. # This will return shared embeddings if they are shared else specific to encoder.
  981. return self.get_encoder().get_input_embeddings()
  982. def set_input_embeddings(self, value):
  983. if self.config.share_encoder_decoder_embeddings:
  984. self.shared = value
  985. self.encoder.embed_tokens = self.shared
  986. self.decoder.embed_tokens = self.shared
  987. else: # if not shared only set encoder embeedings
  988. self.encoder.embed_tokens = value
  989. def get_decoder_input_embeddings(self):
  990. if self.config.share_encoder_decoder_embeddings:
  991. raise ValueError(
  992. "`get_decoder_input_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
  993. "is `True`. Please use `get_input_embeddings` instead."
  994. )
  995. return self.get_decoder().get_input_embeddings()
  996. def set_decoder_input_embeddings(self, value):
  997. if self.config.share_encoder_decoder_embeddings:
  998. raise ValueError(
  999. "`config.share_encoder_decoder_embeddings` is set to `True` meaning the decoder input embeddings "
  1000. "are shared with the encoder. In order to set the decoder input embeddings, you should simply set "
  1001. "the encoder input embeddings by calling `set_input_embeddings` with the appropriate embeddings."
  1002. )
  1003. self.decoder.embed_tokens = value
  1004. def get_encoder(self):
  1005. return self.encoder
  1006. def resize_decoder_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
  1007. if self.config.share_encoder_decoder_embeddings:
  1008. raise ValueError(
  1009. "`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
  1010. "is `True`. Please use `resize_token_embeddings` instead."
  1011. )
  1012. old_embeddings = self.get_decoder_input_embeddings()
  1013. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  1014. self.set_decoder_input_embeddings(new_embeddings)
  1015. model_embeds = self.get_decoder_input_embeddings()
  1016. if new_num_tokens is None:
  1017. return model_embeds
  1018. # Update base model and current model config
  1019. self.config.decoder_vocab_size = new_num_tokens
  1020. # Tie weights again if needed
  1021. self.tie_weights()
  1022. return model_embeds
  1023. @auto_docstring
  1024. def forward(
  1025. self,
  1026. input_ids: Optional[torch.LongTensor] = None,
  1027. attention_mask: Optional[torch.Tensor] = None,
  1028. decoder_input_ids: Optional[torch.LongTensor] = None,
  1029. decoder_attention_mask: Optional[torch.Tensor] = None,
  1030. head_mask: Optional[torch.Tensor] = None,
  1031. decoder_head_mask: Optional[torch.Tensor] = None,
  1032. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1033. encoder_outputs: Optional[Union[tuple[torch.Tensor], BaseModelOutput]] = None,
  1034. past_key_values: Optional[Cache] = None,
  1035. inputs_embeds: Optional[torch.FloatTensor] = None,
  1036. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1037. use_cache: Optional[bool] = None,
  1038. output_attentions: Optional[bool] = None,
  1039. output_hidden_states: Optional[bool] = None,
  1040. return_dict: Optional[bool] = None,
  1041. cache_position: Optional[torch.Tensor] = None,
  1042. ) -> Seq2SeqModelOutput:
  1043. r"""
  1044. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1045. Indices of decoder input sequence tokens in the vocabulary.
  1046. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1047. [`PreTrainedTokenizer.__call__`] for details.
  1048. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1049. Marian uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1050. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1051. `past_key_values`).
  1052. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1053. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1054. be used by default.
  1055. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1056. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1057. 1]`:
  1058. - 1 indicates the head is **not masked**,
  1059. - 0 indicates the head is **masked**.
  1060. Example:
  1061. ```python
  1062. >>> from transformers import AutoTokenizer, MarianModel
  1063. >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
  1064. >>> model = MarianModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
  1065. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  1066. >>> decoder_inputs = tokenizer(
  1067. ... "<pad> Studien haben gezeigt dass es hilfreich ist einen Hund zu besitzen",
  1068. ... return_tensors="pt",
  1069. ... add_special_tokens=False,
  1070. ... )
  1071. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
  1072. >>> last_hidden_states = outputs.last_hidden_state
  1073. >>> list(last_hidden_states.shape)
  1074. [1, 26, 512]
  1075. ```"""
  1076. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1077. output_hidden_states = (
  1078. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1079. )
  1080. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1081. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1082. if encoder_outputs is None:
  1083. encoder_outputs = self.encoder(
  1084. input_ids=input_ids,
  1085. attention_mask=attention_mask,
  1086. head_mask=head_mask,
  1087. inputs_embeds=inputs_embeds,
  1088. output_attentions=output_attentions,
  1089. output_hidden_states=output_hidden_states,
  1090. return_dict=return_dict,
  1091. )
  1092. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1093. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1094. encoder_outputs = BaseModelOutput(
  1095. last_hidden_state=encoder_outputs[0],
  1096. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1097. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1098. )
  1099. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1100. decoder_outputs = self.decoder(
  1101. input_ids=decoder_input_ids,
  1102. attention_mask=decoder_attention_mask,
  1103. encoder_hidden_states=encoder_outputs[0],
  1104. encoder_attention_mask=attention_mask,
  1105. head_mask=decoder_head_mask,
  1106. cross_attn_head_mask=cross_attn_head_mask,
  1107. past_key_values=past_key_values,
  1108. inputs_embeds=decoder_inputs_embeds,
  1109. use_cache=use_cache,
  1110. output_attentions=output_attentions,
  1111. output_hidden_states=output_hidden_states,
  1112. return_dict=return_dict,
  1113. cache_position=cache_position,
  1114. )
  1115. if not return_dict:
  1116. return decoder_outputs + encoder_outputs
  1117. return Seq2SeqModelOutput(
  1118. last_hidden_state=decoder_outputs.last_hidden_state,
  1119. past_key_values=decoder_outputs.past_key_values,
  1120. decoder_hidden_states=decoder_outputs.hidden_states,
  1121. decoder_attentions=decoder_outputs.attentions,
  1122. cross_attentions=decoder_outputs.cross_attentions,
  1123. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1124. encoder_hidden_states=encoder_outputs.hidden_states,
  1125. encoder_attentions=encoder_outputs.attentions,
  1126. )
  1127. @auto_docstring(
  1128. custom_intro="""
  1129. The Marian Model with a language modeling head. Can be used for summarization.
  1130. """
  1131. )
  1132. class MarianMTModel(MarianPreTrainedModel, GenerationMixin):
  1133. base_model_prefix = "model"
  1134. _keys_to_ignore_on_load_missing = [
  1135. "final_logits_bias",
  1136. "encoder.embed_positions.weight",
  1137. "decoder.embed_positions.weight",
  1138. ]
  1139. _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
  1140. _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
  1141. def __init__(self, config: MarianConfig):
  1142. super().__init__(config)
  1143. self.model = MarianModel(config)
  1144. target_vocab_size = config.vocab_size if config.share_encoder_decoder_embeddings else config.decoder_vocab_size
  1145. self.register_buffer("final_logits_bias", torch.zeros((1, target_vocab_size)))
  1146. self.lm_head = nn.Linear(config.d_model, target_vocab_size, bias=False)
  1147. # Initialize weights and apply final processing
  1148. self.post_init()
  1149. def get_encoder(self):
  1150. return self.model.get_encoder()
  1151. def get_decoder(self):
  1152. return self.model.get_decoder()
  1153. def resize_token_embeddings(
  1154. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  1155. ) -> nn.Embedding:
  1156. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1157. if self.config.share_encoder_decoder_embeddings:
  1158. self._resize_final_logits_bias(new_num_tokens)
  1159. return new_embeddings
  1160. # NOTE: `_resize_token_embeddings` was rewritten in the base class, *args exists to absorb the extra arg
  1161. def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None, *args) -> nn.Embedding:
  1162. old_embeddings = self.get_input_embeddings()
  1163. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
  1164. self.set_input_embeddings(new_embeddings)
  1165. new_num_tokens = new_embeddings.weight.shape[0]
  1166. # update config.decoder_vocab_size if embeddings are tied
  1167. if self.config.share_encoder_decoder_embeddings:
  1168. self.config.decoder_vocab_size = new_num_tokens
  1169. # if word embeddings are not tied, make sure that lm head is resized as well
  1170. if (
  1171. self.config.share_encoder_decoder_embeddings
  1172. and self.get_output_embeddings() is not None
  1173. and not self.config.tie_word_embeddings
  1174. ):
  1175. old_lm_head = self.get_output_embeddings()
  1176. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
  1177. self.set_output_embeddings(new_lm_head)
  1178. return self.get_input_embeddings()
  1179. def resize_decoder_token_embeddings(self, new_num_tokens):
  1180. if self.config.share_encoder_decoder_embeddings:
  1181. raise ValueError(
  1182. "`resize_decoder_token_embeddings` should not be called if `config.share_encoder_decoder_embeddings` "
  1183. "is `True`. Please use `resize_token_embeddings` instead."
  1184. )
  1185. old_embeddings = self.model.get_decoder_input_embeddings()
  1186. new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
  1187. self.model.set_decoder_input_embeddings(new_embeddings)
  1188. # if word embeddings are not tied, make sure that lm head is resized as well
  1189. if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
  1190. old_lm_head = self.get_output_embeddings()
  1191. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
  1192. self.set_output_embeddings(new_lm_head)
  1193. model_embeds = self.model.get_decoder_input_embeddings()
  1194. if new_num_tokens is None:
  1195. return model_embeds
  1196. # Update base model and current model config
  1197. self.config.decoder_vocab_size = new_num_tokens
  1198. # Tie weights again if needed
  1199. self.tie_weights()
  1200. self._resize_final_logits_bias(new_num_tokens)
  1201. return model_embeds
  1202. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1203. old_num_tokens = self.final_logits_bias.shape[-1]
  1204. if new_num_tokens <= old_num_tokens:
  1205. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1206. else:
  1207. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1208. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1209. self.register_buffer("final_logits_bias", new_bias)
  1210. def set_output_embeddings(self, new_embeddings: nn.Embedding):
  1211. self.lm_head = new_embeddings
  1212. def tie_weights(self):
  1213. """
  1214. Tie the weights between the input embeddings and the output embeddings.
  1215. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
  1216. weights instead.
  1217. """
  1218. output_embeddings = self.get_output_embeddings()
  1219. if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
  1220. # if embeddings are shared this will return shared embeddings otherwise decoder embed_tokens
  1221. word_embeddings = self.get_decoder().get_input_embeddings()
  1222. self._tie_or_clone_weights(output_embeddings, word_embeddings)
  1223. if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
  1224. if hasattr(self, self.base_model_prefix):
  1225. self = getattr(self, self.base_model_prefix)
  1226. tied_weights = self._tie_encoder_decoder_weights(
  1227. self.encoder, self.decoder, self.base_model_prefix, "encoder"
  1228. )
  1229. # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
  1230. # attributed not an instance member, therefore modifying it will modify the entire class
  1231. # Leading to issues on subsequent calls by different tests or subsequent calls.
  1232. self._dynamic_tied_weights_keys = tied_weights
  1233. for module in self.modules():
  1234. if hasattr(module, "_tie_weights"):
  1235. module._tie_weights()
  1236. @auto_docstring
  1237. def forward(
  1238. self,
  1239. input_ids: Optional[torch.LongTensor] = None,
  1240. attention_mask: Optional[torch.Tensor] = None,
  1241. decoder_input_ids: Optional[torch.LongTensor] = None,
  1242. decoder_attention_mask: Optional[torch.Tensor] = None,
  1243. head_mask: Optional[torch.Tensor] = None,
  1244. decoder_head_mask: Optional[torch.Tensor] = None,
  1245. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1246. encoder_outputs: Optional[Union[tuple[torch.Tensor], BaseModelOutput]] = None,
  1247. past_key_values: Optional[Cache] = None,
  1248. inputs_embeds: Optional[torch.FloatTensor] = None,
  1249. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1250. labels: Optional[torch.LongTensor] = None,
  1251. use_cache: Optional[bool] = None,
  1252. output_attentions: Optional[bool] = None,
  1253. output_hidden_states: Optional[bool] = None,
  1254. return_dict: Optional[bool] = None,
  1255. cache_position: Optional[torch.Tensor] = None,
  1256. ) -> Seq2SeqLMOutput:
  1257. r"""
  1258. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1259. Indices of decoder input sequence tokens in the vocabulary.
  1260. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1261. [`PreTrainedTokenizer.__call__`] for details.
  1262. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1263. Marian uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1264. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1265. `past_key_values`).
  1266. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1267. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1268. be used by default.
  1269. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1270. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1271. 1]`:
  1272. - 1 indicates the head is **not masked**,
  1273. - 0 indicates the head is **masked**.
  1274. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1275. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1276. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1277. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1278. Example:
  1279. ```python
  1280. >>> from transformers import AutoTokenizer, MarianMTModel
  1281. >>> src = "fr" # source language
  1282. >>> trg = "en" # target language
  1283. >>> model_name = f"Helsinki-NLP/opus-mt-{src}-{trg}"
  1284. >>> model = MarianMTModel.from_pretrained(model_name)
  1285. >>> tokenizer = AutoTokenizer.from_pretrained(model_name)
  1286. >>> sample_text = "où est l'arrêt de bus ?"
  1287. >>> batch = tokenizer([sample_text], return_tensors="pt")
  1288. >>> generated_ids = model.generate(**batch)
  1289. >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1290. "Where's the bus stop?"
  1291. ```
  1292. """
  1293. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1294. if labels is not None:
  1295. if use_cache:
  1296. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1297. use_cache = False
  1298. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1299. decoder_input_ids = shift_tokens_right(
  1300. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1301. )
  1302. outputs = self.model(
  1303. input_ids,
  1304. attention_mask=attention_mask,
  1305. decoder_input_ids=decoder_input_ids,
  1306. encoder_outputs=encoder_outputs,
  1307. decoder_attention_mask=decoder_attention_mask,
  1308. head_mask=head_mask,
  1309. decoder_head_mask=decoder_head_mask,
  1310. cross_attn_head_mask=cross_attn_head_mask,
  1311. past_key_values=past_key_values,
  1312. inputs_embeds=inputs_embeds,
  1313. decoder_inputs_embeds=decoder_inputs_embeds,
  1314. use_cache=use_cache,
  1315. output_attentions=output_attentions,
  1316. output_hidden_states=output_hidden_states,
  1317. return_dict=return_dict,
  1318. cache_position=cache_position,
  1319. )
  1320. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1321. masked_lm_loss = None
  1322. if labels is not None:
  1323. loss_fct = CrossEntropyLoss()
  1324. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
  1325. if not return_dict:
  1326. output = (lm_logits,) + outputs[1:]
  1327. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1328. return Seq2SeqLMOutput(
  1329. loss=masked_lm_loss,
  1330. logits=lm_logits,
  1331. past_key_values=outputs.past_key_values,
  1332. decoder_hidden_states=outputs.decoder_hidden_states,
  1333. decoder_attentions=outputs.decoder_attentions,
  1334. cross_attentions=outputs.cross_attentions,
  1335. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1336. encoder_hidden_states=outputs.encoder_hidden_states,
  1337. encoder_attentions=outputs.encoder_attentions,
  1338. )
  1339. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1340. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  1341. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian
  1342. class MarianDecoderWrapper(MarianPreTrainedModel):
  1343. """
  1344. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1345. used in combination with the [`EncoderDecoderModel`] framework.
  1346. """
  1347. def __init__(self, config):
  1348. super().__init__(config)
  1349. self.decoder = MarianDecoder(config)
  1350. def forward(self, *args, **kwargs):
  1351. return self.decoder(*args, **kwargs)
  1352. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
  1353. class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin):
  1354. _tied_weights_keys = ["lm_head.weight"]
  1355. def __init__(self, config):
  1356. config.is_decoder = True
  1357. config.is_encoder_decoder = False
  1358. super().__init__(config)
  1359. self.model = MarianDecoderWrapper(config)
  1360. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1361. # Initialize weights and apply final processing
  1362. self.post_init()
  1363. def get_input_embeddings(self):
  1364. return self.model.decoder.embed_tokens
  1365. def set_input_embeddings(self, value):
  1366. self.model.decoder.embed_tokens = value
  1367. def set_decoder(self, decoder):
  1368. self.model.decoder = decoder
  1369. def get_decoder(self):
  1370. return self.model.decoder
  1371. @auto_docstring
  1372. def forward(
  1373. self,
  1374. input_ids: Optional[torch.LongTensor] = None,
  1375. attention_mask: Optional[torch.Tensor] = None,
  1376. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1377. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1378. head_mask: Optional[torch.Tensor] = None,
  1379. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1380. past_key_values: Optional[Cache] = None,
  1381. inputs_embeds: Optional[torch.FloatTensor] = None,
  1382. labels: Optional[torch.LongTensor] = None,
  1383. use_cache: Optional[bool] = None,
  1384. output_attentions: Optional[bool] = None,
  1385. output_hidden_states: Optional[bool] = None,
  1386. return_dict: Optional[bool] = None,
  1387. cache_position: Optional[torch.LongTensor] = None,
  1388. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1389. r"""
  1390. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1391. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1392. - 1 indicates the head is **not masked**,
  1393. - 0 indicates the head is **masked**.
  1394. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1395. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1396. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1397. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1398. Example:
  1399. ```python
  1400. >>> from transformers import AutoTokenizer, MarianForCausalLM
  1401. >>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en")
  1402. >>> model = MarianForCausalLM.from_pretrained("Helsinki-NLP/opus-mt-fr-en", add_cross_attention=False)
  1403. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1404. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1405. >>> outputs = model(**inputs)
  1406. >>> logits = outputs.logits
  1407. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1408. >>> list(logits.shape) == expected_shape
  1409. True
  1410. ```"""
  1411. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1412. output_hidden_states = (
  1413. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1414. )
  1415. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1416. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1417. outputs = self.model.decoder(
  1418. input_ids=input_ids,
  1419. attention_mask=attention_mask,
  1420. encoder_hidden_states=encoder_hidden_states,
  1421. encoder_attention_mask=encoder_attention_mask,
  1422. head_mask=head_mask,
  1423. cross_attn_head_mask=cross_attn_head_mask,
  1424. past_key_values=past_key_values,
  1425. inputs_embeds=inputs_embeds,
  1426. use_cache=use_cache,
  1427. output_attentions=output_attentions,
  1428. output_hidden_states=output_hidden_states,
  1429. return_dict=return_dict,
  1430. cache_position=cache_position,
  1431. )
  1432. logits = self.lm_head(outputs[0])
  1433. loss = None
  1434. if labels is not None:
  1435. labels = labels.to(logits.device)
  1436. loss_fct = CrossEntropyLoss()
  1437. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1438. if not return_dict:
  1439. output = (logits,) + outputs[1:]
  1440. return (loss,) + output if loss is not None else output
  1441. return CausalLMOutputWithCrossAttentions(
  1442. loss=loss,
  1443. logits=logits,
  1444. past_key_values=outputs.past_key_values,
  1445. hidden_states=outputs.hidden_states,
  1446. attentions=outputs.attentions,
  1447. cross_attentions=outputs.cross_attentions,
  1448. )
  1449. __all__ = ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"]