modular_plbart.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. # coding=utf-8
  2. # Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PLBART model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ...cache_utils import Cache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import (
  24. AttentionMaskConverter,
  25. _prepare_4d_attention_mask,
  26. _prepare_4d_attention_mask_for_sdpa,
  27. )
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...utils import auto_docstring, is_torch_flex_attn_available
  35. from ..bart.modeling_bart import (
  36. BartClassificationHead,
  37. BartDecoder,
  38. BartEncoder,
  39. BartForCausalLM,
  40. BartScaledWordEmbedding,
  41. )
  42. from ..bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForSequenceClassification
  43. from ..mbart.modeling_mbart import shift_tokens_right
  44. from .configuration_plbart import PLBartConfig
  45. if is_torch_flex_attn_available():
  46. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  47. class PLBartScaledWordEmbedding(BartScaledWordEmbedding):
  48. pass
  49. @auto_docstring
  50. class PLBartPreTrainedModel(PreTrainedModel):
  51. config: PLBartConfig
  52. base_model_prefix = "model"
  53. supports_gradient_checkpointing = True
  54. _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
  55. _supports_flash_attn = True
  56. _supports_sdpa = True
  57. _supports_flex_attn = True
  58. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  59. def _update_full_mask(
  60. self,
  61. attention_mask: Union[torch.Tensor, None],
  62. inputs_embeds: torch.Tensor,
  63. ):
  64. if attention_mask is not None:
  65. if self.config._attn_implementation == "flash_attention_2":
  66. attention_mask = attention_mask if 0 in attention_mask else None
  67. elif self.config._attn_implementation == "sdpa":
  68. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  69. # the manual implementation that requires a 4D causal mask in all cases.
  70. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  71. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  72. elif self.config._attn_implementation == "flex_attention":
  73. if isinstance(attention_mask, torch.Tensor):
  74. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  75. else:
  76. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  77. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  78. return attention_mask
  79. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  80. def _update_causal_mask(
  81. self,
  82. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  83. input_tensor: torch.Tensor,
  84. cache_position: torch.Tensor,
  85. past_key_values: Cache,
  86. ):
  87. if self.config._attn_implementation == "flex_attention":
  88. if isinstance(attention_mask, torch.Tensor):
  89. attention_mask = make_flex_block_causal_mask(attention_mask)
  90. # Other attention flavors support in-built causal (when `mask is None`)
  91. # while we need to create our specific block mask regardless
  92. elif attention_mask is None:
  93. attention_mask = make_flex_block_causal_mask(
  94. torch.ones(
  95. size=(input_tensor.shape[0], input_tensor.shape[1]),
  96. device=attention_mask.device,
  97. )
  98. )
  99. return attention_mask
  100. if self.config._attn_implementation == "flash_attention_2":
  101. if attention_mask is not None and (attention_mask == 0.0).any():
  102. return attention_mask
  103. return None
  104. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  105. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  106. # to infer the attention mask.
  107. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  108. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  109. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  110. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  111. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  112. attention_mask,
  113. inputs_embeds=input_tensor,
  114. past_key_values_length=past_seen_tokens,
  115. is_training=self.training,
  116. ):
  117. return None
  118. dtype = input_tensor.dtype
  119. sequence_length = input_tensor.shape[1]
  120. if using_compilable_cache:
  121. target_length = past_key_values.get_max_cache_shape()
  122. else:
  123. target_length = (
  124. attention_mask.shape[-1]
  125. if isinstance(attention_mask, torch.Tensor)
  126. else past_seen_tokens + sequence_length + 1
  127. )
  128. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  129. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  130. attention_mask,
  131. sequence_length=sequence_length,
  132. target_length=target_length,
  133. dtype=dtype,
  134. cache_position=cache_position,
  135. batch_size=input_tensor.shape[0],
  136. )
  137. if (
  138. self.config._attn_implementation == "sdpa"
  139. and attention_mask is not None
  140. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  141. ):
  142. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  143. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  144. # Details: https://github.com/pytorch/pytorch/issues/110213
  145. min_dtype = torch.finfo(dtype).min
  146. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  147. return causal_mask
  148. @staticmethod
  149. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  150. def _prepare_4d_causal_attention_mask_with_cache_position(
  151. attention_mask: torch.Tensor,
  152. sequence_length: int,
  153. target_length: int,
  154. dtype: torch.dtype,
  155. cache_position: torch.Tensor,
  156. batch_size: int,
  157. **kwargs,
  158. ):
  159. """
  160. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  161. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  162. Args:
  163. attention_mask (`torch.Tensor`):
  164. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  165. `(batch_size, 1, query_length, key_value_length)`.
  166. sequence_length (`int`):
  167. The sequence length being processed.
  168. target_length (`int`):
  169. The target length: when generating with static cache, the mask should be as long as the static cache,
  170. to account for the 0 padding, the part of the cache that is not filled yet.
  171. dtype (`torch.dtype`):
  172. The dtype to use for the 4D attention mask.
  173. cache_position (`torch.Tensor`):
  174. Indices depicting the position of the input sequence tokens in the sequence.
  175. batch_size (`torch.Tensor`):
  176. Batch size.
  177. """
  178. if attention_mask is not None and attention_mask.dim() == 4:
  179. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  180. causal_mask = attention_mask
  181. else:
  182. min_dtype = torch.finfo(dtype).min
  183. causal_mask = torch.full(
  184. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  185. )
  186. if sequence_length != 1:
  187. causal_mask = torch.triu(causal_mask, diagonal=1)
  188. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  189. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  190. if attention_mask is not None:
  191. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  192. mask_length = attention_mask.shape[-1]
  193. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  194. causal_mask.device
  195. )
  196. padding_mask = padding_mask == 0
  197. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  198. padding_mask, min_dtype
  199. )
  200. return causal_mask
  201. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  202. def _update_cross_attn_mask(
  203. self,
  204. encoder_hidden_states: Union[torch.Tensor, None],
  205. encoder_attention_mask: Union[torch.Tensor, None],
  206. input_shape: torch.Size,
  207. inputs_embeds: torch.Tensor,
  208. ):
  209. # expand encoder attention mask
  210. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  211. if self.config._attn_implementation == "flash_attention_2":
  212. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  213. elif self.config._attn_implementation == "sdpa":
  214. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  215. # the manual implementation that requires a 4D causal mask in all cases.
  216. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  217. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  218. encoder_attention_mask,
  219. inputs_embeds.dtype,
  220. tgt_len=input_shape[-1],
  221. )
  222. elif self.config._attn_implementation == "flex_attention":
  223. if isinstance(encoder_attention_mask, torch.Tensor):
  224. encoder_attention_mask = make_flex_block_causal_mask(
  225. encoder_attention_mask,
  226. query_length=input_shape[-1],
  227. is_causal=False,
  228. )
  229. else:
  230. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  231. encoder_attention_mask = _prepare_4d_attention_mask(
  232. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  233. )
  234. return encoder_attention_mask
  235. class PLBartEncoder(BartEncoder):
  236. pass
  237. class PLBartDecoder(BartDecoder):
  238. pass
  239. @auto_docstring
  240. class PLBartModel(PLBartPreTrainedModel):
  241. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  242. def __init__(self, config: PLBartConfig):
  243. super().__init__(config)
  244. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  245. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  246. self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  247. self.encoder = PLBartEncoder(config, self.shared)
  248. self.decoder = PLBartDecoder(config, self.shared)
  249. self.init_weights()
  250. def get_input_embeddings(self):
  251. return self.shared
  252. def set_input_embeddings(self, value):
  253. self.shared = value
  254. self.encoder.embed_tokens = self.shared
  255. self.decoder.embed_tokens = self.shared
  256. def _tie_weights(self):
  257. if self.config.tie_word_embeddings:
  258. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  259. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  260. def get_encoder(self):
  261. return self.encoder
  262. @auto_docstring
  263. def forward(
  264. self,
  265. input_ids: Optional[torch.LongTensor] = None,
  266. attention_mask: Optional[torch.LongTensor] = None,
  267. decoder_input_ids: Optional[torch.LongTensor] = None,
  268. decoder_attention_mask: Optional[torch.Tensor] = None,
  269. head_mask: Optional[torch.Tensor] = None,
  270. decoder_head_mask: Optional[torch.LongTensor] = None,
  271. cross_attn_head_mask: Optional[torch.Tensor] = None,
  272. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  273. past_key_values: Optional[Cache] = None,
  274. inputs_embeds: Optional[torch.FloatTensor] = None,
  275. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  276. use_cache: Optional[bool] = None,
  277. output_attentions: Optional[bool] = None,
  278. output_hidden_states: Optional[bool] = None,
  279. return_dict: Optional[bool] = None,
  280. cache_position: Optional[torch.LongTensor] = None,
  281. ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
  282. r"""
  283. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  284. Indices of decoder input sequence tokens in the vocabulary.
  285. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  286. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  287. [What are decoder input IDs?](../glossary#decoder-input-ids)
  288. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  289. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  290. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  291. `past_key_values`).
  292. For translation and summarization training, `decoder_input_ids` should be provided. If no
  293. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  294. for denoising pre-training following the paper.
  295. decoder_attention_mask (:
  296. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  297. Default behavior:
  298. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  299. cross_attn_head_mask (:
  300. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  301. Mask to nullify
  302. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  303. - 1 indicates the head is **not masked**,
  304. - 0 indicates the head is **masked**.
  305. """
  306. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  307. output_hidden_states = (
  308. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  309. )
  310. use_cache = use_cache if use_cache is not None else self.config.use_cache
  311. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  312. # different to other models, PLBart automatically creates decoder_input_ids from
  313. # input_ids if no decoder_input_ids are provided
  314. if decoder_input_ids is None and decoder_inputs_embeds is None:
  315. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  316. if encoder_outputs is None:
  317. encoder_outputs = self.encoder(
  318. input_ids=input_ids,
  319. attention_mask=attention_mask,
  320. head_mask=head_mask,
  321. inputs_embeds=inputs_embeds,
  322. output_attentions=output_attentions,
  323. output_hidden_states=output_hidden_states,
  324. return_dict=return_dict,
  325. )
  326. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  327. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  328. encoder_outputs = BaseModelOutput(
  329. last_hidden_state=encoder_outputs[0],
  330. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  331. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  332. )
  333. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  334. decoder_outputs = self.decoder(
  335. input_ids=decoder_input_ids,
  336. attention_mask=decoder_attention_mask,
  337. encoder_hidden_states=encoder_outputs[0],
  338. encoder_attention_mask=attention_mask,
  339. head_mask=decoder_head_mask,
  340. cross_attn_head_mask=cross_attn_head_mask,
  341. past_key_values=past_key_values,
  342. inputs_embeds=decoder_inputs_embeds,
  343. use_cache=use_cache,
  344. output_attentions=output_attentions,
  345. output_hidden_states=output_hidden_states,
  346. return_dict=return_dict,
  347. cache_position=cache_position,
  348. )
  349. if not return_dict:
  350. return decoder_outputs + encoder_outputs
  351. return Seq2SeqModelOutput(
  352. last_hidden_state=decoder_outputs.last_hidden_state,
  353. past_key_values=decoder_outputs.past_key_values,
  354. decoder_hidden_states=decoder_outputs.hidden_states,
  355. decoder_attentions=decoder_outputs.attentions,
  356. cross_attentions=decoder_outputs.cross_attentions,
  357. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  358. encoder_hidden_states=encoder_outputs.hidden_states,
  359. encoder_attentions=encoder_outputs.attentions,
  360. )
  361. @auto_docstring(
  362. custom_intro="""
  363. The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.
  364. """
  365. )
  366. class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
  367. base_model_prefix = "model"
  368. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  369. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  370. def __init__(self, config: PLBartConfig):
  371. super().__init__(config)
  372. self.model = PLBartModel(config)
  373. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  374. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  375. self.init_weights()
  376. def get_encoder(self):
  377. return self.model.get_encoder()
  378. def get_decoder(self):
  379. return self.model.get_decoder()
  380. def resize_token_embeddings(
  381. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  382. ) -> nn.Embedding:
  383. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  384. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  385. return new_embeddings
  386. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  387. old_num_tokens = self.final_logits_bias.shape[-1]
  388. if new_num_tokens <= old_num_tokens:
  389. new_bias = self.final_logits_bias[:, :new_num_tokens]
  390. else:
  391. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  392. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  393. self.register_buffer("final_logits_bias", new_bias)
  394. @auto_docstring
  395. def forward(
  396. self,
  397. input_ids: Optional[torch.LongTensor] = None,
  398. attention_mask: Optional[torch.LongTensor] = None,
  399. decoder_input_ids: Optional[torch.LongTensor] = None,
  400. decoder_attention_mask: Optional[torch.Tensor] = None,
  401. head_mask: Optional[torch.Tensor] = None,
  402. decoder_head_mask: Optional[torch.LongTensor] = None,
  403. cross_attn_head_mask: Optional[torch.Tensor] = None,
  404. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  405. past_key_values: Optional[Cache] = None,
  406. inputs_embeds: Optional[torch.FloatTensor] = None,
  407. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  408. labels: Optional[torch.Tensor] = None,
  409. use_cache: Optional[bool] = None,
  410. output_attentions: Optional[bool] = None,
  411. output_hidden_states: Optional[bool] = None,
  412. return_dict: Optional[bool] = None,
  413. cache_position: Optional[torch.LongTensor] = None,
  414. ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
  415. r"""
  416. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  417. Indices of decoder input sequence tokens in the vocabulary.
  418. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  419. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  420. [What are decoder input IDs?](../glossary#decoder-input-ids)
  421. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  422. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  423. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  424. `past_key_values`).
  425. For translation and summarization training, `decoder_input_ids` should be provided. If no
  426. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  427. for denoising pre-training following the paper.
  428. decoder_attention_mask (:
  429. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  430. Default behavior:
  431. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  432. cross_attn_head_mask (:
  433. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  434. Mask to nullify
  435. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  436. - 1 indicates the head is **not masked**,
  437. - 0 indicates the head is **masked**.
  438. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  439. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  440. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  441. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  442. Example Mask-filling:
  443. ```python
  444. >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration
  445. >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
  446. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  447. >>> # en_XX is the language symbol id <LID> for English
  448. >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
  449. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids
  450. >>> logits = model(input_ids).logits
  451. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  452. >>> probs = logits[0, masked_index].softmax(dim=0)
  453. >>> values, predictions = probs.topk(5)
  454. >>> tokenizer.decode(predictions).split()
  455. ['first', 'same', 'highest', 'result', 'number']
  456. ```
  457. """
  458. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  459. if labels is not None:
  460. if decoder_input_ids is None and decoder_inputs_embeds is None:
  461. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  462. outputs = self.model(
  463. input_ids,
  464. attention_mask=attention_mask,
  465. decoder_input_ids=decoder_input_ids,
  466. encoder_outputs=encoder_outputs,
  467. decoder_attention_mask=decoder_attention_mask,
  468. head_mask=head_mask,
  469. decoder_head_mask=decoder_head_mask,
  470. cross_attn_head_mask=cross_attn_head_mask,
  471. past_key_values=past_key_values,
  472. inputs_embeds=inputs_embeds,
  473. decoder_inputs_embeds=decoder_inputs_embeds,
  474. use_cache=use_cache,
  475. output_attentions=output_attentions,
  476. output_hidden_states=output_hidden_states,
  477. return_dict=return_dict,
  478. cache_position=cache_position,
  479. )
  480. lm_logits = self.lm_head(outputs[0])
  481. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  482. masked_lm_loss = None
  483. if labels is not None:
  484. loss_fct = CrossEntropyLoss()
  485. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  486. if not return_dict:
  487. output = (lm_logits,) + outputs[1:]
  488. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  489. return Seq2SeqLMOutput(
  490. loss=masked_lm_loss,
  491. logits=lm_logits,
  492. past_key_values=outputs.past_key_values,
  493. decoder_hidden_states=outputs.decoder_hidden_states,
  494. decoder_attentions=outputs.decoder_attentions,
  495. cross_attentions=outputs.cross_attentions,
  496. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  497. encoder_hidden_states=outputs.encoder_hidden_states,
  498. encoder_attentions=outputs.encoder_attentions,
  499. )
  500. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  501. return shift_tokens_right(labels, self.config.pad_token_id)
  502. class PLBartClassificationHead(BartClassificationHead):
  503. pass
  504. class PLBartForSequenceClassification(BigBirdPegasusForSequenceClassification):
  505. def forward(**super_kwargs):
  506. r"""
  507. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  508. Indices of decoder input sequence tokens in the vocabulary.
  509. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  510. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  511. [What are decoder input IDs?](../glossary#decoder-input-ids)
  512. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  513. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  514. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  515. `past_key_values`).
  516. For translation and summarization training, `decoder_input_ids` should be provided. If no
  517. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  518. for denoising pre-training following the paper.
  519. decoder_attention_mask (:
  520. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  521. Default behavior:
  522. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  523. cross_attn_head_mask (:
  524. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  525. Mask to nullify
  526. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  527. - 1 indicates the head is **not masked**,
  528. - 0 indicates the head is **masked**.
  529. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  530. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  531. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  532. """
  533. super().forward(**super_kwargs)
  534. class PLBartForCausalLM(BartForCausalLM):
  535. @auto_docstring
  536. def forward(**super_kwargs):
  537. r"""
  538. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  539. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  540. - 1 indicates the head is **not masked**,
  541. - 0 indicates the head is **masked**.
  542. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  543. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  544. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  545. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  546. Example:
  547. ```python
  548. >>> from transformers import AutoTokenizer, PLBartForCausalLM
  549. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  550. >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False)
  551. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  552. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  553. >>> outputs = model(**inputs)
  554. >>> logits = outputs.logits
  555. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  556. >>> list(logits.shape) == expected_shape
  557. True
  558. ```"""
  559. super().forward(**super_kwargs)
  560. __all__ = [
  561. "PLBartForCausalLM",
  562. "PLBartForConditionalGeneration",
  563. "PLBartForSequenceClassification",
  564. "PLBartModel",
  565. "PLBartPreTrainedModel",
  566. ]