modeling_openai.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. # coding=utf-8
  2. # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch OpenAI GPT model."""
  17. import json
  18. import math
  19. import os
  20. from dataclasses import dataclass
  21. from typing import Any, Callable, Optional, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ...activations import gelu_new, get_activation, silu
  26. from ...generation import GenerationMixin
  27. from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
  28. from ...modeling_utils import PreTrainedModel
  29. from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
  30. from ...utils import (
  31. ModelOutput,
  32. auto_docstring,
  33. logging,
  34. )
  35. from .configuration_openai import OpenAIGPTConfig
  36. logger = logging.get_logger(__name__)
  37. def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
  38. """Load tf pre-trained weights in a pytorch model (from NumPy arrays here)"""
  39. import re
  40. import numpy as np
  41. if ".ckpt" in openai_checkpoint_folder_path:
  42. openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
  43. logger.info(f"Loading weights from {openai_checkpoint_folder_path}")
  44. with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
  45. names = json.load(names_handle)
  46. with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
  47. shapes = json.load(shapes_handle)
  48. offsets = np.cumsum([np.prod(shape) for shape in shapes])
  49. init_params = [np.load(openai_checkpoint_folder_path + f"/params_{n}.npy") for n in range(10)]
  50. init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
  51. init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
  52. # This was used when we had a single embedding matrix for positions and tokens
  53. # init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
  54. # del init_params[1]
  55. init_params = [arr.squeeze() for arr in init_params]
  56. # Check that the token and position embeddings weight dimensions map those of the init parameters.
  57. if model.tokens_embed.weight.shape != init_params[1].shape:
  58. raise ValueError(
  59. f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape:"
  60. f" {init_params[1].shape}"
  61. )
  62. if model.positions_embed.weight.shape != init_params[0].shape:
  63. raise ValueError(
  64. f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape:"
  65. f" {init_params[0].shape}"
  66. )
  67. model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
  68. model.positions_embed.weight.data = torch.from_numpy(init_params[0])
  69. names.pop(0)
  70. # Pop position and token embedding arrays
  71. init_params.pop(0)
  72. init_params.pop(0)
  73. for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
  74. name = name[6:] # skip "model/"
  75. if name[-2:] != ":0":
  76. raise ValueError(f"Layer {name} does not end with :0")
  77. name = name[:-2]
  78. name = name.split("/")
  79. pointer = model
  80. for m_name in name:
  81. if re.fullmatch(r"[A-Za-z]+\d+", m_name):
  82. scope_names = re.split(r"(\d+)", m_name)
  83. else:
  84. scope_names = [m_name]
  85. if scope_names[0] == "g":
  86. pointer = getattr(pointer, "weight")
  87. elif scope_names[0] == "b":
  88. pointer = getattr(pointer, "bias")
  89. elif scope_names[0] == "w":
  90. pointer = getattr(pointer, "weight")
  91. else:
  92. pointer = getattr(pointer, scope_names[0])
  93. if len(scope_names) >= 2:
  94. num = int(scope_names[1])
  95. pointer = pointer[num]
  96. # Ensure that the pointer and array have compatible shapes.
  97. if pointer.shape != array.shape:
  98. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  99. logger.info(f"Initialize PyTorch weight {name}")
  100. pointer.data = torch.from_numpy(array)
  101. return model
  102. ACT_FNS = {"relu": nn.ReLU(), "silu": silu, "gelu": gelu_new, "swish": silu}
  103. class Attention(nn.Module):
  104. def __init__(self, nx, n_positions, config, scale=False):
  105. super().__init__()
  106. n_state = nx # in Attention: n_state=768 (nx=n_embd)
  107. # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
  108. if n_state % config.n_head != 0:
  109. raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
  110. self.register_buffer(
  111. "bias",
  112. torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions),
  113. persistent=False,
  114. )
  115. self.n_head = config.n_head
  116. self.split_size = n_state
  117. self.scale = scale
  118. self.c_attn = Conv1D(n_state * 3, nx)
  119. self.c_proj = Conv1D(n_state, nx)
  120. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  121. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  122. self.pruned_heads = set()
  123. def prune_heads(self, heads):
  124. if len(heads) == 0:
  125. return
  126. heads, index = find_pruneable_heads_and_indices(
  127. heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
  128. )
  129. index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
  130. # Prune conv1d layers
  131. self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
  132. self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
  133. # Update hyper params
  134. self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
  135. self.n_head = self.n_head - len(heads)
  136. self.pruned_heads = self.pruned_heads.union(heads)
  137. def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
  138. w = torch.matmul(q, k)
  139. if self.scale:
  140. w = w / math.sqrt(v.size(-1))
  141. # w = w * self.bias + -1e9 * (1 - self.bias) # TF implementation method: mask_attn_weights
  142. # XD: self.b may be larger than w, so we need to crop it
  143. b = self.bias[:, :, : w.size(-2), : w.size(-1)]
  144. w = w * b + -1e4 * (1 - b)
  145. if attention_mask is not None:
  146. # Apply the attention mask
  147. w = w + attention_mask
  148. w = nn.functional.softmax(w, dim=-1)
  149. w = self.attn_dropout(w)
  150. # Mask heads if we want to
  151. if head_mask is not None:
  152. w = w * head_mask
  153. outputs = [torch.matmul(w, v)]
  154. if output_attentions:
  155. outputs.append(w)
  156. return outputs
  157. def merge_heads(self, x):
  158. x = x.permute(0, 2, 1, 3).contiguous()
  159. new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
  160. return x.view(*new_x_shape) # in Tensorflow implementation: fct merge_states
  161. def split_heads(self, x, k=False):
  162. new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
  163. x = x.view(*new_x_shape) # in Tensorflow implementation: fct split_states
  164. if k:
  165. return x.permute(0, 2, 3, 1)
  166. else:
  167. return x.permute(0, 2, 1, 3)
  168. def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
  169. x = self.c_attn(x)
  170. query, key, value = x.split(self.split_size, dim=2)
  171. query = self.split_heads(query)
  172. key = self.split_heads(key, k=True)
  173. value = self.split_heads(value)
  174. attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
  175. a = attn_outputs[0]
  176. a = self.merge_heads(a)
  177. a = self.c_proj(a)
  178. a = self.resid_dropout(a)
  179. outputs = [a] + attn_outputs[1:]
  180. return outputs # a, (attentions)
  181. class MLP(nn.Module):
  182. def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
  183. super().__init__()
  184. nx = config.n_embd
  185. self.c_fc = Conv1D(n_state, nx)
  186. self.c_proj = Conv1D(nx, n_state)
  187. self.act = ACT_FNS[config.afn]
  188. self.dropout = nn.Dropout(config.resid_pdrop)
  189. def forward(self, x):
  190. h = self.act(self.c_fc(x))
  191. h2 = self.c_proj(h)
  192. return self.dropout(h2)
  193. class Block(nn.Module):
  194. def __init__(self, n_positions, config, scale=False):
  195. super().__init__()
  196. nx = config.n_embd
  197. self.attn = Attention(nx, n_positions, config, scale)
  198. self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
  199. self.mlp = MLP(4 * nx, config)
  200. self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
  201. def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
  202. attn_outputs = self.attn(
  203. x,
  204. attention_mask=attention_mask,
  205. head_mask=head_mask,
  206. output_attentions=output_attentions,
  207. )
  208. a = attn_outputs[0]
  209. n = self.ln_1(x + a)
  210. m = self.mlp(n)
  211. h = self.ln_2(n + m)
  212. outputs = [h] + attn_outputs[1:]
  213. return outputs
  214. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->OpenAIGPT
  215. class OpenAIGPTSequenceSummary(nn.Module):
  216. r"""
  217. Compute a single vector summary of a sequence hidden states.
  218. Args:
  219. config ([`OpenAIGPTConfig`]):
  220. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  221. config class of your model for the default values it uses):
  222. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  223. - `"last"` -- Take the last token hidden state (like XLNet)
  224. - `"first"` -- Take the first token hidden state (like Bert)
  225. - `"mean"` -- Take the mean of all tokens hidden states
  226. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  227. - `"attn"` -- Not implemented now, use multi-head attention
  228. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  229. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  230. (otherwise to `config.hidden_size`).
  231. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  232. another string or `None` will add no activation.
  233. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  234. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  235. """
  236. def __init__(self, config: OpenAIGPTConfig):
  237. super().__init__()
  238. self.summary_type = getattr(config, "summary_type", "last")
  239. if self.summary_type == "attn":
  240. # We should use a standard multi-head attention module with absolute positional embedding for that.
  241. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  242. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  243. raise NotImplementedError
  244. self.summary = nn.Identity()
  245. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  246. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  247. num_classes = config.num_labels
  248. else:
  249. num_classes = config.hidden_size
  250. self.summary = nn.Linear(config.hidden_size, num_classes)
  251. activation_string = getattr(config, "summary_activation", None)
  252. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  253. self.first_dropout = nn.Identity()
  254. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  255. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  256. self.last_dropout = nn.Identity()
  257. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  258. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  259. def forward(
  260. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  261. ) -> torch.FloatTensor:
  262. """
  263. Compute a single vector summary of a sequence hidden states.
  264. Args:
  265. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  266. The hidden states of the last layer.
  267. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  268. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  269. Returns:
  270. `torch.FloatTensor`: The summary of the sequence hidden states.
  271. """
  272. if self.summary_type == "last":
  273. output = hidden_states[:, -1]
  274. elif self.summary_type == "first":
  275. output = hidden_states[:, 0]
  276. elif self.summary_type == "mean":
  277. output = hidden_states.mean(dim=1)
  278. elif self.summary_type == "cls_index":
  279. if cls_index is None:
  280. cls_index = torch.full_like(
  281. hidden_states[..., :1, :],
  282. hidden_states.shape[-2] - 1,
  283. dtype=torch.long,
  284. )
  285. else:
  286. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  287. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  288. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  289. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  290. elif self.summary_type == "attn":
  291. raise NotImplementedError
  292. output = self.first_dropout(output)
  293. output = self.summary(output)
  294. output = self.activation(output)
  295. output = self.last_dropout(output)
  296. return output
  297. @auto_docstring
  298. class OpenAIGPTPreTrainedModel(PreTrainedModel):
  299. config: OpenAIGPTConfig
  300. load_tf_weights = load_tf_weights_in_openai_gpt
  301. base_model_prefix = "transformer"
  302. def _init_weights(self, module):
  303. """Initialize the weights."""
  304. if isinstance(module, (nn.Linear, Conv1D)):
  305. # Slightly different from the TF version which uses truncated_normal for initialization
  306. # cf https://github.com/pytorch/pytorch/pull/5617
  307. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  308. if module.bias is not None:
  309. module.bias.data.zero_()
  310. elif isinstance(module, nn.Embedding):
  311. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  312. if module.padding_idx is not None:
  313. module.weight.data[module.padding_idx].zero_()
  314. elif isinstance(module, nn.LayerNorm):
  315. module.bias.data.zero_()
  316. module.weight.data.fill_(1.0)
  317. @dataclass
  318. @auto_docstring(
  319. custom_intro="""
  320. Base class for outputs of models predicting if two sentences are consecutive or not.
  321. """
  322. )
  323. class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
  324. r"""
  325. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  326. Language modeling loss.
  327. mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
  328. Multiple choice classification loss.
  329. logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
  330. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  331. mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
  332. Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
  333. """
  334. loss: Optional[torch.FloatTensor] = None
  335. mc_loss: Optional[torch.FloatTensor] = None
  336. logits: Optional[torch.FloatTensor] = None
  337. mc_logits: Optional[torch.FloatTensor] = None
  338. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  339. attentions: Optional[tuple[torch.FloatTensor]] = None
  340. @auto_docstring
  341. class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
  342. def __init__(self, config):
  343. super().__init__(config)
  344. self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
  345. self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
  346. self.drop = nn.Dropout(config.embd_pdrop)
  347. self.h = nn.ModuleList([Block(config.n_positions, config, scale=True) for _ in range(config.n_layer)])
  348. self.register_buffer("position_ids", torch.arange(config.n_positions), persistent=False)
  349. # Initialize weights and apply final processing
  350. self.post_init()
  351. def get_input_embeddings(self):
  352. return self.tokens_embed
  353. def set_input_embeddings(self, new_embeddings):
  354. self.tokens_embed = new_embeddings
  355. def _prune_heads(self, heads_to_prune):
  356. """
  357. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  358. """
  359. for layer, heads in heads_to_prune.items():
  360. self.h[layer].attn.prune_heads(heads)
  361. @auto_docstring
  362. def forward(
  363. self,
  364. input_ids: Optional[torch.LongTensor] = None,
  365. attention_mask: Optional[torch.FloatTensor] = None,
  366. token_type_ids: Optional[torch.LongTensor] = None,
  367. position_ids: Optional[torch.LongTensor] = None,
  368. head_mask: Optional[torch.FloatTensor] = None,
  369. inputs_embeds: Optional[torch.FloatTensor] = None,
  370. output_attentions: Optional[bool] = None,
  371. output_hidden_states: Optional[bool] = None,
  372. return_dict: Optional[bool] = None,
  373. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  374. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  375. output_hidden_states = (
  376. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  377. )
  378. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  379. if input_ids is not None and inputs_embeds is not None:
  380. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  381. elif input_ids is not None:
  382. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  383. input_shape = input_ids.size()
  384. input_ids = input_ids.view(-1, input_shape[-1])
  385. elif inputs_embeds is not None:
  386. input_shape = inputs_embeds.size()[:-1]
  387. else:
  388. raise ValueError("You have to specify either input_ids or inputs_embeds")
  389. if position_ids is None:
  390. # Code is different from when we had a single embedding matrix from position and token embeddings
  391. position_ids = self.position_ids[None, : input_shape[-1]]
  392. # Attention mask.
  393. if attention_mask is not None:
  394. # We create a 3D attention mask from a 2D tensor mask.
  395. # Sizes are [batch_size, 1, 1, to_seq_length]
  396. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  397. # this attention mask is more simple than the triangular masking of causal attention
  398. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  399. attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  400. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  401. # masked positions, this operation will create a tensor which is 0.0 for
  402. # positions we want to attend and the dtype's smallest value for masked positions.
  403. # Since we are adding it to the raw scores before the softmax, this is
  404. # effectively the same as removing these entirely.
  405. attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
  406. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  407. # Prepare head mask if needed
  408. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  409. if inputs_embeds is None:
  410. inputs_embeds = self.tokens_embed(input_ids)
  411. position_embeds = self.positions_embed(position_ids)
  412. if token_type_ids is not None:
  413. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
  414. token_type_embeds = self.tokens_embed(token_type_ids)
  415. else:
  416. token_type_embeds = 0
  417. hidden_states = inputs_embeds + position_embeds + token_type_embeds
  418. hidden_states = self.drop(hidden_states)
  419. output_shape = input_shape + (hidden_states.size(-1),)
  420. all_attentions = () if output_attentions else None
  421. all_hidden_states = () if output_hidden_states else None
  422. for i, block in enumerate(self.h):
  423. if output_hidden_states:
  424. all_hidden_states = all_hidden_states + (hidden_states,)
  425. outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
  426. hidden_states = outputs[0]
  427. if output_attentions:
  428. all_attentions = all_attentions + (outputs[1],)
  429. hidden_states = hidden_states.view(*output_shape)
  430. # Add last layer
  431. if output_hidden_states:
  432. all_hidden_states = all_hidden_states + (hidden_states,)
  433. if not return_dict:
  434. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  435. return BaseModelOutput(
  436. last_hidden_state=hidden_states,
  437. hidden_states=all_hidden_states,
  438. attentions=all_attentions,
  439. )
  440. @auto_docstring(
  441. custom_intro="""
  442. OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  443. embeddings).
  444. """
  445. )
  446. class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin):
  447. _tied_weights_keys = ["lm_head.weight"]
  448. def __init__(self, config):
  449. super().__init__(config)
  450. self.transformer = OpenAIGPTModel(config)
  451. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  452. # Initialize weights and apply final processing
  453. self.post_init()
  454. @auto_docstring
  455. def forward(
  456. self,
  457. input_ids: Optional[torch.LongTensor] = None,
  458. attention_mask: Optional[torch.FloatTensor] = None,
  459. token_type_ids: Optional[torch.LongTensor] = None,
  460. position_ids: Optional[torch.LongTensor] = None,
  461. head_mask: Optional[torch.FloatTensor] = None,
  462. inputs_embeds: Optional[torch.FloatTensor] = None,
  463. labels: Optional[torch.LongTensor] = None,
  464. output_attentions: Optional[bool] = None,
  465. output_hidden_states: Optional[bool] = None,
  466. return_dict: Optional[bool] = None,
  467. **kwargs,
  468. ) -> Union[tuple[torch.Tensor], CausalLMOutput]:
  469. r"""
  470. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  471. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  472. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  473. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  474. """
  475. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  476. transformer_outputs = self.transformer(
  477. input_ids,
  478. attention_mask=attention_mask,
  479. token_type_ids=token_type_ids,
  480. position_ids=position_ids,
  481. head_mask=head_mask,
  482. inputs_embeds=inputs_embeds,
  483. output_attentions=output_attentions,
  484. output_hidden_states=output_hidden_states,
  485. return_dict=return_dict,
  486. )
  487. hidden_states = transformer_outputs[0]
  488. lm_logits = self.lm_head(hidden_states)
  489. loss = None
  490. if labels is not None:
  491. # Flatten the tokens
  492. loss = self.loss_function(
  493. lm_logits,
  494. labels,
  495. vocab_size=self.config.vocab_size,
  496. **kwargs,
  497. )
  498. if not return_dict:
  499. output = (lm_logits,) + transformer_outputs[1:]
  500. return ((loss,) + output) if loss is not None else output
  501. return CausalLMOutput(
  502. loss=loss,
  503. logits=lm_logits,
  504. hidden_states=transformer_outputs.hidden_states,
  505. attentions=transformer_outputs.attentions,
  506. )
  507. def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict[str, Any]:
  508. # Overwritten -- old model with reduced inputs
  509. model_inputs = {"input_ids": input_ids}
  510. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  511. for key, value in kwargs.items():
  512. if key not in model_inputs:
  513. model_inputs[key] = value
  514. return model_inputs
  515. @auto_docstring(
  516. custom_intro="""
  517. OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
  518. RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
  519. input embeddings, the classification head takes as input the input of a specified classification token index in the
  520. input sequence).
  521. """
  522. )
  523. class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
  524. _tied_weights_keys = ["lm_head.weight"]
  525. def __init__(self, config):
  526. super().__init__(config)
  527. config.num_labels = 1
  528. self.transformer = OpenAIGPTModel(config)
  529. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  530. self.multiple_choice_head = OpenAIGPTSequenceSummary(config)
  531. # Initialize weights and apply final processing
  532. self.post_init()
  533. @auto_docstring
  534. def forward(
  535. self,
  536. input_ids: Optional[torch.LongTensor] = None,
  537. attention_mask: Optional[torch.FloatTensor] = None,
  538. token_type_ids: Optional[torch.LongTensor] = None,
  539. position_ids: Optional[torch.LongTensor] = None,
  540. head_mask: Optional[torch.FloatTensor] = None,
  541. inputs_embeds: Optional[torch.FloatTensor] = None,
  542. mc_token_ids: Optional[torch.LongTensor] = None,
  543. labels: Optional[torch.LongTensor] = None,
  544. mc_labels: Optional[torch.LongTensor] = None,
  545. output_attentions: Optional[bool] = None,
  546. output_hidden_states: Optional[bool] = None,
  547. return_dict: Optional[bool] = None,
  548. ) -> Union[tuple[torch.Tensor], OpenAIGPTDoubleHeadsModelOutput]:
  549. r"""
  550. mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
  551. Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
  552. 1]`.
  553. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  554. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  555. `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are
  556. ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  557. mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
  558. Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
  559. where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)
  560. Examples:
  561. ```python
  562. >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel
  563. >>> import torch
  564. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
  565. >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
  566. >>> tokenizer.add_special_tokens(
  567. ... {"cls_token": "[CLS]"}
  568. ... ) # Add a [CLS] to the vocabulary (we should train it also!)
  569. >>> model.resize_token_embeddings(len(tokenizer))
  570. >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
  571. >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
  572. >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0) # Batch size 1
  573. >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
  574. >>> lm_logits = outputs.logits
  575. >>> mc_logits = outputs.mc_logits
  576. ```"""
  577. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  578. transformer_outputs = self.transformer(
  579. input_ids,
  580. attention_mask=attention_mask,
  581. token_type_ids=token_type_ids,
  582. position_ids=position_ids,
  583. head_mask=head_mask,
  584. inputs_embeds=inputs_embeds,
  585. output_attentions=output_attentions,
  586. output_hidden_states=output_hidden_states,
  587. return_dict=return_dict,
  588. )
  589. hidden_states = transformer_outputs[0]
  590. lm_logits = self.lm_head(hidden_states)
  591. mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
  592. lm_loss, mc_loss = None, None
  593. if mc_labels is not None:
  594. loss_fct = CrossEntropyLoss()
  595. mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
  596. if labels is not None:
  597. shift_logits = lm_logits[..., :-1, :].contiguous()
  598. shift_labels = labels[..., 1:].contiguous()
  599. loss_fct = CrossEntropyLoss()
  600. lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  601. if not return_dict:
  602. output = (lm_logits, mc_logits) + transformer_outputs[1:]
  603. if mc_loss is not None:
  604. output = (mc_loss,) + output
  605. return ((lm_loss,) + output) if lm_loss is not None else output
  606. return OpenAIGPTDoubleHeadsModelOutput(
  607. loss=lm_loss,
  608. mc_loss=mc_loss,
  609. logits=lm_logits,
  610. mc_logits=mc_logits,
  611. hidden_states=transformer_outputs.hidden_states,
  612. attentions=transformer_outputs.attentions,
  613. )
  614. @auto_docstring(
  615. custom_intro="""
  616. The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
  617. [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
  618. models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
  619. last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
  620. token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
  621. it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
  622. the last value in each row of the batch).
  623. """
  624. )
  625. class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
  626. def __init__(self, config):
  627. super().__init__(config)
  628. self.num_labels = config.num_labels
  629. self.transformer = OpenAIGPTModel(config)
  630. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  631. # Initialize weights and apply final processing
  632. self.post_init()
  633. @auto_docstring
  634. def forward(
  635. self,
  636. input_ids: Optional[torch.LongTensor] = None,
  637. attention_mask: Optional[torch.FloatTensor] = None,
  638. token_type_ids: Optional[torch.LongTensor] = None,
  639. position_ids: Optional[torch.LongTensor] = None,
  640. head_mask: Optional[torch.FloatTensor] = None,
  641. inputs_embeds: Optional[torch.FloatTensor] = None,
  642. labels: Optional[torch.LongTensor] = None,
  643. output_attentions: Optional[bool] = None,
  644. output_hidden_states: Optional[bool] = None,
  645. return_dict: Optional[bool] = None,
  646. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  647. r"""
  648. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  649. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  650. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  651. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  652. """
  653. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  654. transformer_outputs = self.transformer(
  655. input_ids,
  656. attention_mask=attention_mask,
  657. token_type_ids=token_type_ids,
  658. position_ids=position_ids,
  659. head_mask=head_mask,
  660. inputs_embeds=inputs_embeds,
  661. output_attentions=output_attentions,
  662. output_hidden_states=output_hidden_states,
  663. return_dict=return_dict,
  664. )
  665. hidden_states = transformer_outputs[0]
  666. logits = self.score(hidden_states)
  667. if input_ids is not None:
  668. batch_size, sequence_length = input_ids.shape[:2]
  669. else:
  670. batch_size, sequence_length = inputs_embeds.shape[:2]
  671. # Ensure the batch size is > 1 if there is no padding.
  672. if self.config.pad_token_id is None and batch_size != 1:
  673. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  674. if self.config.pad_token_id is None:
  675. last_non_pad_token = -1
  676. elif input_ids is not None:
  677. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  678. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  679. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  680. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  681. else:
  682. last_non_pad_token = -1
  683. logger.warning_once(
  684. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  685. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  686. )
  687. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  688. loss = None
  689. if labels is not None:
  690. if self.config.problem_type is None:
  691. if self.num_labels == 1:
  692. self.config.problem_type = "regression"
  693. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  694. self.config.problem_type = "single_label_classification"
  695. else:
  696. self.config.problem_type = "multi_label_classification"
  697. if self.config.problem_type == "regression":
  698. loss_fct = MSELoss()
  699. if self.num_labels == 1:
  700. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  701. else:
  702. loss = loss_fct(pooled_logits, labels)
  703. elif self.config.problem_type == "single_label_classification":
  704. loss_fct = CrossEntropyLoss()
  705. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  706. elif self.config.problem_type == "multi_label_classification":
  707. loss_fct = BCEWithLogitsLoss()
  708. loss = loss_fct(pooled_logits, labels)
  709. if not return_dict:
  710. output = (pooled_logits,) + transformer_outputs[1:]
  711. return ((loss,) + output) if loss is not None else output
  712. return SequenceClassifierOutput(
  713. loss=loss,
  714. logits=pooled_logits,
  715. hidden_states=transformer_outputs.hidden_states,
  716. attentions=transformer_outputs.attentions,
  717. )
  718. __all__ = [
  719. "OpenAIGPTDoubleHeadsModel",
  720. "OpenAIGPTForSequenceClassification",
  721. "OpenAIGPTLMHeadModel",
  722. "OpenAIGPTModel",
  723. "OpenAIGPTPreTrainedModel",
  724. "load_tf_weights_in_openai_gpt",
  725. ]