modeling_bart.py 88 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BART model."""
  16. import math
  17. import warnings
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import (
  26. AttentionMaskConverter,
  27. _prepare_4d_attention_mask,
  28. _prepare_4d_attention_mask_for_sdpa,
  29. )
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutput,
  34. BaseModelOutputWithPastAndCrossAttentions,
  35. CausalLMOutputWithCrossAttentions,
  36. Seq2SeqLMOutput,
  37. Seq2SeqModelOutput,
  38. Seq2SeqQuestionAnsweringModelOutput,
  39. Seq2SeqSequenceClassifierOutput,
  40. )
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import (
  44. auto_docstring,
  45. is_torch_flex_attn_available,
  46. is_torchdynamo_compiling,
  47. logging,
  48. )
  49. from ...utils.deprecation import deprecate_kwarg
  50. from .configuration_bart import BartConfig
  51. if is_torch_flex_attn_available():
  52. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  53. logger = logging.get_logger(__name__)
  54. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  55. """
  56. Shift input ids one token to the right.
  57. """
  58. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  59. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  60. shifted_input_ids[:, 0] = decoder_start_token_id
  61. if pad_token_id is None:
  62. raise ValueError("self.model.config.pad_token_id has to be defined.")
  63. # replace possible -100 values in labels by `pad_token_id`
  64. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  65. return shifted_input_ids
  66. class BartLearnedPositionalEmbedding(nn.Embedding):
  67. """
  68. This module learns positional embeddings up to a fixed maximum size.
  69. """
  70. def __init__(self, num_embeddings: int, embedding_dim: int):
  71. # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
  72. # and adjust num_embeddings appropriately. Other models don't have this hack
  73. self.offset = 2
  74. super().__init__(num_embeddings + self.offset, embedding_dim)
  75. def forward(
  76. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  77. ):
  78. """`input_ids' shape is expected to be [bsz x seqlen]."""
  79. if position_ids is None:
  80. bsz, seq_len = input_ids.shape[:2]
  81. position_ids = torch.arange(
  82. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  83. ).expand(bsz, -1)
  84. else:
  85. position_ids = position_ids.unsqueeze(0)
  86. return super().forward(position_ids + self.offset)
  87. class BartScaledWordEmbedding(nn.Embedding):
  88. """
  89. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  90. """
  91. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  92. super().__init__(num_embeddings, embedding_dim, padding_idx)
  93. self.embed_scale = embed_scale
  94. def forward(self, input_ids: torch.Tensor):
  95. return super().forward(input_ids) * self.embed_scale
  96. def eager_attention_forward(
  97. module: nn.Module,
  98. query: torch.Tensor,
  99. key: torch.Tensor,
  100. value: torch.Tensor,
  101. attention_mask: Optional[torch.Tensor],
  102. scaling: Optional[float] = None,
  103. dropout: float = 0.0,
  104. head_mask: Optional[torch.Tensor] = None,
  105. **kwargs,
  106. ):
  107. if scaling is None:
  108. scaling = query.size(-1) ** -0.5
  109. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  110. if attention_mask is not None:
  111. attn_weights = attn_weights + attention_mask
  112. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  113. if head_mask is not None:
  114. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  115. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  116. attn_output = torch.matmul(attn_weights, value)
  117. attn_output = attn_output.transpose(1, 2).contiguous()
  118. return attn_output, attn_weights
  119. class BartAttention(nn.Module):
  120. """Multi-headed attention from 'Attention Is All You Need' paper"""
  121. def __init__(
  122. self,
  123. embed_dim: int,
  124. num_heads: int,
  125. dropout: float = 0.0,
  126. is_decoder: bool = False,
  127. bias: bool = True,
  128. is_causal: bool = False,
  129. config: Optional[BartConfig] = None,
  130. layer_idx: Optional[int] = None,
  131. ):
  132. super().__init__()
  133. self.embed_dim = embed_dim
  134. self.num_heads = num_heads
  135. self.dropout = dropout
  136. self.head_dim = embed_dim // num_heads
  137. self.config = config
  138. if (self.head_dim * num_heads) != self.embed_dim:
  139. raise ValueError(
  140. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  141. f" and `num_heads`: {num_heads})."
  142. )
  143. self.scaling = self.head_dim**-0.5
  144. self.is_decoder = is_decoder
  145. self.is_causal = is_causal
  146. self.layer_idx = layer_idx
  147. if layer_idx is None and self.is_decoder:
  148. logger.warning_once(
  149. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  150. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  151. "when creating this class."
  152. )
  153. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  154. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  155. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  156. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  157. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  158. def forward(
  159. self,
  160. hidden_states: torch.Tensor,
  161. key_value_states: Optional[torch.Tensor] = None,
  162. past_key_values: Optional[Cache] = None,
  163. attention_mask: Optional[torch.Tensor] = None,
  164. layer_head_mask: Optional[torch.Tensor] = None,
  165. output_attentions: bool = False,
  166. cache_position: Optional[torch.Tensor] = None,
  167. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  168. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  169. **kwargs: Unpack[FlashAttentionKwargs],
  170. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  171. """Input shape: Batch x Time x Channel"""
  172. # if key_value_states are provided this layer is used as a cross-attention layer
  173. # for the decoder
  174. is_cross_attention = key_value_states is not None
  175. # determine input shapes
  176. bsz, tgt_len = hidden_states.shape[:-1]
  177. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  178. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  179. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  180. # get query proj
  181. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  182. is_updated = False
  183. if past_key_values is not None:
  184. if isinstance(past_key_values, EncoderDecoderCache):
  185. is_updated = past_key_values.is_updated.get(self.layer_idx)
  186. if is_cross_attention:
  187. # after the first generated id, we can subsequently re-use all key/value_states from cache
  188. curr_past_key_value = past_key_values.cross_attention_cache
  189. else:
  190. curr_past_key_value = past_key_values.self_attention_cache
  191. else:
  192. curr_past_key_value = past_key_values
  193. current_states = key_value_states if is_cross_attention else hidden_states
  194. if is_cross_attention and past_key_values is not None and is_updated:
  195. # reuse k,v, cross_attentions
  196. key_states = curr_past_key_value.layers[self.layer_idx].keys
  197. value_states = curr_past_key_value.layers[self.layer_idx].values
  198. else:
  199. key_states = self.k_proj(current_states)
  200. value_states = self.v_proj(current_states)
  201. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  202. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  203. if past_key_values is not None:
  204. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  205. cache_position = cache_position if not is_cross_attention else None
  206. key_states, value_states = curr_past_key_value.update(
  207. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  208. )
  209. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  210. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  211. past_key_values.is_updated[self.layer_idx] = True
  212. attention_interface: Callable = eager_attention_forward
  213. if self.config._attn_implementation != "eager":
  214. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  215. attn_output, attn_weights = attention_interface(
  216. self,
  217. query_states,
  218. key_states,
  219. value_states,
  220. attention_mask,
  221. dropout=0.0 if not self.training else self.dropout,
  222. scaling=self.scaling,
  223. output_attentions=output_attentions,
  224. head_mask=layer_head_mask,
  225. **kwargs,
  226. )
  227. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  228. attn_output = self.out_proj(attn_output)
  229. return attn_output, attn_weights
  230. class BartEncoderLayer(GradientCheckpointingLayer):
  231. def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
  232. super().__init__()
  233. self.embed_dim = config.d_model
  234. self.self_attn = BartAttention(
  235. embed_dim=self.embed_dim,
  236. num_heads=config.encoder_attention_heads,
  237. dropout=config.attention_dropout,
  238. config=config,
  239. layer_idx=layer_idx,
  240. )
  241. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  242. self.dropout = config.dropout
  243. self.activation_fn = ACT2FN[config.activation_function]
  244. self.activation_dropout = config.activation_dropout
  245. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  246. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  247. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  248. def forward(
  249. self,
  250. hidden_states: torch.FloatTensor,
  251. attention_mask: torch.FloatTensor,
  252. layer_head_mask: torch.FloatTensor,
  253. output_attentions: Optional[bool] = False,
  254. ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
  255. """
  256. Args:
  257. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  258. attention_mask (`torch.FloatTensor`): attention mask of size
  259. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  260. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  261. `(encoder_attention_heads,)`.
  262. output_attentions (`bool`, *optional*):
  263. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  264. returned tensors for more detail.
  265. """
  266. residual = hidden_states
  267. hidden_states, attn_weights = self.self_attn(
  268. hidden_states=hidden_states,
  269. attention_mask=attention_mask,
  270. layer_head_mask=layer_head_mask,
  271. output_attentions=output_attentions,
  272. )
  273. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  274. hidden_states = residual + hidden_states
  275. hidden_states = self.self_attn_layer_norm(hidden_states)
  276. residual = hidden_states
  277. hidden_states = self.activation_fn(self.fc1(hidden_states))
  278. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  279. hidden_states = self.fc2(hidden_states)
  280. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  281. hidden_states = residual + hidden_states
  282. hidden_states = self.final_layer_norm(hidden_states)
  283. if hidden_states.dtype == torch.float16 and (
  284. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  285. ):
  286. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  287. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  288. outputs = (hidden_states,)
  289. if output_attentions:
  290. outputs += (attn_weights,)
  291. return outputs
  292. class BartDecoderLayer(GradientCheckpointingLayer):
  293. def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
  294. super().__init__()
  295. self.embed_dim = config.d_model
  296. self.self_attn = BartAttention(
  297. embed_dim=self.embed_dim,
  298. num_heads=config.decoder_attention_heads,
  299. dropout=config.attention_dropout,
  300. is_decoder=True,
  301. is_causal=True,
  302. config=config,
  303. layer_idx=layer_idx,
  304. )
  305. self.dropout = config.dropout
  306. self.activation_fn = ACT2FN[config.activation_function]
  307. self.activation_dropout = config.activation_dropout
  308. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  309. self.encoder_attn = BartAttention(
  310. self.embed_dim,
  311. config.decoder_attention_heads,
  312. dropout=config.attention_dropout,
  313. is_decoder=True,
  314. config=config,
  315. layer_idx=layer_idx,
  316. )
  317. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  318. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  319. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  320. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  321. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  322. def forward(
  323. self,
  324. hidden_states: torch.Tensor,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. encoder_hidden_states: Optional[torch.Tensor] = None,
  327. encoder_attention_mask: Optional[torch.Tensor] = None,
  328. layer_head_mask: Optional[torch.Tensor] = None,
  329. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  330. past_key_values: Optional[Cache] = None,
  331. output_attentions: Optional[bool] = False,
  332. use_cache: Optional[bool] = True,
  333. cache_position: Optional[torch.Tensor] = None,
  334. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  335. """
  336. Args:
  337. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  338. attention_mask (`torch.FloatTensor`): attention mask of size
  339. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  340. encoder_hidden_states (`torch.FloatTensor`):
  341. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  342. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  343. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  344. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  345. `(encoder_attention_heads,)`.
  346. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  347. size `(decoder_attention_heads,)`.
  348. past_key_values (`Cache`): cached past key and value projection states
  349. output_attentions (`bool`, *optional*):
  350. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  351. returned tensors for more detail.
  352. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  353. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  354. cache in the correct position and to infer the complete sequence length.
  355. """
  356. residual = hidden_states
  357. # Self Attention
  358. hidden_states, self_attn_weights = self.self_attn(
  359. hidden_states=hidden_states,
  360. past_key_values=past_key_values,
  361. attention_mask=attention_mask,
  362. layer_head_mask=layer_head_mask,
  363. output_attentions=output_attentions,
  364. cache_position=cache_position,
  365. )
  366. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  367. hidden_states = residual + hidden_states
  368. hidden_states = self.self_attn_layer_norm(hidden_states)
  369. # Cross-Attention Block
  370. cross_attn_weights = None
  371. if encoder_hidden_states is not None:
  372. residual = hidden_states
  373. hidden_states, cross_attn_weights = self.encoder_attn(
  374. hidden_states=hidden_states,
  375. key_value_states=encoder_hidden_states,
  376. attention_mask=encoder_attention_mask,
  377. layer_head_mask=cross_attn_layer_head_mask,
  378. past_key_values=past_key_values,
  379. output_attentions=output_attentions,
  380. cache_position=cache_position,
  381. )
  382. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  383. hidden_states = residual + hidden_states
  384. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  385. # Fully Connected
  386. residual = hidden_states
  387. hidden_states = self.activation_fn(self.fc1(hidden_states))
  388. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  389. hidden_states = self.fc2(hidden_states)
  390. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  391. hidden_states = residual + hidden_states
  392. hidden_states = self.final_layer_norm(hidden_states)
  393. outputs = (hidden_states,)
  394. if output_attentions:
  395. outputs += (self_attn_weights, cross_attn_weights)
  396. return outputs
  397. class BartClassificationHead(nn.Module):
  398. """Head for sentence-level classification tasks."""
  399. def __init__(
  400. self,
  401. input_dim: int,
  402. inner_dim: int,
  403. num_classes: int,
  404. pooler_dropout: float,
  405. ):
  406. super().__init__()
  407. self.dense = nn.Linear(input_dim, inner_dim)
  408. self.dropout = nn.Dropout(p=pooler_dropout)
  409. self.out_proj = nn.Linear(inner_dim, num_classes)
  410. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  411. hidden_states = self.dropout(hidden_states)
  412. hidden_states = self.dense(hidden_states)
  413. hidden_states = torch.tanh(hidden_states)
  414. hidden_states = self.dropout(hidden_states)
  415. hidden_states = self.out_proj(hidden_states)
  416. return hidden_states
  417. @auto_docstring
  418. class BartPreTrainedModel(PreTrainedModel):
  419. config: BartConfig
  420. base_model_prefix = "model"
  421. supports_gradient_checkpointing = True
  422. _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"]
  423. _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
  424. _skip_keys_device_placement = "past_key_values"
  425. _supports_flash_attn = True
  426. _supports_sdpa = True
  427. _supports_flex_attn = True
  428. _can_compile_fullgraph = True
  429. def _init_weights(self, module):
  430. std = self.config.init_std
  431. if isinstance(module, nn.Linear):
  432. module.weight.data.normal_(mean=0.0, std=std)
  433. if module.bias is not None:
  434. module.bias.data.zero_()
  435. elif isinstance(module, nn.Embedding):
  436. module.weight.data.normal_(mean=0.0, std=std)
  437. if module.padding_idx is not None:
  438. module.weight.data[module.padding_idx].zero_()
  439. elif isinstance(module, nn.LayerNorm):
  440. module.weight.data.fill_(1.0)
  441. module.bias.data.zero_()
  442. @property
  443. def dummy_inputs(self):
  444. pad_token = self.config.pad_token_id
  445. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  446. dummy_inputs = {
  447. "attention_mask": input_ids.ne(pad_token),
  448. "input_ids": input_ids,
  449. }
  450. return dummy_inputs
  451. def _update_full_mask(
  452. self,
  453. attention_mask: Union[torch.Tensor, None],
  454. inputs_embeds: torch.Tensor,
  455. ):
  456. if attention_mask is not None:
  457. if self.config._attn_implementation == "flash_attention_2":
  458. attention_mask = attention_mask if 0 in attention_mask else None
  459. elif self.config._attn_implementation == "sdpa":
  460. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  461. # the manual implementation that requires a 4D causal mask in all cases.
  462. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  463. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  464. elif self.config._attn_implementation == "flex_attention":
  465. if isinstance(attention_mask, torch.Tensor):
  466. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  467. else:
  468. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  469. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  470. return attention_mask
  471. def _update_causal_mask(
  472. self,
  473. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  474. input_tensor: torch.Tensor,
  475. cache_position: torch.Tensor,
  476. past_key_values: Cache,
  477. ):
  478. if self.config._attn_implementation == "flex_attention":
  479. if isinstance(attention_mask, torch.Tensor):
  480. attention_mask = make_flex_block_causal_mask(attention_mask)
  481. # Other attention flavors support in-built causal (when `mask is None`)
  482. # while we need to create our specific block mask regardless
  483. elif attention_mask is None:
  484. attention_mask = make_flex_block_causal_mask(
  485. torch.ones(
  486. size=(input_tensor.shape[0], input_tensor.shape[1]),
  487. device=attention_mask.device,
  488. )
  489. )
  490. return attention_mask
  491. if self.config._attn_implementation == "flash_attention_2":
  492. if attention_mask is not None and (attention_mask == 0.0).any():
  493. return attention_mask
  494. return None
  495. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  496. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  497. # to infer the attention mask.
  498. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  499. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  500. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  501. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  502. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  503. attention_mask,
  504. inputs_embeds=input_tensor,
  505. past_key_values_length=past_seen_tokens,
  506. is_training=self.training,
  507. ):
  508. return None
  509. dtype = input_tensor.dtype
  510. sequence_length = input_tensor.shape[1]
  511. if using_compilable_cache:
  512. target_length = past_key_values.get_max_cache_shape()
  513. else:
  514. target_length = (
  515. attention_mask.shape[-1]
  516. if isinstance(attention_mask, torch.Tensor)
  517. else past_seen_tokens + sequence_length + 1
  518. )
  519. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  520. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  521. attention_mask,
  522. sequence_length=sequence_length,
  523. target_length=target_length,
  524. dtype=dtype,
  525. cache_position=cache_position,
  526. batch_size=input_tensor.shape[0],
  527. )
  528. if (
  529. self.config._attn_implementation == "sdpa"
  530. and attention_mask is not None
  531. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  532. ):
  533. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  534. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  535. # Details: https://github.com/pytorch/pytorch/issues/110213
  536. min_dtype = torch.finfo(dtype).min
  537. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  538. return causal_mask
  539. @staticmethod
  540. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  541. def _prepare_4d_causal_attention_mask_with_cache_position(
  542. attention_mask: torch.Tensor,
  543. sequence_length: int,
  544. target_length: int,
  545. dtype: torch.dtype,
  546. cache_position: torch.Tensor,
  547. batch_size: int,
  548. **kwargs,
  549. ):
  550. """
  551. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  552. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  553. Args:
  554. attention_mask (`torch.Tensor`):
  555. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  556. `(batch_size, 1, query_length, key_value_length)`.
  557. sequence_length (`int`):
  558. The sequence length being processed.
  559. target_length (`int`):
  560. The target length: when generating with static cache, the mask should be as long as the static cache,
  561. to account for the 0 padding, the part of the cache that is not filled yet.
  562. dtype (`torch.dtype`):
  563. The dtype to use for the 4D attention mask.
  564. cache_position (`torch.Tensor`):
  565. Indices depicting the position of the input sequence tokens in the sequence.
  566. batch_size (`torch.Tensor`):
  567. Batch size.
  568. """
  569. if attention_mask is not None and attention_mask.dim() == 4:
  570. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  571. causal_mask = attention_mask
  572. else:
  573. min_dtype = torch.finfo(dtype).min
  574. causal_mask = torch.full(
  575. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  576. )
  577. if sequence_length != 1:
  578. causal_mask = torch.triu(causal_mask, diagonal=1)
  579. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  580. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  581. if attention_mask is not None:
  582. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  583. mask_length = attention_mask.shape[-1]
  584. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  585. causal_mask.device
  586. )
  587. padding_mask = padding_mask == 0
  588. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  589. padding_mask, min_dtype
  590. )
  591. return causal_mask
  592. def _update_cross_attn_mask(
  593. self,
  594. encoder_hidden_states: Union[torch.Tensor, None],
  595. encoder_attention_mask: Union[torch.Tensor, None],
  596. input_shape: torch.Size,
  597. inputs_embeds: torch.Tensor,
  598. ):
  599. # expand encoder attention mask
  600. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  601. if self.config._attn_implementation == "flash_attention_2":
  602. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  603. elif self.config._attn_implementation == "sdpa":
  604. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  605. # the manual implementation that requires a 4D causal mask in all cases.
  606. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  607. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  608. encoder_attention_mask,
  609. inputs_embeds.dtype,
  610. tgt_len=input_shape[-1],
  611. )
  612. elif self.config._attn_implementation == "flex_attention":
  613. if isinstance(encoder_attention_mask, torch.Tensor):
  614. encoder_attention_mask = make_flex_block_causal_mask(
  615. encoder_attention_mask,
  616. query_length=input_shape[-1],
  617. is_causal=False,
  618. )
  619. else:
  620. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  621. encoder_attention_mask = _prepare_4d_attention_mask(
  622. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  623. )
  624. return encoder_attention_mask
  625. class PretrainedBartModel(BartPreTrainedModel):
  626. def __init_subclass__(self):
  627. warnings.warn(
  628. "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
  629. FutureWarning,
  630. )
  631. class BartPretrainedModel(BartPreTrainedModel):
  632. def __init_subclass__(self):
  633. warnings.warn(
  634. "The class `PretrainedBartModel` has been depreciated, please use `BartPreTrainedModel` instead.",
  635. FutureWarning,
  636. )
  637. class BartEncoder(BartPreTrainedModel):
  638. """
  639. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  640. [`BartEncoderLayer`].
  641. Args:
  642. config: BartConfig
  643. embed_tokens (nn.Embedding): output embedding
  644. """
  645. def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
  646. super().__init__(config)
  647. self.dropout = config.dropout
  648. self.layerdrop = config.encoder_layerdrop
  649. embed_dim = config.d_model
  650. self.padding_idx = config.pad_token_id
  651. self.max_source_positions = config.max_position_embeddings
  652. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  653. self.embed_tokens = BartScaledWordEmbedding(
  654. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  655. )
  656. if embed_tokens is not None:
  657. self.embed_tokens.weight = embed_tokens.weight
  658. self.embed_positions = BartLearnedPositionalEmbedding(
  659. config.max_position_embeddings,
  660. embed_dim,
  661. )
  662. self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
  663. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  664. self.gradient_checkpointing = False
  665. # Initialize weights and apply final processing
  666. self.post_init()
  667. def forward(
  668. self,
  669. input_ids: Optional[torch.LongTensor] = None,
  670. attention_mask: Optional[torch.Tensor] = None,
  671. head_mask: Optional[torch.Tensor] = None,
  672. inputs_embeds: Optional[torch.FloatTensor] = None,
  673. output_attentions: Optional[bool] = None,
  674. output_hidden_states: Optional[bool] = None,
  675. return_dict: Optional[bool] = None,
  676. ) -> Union[tuple, BaseModelOutput]:
  677. r"""
  678. Args:
  679. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  680. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  681. provide it.
  682. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  683. [`PreTrainedTokenizer.__call__`] for details.
  684. [What are input IDs?](../glossary#input-ids)
  685. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  686. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  687. - 1 for tokens that are **not masked**,
  688. - 0 for tokens that are **masked**.
  689. [What are attention masks?](../glossary#attention-mask)
  690. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  691. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  692. - 1 indicates the head is **not masked**,
  693. - 0 indicates the head is **masked**.
  694. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  695. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  696. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  697. than the model's internal embedding lookup matrix.
  698. output_attentions (`bool`, *optional*):
  699. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  700. returned tensors for more detail.
  701. output_hidden_states (`bool`, *optional*):
  702. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  703. for more detail.
  704. return_dict (`bool`, *optional*):
  705. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  706. """
  707. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  708. output_hidden_states = (
  709. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  710. )
  711. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  712. # retrieve input_ids and inputs_embeds
  713. if input_ids is not None and inputs_embeds is not None:
  714. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  715. elif input_ids is not None:
  716. input = input_ids
  717. input_ids = input_ids.view(-1, input_ids.shape[-1])
  718. elif inputs_embeds is not None:
  719. input = inputs_embeds[:, :, -1]
  720. else:
  721. raise ValueError("You have to specify either input_ids or inputs_embeds")
  722. if inputs_embeds is None:
  723. inputs_embeds = self.embed_tokens(input_ids)
  724. embed_pos = self.embed_positions(input)
  725. embed_pos = embed_pos.to(inputs_embeds.device)
  726. hidden_states = inputs_embeds + embed_pos
  727. hidden_states = self.layernorm_embedding(hidden_states)
  728. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  729. attention_mask = self._update_full_mask(
  730. attention_mask,
  731. inputs_embeds,
  732. )
  733. encoder_states = () if output_hidden_states else None
  734. all_attentions = () if output_attentions else None
  735. # check if head_mask has a correct number of layers specified if desired
  736. if head_mask is not None:
  737. if head_mask.size()[0] != (len(self.layers)):
  738. raise ValueError(
  739. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  740. f" {head_mask.size()[0]}."
  741. )
  742. for idx, encoder_layer in enumerate(self.layers):
  743. if output_hidden_states:
  744. encoder_states = encoder_states + (hidden_states,)
  745. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  746. to_drop = False
  747. if self.training:
  748. dropout_probability = torch.rand([])
  749. if dropout_probability < self.layerdrop: # skip the layer
  750. to_drop = True
  751. if to_drop:
  752. layer_outputs = (None, None)
  753. else:
  754. layer_outputs = encoder_layer(
  755. hidden_states,
  756. attention_mask,
  757. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  758. output_attentions=output_attentions,
  759. )
  760. hidden_states = layer_outputs[0]
  761. if output_attentions:
  762. all_attentions = all_attentions + (layer_outputs[1],)
  763. if output_hidden_states:
  764. encoder_states = encoder_states + (hidden_states,)
  765. if not return_dict:
  766. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  767. return BaseModelOutput(
  768. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  769. )
  770. class BartDecoder(BartPreTrainedModel):
  771. """
  772. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
  773. Args:
  774. config: BartConfig
  775. embed_tokens (nn.Embedding): output embedding
  776. """
  777. def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
  778. super().__init__(config)
  779. self.dropout = config.dropout
  780. self.layerdrop = config.decoder_layerdrop
  781. self.padding_idx = config.pad_token_id
  782. self.max_target_positions = config.max_position_embeddings
  783. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  784. self.embed_tokens = BartScaledWordEmbedding(
  785. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  786. )
  787. if embed_tokens is not None:
  788. self.embed_tokens.weight = embed_tokens.weight
  789. self.embed_positions = BartLearnedPositionalEmbedding(
  790. config.max_position_embeddings,
  791. config.d_model,
  792. )
  793. self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  794. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  795. self.gradient_checkpointing = False
  796. # Initialize weights and apply final processing
  797. self.post_init()
  798. def forward(
  799. self,
  800. input_ids: Optional[torch.LongTensor] = None,
  801. attention_mask: Optional[torch.Tensor] = None,
  802. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  803. encoder_attention_mask: Optional[torch.LongTensor] = None,
  804. head_mask: Optional[torch.Tensor] = None,
  805. cross_attn_head_mask: Optional[torch.Tensor] = None,
  806. past_key_values: Optional[Cache] = None,
  807. inputs_embeds: Optional[torch.FloatTensor] = None,
  808. use_cache: Optional[bool] = None,
  809. output_attentions: Optional[bool] = None,
  810. output_hidden_states: Optional[bool] = None,
  811. return_dict: Optional[bool] = None,
  812. cache_position: Optional[torch.LongTensor] = None,
  813. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  814. r"""
  815. Args:
  816. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  817. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  818. provide it.
  819. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  820. [`PreTrainedTokenizer.__call__`] for details.
  821. [What are input IDs?](../glossary#input-ids)
  822. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  823. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  824. - 1 for tokens that are **not masked**,
  825. - 0 for tokens that are **masked**.
  826. [What are attention masks?](../glossary#attention-mask)
  827. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  828. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  829. of the decoder.
  830. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  831. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  832. selected in `[0, 1]`:
  833. - 1 for tokens that are **not masked**,
  834. - 0 for tokens that are **masked**.
  835. [What are attention masks?](../glossary#attention-mask)
  836. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  837. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  838. - 1 indicates the head is **not masked**,
  839. - 0 indicates the head is **masked**.
  840. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  841. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  842. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  843. - 1 indicates the head is **not masked**,
  844. - 0 indicates the head is **masked**.
  845. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  846. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  847. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  848. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  849. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  850. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  851. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  852. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  853. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  854. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  855. than the model's internal embedding lookup matrix.
  856. output_attentions (`bool`, *optional*):
  857. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  858. returned tensors for more detail.
  859. output_hidden_states (`bool`, *optional*):
  860. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  861. for more detail.
  862. return_dict (`bool`, *optional*):
  863. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  864. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  865. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  866. cache in the correct position and to infer the complete sequence length.
  867. """
  868. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  869. output_hidden_states = (
  870. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  871. )
  872. use_cache = use_cache if use_cache is not None else self.config.use_cache
  873. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  874. if self.gradient_checkpointing and self.training:
  875. if use_cache:
  876. logger.warning_once(
  877. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  878. )
  879. use_cache = False
  880. # retrieve input_ids and inputs_embeds
  881. if (input_ids is None) ^ (inputs_embeds is not None):
  882. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  883. elif input_ids is not None:
  884. input = input_ids
  885. input_shape = input.shape
  886. input_ids = input_ids.view(-1, input_shape[-1])
  887. elif inputs_embeds is not None:
  888. input_shape = inputs_embeds.size()[:-1]
  889. input = inputs_embeds[:, :, -1]
  890. else:
  891. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  892. if inputs_embeds is None:
  893. inputs_embeds = self.embed_tokens(input)
  894. # initialize `past_key_values`
  895. if use_cache and past_key_values is None:
  896. past_key_values = (
  897. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  898. if encoder_hidden_states is not None
  899. else DynamicCache(config=self.config)
  900. )
  901. if use_cache and isinstance(past_key_values, tuple):
  902. logger.warning_once(
  903. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  904. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  905. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  906. )
  907. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  908. batch_size, seq_length = inputs_embeds.size()[:-1]
  909. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  910. if cache_position is None:
  911. cache_position = torch.arange(
  912. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  913. )
  914. if attention_mask is None and not is_torchdynamo_compiling():
  915. # required mask seq length can be calculated via length of past cache
  916. mask_seq_length = past_key_values_length + seq_length
  917. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  918. self_attn_cache = (
  919. past_key_values.self_attention_cache
  920. if isinstance(past_key_values, EncoderDecoderCache)
  921. else past_key_values
  922. )
  923. attention_mask = self._update_causal_mask(
  924. attention_mask,
  925. inputs_embeds,
  926. cache_position,
  927. self_attn_cache,
  928. )
  929. encoder_attention_mask = self._update_cross_attn_mask(
  930. encoder_hidden_states,
  931. encoder_attention_mask,
  932. input_shape,
  933. inputs_embeds,
  934. )
  935. # embed positions
  936. positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
  937. positions = positions.to(inputs_embeds.device)
  938. hidden_states = inputs_embeds + positions
  939. hidden_states = self.layernorm_embedding(hidden_states)
  940. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  941. # decoder layers
  942. all_hidden_states = () if output_hidden_states else None
  943. all_self_attns = () if output_attentions else None
  944. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  945. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  946. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  947. if attn_mask is not None:
  948. if attn_mask.size()[0] != (len(self.layers)):
  949. raise ValueError(
  950. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  951. f" {head_mask.size()[0]}."
  952. )
  953. for idx, decoder_layer in enumerate(self.layers):
  954. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  955. if output_hidden_states:
  956. all_hidden_states += (hidden_states,)
  957. if self.training:
  958. dropout_probability = torch.rand([])
  959. if dropout_probability < self.layerdrop:
  960. continue
  961. layer_outputs = decoder_layer(
  962. hidden_states,
  963. attention_mask,
  964. encoder_hidden_states, # as a positional argument for gradient checkpointing
  965. encoder_attention_mask=encoder_attention_mask,
  966. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  967. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  968. past_key_values=past_key_values,
  969. output_attentions=output_attentions,
  970. use_cache=use_cache,
  971. cache_position=cache_position,
  972. )
  973. hidden_states = layer_outputs[0]
  974. if output_attentions:
  975. all_self_attns += (layer_outputs[1],)
  976. if encoder_hidden_states is not None:
  977. all_cross_attentions += (layer_outputs[2],)
  978. # add hidden states from the last decoder layer
  979. if output_hidden_states:
  980. all_hidden_states += (hidden_states,)
  981. if not return_dict:
  982. return tuple(
  983. v
  984. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  985. if v is not None
  986. )
  987. return BaseModelOutputWithPastAndCrossAttentions(
  988. last_hidden_state=hidden_states,
  989. past_key_values=past_key_values,
  990. hidden_states=all_hidden_states,
  991. attentions=all_self_attns,
  992. cross_attentions=all_cross_attentions,
  993. )
  994. @auto_docstring
  995. class BartModel(BartPreTrainedModel):
  996. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  997. def __init__(self, config: BartConfig):
  998. super().__init__(config)
  999. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  1000. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  1001. self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  1002. self.encoder = BartEncoder(config, self.shared)
  1003. self.decoder = BartDecoder(config, self.shared)
  1004. # Initialize weights and apply final processing
  1005. self.post_init()
  1006. def _tie_weights(self):
  1007. if self.config.tie_word_embeddings:
  1008. # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
  1009. if self.shared.weight.device == torch.device(
  1010. "meta"
  1011. ) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
  1012. self._tie_or_clone_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
  1013. self._tie_or_clone_weights(self.shared, self.decoder.embed_tokens)
  1014. else:
  1015. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1016. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1017. def get_input_embeddings(self):
  1018. return self.shared
  1019. def set_input_embeddings(self, value):
  1020. self.shared = value
  1021. self.encoder.embed_tokens = self.shared
  1022. self.decoder.embed_tokens = self.shared
  1023. def get_encoder(self):
  1024. return self.encoder
  1025. @auto_docstring
  1026. def forward(
  1027. self,
  1028. input_ids: Optional[torch.LongTensor] = None,
  1029. attention_mask: Optional[torch.Tensor] = None,
  1030. decoder_input_ids: Optional[torch.LongTensor] = None,
  1031. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1032. head_mask: Optional[torch.Tensor] = None,
  1033. decoder_head_mask: Optional[torch.Tensor] = None,
  1034. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1035. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1036. past_key_values: Optional[Cache] = None,
  1037. inputs_embeds: Optional[torch.FloatTensor] = None,
  1038. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1039. use_cache: Optional[bool] = None,
  1040. output_attentions: Optional[bool] = None,
  1041. output_hidden_states: Optional[bool] = None,
  1042. return_dict: Optional[bool] = None,
  1043. cache_position: Optional[torch.LongTensor] = None,
  1044. ) -> Union[tuple, Seq2SeqModelOutput]:
  1045. r"""
  1046. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1047. Indices of decoder input sequence tokens in the vocabulary.
  1048. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1049. [`PreTrainedTokenizer.__call__`] for details.
  1050. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1051. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1052. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1053. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1054. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1055. for denoising pre-training following the paper.
  1056. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1057. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1058. be used by default.
  1059. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1060. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1061. information on the default strategy.
  1062. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1063. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1064. 1]`:
  1065. - 1 indicates the head is **not masked**,
  1066. - 0 indicates the head is **masked**.
  1067. """
  1068. # different to other models, Bart automatically creates decoder_input_ids from
  1069. # input_ids if no decoder_input_ids are provided
  1070. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1071. if input_ids is None:
  1072. raise ValueError(
  1073. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1074. "passed, `input_ids` cannot be `None`. Please pass either "
  1075. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1076. )
  1077. decoder_input_ids = shift_tokens_right(
  1078. input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
  1079. )
  1080. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1081. output_hidden_states = (
  1082. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1083. )
  1084. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1085. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1086. if encoder_outputs is None:
  1087. encoder_outputs = self.encoder(
  1088. input_ids=input_ids,
  1089. attention_mask=attention_mask,
  1090. head_mask=head_mask,
  1091. inputs_embeds=inputs_embeds,
  1092. output_attentions=output_attentions,
  1093. output_hidden_states=output_hidden_states,
  1094. return_dict=return_dict,
  1095. )
  1096. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1097. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1098. encoder_outputs = BaseModelOutput(
  1099. last_hidden_state=encoder_outputs[0],
  1100. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1101. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1102. )
  1103. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1104. decoder_outputs = self.decoder(
  1105. input_ids=decoder_input_ids,
  1106. attention_mask=decoder_attention_mask,
  1107. encoder_hidden_states=encoder_outputs[0],
  1108. encoder_attention_mask=attention_mask,
  1109. head_mask=decoder_head_mask,
  1110. cross_attn_head_mask=cross_attn_head_mask,
  1111. past_key_values=past_key_values,
  1112. inputs_embeds=decoder_inputs_embeds,
  1113. use_cache=use_cache,
  1114. output_attentions=output_attentions,
  1115. output_hidden_states=output_hidden_states,
  1116. return_dict=return_dict,
  1117. cache_position=cache_position,
  1118. )
  1119. if not return_dict:
  1120. return decoder_outputs + encoder_outputs
  1121. return Seq2SeqModelOutput(
  1122. last_hidden_state=decoder_outputs.last_hidden_state,
  1123. past_key_values=decoder_outputs.past_key_values,
  1124. decoder_hidden_states=decoder_outputs.hidden_states,
  1125. decoder_attentions=decoder_outputs.attentions,
  1126. cross_attentions=decoder_outputs.cross_attentions,
  1127. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1128. encoder_hidden_states=encoder_outputs.hidden_states,
  1129. encoder_attentions=encoder_outputs.attentions,
  1130. )
  1131. @auto_docstring(
  1132. custom_intro="""
  1133. The BART Model with a language modeling head. Can be used for summarization.
  1134. """
  1135. )
  1136. class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
  1137. base_model_prefix = "model"
  1138. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1139. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1140. def __init__(self, config: BartConfig):
  1141. super().__init__(config)
  1142. self.model = BartModel(config)
  1143. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1144. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1145. # Initialize weights and apply final processing
  1146. self.post_init()
  1147. def get_encoder(self):
  1148. return self.model.get_encoder()
  1149. def get_decoder(self):
  1150. return self.model.get_decoder()
  1151. def resize_token_embeddings(
  1152. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  1153. ) -> nn.Embedding:
  1154. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1155. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1156. return new_embeddings
  1157. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1158. old_num_tokens = self.final_logits_bias.shape[-1]
  1159. if new_num_tokens <= old_num_tokens:
  1160. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1161. else:
  1162. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1163. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1164. self.register_buffer("final_logits_bias", new_bias)
  1165. def _tie_weights(self):
  1166. if self.config.tie_word_embeddings:
  1167. self.model._tie_weights()
  1168. self._tie_or_clone_weights(self.lm_head, self.model.shared)
  1169. @auto_docstring
  1170. def forward(
  1171. self,
  1172. input_ids: Optional[torch.LongTensor] = None,
  1173. attention_mask: Optional[torch.Tensor] = None,
  1174. decoder_input_ids: Optional[torch.LongTensor] = None,
  1175. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1176. head_mask: Optional[torch.Tensor] = None,
  1177. decoder_head_mask: Optional[torch.Tensor] = None,
  1178. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1179. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1180. past_key_values: Optional[Cache] = None,
  1181. inputs_embeds: Optional[torch.FloatTensor] = None,
  1182. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1183. labels: Optional[torch.LongTensor] = None,
  1184. use_cache: Optional[bool] = None,
  1185. output_attentions: Optional[bool] = None,
  1186. output_hidden_states: Optional[bool] = None,
  1187. return_dict: Optional[bool] = None,
  1188. cache_position: Optional[torch.LongTensor] = None,
  1189. ) -> Union[tuple, Seq2SeqLMOutput]:
  1190. r"""
  1191. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1192. Indices of decoder input sequence tokens in the vocabulary.
  1193. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1194. [`PreTrainedTokenizer.__call__`] for details.
  1195. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1196. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1197. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1198. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1199. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1200. for denoising pre-training following the paper.
  1201. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1202. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1203. be used by default.
  1204. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1205. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1206. information on the default strategy.
  1207. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1208. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1209. 1]`:
  1210. - 1 indicates the head is **not masked**,
  1211. - 0 indicates the head is **masked**.
  1212. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1213. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1214. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1215. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1216. Example summarization:
  1217. ```python
  1218. >>> from transformers import AutoTokenizer, BartForConditionalGeneration
  1219. >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  1220. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  1221. >>> ARTICLE_TO_SUMMARIZE = (
  1222. ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
  1223. ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
  1224. ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
  1225. ... )
  1226. >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
  1227. >>> # Generate Summary
  1228. >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
  1229. >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1230. 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
  1231. ```
  1232. Mask filling example:
  1233. ```python
  1234. >>> from transformers import AutoTokenizer, BartForConditionalGeneration
  1235. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
  1236. >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
  1237. >>> TXT = "My friends are <mask> but they eat too many carbs."
  1238. >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
  1239. >>> logits = model(input_ids).logits
  1240. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  1241. >>> probs = logits[0, masked_index].softmax(dim=0)
  1242. >>> values, predictions = probs.topk(5)
  1243. >>> tokenizer.decode(predictions).split()
  1244. ['not', 'good', 'healthy', 'great', 'very']
  1245. ```
  1246. """
  1247. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1248. if labels is not None:
  1249. if use_cache:
  1250. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1251. use_cache = False
  1252. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1253. decoder_input_ids = shift_tokens_right(
  1254. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1255. )
  1256. outputs = self.model(
  1257. input_ids,
  1258. attention_mask=attention_mask,
  1259. decoder_input_ids=decoder_input_ids,
  1260. encoder_outputs=encoder_outputs,
  1261. decoder_attention_mask=decoder_attention_mask,
  1262. head_mask=head_mask,
  1263. decoder_head_mask=decoder_head_mask,
  1264. cross_attn_head_mask=cross_attn_head_mask,
  1265. past_key_values=past_key_values,
  1266. inputs_embeds=inputs_embeds,
  1267. decoder_inputs_embeds=decoder_inputs_embeds,
  1268. use_cache=use_cache,
  1269. output_attentions=output_attentions,
  1270. output_hidden_states=output_hidden_states,
  1271. return_dict=return_dict,
  1272. cache_position=cache_position,
  1273. )
  1274. lm_logits = self.lm_head(outputs[0])
  1275. lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
  1276. masked_lm_loss = None
  1277. if labels is not None:
  1278. labels = labels.to(lm_logits.device)
  1279. loss_fct = CrossEntropyLoss()
  1280. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1281. if not return_dict:
  1282. output = (lm_logits,) + outputs[1:]
  1283. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1284. return Seq2SeqLMOutput(
  1285. loss=masked_lm_loss,
  1286. logits=lm_logits,
  1287. past_key_values=outputs.past_key_values,
  1288. decoder_hidden_states=outputs.decoder_hidden_states,
  1289. decoder_attentions=outputs.decoder_attentions,
  1290. cross_attentions=outputs.cross_attentions,
  1291. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1292. encoder_hidden_states=outputs.encoder_hidden_states,
  1293. encoder_attentions=outputs.encoder_attentions,
  1294. )
  1295. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1296. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  1297. @auto_docstring(
  1298. custom_intro="""
  1299. Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1300. tasks.
  1301. """
  1302. )
  1303. class BartForSequenceClassification(BartPreTrainedModel):
  1304. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1305. def __init__(self, config: BartConfig, **kwargs):
  1306. super().__init__(config, **kwargs)
  1307. self.model = BartModel(config)
  1308. self.classification_head = BartClassificationHead(
  1309. config.d_model,
  1310. config.d_model,
  1311. config.num_labels,
  1312. config.classifier_dropout,
  1313. )
  1314. # Initialize weights and apply final processing
  1315. self.post_init()
  1316. @auto_docstring
  1317. def forward(
  1318. self,
  1319. input_ids: Optional[torch.LongTensor] = None,
  1320. attention_mask: Optional[torch.Tensor] = None,
  1321. decoder_input_ids: Optional[torch.LongTensor] = None,
  1322. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1323. head_mask: Optional[torch.Tensor] = None,
  1324. decoder_head_mask: Optional[torch.Tensor] = None,
  1325. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1326. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1327. inputs_embeds: Optional[torch.FloatTensor] = None,
  1328. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1329. labels: Optional[torch.LongTensor] = None,
  1330. use_cache: Optional[bool] = None,
  1331. output_attentions: Optional[bool] = None,
  1332. output_hidden_states: Optional[bool] = None,
  1333. return_dict: Optional[bool] = None,
  1334. cache_position: Optional[torch.LongTensor] = None,
  1335. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1336. r"""
  1337. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1338. Indices of decoder input sequence tokens in the vocabulary.
  1339. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1340. [`PreTrainedTokenizer.__call__`] for details.
  1341. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1342. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1343. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1344. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1345. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1346. for denoising pre-training following the paper.
  1347. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1348. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1349. be used by default.
  1350. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1351. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1352. information on the default strategy.
  1353. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1354. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1355. 1]`:
  1356. - 1 indicates the head is **not masked**,
  1357. - 0 indicates the head is **masked**.
  1358. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1359. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1360. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1361. """
  1362. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1363. if labels is not None:
  1364. use_cache = False
  1365. if input_ids is None and inputs_embeds is not None:
  1366. raise NotImplementedError(
  1367. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1368. )
  1369. outputs = self.model(
  1370. input_ids,
  1371. attention_mask=attention_mask,
  1372. decoder_input_ids=decoder_input_ids,
  1373. decoder_attention_mask=decoder_attention_mask,
  1374. head_mask=head_mask,
  1375. decoder_head_mask=decoder_head_mask,
  1376. cross_attn_head_mask=cross_attn_head_mask,
  1377. encoder_outputs=encoder_outputs,
  1378. inputs_embeds=inputs_embeds,
  1379. decoder_inputs_embeds=decoder_inputs_embeds,
  1380. use_cache=use_cache,
  1381. output_attentions=output_attentions,
  1382. output_hidden_states=output_hidden_states,
  1383. return_dict=return_dict,
  1384. cache_position=cache_position,
  1385. )
  1386. hidden_states = outputs[0] # last hidden state
  1387. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1388. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1389. raise ValueError("All examples must have the same number of <eos> tokens.")
  1390. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1391. :, -1, :
  1392. ]
  1393. logits = self.classification_head(sentence_representation)
  1394. loss = None
  1395. if labels is not None:
  1396. labels = labels.to(logits.device)
  1397. if self.config.problem_type is None:
  1398. if self.config.num_labels == 1:
  1399. self.config.problem_type = "regression"
  1400. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1401. self.config.problem_type = "single_label_classification"
  1402. else:
  1403. self.config.problem_type = "multi_label_classification"
  1404. if self.config.problem_type == "regression":
  1405. loss_fct = MSELoss()
  1406. if self.config.num_labels == 1:
  1407. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1408. else:
  1409. loss = loss_fct(logits, labels)
  1410. elif self.config.problem_type == "single_label_classification":
  1411. loss_fct = CrossEntropyLoss()
  1412. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1413. elif self.config.problem_type == "multi_label_classification":
  1414. loss_fct = BCEWithLogitsLoss()
  1415. loss = loss_fct(logits, labels)
  1416. if not return_dict:
  1417. output = (logits,) + outputs[1:]
  1418. return ((loss,) + output) if loss is not None else output
  1419. return Seq2SeqSequenceClassifierOutput(
  1420. loss=loss,
  1421. logits=logits,
  1422. past_key_values=outputs.past_key_values,
  1423. decoder_hidden_states=outputs.decoder_hidden_states,
  1424. decoder_attentions=outputs.decoder_attentions,
  1425. cross_attentions=outputs.cross_attentions,
  1426. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1427. encoder_hidden_states=outputs.encoder_hidden_states,
  1428. encoder_attentions=outputs.encoder_attentions,
  1429. )
  1430. @auto_docstring
  1431. class BartForQuestionAnswering(BartPreTrainedModel):
  1432. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1433. def __init__(self, config):
  1434. super().__init__(config)
  1435. config.num_labels = 2
  1436. self.num_labels = config.num_labels
  1437. self.model = BartModel(config)
  1438. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1439. # Initialize weights and apply final processing
  1440. self.post_init()
  1441. @auto_docstring
  1442. def forward(
  1443. self,
  1444. input_ids: Optional[torch.Tensor] = None,
  1445. attention_mask: Optional[torch.Tensor] = None,
  1446. decoder_input_ids: Optional[torch.LongTensor] = None,
  1447. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1448. head_mask: Optional[torch.Tensor] = None,
  1449. decoder_head_mask: Optional[torch.Tensor] = None,
  1450. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1451. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1452. start_positions: Optional[torch.LongTensor] = None,
  1453. end_positions: Optional[torch.LongTensor] = None,
  1454. inputs_embeds: Optional[torch.FloatTensor] = None,
  1455. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1456. use_cache: Optional[bool] = None,
  1457. output_attentions: Optional[bool] = None,
  1458. output_hidden_states: Optional[bool] = None,
  1459. return_dict: Optional[bool] = None,
  1460. cache_position: Optional[torch.LongTensor] = None,
  1461. ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
  1462. r"""
  1463. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1464. Indices of decoder input sequence tokens in the vocabulary.
  1465. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1466. [`PreTrainedTokenizer.__call__`] for details.
  1467. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1468. Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1469. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1470. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1471. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1472. for denoising pre-training following the paper.
  1473. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1474. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1475. be used by default.
  1476. If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_attention_mask`]
  1477. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1478. information on the default strategy.
  1479. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1480. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1481. 1]`:
  1482. - 1 indicates the head is **not masked**,
  1483. - 0 indicates the head is **masked**.
  1484. """
  1485. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1486. if start_positions is not None and end_positions is not None:
  1487. use_cache = False
  1488. outputs = self.model(
  1489. input_ids,
  1490. attention_mask=attention_mask,
  1491. decoder_input_ids=decoder_input_ids,
  1492. decoder_attention_mask=decoder_attention_mask,
  1493. head_mask=head_mask,
  1494. decoder_head_mask=decoder_head_mask,
  1495. cross_attn_head_mask=cross_attn_head_mask,
  1496. encoder_outputs=encoder_outputs,
  1497. inputs_embeds=inputs_embeds,
  1498. decoder_inputs_embeds=decoder_inputs_embeds,
  1499. use_cache=use_cache,
  1500. output_attentions=output_attentions,
  1501. output_hidden_states=output_hidden_states,
  1502. return_dict=return_dict,
  1503. cache_position=cache_position,
  1504. )
  1505. sequence_output = outputs[0]
  1506. logits = self.qa_outputs(sequence_output)
  1507. start_logits, end_logits = logits.split(1, dim=-1)
  1508. start_logits = start_logits.squeeze(-1).contiguous()
  1509. end_logits = end_logits.squeeze(-1).contiguous()
  1510. total_loss = None
  1511. if start_positions is not None and end_positions is not None:
  1512. # If we are on multi-GPU, split add a dimension
  1513. if len(start_positions.size()) > 1:
  1514. start_positions = start_positions.squeeze(-1)
  1515. if len(end_positions.size()) > 1:
  1516. end_positions = end_positions.squeeze(-1)
  1517. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1518. ignored_index = start_logits.size(1)
  1519. start_positions = start_positions.clamp(0, ignored_index)
  1520. end_positions = end_positions.clamp(0, ignored_index)
  1521. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1522. start_loss = loss_fct(start_logits, start_positions)
  1523. end_loss = loss_fct(end_logits, end_positions)
  1524. total_loss = (start_loss + end_loss) / 2
  1525. if not return_dict:
  1526. output = (
  1527. start_logits,
  1528. end_logits,
  1529. ) + outputs[1:]
  1530. return ((total_loss,) + output) if total_loss is not None else output
  1531. return Seq2SeqQuestionAnsweringModelOutput(
  1532. loss=total_loss,
  1533. start_logits=start_logits,
  1534. end_logits=end_logits,
  1535. past_key_values=outputs.past_key_values,
  1536. decoder_hidden_states=outputs.decoder_hidden_states,
  1537. decoder_attentions=outputs.decoder_attentions,
  1538. cross_attentions=outputs.cross_attentions,
  1539. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1540. encoder_hidden_states=outputs.encoder_hidden_states,
  1541. encoder_attentions=outputs.encoder_attentions,
  1542. )
  1543. class BartDecoderWrapper(BartPreTrainedModel):
  1544. """
  1545. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1546. used in combination with the [`EncoderDecoderModel`] framework.
  1547. """
  1548. def __init__(self, config):
  1549. super().__init__(config)
  1550. self.decoder = BartDecoder(config)
  1551. def forward(self, *args, **kwargs):
  1552. return self.decoder(*args, **kwargs)
  1553. @auto_docstring(
  1554. custom_intro="""
  1555. BART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  1556. """
  1557. )
  1558. class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
  1559. _tied_weights_keys = ["lm_head.weight"]
  1560. def __init__(self, config):
  1561. config.is_decoder = True
  1562. config.is_encoder_decoder = False
  1563. super().__init__(config)
  1564. self.model = BartDecoderWrapper(config)
  1565. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1566. # Initialize weights and apply final processing
  1567. self.post_init()
  1568. def get_input_embeddings(self):
  1569. return self.model.decoder.embed_tokens
  1570. def set_input_embeddings(self, value):
  1571. self.model.decoder.embed_tokens = value
  1572. def set_decoder(self, decoder):
  1573. self.model.decoder = decoder
  1574. def get_decoder(self):
  1575. return self.model.decoder
  1576. @auto_docstring
  1577. def forward(
  1578. self,
  1579. input_ids: Optional[torch.LongTensor] = None,
  1580. attention_mask: Optional[torch.Tensor] = None,
  1581. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1582. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1583. head_mask: Optional[torch.Tensor] = None,
  1584. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1585. past_key_values: Optional[Cache] = None,
  1586. inputs_embeds: Optional[torch.FloatTensor] = None,
  1587. labels: Optional[torch.LongTensor] = None,
  1588. use_cache: Optional[bool] = None,
  1589. output_attentions: Optional[bool] = None,
  1590. output_hidden_states: Optional[bool] = None,
  1591. return_dict: Optional[bool] = None,
  1592. cache_position: Optional[torch.LongTensor] = None,
  1593. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1594. r"""
  1595. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1596. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1597. - 1 indicates the head is **not masked**,
  1598. - 0 indicates the head is **masked**.
  1599. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1600. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1601. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1602. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1603. Example:
  1604. ```python
  1605. >>> from transformers import AutoTokenizer, BartForCausalLM
  1606. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
  1607. >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
  1608. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1609. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1610. >>> outputs = model(**inputs)
  1611. >>> logits = outputs.logits
  1612. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1613. >>> list(logits.shape) == expected_shape
  1614. True
  1615. ```"""
  1616. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1617. output_hidden_states = (
  1618. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1619. )
  1620. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1621. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1622. outputs = self.model.decoder(
  1623. input_ids=input_ids,
  1624. attention_mask=attention_mask,
  1625. encoder_hidden_states=encoder_hidden_states,
  1626. encoder_attention_mask=encoder_attention_mask,
  1627. head_mask=head_mask,
  1628. cross_attn_head_mask=cross_attn_head_mask,
  1629. past_key_values=past_key_values,
  1630. inputs_embeds=inputs_embeds,
  1631. use_cache=use_cache,
  1632. output_attentions=output_attentions,
  1633. output_hidden_states=output_hidden_states,
  1634. return_dict=return_dict,
  1635. cache_position=cache_position,
  1636. )
  1637. logits = self.lm_head(outputs[0])
  1638. loss = None
  1639. if labels is not None:
  1640. labels = labels.to(logits.device)
  1641. loss_fct = CrossEntropyLoss()
  1642. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1643. if not return_dict:
  1644. output = (logits,) + outputs[1:]
  1645. return (loss,) + output if loss is not None else output
  1646. return CausalLMOutputWithCrossAttentions(
  1647. loss=loss,
  1648. logits=logits,
  1649. past_key_values=outputs.past_key_values,
  1650. hidden_states=outputs.hidden_states,
  1651. attentions=outputs.attentions,
  1652. cross_attentions=outputs.cross_attentions,
  1653. )
  1654. __all__ = [
  1655. "BartForCausalLM",
  1656. "BartForConditionalGeneration",
  1657. "BartForQuestionAnswering",
  1658. "BartForSequenceClassification",
  1659. "BartModel",
  1660. "BartPreTrainedModel",
  1661. "BartPretrainedModel",
  1662. "PretrainedBartModel",
  1663. ]