modeling_xglm.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch XGLM model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import auto_docstring, logging
  28. from ...utils.deprecation import deprecate_kwarg
  29. from .configuration_xglm import XGLMConfig
  30. logger = logging.get_logger(__name__)
  31. # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->XGLM
  32. class XGLMScaledWordEmbedding(nn.Embedding):
  33. """
  34. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  35. """
  36. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  37. super().__init__(num_embeddings, embedding_dim, padding_idx)
  38. self.embed_scale = embed_scale
  39. def forward(self, input_ids: torch.Tensor):
  40. return super().forward(input_ids) * self.embed_scale
  41. class XGLMSinusoidalPositionalEmbedding(nn.Module):
  42. """This module produces sinusoidal positional embeddings of any length."""
  43. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
  44. super().__init__()
  45. self.offset = 2
  46. self.embedding_dim = embedding_dim
  47. self.padding_idx = padding_idx
  48. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  49. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  50. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  51. if hasattr(self, "weights"):
  52. # in forward put the weights on the correct dtype and device of the param
  53. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  54. self.register_buffer("weights", emb_weights, persistent=False)
  55. @staticmethod
  56. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  57. """
  58. Build sinusoidal embeddings.
  59. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
  60. "Attention Is All You Need".
  61. """
  62. half_dim = embedding_dim // 2
  63. emb = math.log(10000) / (half_dim - 1)
  64. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  65. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  66. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  67. if embedding_dim % 2 == 1:
  68. # zero pad
  69. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  70. if padding_idx is not None:
  71. emb[padding_idx, :] = 0
  72. return emb.to(torch.get_default_dtype())
  73. @torch.no_grad()
  74. def forward(self, position_ids: Optional[torch.Tensor] = None, past_key_values_length: int = 0):
  75. bsz, seq_len = position_ids.size()
  76. position_ids += self.offset
  77. # Expand embeddings if needed. `position_ids.max()` is NOT used to keep torch.fx compatibility.
  78. max_pos = 2 + seq_len + past_key_values_length
  79. if max_pos > self.weights.size(0):
  80. self.make_weights(max_pos, self.embedding_dim, self.padding_idx)
  81. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
  82. class XGLMAttention(nn.Module):
  83. """Multi-headed attention from 'Attention Is All You Need' paper"""
  84. def __init__(
  85. self,
  86. embed_dim: int,
  87. num_heads: int,
  88. dropout: Optional[float] = 0.0,
  89. is_decoder: Optional[bool] = False,
  90. bias: Optional[bool] = True,
  91. layer_idx: Optional[bool] = None,
  92. ):
  93. super().__init__()
  94. self.embed_dim = embed_dim
  95. self.num_heads = num_heads
  96. self.dropout = dropout
  97. self.head_dim = embed_dim // num_heads
  98. if (self.head_dim * num_heads) != self.embed_dim:
  99. raise ValueError(
  100. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  101. f" and `num_heads`: {num_heads})."
  102. )
  103. self.scaling = self.head_dim**-0.5
  104. self.is_decoder = is_decoder
  105. self.layer_idx = layer_idx
  106. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  107. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  108. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  109. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  110. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  111. def forward(
  112. self,
  113. hidden_states: torch.Tensor,
  114. key_value_states: Optional[torch.Tensor] = None,
  115. past_key_values: Optional[Cache] = None,
  116. attention_mask: Optional[torch.Tensor] = None,
  117. layer_head_mask: Optional[torch.Tensor] = None,
  118. output_attentions: bool = False,
  119. cache_position: Optional[torch.Tensor] = None,
  120. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  121. """Input shape: Batch x Time x Channel"""
  122. # if key_value_states are provided this layer is used as a cross-attention layer
  123. # for the decoder
  124. is_cross_attention = key_value_states is not None
  125. bsz, tgt_len, _ = hidden_states.size()
  126. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  127. # get query proj
  128. query_states = self.q_proj(hidden_states) * self.scaling
  129. is_updated = False
  130. if past_key_values is not None:
  131. if isinstance(past_key_values, EncoderDecoderCache):
  132. is_updated = past_key_values.is_updated.get(self.layer_idx)
  133. if is_cross_attention:
  134. # after the first generated id, we can subsequently re-use all key/value_states from cache
  135. curr_past_key_value = past_key_values.cross_attention_cache
  136. else:
  137. curr_past_key_value = past_key_values.self_attention_cache
  138. else:
  139. curr_past_key_value = past_key_values
  140. current_states = key_value_states if is_cross_attention else hidden_states
  141. if is_cross_attention and past_key_values is not None and is_updated:
  142. # reuse k,v, cross_attentions
  143. key_states = curr_past_key_value.layers[self.layer_idx].keys
  144. value_states = curr_past_key_value.layers[self.layer_idx].values
  145. else:
  146. key_states = self.k_proj(current_states)
  147. value_states = self.v_proj(current_states)
  148. key_states = key_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2)
  149. value_states = value_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2)
  150. if past_key_values is not None:
  151. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  152. cache_position = cache_position if not is_cross_attention else None
  153. key_states, value_states = curr_past_key_value.update(
  154. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  155. )
  156. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  157. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  158. past_key_values.is_updated[self.layer_idx] = True
  159. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  160. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  161. query_states = query_states.reshape(*proj_shape)
  162. key_states = key_states.reshape(*proj_shape)
  163. value_states = value_states.reshape(*proj_shape)
  164. src_len = key_states.size(1)
  165. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  166. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  167. raise ValueError(
  168. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  169. f" {attn_weights.size()}"
  170. )
  171. if attention_mask is not None:
  172. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  173. raise ValueError(
  174. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  175. )
  176. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  177. attn_weights = torch.max(
  178. attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
  179. )
  180. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  181. # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
  182. if attn_weights.dtype == torch.float16:
  183. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
  184. else:
  185. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  186. if layer_head_mask is not None:
  187. if layer_head_mask.size() != (self.num_heads,):
  188. raise ValueError(
  189. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  190. f" {layer_head_mask.size()}"
  191. )
  192. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  193. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  194. if output_attentions:
  195. # this operation is a bit awkward, but it's required to
  196. # make sure that attn_weights keeps its gradient.
  197. # In order to do so, attn_weights have to be reshaped
  198. # twice and have to be reused in the following
  199. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  200. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  201. else:
  202. attn_weights_reshaped = None
  203. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  204. attn_output = torch.bmm(attn_probs, value_states)
  205. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  206. raise ValueError(
  207. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  208. f" {attn_output.size()}"
  209. )
  210. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  211. attn_output = attn_output.transpose(1, 2)
  212. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  213. # partitioned across GPUs when using tensor-parallelism.
  214. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  215. attn_output = self.out_proj(attn_output)
  216. return attn_output, attn_weights_reshaped
  217. class XGLMDecoderLayer(GradientCheckpointingLayer):
  218. def __init__(self, config: XGLMConfig, layer_idx=None):
  219. super().__init__()
  220. self.embed_dim = config.d_model
  221. self.self_attn = XGLMAttention(
  222. embed_dim=self.embed_dim,
  223. num_heads=config.attention_heads,
  224. dropout=config.attention_dropout,
  225. is_decoder=True,
  226. layer_idx=layer_idx,
  227. )
  228. self.dropout = config.dropout
  229. self.activation_fn = ACT2FN[config.activation_function]
  230. self.activation_dropout = config.activation_dropout
  231. if config.add_cross_attention:
  232. self.encoder_attn = XGLMAttention(
  233. embed_dim=self.embed_dim,
  234. num_heads=config.attention_heads,
  235. dropout=config.attention_dropout,
  236. is_decoder=True,
  237. layer_idx=layer_idx,
  238. )
  239. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  240. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  241. self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim)
  242. self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
  243. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  244. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  245. # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: Optional[torch.Tensor] = None,
  250. encoder_hidden_states: Optional[torch.Tensor] = None,
  251. encoder_attention_mask: Optional[torch.Tensor] = None,
  252. layer_head_mask: Optional[torch.Tensor] = None,
  253. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  254. past_key_values: Optional[Cache] = None,
  255. output_attentions: Optional[bool] = False,
  256. use_cache: Optional[bool] = True,
  257. cache_position: Optional[torch.Tensor] = None,
  258. ) -> torch.Tensor:
  259. """
  260. Args:
  261. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  262. attention_mask (`torch.FloatTensor`): attention mask of size
  263. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  264. encoder_hidden_states (`torch.FloatTensor`):
  265. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  266. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  267. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  268. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  269. `(encoder_attention_heads,)`.
  270. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  271. size `(decoder_attention_heads,)`.
  272. past_key_values (`Cache`): cached past key and value projection states
  273. output_attentions (`bool`, *optional*):
  274. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  275. returned tensors for more detail.
  276. """
  277. residual = hidden_states
  278. hidden_states = self.self_attn_layer_norm(hidden_states)
  279. # Self Attention
  280. hidden_states, self_attn_weights = self.self_attn(
  281. hidden_states=hidden_states,
  282. past_key_values=past_key_values,
  283. attention_mask=attention_mask,
  284. layer_head_mask=layer_head_mask,
  285. output_attentions=output_attentions,
  286. cache_position=cache_position,
  287. )
  288. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  289. hidden_states = residual + hidden_states
  290. # Cross-Attention Block
  291. cross_attn_weights = None
  292. if encoder_hidden_states is not None:
  293. residual = hidden_states
  294. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  295. hidden_states, cross_attn_weights = self.encoder_attn(
  296. hidden_states=hidden_states,
  297. key_value_states=encoder_hidden_states,
  298. attention_mask=encoder_attention_mask,
  299. layer_head_mask=cross_attn_layer_head_mask,
  300. past_key_values=past_key_values,
  301. output_attentions=output_attentions,
  302. cache_position=cache_position,
  303. )
  304. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  305. hidden_states = residual + hidden_states
  306. # Fully Connected
  307. residual = hidden_states
  308. hidden_states = self.final_layer_norm(hidden_states)
  309. hidden_states = self.activation_fn(self.fc1(hidden_states))
  310. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  311. hidden_states = self.fc2(hidden_states)
  312. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  313. hidden_states = residual + hidden_states
  314. outputs = (hidden_states,)
  315. if output_attentions:
  316. outputs += (self_attn_weights, cross_attn_weights)
  317. return outputs
  318. @auto_docstring
  319. class XGLMPreTrainedModel(PreTrainedModel):
  320. config: XGLMConfig
  321. base_model_prefix = "model"
  322. supports_gradient_checkpointing = True
  323. _no_split_modules = ["XGLMDecoderLayer"]
  324. def _init_weights(self, module):
  325. std = self.config.init_std
  326. if isinstance(module, nn.Linear):
  327. module.weight.data.normal_(mean=0.0, std=std)
  328. if module.bias is not None:
  329. module.bias.data.zero_()
  330. elif isinstance(module, nn.Embedding):
  331. module.weight.data.normal_(mean=0.0, std=std)
  332. if module.padding_idx is not None:
  333. module.weight.data[module.padding_idx].zero_()
  334. @auto_docstring
  335. class XGLMModel(XGLMPreTrainedModel):
  336. def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None):
  337. r"""
  338. embed_tokens (`nn.Embedding`, *optional*):
  339. output embeddings
  340. """
  341. super().__init__(config)
  342. self.dropout = config.dropout
  343. self.layerdrop = config.layerdrop
  344. self.padding_idx = config.pad_token_id
  345. self.max_target_positions = config.max_position_embeddings
  346. embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  347. if embed_tokens is not None:
  348. self.embed_tokens = embed_tokens
  349. else:
  350. self.embed_tokens = XGLMScaledWordEmbedding(
  351. config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
  352. )
  353. self.embed_positions = XGLMSinusoidalPositionalEmbedding(
  354. config.max_position_embeddings,
  355. config.d_model,
  356. config.pad_token_id,
  357. )
  358. self.layers = nn.ModuleList([XGLMDecoderLayer(config, layer_idx=i) for i in range(config.num_layers)])
  359. self.layer_norm = nn.LayerNorm(config.d_model)
  360. self.gradient_checkpointing = False
  361. # Initialize weights and apply final processing
  362. self.post_init()
  363. @auto_docstring
  364. def forward(
  365. self,
  366. input_ids: Optional[torch.Tensor] = None,
  367. attention_mask: Optional[torch.Tensor] = None,
  368. position_ids: Optional[torch.Tensor] = None,
  369. encoder_hidden_states: Optional[torch.Tensor] = None,
  370. encoder_attention_mask: Optional[torch.Tensor] = None,
  371. head_mask: Optional[torch.Tensor] = None,
  372. cross_attn_head_mask: Optional[torch.Tensor] = None,
  373. past_key_values: Optional[Cache] = None,
  374. inputs_embeds: Optional[torch.Tensor] = None,
  375. use_cache: Optional[bool] = None,
  376. output_attentions: Optional[bool] = None,
  377. output_hidden_states: Optional[bool] = None,
  378. return_dict: Optional[bool] = None,
  379. cache_position: Optional[torch.Tensor] = None,
  380. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  381. r"""
  382. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  383. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  384. the decoder.
  385. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  386. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  387. selected in `[0, 1]`:
  388. - 1 for tokens that are **not masked**,
  389. - 0 for tokens that are **masked**.
  390. [What are attention masks?](../glossary#attention-mask)
  391. cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
  392. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  393. - 1 indicates the head is **not masked**,
  394. - 0 indicates the head is **masked**.
  395. """
  396. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  397. output_hidden_states = (
  398. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  399. )
  400. use_cache = use_cache if use_cache is not None else self.config.use_cache
  401. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  402. # retrieve input_ids and inputs_embeds
  403. if input_ids is not None and inputs_embeds is not None:
  404. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  405. elif input_ids is not None:
  406. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  407. input_shape = input_ids.size()
  408. input_ids = input_ids.view(-1, input_shape[-1])
  409. elif inputs_embeds is not None:
  410. input_shape = inputs_embeds.size()[:-1]
  411. else:
  412. raise ValueError("You have to specify either input_ids or inputs_embeds")
  413. if inputs_embeds is None:
  414. inputs_embeds = self.embed_tokens(input_ids)
  415. if self.gradient_checkpointing and self.training:
  416. if use_cache:
  417. logger.warning_once(
  418. "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..."
  419. )
  420. use_cache = False
  421. # initialize `past_key_values`
  422. if use_cache and past_key_values is None:
  423. past_key_values = (
  424. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  425. if encoder_hidden_states is not None
  426. else DynamicCache(config=self.config)
  427. )
  428. if use_cache and isinstance(past_key_values, tuple):
  429. logger.warning_once(
  430. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  431. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  432. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  433. )
  434. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  435. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  436. attention_mask = _prepare_4d_causal_attention_mask(
  437. attention_mask, input_shape, inputs_embeds, past_key_values_length
  438. )
  439. if position_ids is None:
  440. position_ids = torch.arange(
  441. past_key_values_length,
  442. input_shape[-1] + past_key_values_length,
  443. dtype=torch.long,
  444. device=input_ids.device if input_ids is not None else inputs_embeds.device,
  445. )
  446. position_ids = position_ids.unsqueeze(0)
  447. # expand encoder attention mask
  448. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  449. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  450. encoder_attention_mask = _prepare_4d_attention_mask(
  451. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  452. )
  453. hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length).to(
  454. inputs_embeds.device
  455. )
  456. hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training)
  457. # decoder layers
  458. all_hidden_states = () if output_hidden_states else None
  459. all_self_attns = () if output_attentions else None
  460. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  461. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  462. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  463. if attn_mask is not None:
  464. if attn_mask.size()[0] != len(self.layers):
  465. raise ValueError(
  466. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  467. f" {head_mask.size()[0]}."
  468. )
  469. for idx, decoder_layer in enumerate(self.layers):
  470. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  471. if output_hidden_states:
  472. all_hidden_states += (hidden_states,)
  473. if self.training:
  474. dropout_probability = torch.rand([])
  475. if dropout_probability < self.layerdrop:
  476. continue
  477. layer_outputs = decoder_layer(
  478. hidden_states,
  479. attention_mask,
  480. encoder_hidden_states, # as a positional argument for gradient checkpointing
  481. encoder_attention_mask=encoder_attention_mask,
  482. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  483. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  484. past_key_values=past_key_values,
  485. output_attentions=output_attentions,
  486. use_cache=use_cache,
  487. cache_position=cache_position,
  488. )
  489. hidden_states = layer_outputs[0]
  490. if output_attentions:
  491. all_self_attns += (layer_outputs[1],)
  492. if encoder_hidden_states is not None:
  493. all_cross_attentions += (layer_outputs[2],)
  494. hidden_states = self.layer_norm(hidden_states)
  495. # add hidden states from the last decoder layer
  496. if output_hidden_states:
  497. all_hidden_states += (hidden_states,)
  498. if not return_dict:
  499. return tuple(
  500. v
  501. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  502. if v is not None
  503. )
  504. return BaseModelOutputWithPastAndCrossAttentions(
  505. last_hidden_state=hidden_states,
  506. past_key_values=past_key_values,
  507. hidden_states=all_hidden_states,
  508. attentions=all_self_attns,
  509. cross_attentions=all_cross_attentions,
  510. )
  511. @auto_docstring(
  512. custom_intro="""
  513. The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
  514. embeddings).
  515. """
  516. )
  517. class XGLMForCausalLM(XGLMPreTrainedModel, GenerationMixin):
  518. base_model_prefix = "model"
  519. _tied_weights_keys = ["lm_head.weight"]
  520. def __init__(self, config):
  521. super().__init__(config)
  522. self.model = XGLMModel(config)
  523. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  524. # Initialize weights and apply final processing
  525. self.post_init()
  526. @auto_docstring
  527. def forward(
  528. self,
  529. input_ids: Optional[torch.Tensor] = None,
  530. attention_mask: Optional[torch.Tensor] = None,
  531. position_ids: Optional[torch.Tensor] = None,
  532. encoder_hidden_states: Optional[torch.Tensor] = None,
  533. encoder_attention_mask: Optional[torch.Tensor] = None,
  534. head_mask: Optional[torch.Tensor] = None,
  535. cross_attn_head_mask: Optional[torch.Tensor] = None,
  536. past_key_values: Optional[Cache] = None,
  537. inputs_embeds: Optional[torch.Tensor] = None,
  538. labels: Optional[torch.Tensor] = None,
  539. use_cache: Optional[bool] = None,
  540. output_attentions: Optional[bool] = None,
  541. output_hidden_states: Optional[bool] = None,
  542. return_dict: Optional[bool] = None,
  543. cache_position: Optional[torch.Tensor] = None,
  544. **kwargs,
  545. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  546. r"""
  547. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  548. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
  549. the decoder.
  550. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  551. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  552. selected in `[0, 1]`:
  553. - 1 for tokens that are **not masked**,
  554. - 0 for tokens that are **masked**.
  555. [What are attention masks?](../glossary#attention-mask)
  556. cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
  557. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  558. - 1 indicates the head is **not masked**,
  559. - 0 indicates the head is **masked**.
  560. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  561. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  562. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  563. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  564. """
  565. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  566. output_hidden_states = (
  567. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  568. )
  569. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  570. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  571. outputs = self.model(
  572. input_ids=input_ids,
  573. attention_mask=attention_mask,
  574. position_ids=position_ids,
  575. encoder_hidden_states=encoder_hidden_states,
  576. encoder_attention_mask=encoder_attention_mask,
  577. head_mask=head_mask,
  578. cross_attn_head_mask=cross_attn_head_mask,
  579. past_key_values=past_key_values,
  580. inputs_embeds=inputs_embeds,
  581. use_cache=use_cache,
  582. output_attentions=output_attentions,
  583. output_hidden_states=output_hidden_states,
  584. return_dict=return_dict,
  585. cache_position=cache_position,
  586. )
  587. logits = self.lm_head(outputs[0])
  588. loss = None
  589. if labels is not None:
  590. loss = self.loss_function(
  591. logits,
  592. labels,
  593. vocab_size=self.config.vocab_size,
  594. pad_token_id=self.config.pad_token_id,
  595. **kwargs,
  596. )
  597. if not return_dict:
  598. output = (logits,) + outputs[1:]
  599. return (loss,) + output if loss is not None else output
  600. return CausalLMOutputWithCrossAttentions(
  601. loss=loss,
  602. logits=logits,
  603. past_key_values=outputs.past_key_values,
  604. hidden_states=outputs.hidden_states,
  605. attentions=outputs.attentions,
  606. cross_attentions=outputs.cross_attentions,
  607. )
  608. __all__ = ["XGLMForCausalLM", "XGLMModel", "XGLMPreTrainedModel"]