modeling_plbart.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/plbart/modular_plbart.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_plbart.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from typing import Callable, Optional, Union
  23. import torch
  24. from torch import nn
  25. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import GenerationMixin
  29. from ...modeling_attn_mask_utils import (
  30. AttentionMaskConverter,
  31. _prepare_4d_attention_mask,
  32. _prepare_4d_attention_mask_for_sdpa,
  33. )
  34. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  35. from ...modeling_layers import GradientCheckpointingLayer
  36. from ...modeling_outputs import (
  37. BaseModelOutput,
  38. BaseModelOutputWithPastAndCrossAttentions,
  39. CausalLMOutputWithCrossAttentions,
  40. Seq2SeqLMOutput,
  41. Seq2SeqModelOutput,
  42. Seq2SeqSequenceClassifierOutput,
  43. )
  44. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  45. from ...processing_utils import Unpack
  46. from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
  47. from ...utils.deprecation import deprecate_kwarg
  48. from .configuration_plbart import PLBartConfig
  49. if is_torch_flex_attn_available():
  50. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  51. logger = logging.get_logger(__name__)
  52. class PLBartScaledWordEmbedding(nn.Embedding):
  53. """
  54. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  55. """
  56. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  57. super().__init__(num_embeddings, embedding_dim, padding_idx)
  58. self.embed_scale = embed_scale
  59. def forward(self, input_ids: torch.Tensor):
  60. return super().forward(input_ids) * self.embed_scale
  61. @auto_docstring
  62. class PLBartPreTrainedModel(PreTrainedModel):
  63. config: PLBartConfig
  64. base_model_prefix = "model"
  65. supports_gradient_checkpointing = True
  66. _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
  67. _supports_flash_attn = True
  68. _supports_sdpa = True
  69. _supports_flex_attn = True
  70. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  71. def _update_full_mask(
  72. self,
  73. attention_mask: Union[torch.Tensor, None],
  74. inputs_embeds: torch.Tensor,
  75. ):
  76. if attention_mask is not None:
  77. if self.config._attn_implementation == "flash_attention_2":
  78. attention_mask = attention_mask if 0 in attention_mask else None
  79. elif self.config._attn_implementation == "sdpa":
  80. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  81. # the manual implementation that requires a 4D causal mask in all cases.
  82. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  83. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  84. elif self.config._attn_implementation == "flex_attention":
  85. if isinstance(attention_mask, torch.Tensor):
  86. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  87. else:
  88. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  89. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  90. return attention_mask
  91. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  92. def _update_causal_mask(
  93. self,
  94. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  95. input_tensor: torch.Tensor,
  96. cache_position: torch.Tensor,
  97. past_key_values: Cache,
  98. ):
  99. if self.config._attn_implementation == "flex_attention":
  100. if isinstance(attention_mask, torch.Tensor):
  101. attention_mask = make_flex_block_causal_mask(attention_mask)
  102. # Other attention flavors support in-built causal (when `mask is None`)
  103. # while we need to create our specific block mask regardless
  104. elif attention_mask is None:
  105. attention_mask = make_flex_block_causal_mask(
  106. torch.ones(
  107. size=(input_tensor.shape[0], input_tensor.shape[1]),
  108. device=attention_mask.device,
  109. )
  110. )
  111. return attention_mask
  112. if self.config._attn_implementation == "flash_attention_2":
  113. if attention_mask is not None and (attention_mask == 0.0).any():
  114. return attention_mask
  115. return None
  116. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  117. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  118. # to infer the attention mask.
  119. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  120. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  121. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  122. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  123. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  124. attention_mask,
  125. inputs_embeds=input_tensor,
  126. past_key_values_length=past_seen_tokens,
  127. is_training=self.training,
  128. ):
  129. return None
  130. dtype = input_tensor.dtype
  131. sequence_length = input_tensor.shape[1]
  132. if using_compilable_cache:
  133. target_length = past_key_values.get_max_cache_shape()
  134. else:
  135. target_length = (
  136. attention_mask.shape[-1]
  137. if isinstance(attention_mask, torch.Tensor)
  138. else past_seen_tokens + sequence_length + 1
  139. )
  140. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  141. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  142. attention_mask,
  143. sequence_length=sequence_length,
  144. target_length=target_length,
  145. dtype=dtype,
  146. cache_position=cache_position,
  147. batch_size=input_tensor.shape[0],
  148. )
  149. if (
  150. self.config._attn_implementation == "sdpa"
  151. and attention_mask is not None
  152. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  153. ):
  154. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  155. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  156. # Details: https://github.com/pytorch/pytorch/issues/110213
  157. min_dtype = torch.finfo(dtype).min
  158. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  159. return causal_mask
  160. @staticmethod
  161. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  162. def _prepare_4d_causal_attention_mask_with_cache_position(
  163. attention_mask: torch.Tensor,
  164. sequence_length: int,
  165. target_length: int,
  166. dtype: torch.dtype,
  167. cache_position: torch.Tensor,
  168. batch_size: int,
  169. **kwargs,
  170. ):
  171. """
  172. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  173. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  174. Args:
  175. attention_mask (`torch.Tensor`):
  176. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  177. `(batch_size, 1, query_length, key_value_length)`.
  178. sequence_length (`int`):
  179. The sequence length being processed.
  180. target_length (`int`):
  181. The target length: when generating with static cache, the mask should be as long as the static cache,
  182. to account for the 0 padding, the part of the cache that is not filled yet.
  183. dtype (`torch.dtype`):
  184. The dtype to use for the 4D attention mask.
  185. cache_position (`torch.Tensor`):
  186. Indices depicting the position of the input sequence tokens in the sequence.
  187. batch_size (`torch.Tensor`):
  188. Batch size.
  189. """
  190. if attention_mask is not None and attention_mask.dim() == 4:
  191. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  192. causal_mask = attention_mask
  193. else:
  194. min_dtype = torch.finfo(dtype).min
  195. causal_mask = torch.full(
  196. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  197. )
  198. if sequence_length != 1:
  199. causal_mask = torch.triu(causal_mask, diagonal=1)
  200. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  201. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  202. if attention_mask is not None:
  203. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  204. mask_length = attention_mask.shape[-1]
  205. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  206. causal_mask.device
  207. )
  208. padding_mask = padding_mask == 0
  209. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  210. padding_mask, min_dtype
  211. )
  212. return causal_mask
  213. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  214. def _update_cross_attn_mask(
  215. self,
  216. encoder_hidden_states: Union[torch.Tensor, None],
  217. encoder_attention_mask: Union[torch.Tensor, None],
  218. input_shape: torch.Size,
  219. inputs_embeds: torch.Tensor,
  220. ):
  221. # expand encoder attention mask
  222. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  223. if self.config._attn_implementation == "flash_attention_2":
  224. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  225. elif self.config._attn_implementation == "sdpa":
  226. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  227. # the manual implementation that requires a 4D causal mask in all cases.
  228. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  229. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  230. encoder_attention_mask,
  231. inputs_embeds.dtype,
  232. tgt_len=input_shape[-1],
  233. )
  234. elif self.config._attn_implementation == "flex_attention":
  235. if isinstance(encoder_attention_mask, torch.Tensor):
  236. encoder_attention_mask = make_flex_block_causal_mask(
  237. encoder_attention_mask,
  238. query_length=input_shape[-1],
  239. is_causal=False,
  240. )
  241. else:
  242. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  243. encoder_attention_mask = _prepare_4d_attention_mask(
  244. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  245. )
  246. return encoder_attention_mask
  247. class PLBartLearnedPositionalEmbedding(nn.Embedding):
  248. """
  249. This module learns positional embeddings up to a fixed maximum size.
  250. """
  251. def __init__(self, num_embeddings: int, embedding_dim: int):
  252. # PLBart is set up so that if padding_idx is specified then offset the embedding ids by 2
  253. # and adjust num_embeddings appropriately. Other models don't have this hack
  254. self.offset = 2
  255. super().__init__(num_embeddings + self.offset, embedding_dim)
  256. def forward(
  257. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  258. ):
  259. """`input_ids' shape is expected to be [bsz x seqlen]."""
  260. if position_ids is None:
  261. bsz, seq_len = input_ids.shape[:2]
  262. position_ids = torch.arange(
  263. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  264. ).expand(bsz, -1)
  265. else:
  266. position_ids = position_ids.unsqueeze(0)
  267. return super().forward(position_ids + self.offset)
  268. def eager_attention_forward(
  269. module: nn.Module,
  270. query: torch.Tensor,
  271. key: torch.Tensor,
  272. value: torch.Tensor,
  273. attention_mask: Optional[torch.Tensor],
  274. scaling: Optional[float] = None,
  275. dropout: float = 0.0,
  276. head_mask: Optional[torch.Tensor] = None,
  277. **kwargs,
  278. ):
  279. if scaling is None:
  280. scaling = query.size(-1) ** -0.5
  281. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  282. if attention_mask is not None:
  283. attn_weights = attn_weights + attention_mask
  284. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  285. if head_mask is not None:
  286. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  287. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  288. attn_output = torch.matmul(attn_weights, value)
  289. attn_output = attn_output.transpose(1, 2).contiguous()
  290. return attn_output, attn_weights
  291. class PLBartAttention(nn.Module):
  292. """Multi-headed attention from 'Attention Is All You Need' paper"""
  293. def __init__(
  294. self,
  295. embed_dim: int,
  296. num_heads: int,
  297. dropout: float = 0.0,
  298. is_decoder: bool = False,
  299. bias: bool = True,
  300. is_causal: bool = False,
  301. config: Optional[PLBartConfig] = None,
  302. layer_idx: Optional[int] = None,
  303. ):
  304. super().__init__()
  305. self.embed_dim = embed_dim
  306. self.num_heads = num_heads
  307. self.dropout = dropout
  308. self.head_dim = embed_dim // num_heads
  309. self.config = config
  310. if (self.head_dim * num_heads) != self.embed_dim:
  311. raise ValueError(
  312. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  313. f" and `num_heads`: {num_heads})."
  314. )
  315. self.scaling = self.head_dim**-0.5
  316. self.is_decoder = is_decoder
  317. self.is_causal = is_causal
  318. self.layer_idx = layer_idx
  319. if layer_idx is None and self.is_decoder:
  320. logger.warning_once(
  321. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  322. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  323. "when creating this class."
  324. )
  325. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  326. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  327. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  328. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  329. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  330. def forward(
  331. self,
  332. hidden_states: torch.Tensor,
  333. key_value_states: Optional[torch.Tensor] = None,
  334. past_key_values: Optional[Cache] = None,
  335. attention_mask: Optional[torch.Tensor] = None,
  336. layer_head_mask: Optional[torch.Tensor] = None,
  337. output_attentions: bool = False,
  338. cache_position: Optional[torch.Tensor] = None,
  339. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  340. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  341. **kwargs: Unpack[FlashAttentionKwargs],
  342. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  343. """Input shape: Batch x Time x Channel"""
  344. # if key_value_states are provided this layer is used as a cross-attention layer
  345. # for the decoder
  346. is_cross_attention = key_value_states is not None
  347. # determine input shapes
  348. bsz, tgt_len = hidden_states.shape[:-1]
  349. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  350. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  351. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  352. # get query proj
  353. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  354. is_updated = False
  355. if past_key_values is not None:
  356. if isinstance(past_key_values, EncoderDecoderCache):
  357. is_updated = past_key_values.is_updated.get(self.layer_idx)
  358. if is_cross_attention:
  359. # after the first generated id, we can subsequently re-use all key/value_states from cache
  360. curr_past_key_value = past_key_values.cross_attention_cache
  361. else:
  362. curr_past_key_value = past_key_values.self_attention_cache
  363. else:
  364. curr_past_key_value = past_key_values
  365. current_states = key_value_states if is_cross_attention else hidden_states
  366. if is_cross_attention and past_key_values is not None and is_updated:
  367. # reuse k,v, cross_attentions
  368. key_states = curr_past_key_value.layers[self.layer_idx].keys
  369. value_states = curr_past_key_value.layers[self.layer_idx].values
  370. else:
  371. key_states = self.k_proj(current_states)
  372. value_states = self.v_proj(current_states)
  373. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  374. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  375. if past_key_values is not None:
  376. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  377. cache_position = cache_position if not is_cross_attention else None
  378. key_states, value_states = curr_past_key_value.update(
  379. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  380. )
  381. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  382. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  383. past_key_values.is_updated[self.layer_idx] = True
  384. attention_interface: Callable = eager_attention_forward
  385. if self.config._attn_implementation != "eager":
  386. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  387. attn_output, attn_weights = attention_interface(
  388. self,
  389. query_states,
  390. key_states,
  391. value_states,
  392. attention_mask,
  393. dropout=0.0 if not self.training else self.dropout,
  394. scaling=self.scaling,
  395. output_attentions=output_attentions,
  396. head_mask=layer_head_mask,
  397. **kwargs,
  398. )
  399. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  400. attn_output = self.out_proj(attn_output)
  401. return attn_output, attn_weights
  402. class PLBartEncoderLayer(GradientCheckpointingLayer):
  403. def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None):
  404. super().__init__()
  405. self.embed_dim = config.d_model
  406. self.self_attn = PLBartAttention(
  407. embed_dim=self.embed_dim,
  408. num_heads=config.encoder_attention_heads,
  409. dropout=config.attention_dropout,
  410. config=config,
  411. layer_idx=layer_idx,
  412. )
  413. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  414. self.dropout = config.dropout
  415. self.activation_fn = ACT2FN[config.activation_function]
  416. self.activation_dropout = config.activation_dropout
  417. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  418. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  419. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  420. def forward(
  421. self,
  422. hidden_states: torch.FloatTensor,
  423. attention_mask: torch.FloatTensor,
  424. layer_head_mask: torch.FloatTensor,
  425. output_attentions: Optional[bool] = False,
  426. ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
  427. """
  428. Args:
  429. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  430. attention_mask (`torch.FloatTensor`): attention mask of size
  431. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  432. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  433. `(encoder_attention_heads,)`.
  434. output_attentions (`bool`, *optional*):
  435. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  436. returned tensors for more detail.
  437. """
  438. residual = hidden_states
  439. hidden_states, attn_weights = self.self_attn(
  440. hidden_states=hidden_states,
  441. attention_mask=attention_mask,
  442. layer_head_mask=layer_head_mask,
  443. output_attentions=output_attentions,
  444. )
  445. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  446. hidden_states = residual + hidden_states
  447. hidden_states = self.self_attn_layer_norm(hidden_states)
  448. residual = hidden_states
  449. hidden_states = self.activation_fn(self.fc1(hidden_states))
  450. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  451. hidden_states = self.fc2(hidden_states)
  452. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  453. hidden_states = residual + hidden_states
  454. hidden_states = self.final_layer_norm(hidden_states)
  455. if hidden_states.dtype == torch.float16 and (
  456. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  457. ):
  458. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  459. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  460. outputs = (hidden_states,)
  461. if output_attentions:
  462. outputs += (attn_weights,)
  463. return outputs
  464. class PLBartEncoder(PLBartPreTrainedModel):
  465. """
  466. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  467. [`PLBartEncoderLayer`].
  468. Args:
  469. config: PLBartConfig
  470. embed_tokens (nn.Embedding): output embedding
  471. """
  472. def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None):
  473. super().__init__(config)
  474. self.dropout = config.dropout
  475. self.layerdrop = config.encoder_layerdrop
  476. embed_dim = config.d_model
  477. self.padding_idx = config.pad_token_id
  478. self.max_source_positions = config.max_position_embeddings
  479. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  480. self.embed_tokens = PLBartScaledWordEmbedding(
  481. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  482. )
  483. if embed_tokens is not None:
  484. self.embed_tokens.weight = embed_tokens.weight
  485. self.embed_positions = PLBartLearnedPositionalEmbedding(
  486. config.max_position_embeddings,
  487. embed_dim,
  488. )
  489. self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
  490. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  491. self.gradient_checkpointing = False
  492. # Initialize weights and apply final processing
  493. self.post_init()
  494. def forward(
  495. self,
  496. input_ids: Optional[torch.LongTensor] = None,
  497. attention_mask: Optional[torch.Tensor] = None,
  498. head_mask: Optional[torch.Tensor] = None,
  499. inputs_embeds: Optional[torch.FloatTensor] = None,
  500. output_attentions: Optional[bool] = None,
  501. output_hidden_states: Optional[bool] = None,
  502. return_dict: Optional[bool] = None,
  503. ) -> Union[tuple, BaseModelOutput]:
  504. r"""
  505. Args:
  506. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  507. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  508. provide it.
  509. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  510. [`PreTrainedTokenizer.__call__`] for details.
  511. [What are input IDs?](../glossary#input-ids)
  512. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  513. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  514. - 1 for tokens that are **not masked**,
  515. - 0 for tokens that are **masked**.
  516. [What are attention masks?](../glossary#attention-mask)
  517. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  518. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  519. - 1 indicates the head is **not masked**,
  520. - 0 indicates the head is **masked**.
  521. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  522. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  523. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  524. than the model's internal embedding lookup matrix.
  525. output_attentions (`bool`, *optional*):
  526. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  527. returned tensors for more detail.
  528. output_hidden_states (`bool`, *optional*):
  529. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  530. for more detail.
  531. return_dict (`bool`, *optional*):
  532. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  533. """
  534. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  535. output_hidden_states = (
  536. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  537. )
  538. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  539. # retrieve input_ids and inputs_embeds
  540. if input_ids is not None and inputs_embeds is not None:
  541. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  542. elif input_ids is not None:
  543. input = input_ids
  544. input_ids = input_ids.view(-1, input_ids.shape[-1])
  545. elif inputs_embeds is not None:
  546. input = inputs_embeds[:, :, -1]
  547. else:
  548. raise ValueError("You have to specify either input_ids or inputs_embeds")
  549. if inputs_embeds is None:
  550. inputs_embeds = self.embed_tokens(input_ids)
  551. embed_pos = self.embed_positions(input)
  552. embed_pos = embed_pos.to(inputs_embeds.device)
  553. hidden_states = inputs_embeds + embed_pos
  554. hidden_states = self.layernorm_embedding(hidden_states)
  555. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  556. attention_mask = self._update_full_mask(
  557. attention_mask,
  558. inputs_embeds,
  559. )
  560. encoder_states = () if output_hidden_states else None
  561. all_attentions = () if output_attentions else None
  562. # check if head_mask has a correct number of layers specified if desired
  563. if head_mask is not None:
  564. if head_mask.size()[0] != (len(self.layers)):
  565. raise ValueError(
  566. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  567. f" {head_mask.size()[0]}."
  568. )
  569. for idx, encoder_layer in enumerate(self.layers):
  570. if output_hidden_states:
  571. encoder_states = encoder_states + (hidden_states,)
  572. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  573. to_drop = False
  574. if self.training:
  575. dropout_probability = torch.rand([])
  576. if dropout_probability < self.layerdrop: # skip the layer
  577. to_drop = True
  578. if to_drop:
  579. layer_outputs = (None, None)
  580. else:
  581. layer_outputs = encoder_layer(
  582. hidden_states,
  583. attention_mask,
  584. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  585. output_attentions=output_attentions,
  586. )
  587. hidden_states = layer_outputs[0]
  588. if output_attentions:
  589. all_attentions = all_attentions + (layer_outputs[1],)
  590. if output_hidden_states:
  591. encoder_states = encoder_states + (hidden_states,)
  592. if not return_dict:
  593. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  594. return BaseModelOutput(
  595. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  596. )
  597. class PLBartDecoderLayer(GradientCheckpointingLayer):
  598. def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None):
  599. super().__init__()
  600. self.embed_dim = config.d_model
  601. self.self_attn = PLBartAttention(
  602. embed_dim=self.embed_dim,
  603. num_heads=config.decoder_attention_heads,
  604. dropout=config.attention_dropout,
  605. is_decoder=True,
  606. is_causal=True,
  607. config=config,
  608. layer_idx=layer_idx,
  609. )
  610. self.dropout = config.dropout
  611. self.activation_fn = ACT2FN[config.activation_function]
  612. self.activation_dropout = config.activation_dropout
  613. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  614. self.encoder_attn = PLBartAttention(
  615. self.embed_dim,
  616. config.decoder_attention_heads,
  617. dropout=config.attention_dropout,
  618. is_decoder=True,
  619. config=config,
  620. layer_idx=layer_idx,
  621. )
  622. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  623. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  624. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  625. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  626. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  627. def forward(
  628. self,
  629. hidden_states: torch.Tensor,
  630. attention_mask: Optional[torch.Tensor] = None,
  631. encoder_hidden_states: Optional[torch.Tensor] = None,
  632. encoder_attention_mask: Optional[torch.Tensor] = None,
  633. layer_head_mask: Optional[torch.Tensor] = None,
  634. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  635. past_key_values: Optional[Cache] = None,
  636. output_attentions: Optional[bool] = False,
  637. use_cache: Optional[bool] = True,
  638. cache_position: Optional[torch.Tensor] = None,
  639. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  640. """
  641. Args:
  642. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  643. attention_mask (`torch.FloatTensor`): attention mask of size
  644. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  645. encoder_hidden_states (`torch.FloatTensor`):
  646. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  647. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  648. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  649. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  650. `(encoder_attention_heads,)`.
  651. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  652. size `(decoder_attention_heads,)`.
  653. past_key_values (`Cache`): cached past key and value projection states
  654. output_attentions (`bool`, *optional*):
  655. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  656. returned tensors for more detail.
  657. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  658. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  659. cache in the correct position and to infer the complete sequence length.
  660. """
  661. residual = hidden_states
  662. # Self Attention
  663. hidden_states, self_attn_weights = self.self_attn(
  664. hidden_states=hidden_states,
  665. past_key_values=past_key_values,
  666. attention_mask=attention_mask,
  667. layer_head_mask=layer_head_mask,
  668. output_attentions=output_attentions,
  669. cache_position=cache_position,
  670. )
  671. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  672. hidden_states = residual + hidden_states
  673. hidden_states = self.self_attn_layer_norm(hidden_states)
  674. # Cross-Attention Block
  675. cross_attn_weights = None
  676. if encoder_hidden_states is not None:
  677. residual = hidden_states
  678. hidden_states, cross_attn_weights = self.encoder_attn(
  679. hidden_states=hidden_states,
  680. key_value_states=encoder_hidden_states,
  681. attention_mask=encoder_attention_mask,
  682. layer_head_mask=cross_attn_layer_head_mask,
  683. past_key_values=past_key_values,
  684. output_attentions=output_attentions,
  685. cache_position=cache_position,
  686. )
  687. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  688. hidden_states = residual + hidden_states
  689. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  690. # Fully Connected
  691. residual = hidden_states
  692. hidden_states = self.activation_fn(self.fc1(hidden_states))
  693. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  694. hidden_states = self.fc2(hidden_states)
  695. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  696. hidden_states = residual + hidden_states
  697. hidden_states = self.final_layer_norm(hidden_states)
  698. outputs = (hidden_states,)
  699. if output_attentions:
  700. outputs += (self_attn_weights, cross_attn_weights)
  701. return outputs
  702. class PLBartDecoder(PLBartPreTrainedModel):
  703. """
  704. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`]
  705. Args:
  706. config: PLBartConfig
  707. embed_tokens (nn.Embedding): output embedding
  708. """
  709. def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None):
  710. super().__init__(config)
  711. self.dropout = config.dropout
  712. self.layerdrop = config.decoder_layerdrop
  713. self.padding_idx = config.pad_token_id
  714. self.max_target_positions = config.max_position_embeddings
  715. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  716. self.embed_tokens = PLBartScaledWordEmbedding(
  717. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  718. )
  719. if embed_tokens is not None:
  720. self.embed_tokens.weight = embed_tokens.weight
  721. self.embed_positions = PLBartLearnedPositionalEmbedding(
  722. config.max_position_embeddings,
  723. config.d_model,
  724. )
  725. self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  726. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  727. self.gradient_checkpointing = False
  728. # Initialize weights and apply final processing
  729. self.post_init()
  730. def forward(
  731. self,
  732. input_ids: Optional[torch.LongTensor] = None,
  733. attention_mask: Optional[torch.Tensor] = None,
  734. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  735. encoder_attention_mask: Optional[torch.LongTensor] = None,
  736. head_mask: Optional[torch.Tensor] = None,
  737. cross_attn_head_mask: Optional[torch.Tensor] = None,
  738. past_key_values: Optional[Cache] = None,
  739. inputs_embeds: Optional[torch.FloatTensor] = None,
  740. use_cache: Optional[bool] = None,
  741. output_attentions: Optional[bool] = None,
  742. output_hidden_states: Optional[bool] = None,
  743. return_dict: Optional[bool] = None,
  744. cache_position: Optional[torch.LongTensor] = None,
  745. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  746. r"""
  747. Args:
  748. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  749. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  750. provide it.
  751. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  752. [`PreTrainedTokenizer.__call__`] for details.
  753. [What are input IDs?](../glossary#input-ids)
  754. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  755. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  756. - 1 for tokens that are **not masked**,
  757. - 0 for tokens that are **masked**.
  758. [What are attention masks?](../glossary#attention-mask)
  759. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  760. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  761. of the decoder.
  762. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  763. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  764. selected in `[0, 1]`:
  765. - 1 for tokens that are **not masked**,
  766. - 0 for tokens that are **masked**.
  767. [What are attention masks?](../glossary#attention-mask)
  768. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  769. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  770. - 1 indicates the head is **not masked**,
  771. - 0 indicates the head is **masked**.
  772. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  773. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  774. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  775. - 1 indicates the head is **not masked**,
  776. - 0 indicates the head is **masked**.
  777. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  778. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  779. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  780. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  781. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  782. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  783. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  784. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  785. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  786. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  787. than the model's internal embedding lookup matrix.
  788. output_attentions (`bool`, *optional*):
  789. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  790. returned tensors for more detail.
  791. output_hidden_states (`bool`, *optional*):
  792. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  793. for more detail.
  794. return_dict (`bool`, *optional*):
  795. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  796. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  797. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  798. cache in the correct position and to infer the complete sequence length.
  799. """
  800. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  801. output_hidden_states = (
  802. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  803. )
  804. use_cache = use_cache if use_cache is not None else self.config.use_cache
  805. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  806. if self.gradient_checkpointing and self.training:
  807. if use_cache:
  808. logger.warning_once(
  809. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  810. )
  811. use_cache = False
  812. # retrieve input_ids and inputs_embeds
  813. if (input_ids is None) ^ (inputs_embeds is not None):
  814. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  815. elif input_ids is not None:
  816. input = input_ids
  817. input_shape = input.shape
  818. input_ids = input_ids.view(-1, input_shape[-1])
  819. elif inputs_embeds is not None:
  820. input_shape = inputs_embeds.size()[:-1]
  821. input = inputs_embeds[:, :, -1]
  822. else:
  823. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  824. if inputs_embeds is None:
  825. inputs_embeds = self.embed_tokens(input)
  826. # initialize `past_key_values`
  827. if use_cache and past_key_values is None:
  828. past_key_values = (
  829. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  830. if encoder_hidden_states is not None
  831. else DynamicCache(config=self.config)
  832. )
  833. if use_cache and isinstance(past_key_values, tuple):
  834. logger.warning_once(
  835. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  836. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  837. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  838. )
  839. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  840. batch_size, seq_length = inputs_embeds.size()[:-1]
  841. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  842. if cache_position is None:
  843. cache_position = torch.arange(
  844. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  845. )
  846. if attention_mask is None and not is_torchdynamo_compiling():
  847. # required mask seq length can be calculated via length of past cache
  848. mask_seq_length = past_key_values_length + seq_length
  849. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  850. self_attn_cache = (
  851. past_key_values.self_attention_cache
  852. if isinstance(past_key_values, EncoderDecoderCache)
  853. else past_key_values
  854. )
  855. attention_mask = self._update_causal_mask(
  856. attention_mask,
  857. inputs_embeds,
  858. cache_position,
  859. self_attn_cache,
  860. )
  861. encoder_attention_mask = self._update_cross_attn_mask(
  862. encoder_hidden_states,
  863. encoder_attention_mask,
  864. input_shape,
  865. inputs_embeds,
  866. )
  867. # embed positions
  868. positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
  869. positions = positions.to(inputs_embeds.device)
  870. hidden_states = inputs_embeds + positions
  871. hidden_states = self.layernorm_embedding(hidden_states)
  872. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  873. # decoder layers
  874. all_hidden_states = () if output_hidden_states else None
  875. all_self_attns = () if output_attentions else None
  876. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  877. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  878. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  879. if attn_mask is not None:
  880. if attn_mask.size()[0] != (len(self.layers)):
  881. raise ValueError(
  882. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  883. f" {head_mask.size()[0]}."
  884. )
  885. for idx, decoder_layer in enumerate(self.layers):
  886. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  887. if output_hidden_states:
  888. all_hidden_states += (hidden_states,)
  889. if self.training:
  890. dropout_probability = torch.rand([])
  891. if dropout_probability < self.layerdrop:
  892. continue
  893. layer_outputs = decoder_layer(
  894. hidden_states,
  895. attention_mask,
  896. encoder_hidden_states, # as a positional argument for gradient checkpointing
  897. encoder_attention_mask=encoder_attention_mask,
  898. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  899. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  900. past_key_values=past_key_values,
  901. output_attentions=output_attentions,
  902. use_cache=use_cache,
  903. cache_position=cache_position,
  904. )
  905. hidden_states = layer_outputs[0]
  906. if output_attentions:
  907. all_self_attns += (layer_outputs[1],)
  908. if encoder_hidden_states is not None:
  909. all_cross_attentions += (layer_outputs[2],)
  910. # add hidden states from the last decoder layer
  911. if output_hidden_states:
  912. all_hidden_states += (hidden_states,)
  913. if not return_dict:
  914. return tuple(
  915. v
  916. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  917. if v is not None
  918. )
  919. return BaseModelOutputWithPastAndCrossAttentions(
  920. last_hidden_state=hidden_states,
  921. past_key_values=past_key_values,
  922. hidden_states=all_hidden_states,
  923. attentions=all_self_attns,
  924. cross_attentions=all_cross_attentions,
  925. )
  926. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
  927. """
  928. Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that PLBart does not
  929. have a single `decoder_start_token_id` in contrast to other Bart-like models.
  930. """
  931. prev_output_tokens = input_ids.clone()
  932. if pad_token_id is None:
  933. raise ValueError("self.model.config.pad_token_id has to be defined.")
  934. # replace possible -100 values in labels by `pad_token_id`
  935. prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
  936. index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
  937. decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
  938. prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
  939. prev_output_tokens[:, 0] = decoder_start_tokens
  940. return prev_output_tokens
  941. @auto_docstring
  942. class PLBartModel(PLBartPreTrainedModel):
  943. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  944. def __init__(self, config: PLBartConfig):
  945. super().__init__(config)
  946. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  947. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  948. self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  949. self.encoder = PLBartEncoder(config, self.shared)
  950. self.decoder = PLBartDecoder(config, self.shared)
  951. self.init_weights()
  952. def get_input_embeddings(self):
  953. return self.shared
  954. def set_input_embeddings(self, value):
  955. self.shared = value
  956. self.encoder.embed_tokens = self.shared
  957. self.decoder.embed_tokens = self.shared
  958. def _tie_weights(self):
  959. if self.config.tie_word_embeddings:
  960. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  961. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  962. def get_encoder(self):
  963. return self.encoder
  964. @auto_docstring
  965. def forward(
  966. self,
  967. input_ids: Optional[torch.LongTensor] = None,
  968. attention_mask: Optional[torch.LongTensor] = None,
  969. decoder_input_ids: Optional[torch.LongTensor] = None,
  970. decoder_attention_mask: Optional[torch.Tensor] = None,
  971. head_mask: Optional[torch.Tensor] = None,
  972. decoder_head_mask: Optional[torch.LongTensor] = None,
  973. cross_attn_head_mask: Optional[torch.Tensor] = None,
  974. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  975. past_key_values: Optional[Cache] = None,
  976. inputs_embeds: Optional[torch.FloatTensor] = None,
  977. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  978. use_cache: Optional[bool] = None,
  979. output_attentions: Optional[bool] = None,
  980. output_hidden_states: Optional[bool] = None,
  981. return_dict: Optional[bool] = None,
  982. cache_position: Optional[torch.LongTensor] = None,
  983. ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
  984. r"""
  985. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  986. Indices of decoder input sequence tokens in the vocabulary.
  987. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  988. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  989. [What are decoder input IDs?](../glossary#decoder-input-ids)
  990. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  991. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  992. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  993. `past_key_values`).
  994. For translation and summarization training, `decoder_input_ids` should be provided. If no
  995. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  996. for denoising pre-training following the paper.
  997. decoder_attention_mask (:
  998. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  999. Default behavior:
  1000. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  1001. cross_attn_head_mask (:
  1002. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1003. Mask to nullify
  1004. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  1005. - 1 indicates the head is **not masked**,
  1006. - 0 indicates the head is **masked**.
  1007. """
  1008. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1009. output_hidden_states = (
  1010. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1011. )
  1012. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1013. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1014. # different to other models, PLBart automatically creates decoder_input_ids from
  1015. # input_ids if no decoder_input_ids are provided
  1016. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1017. decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
  1018. if encoder_outputs is None:
  1019. encoder_outputs = self.encoder(
  1020. input_ids=input_ids,
  1021. attention_mask=attention_mask,
  1022. head_mask=head_mask,
  1023. inputs_embeds=inputs_embeds,
  1024. output_attentions=output_attentions,
  1025. output_hidden_states=output_hidden_states,
  1026. return_dict=return_dict,
  1027. )
  1028. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1029. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1030. encoder_outputs = BaseModelOutput(
  1031. last_hidden_state=encoder_outputs[0],
  1032. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1033. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1034. )
  1035. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1036. decoder_outputs = self.decoder(
  1037. input_ids=decoder_input_ids,
  1038. attention_mask=decoder_attention_mask,
  1039. encoder_hidden_states=encoder_outputs[0],
  1040. encoder_attention_mask=attention_mask,
  1041. head_mask=decoder_head_mask,
  1042. cross_attn_head_mask=cross_attn_head_mask,
  1043. past_key_values=past_key_values,
  1044. inputs_embeds=decoder_inputs_embeds,
  1045. use_cache=use_cache,
  1046. output_attentions=output_attentions,
  1047. output_hidden_states=output_hidden_states,
  1048. return_dict=return_dict,
  1049. cache_position=cache_position,
  1050. )
  1051. if not return_dict:
  1052. return decoder_outputs + encoder_outputs
  1053. return Seq2SeqModelOutput(
  1054. last_hidden_state=decoder_outputs.last_hidden_state,
  1055. past_key_values=decoder_outputs.past_key_values,
  1056. decoder_hidden_states=decoder_outputs.hidden_states,
  1057. decoder_attentions=decoder_outputs.attentions,
  1058. cross_attentions=decoder_outputs.cross_attentions,
  1059. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1060. encoder_hidden_states=encoder_outputs.hidden_states,
  1061. encoder_attentions=encoder_outputs.attentions,
  1062. )
  1063. @auto_docstring(
  1064. custom_intro="""
  1065. The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.
  1066. """
  1067. )
  1068. class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
  1069. base_model_prefix = "model"
  1070. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1071. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1072. def __init__(self, config: PLBartConfig):
  1073. super().__init__(config)
  1074. self.model = PLBartModel(config)
  1075. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1076. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1077. self.init_weights()
  1078. def get_encoder(self):
  1079. return self.model.get_encoder()
  1080. def get_decoder(self):
  1081. return self.model.get_decoder()
  1082. def resize_token_embeddings(
  1083. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  1084. ) -> nn.Embedding:
  1085. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1086. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1087. return new_embeddings
  1088. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1089. old_num_tokens = self.final_logits_bias.shape[-1]
  1090. if new_num_tokens <= old_num_tokens:
  1091. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1092. else:
  1093. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1094. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1095. self.register_buffer("final_logits_bias", new_bias)
  1096. @auto_docstring
  1097. def forward(
  1098. self,
  1099. input_ids: Optional[torch.LongTensor] = None,
  1100. attention_mask: Optional[torch.LongTensor] = None,
  1101. decoder_input_ids: Optional[torch.LongTensor] = None,
  1102. decoder_attention_mask: Optional[torch.Tensor] = None,
  1103. head_mask: Optional[torch.Tensor] = None,
  1104. decoder_head_mask: Optional[torch.LongTensor] = None,
  1105. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1106. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1107. past_key_values: Optional[Cache] = None,
  1108. inputs_embeds: Optional[torch.FloatTensor] = None,
  1109. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1110. labels: Optional[torch.Tensor] = None,
  1111. use_cache: Optional[bool] = None,
  1112. output_attentions: Optional[bool] = None,
  1113. output_hidden_states: Optional[bool] = None,
  1114. return_dict: Optional[bool] = None,
  1115. cache_position: Optional[torch.LongTensor] = None,
  1116. ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
  1117. r"""
  1118. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1119. Indices of decoder input sequence tokens in the vocabulary.
  1120. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  1121. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  1122. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1123. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  1124. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  1125. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1126. `past_key_values`).
  1127. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1128. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1129. for denoising pre-training following the paper.
  1130. decoder_attention_mask (:
  1131. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  1132. Default behavior:
  1133. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  1134. cross_attn_head_mask (:
  1135. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1136. Mask to nullify
  1137. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  1138. - 1 indicates the head is **not masked**,
  1139. - 0 indicates the head is **masked**.
  1140. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1141. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1142. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1143. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1144. Example Mask-filling:
  1145. ```python
  1146. >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration
  1147. >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
  1148. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  1149. >>> # en_XX is the language symbol id <LID> for English
  1150. >>> TXT = "<s> Is 0 the <mask> Fibonacci number ? </s> en_XX"
  1151. >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids
  1152. >>> logits = model(input_ids).logits
  1153. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  1154. >>> probs = logits[0, masked_index].softmax(dim=0)
  1155. >>> values, predictions = probs.topk(5)
  1156. >>> tokenizer.decode(predictions).split()
  1157. ['first', 'same', 'highest', 'result', 'number']
  1158. ```
  1159. """
  1160. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1161. if labels is not None:
  1162. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1163. decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
  1164. outputs = self.model(
  1165. input_ids,
  1166. attention_mask=attention_mask,
  1167. decoder_input_ids=decoder_input_ids,
  1168. encoder_outputs=encoder_outputs,
  1169. decoder_attention_mask=decoder_attention_mask,
  1170. head_mask=head_mask,
  1171. decoder_head_mask=decoder_head_mask,
  1172. cross_attn_head_mask=cross_attn_head_mask,
  1173. past_key_values=past_key_values,
  1174. inputs_embeds=inputs_embeds,
  1175. decoder_inputs_embeds=decoder_inputs_embeds,
  1176. use_cache=use_cache,
  1177. output_attentions=output_attentions,
  1178. output_hidden_states=output_hidden_states,
  1179. return_dict=return_dict,
  1180. cache_position=cache_position,
  1181. )
  1182. lm_logits = self.lm_head(outputs[0])
  1183. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  1184. masked_lm_loss = None
  1185. if labels is not None:
  1186. loss_fct = CrossEntropyLoss()
  1187. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1188. if not return_dict:
  1189. output = (lm_logits,) + outputs[1:]
  1190. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1191. return Seq2SeqLMOutput(
  1192. loss=masked_lm_loss,
  1193. logits=lm_logits,
  1194. past_key_values=outputs.past_key_values,
  1195. decoder_hidden_states=outputs.decoder_hidden_states,
  1196. decoder_attentions=outputs.decoder_attentions,
  1197. cross_attentions=outputs.cross_attentions,
  1198. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1199. encoder_hidden_states=outputs.encoder_hidden_states,
  1200. encoder_attentions=outputs.encoder_attentions,
  1201. )
  1202. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1203. return shift_tokens_right(labels, self.config.pad_token_id)
  1204. class PLBartClassificationHead(nn.Module):
  1205. """Head for sentence-level classification tasks."""
  1206. def __init__(
  1207. self,
  1208. input_dim: int,
  1209. inner_dim: int,
  1210. num_classes: int,
  1211. pooler_dropout: float,
  1212. ):
  1213. super().__init__()
  1214. self.dense = nn.Linear(input_dim, inner_dim)
  1215. self.dropout = nn.Dropout(p=pooler_dropout)
  1216. self.out_proj = nn.Linear(inner_dim, num_classes)
  1217. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1218. hidden_states = self.dropout(hidden_states)
  1219. hidden_states = self.dense(hidden_states)
  1220. hidden_states = torch.tanh(hidden_states)
  1221. hidden_states = self.dropout(hidden_states)
  1222. hidden_states = self.out_proj(hidden_states)
  1223. return hidden_states
  1224. @auto_docstring(
  1225. custom_intro="""
  1226. PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g.
  1227. for GLUE tasks.
  1228. """
  1229. )
  1230. class PLBartForSequenceClassification(PLBartPreTrainedModel):
  1231. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1232. def __init__(self, config: PLBartConfig, **kwargs):
  1233. super().__init__(config, **kwargs)
  1234. self.model = PLBartModel(config)
  1235. self.classification_head = PLBartClassificationHead(
  1236. config.d_model,
  1237. config.d_model,
  1238. config.num_labels,
  1239. config.classifier_dropout,
  1240. )
  1241. # Initialize weights and apply final processing
  1242. self.post_init()
  1243. @auto_docstring
  1244. def forward(
  1245. self,
  1246. input_ids: Optional[torch.LongTensor] = None,
  1247. attention_mask: Optional[torch.Tensor] = None,
  1248. decoder_input_ids: Optional[torch.LongTensor] = None,
  1249. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1250. head_mask: Optional[torch.Tensor] = None,
  1251. decoder_head_mask: Optional[torch.Tensor] = None,
  1252. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1253. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1254. inputs_embeds: Optional[torch.FloatTensor] = None,
  1255. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1256. labels: Optional[torch.LongTensor] = None,
  1257. use_cache: Optional[bool] = None,
  1258. output_attentions: Optional[bool] = None,
  1259. output_hidden_states: Optional[bool] = None,
  1260. return_dict: Optional[bool] = None,
  1261. cache_position: Optional[torch.LongTensor] = None,
  1262. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1263. r"""
  1264. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1265. Indices of decoder input sequence tokens in the vocabulary.
  1266. Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
  1267. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  1268. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1269. PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
  1270. varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
  1271. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1272. `past_key_values`).
  1273. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1274. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1275. for denoising pre-training following the paper.
  1276. decoder_attention_mask (:
  1277. obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
  1278. Default behavior:
  1279. generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
  1280. cross_attn_head_mask (:
  1281. obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1282. Mask to nullify
  1283. selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
  1284. - 1 indicates the head is **not masked**,
  1285. - 0 indicates the head is **masked**.
  1286. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1287. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1288. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1289. """
  1290. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1291. if labels is not None:
  1292. use_cache = False
  1293. if input_ids is None and inputs_embeds is not None:
  1294. raise NotImplementedError(
  1295. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1296. )
  1297. outputs = self.model(
  1298. input_ids,
  1299. attention_mask=attention_mask,
  1300. decoder_input_ids=decoder_input_ids,
  1301. decoder_attention_mask=decoder_attention_mask,
  1302. head_mask=head_mask,
  1303. decoder_head_mask=decoder_head_mask,
  1304. cross_attn_head_mask=cross_attn_head_mask,
  1305. encoder_outputs=encoder_outputs,
  1306. inputs_embeds=inputs_embeds,
  1307. decoder_inputs_embeds=decoder_inputs_embeds,
  1308. use_cache=use_cache,
  1309. output_attentions=output_attentions,
  1310. output_hidden_states=output_hidden_states,
  1311. return_dict=return_dict,
  1312. cache_position=cache_position,
  1313. )
  1314. hidden_states = outputs[0] # last hidden state
  1315. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1316. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1317. raise ValueError("All examples must have the same number of <eos> tokens.")
  1318. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1319. :, -1, :
  1320. ]
  1321. logits = self.classification_head(sentence_representation)
  1322. loss = None
  1323. if labels is not None:
  1324. labels = labels.to(logits.device)
  1325. if self.config.problem_type is None:
  1326. if self.config.num_labels == 1:
  1327. self.config.problem_type = "regression"
  1328. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1329. self.config.problem_type = "single_label_classification"
  1330. else:
  1331. self.config.problem_type = "multi_label_classification"
  1332. if self.config.problem_type == "regression":
  1333. loss_fct = MSELoss()
  1334. if self.config.num_labels == 1:
  1335. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1336. else:
  1337. loss = loss_fct(logits, labels)
  1338. elif self.config.problem_type == "single_label_classification":
  1339. loss_fct = CrossEntropyLoss()
  1340. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1341. elif self.config.problem_type == "multi_label_classification":
  1342. loss_fct = BCEWithLogitsLoss()
  1343. loss = loss_fct(logits, labels)
  1344. if not return_dict:
  1345. output = (logits,) + outputs[1:]
  1346. return ((loss,) + output) if loss is not None else output
  1347. return Seq2SeqSequenceClassifierOutput(
  1348. loss=loss,
  1349. logits=logits,
  1350. past_key_values=outputs.past_key_values,
  1351. decoder_hidden_states=outputs.decoder_hidden_states,
  1352. decoder_attentions=outputs.decoder_attentions,
  1353. cross_attentions=outputs.cross_attentions,
  1354. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1355. encoder_hidden_states=outputs.encoder_hidden_states,
  1356. encoder_attentions=outputs.encoder_attentions,
  1357. )
  1358. class PLBartDecoderWrapper(PLBartPreTrainedModel):
  1359. """
  1360. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1361. used in combination with the [`EncoderDecoderModel`] framework.
  1362. """
  1363. def __init__(self, config):
  1364. super().__init__(config)
  1365. self.decoder = PLBartDecoder(config)
  1366. def forward(self, *args, **kwargs):
  1367. return self.decoder(*args, **kwargs)
  1368. @auto_docstring(
  1369. custom_intro="""
  1370. PLBART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  1371. """
  1372. )
  1373. class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin):
  1374. _tied_weights_keys = ["lm_head.weight"]
  1375. def __init__(self, config):
  1376. config.is_decoder = True
  1377. config.is_encoder_decoder = False
  1378. super().__init__(config)
  1379. self.model = PLBartDecoderWrapper(config)
  1380. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1381. # Initialize weights and apply final processing
  1382. self.post_init()
  1383. def get_input_embeddings(self):
  1384. return self.model.decoder.embed_tokens
  1385. def set_input_embeddings(self, value):
  1386. self.model.decoder.embed_tokens = value
  1387. def set_decoder(self, decoder):
  1388. self.model.decoder = decoder
  1389. def get_decoder(self):
  1390. return self.model.decoder
  1391. @auto_docstring
  1392. def forward(
  1393. self,
  1394. input_ids: Optional[torch.LongTensor] = None,
  1395. attention_mask: Optional[torch.Tensor] = None,
  1396. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1397. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1398. head_mask: Optional[torch.Tensor] = None,
  1399. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1400. past_key_values: Optional[Cache] = None,
  1401. inputs_embeds: Optional[torch.FloatTensor] = None,
  1402. labels: Optional[torch.LongTensor] = None,
  1403. use_cache: Optional[bool] = None,
  1404. output_attentions: Optional[bool] = None,
  1405. output_hidden_states: Optional[bool] = None,
  1406. return_dict: Optional[bool] = None,
  1407. cache_position: Optional[torch.LongTensor] = None,
  1408. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1409. r"""
  1410. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1411. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1412. - 1 indicates the head is **not masked**,
  1413. - 0 indicates the head is **masked**.
  1414. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1415. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1416. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1417. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1418. Example:
  1419. ```python
  1420. >>> from transformers import AutoTokenizer, PLBartForCausalLM
  1421. >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
  1422. >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False)
  1423. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1424. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1425. >>> outputs = model(**inputs)
  1426. >>> logits = outputs.logits
  1427. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1428. >>> list(logits.shape) == expected_shape
  1429. True
  1430. ```"""
  1431. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1432. output_hidden_states = (
  1433. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1434. )
  1435. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1436. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1437. outputs = self.model.decoder(
  1438. input_ids=input_ids,
  1439. attention_mask=attention_mask,
  1440. encoder_hidden_states=encoder_hidden_states,
  1441. encoder_attention_mask=encoder_attention_mask,
  1442. head_mask=head_mask,
  1443. cross_attn_head_mask=cross_attn_head_mask,
  1444. past_key_values=past_key_values,
  1445. inputs_embeds=inputs_embeds,
  1446. use_cache=use_cache,
  1447. output_attentions=output_attentions,
  1448. output_hidden_states=output_hidden_states,
  1449. return_dict=return_dict,
  1450. cache_position=cache_position,
  1451. )
  1452. logits = self.lm_head(outputs[0])
  1453. loss = None
  1454. if labels is not None:
  1455. labels = labels.to(logits.device)
  1456. loss_fct = CrossEntropyLoss()
  1457. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1458. if not return_dict:
  1459. output = (logits,) + outputs[1:]
  1460. return (loss,) + output if loss is not None else output
  1461. return CausalLMOutputWithCrossAttentions(
  1462. loss=loss,
  1463. logits=logits,
  1464. past_key_values=outputs.past_key_values,
  1465. hidden_states=outputs.hidden_states,
  1466. attentions=outputs.attentions,
  1467. cross_attentions=outputs.cross_attentions,
  1468. )
  1469. __all__ = [
  1470. "PLBartForCausalLM",
  1471. "PLBartForConditionalGeneration",
  1472. "PLBartForSequenceClassification",
  1473. "PLBartModel",
  1474. "PLBartPreTrainedModel",
  1475. ]