modeling_opt.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102
  1. # coding=utf-8
  2. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch OPT model."""
  16. from typing import Callable, Optional, Union
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import AttentionMaskConverter
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithPast,
  28. CausalLMOutputWithPast,
  29. QuestionAnsweringModelOutput,
  30. SequenceClassifierOutputWithPast,
  31. )
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
  35. from ...utils.deprecation import deprecate_kwarg
  36. from .configuration_opt import OPTConfig
  37. if is_torch_flex_attn_available():
  38. from torch.nn.attention.flex_attention import BlockMask
  39. from ...integrations.flex_attention import make_flex_block_causal_mask
  40. logger = logging.get_logger(__name__)
  41. class OPTLearnedPositionalEmbedding(nn.Embedding):
  42. """
  43. This module learns positional embeddings up to a fixed maximum size.
  44. """
  45. def __init__(self, num_embeddings: int, embedding_dim: int):
  46. # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  47. # and adjust num_embeddings appropriately. Other models don't have this hack
  48. self.offset = 2
  49. super().__init__(num_embeddings + self.offset, embedding_dim)
  50. def forward(
  51. self,
  52. attention_mask: torch.LongTensor,
  53. past_key_values_length: int = 0,
  54. position_ids: Optional[torch.LongTensor] = None,
  55. ):
  56. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  57. if position_ids is None:
  58. position_ids = torch.cumsum(attention_mask, dim=1)
  59. position_ids = (position_ids * attention_mask - 1).long()
  60. # cut positions if `past_key_values_length` is > 0
  61. position_ids = position_ids[:, past_key_values_length:]
  62. return super().forward(position_ids + self.offset)
  63. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  64. def eager_attention_forward(
  65. module: nn.Module,
  66. query: torch.Tensor,
  67. key: torch.Tensor,
  68. value: torch.Tensor,
  69. attention_mask: Optional[torch.Tensor],
  70. scaling: float,
  71. dropout: float = 0.0,
  72. **kwargs,
  73. ):
  74. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  75. if attention_mask is not None:
  76. attn_weights = attn_weights + attention_mask
  77. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  78. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  79. attn_output = torch.matmul(attn_weights, value)
  80. attn_output = attn_output.transpose(1, 2).contiguous()
  81. return attn_output, attn_weights
  82. class OPTAttention(nn.Module):
  83. """Multi-headed attention from 'Attention Is All You Need' paper"""
  84. def __init__(
  85. self,
  86. config: OPTConfig,
  87. layer_idx: Optional[int] = None,
  88. **kwargs,
  89. ):
  90. super().__init__()
  91. self.config = config
  92. self.embed_dim = config.hidden_size
  93. self.num_heads = config.num_attention_heads
  94. self.dropout = config.attention_dropout
  95. self.enable_bias = config.enable_bias
  96. self.layer_idx = layer_idx
  97. if layer_idx is None:
  98. logger.warning_once(
  99. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  100. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  101. "when creating this class."
  102. )
  103. self.head_dim = self.embed_dim // self.num_heads
  104. self.is_causal = True
  105. if (self.head_dim * self.num_heads) != self.embed_dim:
  106. raise ValueError(
  107. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  108. f" and `num_heads`: {self.num_heads})."
  109. )
  110. self.scaling = self.head_dim**-0.5
  111. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  112. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  113. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  114. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  115. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  116. def forward(
  117. self,
  118. hidden_states: torch.Tensor,
  119. past_key_values: Optional[Cache] = None,
  120. attention_mask: Optional[torch.Tensor] = None,
  121. layer_head_mask: Optional[torch.Tensor] = None,
  122. output_attentions: bool = False,
  123. cache_position: Optional[torch.Tensor] = None,
  124. **kwargs,
  125. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
  126. """Input shape: Batch x Time x Channel"""
  127. bsz, tgt_len, _ = hidden_states.size()
  128. # Scaling is susceptible to floating point arithmetics' inprecisions
  129. # which can lead to different results (this is dependent from model
  130. # to model, e.g. whisper is one such case). We therefore keep the
  131. # original order of scaling to follow the original implementation
  132. # and enforce no scaling (1.0) in the attention call below.
  133. query_states = self.q_proj(hidden_states) * self.scaling
  134. query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  135. key_states = self.k_proj(hidden_states)
  136. value_states = self.v_proj(hidden_states)
  137. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  138. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  139. if past_key_values is not None:
  140. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  141. key_states, value_states = past_key_values.update(
  142. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  143. )
  144. attention_interface: Callable = eager_attention_forward
  145. if self.config._attn_implementation != "eager":
  146. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  147. attn_output, attn_weights = attention_interface(
  148. self,
  149. query_states,
  150. key_states,
  151. value_states,
  152. attention_mask,
  153. dropout=0.0 if not self.training else self.dropout,
  154. scaling=1.0,
  155. **kwargs,
  156. )
  157. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  158. attn_output = self.out_proj(attn_output)
  159. if not output_attentions:
  160. attn_weights = None
  161. return attn_output, attn_weights
  162. class OPTDecoderLayer(GradientCheckpointingLayer):
  163. def __init__(self, config: OPTConfig, layer_idx: Optional[int] = None):
  164. super().__init__()
  165. self.embed_dim = config.hidden_size
  166. self.self_attn = OPTAttention(config=config, layer_idx=layer_idx)
  167. self.do_layer_norm_before = config.do_layer_norm_before
  168. self.dropout = config.dropout
  169. self.activation_fn = ACT2FN[config.activation_function]
  170. self.self_attn_layer_norm = nn.LayerNorm(
  171. self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
  172. )
  173. self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
  174. self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
  175. self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
  176. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  177. def forward(
  178. self,
  179. hidden_states: torch.Tensor,
  180. attention_mask: Optional[torch.Tensor] = None,
  181. layer_head_mask: Optional[torch.Tensor] = None,
  182. past_key_values: Optional[Cache] = None,
  183. output_attentions: Optional[bool] = False,
  184. use_cache: Optional[bool] = False,
  185. position_ids: Optional[torch.LongTensor] = None,
  186. cache_position: Optional[torch.Tensor] = None,
  187. **kwargs: Unpack[FlashAttentionKwargs],
  188. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  189. """
  190. Args:
  191. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  192. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  193. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  194. layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
  195. `(encoder_attention_heads,)`.
  196. output_attentions (`bool`, *optional*):
  197. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  198. returned tensors for more detail.
  199. use_cache (`bool`, *optional*):
  200. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  201. (see `past_key_values`).
  202. past_key_values (`Cache`, *optional*): cached past key and value projection states
  203. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  204. Indices depicting the position of the input sequence tokens in the sequence..
  205. """
  206. residual = hidden_states
  207. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  208. if self.do_layer_norm_before:
  209. hidden_states = self.self_attn_layer_norm(hidden_states)
  210. # Self Attention
  211. hidden_states, self_attn_weights = self.self_attn(
  212. hidden_states=hidden_states,
  213. past_key_values=past_key_values,
  214. position_ids=position_ids,
  215. attention_mask=attention_mask,
  216. layer_head_mask=layer_head_mask,
  217. output_attentions=output_attentions,
  218. cache_position=cache_position,
  219. **kwargs,
  220. )
  221. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  222. hidden_states = residual + hidden_states
  223. # 350m applies layer norm AFTER attention
  224. if not self.do_layer_norm_before:
  225. hidden_states = self.self_attn_layer_norm(hidden_states)
  226. # Fully Connected
  227. hidden_states_shape = hidden_states.shape
  228. hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
  229. residual = hidden_states
  230. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  231. if self.do_layer_norm_before:
  232. hidden_states = self.final_layer_norm(hidden_states)
  233. hidden_states = self.fc1(hidden_states)
  234. hidden_states = self.activation_fn(hidden_states)
  235. hidden_states = self.fc2(hidden_states)
  236. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  237. hidden_states = (residual + hidden_states).view(hidden_states_shape)
  238. # 350m applies layer norm AFTER attention
  239. if not self.do_layer_norm_before:
  240. hidden_states = self.final_layer_norm(hidden_states)
  241. outputs = (hidden_states,)
  242. if output_attentions:
  243. outputs += (self_attn_weights,)
  244. return outputs
  245. @auto_docstring
  246. class OPTPreTrainedModel(PreTrainedModel):
  247. config: OPTConfig
  248. base_model_prefix = "model"
  249. supports_gradient_checkpointing = True
  250. _no_split_modules = ["OPTDecoderLayer"]
  251. _supports_attention_backend = True
  252. _supports_flash_attn = True
  253. _supports_sdpa = True
  254. _supports_flex_attn = True
  255. _can_compile_fullgraph = True
  256. def _init_weights(self, module):
  257. std = self.config.init_std
  258. if isinstance(module, nn.Linear):
  259. module.weight.data.normal_(mean=0.0, std=std)
  260. if module.bias is not None:
  261. module.bias.data.zero_()
  262. elif isinstance(module, nn.Embedding):
  263. module.weight.data.normal_(mean=0.0, std=std)
  264. if module.padding_idx is not None:
  265. module.weight.data[module.padding_idx].zero_()
  266. elif isinstance(module, nn.LayerNorm):
  267. module.weight.data.fill_(1.0)
  268. module.bias.data.zero_()
  269. class OPTDecoder(OPTPreTrainedModel):
  270. """
  271. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
  272. Args:
  273. config: OPTConfig
  274. """
  275. def __init__(self, config: OPTConfig):
  276. super().__init__(config)
  277. self.dropout = config.dropout
  278. self.layerdrop = config.layerdrop
  279. self.padding_idx = config.pad_token_id
  280. self.max_target_positions = config.max_position_embeddings
  281. self.vocab_size = config.vocab_size
  282. self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
  283. self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
  284. if config.word_embed_proj_dim != config.hidden_size:
  285. self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
  286. else:
  287. self.project_out = None
  288. if config.word_embed_proj_dim != config.hidden_size:
  289. self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
  290. else:
  291. self.project_in = None
  292. # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
  293. # with checkpoints that have been fine-tuned before transformers v4.20.1
  294. # see https://github.com/facebookresearch/metaseq/pull/164
  295. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  296. self.final_layer_norm = nn.LayerNorm(
  297. config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
  298. )
  299. else:
  300. self.final_layer_norm = None
  301. self.layers = nn.ModuleList([OPTDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  302. self.gradient_checkpointing = False
  303. # Initialize weights and apply final processing
  304. self.post_init()
  305. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  306. def _update_causal_mask(
  307. self,
  308. attention_mask: Union[torch.Tensor, "BlockMask"],
  309. input_tensor: torch.Tensor,
  310. cache_position: torch.Tensor,
  311. past_key_values: Cache,
  312. output_attentions: bool = False,
  313. ):
  314. if self.config._attn_implementation == "flash_attention_2":
  315. if attention_mask is not None and (attention_mask == 0.0).any():
  316. return attention_mask
  317. return None
  318. if self.config._attn_implementation == "flex_attention":
  319. if isinstance(attention_mask, torch.Tensor):
  320. attention_mask = make_flex_block_causal_mask(attention_mask)
  321. return attention_mask
  322. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  323. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  324. # to infer the attention mask.
  325. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  326. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  327. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  328. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  329. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  330. attention_mask,
  331. inputs_embeds=input_tensor,
  332. past_key_values_length=past_seen_tokens,
  333. is_training=self.training,
  334. ):
  335. return None
  336. dtype = input_tensor.dtype
  337. sequence_length = input_tensor.shape[1]
  338. if using_compilable_cache:
  339. target_length = past_key_values.get_max_cache_shape()
  340. else:
  341. target_length = (
  342. attention_mask.shape[-1]
  343. if isinstance(attention_mask, torch.Tensor)
  344. else past_seen_tokens + sequence_length + 1
  345. )
  346. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  347. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  348. attention_mask,
  349. sequence_length=sequence_length,
  350. target_length=target_length,
  351. dtype=dtype,
  352. cache_position=cache_position,
  353. batch_size=input_tensor.shape[0],
  354. )
  355. if (
  356. self.config._attn_implementation == "sdpa"
  357. and attention_mask is not None
  358. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  359. and not output_attentions
  360. ):
  361. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  362. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  363. # Details: https://github.com/pytorch/pytorch/issues/110213
  364. min_dtype = torch.finfo(dtype).min
  365. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  366. return causal_mask
  367. @staticmethod
  368. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  369. def _prepare_4d_causal_attention_mask_with_cache_position(
  370. attention_mask: torch.Tensor,
  371. sequence_length: int,
  372. target_length: int,
  373. dtype: torch.dtype,
  374. cache_position: torch.Tensor,
  375. batch_size: int,
  376. **kwargs,
  377. ):
  378. """
  379. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  380. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  381. Args:
  382. attention_mask (`torch.Tensor`):
  383. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  384. `(batch_size, 1, query_length, key_value_length)`.
  385. sequence_length (`int`):
  386. The sequence length being processed.
  387. target_length (`int`):
  388. The target length: when generating with static cache, the mask should be as long as the static cache,
  389. to account for the 0 padding, the part of the cache that is not filled yet.
  390. dtype (`torch.dtype`):
  391. The dtype to use for the 4D attention mask.
  392. cache_position (`torch.Tensor`):
  393. Indices depicting the position of the input sequence tokens in the sequence.
  394. batch_size (`torch.Tensor`):
  395. Batch size.
  396. """
  397. if attention_mask is not None and attention_mask.dim() == 4:
  398. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  399. causal_mask = attention_mask
  400. else:
  401. min_dtype = torch.finfo(dtype).min
  402. causal_mask = torch.full(
  403. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  404. )
  405. if sequence_length != 1:
  406. causal_mask = torch.triu(causal_mask, diagonal=1)
  407. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  408. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  409. if attention_mask is not None:
  410. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  411. mask_length = attention_mask.shape[-1]
  412. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  413. causal_mask.device
  414. )
  415. padding_mask = padding_mask == 0
  416. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  417. padding_mask, min_dtype
  418. )
  419. return causal_mask
  420. @can_return_tuple
  421. def forward(
  422. self,
  423. input_ids: Optional[torch.LongTensor] = None,
  424. attention_mask: Optional[torch.Tensor] = None,
  425. head_mask: Optional[torch.Tensor] = None,
  426. past_key_values: Optional[Cache] = None,
  427. inputs_embeds: Optional[torch.FloatTensor] = None,
  428. use_cache: Optional[bool] = None,
  429. output_attentions: Optional[bool] = None,
  430. output_hidden_states: Optional[bool] = None,
  431. return_dict: Optional[bool] = None,
  432. position_ids: Optional[torch.LongTensor] = None,
  433. cache_position: Optional[torch.Tensor] = None,
  434. **kwargs: Unpack[FlashAttentionKwargs],
  435. ) -> Union[tuple, BaseModelOutputWithPast]:
  436. r"""
  437. Args:
  438. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  439. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  440. provide it.
  441. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  442. [`PreTrainedTokenizer.__call__`] for details.
  443. [What are input IDs?](../glossary#input-ids)
  444. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  445. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  446. - 1 for tokens that are **not masked**,
  447. - 0 for tokens that are **masked**.
  448. [What are attention masks?](../glossary#attention-mask)
  449. head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
  450. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  451. - 1 indicates the head is **not masked**,
  452. - 0 indicates the head is **masked**.
  453. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  454. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  455. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  456. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  457. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  458. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  459. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  460. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  461. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  462. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  463. than the model's internal embedding lookup matrix.
  464. output_attentions (`bool`, *optional*):
  465. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  466. returned tensors for more detail.
  467. output_hidden_states (`bool`, *optional*):
  468. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  469. for more detail.
  470. return_dict (`bool`, *optional*):
  471. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  472. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  473. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  474. config.n_positions - 1]`. for padding use -1.
  475. [What are position IDs?](../glossary#position-ids)
  476. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  477. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  478. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  479. the complete sequence length.
  480. """
  481. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  482. output_hidden_states = (
  483. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  484. )
  485. use_cache = use_cache if use_cache is not None else self.config.use_cache
  486. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  487. if (input_ids is None) ^ (inputs_embeds is not None):
  488. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  489. if self.gradient_checkpointing and self.training and use_cache:
  490. logger.warning_once(
  491. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  492. )
  493. use_cache = False
  494. if input_ids is not None:
  495. input_ids = input_ids.view(-1, input_ids.shape[-1])
  496. if inputs_embeds is None:
  497. inputs_embeds = self.embed_tokens(input_ids)
  498. if use_cache and past_key_values is None:
  499. past_key_values = DynamicCache(config=self.config)
  500. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  501. if cache_position is None:
  502. cache_position = torch.arange(
  503. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  504. )
  505. if attention_mask is None:
  506. seq_length = past_seen_tokens + inputs_embeds.shape[1]
  507. attention_mask = torch.ones(inputs_embeds.shape[0], seq_length, device=inputs_embeds.device)
  508. causal_mask = self._update_causal_mask(
  509. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  510. )
  511. # embed positions
  512. if position_ids is None:
  513. # position_ids = cache_position.unsqueeze(0)
  514. position_ids = torch.cumsum(attention_mask, dim=1)
  515. position_ids = (position_ids * attention_mask - 1).long()
  516. # cut positions if `past_seen_tokens` is > 0
  517. position_ids = position_ids[:, past_seen_tokens:]
  518. pos_embeds = self.embed_positions(attention_mask, past_seen_tokens, position_ids=position_ids)
  519. if self.project_in is not None:
  520. inputs_embeds = self.project_in(inputs_embeds)
  521. hidden_states = inputs_embeds + pos_embeds.to(inputs_embeds.device)
  522. # decoder layers
  523. all_hidden_states = () if output_hidden_states else None
  524. all_self_attns = () if output_attentions else None
  525. # check if head_mask has a correct number of layers specified if desired
  526. for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
  527. if attn_mask is not None:
  528. if attn_mask.size()[0] != (len(self.layers)):
  529. raise ValueError(
  530. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  531. f" {head_mask.size()[0]}."
  532. )
  533. for idx, decoder_layer in enumerate(self.layers):
  534. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  535. if output_hidden_states:
  536. all_hidden_states += (hidden_states,)
  537. if self.training:
  538. dropout_probability = torch.rand([])
  539. if dropout_probability < self.layerdrop:
  540. continue
  541. layer_outputs = decoder_layer(
  542. hidden_states,
  543. attention_mask=causal_mask,
  544. position_ids=position_ids,
  545. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  546. past_key_values=past_key_values,
  547. output_attentions=output_attentions,
  548. use_cache=use_cache,
  549. cache_position=cache_position,
  550. **kwargs,
  551. )
  552. hidden_states = layer_outputs[0]
  553. if output_attentions:
  554. all_self_attns += (layer_outputs[1],)
  555. if self.final_layer_norm is not None:
  556. hidden_states = self.final_layer_norm(hidden_states)
  557. if self.project_out is not None:
  558. hidden_states = self.project_out(hidden_states)
  559. # add hidden states from the last decoder layer
  560. if output_hidden_states:
  561. all_hidden_states += (hidden_states,)
  562. return BaseModelOutputWithPast(
  563. last_hidden_state=hidden_states,
  564. past_key_values=past_key_values,
  565. hidden_states=all_hidden_states,
  566. attentions=all_self_attns,
  567. )
  568. @auto_docstring
  569. class OPTModel(OPTPreTrainedModel):
  570. def __init__(self, config: OPTConfig):
  571. super().__init__(config)
  572. self.decoder = OPTDecoder(config)
  573. # Initialize weights and apply final processing
  574. self.post_init()
  575. def get_input_embeddings(self):
  576. return self.decoder.embed_tokens
  577. def set_input_embeddings(self, value):
  578. self.decoder.embed_tokens = value
  579. @can_return_tuple
  580. @auto_docstring
  581. def forward(
  582. self,
  583. input_ids: Optional[torch.LongTensor] = None,
  584. attention_mask: Optional[torch.Tensor] = None,
  585. head_mask: Optional[torch.Tensor] = None,
  586. past_key_values: Optional[Cache] = None,
  587. inputs_embeds: Optional[torch.FloatTensor] = None,
  588. use_cache: Optional[bool] = None,
  589. output_attentions: Optional[bool] = None,
  590. output_hidden_states: Optional[bool] = None,
  591. return_dict: Optional[bool] = None,
  592. position_ids: Optional[torch.LongTensor] = None,
  593. cache_position: Optional[torch.Tensor] = None,
  594. **kwargs: Unpack[FlashAttentionKwargs],
  595. ) -> Union[tuple, BaseModelOutputWithPast]:
  596. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  597. output_hidden_states = (
  598. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  599. )
  600. use_cache = use_cache if use_cache is not None else self.config.use_cache
  601. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  602. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  603. decoder_outputs = self.decoder(
  604. input_ids=input_ids,
  605. attention_mask=attention_mask,
  606. position_ids=position_ids,
  607. head_mask=head_mask,
  608. past_key_values=past_key_values,
  609. inputs_embeds=inputs_embeds,
  610. use_cache=use_cache,
  611. output_attentions=output_attentions,
  612. output_hidden_states=output_hidden_states,
  613. return_dict=True,
  614. cache_position=cache_position,
  615. **kwargs,
  616. )
  617. return BaseModelOutputWithPast(
  618. last_hidden_state=decoder_outputs.last_hidden_state,
  619. past_key_values=decoder_outputs.past_key_values,
  620. hidden_states=decoder_outputs.hidden_states,
  621. attentions=decoder_outputs.attentions,
  622. )
  623. class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
  624. _tied_weights_keys = ["lm_head.weight"]
  625. def __init__(self, config):
  626. super().__init__(config)
  627. self.model = OPTModel(config)
  628. # the lm_head weight is automatically tied to the embed tokens weight
  629. self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
  630. # Initialize weights and apply final processing
  631. self.post_init()
  632. def get_input_embeddings(self):
  633. return self.model.decoder.embed_tokens
  634. def set_input_embeddings(self, value):
  635. self.model.decoder.embed_tokens = value
  636. def set_decoder(self, decoder):
  637. self.model.decoder = decoder
  638. def get_decoder(self):
  639. return self.model.decoder
  640. @can_return_tuple
  641. @auto_docstring
  642. def forward(
  643. self,
  644. input_ids: Optional[torch.LongTensor] = None,
  645. attention_mask: Optional[torch.Tensor] = None,
  646. head_mask: Optional[torch.Tensor] = None,
  647. past_key_values: Optional[Cache] = None,
  648. inputs_embeds: Optional[torch.FloatTensor] = None,
  649. labels: Optional[torch.LongTensor] = None,
  650. use_cache: Optional[bool] = None,
  651. output_attentions: Optional[bool] = None,
  652. output_hidden_states: Optional[bool] = None,
  653. return_dict: Optional[bool] = None,
  654. position_ids: Optional[torch.LongTensor] = None,
  655. cache_position: Optional[torch.Tensor] = None,
  656. **kwargs: Unpack[TransformersKwargs],
  657. ) -> Union[tuple, CausalLMOutputWithPast]:
  658. r"""
  659. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  660. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  661. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  662. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  663. Example:
  664. ```python
  665. >>> from transformers import AutoTokenizer, OPTForCausalLM
  666. >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
  667. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
  668. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  669. >>> inputs = tokenizer(prompt, return_tensors="pt")
  670. >>> # Generate
  671. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  672. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  673. "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
  674. ```"""
  675. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  676. output_hidden_states = (
  677. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  678. )
  679. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  680. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  681. outputs = self.model.decoder(
  682. input_ids=input_ids,
  683. attention_mask=attention_mask,
  684. position_ids=position_ids,
  685. head_mask=head_mask,
  686. past_key_values=past_key_values,
  687. inputs_embeds=inputs_embeds,
  688. use_cache=use_cache,
  689. output_attentions=output_attentions,
  690. output_hidden_states=output_hidden_states,
  691. return_dict=True,
  692. cache_position=cache_position,
  693. **kwargs,
  694. )
  695. logits = self.lm_head(outputs[0]).contiguous()
  696. loss = None
  697. if labels is not None:
  698. # move labels to correct device to enable model parallelism
  699. labels = labels.to(logits.device)
  700. loss = self.loss_function(
  701. logits,
  702. labels,
  703. vocab_size=self.config.vocab_size,
  704. **kwargs,
  705. )
  706. return CausalLMOutputWithPast(
  707. loss=loss,
  708. logits=logits,
  709. past_key_values=outputs.past_key_values,
  710. hidden_states=outputs.hidden_states,
  711. attentions=outputs.attentions,
  712. )
  713. @auto_docstring(
  714. custom_intro="""
  715. The OPT Model transformer with a sequence classification head on top (linear layer).
  716. [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  717. (e.g. GPT-2) do.
  718. Since it does classification on the last token, it requires to know the position of the last token. If a
  719. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  720. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  721. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  722. each row of the batch).
  723. """
  724. )
  725. class OPTForSequenceClassification(OPTPreTrainedModel):
  726. def __init__(self, config: OPTConfig):
  727. super().__init__(config)
  728. self.num_labels = config.num_labels
  729. self.model = OPTModel(config)
  730. self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
  731. # Initialize weights and apply final processing
  732. self.post_init()
  733. @auto_docstring
  734. def forward(
  735. self,
  736. input_ids: Optional[torch.LongTensor] = None,
  737. attention_mask: Optional[torch.FloatTensor] = None,
  738. head_mask: Optional[torch.FloatTensor] = None,
  739. past_key_values: Optional[Cache] = None,
  740. inputs_embeds: Optional[torch.FloatTensor] = None,
  741. labels: Optional[torch.LongTensor] = None,
  742. use_cache: Optional[bool] = None,
  743. output_attentions: Optional[bool] = None,
  744. output_hidden_states: Optional[bool] = None,
  745. return_dict: Optional[bool] = None,
  746. position_ids: Optional[torch.LongTensor] = None,
  747. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  748. r"""
  749. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  750. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  751. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  752. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  753. """
  754. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  755. transformer_outputs = self.model(
  756. input_ids,
  757. past_key_values=past_key_values,
  758. attention_mask=attention_mask,
  759. position_ids=position_ids,
  760. head_mask=head_mask,
  761. inputs_embeds=inputs_embeds,
  762. use_cache=use_cache,
  763. output_attentions=output_attentions,
  764. output_hidden_states=output_hidden_states,
  765. return_dict=return_dict,
  766. )
  767. hidden_states = transformer_outputs[0]
  768. logits = self.score(hidden_states)
  769. if input_ids is not None:
  770. batch_size, sequence_length = input_ids.shape[:2]
  771. else:
  772. batch_size, sequence_length = inputs_embeds.shape[:2]
  773. if self.config.pad_token_id is None and batch_size != 1:
  774. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  775. if self.config.pad_token_id is None:
  776. last_non_pad_token = -1
  777. elif input_ids is not None:
  778. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  779. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  780. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  781. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  782. else:
  783. last_non_pad_token = -1
  784. logger.warning_once(
  785. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  786. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  787. )
  788. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  789. loss = None
  790. if labels is not None:
  791. if self.config.problem_type is None:
  792. if self.num_labels == 1:
  793. self.config.problem_type = "regression"
  794. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  795. self.config.problem_type = "single_label_classification"
  796. else:
  797. self.config.problem_type = "multi_label_classification"
  798. if self.config.problem_type == "regression":
  799. loss_fct = MSELoss()
  800. if self.num_labels == 1:
  801. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  802. else:
  803. loss = loss_fct(pooled_logits, labels)
  804. elif self.config.problem_type == "single_label_classification":
  805. loss_fct = CrossEntropyLoss()
  806. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  807. elif self.config.problem_type == "multi_label_classification":
  808. loss_fct = BCEWithLogitsLoss()
  809. loss = loss_fct(pooled_logits, labels)
  810. if not return_dict:
  811. output = (pooled_logits,) + transformer_outputs[1:]
  812. return ((loss,) + output) if loss is not None else output
  813. return SequenceClassifierOutputWithPast(
  814. loss=loss,
  815. logits=pooled_logits,
  816. past_key_values=transformer_outputs.past_key_values,
  817. hidden_states=transformer_outputs.hidden_states,
  818. attentions=transformer_outputs.attentions,
  819. )
  820. def get_input_embeddings(self):
  821. return self.model.decoder.embed_tokens
  822. def set_input_embeddings(self, value):
  823. self.model.decoder.embed_tokens = value
  824. @auto_docstring
  825. class OPTForQuestionAnswering(OPTPreTrainedModel):
  826. def __init__(self, config: OPTConfig):
  827. super().__init__(config)
  828. self.model = OPTModel(config)
  829. self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
  830. # Initialize weights and apply final processing
  831. self.post_init()
  832. @auto_docstring
  833. def forward(
  834. self,
  835. input_ids: Optional[torch.LongTensor] = None,
  836. attention_mask: Optional[torch.FloatTensor] = None,
  837. head_mask: Optional[torch.FloatTensor] = None,
  838. past_key_values: Optional[Cache] = None,
  839. inputs_embeds: Optional[torch.FloatTensor] = None,
  840. start_positions: Optional[torch.LongTensor] = None,
  841. end_positions: Optional[torch.LongTensor] = None,
  842. use_cache: Optional[bool] = None,
  843. output_attentions: Optional[bool] = None,
  844. output_hidden_states: Optional[bool] = None,
  845. return_dict: Optional[bool] = None,
  846. position_ids: Optional[torch.LongTensor] = None,
  847. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  848. r"""
  849. Example:
  850. ```python
  851. >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
  852. >>> import torch
  853. >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
  854. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
  855. >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
  856. >>> # so the head will be randomly initialized, hence the predictions will be random
  857. >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
  858. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  859. >>> inputs = tokenizer(question, text, return_tensors="pt")
  860. >>> with torch.no_grad():
  861. ... outputs = model(**inputs)
  862. >>> answer_start_index = outputs.start_logits.argmax()
  863. >>> answer_end_index = outputs.end_logits.argmax()
  864. >>> answer_offset = len(tokenizer(question)[0])
  865. >>> predict_answer_tokens = inputs.input_ids[
  866. ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
  867. ... ]
  868. >>> predicted = tokenizer.decode(predict_answer_tokens)
  869. >>> predicted
  870. ' a nice puppet'
  871. ```"""
  872. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  873. transformer_outputs = self.model(
  874. input_ids,
  875. past_key_values=past_key_values,
  876. attention_mask=attention_mask,
  877. position_ids=position_ids,
  878. head_mask=head_mask,
  879. inputs_embeds=inputs_embeds,
  880. use_cache=use_cache,
  881. output_attentions=output_attentions,
  882. output_hidden_states=output_hidden_states,
  883. return_dict=return_dict,
  884. )
  885. hidden_states = transformer_outputs[0]
  886. logits = self.qa_outputs(hidden_states)
  887. start_logits, end_logits = logits.split(1, dim=-1)
  888. start_logits = start_logits.squeeze(-1).contiguous()
  889. end_logits = end_logits.squeeze(-1).contiguous()
  890. total_loss = None
  891. if start_positions is not None and end_positions is not None:
  892. # If we are on multi-GPU, split add a dimension
  893. if len(start_positions.size()) > 1:
  894. start_positions = start_positions.squeeze(-1)
  895. if len(end_positions.size()) > 1:
  896. end_positions = end_positions.squeeze(-1)
  897. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  898. ignored_index = start_logits.size(1)
  899. start_positions = start_positions.clamp(0, ignored_index).to(logits.device)
  900. end_positions = end_positions.clamp(0, ignored_index).to(logits.device)
  901. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  902. start_loss = loss_fct(start_logits, start_positions)
  903. end_loss = loss_fct(end_logits, end_positions)
  904. total_loss = (start_loss + end_loss) / 2
  905. if not return_dict:
  906. output = (start_logits, end_logits) + transformer_outputs[2:]
  907. return ((total_loss,) + output) if total_loss is not None else output
  908. return QuestionAnsweringModelOutput(
  909. loss=total_loss,
  910. start_logits=start_logits,
  911. end_logits=end_logits,
  912. hidden_states=transformer_outputs.hidden_states,
  913. attentions=transformer_outputs.attentions,
  914. )
  915. def get_input_embeddings(self):
  916. return self.model.decoder.embed_tokens
  917. def set_input_embeddings(self, value):
  918. self.model.decoder.embed_tokens = value
  919. __all__ = [
  920. "OPTForCausalLM",
  921. "OPTModel",
  922. "OPTPreTrainedModel",
  923. "OPTForSequenceClassification",
  924. "OPTForQuestionAnswering",
  925. ]