modeling_gpt2.py 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638
  1. # coding=utf-8
  2. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch OpenAI GPT-2 model."""
  17. import math
  18. import os
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ...activations import ACT2FN, get_activation
  26. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  27. from ...generation import GenerationMixin
  28. from ...masking_utils import create_causal_mask
  29. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. CausalLMOutputWithCrossAttentions,
  34. QuestionAnsweringModelOutput,
  35. SequenceClassifierOutputWithPast,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
  40. from ...utils import (
  41. ModelOutput,
  42. add_start_docstrings,
  43. auto_docstring,
  44. logging,
  45. )
  46. from ...utils.deprecation import deprecate_kwarg
  47. from ...utils.model_parallel_utils import assert_device_map, get_device_map
  48. from .configuration_gpt2 import GPT2Config
  49. logger = logging.get_logger(__name__)
  50. def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
  51. """Load tf checkpoints in a pytorch model"""
  52. try:
  53. import re
  54. import tensorflow as tf
  55. except ImportError:
  56. logger.error(
  57. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  58. "https://www.tensorflow.org/install/ for installation instructions."
  59. )
  60. raise
  61. tf_path = os.path.abspath(gpt2_checkpoint_path)
  62. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  63. # Load weights from TF model
  64. init_vars = tf.train.list_variables(tf_path)
  65. names = []
  66. arrays = []
  67. for name, shape in init_vars:
  68. logger.info(f"Loading TF weight {name} with shape {shape}")
  69. array = tf.train.load_variable(tf_path, name)
  70. names.append(name)
  71. arrays.append(array.squeeze())
  72. for name, array in zip(names, arrays):
  73. name = name[6:] # skip "model/"
  74. name = name.split("/")
  75. pointer = model
  76. for m_name in name:
  77. if re.fullmatch(r"[A-Za-z]+\d+", m_name):
  78. scope_names = re.split(r"(\d+)", m_name)
  79. else:
  80. scope_names = [m_name]
  81. if scope_names[0] == "w" or scope_names[0] == "g":
  82. pointer = getattr(pointer, "weight")
  83. elif scope_names[0] == "b":
  84. pointer = getattr(pointer, "bias")
  85. elif scope_names[0] == "wpe" or scope_names[0] == "wte":
  86. pointer = getattr(pointer, scope_names[0])
  87. pointer = getattr(pointer, "weight")
  88. else:
  89. pointer = getattr(pointer, scope_names[0])
  90. if len(scope_names) >= 2:
  91. num = int(scope_names[1])
  92. pointer = pointer[num]
  93. try:
  94. if pointer.shape != array.shape:
  95. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  96. except ValueError as e:
  97. e.args += (pointer.shape, array.shape)
  98. raise
  99. logger.info(f"Initialize PyTorch weight {name}")
  100. pointer.data = torch.from_numpy(array)
  101. return model
  102. def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
  103. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  104. if module.scale_attn_weights:
  105. attn_weights = attn_weights / torch.full(
  106. [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
  107. )
  108. # Layer-wise attention scaling
  109. if module.scale_attn_by_inverse_layer_idx:
  110. attn_weights = attn_weights / float(module.layer_idx + 1)
  111. if not module.is_cross_attention:
  112. # if only "normal" attention layer implements causal mask
  113. query_length, key_length = query.size(-2), key.size(-2)
  114. causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
  115. mask_value = torch.finfo(attn_weights.dtype).min
  116. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  117. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  118. mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
  119. attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
  120. if attention_mask is not None:
  121. # Apply the attention mask
  122. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  123. attn_weights = attn_weights + causal_mask
  124. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  125. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
  126. attn_weights = attn_weights.type(value.dtype)
  127. attn_weights = module.attn_dropout(attn_weights)
  128. # Mask heads if we want to
  129. if head_mask is not None:
  130. attn_weights = attn_weights * head_mask
  131. attn_output = torch.matmul(attn_weights, value)
  132. attn_output = attn_output.transpose(1, 2)
  133. return attn_output, attn_weights
  134. class GPT2Attention(nn.Module):
  135. def __init__(self, config, is_cross_attention=False, layer_idx=None):
  136. super().__init__()
  137. self.config = config
  138. max_positions = config.max_position_embeddings
  139. self.register_buffer(
  140. "bias",
  141. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
  142. 1, 1, max_positions, max_positions
  143. ),
  144. persistent=False,
  145. )
  146. self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
  147. self.embed_dim = config.hidden_size
  148. self.num_heads = config.num_attention_heads
  149. self.head_dim = self.embed_dim // self.num_heads
  150. self.split_size = self.embed_dim
  151. if self.head_dim * self.num_heads != self.embed_dim:
  152. raise ValueError(
  153. f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  154. f" {self.num_heads})."
  155. )
  156. self.scale_attn_weights = config.scale_attn_weights
  157. self.is_cross_attention = is_cross_attention
  158. # Layer-wise attention scaling, reordering, and upcasting
  159. self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
  160. self.layer_idx = layer_idx
  161. self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
  162. if self.is_cross_attention:
  163. self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
  164. self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
  165. else:
  166. self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
  167. self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
  168. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  169. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  170. self.is_causal = True
  171. self.pruned_heads = set()
  172. def prune_heads(self, heads):
  173. if len(heads) == 0:
  174. return
  175. heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
  176. index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
  177. # Prune conv1d layers
  178. self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
  179. self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
  180. # Update hyper params
  181. self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
  182. self.num_heads = self.num_heads - len(heads)
  183. self.pruned_heads = self.pruned_heads.union(heads)
  184. def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
  185. # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
  186. bsz, num_heads, q_seq_len, dk = query.size()
  187. _, _, k_seq_len, _ = key.size()
  188. # Preallocate attn_weights for `baddbmm`
  189. attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
  190. # Compute Scale Factor
  191. scale_factor = 1.0
  192. if self.scale_attn_weights:
  193. scale_factor /= float(value.size(-1)) ** 0.5
  194. if self.scale_attn_by_inverse_layer_idx:
  195. scale_factor /= float(self.layer_idx + 1)
  196. # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
  197. with torch.autocast(query.device.type, enabled=False):
  198. q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
  199. attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
  200. attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
  201. if not self.is_cross_attention:
  202. # if only "normal" attention layer implements causal mask
  203. query_length, key_length = query.size(-2), key.size(-2)
  204. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  205. mask_value = torch.finfo(attn_weights.dtype).min
  206. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  207. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  208. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
  209. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  210. if attention_mask is not None:
  211. # Apply the attention mask
  212. attn_weights = attn_weights + attention_mask
  213. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  214. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
  215. if attn_weights.dtype != torch.float32:
  216. raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
  217. attn_weights = attn_weights.type(value.dtype)
  218. attn_weights = self.attn_dropout(attn_weights)
  219. # Mask heads if we want to
  220. if head_mask is not None:
  221. attn_weights = attn_weights * head_mask
  222. attn_output = torch.matmul(attn_weights, value)
  223. attn_output = attn_output.transpose(1, 2)
  224. return attn_output, attn_weights
  225. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  226. def forward(
  227. self,
  228. hidden_states: Optional[tuple[torch.FloatTensor]],
  229. past_key_values: Optional[Cache] = None,
  230. cache_position: Optional[torch.LongTensor] = None,
  231. attention_mask: Optional[torch.FloatTensor] = None,
  232. head_mask: Optional[torch.FloatTensor] = None,
  233. encoder_hidden_states: Optional[torch.Tensor] = None,
  234. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  235. output_attentions: Optional[bool] = False,
  236. **kwargs,
  237. ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
  238. is_cross_attention = encoder_hidden_states is not None
  239. if past_key_values is not None:
  240. if isinstance(past_key_values, EncoderDecoderCache):
  241. is_updated = past_key_values.is_updated.get(self.layer_idx)
  242. if is_cross_attention:
  243. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  244. curr_past_key_value = past_key_values.cross_attention_cache
  245. else:
  246. curr_past_key_value = past_key_values.self_attention_cache
  247. else:
  248. curr_past_key_value = past_key_values
  249. if is_cross_attention:
  250. if not hasattr(self, "q_attn"):
  251. raise ValueError(
  252. "If class is used as cross attention, the weights `q_attn` have to be defined. "
  253. "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
  254. )
  255. query_states = self.q_attn(hidden_states)
  256. attention_mask = encoder_attention_mask
  257. # Try to get key/value states from cache if possible
  258. if past_key_values is not None and is_updated:
  259. key_states = curr_past_key_value.layers[self.layer_idx].keys
  260. value_states = curr_past_key_value.layers[self.layer_idx].values
  261. else:
  262. key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
  263. shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
  264. key_states = key_states.view(shape_kv).transpose(1, 2)
  265. value_states = value_states.view(shape_kv).transpose(1, 2)
  266. else:
  267. query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
  268. shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
  269. key_states = key_states.view(shape_kv).transpose(1, 2)
  270. value_states = value_states.view(shape_kv).transpose(1, 2)
  271. shape_q = (*query_states.shape[:-1], -1, self.head_dim)
  272. query_states = query_states.view(shape_q).transpose(1, 2)
  273. if (past_key_values is not None and not is_cross_attention) or (
  274. past_key_values is not None and is_cross_attention and not is_updated
  275. ):
  276. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  277. cache_position = cache_position if not is_cross_attention else None
  278. key_states, value_states = curr_past_key_value.update(
  279. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  280. )
  281. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  282. if is_cross_attention:
  283. past_key_values.is_updated[self.layer_idx] = True
  284. is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
  285. using_eager = self.config._attn_implementation == "eager"
  286. attention_interface: Callable = eager_attention_forward
  287. if self.config._attn_implementation != "eager":
  288. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  289. if using_eager and self.reorder_and_upcast_attn:
  290. attn_output, attn_weights = self._upcast_and_reordered_attn(
  291. query_states, key_states, value_states, attention_mask, head_mask
  292. )
  293. else:
  294. attn_output, attn_weights = attention_interface(
  295. self,
  296. query_states,
  297. key_states,
  298. value_states,
  299. attention_mask,
  300. head_mask=head_mask,
  301. dropout=self.attn_dropout.p if self.training else 0.0,
  302. is_causal=is_causal,
  303. **kwargs,
  304. )
  305. attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
  306. attn_output = self.c_proj(attn_output)
  307. attn_output = self.resid_dropout(attn_output)
  308. return attn_output, attn_weights
  309. class GPT2MLP(nn.Module):
  310. def __init__(self, intermediate_size, config):
  311. super().__init__()
  312. embed_dim = config.hidden_size
  313. self.c_fc = Conv1D(intermediate_size, embed_dim)
  314. self.c_proj = Conv1D(embed_dim, intermediate_size)
  315. self.act = ACT2FN[config.activation_function]
  316. self.dropout = nn.Dropout(config.resid_pdrop)
  317. def forward(self, hidden_states: Optional[tuple[torch.FloatTensor]]) -> torch.FloatTensor:
  318. hidden_states = self.c_fc(hidden_states)
  319. hidden_states = self.act(hidden_states)
  320. hidden_states = self.c_proj(hidden_states)
  321. hidden_states = self.dropout(hidden_states)
  322. return hidden_states
  323. class GPT2Block(GradientCheckpointingLayer):
  324. def __init__(self, config, layer_idx=None):
  325. super().__init__()
  326. hidden_size = config.hidden_size
  327. inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
  328. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  329. self.attn = GPT2Attention(config=config, layer_idx=layer_idx)
  330. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  331. if config.add_cross_attention:
  332. self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)
  333. self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  334. self.mlp = GPT2MLP(inner_dim, config)
  335. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  336. def forward(
  337. self,
  338. hidden_states: Optional[tuple[torch.FloatTensor]],
  339. past_key_values: Optional[Cache] = None,
  340. cache_position: Optional[torch.LongTensor] = None,
  341. attention_mask: Optional[torch.FloatTensor] = None,
  342. head_mask: Optional[torch.FloatTensor] = None,
  343. encoder_hidden_states: Optional[torch.Tensor] = None,
  344. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  345. use_cache: Optional[bool] = False,
  346. output_attentions: Optional[bool] = False,
  347. **kwargs,
  348. ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
  349. residual = hidden_states
  350. hidden_states = self.ln_1(hidden_states)
  351. attn_output, self_attn_weights = self.attn(
  352. hidden_states,
  353. past_key_values=past_key_values,
  354. cache_position=cache_position,
  355. attention_mask=attention_mask,
  356. head_mask=head_mask,
  357. use_cache=use_cache,
  358. output_attentions=output_attentions,
  359. **kwargs,
  360. )
  361. # residual connection
  362. hidden_states = attn_output + residual
  363. if encoder_hidden_states is not None:
  364. # add one self-attention block for cross-attention
  365. if not hasattr(self, "crossattention"):
  366. raise ValueError(
  367. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
  368. "cross-attention layers by setting `config.add_cross_attention=True`"
  369. )
  370. residual = hidden_states
  371. hidden_states = self.ln_cross_attn(hidden_states)
  372. cross_attn_output, cross_attn_weights = self.crossattention(
  373. hidden_states,
  374. past_key_values=past_key_values,
  375. attention_mask=attention_mask,
  376. head_mask=head_mask,
  377. encoder_hidden_states=encoder_hidden_states,
  378. encoder_attention_mask=encoder_attention_mask,
  379. output_attentions=output_attentions,
  380. )
  381. # residual connection
  382. hidden_states = residual + cross_attn_output
  383. residual = hidden_states
  384. hidden_states = self.ln_2(hidden_states)
  385. feed_forward_hidden_states = self.mlp(hidden_states)
  386. # residual connection
  387. hidden_states = residual + feed_forward_hidden_states
  388. outputs = (hidden_states,)
  389. if output_attentions:
  390. outputs += (self_attn_weights,)
  391. if encoder_hidden_states is not None:
  392. outputs += (cross_attn_weights,)
  393. return outputs
  394. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
  395. class GPT2SequenceSummary(nn.Module):
  396. r"""
  397. Compute a single vector summary of a sequence hidden states.
  398. Args:
  399. config ([`GPT2Config`]):
  400. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  401. config class of your model for the default values it uses):
  402. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  403. - `"last"` -- Take the last token hidden state (like XLNet)
  404. - `"first"` -- Take the first token hidden state (like Bert)
  405. - `"mean"` -- Take the mean of all tokens hidden states
  406. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  407. - `"attn"` -- Not implemented now, use multi-head attention
  408. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  409. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  410. (otherwise to `config.hidden_size`).
  411. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  412. another string or `None` will add no activation.
  413. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  414. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  415. """
  416. def __init__(self, config: GPT2Config):
  417. super().__init__()
  418. self.summary_type = getattr(config, "summary_type", "last")
  419. if self.summary_type == "attn":
  420. # We should use a standard multi-head attention module with absolute positional embedding for that.
  421. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  422. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  423. raise NotImplementedError
  424. self.summary = nn.Identity()
  425. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  426. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  427. num_classes = config.num_labels
  428. else:
  429. num_classes = config.hidden_size
  430. self.summary = nn.Linear(config.hidden_size, num_classes)
  431. activation_string = getattr(config, "summary_activation", None)
  432. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  433. self.first_dropout = nn.Identity()
  434. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  435. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  436. self.last_dropout = nn.Identity()
  437. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  438. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  439. def forward(
  440. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  441. ) -> torch.FloatTensor:
  442. """
  443. Compute a single vector summary of a sequence hidden states.
  444. Args:
  445. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  446. The hidden states of the last layer.
  447. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  448. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  449. Returns:
  450. `torch.FloatTensor`: The summary of the sequence hidden states.
  451. """
  452. if self.summary_type == "last":
  453. output = hidden_states[:, -1]
  454. elif self.summary_type == "first":
  455. output = hidden_states[:, 0]
  456. elif self.summary_type == "mean":
  457. output = hidden_states.mean(dim=1)
  458. elif self.summary_type == "cls_index":
  459. if cls_index is None:
  460. cls_index = torch.full_like(
  461. hidden_states[..., :1, :],
  462. hidden_states.shape[-2] - 1,
  463. dtype=torch.long,
  464. )
  465. else:
  466. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  467. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  468. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  469. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  470. elif self.summary_type == "attn":
  471. raise NotImplementedError
  472. output = self.first_dropout(output)
  473. output = self.summary(output)
  474. output = self.activation(output)
  475. output = self.last_dropout(output)
  476. return output
  477. @auto_docstring
  478. class GPT2PreTrainedModel(PreTrainedModel):
  479. config: GPT2Config
  480. load_tf_weights = load_tf_weights_in_gpt2
  481. base_model_prefix = "transformer"
  482. is_parallelizable = True
  483. supports_gradient_checkpointing = True
  484. _no_split_modules = ["GPT2Block"]
  485. _skip_keys_device_placement = "past_key_values"
  486. _supports_flash_attn = True
  487. _supports_sdpa = True
  488. _supports_attention_backend = True
  489. _can_compile_fullgraph = True
  490. def __init__(self, *inputs, **kwargs):
  491. super().__init__(*inputs, **kwargs)
  492. def _init_weights(self, module):
  493. """Initialize the weights."""
  494. if isinstance(module, (nn.Linear, Conv1D)):
  495. # Slightly different from the TF version which uses truncated_normal for initialization
  496. # cf https://github.com/pytorch/pytorch/pull/5617
  497. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  498. if module.bias is not None:
  499. module.bias.data.zero_()
  500. elif isinstance(module, nn.Embedding):
  501. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  502. if module.padding_idx is not None:
  503. module.weight.data[module.padding_idx].zero_()
  504. elif isinstance(module, nn.LayerNorm):
  505. module.bias.data.zero_()
  506. module.weight.data.fill_(1.0)
  507. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  508. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  509. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  510. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  511. #
  512. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  513. for name, p in module.named_parameters():
  514. if name == "c_proj.weight":
  515. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  516. p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
  517. @dataclass
  518. @auto_docstring(
  519. custom_intro="""
  520. Base class for outputs of models predicting if two sentences are consecutive or not.
  521. """
  522. )
  523. class GPT2DoubleHeadsModelOutput(ModelOutput):
  524. r"""
  525. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  526. Language modeling loss.
  527. mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
  528. Multiple choice classification loss.
  529. logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
  530. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  531. mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  532. Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
  533. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  534. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  535. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
  536. `past_key_values` input) to speed up sequential decoding.
  537. """
  538. loss: Optional[torch.FloatTensor] = None
  539. mc_loss: Optional[torch.FloatTensor] = None
  540. logits: Optional[torch.FloatTensor] = None
  541. mc_logits: Optional[torch.FloatTensor] = None
  542. past_key_values: Optional[Cache] = None
  543. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  544. attentions: Optional[tuple[torch.FloatTensor]] = None
  545. PARALLELIZE_DOCSTRING = r"""
  546. This is an experimental feature and is a subject to change at a moment's notice.
  547. Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
  548. it will evenly distribute blocks across all devices.
  549. Args:
  550. device_map (`dict[int, list]`, *optional*):
  551. A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
  552. automatically mapped to the first device (for esoteric reasons). That means that the first device should
  553. have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
  554. following number of attention modules:
  555. - openai-community/gpt2: 12
  556. - openai-community/gpt2-medium: 24
  557. - openai-community/gpt2-large: 36
  558. - openai-community/gpt2-xl: 48
  559. Example:
  560. ```python
  561. # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
  562. model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl")
  563. device_map = {
  564. 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
  565. 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
  566. 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
  567. 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
  568. }
  569. model.parallelize(device_map)
  570. ```
  571. """
  572. DEPARALLELIZE_DOCSTRING = r"""
  573. Moves the model to cpu from a model parallel state.
  574. Example:
  575. ```python
  576. # On a 4 GPU machine with openai-community/gpt2-large:
  577. model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large")
  578. device_map = {
  579. 0: [0, 1, 2, 3, 4, 5, 6, 7],
  580. 1: [8, 9, 10, 11, 12, 13, 14, 15],
  581. 2: [16, 17, 18, 19, 20, 21, 22, 23],
  582. 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
  583. }
  584. model.parallelize(device_map) # Splits the model across several devices
  585. model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
  586. ```
  587. """
  588. @auto_docstring
  589. class GPT2Model(GPT2PreTrainedModel):
  590. _supports_param_buffer_assignment = False
  591. def __init__(self, config):
  592. super().__init__(config)
  593. self.embed_dim = config.hidden_size
  594. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  595. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  596. self.drop = nn.Dropout(config.embd_pdrop)
  597. self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  598. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  599. # Model parallel
  600. self.model_parallel = False
  601. self.device_map = None
  602. self.gradient_checkpointing = False
  603. self._attn_implementation = config._attn_implementation
  604. # Initialize weights and apply final processing
  605. self.post_init()
  606. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  607. def parallelize(self, device_map=None):
  608. # Check validity of device_map
  609. warnings.warn(
  610. "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
  611. " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  612. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
  613. " ...}",
  614. FutureWarning,
  615. )
  616. self.device_map = (
  617. get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
  618. )
  619. assert_device_map(self.device_map, len(self.h))
  620. self.model_parallel = True
  621. self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys()))
  622. self.last_device = "cuda:" + str(max(self.device_map.keys()))
  623. self.wte = self.wte.to(self.first_device)
  624. self.wpe = self.wpe.to(self.first_device)
  625. # Load onto devices
  626. for k, v in self.device_map.items():
  627. for block in v:
  628. cuda_device = "cuda:" + str(k)
  629. self.h[block] = self.h[block].to(cuda_device)
  630. # ln_f to last
  631. self.ln_f = self.ln_f.to(self.last_device)
  632. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  633. def deparallelize(self):
  634. warnings.warn(
  635. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  636. FutureWarning,
  637. )
  638. self.model_parallel = False
  639. self.device_map = None
  640. self.first_device = "cpu"
  641. self.last_device = "cpu"
  642. self.wte = self.wte.to("cpu")
  643. self.wpe = self.wpe.to("cpu")
  644. for index in range(len(self.h)):
  645. self.h[index] = self.h[index].to("cpu")
  646. self.ln_f = self.ln_f.to("cpu")
  647. torch.cuda.empty_cache()
  648. def get_input_embeddings(self):
  649. return self.wte
  650. def set_input_embeddings(self, new_embeddings):
  651. self.wte = new_embeddings
  652. def _prune_heads(self, heads_to_prune):
  653. """
  654. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  655. """
  656. for layer, heads in heads_to_prune.items():
  657. self.h[layer].attn.prune_heads(heads)
  658. @auto_docstring
  659. def forward(
  660. self,
  661. input_ids: Optional[torch.LongTensor] = None,
  662. past_key_values: Optional[Cache] = None,
  663. cache_position: Optional[torch.LongTensor] = None,
  664. attention_mask: Optional[torch.FloatTensor] = None,
  665. token_type_ids: Optional[torch.LongTensor] = None,
  666. position_ids: Optional[torch.LongTensor] = None,
  667. head_mask: Optional[torch.FloatTensor] = None,
  668. inputs_embeds: Optional[torch.FloatTensor] = None,
  669. encoder_hidden_states: Optional[torch.Tensor] = None,
  670. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  671. use_cache: Optional[bool] = None,
  672. output_attentions: Optional[bool] = None,
  673. output_hidden_states: Optional[bool] = None,
  674. return_dict: Optional[bool] = None,
  675. **kwargs,
  676. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  677. r"""
  678. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  679. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  680. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  681. sequence tokens in the vocabulary.
  682. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  683. `input_ids`.
  684. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  685. [`PreTrainedTokenizer.__call__`] for details.
  686. [What are input IDs?](../glossary#input-ids)
  687. """
  688. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  689. output_hidden_states = (
  690. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  691. )
  692. use_cache = use_cache if use_cache is not None else self.config.use_cache
  693. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  694. if input_ids is not None and inputs_embeds is not None:
  695. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  696. elif input_ids is not None:
  697. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  698. input_shape = input_ids.size()
  699. input_ids = input_ids.view(-1, input_shape[-1])
  700. batch_size = input_ids.shape[0]
  701. elif inputs_embeds is not None:
  702. input_shape = inputs_embeds.size()[:-1]
  703. batch_size = inputs_embeds.shape[0]
  704. else:
  705. raise ValueError("You have to specify either input_ids or inputs_embeds")
  706. device = input_ids.device if input_ids is not None else inputs_embeds.device
  707. if token_type_ids is not None:
  708. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  709. if self.gradient_checkpointing and self.training:
  710. if use_cache:
  711. logger.warning_once(
  712. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  713. )
  714. use_cache = False
  715. # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
  716. if use_cache:
  717. if past_key_values is None:
  718. past_key_values = DynamicCache(config=self.config)
  719. elif isinstance(past_key_values, tuple):
  720. logger.warning_once(
  721. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
  722. "You should pass an instance of `Cache` instead, e.g. "
  723. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  724. )
  725. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  726. if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
  727. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
  728. if inputs_embeds is None:
  729. inputs_embeds = self.wte(input_ids)
  730. if cache_position is None:
  731. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  732. cache_position = torch.arange(
  733. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  734. )
  735. if position_ids is None:
  736. position_ids = cache_position.unsqueeze(0)
  737. position_embeds = self.wpe(position_ids)
  738. hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
  739. # Attention mask.
  740. # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
  741. if attention_mask is not None and attention_mask.ndim < 4:
  742. attention_mask = attention_mask.view(batch_size, -1)
  743. causal_mask = create_causal_mask(
  744. config=self.config,
  745. input_embeds=inputs_embeds,
  746. attention_mask=attention_mask,
  747. cache_position=cache_position,
  748. past_key_values=past_key_values,
  749. position_ids=position_ids,
  750. )
  751. # If a 2D or 3D attention mask is provided for the cross-attention
  752. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  753. _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
  754. if self.config.add_cross_attention and encoder_hidden_states is not None:
  755. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  756. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  757. if encoder_attention_mask is None:
  758. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  759. if _use_sdpa:
  760. encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  761. mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
  762. )
  763. elif self._attn_implementation != "flash_attention_2":
  764. encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  765. else:
  766. encoder_attention_mask = None
  767. # Prepare head mask if needed
  768. # 1.0 in head_mask indicate we keep the head
  769. # attention_probs has shape bsz x n_heads x N x N
  770. # head_mask has shape n_layer x batch x n_heads x N x N
  771. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  772. if token_type_ids is not None:
  773. token_type_embeds = self.wte(token_type_ids)
  774. hidden_states = hidden_states + token_type_embeds
  775. hidden_states = self.drop(hidden_states)
  776. output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
  777. all_self_attentions = () if output_attentions else None
  778. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  779. all_hidden_states = () if output_hidden_states else None
  780. for i, block in enumerate(self.h):
  781. # Model parallel
  782. if self.model_parallel:
  783. torch.cuda.set_device(hidden_states.device)
  784. if isinstance(head_mask, torch.Tensor):
  785. head_mask = head_mask.to(hidden_states.device)
  786. if output_hidden_states:
  787. all_hidden_states = all_hidden_states + (hidden_states,)
  788. outputs = block(
  789. hidden_states,
  790. past_key_values if not (self.gradient_checkpointing and self.training) else None,
  791. cache_position,
  792. causal_mask,
  793. head_mask[i],
  794. encoder_hidden_states, # as a positional argument for gradient checkpointing
  795. encoder_attention_mask=encoder_attention_mask,
  796. use_cache=use_cache,
  797. output_attentions=output_attentions,
  798. **kwargs,
  799. )
  800. hidden_states = outputs[0]
  801. if output_attentions:
  802. all_self_attentions = all_self_attentions + (outputs[1],)
  803. if self.config.add_cross_attention:
  804. all_cross_attentions = all_cross_attentions + (outputs[2],)
  805. # Model Parallel: If it's the last layer for that device, put things on the next device
  806. if self.model_parallel:
  807. for k, v in self.device_map.items():
  808. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  809. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  810. hidden_states = self.ln_f(hidden_states)
  811. hidden_states = hidden_states.view(output_shape)
  812. # Add last hidden state
  813. if output_hidden_states:
  814. all_hidden_states = all_hidden_states + (hidden_states,)
  815. past_key_values = past_key_values if use_cache else None
  816. if not return_dict:
  817. return tuple(
  818. v
  819. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
  820. if v is not None
  821. )
  822. return BaseModelOutputWithPastAndCrossAttentions(
  823. last_hidden_state=hidden_states,
  824. past_key_values=past_key_values,
  825. hidden_states=all_hidden_states,
  826. attentions=all_self_attentions,
  827. cross_attentions=all_cross_attentions,
  828. )
  829. @auto_docstring(
  830. custom_intro="""
  831. The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
  832. embeddings).
  833. """
  834. )
  835. class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
  836. _tied_weights_keys = ["lm_head.weight"]
  837. def __init__(self, config):
  838. super().__init__(config)
  839. self.transformer = GPT2Model(config)
  840. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  841. # Model parallel
  842. self.model_parallel = False
  843. self.device_map = None
  844. # Initialize weights and apply final processing
  845. self.post_init()
  846. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  847. def parallelize(self, device_map=None):
  848. warnings.warn(
  849. "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
  850. " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  851. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
  852. " 0, 'transformer.h.1': 1, ...}",
  853. FutureWarning,
  854. )
  855. self.device_map = (
  856. get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
  857. if device_map is None
  858. else device_map
  859. )
  860. assert_device_map(self.device_map, len(self.transformer.h))
  861. self.transformer.parallelize(self.device_map)
  862. self.lm_head = self.lm_head.to(self.transformer.first_device)
  863. self.model_parallel = True
  864. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  865. def deparallelize(self):
  866. warnings.warn(
  867. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  868. FutureWarning,
  869. )
  870. self.transformer.deparallelize()
  871. self.transformer = self.transformer.to("cpu")
  872. self.lm_head = self.lm_head.to("cpu")
  873. self.model_parallel = False
  874. torch.cuda.empty_cache()
  875. @auto_docstring
  876. def forward(
  877. self,
  878. input_ids: Optional[torch.LongTensor] = None,
  879. past_key_values: Optional[Cache] = None,
  880. cache_position: Optional[torch.LongTensor] = None,
  881. attention_mask: Optional[torch.FloatTensor] = None,
  882. token_type_ids: Optional[torch.LongTensor] = None,
  883. position_ids: Optional[torch.LongTensor] = None,
  884. head_mask: Optional[torch.FloatTensor] = None,
  885. inputs_embeds: Optional[torch.FloatTensor] = None,
  886. encoder_hidden_states: Optional[torch.Tensor] = None,
  887. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  888. labels: Optional[torch.LongTensor] = None,
  889. use_cache: Optional[bool] = None,
  890. output_attentions: Optional[bool] = None,
  891. output_hidden_states: Optional[bool] = None,
  892. return_dict: Optional[bool] = None,
  893. logits_to_keep: Union[int, torch.Tensor] = 0,
  894. **kwargs,
  895. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  896. r"""
  897. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  898. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  899. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  900. sequence tokens in the vocabulary.
  901. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  902. `input_ids`.
  903. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  904. [`PreTrainedTokenizer.__call__`] for details.
  905. [What are input IDs?](../glossary#input-ids)
  906. labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  907. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  908. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  909. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  910. """
  911. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  912. transformer_outputs = self.transformer(
  913. input_ids,
  914. past_key_values=past_key_values,
  915. attention_mask=attention_mask,
  916. cache_position=cache_position,
  917. token_type_ids=token_type_ids,
  918. position_ids=position_ids,
  919. head_mask=head_mask,
  920. inputs_embeds=inputs_embeds,
  921. encoder_hidden_states=encoder_hidden_states,
  922. encoder_attention_mask=encoder_attention_mask,
  923. use_cache=use_cache,
  924. output_attentions=output_attentions,
  925. output_hidden_states=output_hidden_states,
  926. return_dict=return_dict,
  927. )
  928. hidden_states = transformer_outputs[0]
  929. # Set device for model parallelism
  930. if self.model_parallel:
  931. torch.cuda.set_device(self.transformer.first_device)
  932. hidden_states = hidden_states.to(self.lm_head.weight.device)
  933. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  934. logits = self.lm_head(hidden_states[:, slice_indices, :])
  935. loss = None
  936. if labels is not None:
  937. # Flatten the tokens
  938. loss = self.loss_function(
  939. logits,
  940. labels,
  941. vocab_size=self.config.vocab_size,
  942. **kwargs,
  943. )
  944. if not return_dict:
  945. output = (logits,) + transformer_outputs[1:]
  946. return ((loss,) + output) if loss is not None else output
  947. return CausalLMOutputWithCrossAttentions(
  948. loss=loss,
  949. logits=logits,
  950. past_key_values=transformer_outputs.past_key_values,
  951. hidden_states=transformer_outputs.hidden_states,
  952. attentions=transformer_outputs.attentions,
  953. cross_attentions=transformer_outputs.cross_attentions,
  954. )
  955. @auto_docstring(
  956. custom_intro="""
  957. The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
  958. RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
  959. input embeddings, the classification head takes as input the input of a specified classification token index in the
  960. input sequence).
  961. """
  962. )
  963. class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
  964. _tied_weights_keys = ["lm_head.weight"]
  965. def __init__(self, config):
  966. super().__init__(config)
  967. config.num_labels = 1
  968. self.transformer = GPT2Model(config)
  969. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  970. self.multiple_choice_head = GPT2SequenceSummary(config)
  971. # Model parallel
  972. self.model_parallel = False
  973. self.device_map = None
  974. # Initialize weights and apply final processing
  975. self.post_init()
  976. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  977. def parallelize(self, device_map=None):
  978. warnings.warn(
  979. "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should"
  980. " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your"
  981. " own `device_map` but it needs to be a dictionary module_name to device, so for instance"
  982. " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}",
  983. FutureWarning,
  984. )
  985. self.device_map = (
  986. get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
  987. if device_map is None
  988. else device_map
  989. )
  990. assert_device_map(self.device_map, len(self.transformer.h))
  991. self.transformer.parallelize(self.device_map)
  992. self.lm_head = self.lm_head.to(self.transformer.first_device)
  993. self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device)
  994. self.model_parallel = True
  995. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  996. def deparallelize(self):
  997. warnings.warn(
  998. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  999. FutureWarning,
  1000. )
  1001. self.transformer.deparallelize()
  1002. self.transformer = self.transformer.to("cpu")
  1003. self.lm_head = self.lm_head.to("cpu")
  1004. self.multiple_choice_head = self.multiple_choice_head.to("cpu")
  1005. self.model_parallel = False
  1006. torch.cuda.empty_cache()
  1007. @auto_docstring
  1008. def forward(
  1009. self,
  1010. input_ids: Optional[torch.LongTensor] = None,
  1011. past_key_values: Optional[Cache] = None,
  1012. cache_position: Optional[torch.LongTensor] = None,
  1013. attention_mask: Optional[torch.FloatTensor] = None,
  1014. token_type_ids: Optional[torch.LongTensor] = None,
  1015. position_ids: Optional[torch.LongTensor] = None,
  1016. head_mask: Optional[torch.FloatTensor] = None,
  1017. inputs_embeds: Optional[torch.FloatTensor] = None,
  1018. mc_token_ids: Optional[torch.LongTensor] = None,
  1019. labels: Optional[torch.LongTensor] = None,
  1020. mc_labels: Optional[torch.LongTensor] = None,
  1021. use_cache: Optional[bool] = None,
  1022. output_attentions: Optional[bool] = None,
  1023. output_hidden_states: Optional[bool] = None,
  1024. return_dict: Optional[bool] = None,
  1025. **kwargs,
  1026. ) -> Union[tuple, GPT2DoubleHeadsModelOutput]:
  1027. r"""
  1028. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  1029. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  1030. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  1031. sequence tokens in the vocabulary.
  1032. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  1033. `input_ids`.
  1034. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1035. [`PreTrainedTokenizer.__call__`] for details.
  1036. [What are input IDs?](../glossary#input-ids)
  1037. mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
  1038. Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
  1039. 1]`.
  1040. labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  1041. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  1042. `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to
  1043. `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]`
  1044. mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
  1045. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  1046. where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
  1047. Example:
  1048. ```python
  1049. >>> import torch
  1050. >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel
  1051. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  1052. >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2")
  1053. >>> # Add a [CLS] to the vocabulary (we should train it also!)
  1054. >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
  1055. >>> # Update the model embeddings with the new vocabulary size
  1056. >>> embedding_layer = model.resize_token_embeddings(len(tokenizer))
  1057. >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
  1058. >>> encoded_choices = [tokenizer.encode(s) for s in choices]
  1059. >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
  1060. >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
  1061. >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
  1062. >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
  1063. >>> lm_logits = outputs.logits
  1064. >>> mc_logits = outputs.mc_logits
  1065. ```"""
  1066. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1067. transformer_outputs = self.transformer(
  1068. input_ids,
  1069. past_key_values=past_key_values,
  1070. cache_position=cache_position,
  1071. attention_mask=attention_mask,
  1072. token_type_ids=token_type_ids,
  1073. position_ids=position_ids,
  1074. head_mask=head_mask,
  1075. inputs_embeds=inputs_embeds,
  1076. use_cache=use_cache,
  1077. output_attentions=output_attentions,
  1078. output_hidden_states=output_hidden_states,
  1079. return_dict=return_dict,
  1080. )
  1081. hidden_states = transformer_outputs[0]
  1082. # Set device for model parallelism
  1083. if self.model_parallel:
  1084. torch.cuda.set_device(self.transformer.first_device)
  1085. hidden_states = hidden_states.to(self.lm_head.weight.device)
  1086. lm_logits = self.lm_head(hidden_states)
  1087. mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
  1088. mc_loss = None
  1089. if mc_labels is not None:
  1090. loss_fct = CrossEntropyLoss()
  1091. mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
  1092. lm_loss = None
  1093. if labels is not None:
  1094. labels = labels.to(lm_logits.device)
  1095. shift_logits = lm_logits[..., :-1, :].contiguous()
  1096. shift_labels = labels[..., 1:].contiguous()
  1097. loss_fct = CrossEntropyLoss()
  1098. lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  1099. if not return_dict:
  1100. output = (lm_logits, mc_logits) + transformer_outputs[1:]
  1101. if mc_loss is not None:
  1102. output = (mc_loss,) + output
  1103. return ((lm_loss,) + output) if lm_loss is not None else output
  1104. return GPT2DoubleHeadsModelOutput(
  1105. loss=lm_loss,
  1106. mc_loss=mc_loss,
  1107. logits=lm_logits,
  1108. mc_logits=mc_logits,
  1109. past_key_values=transformer_outputs.past_key_values,
  1110. hidden_states=transformer_outputs.hidden_states,
  1111. attentions=transformer_outputs.attentions,
  1112. )
  1113. @auto_docstring(
  1114. custom_intro="""
  1115. The GPT2 Model transformer with a sequence classification head on top (linear layer).
  1116. [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1117. (e.g. GPT-1) do.
  1118. Since it does classification on the last token, it requires to know the position of the last token. If a
  1119. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1120. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1121. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1122. each row of the batch).
  1123. """
  1124. )
  1125. class GPT2ForSequenceClassification(GPT2PreTrainedModel):
  1126. def __init__(self, config):
  1127. super().__init__(config)
  1128. self.num_labels = config.num_labels
  1129. self.transformer = GPT2Model(config)
  1130. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  1131. # Model parallel
  1132. self.model_parallel = False
  1133. self.device_map = None
  1134. # Initialize weights and apply final processing
  1135. self.post_init()
  1136. @auto_docstring
  1137. def forward(
  1138. self,
  1139. input_ids: Optional[torch.LongTensor] = None,
  1140. past_key_values: Optional[Cache] = None,
  1141. attention_mask: Optional[torch.FloatTensor] = None,
  1142. token_type_ids: Optional[torch.LongTensor] = None,
  1143. position_ids: Optional[torch.LongTensor] = None,
  1144. head_mask: Optional[torch.FloatTensor] = None,
  1145. inputs_embeds: Optional[torch.FloatTensor] = None,
  1146. labels: Optional[torch.LongTensor] = None,
  1147. use_cache: Optional[bool] = None,
  1148. output_attentions: Optional[bool] = None,
  1149. output_hidden_states: Optional[bool] = None,
  1150. return_dict: Optional[bool] = None,
  1151. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  1152. r"""
  1153. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  1154. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  1155. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  1156. sequence tokens in the vocabulary.
  1157. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  1158. `input_ids`.
  1159. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1160. [`PreTrainedTokenizer.__call__`] for details.
  1161. [What are input IDs?](../glossary#input-ids)
  1162. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1163. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1164. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1165. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1166. """
  1167. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1168. transformer_outputs = self.transformer(
  1169. input_ids,
  1170. past_key_values=past_key_values,
  1171. attention_mask=attention_mask,
  1172. token_type_ids=token_type_ids,
  1173. position_ids=position_ids,
  1174. head_mask=head_mask,
  1175. inputs_embeds=inputs_embeds,
  1176. use_cache=use_cache,
  1177. output_attentions=output_attentions,
  1178. output_hidden_states=output_hidden_states,
  1179. return_dict=return_dict,
  1180. )
  1181. hidden_states = transformer_outputs[0]
  1182. logits = self.score(hidden_states)
  1183. if input_ids is not None:
  1184. batch_size, sequence_length = input_ids.shape[:2]
  1185. else:
  1186. batch_size, sequence_length = inputs_embeds.shape[:2]
  1187. if self.config.pad_token_id is None and batch_size != 1:
  1188. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1189. if self.config.pad_token_id is None:
  1190. last_non_pad_token = -1
  1191. elif input_ids is not None:
  1192. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  1193. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  1194. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  1195. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  1196. else:
  1197. last_non_pad_token = -1
  1198. logger.warning_once(
  1199. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  1200. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  1201. )
  1202. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  1203. loss = None
  1204. if labels is not None:
  1205. if self.config.problem_type is None:
  1206. if self.num_labels == 1:
  1207. self.config.problem_type = "regression"
  1208. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1209. self.config.problem_type = "single_label_classification"
  1210. else:
  1211. self.config.problem_type = "multi_label_classification"
  1212. if self.config.problem_type == "regression":
  1213. loss_fct = MSELoss()
  1214. if self.num_labels == 1:
  1215. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  1216. else:
  1217. loss = loss_fct(pooled_logits, labels)
  1218. elif self.config.problem_type == "single_label_classification":
  1219. loss_fct = CrossEntropyLoss()
  1220. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  1221. elif self.config.problem_type == "multi_label_classification":
  1222. loss_fct = BCEWithLogitsLoss()
  1223. loss = loss_fct(pooled_logits, labels)
  1224. if not return_dict:
  1225. output = (pooled_logits,) + transformer_outputs[1:]
  1226. return ((loss,) + output) if loss is not None else output
  1227. return SequenceClassifierOutputWithPast(
  1228. loss=loss,
  1229. logits=pooled_logits,
  1230. past_key_values=transformer_outputs.past_key_values,
  1231. hidden_states=transformer_outputs.hidden_states,
  1232. attentions=transformer_outputs.attentions,
  1233. )
  1234. @auto_docstring
  1235. class GPT2ForTokenClassification(GPT2PreTrainedModel):
  1236. def __init__(self, config):
  1237. super().__init__(config)
  1238. self.num_labels = config.num_labels
  1239. self.transformer = GPT2Model(config)
  1240. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  1241. classifier_dropout = config.classifier_dropout
  1242. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  1243. classifier_dropout = config.hidden_dropout
  1244. else:
  1245. classifier_dropout = 0.1
  1246. self.dropout = nn.Dropout(classifier_dropout)
  1247. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1248. # Model parallel
  1249. self.model_parallel = False
  1250. self.device_map = None
  1251. # Initialize weights and apply final processing
  1252. self.post_init()
  1253. @auto_docstring
  1254. def forward(
  1255. self,
  1256. input_ids: Optional[torch.LongTensor] = None,
  1257. past_key_values: Optional[Cache] = None,
  1258. attention_mask: Optional[torch.FloatTensor] = None,
  1259. token_type_ids: Optional[torch.LongTensor] = None,
  1260. position_ids: Optional[torch.LongTensor] = None,
  1261. head_mask: Optional[torch.FloatTensor] = None,
  1262. inputs_embeds: Optional[torch.FloatTensor] = None,
  1263. labels: Optional[torch.LongTensor] = None,
  1264. use_cache: Optional[bool] = None,
  1265. output_attentions: Optional[bool] = None,
  1266. output_hidden_states: Optional[bool] = None,
  1267. return_dict: Optional[bool] = None,
  1268. ) -> Union[tuple, TokenClassifierOutput]:
  1269. r"""
  1270. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  1271. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  1272. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  1273. sequence tokens in the vocabulary.
  1274. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  1275. `input_ids`.
  1276. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1277. [`PreTrainedTokenizer.__call__`] for details.
  1278. [What are input IDs?](../glossary#input-ids)
  1279. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1280. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1281. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1282. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1283. """
  1284. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1285. transformer_outputs = self.transformer(
  1286. input_ids,
  1287. past_key_values=past_key_values,
  1288. attention_mask=attention_mask,
  1289. token_type_ids=token_type_ids,
  1290. position_ids=position_ids,
  1291. head_mask=head_mask,
  1292. inputs_embeds=inputs_embeds,
  1293. use_cache=use_cache,
  1294. output_attentions=output_attentions,
  1295. output_hidden_states=output_hidden_states,
  1296. return_dict=return_dict,
  1297. )
  1298. hidden_states = transformer_outputs[0]
  1299. hidden_states = self.dropout(hidden_states)
  1300. logits = self.classifier(hidden_states)
  1301. loss = None
  1302. if labels is not None:
  1303. labels = labels.to(logits.device)
  1304. loss_fct = CrossEntropyLoss()
  1305. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1306. if not return_dict:
  1307. output = (logits,) + transformer_outputs[2:]
  1308. return ((loss,) + output) if loss is not None else output
  1309. return TokenClassifierOutput(
  1310. loss=loss,
  1311. logits=logits,
  1312. hidden_states=transformer_outputs.hidden_states,
  1313. attentions=transformer_outputs.attentions,
  1314. )
  1315. @auto_docstring
  1316. class GPT2ForQuestionAnswering(GPT2PreTrainedModel):
  1317. def __init__(self, config):
  1318. super().__init__(config)
  1319. self.num_labels = config.num_labels
  1320. self.transformer = GPT2Model(config)
  1321. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  1322. # Model parallel
  1323. self.model_parallel = False
  1324. self.device_map = None
  1325. # Initialize weights and apply final processing
  1326. self.post_init()
  1327. @auto_docstring
  1328. def forward(
  1329. self,
  1330. input_ids: Optional[torch.LongTensor] = None,
  1331. attention_mask: Optional[torch.FloatTensor] = None,
  1332. token_type_ids: Optional[torch.LongTensor] = None,
  1333. position_ids: Optional[torch.LongTensor] = None,
  1334. head_mask: Optional[torch.FloatTensor] = None,
  1335. inputs_embeds: Optional[torch.FloatTensor] = None,
  1336. start_positions: Optional[torch.LongTensor] = None,
  1337. end_positions: Optional[torch.LongTensor] = None,
  1338. output_attentions: Optional[bool] = None,
  1339. output_hidden_states: Optional[bool] = None,
  1340. return_dict: Optional[bool] = None,
  1341. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  1342. r"""
  1343. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  1344. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  1345. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  1346. sequence tokens in the vocabulary.
  1347. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  1348. `input_ids`.
  1349. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1350. [`PreTrainedTokenizer.__call__`] for details.
  1351. [What are input IDs?](../glossary#input-ids)
  1352. """
  1353. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1354. outputs = self.transformer(
  1355. input_ids,
  1356. attention_mask=attention_mask,
  1357. token_type_ids=token_type_ids,
  1358. position_ids=position_ids,
  1359. head_mask=head_mask,
  1360. inputs_embeds=inputs_embeds,
  1361. output_attentions=output_attentions,
  1362. output_hidden_states=output_hidden_states,
  1363. return_dict=return_dict,
  1364. )
  1365. sequence_output = outputs[0]
  1366. logits = self.qa_outputs(sequence_output)
  1367. start_logits, end_logits = logits.split(1, dim=-1)
  1368. start_logits = start_logits.squeeze(-1).contiguous()
  1369. end_logits = end_logits.squeeze(-1).contiguous()
  1370. total_loss = None
  1371. if start_positions is not None and end_positions is not None:
  1372. # If we are on multi-GPU, split add a dimension
  1373. if len(start_positions.size()) > 1:
  1374. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  1375. if len(end_positions.size()) > 1:
  1376. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  1377. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1378. ignored_index = start_logits.size(1)
  1379. start_positions = start_positions.clamp(0, ignored_index)
  1380. end_positions = end_positions.clamp(0, ignored_index)
  1381. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1382. start_loss = loss_fct(start_logits, start_positions)
  1383. end_loss = loss_fct(end_logits, end_positions)
  1384. total_loss = (start_loss + end_loss) / 2
  1385. if not return_dict:
  1386. output = (start_logits, end_logits) + outputs[2:]
  1387. return ((total_loss,) + output) if total_loss is not None else output
  1388. return QuestionAnsweringModelOutput(
  1389. loss=total_loss,
  1390. start_logits=start_logits,
  1391. end_logits=end_logits,
  1392. hidden_states=outputs.hidden_states,
  1393. attentions=outputs.attentions,
  1394. )
  1395. __all__ = [
  1396. "GPT2DoubleHeadsModel",
  1397. "GPT2ForQuestionAnswering",
  1398. "GPT2ForSequenceClassification",
  1399. "GPT2ForTokenClassification",
  1400. "GPT2LMHeadModel",
  1401. "GPT2Model",
  1402. "GPT2PreTrainedModel",
  1403. "load_tf_weights_in_gpt2",
  1404. ]