modeling_blenderbot.py 72 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595
  1. # coding=utf-8
  2. # Copyright 2021 The Facebook, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Blenderbot model."""
  16. import math
  17. import os
  18. import warnings
  19. from typing import Callable, Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import (
  27. AttentionMaskConverter,
  28. _prepare_4d_attention_mask,
  29. _prepare_4d_attention_mask_for_sdpa,
  30. )
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutput,
  35. BaseModelOutputWithPastAndCrossAttentions,
  36. CausalLMOutputWithCrossAttentions,
  37. Seq2SeqLMOutput,
  38. Seq2SeqModelOutput,
  39. )
  40. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  41. from ...processing_utils import Unpack
  42. from ...utils import (
  43. auto_docstring,
  44. is_torch_flex_attn_available,
  45. is_torchdynamo_compiling,
  46. logging,
  47. )
  48. from ...utils.deprecation import deprecate_kwarg
  49. from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel
  50. from .configuration_blenderbot import BlenderbotConfig
  51. if is_torch_flex_attn_available():
  52. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  53. logger = logging.get_logger(__name__)
  54. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  55. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  56. """
  57. Shift input ids one token to the right.
  58. """
  59. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  60. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  61. shifted_input_ids[:, 0] = decoder_start_token_id
  62. if pad_token_id is None:
  63. raise ValueError("self.model.config.pad_token_id has to be defined.")
  64. # replace possible -100 values in labels by `pad_token_id`
  65. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  66. return shifted_input_ids
  67. class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
  68. """
  69. This module learns positional embeddings up to a fixed maximum size.
  70. """
  71. def __init__(self, num_embeddings: int, embedding_dim: int):
  72. super().__init__(num_embeddings, embedding_dim)
  73. def forward(
  74. self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  75. ):
  76. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  77. if position_ids is None:
  78. bsz, seq_len = input_ids_shape[:2]
  79. position_ids = torch.arange(
  80. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  81. )
  82. return super().forward(position_ids)
  83. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
  84. class BlenderbotScaledWordEmbedding(nn.Embedding):
  85. """
  86. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  87. """
  88. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  89. super().__init__(num_embeddings, embedding_dim, padding_idx)
  90. self.embed_scale = embed_scale
  91. def forward(self, input_ids: torch.Tensor):
  92. return super().forward(input_ids) * self.embed_scale
  93. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  94. def eager_attention_forward(
  95. module: nn.Module,
  96. query: torch.Tensor,
  97. key: torch.Tensor,
  98. value: torch.Tensor,
  99. attention_mask: Optional[torch.Tensor],
  100. scaling: Optional[float] = None,
  101. dropout: float = 0.0,
  102. head_mask: Optional[torch.Tensor] = None,
  103. **kwargs,
  104. ):
  105. if scaling is None:
  106. scaling = query.size(-1) ** -0.5
  107. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  108. if attention_mask is not None:
  109. attn_weights = attn_weights + attention_mask
  110. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  111. if head_mask is not None:
  112. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  113. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  114. attn_output = torch.matmul(attn_weights, value)
  115. attn_output = attn_output.transpose(1, 2).contiguous()
  116. return attn_output, attn_weights
  117. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
  118. class BlenderbotAttention(nn.Module):
  119. """Multi-headed attention from 'Attention Is All You Need' paper"""
  120. def __init__(
  121. self,
  122. embed_dim: int,
  123. num_heads: int,
  124. dropout: float = 0.0,
  125. is_decoder: bool = False,
  126. bias: bool = True,
  127. is_causal: bool = False,
  128. config: Optional[BlenderbotConfig] = None,
  129. layer_idx: Optional[int] = None,
  130. ):
  131. super().__init__()
  132. self.embed_dim = embed_dim
  133. self.num_heads = num_heads
  134. self.dropout = dropout
  135. self.head_dim = embed_dim // num_heads
  136. self.config = config
  137. if (self.head_dim * num_heads) != self.embed_dim:
  138. raise ValueError(
  139. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  140. f" and `num_heads`: {num_heads})."
  141. )
  142. self.scaling = self.head_dim**-0.5
  143. self.is_decoder = is_decoder
  144. self.is_causal = is_causal
  145. self.layer_idx = layer_idx
  146. if layer_idx is None and self.is_decoder:
  147. logger.warning_once(
  148. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  149. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  150. "when creating this class."
  151. )
  152. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  153. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  154. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  155. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  156. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  157. def forward(
  158. self,
  159. hidden_states: torch.Tensor,
  160. key_value_states: Optional[torch.Tensor] = None,
  161. past_key_values: Optional[Cache] = None,
  162. attention_mask: Optional[torch.Tensor] = None,
  163. layer_head_mask: Optional[torch.Tensor] = None,
  164. output_attentions: bool = False,
  165. cache_position: Optional[torch.Tensor] = None,
  166. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  167. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  168. **kwargs: Unpack[FlashAttentionKwargs],
  169. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  170. """Input shape: Batch x Time x Channel"""
  171. # if key_value_states are provided this layer is used as a cross-attention layer
  172. # for the decoder
  173. is_cross_attention = key_value_states is not None
  174. # determine input shapes
  175. bsz, tgt_len = hidden_states.shape[:-1]
  176. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  177. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  178. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  179. # get query proj
  180. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  181. is_updated = False
  182. if past_key_values is not None:
  183. if isinstance(past_key_values, EncoderDecoderCache):
  184. is_updated = past_key_values.is_updated.get(self.layer_idx)
  185. if is_cross_attention:
  186. # after the first generated id, we can subsequently re-use all key/value_states from cache
  187. curr_past_key_value = past_key_values.cross_attention_cache
  188. else:
  189. curr_past_key_value = past_key_values.self_attention_cache
  190. else:
  191. curr_past_key_value = past_key_values
  192. current_states = key_value_states if is_cross_attention else hidden_states
  193. if is_cross_attention and past_key_values is not None and is_updated:
  194. # reuse k,v, cross_attentions
  195. key_states = curr_past_key_value.layers[self.layer_idx].keys
  196. value_states = curr_past_key_value.layers[self.layer_idx].values
  197. else:
  198. key_states = self.k_proj(current_states)
  199. value_states = self.v_proj(current_states)
  200. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  201. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  202. if past_key_values is not None:
  203. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  204. cache_position = cache_position if not is_cross_attention else None
  205. key_states, value_states = curr_past_key_value.update(
  206. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  207. )
  208. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  209. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  210. past_key_values.is_updated[self.layer_idx] = True
  211. attention_interface: Callable = eager_attention_forward
  212. if self.config._attn_implementation != "eager":
  213. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  214. attn_output, attn_weights = attention_interface(
  215. self,
  216. query_states,
  217. key_states,
  218. value_states,
  219. attention_mask,
  220. dropout=0.0 if not self.training else self.dropout,
  221. scaling=self.scaling,
  222. output_attentions=output_attentions,
  223. head_mask=layer_head_mask,
  224. **kwargs,
  225. )
  226. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  227. attn_output = self.out_proj(attn_output)
  228. return attn_output, attn_weights
  229. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
  230. class BlenderbotEncoderLayer(GradientCheckpointingLayer):
  231. def __init__(self, config: BlenderbotConfig):
  232. super().__init__()
  233. self.embed_dim = config.d_model
  234. self.self_attn = BlenderbotAttention(
  235. embed_dim=self.embed_dim,
  236. num_heads=config.encoder_attention_heads,
  237. dropout=config.attention_dropout,
  238. config=config,
  239. )
  240. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  241. self.dropout = config.dropout
  242. self.activation_fn = ACT2FN[config.activation_function]
  243. self.activation_dropout = config.activation_dropout
  244. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  245. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  246. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  247. def forward(
  248. self,
  249. hidden_states: torch.Tensor,
  250. attention_mask: torch.Tensor,
  251. layer_head_mask: torch.Tensor,
  252. output_attentions: bool = False,
  253. ) -> torch.Tensor:
  254. """
  255. Args:
  256. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  257. attention_mask (`torch.FloatTensor`): attention mask of size
  258. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  259. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  260. `(encoder_attention_heads,)`.
  261. output_attentions (`bool`, *optional*):
  262. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  263. returned tensors for more detail.
  264. """
  265. residual = hidden_states
  266. hidden_states = self.self_attn_layer_norm(hidden_states)
  267. hidden_states, attn_weights = self.self_attn(
  268. hidden_states=hidden_states,
  269. attention_mask=attention_mask,
  270. layer_head_mask=layer_head_mask,
  271. output_attentions=output_attentions,
  272. )
  273. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  274. hidden_states = residual + hidden_states
  275. residual = hidden_states
  276. hidden_states = self.final_layer_norm(hidden_states)
  277. hidden_states = self.activation_fn(self.fc1(hidden_states))
  278. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  279. hidden_states = self.fc2(hidden_states)
  280. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  281. hidden_states = residual + hidden_states
  282. if hidden_states.dtype == torch.float16:
  283. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  284. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  285. return hidden_states, attn_weights
  286. # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
  287. class BlenderbotDecoderLayer(GradientCheckpointingLayer):
  288. def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
  289. super().__init__()
  290. self.embed_dim = config.d_model
  291. self.self_attn = BlenderbotAttention(
  292. embed_dim=self.embed_dim,
  293. num_heads=config.decoder_attention_heads,
  294. dropout=config.attention_dropout,
  295. is_decoder=True,
  296. is_causal=True,
  297. config=config,
  298. layer_idx=layer_idx,
  299. )
  300. self.dropout = config.dropout
  301. self.activation_fn = ACT2FN[config.activation_function]
  302. self.activation_dropout = config.activation_dropout
  303. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  304. self.encoder_attn = BlenderbotAttention(
  305. self.embed_dim,
  306. config.decoder_attention_heads,
  307. dropout=config.attention_dropout,
  308. is_decoder=True,
  309. config=config,
  310. layer_idx=layer_idx,
  311. )
  312. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  313. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  314. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  315. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  316. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  317. def forward(
  318. self,
  319. hidden_states: torch.Tensor,
  320. attention_mask: Optional[torch.Tensor] = None,
  321. encoder_hidden_states: Optional[torch.Tensor] = None,
  322. encoder_attention_mask: Optional[torch.Tensor] = None,
  323. layer_head_mask: Optional[torch.Tensor] = None,
  324. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  325. past_key_values: Optional[Cache] = None,
  326. output_attentions: Optional[bool] = False,
  327. use_cache: Optional[bool] = True,
  328. cache_position: Optional[torch.Tensor] = None,
  329. ) -> torch.Tensor:
  330. """
  331. Args:
  332. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  333. attention_mask (`torch.FloatTensor`): attention mask of size
  334. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  335. encoder_hidden_states (`torch.FloatTensor`):
  336. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  337. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  338. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  339. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  340. `(encoder_attention_heads,)`.
  341. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  342. size `(decoder_attention_heads,)`.
  343. past_key_values (`Cache`): cached past key and value projection states
  344. output_attentions (`bool`, *optional*):
  345. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  346. returned tensors for more detail.
  347. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  348. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  349. cache in the correct position and to infer the complete sequence length.
  350. """
  351. residual = hidden_states
  352. hidden_states = self.self_attn_layer_norm(hidden_states)
  353. # Self Attention
  354. hidden_states, self_attn_weights = self.self_attn(
  355. hidden_states=hidden_states,
  356. past_key_values=past_key_values,
  357. attention_mask=attention_mask,
  358. layer_head_mask=layer_head_mask,
  359. output_attentions=output_attentions,
  360. cache_position=cache_position,
  361. )
  362. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  363. hidden_states = residual + hidden_states
  364. # Cross-Attention Block
  365. cross_attn_weights = None
  366. if encoder_hidden_states is not None:
  367. residual = hidden_states
  368. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  369. hidden_states, cross_attn_weights = self.encoder_attn(
  370. hidden_states=hidden_states,
  371. key_value_states=encoder_hidden_states,
  372. attention_mask=encoder_attention_mask,
  373. layer_head_mask=cross_attn_layer_head_mask,
  374. past_key_values=past_key_values,
  375. output_attentions=output_attentions,
  376. )
  377. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  378. hidden_states = residual + hidden_states
  379. # Fully Connected
  380. residual = hidden_states
  381. hidden_states = self.final_layer_norm(hidden_states)
  382. hidden_states = self.activation_fn(self.fc1(hidden_states))
  383. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  384. hidden_states = self.fc2(hidden_states)
  385. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  386. hidden_states = residual + hidden_states
  387. outputs = (hidden_states,)
  388. if output_attentions:
  389. outputs += (self_attn_weights, cross_attn_weights)
  390. return outputs
  391. @auto_docstring
  392. class BlenderbotPreTrainedModel(PreTrainedModel):
  393. config: BlenderbotConfig
  394. base_model_prefix = "model"
  395. supports_gradient_checkpointing = True
  396. _supports_flash_attn = True
  397. _supports_sdpa = True
  398. _supports_flex_attn = True
  399. _can_compile_fullgraph = True
  400. def _init_weights(self, module):
  401. std = self.config.init_std
  402. if isinstance(module, nn.Linear):
  403. module.weight.data.normal_(mean=0.0, std=std)
  404. if module.bias is not None:
  405. module.bias.data.zero_()
  406. elif isinstance(module, nn.Embedding):
  407. module.weight.data.normal_(mean=0.0, std=std)
  408. if module.padding_idx is not None:
  409. module.weight.data[module.padding_idx].zero_()
  410. elif isinstance(module, nn.LayerNorm):
  411. module.weight.data.fill_(1.0)
  412. module.bias.data.zero_()
  413. @property
  414. def dummy_inputs(self):
  415. pad_token = self.config.pad_token_id
  416. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  417. dummy_inputs = {
  418. "attention_mask": input_ids.ne(pad_token),
  419. "input_ids": input_ids,
  420. "decoder_input_ids": input_ids,
  421. }
  422. return dummy_inputs
  423. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  424. def _update_full_mask(
  425. self,
  426. attention_mask: Union[torch.Tensor, None],
  427. inputs_embeds: torch.Tensor,
  428. ):
  429. if attention_mask is not None:
  430. if self.config._attn_implementation == "flash_attention_2":
  431. attention_mask = attention_mask if 0 in attention_mask else None
  432. elif self.config._attn_implementation == "sdpa":
  433. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  434. # the manual implementation that requires a 4D causal mask in all cases.
  435. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  436. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  437. elif self.config._attn_implementation == "flex_attention":
  438. if isinstance(attention_mask, torch.Tensor):
  439. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  440. else:
  441. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  442. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  443. return attention_mask
  444. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  445. def _update_causal_mask(
  446. self,
  447. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  448. input_tensor: torch.Tensor,
  449. cache_position: torch.Tensor,
  450. past_key_values: Cache,
  451. ):
  452. if self.config._attn_implementation == "flex_attention":
  453. if isinstance(attention_mask, torch.Tensor):
  454. attention_mask = make_flex_block_causal_mask(attention_mask)
  455. # Other attention flavors support in-built causal (when `mask is None`)
  456. # while we need to create our specific block mask regardless
  457. elif attention_mask is None:
  458. attention_mask = make_flex_block_causal_mask(
  459. torch.ones(
  460. size=(input_tensor.shape[0], input_tensor.shape[1]),
  461. device=attention_mask.device,
  462. )
  463. )
  464. return attention_mask
  465. if self.config._attn_implementation == "flash_attention_2":
  466. if attention_mask is not None and (attention_mask == 0.0).any():
  467. return attention_mask
  468. return None
  469. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  470. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  471. # to infer the attention mask.
  472. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  473. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  474. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  475. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  476. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  477. attention_mask,
  478. inputs_embeds=input_tensor,
  479. past_key_values_length=past_seen_tokens,
  480. is_training=self.training,
  481. ):
  482. return None
  483. dtype = input_tensor.dtype
  484. sequence_length = input_tensor.shape[1]
  485. if using_compilable_cache:
  486. target_length = past_key_values.get_max_cache_shape()
  487. else:
  488. target_length = (
  489. attention_mask.shape[-1]
  490. if isinstance(attention_mask, torch.Tensor)
  491. else past_seen_tokens + sequence_length + 1
  492. )
  493. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  494. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  495. attention_mask,
  496. sequence_length=sequence_length,
  497. target_length=target_length,
  498. dtype=dtype,
  499. cache_position=cache_position,
  500. batch_size=input_tensor.shape[0],
  501. )
  502. if (
  503. self.config._attn_implementation == "sdpa"
  504. and attention_mask is not None
  505. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  506. ):
  507. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  508. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  509. # Details: https://github.com/pytorch/pytorch/issues/110213
  510. min_dtype = torch.finfo(dtype).min
  511. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  512. return causal_mask
  513. @staticmethod
  514. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  515. def _prepare_4d_causal_attention_mask_with_cache_position(
  516. attention_mask: torch.Tensor,
  517. sequence_length: int,
  518. target_length: int,
  519. dtype: torch.dtype,
  520. cache_position: torch.Tensor,
  521. batch_size: int,
  522. **kwargs,
  523. ):
  524. """
  525. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  526. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  527. Args:
  528. attention_mask (`torch.Tensor`):
  529. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  530. `(batch_size, 1, query_length, key_value_length)`.
  531. sequence_length (`int`):
  532. The sequence length being processed.
  533. target_length (`int`):
  534. The target length: when generating with static cache, the mask should be as long as the static cache,
  535. to account for the 0 padding, the part of the cache that is not filled yet.
  536. dtype (`torch.dtype`):
  537. The dtype to use for the 4D attention mask.
  538. cache_position (`torch.Tensor`):
  539. Indices depicting the position of the input sequence tokens in the sequence.
  540. batch_size (`torch.Tensor`):
  541. Batch size.
  542. """
  543. if attention_mask is not None and attention_mask.dim() == 4:
  544. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  545. causal_mask = attention_mask
  546. else:
  547. min_dtype = torch.finfo(dtype).min
  548. causal_mask = torch.full(
  549. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  550. )
  551. if sequence_length != 1:
  552. causal_mask = torch.triu(causal_mask, diagonal=1)
  553. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  554. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  555. if attention_mask is not None:
  556. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  557. mask_length = attention_mask.shape[-1]
  558. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  559. causal_mask.device
  560. )
  561. padding_mask = padding_mask == 0
  562. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  563. padding_mask, min_dtype
  564. )
  565. return causal_mask
  566. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
  567. def _update_cross_attn_mask(
  568. self,
  569. encoder_hidden_states: Union[torch.Tensor, None],
  570. encoder_attention_mask: Union[torch.Tensor, None],
  571. input_shape: torch.Size,
  572. inputs_embeds: torch.Tensor,
  573. ):
  574. # expand encoder attention mask
  575. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  576. if self.config._attn_implementation == "flash_attention_2":
  577. encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
  578. elif self.config._attn_implementation == "sdpa":
  579. # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
  580. # the manual implementation that requires a 4D causal mask in all cases.
  581. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  582. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  583. encoder_attention_mask,
  584. inputs_embeds.dtype,
  585. tgt_len=input_shape[-1],
  586. )
  587. elif self.config._attn_implementation == "flex_attention":
  588. if isinstance(encoder_attention_mask, torch.Tensor):
  589. encoder_attention_mask = make_flex_block_causal_mask(
  590. encoder_attention_mask,
  591. query_length=input_shape[-1],
  592. is_causal=False,
  593. )
  594. else:
  595. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  596. encoder_attention_mask = _prepare_4d_attention_mask(
  597. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  598. )
  599. return encoder_attention_mask
  600. class BlenderbotEncoder(BlenderbotPreTrainedModel):
  601. """
  602. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  603. [`BlenderbotEncoderLayer`].
  604. Args:
  605. config: BlenderbotConfig
  606. embed_tokens (nn.Embedding): output embedding
  607. """
  608. def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
  609. super().__init__(config)
  610. self.dropout = config.dropout
  611. self.layerdrop = config.encoder_layerdrop
  612. embed_dim = config.d_model
  613. self.padding_idx = config.pad_token_id
  614. self.max_source_positions = config.max_position_embeddings
  615. embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  616. if embed_tokens is not None:
  617. self.embed_tokens = embed_tokens
  618. else:
  619. self.embed_tokens = BlenderbotScaledWordEmbedding(
  620. config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
  621. )
  622. self.embed_positions = BlenderbotLearnedPositionalEmbedding(
  623. config.max_position_embeddings,
  624. embed_dim,
  625. )
  626. self.layers = nn.ModuleList([BlenderbotEncoderLayer(config) for _ in range(config.encoder_layers)])
  627. self.layer_norm = nn.LayerNorm(config.d_model)
  628. self.gradient_checkpointing = False
  629. # Initialize weights and apply final processing
  630. self.post_init()
  631. def forward(
  632. self,
  633. input_ids=None,
  634. attention_mask=None,
  635. head_mask=None,
  636. inputs_embeds=None,
  637. output_attentions=None,
  638. output_hidden_states=None,
  639. return_dict=None,
  640. ):
  641. r"""
  642. Args:
  643. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  644. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  645. provide it.
  646. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  647. [`PreTrainedTokenizer.__call__`] for details.
  648. [What are input IDs?](../glossary#input-ids)
  649. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  650. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  651. - 1 for tokens that are **not masked**,
  652. - 0 for tokens that are **masked**.
  653. [What are attention masks?](../glossary#attention-mask)
  654. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  655. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  656. - 1 indicates the head is **not masked**,
  657. - 0 indicates the head is **masked**.
  658. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  659. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  660. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  661. than the model's internal embedding lookup matrix.
  662. output_attentions (`bool`, *optional*):
  663. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  664. returned tensors for more detail.
  665. output_hidden_states (`bool`, *optional*):
  666. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  667. for more detail.
  668. return_dict (`bool`, *optional*):
  669. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  670. """
  671. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  672. output_hidden_states = (
  673. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  674. )
  675. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  676. # retrieve input_ids and inputs_embeds
  677. if input_ids is not None and inputs_embeds is not None:
  678. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  679. elif input_ids is not None:
  680. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  681. input_shape = input_ids.size()
  682. input_ids = input_ids.view(-1, input_shape[-1])
  683. elif inputs_embeds is not None:
  684. input_shape = inputs_embeds.size()[:-1]
  685. else:
  686. raise ValueError("You have to specify either input_ids or inputs_embeds")
  687. if inputs_embeds is None:
  688. inputs_embeds = self.embed_tokens(input_ids)
  689. embed_pos = self.embed_positions(input_shape)
  690. hidden_states = inputs_embeds + embed_pos
  691. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  692. attention_mask = self._update_full_mask(
  693. attention_mask,
  694. inputs_embeds,
  695. )
  696. encoder_states = () if output_hidden_states else None
  697. all_attentions = () if output_attentions else None
  698. # check if head_mask has a correct number of layers specified if desired
  699. if head_mask is not None:
  700. if head_mask.size()[0] != len(self.layers):
  701. raise ValueError(
  702. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  703. f" {head_mask.size()[0]}."
  704. )
  705. for idx, encoder_layer in enumerate(self.layers):
  706. if output_hidden_states:
  707. encoder_states = encoder_states + (hidden_states,)
  708. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  709. to_drop = False
  710. if self.training:
  711. dropout_probability = torch.rand([])
  712. if dropout_probability < self.layerdrop: # skip the layer
  713. to_drop = True
  714. if to_drop:
  715. layer_outputs = (None, None)
  716. else:
  717. layer_outputs = encoder_layer(
  718. hidden_states,
  719. attention_mask,
  720. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  721. output_attentions=output_attentions,
  722. )
  723. hidden_states = layer_outputs[0]
  724. if output_attentions:
  725. all_attentions = all_attentions + (layer_outputs[1],)
  726. # add final layer norm
  727. hidden_states = self.layer_norm(hidden_states)
  728. if output_hidden_states:
  729. encoder_states = encoder_states + (hidden_states,)
  730. if not return_dict:
  731. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  732. return BaseModelOutput(
  733. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  734. )
  735. class BlenderbotDecoder(BlenderbotPreTrainedModel):
  736. """
  737. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BlenderbotDecoderLayer`]
  738. Args:
  739. config: BlenderbotConfig
  740. embed_tokens (nn.Embedding): output embedding
  741. """
  742. def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
  743. super().__init__(config)
  744. self.dropout = config.dropout
  745. self.layerdrop = config.decoder_layerdrop
  746. self.padding_idx = config.pad_token_id
  747. self.max_target_positions = config.max_position_embeddings
  748. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  749. if embed_tokens is not None:
  750. self.embed_tokens = embed_tokens
  751. else:
  752. self.embed_tokens = BlenderbotScaledWordEmbedding(
  753. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  754. )
  755. self.embed_positions = BlenderbotLearnedPositionalEmbedding(
  756. config.max_position_embeddings,
  757. config.d_model,
  758. )
  759. self.layers = nn.ModuleList(
  760. [BlenderbotDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]
  761. )
  762. self.layer_norm = nn.LayerNorm(config.d_model)
  763. self.gradient_checkpointing = False
  764. # Initialize weights and apply final processing
  765. self.post_init()
  766. def forward(
  767. self,
  768. input_ids=None,
  769. attention_mask=None,
  770. encoder_hidden_states=None,
  771. encoder_attention_mask=None,
  772. head_mask=None,
  773. cross_attn_head_mask=None,
  774. past_key_values=None,
  775. inputs_embeds=None,
  776. use_cache=None,
  777. output_attentions=None,
  778. output_hidden_states=None,
  779. return_dict=None,
  780. cache_position: Optional[torch.Tensor] = None,
  781. ):
  782. r"""
  783. Args:
  784. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  785. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  786. provide it.
  787. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  788. [`PreTrainedTokenizer.__call__`] for details.
  789. [What are input IDs?](../glossary#input-ids)
  790. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  791. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  792. - 1 for tokens that are **not masked**,
  793. - 0 for tokens that are **masked**.
  794. [What are attention masks?](../glossary#attention-mask)
  795. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  796. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  797. of the decoder.
  798. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  799. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  800. selected in `[0, 1]`:
  801. - 1 for tokens that are **not masked**,
  802. - 0 for tokens that are **masked**.
  803. [What are attention masks?](../glossary#attention-mask)
  804. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  805. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0,
  806. 1]`:
  807. - 1 indicates the head is **not masked**,
  808. - 0 indicates the head is **masked**.
  809. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  810. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  811. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  812. - 1 indicates the head is **not masked**,
  813. - 0 indicates the head is **masked**.
  814. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  815. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  816. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  817. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  818. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  819. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  820. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  821. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  822. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  823. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  824. than the model's internal embedding lookup matrix.
  825. output_attentions (`bool`, *optional*):
  826. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  827. returned tensors for more detail.
  828. output_hidden_states (`bool`, *optional*):
  829. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  830. for more detail.
  831. return_dict (`bool`, *optional*):
  832. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  833. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  834. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  835. cache in the correct position and to infer the complete sequence length.
  836. """
  837. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  838. output_hidden_states = (
  839. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  840. )
  841. use_cache = use_cache if use_cache is not None else self.config.use_cache
  842. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  843. # retrieve input_ids and inputs_embeds
  844. if (input_ids is None) ^ (inputs_embeds is not None):
  845. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  846. elif input_ids is not None:
  847. input = input_ids
  848. input_shape = input.shape
  849. input_ids = input_ids.view(-1, input_shape[-1])
  850. elif inputs_embeds is not None:
  851. input_shape = inputs_embeds.size()[:-1]
  852. input = inputs_embeds[:, :, -1]
  853. else:
  854. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  855. if inputs_embeds is None:
  856. inputs_embeds = self.embed_tokens(input)
  857. if self.gradient_checkpointing and self.training:
  858. if use_cache:
  859. logger.warning_once(
  860. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  861. )
  862. use_cache = False
  863. # initialize `past_key_values`
  864. if use_cache and past_key_values is None:
  865. past_key_values = (
  866. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  867. if encoder_hidden_states is not None
  868. else DynamicCache(config=self.config)
  869. )
  870. if use_cache and isinstance(past_key_values, tuple):
  871. logger.warning_once(
  872. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  873. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  874. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  875. )
  876. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  877. batch_size, seq_length = inputs_embeds.size()[:-1]
  878. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  879. if cache_position is None:
  880. cache_position = torch.arange(
  881. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  882. )
  883. if attention_mask is None and not is_torchdynamo_compiling():
  884. # required mask seq length can be calculated via length of past cache
  885. mask_seq_length = past_key_values_length + seq_length
  886. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  887. self_attn_cache = (
  888. past_key_values.self_attention_cache
  889. if isinstance(past_key_values, EncoderDecoderCache)
  890. else past_key_values
  891. )
  892. causal_mask = self._update_causal_mask(
  893. attention_mask,
  894. inputs_embeds,
  895. cache_position,
  896. self_attn_cache,
  897. )
  898. encoder_attention_mask = self._update_cross_attn_mask(
  899. encoder_hidden_states,
  900. encoder_attention_mask,
  901. input_shape,
  902. inputs_embeds,
  903. )
  904. # embed positions
  905. position_ids = self.embed_positions(
  906. (batch_size, seq_length), past_key_values_length, position_ids=cache_position
  907. )
  908. hidden_states = inputs_embeds + position_ids
  909. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  910. # decoder layers
  911. all_hidden_states = () if output_hidden_states else None
  912. all_self_attns = () if output_attentions else None
  913. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  914. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  915. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  916. if attn_mask is not None:
  917. if attn_mask.size()[0] != len(self.layers):
  918. raise ValueError(
  919. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  920. f" {head_mask.size()[0]}."
  921. )
  922. for idx, decoder_layer in enumerate(self.layers):
  923. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  924. if output_hidden_states:
  925. all_hidden_states += (hidden_states,)
  926. if self.training:
  927. dropout_probability = torch.rand([])
  928. if dropout_probability < self.layerdrop:
  929. continue
  930. layer_outputs = decoder_layer(
  931. hidden_states,
  932. causal_mask,
  933. encoder_hidden_states, # as a positional argument for gradient checkpointing
  934. encoder_attention_mask=encoder_attention_mask,
  935. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  936. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  937. past_key_values=past_key_values,
  938. output_attentions=output_attentions,
  939. use_cache=use_cache,
  940. cache_position=cache_position,
  941. )
  942. hidden_states = layer_outputs[0]
  943. if output_attentions:
  944. all_self_attns += (layer_outputs[1],)
  945. if encoder_hidden_states is not None:
  946. all_cross_attentions += (layer_outputs[2],)
  947. # add final layer norm
  948. hidden_states = self.layer_norm(hidden_states)
  949. # add hidden states from the last decoder layer
  950. if output_hidden_states:
  951. all_hidden_states += (hidden_states,)
  952. if not return_dict:
  953. return tuple(
  954. v
  955. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  956. if v is not None
  957. )
  958. return BaseModelOutputWithPastAndCrossAttentions(
  959. last_hidden_state=hidden_states,
  960. past_key_values=past_key_values,
  961. hidden_states=all_hidden_states,
  962. attentions=all_self_attns,
  963. cross_attentions=all_cross_attentions,
  964. )
  965. @auto_docstring
  966. class BlenderbotModel(BlenderbotPreTrainedModel):
  967. _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
  968. def __init__(self, config: BlenderbotConfig):
  969. super().__init__(config)
  970. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  971. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  972. self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
  973. self.encoder = BlenderbotEncoder(config, self.shared)
  974. self.decoder = BlenderbotDecoder(config, self.shared)
  975. # Initialize weights and apply final processing
  976. self.post_init()
  977. @classmethod
  978. def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
  979. if pretrained_model_name_or_path == "facebook/blenderbot-90M":
  980. warnings.warn(
  981. "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical"
  982. " checkpoint `facebook/small_blenderbot-90M` with"
  983. " `BlenderbotSmallModel.from_pretrained('facebook/small_blenderbot-90M')` instead.",
  984. FutureWarning,
  985. )
  986. return BlenderbotSmallModel.from_pretrained(pretrained_model_name_or_path)
  987. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  988. def get_input_embeddings(self):
  989. return self.shared
  990. def set_input_embeddings(self, value):
  991. self.shared = value
  992. self.encoder.embed_tokens = self.shared
  993. self.decoder.embed_tokens = self.shared
  994. def get_encoder(self):
  995. return self.encoder
  996. @auto_docstring
  997. def forward(
  998. self,
  999. input_ids: Optional[torch.LongTensor] = None,
  1000. attention_mask: Optional[torch.Tensor] = None,
  1001. decoder_input_ids: Optional[torch.LongTensor] = None,
  1002. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1003. head_mask: Optional[torch.Tensor] = None,
  1004. decoder_head_mask: Optional[torch.Tensor] = None,
  1005. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1006. encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None,
  1007. past_key_values: Optional[Cache] = None,
  1008. inputs_embeds: Optional[torch.Tensor] = None,
  1009. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1010. use_cache: Optional[bool] = None,
  1011. output_attentions: Optional[bool] = None,
  1012. output_hidden_states: Optional[bool] = None,
  1013. return_dict: Optional[bool] = None,
  1014. cache_position: Optional[torch.Tensor] = None,
  1015. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1016. r"""
  1017. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1018. Indices of decoder input sequence tokens in the vocabulary.
  1019. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1020. [`PreTrainedTokenizer.__call__`] for details.
  1021. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1022. Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If
  1023. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1024. `past_key_values`).
  1025. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1026. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1027. be used by default.
  1028. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1029. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1030. 1]`:
  1031. - 1 indicates the head is **not masked**,
  1032. - 0 indicates the head is **masked**.
  1033. Example:
  1034. ```python
  1035. >>> from transformers import AutoTokenizer, BlenderbotModel
  1036. >>> model = BlenderbotModel.from_pretrained("facebook/blenderbot-400M-distill")
  1037. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
  1038. >>> inputs = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt")
  1039. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1040. >>> outputs = model(input_ids=inputs.input_ids, decoder_input_ids=decoder_input_ids)
  1041. >>> last_hidden_states = outputs.last_hidden_state
  1042. >>> list(last_hidden_states.shape)
  1043. [1, 6, 1280]
  1044. ```"""
  1045. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1046. output_hidden_states = (
  1047. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1048. )
  1049. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1050. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1051. if encoder_outputs is None:
  1052. encoder_outputs = self.encoder(
  1053. input_ids=input_ids,
  1054. attention_mask=attention_mask,
  1055. head_mask=head_mask,
  1056. inputs_embeds=inputs_embeds,
  1057. output_attentions=output_attentions,
  1058. output_hidden_states=output_hidden_states,
  1059. return_dict=return_dict,
  1060. )
  1061. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1062. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1063. encoder_outputs = BaseModelOutput(
  1064. last_hidden_state=encoder_outputs[0],
  1065. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1066. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1067. )
  1068. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1069. decoder_outputs = self.decoder(
  1070. input_ids=decoder_input_ids,
  1071. attention_mask=decoder_attention_mask,
  1072. encoder_hidden_states=encoder_outputs[0],
  1073. encoder_attention_mask=attention_mask,
  1074. head_mask=decoder_head_mask,
  1075. cross_attn_head_mask=cross_attn_head_mask,
  1076. past_key_values=past_key_values,
  1077. inputs_embeds=decoder_inputs_embeds,
  1078. use_cache=use_cache,
  1079. output_attentions=output_attentions,
  1080. output_hidden_states=output_hidden_states,
  1081. return_dict=return_dict,
  1082. cache_position=cache_position,
  1083. )
  1084. if not return_dict:
  1085. return decoder_outputs + encoder_outputs
  1086. return Seq2SeqModelOutput(
  1087. last_hidden_state=decoder_outputs.last_hidden_state,
  1088. past_key_values=decoder_outputs.past_key_values,
  1089. decoder_hidden_states=decoder_outputs.hidden_states,
  1090. decoder_attentions=decoder_outputs.attentions,
  1091. cross_attentions=decoder_outputs.cross_attentions,
  1092. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1093. encoder_hidden_states=encoder_outputs.hidden_states,
  1094. encoder_attentions=encoder_outputs.attentions,
  1095. )
  1096. @auto_docstring(
  1097. custom_intro="""
  1098. The Blenderbot Model with a language modeling head. Can be used for summarization.
  1099. """
  1100. )
  1101. class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
  1102. base_model_prefix = "model"
  1103. _keys_to_ignore_on_load_missing = ["final_logits_bias"]
  1104. _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
  1105. def __init__(self, config: BlenderbotConfig):
  1106. super().__init__(config)
  1107. self.model = BlenderbotModel(config)
  1108. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  1109. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  1110. # Initialize weights and apply final processing
  1111. self.post_init()
  1112. @classmethod
  1113. def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
  1114. if pretrained_model_name_or_path == "facebook/blenderbot-90M":
  1115. warnings.warn(
  1116. "The checkpoint `facebook/blenderbot-90M` is deprecated. In the future, please use the identical"
  1117. " checkpoint `facebook/small_blenderbot-90M` with"
  1118. " `BlenderbotSmallForConditionalGeneration.from_pretrained('facebook/small_blenderbot-90M')` instead.",
  1119. FutureWarning,
  1120. )
  1121. return BlenderbotSmallForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
  1122. return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
  1123. def get_encoder(self):
  1124. return self.model.get_encoder()
  1125. def get_decoder(self):
  1126. return self.model.get_decoder()
  1127. def resize_token_embeddings(
  1128. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  1129. ) -> nn.Embedding:
  1130. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  1131. self._resize_final_logits_bias(new_embeddings.weight.shape[0])
  1132. return new_embeddings
  1133. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  1134. old_num_tokens = self.final_logits_bias.shape[-1]
  1135. if new_num_tokens <= old_num_tokens:
  1136. new_bias = self.final_logits_bias[:, :new_num_tokens]
  1137. else:
  1138. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  1139. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  1140. self.register_buffer("final_logits_bias", new_bias)
  1141. @auto_docstring
  1142. def forward(
  1143. self,
  1144. input_ids: Optional[torch.LongTensor] = None,
  1145. attention_mask: Optional[torch.Tensor] = None,
  1146. decoder_input_ids: Optional[torch.LongTensor] = None,
  1147. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1148. head_mask: Optional[torch.Tensor] = None,
  1149. decoder_head_mask: Optional[torch.Tensor] = None,
  1150. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1151. encoder_outputs: Optional[Union[tuple, BaseModelOutput]] = None,
  1152. past_key_values: Optional[Cache] = None,
  1153. inputs_embeds: Optional[torch.Tensor] = None,
  1154. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1155. labels: Optional[torch.LongTensor] = None,
  1156. use_cache: Optional[bool] = None,
  1157. output_attentions: Optional[bool] = None,
  1158. output_hidden_states: Optional[bool] = None,
  1159. return_dict: Optional[bool] = None,
  1160. cache_position: Optional[torch.Tensor] = None,
  1161. ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1162. r"""
  1163. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1164. Indices of decoder input sequence tokens in the vocabulary.
  1165. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1166. [`PreTrainedTokenizer.__call__`] for details.
  1167. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1168. Blenderbot uses the `bos_token_id` as the starting token for `decoder_input_ids` generation. If
  1169. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1170. `past_key_values`).
  1171. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1172. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1173. be used by default.
  1174. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1175. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1176. 1]`:
  1177. - 1 indicates the head is **not masked**,
  1178. - 0 indicates the head is **masked**.
  1179. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1180. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1181. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1182. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1183. Example conversation:
  1184. ```python
  1185. >>> from transformers import AutoTokenizer, BlenderbotForConditionalGeneration
  1186. >>> mname = "facebook/blenderbot-400M-distill"
  1187. >>> model = BlenderbotForConditionalGeneration.from_pretrained(mname)
  1188. >>> tokenizer = AutoTokenizer.from_pretrained(mname)
  1189. >>> UTTERANCE = "My friends are cool but they eat too many carbs."
  1190. >>> print("Human: ", UTTERANCE)
  1191. Human: My friends are cool but they eat too many carbs.
  1192. >>> inputs = tokenizer([UTTERANCE], return_tensors="pt")
  1193. >>> reply_ids = model.generate(**inputs)
  1194. >>> print("Bot: ", tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0])
  1195. Bot: That's unfortunate. Are they trying to lose weight or are they just trying to be healthier?
  1196. >>> REPLY = "I'm not sure"
  1197. >>> print("Human: ", REPLY)
  1198. Human: I'm not sure
  1199. >>> NEXT_UTTERANCE = (
  1200. ... "My friends are cool but they eat too many carbs.</s> <s>That's unfortunate. "
  1201. ... "Are they trying to lose weight or are they just trying to be healthier?</s> "
  1202. ... "<s> I'm not sure."
  1203. ... )
  1204. >>> inputs = tokenizer([NEXT_UTTERANCE], return_tensors="pt")
  1205. >>> next_reply_ids = model.generate(**inputs)
  1206. >>> print("Bot: ", tokenizer.batch_decode(next_reply_ids, skip_special_tokens=True)[0])
  1207. Bot: I see. Well, it's good that they're trying to change their eating habits.
  1208. ```
  1209. """
  1210. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1211. if labels is not None:
  1212. if use_cache:
  1213. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1214. use_cache = False
  1215. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1216. decoder_input_ids = shift_tokens_right(
  1217. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1218. )
  1219. outputs = self.model(
  1220. input_ids,
  1221. attention_mask=attention_mask,
  1222. decoder_input_ids=decoder_input_ids,
  1223. encoder_outputs=encoder_outputs,
  1224. decoder_attention_mask=decoder_attention_mask,
  1225. head_mask=head_mask,
  1226. decoder_head_mask=decoder_head_mask,
  1227. cross_attn_head_mask=cross_attn_head_mask,
  1228. past_key_values=past_key_values,
  1229. inputs_embeds=inputs_embeds,
  1230. decoder_inputs_embeds=decoder_inputs_embeds,
  1231. use_cache=use_cache,
  1232. output_attentions=output_attentions,
  1233. output_hidden_states=output_hidden_states,
  1234. return_dict=return_dict,
  1235. cache_position=cache_position,
  1236. )
  1237. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1238. masked_lm_loss = None
  1239. if labels is not None:
  1240. loss_fct = CrossEntropyLoss()
  1241. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1242. if not return_dict:
  1243. output = (lm_logits,) + outputs[1:]
  1244. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1245. return Seq2SeqLMOutput(
  1246. loss=masked_lm_loss,
  1247. logits=lm_logits,
  1248. past_key_values=outputs.past_key_values,
  1249. decoder_hidden_states=outputs.decoder_hidden_states,
  1250. decoder_attentions=outputs.decoder_attentions,
  1251. cross_attentions=outputs.cross_attentions,
  1252. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1253. encoder_hidden_states=outputs.encoder_hidden_states,
  1254. encoder_attentions=outputs.encoder_attentions,
  1255. )
  1256. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot
  1257. class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
  1258. """
  1259. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1260. used in combination with the [`EncoderDecoderModel`] framework.
  1261. """
  1262. def __init__(self, config):
  1263. super().__init__(config)
  1264. self.decoder = BlenderbotDecoder(config)
  1265. def forward(self, *args, **kwargs):
  1266. return self.decoder(*args, **kwargs)
  1267. # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
  1268. class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
  1269. _tied_weights_keys = ["lm_head.weight"]
  1270. def __init__(self, config):
  1271. config.is_decoder = True
  1272. config.is_encoder_decoder = False
  1273. super().__init__(config)
  1274. self.model = BlenderbotDecoderWrapper(config)
  1275. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1276. # Initialize weights and apply final processing
  1277. self.post_init()
  1278. def get_input_embeddings(self):
  1279. return self.model.decoder.embed_tokens
  1280. def set_input_embeddings(self, value):
  1281. self.model.decoder.embed_tokens = value
  1282. def set_decoder(self, decoder):
  1283. self.model.decoder = decoder
  1284. def get_decoder(self):
  1285. return self.model.decoder
  1286. @auto_docstring
  1287. def forward(
  1288. self,
  1289. input_ids: Optional[torch.LongTensor] = None,
  1290. attention_mask: Optional[torch.Tensor] = None,
  1291. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1292. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1293. head_mask: Optional[torch.Tensor] = None,
  1294. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1295. past_key_values: Optional[Cache] = None,
  1296. inputs_embeds: Optional[torch.FloatTensor] = None,
  1297. labels: Optional[torch.LongTensor] = None,
  1298. use_cache: Optional[bool] = None,
  1299. output_attentions: Optional[bool] = None,
  1300. output_hidden_states: Optional[bool] = None,
  1301. return_dict: Optional[bool] = None,
  1302. cache_position: Optional[torch.LongTensor] = None,
  1303. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1304. r"""
  1305. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1306. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1307. - 1 indicates the head is **not masked**,
  1308. - 0 indicates the head is **masked**.
  1309. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1310. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1311. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1312. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1313. Example:
  1314. ```python
  1315. >>> from transformers import AutoTokenizer, BlenderbotForCausalLM
  1316. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
  1317. >>> model = BlenderbotForCausalLM.from_pretrained("facebook/blenderbot-400M-distill", add_cross_attention=False)
  1318. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1319. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1320. >>> outputs = model(**inputs)
  1321. >>> logits = outputs.logits
  1322. >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
  1323. >>> list(logits.shape) == expected_shape
  1324. True
  1325. ```"""
  1326. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1327. output_hidden_states = (
  1328. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1329. )
  1330. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1331. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1332. outputs = self.model.decoder(
  1333. input_ids=input_ids,
  1334. attention_mask=attention_mask,
  1335. encoder_hidden_states=encoder_hidden_states,
  1336. encoder_attention_mask=encoder_attention_mask,
  1337. head_mask=head_mask,
  1338. cross_attn_head_mask=cross_attn_head_mask,
  1339. past_key_values=past_key_values,
  1340. inputs_embeds=inputs_embeds,
  1341. use_cache=use_cache,
  1342. output_attentions=output_attentions,
  1343. output_hidden_states=output_hidden_states,
  1344. return_dict=return_dict,
  1345. cache_position=cache_position,
  1346. )
  1347. logits = self.lm_head(outputs[0])
  1348. loss = None
  1349. if labels is not None:
  1350. labels = labels.to(logits.device)
  1351. loss_fct = CrossEntropyLoss()
  1352. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1353. if not return_dict:
  1354. output = (logits,) + outputs[1:]
  1355. return (loss,) + output if loss is not None else output
  1356. return CausalLMOutputWithCrossAttentions(
  1357. loss=loss,
  1358. logits=logits,
  1359. past_key_values=outputs.past_key_values,
  1360. hidden_states=outputs.hidden_states,
  1361. attentions=outputs.attentions,
  1362. cross_attentions=outputs.cross_attentions,
  1363. )
  1364. __all__ = [
  1365. "BlenderbotForCausalLM",
  1366. "BlenderbotForConditionalGeneration",
  1367. "BlenderbotModel",
  1368. "BlenderbotPreTrainedModel",
  1369. ]