modeling_pegasus.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671
  1. # coding=utf-8
  2. # Copyright 2021, Google and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PEGASUS model."""
  16. import copy
  17. import math
  18. from typing import Callable, Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import (
  27. AttentionMaskConverter,
  28. _prepare_4d_attention_mask,
  29. _prepare_4d_attention_mask_for_sdpa,
  30. )
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutput,
  35. BaseModelOutputWithPastAndCrossAttentions,
  36. CausalLMOutputWithCrossAttentions,
  37. Seq2SeqLMOutput,
  38. Seq2SeqModelOutput,
  39. )
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import (
  43. auto_docstring,
  44. is_torch_flex_attn_available,
  45. is_torchdynamo_compiling,
  46. logging,
  47. )
  48. from ...utils.deprecation import deprecate_kwarg
  49. from .configuration_pegasus import PegasusConfig
  50. if is_torch_flex_attn_available():
  51. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  52. logger = logging.get_logger(__name__)
  53. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  54. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  55. """
  56. Shift input ids one token to the right.
  57. """
  58. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  59. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  60. shifted_input_ids[:, 0] = decoder_start_token_id
  61. if pad_token_id is None:
  62. raise ValueError("self.model.config.pad_token_id has to be defined.")
  63. # replace possible -100 values in labels by `pad_token_id`
  64. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  65. return shifted_input_ids
  66. # Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Pegasus
  67. class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
  68. """This module produces sinusoidal positional embeddings of any length."""
  69. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
  70. super().__init__(num_positions, embedding_dim)
  71. def _init_weight(self):
  72. """
  73. Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
  74. the 2nd half of the vector. [dim // 2:]
  75. """
  76. n_pos, dim = self.weight.shape
  77. position_enc = np.array(
  78. [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
  79. )
  80. out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
  81. sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
  82. out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  83. out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  84. self.weight = nn.Parameter(out, requires_grad=False)
  85. @torch.no_grad()
  86. def forward(
  87. self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  88. ) -> torch.Tensor:
  89. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  90. if position_ids is None:
  91. bsz, seq_len = input_ids_shape[:2]
  92. position_ids = torch.arange(
  93. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  94. )
  95. return super().forward(position_ids)
  96. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  97. def eager_attention_forward(
  98. module: nn.Module,
  99. query: torch.Tensor,
  100. key: torch.Tensor,
  101. value: torch.Tensor,
  102. attention_mask: Optional[torch.Tensor],
  103. scaling: Optional[float] = None,
  104. dropout: float = 0.0,
  105. head_mask: Optional[torch.Tensor] = None,
  106. **kwargs,
  107. ):
  108. if scaling is None:
  109. scaling = query.size(-1) ** -0.5
  110. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  111. if attention_mask is not None:
  112. attn_weights = attn_weights + attention_mask
  113. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  114. if head_mask is not None:
  115. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  116. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  117. attn_output = torch.matmul(attn_weights, value)
  118. attn_output = attn_output.transpose(1, 2).contiguous()
  119. return attn_output, attn_weights
  120. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus
  121. class PegasusAttention(nn.Module):
  122. """Multi-headed attention from 'Attention Is All You Need' paper"""
  123. def __init__(
  124. self,
  125. embed_dim: int,
  126. num_heads: int,
  127. dropout: float = 0.0,
  128. is_decoder: bool = False,
  129. bias: bool = True,
  130. is_causal: bool = False,
  131. config: Optional[PegasusConfig] = None,
  132. layer_idx: Optional[int] = None,
  133. ):
  134. super().__init__()
  135. self.embed_dim = embed_dim
  136. self.num_heads = num_heads
  137. self.dropout = dropout
  138. self.head_dim = embed_dim // num_heads
  139. self.config = config
  140. if (self.head_dim * num_heads) != self.embed_dim:
  141. raise ValueError(
  142. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  143. f" and `num_heads`: {num_heads})."
  144. )
  145. self.scaling = self.head_dim**-0.5
  146. self.is_decoder = is_decoder
  147. self.is_causal = is_causal
  148. self.layer_idx = layer_idx
  149. if layer_idx is None and self.is_decoder:
  150. logger.warning_once(
  151. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  152. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  153. "when creating this class."
  154. )
  155. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  156. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  157. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  158. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  159. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  160. def forward(
  161. self,
  162. hidden_states: torch.Tensor,
  163. key_value_states: Optional[torch.Tensor] = None,
  164. past_key_values: Optional[Cache] = None,
  165. attention_mask: Optional[torch.Tensor] = None,
  166. layer_head_mask: Optional[torch.Tensor] = None,
  167. output_attentions: bool = False,
  168. cache_position: Optional[torch.Tensor] = None,
  169. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  170. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  171. **kwargs: Unpack[FlashAttentionKwargs],
  172. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  173. """Input shape: Batch x Time x Channel"""
  174. # if key_value_states are provided this layer is used as a cross-attention layer
  175. # for the decoder
  176. is_cross_attention = key_value_states is not None
  177. # determine input shapes
  178. bsz, tgt_len = hidden_states.shape[:-1]
  179. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  180. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  181. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  182. # get query proj
  183. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  184. is_updated = False
  185. if past_key_values is not None:
  186. if isinstance(past_key_values, EncoderDecoderCache):
  187. is_updated = past_key_values.is_updated.get(self.layer_idx)
  188. if is_cross_attention:
  189. # after the first generated id, we can subsequently re-use all key/value_states from cache
  190. curr_past_key_value = past_key_values.cross_attention_cache
  191. else:
  192. curr_past_key_value = past_key_values.self_attention_cache
  193. else:
  194. curr_past_key_value = past_key_values
  195. current_states = key_value_states if is_cross_attention else hidden_states
  196. if is_cross_attention and past_key_values is not None and is_updated:
  197. # reuse k,v, cross_attentions
  198. key_states = curr_past_key_value.layers[self.layer_idx].keys
  199. value_states = curr_past_key_value.layers[self.layer_idx].values
  200. else:
  201. key_states = self.k_proj(current_states)
  202. value_states = self.v_proj(current_states)
  203. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  204. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  205. if past_key_values is not None:
  206. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  207. cache_position = cache_position if not is_cross_attention else None
  208. key_states, value_states = curr_past_key_value.update(
  209. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  210. )
  211. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  212. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  213. past_key_values.is_updated[self.layer_idx] = True
  214. attention_interface: Callable = eager_attention_forward
  215. if self.config._attn_implementation != "eager":
  216. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  217. attn_output, attn_weights = attention_interface(
  218. self,
  219. query_states,
  220. key_states,
  221. value_states,
  222. attention_mask,
  223. dropout=0.0 if not self.training else self.dropout,
  224. scaling=self.scaling,
  225. output_attentions=output_attentions,
  226. head_mask=layer_head_mask,
  227. **kwargs,
  228. )
  229. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  230. attn_output = self.out_proj(attn_output)
  231. return attn_output, attn_weights
  232. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
  233. class PegasusEncoderLayer(GradientCheckpointingLayer):
  234. def __init__(self, config: PegasusConfig):
  235. super().__init__()
  236. self.embed_dim = config.d_model
  237. self.self_attn = PegasusAttention(
  238. embed_dim=self.embed_dim,
  239. num_heads=config.encoder_attention_heads,
  240. dropout=config.attention_dropout,
  241. config=config,
  242. )
  243. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  244. self.dropout = config.dropout
  245. self.activation_fn = ACT2FN[config.activation_function]
  246. self.activation_dropout = config.activation_dropout
  247. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  248. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  249. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  250. def forward(
  251. self,
  252. hidden_states: torch.Tensor,
  253. attention_mask: torch.Tensor,
  254. layer_head_mask: torch.Tensor,
  255. output_attentions: bool = False,
  256. ) -> torch.Tensor:
  257. """
  258. Args:
  259. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  260. attention_mask (`torch.FloatTensor`): attention mask of size
  261. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  262. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  263. `(encoder_attention_heads,)`.
  264. output_attentions (`bool`, *optional*):
  265. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  266. returned tensors for more detail.
  267. """
  268. residual = hidden_states
  269. hidden_states = self.self_attn_layer_norm(hidden_states)
  270. hidden_states, attn_weights = self.self_attn(
  271. hidden_states=hidden_states,
  272. attention_mask=attention_mask,
  273. layer_head_mask=layer_head_mask,
  274. output_attentions=output_attentions,
  275. )
  276. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  277. hidden_states = residual + hidden_states
  278. residual = hidden_states
  279. hidden_states = self.final_layer_norm(hidden_states)
  280. hidden_states = self.activation_fn(self.fc1(hidden_states))
  281. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  282. hidden_states = self.fc2(hidden_states)
  283. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  284. hidden_states = residual + hidden_states
  285. if hidden_states.dtype == torch.float16:
  286. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  287. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  288. return hidden_states, attn_weights
  289. # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS
  290. class PegasusDecoderLayer(GradientCheckpointingLayer):
  291. def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None):
  292. super().__init__()
  293. self.embed_dim = config.d_model
  294. self.self_attn = PegasusAttention(
  295. embed_dim=self.embed_dim,
  296. num_heads=config.decoder_attention_heads,
  297. dropout=config.attention_dropout,
  298. is_decoder=True,
  299. is_causal=True,
  300. config=config,
  301. layer_idx=layer_idx,
  302. )
  303. self.dropout = config.dropout
  304. self.activation_fn = ACT2FN[config.activation_function]
  305. self.activation_dropout = config.activation_dropout
  306. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  307. self.encoder_attn = PegasusAttention(
  308. self.embed_dim,
  309. config.decoder_attention_heads,
  310. dropout=config.attention_dropout,
  311. is_decoder=True,
  312. config=config,
  313. layer_idx=layer_idx,
  314. )
  315. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  316. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  317. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  318. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  319. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  320. def forward(
  321. self,
  322. hidden_states: torch.Tensor,
  323. attention_mask: Optional[torch.Tensor] = None,
  324. encoder_hidden_states: Optional[torch.Tensor] = None,
  325. encoder_attention_mask: Optional[torch.Tensor] = None,
  326. layer_head_mask: Optional[torch.Tensor] = None,
  327. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  328. past_key_values: Optional[Cache] = None,
  329. output_attentions: Optional[bool] = False,
  330. use_cache: Optional[bool] = True,
  331. cache_position: Optional[torch.Tensor] = None,
  332. ) -> torch.Tensor:
  333. """
  334. Args:
  335. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  336. attention_mask (`torch.FloatTensor`): attention mask of size
  337. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  338. encoder_hidden_states (`torch.FloatTensor`):
  339. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  340. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  341. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  342. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  343. `(encoder_attention_heads,)`.
  344. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  345. size `(decoder_attention_heads,)`.
  346. past_key_values (`Cache`): cached past key and value projection states
  347. output_attentions (`bool`, *optional*):
  348. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  349. returned tensors for more detail.
  350. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  351. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  352. cache in the correct position and to infer the complete sequence length.
  353. """
  354. residual = hidden_states
  355. hidden_states = self.self_attn_layer_norm(hidden_states)
  356. # Self Attention
  357. hidden_states, self_attn_weights = self.self_attn(
  358. hidden_states=hidden_states,
  359. past_key_values=past_key_values,
  360. attention_mask=attention_mask,
  361. layer_head_mask=layer_head_mask,
  362. output_attentions=output_attentions,
  363. cache_position=cache_position,
  364. )
  365. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  366. hidden_states = residual + hidden_states
  367. # Cross-Attention Block
  368. cross_attn_weights = None
  369. if encoder_hidden_states is not None:
  370. residual = hidden_states
  371. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  372. hidden_states, cross_attn_weights = self.encoder_attn(
  373. hidden_states=hidden_states,
  374. key_value_states=encoder_hidden_states,
  375. attention_mask=encoder_attention_mask,
  376. layer_head_mask=cross_attn_layer_head_mask,
  377. past_key_values=past_key_values,
  378. output_attentions=output_attentions,
  379. )
  380. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  381. hidden_states = residual + hidden_states
  382. # Fully Connected
  383. residual = hidden_states
  384. hidden_states = self.final_layer_norm(hidden_states)
  385. hidden_states = self.activation_fn(self.fc1(hidden_states))
  386. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  387. hidden_states = self.fc2(hidden_states)
  388. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  389. hidden_states = residual + hidden_states
  390. outputs = (hidden_states,)
  391. if output_attentions:
  392. outputs += (self_attn_weights, cross_attn_weights)
  393. return outputs
  394. @auto_docstring
  395. class PegasusPreTrainedModel(PreTrainedModel):
  396. config: PegasusConfig
  397. base_model_prefix = "model"
  398. supports_gradient_checkpointing = True
  399. _supports_flash_attn = True
  400. _supports_sdpa = True
  401. _supports_flex_attn = True
  402. _can_compile_fullgraph = True
  403. def _init_weights(self, module):
  404. std = self.config.init_std
  405. if isinstance(module, nn.Linear):
  406. module.weight.data.normal_(mean=0.0, std=std)
  407. if module.bias is not None:
  408. module.bias.data.zero_()
  409. elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
  410. module._init_weight()
  411. elif isinstance(module, nn.Embedding):
  412. module.weight.data.normal_(mean=0.0, std=std)
  413. if module.padding_idx is not None:
  414. module.weight.data[module.padding_idx].zero_()
  415. elif isinstance(module, nn.LayerNorm):
  416. module.weight.data.fill_(1.0)
  417. module.bias.data.zero_()
  418. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  419. def _update_full_mask(
  420. self,
  421. attention_mask: Union[torch.Tensor, None],
  422. inputs_embeds: torch.Tensor,
  423. ):
  424. if attention_mask is not None:
  425. if self.config._attn_implementation == "flash_attention_2":
  426. attention_mask = attention_mask if 0 in attention_mask else None
  427. elif self.config._attn_implementation == "sdpa":
  428. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  429. # the manual implementation that requires a 4D causal mask in all cases.
  430. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  431. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  432. elif self.config._attn_implementation == "flex_attention":
  433. if isinstance(attention_mask, torch.Tensor):
  434. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  435. else:
  436. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  437. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  438. return attention_mask
  439. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  440. def _update_causal_mask(
  441. self,
  442. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  443. input_tensor: torch.Tensor,
  444. cache_position: torch.Tensor,
  445. past_key_values: Cache,
  446. ):
  447. if self.config._attn_implementation == "flex_attention":
  448. if isinstance(attention_mask, torch.Tensor):
  449. attention_mask = make_flex_block_causal_mask(attention_mask)
  450. # Other attention flavors support in-built causal (when `mask is None`)
  451. # while we need to create our specific block mask regardless
  452. elif attention_mask is None:
  453. attention_mask = make_flex_block_causal_mask(
  454. torch.ones(
  455. size=(input_tensor.shape[0], input_tensor.shape[1]),
  456. device=attention_mask.device,
  457. )
  458. )
  459. return attention_mask
  460. if self.config._attn_implementation == "flash_attention_2":
  461. if attention_mask is not None and (attention_mask == 0.0).any():
  462. return attention_mask
  463. return None
  464. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  465. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  466. # to infer the attention mask.
  467. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  468. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  469. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  470. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  471. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  472. attention_mask,
  473. inputs_embeds=input_tensor,
  474. past_key_values_length=past_seen_tokens,
  475. is_training=self.training,
  476. ):
  477. return None
  478. dtype = input_tensor.dtype
  479. sequence_length = input_tensor.shape[1]
  480. if using_compilable_cache:
  481. target_length = past_key_values.get_max_cache_shape()
  482. else:
  483. target_length = (
  484. attention_mask.shape[-1]
  485. if isinstance(attention_mask, torch.Tensor)
  486. else past_seen_tokens + sequence_length + 1
  487. )
  488. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  489. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  490. attention_mask,
  491. sequence_length=sequence_length,
  492. target_length=target_length,
  493. dtype=dtype,
  494. cache_position=cache_position,
  495. batch_size=input_tensor.shape[0],
  496. )
  497. if (
  498. self.config._attn_implementation == "sdpa"
  499. and attention_mask is not None
  500. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  501. ):
  502. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  503. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  504. # Details: https://github.com/pytorch/pytorch/issues/110213
  505. min_dtype = torch.finfo(dtype).min
  506. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  507. return causal_mask
  508. @staticmethod
  509. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  510. def _prepare_4d_causal_attention_mask_with_cache_position(
  511. attention_mask: torch.Tensor,
  512. sequence_length: int,
  513. target_length: int,
  514. dtype: torch.dtype,
  515. cache_position: torch.Tensor,
  516. batch_size: int,
  517. **kwargs,
  518. ):
  519. """
  520. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  521. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  522. Args:
  523. attention_mask (`torch.Tensor`):
  524. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  525. `(batch_size, 1, query_length, key_value_length)`.
  526. sequence_length (`int`):
  527. The sequence length being processed.
  528. target_length (`int`):
  529. The target length: when generating with static cache, the mask should be as long as the static cache,
  530. to account for the 0 padding, the part of the cache that is not filled yet.
  531. dtype (`torch.dtype`):
  532. The dtype to use for the 4D attention mask.
  533. cache_position (`torch.Tensor`):
  534. Indices depicting the position of the input sequence tokens in the sequence.
  535. batch_size (`torch.Tensor`):
  536. Batch size.
  537. """
  538. if attention_mask is not None and attention_mask.dim() == 4:
  539. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  540. causal_mask = attention_mask
  541. else:
  542. min_dtype = torch.finfo(dtype).min
  543. causal_mask = torch.full(
  544. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  545. )
  546. if sequence_length != 1:
  547. causal_mask = torch.triu(causal_mask, diagonal=1)
  548. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  549. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  550. if attention_mask is not None:
  551. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  552. mask_length = attention_mask.shape[-1]
  553. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  554. causal_mask.device
  555. )
  556. padding_mask = padding_mask == 0
  557. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  558. padding_mask, min_dtype
  559. )
  560. return causal_mask
  561. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  562. def _update_cross_attn_mask(
  563. self,
  564. encoder_hidden_states: Union[torch.Tensor, None],
  565. encoder_attention_mask: Union[torch.Tensor, None],
  566. input_shape: torch.Size,
  567. inputs_embeds: torch.Tensor,
  568. ):
  569. # expand encoder attention mask
  570. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  571. if self.config._attn_implementation == "flash_attention_2":
  572. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  573. elif self.config._attn_implementation == "sdpa":
  574. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  575. # the manual implementation that requires a 4D causal mask in all cases.
  576. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  577. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  578. encoder_attention_mask,
  579. inputs_embeds.dtype,
  580. tgt_len=input_shape[-1],
  581. )
  582. elif self.config._attn_implementation == "flex_attention":
  583. if isinstance(encoder_attention_mask, torch.Tensor):
  584. encoder_attention_mask = make_flex_block_causal_mask(
  585. encoder_attention_mask,
  586. query_length=input_shape[-1],
  587. is_causal=False,
  588. )
  589. else:
  590. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  591. encoder_attention_mask = _prepare_4d_attention_mask(
  592. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  593. )
  594. return encoder_attention_mask
  595. class PegasusEncoder(PegasusPreTrainedModel):
  596. """
  597. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  598. [`PegasusEncoderLayer`].
  599. Args:
  600. config: PegasusConfig
  601. embed_tokens (nn.Embedding): output embedding
  602. """
  603. def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
  604. super().__init__(config)
  605. self.dropout = config.dropout
  606. self.layerdrop = config.encoder_layerdrop
  607. embed_dim = config.d_model
  608. self.padding_idx = config.pad_token_id
  609. self.max_source_positions = config.max_position_embeddings
  610. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  611. if embed_tokens is not None:
  612. self.embed_tokens = embed_tokens
  613. else:
  614. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  615. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  616. config.max_position_embeddings,
  617. embed_dim,
  618. self.padding_idx,
  619. )
  620. self.layers = nn.ModuleList([PegasusEncoderLayer(config) for _ in range(config.encoder_layers)])
  621. self.layer_norm = nn.LayerNorm(config.d_model)
  622. self.gradient_checkpointing = False
  623. # Initialize weights and apply final processing
  624. self.post_init()
  625. def resize_position_embeddings(self, new_num_position_embeddings: int):
  626. """
  627. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  628. config.max_position_embeddings`.
  629. Arguments:
  630. new_num_position_embeddings (`int`):
  631. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  632. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  633. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  634. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  635. will remove vectors from the end.
  636. """
  637. logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
  638. self.config.max_position_embeddings = new_num_position_embeddings
  639. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  640. self.config.max_position_embeddings,
  641. self.config.d_model,
  642. self.padding_idx,
  643. )
  644. self.embed_positions._init_weight()
  645. self.embed_positions.to(self.device)
  646. def get_position_embeddings(self) -> nn.Embedding:
  647. """
  648. Returns the position embeddings matrix
  649. """
  650. return self.embed_positions
  651. def forward(
  652. self,
  653. input_ids=None,
  654. attention_mask=None,
  655. head_mask=None,
  656. inputs_embeds=None,
  657. output_attentions=None,
  658. output_hidden_states=None,
  659. return_dict=None,
  660. ):
  661. r"""
  662. Args:
  663. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  664. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  665. provide it.
  666. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  667. [`PreTrainedTokenizer.__call__`] for details.
  668. [What are input IDs?](../glossary#input-ids)
  669. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  670. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  671. - 1 for tokens that are **not masked**,
  672. - 0 for tokens that are **masked**.
  673. [What are attention masks?](../glossary#attention-mask)
  674. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  675. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  676. - 1 indicates the head is **not masked**,
  677. - 0 indicates the head is **masked**.
  678. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  679. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  680. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  681. than the model's internal embedding lookup matrix.
  682. output_attentions (`bool`, *optional*):
  683. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  684. returned tensors for more detail.
  685. output_hidden_states (`bool`, *optional*):
  686. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  687. for more detail.
  688. return_dict (`bool`, *optional*):
  689. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  690. """
  691. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  692. output_hidden_states = (
  693. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  694. )
  695. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  696. # retrieve input_ids and inputs_embeds
  697. if input_ids is not None and inputs_embeds is not None:
  698. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  699. elif input_ids is not None:
  700. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  701. input_shape = input_ids.size()
  702. input_ids = input_ids.view(-1, input_shape[-1])
  703. elif inputs_embeds is not None:
  704. input_shape = inputs_embeds.size()[:-1]
  705. else:
  706. raise ValueError("You have to specify either input_ids or inputs_embeds")
  707. if inputs_embeds is None:
  708. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  709. embed_pos = self.embed_positions(input_shape)
  710. hidden_states = inputs_embeds + embed_pos
  711. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  712. attention_mask = self._update_full_mask(
  713. attention_mask,
  714. inputs_embeds,
  715. )
  716. encoder_states = () if output_hidden_states else None
  717. all_attentions = () if output_attentions else None
  718. # check if head_mask has a correct number of layers specified if desired
  719. if head_mask is not None:
  720. if head_mask.size()[0] != len(self.layers):
  721. raise ValueError(
  722. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  723. f" {head_mask.size()[0]}."
  724. )
  725. for idx, encoder_layer in enumerate(self.layers):
  726. if output_hidden_states:
  727. encoder_states = encoder_states + (hidden_states,)
  728. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  729. to_drop = False
  730. if self.training:
  731. dropout_probability = torch.rand([])
  732. if dropout_probability < self.layerdrop: # skip the layer
  733. to_drop = True
  734. if to_drop:
  735. layer_outputs = (None, None)
  736. else:
  737. layer_outputs = encoder_layer(
  738. hidden_states,
  739. attention_mask,
  740. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  741. output_attentions=output_attentions,
  742. )
  743. hidden_states = layer_outputs[0]
  744. if output_attentions:
  745. all_attentions = all_attentions + (layer_outputs[1],)
  746. hidden_states = self.layer_norm(hidden_states)
  747. if output_hidden_states:
  748. encoder_states = encoder_states + (hidden_states,)
  749. if not return_dict:
  750. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  751. return BaseModelOutput(
  752. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  753. )
  754. class PegasusDecoder(PegasusPreTrainedModel):
  755. """
  756. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PegasusDecoderLayer`]
  757. Args:
  758. config: PegasusConfig
  759. embed_tokens (nn.Embedding): output embedding
  760. """
  761. def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
  762. super().__init__(config)
  763. self.dropout = config.dropout
  764. self.layerdrop = config.decoder_layerdrop
  765. self.padding_idx = config.pad_token_id
  766. self.max_target_positions = config.max_position_embeddings
  767. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  768. if embed_tokens is not None:
  769. self.embed_tokens = embed_tokens
  770. else:
  771. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  772. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  773. config.max_position_embeddings,
  774. config.d_model,
  775. self.padding_idx,
  776. )
  777. self.layers = nn.ModuleList([PegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  778. self.layer_norm = nn.LayerNorm(config.d_model)
  779. self.gradient_checkpointing = False
  780. # Initialize weights and apply final processing
  781. self.post_init()
  782. def resize_position_embeddings(self, new_num_position_embeddings: int):
  783. """
  784. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  785. config.max_position_embeddings`.
  786. Arguments:
  787. new_num_position_embeddings (`int`):
  788. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  789. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  790. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  791. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  792. will remove vectors from the end.
  793. """
  794. logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
  795. self.config.max_position_embeddings = new_num_position_embeddings
  796. self.embed_positions = PegasusSinusoidalPositionalEmbedding(
  797. self.config.max_position_embeddings,
  798. self.config.d_model,
  799. self.padding_idx,
  800. )
  801. self.embed_positions._init_weight()
  802. self.embed_positions.to(self.device)
  803. def get_position_embeddings(self) -> nn.Embedding:
  804. """
  805. Returns the position embeddings matrix
  806. """
  807. return self.embed_positions
  808. def forward(
  809. self,
  810. input_ids=None,
  811. attention_mask=None,
  812. encoder_hidden_states=None,
  813. encoder_attention_mask=None,
  814. head_mask=None,
  815. cross_attn_head_mask=None,
  816. past_key_values=None,
  817. inputs_embeds=None,
  818. use_cache=None,
  819. output_attentions=None,
  820. output_hidden_states=None,
  821. return_dict=None,
  822. cache_position=None,
  823. ):
  824. r"""
  825. Args:
  826. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  827. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  828. provide it.
  829. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  830. [`PreTrainedTokenizer.__call__`] for details.
  831. [What are input IDs?](../glossary#input-ids)
  832. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  833. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  834. - 1 for tokens that are **not masked**,
  835. - 0 for tokens that are **masked**.
  836. [What are attention masks?](../glossary#attention-mask)
  837. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  838. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  839. of the decoder.
  840. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  841. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  842. selected in `[0, 1]`:
  843. - 1 for tokens that are **not masked**,
  844. - 0 for tokens that are **masked**.
  845. [What are attention masks?](../glossary#attention-mask)
  846. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  847. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  848. - 1 indicates the head is **not masked**,
  849. - 0 indicates the head is **masked**.
  850. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  851. Mask to nullify selected heads of the cross-attention modules in decoder to avoid performing
  852. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  853. - 1 indicates the head is **not masked**,
  854. - 0 indicates the head is **masked**.
  855. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  856. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  857. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  858. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  859. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  860. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  861. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  862. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  863. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  864. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  865. than the model's internal embedding lookup matrix.
  866. output_attentions (`bool`, *optional*):
  867. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  868. returned tensors for more detail.
  869. output_hidden_states (`bool`, *optional*):
  870. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  871. for more detail.
  872. return_dict (`bool`, *optional*):
  873. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  874. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  875. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  876. cache in the correct position and to infer the complete sequence length.
  877. """
  878. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  879. output_hidden_states = (
  880. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  881. )
  882. use_cache = use_cache if use_cache is not None else self.config.use_cache
  883. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  884. # retrieve input_ids and inputs_embeds
  885. if (input_ids is None) ^ (inputs_embeds is not None):
  886. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  887. elif input_ids is not None:
  888. input = input_ids
  889. input_shape = input.shape
  890. input_ids = input_ids.view(-1, input_shape[-1])
  891. elif inputs_embeds is not None:
  892. input_shape = inputs_embeds.size()[:-1]
  893. input = inputs_embeds[:, :, -1]
  894. else:
  895. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  896. if inputs_embeds is None:
  897. inputs_embeds = self.embed_tokens(input)
  898. # important to apply scale outside of `if` in case users pass `embeds`
  899. inputs_embeds = inputs_embeds * self.embed_scale
  900. if self.gradient_checkpointing and self.training:
  901. if use_cache:
  902. logger.warning_once(
  903. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  904. )
  905. use_cache = False
  906. # initialize `past_key_values`
  907. if use_cache and past_key_values is None:
  908. past_key_values = (
  909. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  910. if encoder_hidden_states is not None
  911. else DynamicCache(config=self.config)
  912. )
  913. if use_cache and isinstance(past_key_values, tuple):
  914. logger.warning_once(
  915. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  916. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  917. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  918. )
  919. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  920. batch_size, seq_length = inputs_embeds.size()[:-1]
  921. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  922. if cache_position is None:
  923. cache_position = torch.arange(
  924. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  925. )
  926. if attention_mask is None and not is_torchdynamo_compiling():
  927. # required mask seq length can be calculated via length of past cache
  928. mask_seq_length = past_key_values_length + seq_length
  929. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  930. self_attn_cache = (
  931. past_key_values.self_attention_cache
  932. if isinstance(past_key_values, EncoderDecoderCache)
  933. else past_key_values
  934. )
  935. causal_mask = self._update_causal_mask(
  936. attention_mask,
  937. inputs_embeds,
  938. cache_position,
  939. self_attn_cache,
  940. )
  941. encoder_attention_mask = self._update_cross_attn_mask(
  942. encoder_hidden_states,
  943. encoder_attention_mask,
  944. input_shape,
  945. inputs_embeds,
  946. )
  947. # embed positions
  948. positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position)
  949. hidden_states = inputs_embeds + positions
  950. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  951. # decoder layers
  952. all_hidden_states = () if output_hidden_states else None
  953. all_self_attns = () if output_attentions else None
  954. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  955. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  956. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  957. if attn_mask is not None:
  958. if attn_mask.size()[0] != len(self.layers):
  959. raise ValueError(
  960. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  961. f" {head_mask.size()[0]}."
  962. )
  963. for idx, decoder_layer in enumerate(self.layers):
  964. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  965. if output_hidden_states:
  966. all_hidden_states += (hidden_states,)
  967. if self.training:
  968. dropout_probability = torch.rand([])
  969. if dropout_probability < self.layerdrop:
  970. continue
  971. layer_outputs = decoder_layer(
  972. hidden_states,
  973. causal_mask,
  974. encoder_hidden_states, # as a positional argument for gradient checkpointing
  975. encoder_attention_mask=encoder_attention_mask,
  976. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  977. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  978. past_key_values=past_key_values,
  979. output_attentions=output_attentions,
  980. use_cache=use_cache,
  981. cache_position=cache_position,
  982. )
  983. hidden_states = layer_outputs[0]
  984. if output_attentions:
  985. all_self_attns += (layer_outputs[1],)
  986. if encoder_hidden_states is not None:
  987. all_cross_attentions += (layer_outputs[2],)
  988. hidden_states = self.layer_norm(hidden_states)
  989. # add hidden states from the last decoder layer
  990. if output_hidden_states:
  991. all_hidden_states += (hidden_states,)
  992. if not return_dict:
  993. return tuple(
  994. v
  995. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  996. if v is not None
  997. )
  998. return BaseModelOutputWithPastAndCrossAttentions(
  999. last_hidden_state=hidden_states,
  1000. past_key_values=past_key_values,
  1001. hidden_states=all_hidden_states,
  1002. attentions=all_self_attns,
  1003. cross_attentions=all_cross_attentions,
  1004. )
  1005. @auto_docstring
  1006. class PegasusModel(PegasusPreTrainedModel):
  1007. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1008. def __init__(self, config: PegasusConfig):
  1009. super().__init__(config)
  1010. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  1011. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  1012. self.encoder = PegasusEncoder(config, self.shared)
  1013. self.decoder = PegasusDecoder(config, self.shared)
  1014. # Initialize weights and apply final processing
  1015. self.post_init()
  1016. def get_input_embeddings(self):
  1017. return self.shared
  1018. def set_input_embeddings(self, value):
  1019. self.shared = value
  1020. self.encoder.embed_tokens = self.shared
  1021. self.decoder.embed_tokens = self.shared
  1022. def get_encoder(self):
  1023. return self.encoder
  1024. def resize_position_embeddings(self, new_num_position_embeddings: int):
  1025. """
  1026. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  1027. config.max_position_embeddings`.
  1028. Arguments:
  1029. new_num_position_embeddings (`int`):
  1030. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  1031. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  1032. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  1033. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  1034. will remove vectors from the end.
  1035. """
  1036. self.config.max_position_embeddings = new_num_position_embeddings
  1037. self.encoder.resize_position_embeddings(new_num_position_embeddings)
  1038. self.decoder.resize_position_embeddings(new_num_position_embeddings)
  1039. def get_position_embeddings(self) -> tuple[nn.Embedding]:
  1040. """
  1041. Returns the position embeddings matrix
  1042. """
  1043. return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
  1044. @auto_docstring
  1045. def forward(
  1046. self,
  1047. input_ids: Optional[torch.Tensor] = None,
  1048. attention_mask: Optional[torch.Tensor] = None,
  1049. decoder_input_ids: Optional[torch.Tensor] = None,
  1050. decoder_attention_mask: Optional[torch.Tensor] = None,
  1051. head_mask: Optional[torch.Tensor] = None,
  1052. decoder_head_mask: Optional[torch.Tensor] = None,
  1053. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1054. encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
  1055. past_key_values: Optional[Cache] = None,
  1056. inputs_embeds: Optional[torch.Tensor] = None,
  1057. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1058. use_cache: Optional[bool] = None,
  1059. output_attentions: Optional[bool] = None,
  1060. output_hidden_states: Optional[bool] = None,
  1061. return_dict: Optional[bool] = None,
  1062. cache_position: Optional[torch.Tensor] = None,
  1063. ) -> Union[tuple, Seq2SeqModelOutput]:
  1064. r"""
  1065. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1066. Indices of decoder input sequence tokens in the vocabulary.
  1067. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1068. [`PreTrainedTokenizer.__call__`] for details.
  1069. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1070. Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1071. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1072. `past_key_values`).
  1073. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1074. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1075. be used by default.
  1076. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1077. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1078. 1]`:
  1079. - 1 indicates the head is **not masked**,
  1080. - 0 indicates the head is **masked**.
  1081. Example:
  1082. ```python
  1083. >>> from transformers import AutoTokenizer, PegasusModel
  1084. >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
  1085. >>> model = PegasusModel.from_pretrained("google/pegasus-large")
  1086. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  1087. >>> decoder_inputs = tokenizer("Studies show that", return_tensors="pt")
  1088. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_inputs.input_ids)
  1089. >>> last_hidden_states = outputs.last_hidden_state
  1090. >>> list(last_hidden_states.shape)
  1091. [1, 4, 1024]
  1092. ```"""
  1093. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1094. output_hidden_states = (
  1095. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1096. )
  1097. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1098. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1099. if encoder_outputs is None:
  1100. encoder_outputs = self.encoder(
  1101. input_ids=input_ids,
  1102. attention_mask=attention_mask,
  1103. head_mask=head_mask,
  1104. inputs_embeds=inputs_embeds,
  1105. output_attentions=output_attentions,
  1106. output_hidden_states=output_hidden_states,
  1107. return_dict=return_dict,
  1108. )
  1109. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1110. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1111. encoder_outputs = BaseModelOutput(
  1112. last_hidden_state=encoder_outputs[0],
  1113. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1114. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1115. )
  1116. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1117. decoder_outputs = self.decoder(
  1118. input_ids=decoder_input_ids,
  1119. attention_mask=decoder_attention_mask,
  1120. encoder_hidden_states=encoder_outputs[0],
  1121. encoder_attention_mask=attention_mask,
  1122. head_mask=decoder_head_mask,
  1123. cross_attn_head_mask=cross_attn_head_mask,
  1124. past_key_values=past_key_values,
  1125. inputs_embeds=decoder_inputs_embeds,
  1126. use_cache=use_cache,
  1127. output_attentions=output_attentions,
  1128. output_hidden_states=output_hidden_states,
  1129. return_dict=return_dict,
  1130. cache_position=cache_position,
  1131. )
  1132. if not return_dict:
  1133. return decoder_outputs + encoder_outputs
  1134. return Seq2SeqModelOutput(
  1135. last_hidden_state=decoder_outputs.last_hidden_state,
  1136. past_key_values=decoder_outputs.past_key_values,
  1137. decoder_hidden_states=decoder_outputs.hidden_states,
  1138. decoder_attentions=decoder_outputs.attentions,
  1139. cross_attentions=decoder_outputs.cross_attentions,
  1140. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1141. encoder_hidden_states=encoder_outputs.hidden_states,
  1142. encoder_attentions=encoder_outputs.attentions,
  1143. )
  1144. @auto_docstring(
  1145. custom_intro="""
  1146. The PEGASUS Model with a language modeling head. Can be used for summarization.
  1147. """
  1148. )
  1149. class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
  1150. base_model_prefix = "model"
  1151. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1152. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1153. def __init__(self, config: PegasusConfig):
  1154. super().__init__(config)
  1155. self.model = PegasusModel(config)
  1156. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1157. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1158. # Initialize weights and apply final processing
  1159. self.post_init()
  1160. def get_encoder(self):
  1161. return self.model.get_encoder()
  1162. def get_decoder(self):
  1163. return self.model.get_decoder()
  1164. def resize_token_embeddings(
  1165. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  1166. ) -> nn.Embedding:
  1167. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1168. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1169. return new_embeddings
  1170. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1171. old_num_tokens = self.final_logits_bias.shape[-1]
  1172. if new_num_tokens <= old_num_tokens:
  1173. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1174. else:
  1175. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1176. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1177. self.register_buffer("final_logits_bias", new_bias)
  1178. def resize_position_embeddings(self, new_num_position_embeddings: int):
  1179. """
  1180. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  1181. config.max_position_embeddings`.
  1182. Arguments:
  1183. new_num_position_embeddings (`int`):
  1184. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  1185. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  1186. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  1187. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  1188. will remove vectors from the end.
  1189. """
  1190. self.config.max_position_embeddings = new_num_position_embeddings
  1191. self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
  1192. self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
  1193. def get_position_embeddings(self) -> tuple[nn.Embedding]:
  1194. """
  1195. Returns the position embeddings matrix
  1196. """
  1197. return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
  1198. @auto_docstring
  1199. def forward(
  1200. self,
  1201. input_ids: Optional[torch.Tensor] = None,
  1202. attention_mask: Optional[torch.Tensor] = None,
  1203. decoder_input_ids: Optional[torch.Tensor] = None,
  1204. decoder_attention_mask: Optional[torch.Tensor] = None,
  1205. head_mask: Optional[torch.Tensor] = None,
  1206. decoder_head_mask: Optional[torch.Tensor] = None,
  1207. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1208. encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
  1209. past_key_values: Optional[Cache] = None,
  1210. inputs_embeds: Optional[torch.Tensor] = None,
  1211. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1212. labels: Optional[torch.Tensor] = None,
  1213. use_cache: Optional[bool] = None,
  1214. output_attentions: Optional[bool] = None,
  1215. output_hidden_states: Optional[bool] = None,
  1216. return_dict: Optional[bool] = None,
  1217. cache_position: Optional[torch.Tensor] = None,
  1218. ) -> Union[tuple, Seq2SeqLMOutput]:
  1219. r"""
  1220. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1221. Indices of decoder input sequence tokens in the vocabulary.
  1222. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1223. [`PreTrainedTokenizer.__call__`] for details.
  1224. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1225. Pegasus uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1226. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1227. `past_key_values`).
  1228. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1229. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1230. be used by default.
  1231. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1232. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1233. 1]`:
  1234. - 1 indicates the head is **not masked**,
  1235. - 0 indicates the head is **masked**.
  1236. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1237. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1238. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1239. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1240. Example Summarization:
  1241. ```python
  1242. >>> from transformers import AutoTokenizer, PegasusForConditionalGeneration
  1243. >>> model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
  1244. >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-xsum")
  1245. >>> ARTICLE_TO_SUMMARIZE = (
  1246. ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
  1247. ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
  1248. ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
  1249. ... )
  1250. >>> inputs = tokenizer(ARTICLE_TO_SUMMARIZE, max_length=1024, return_tensors="pt")
  1251. >>> # Generate Summary
  1252. >>> summary_ids = model.generate(inputs["input_ids"])
  1253. >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1254. "California's largest electricity provider has turned off power to hundreds of thousands of customers."
  1255. ```
  1256. """
  1257. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1258. if labels is not None:
  1259. if use_cache:
  1260. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1261. use_cache = False
  1262. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1263. decoder_input_ids = shift_tokens_right(
  1264. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1265. )
  1266. outputs = self.model(
  1267. input_ids,
  1268. attention_mask=attention_mask,
  1269. decoder_input_ids=decoder_input_ids,
  1270. encoder_outputs=encoder_outputs,
  1271. decoder_attention_mask=decoder_attention_mask,
  1272. head_mask=head_mask,
  1273. decoder_head_mask=decoder_head_mask,
  1274. cross_attn_head_mask=cross_attn_head_mask,
  1275. past_key_values=past_key_values,
  1276. inputs_embeds=inputs_embeds,
  1277. decoder_inputs_embeds=decoder_inputs_embeds,
  1278. use_cache=use_cache,
  1279. output_attentions=output_attentions,
  1280. output_hidden_states=output_hidden_states,
  1281. return_dict=return_dict,
  1282. cache_position=cache_position,
  1283. )
  1284. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1285. masked_lm_loss = None
  1286. if labels is not None:
  1287. loss_fct = CrossEntropyLoss()
  1288. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1289. if not return_dict:
  1290. output = (lm_logits,) + outputs[1:]
  1291. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1292. return Seq2SeqLMOutput(
  1293. loss=masked_lm_loss,
  1294. logits=lm_logits,
  1295. past_key_values=outputs.past_key_values,
  1296. decoder_hidden_states=outputs.decoder_hidden_states,
  1297. decoder_attentions=outputs.decoder_attentions,
  1298. cross_attentions=outputs.cross_attentions,
  1299. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1300. encoder_hidden_states=outputs.encoder_hidden_states,
  1301. encoder_attentions=outputs.encoder_attentions,
  1302. )
  1303. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1304. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  1305. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
  1306. class PegasusDecoderWrapper(PegasusPreTrainedModel):
  1307. """
  1308. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1309. used in combination with the [`EncoderDecoderModel`] framework.
  1310. """
  1311. def __init__(self, config):
  1312. super().__init__(config)
  1313. self.decoder = PegasusDecoder(config)
  1314. def forward(self, *args, **kwargs):
  1315. return self.decoder(*args, **kwargs)
  1316. class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
  1317. _tied_weights_keys = ["lm_head.weight"]
  1318. def __init__(self, config):
  1319. config = copy.deepcopy(config)
  1320. config.is_decoder = True
  1321. config.is_encoder_decoder = False
  1322. super().__init__(config)
  1323. self.model = PegasusDecoderWrapper(config)
  1324. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1325. # Initialize weights and apply final processing
  1326. self.post_init()
  1327. def get_input_embeddings(self):
  1328. return self.model.decoder.embed_tokens
  1329. def set_input_embeddings(self, value):
  1330. self.model.decoder.embed_tokens = value
  1331. def set_decoder(self, decoder):
  1332. self.model.decoder = decoder
  1333. def get_decoder(self):
  1334. return self.model.decoder
  1335. def get_position_embeddings(self) -> nn.Embedding:
  1336. """
  1337. Returns the position embeddings matrix
  1338. """
  1339. return self.model.decoder.get_position_embeddings()
  1340. def resize_position_embeddings(self, new_num_position_embeddings: int):
  1341. """
  1342. Resizes position embeddings matrix of the model if `new_num_position_embeddings !=
  1343. config.max_position_embeddings`.
  1344. Arguments:
  1345. new_num_position_embeddings (`int`):
  1346. The number of new position embeddings. If position embeddings are learned, increasing the size will add
  1347. newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
  1348. position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
  1349. add correct vectors at the end following the position encoding algorithm, whereas reducing the size
  1350. will remove vectors from the end.
  1351. """
  1352. self.config.max_position_embeddings = new_num_position_embeddings
  1353. self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
  1354. @auto_docstring
  1355. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus, facebook/bart-base->google/pegasus-large
  1356. def forward(
  1357. self,
  1358. input_ids: Optional[torch.LongTensor] = None,
  1359. attention_mask: Optional[torch.Tensor] = None,
  1360. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1361. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1362. head_mask: Optional[torch.Tensor] = None,
  1363. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1364. past_key_values: Optional[Cache] = None,
  1365. inputs_embeds: Optional[torch.FloatTensor] = None,
  1366. labels: Optional[torch.LongTensor] = None,
  1367. use_cache: Optional[bool] = None,
  1368. output_attentions: Optional[bool] = None,
  1369. output_hidden_states: Optional[bool] = None,
  1370. return_dict: Optional[bool] = None,
  1371. cache_position: Optional[torch.LongTensor] = None,
  1372. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1373. r"""
  1374. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1375. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1376. - 1 indicates the head is **not masked**,
  1377. - 0 indicates the head is **masked**.
  1378. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1379. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1380. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1381. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1382. Example:
  1383. ```python
  1384. >>> from transformers import AutoTokenizer, PegasusForCausalLM
  1385. >>> tokenizer = AutoTokenizer.from_pretrained("google/pegasus-large")
  1386. >>> model = PegasusForCausalLM.from_pretrained("google/pegasus-large", add_cross_attention=False)
  1387. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1388. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1389. >>> outputs = model(**inputs)
  1390. >>> logits = outputs.logits
  1391. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1392. >>> list(logits.shape) == expected_shape
  1393. True
  1394. ```"""
  1395. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1396. output_hidden_states = (
  1397. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1398. )
  1399. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1400. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1401. outputs = self.model.decoder(
  1402. input_ids=input_ids,
  1403. attention_mask=attention_mask,
  1404. encoder_hidden_states=encoder_hidden_states,
  1405. encoder_attention_mask=encoder_attention_mask,
  1406. head_mask=head_mask,
  1407. cross_attn_head_mask=cross_attn_head_mask,
  1408. past_key_values=past_key_values,
  1409. inputs_embeds=inputs_embeds,
  1410. use_cache=use_cache,
  1411. output_attentions=output_attentions,
  1412. output_hidden_states=output_hidden_states,
  1413. return_dict=return_dict,
  1414. cache_position=cache_position,
  1415. )
  1416. logits = self.lm_head(outputs[0])
  1417. loss = None
  1418. if labels is not None:
  1419. labels = labels.to(logits.device)
  1420. loss_fct = CrossEntropyLoss()
  1421. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1422. if not return_dict:
  1423. output = (logits,) + outputs[1:]
  1424. return (loss,) + output if loss is not None else output
  1425. return CausalLMOutputWithCrossAttentions(
  1426. loss=loss,
  1427. logits=logits,
  1428. past_key_values=outputs.past_key_values,
  1429. hidden_states=outputs.hidden_states,
  1430. attentions=outputs.attentions,
  1431. cross_attentions=outputs.cross_attentions,
  1432. )
  1433. __all__ = ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]