modeling_imagegpt.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. # coding=utf-8
  2. # Copyright 2021 The OpenAI Team Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch OpenAI ImageGPT model."""
  16. import math
  17. import os
  18. from typing import Any, Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. CausalLMOutputWithCrossAttentions,
  29. SequenceClassifierOutputWithPast,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
  33. from ...utils import (
  34. auto_docstring,
  35. logging,
  36. torch_float,
  37. )
  38. from .configuration_imagegpt import ImageGPTConfig
  39. logger = logging.get_logger(__name__)
  40. def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
  41. """
  42. Load tf checkpoints in a pytorch model
  43. """
  44. try:
  45. import re
  46. import tensorflow as tf
  47. except ImportError:
  48. logger.error(
  49. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  50. "https://www.tensorflow.org/install/ for installation instructions."
  51. )
  52. raise
  53. tf_path = os.path.abspath(imagegpt_checkpoint_path)
  54. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  55. # Load weights from TF model
  56. init_vars = tf.train.list_variables(tf_path)
  57. names = []
  58. arrays = []
  59. for name, shape in init_vars:
  60. logger.info(f"Loading TF weight {name} with shape {shape}")
  61. array = tf.train.load_variable(tf_path, name)
  62. names.append(name)
  63. arrays.append(array.squeeze())
  64. for name, array in zip(names, arrays):
  65. name = name[6:] # skip "model/"
  66. name = name.split("/")
  67. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  68. # which are not required for using pretrained model
  69. if any(
  70. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  71. for n in name
  72. ) or name[-1] in ["_step"]:
  73. logger.info("Skipping {}".format("/".join(name)))
  74. continue
  75. pointer = model
  76. if name[-1] not in ["wtet"]:
  77. pointer = getattr(pointer, "transformer")
  78. for m_name in name:
  79. if re.fullmatch(r"[A-Za-z]+\d+", m_name):
  80. scope_names = re.split(r"(\d+)", m_name)
  81. else:
  82. scope_names = [m_name]
  83. if scope_names[0] == "w" or scope_names[0] == "g":
  84. pointer = getattr(pointer, "weight")
  85. elif scope_names[0] == "b":
  86. pointer = getattr(pointer, "bias")
  87. elif scope_names[0] == "wpe" or scope_names[0] == "wte":
  88. pointer = getattr(pointer, scope_names[0])
  89. pointer = getattr(pointer, "weight")
  90. elif scope_names[0] in ["q_proj", "k_proj", "v_proj"]:
  91. pointer = getattr(pointer, "c_attn")
  92. pointer = getattr(pointer, "weight")
  93. elif len(name) == 3 and name[1] == "attn" and scope_names[0] == "c_proj":
  94. pointer = getattr(pointer, scope_names[0])
  95. pointer = getattr(pointer, "weight")
  96. elif scope_names[0] == "wtet":
  97. pointer = getattr(pointer, "lm_head")
  98. pointer = getattr(pointer, "weight")
  99. elif scope_names[0] == "sos":
  100. pointer = getattr(pointer, "wte")
  101. pointer = getattr(pointer, "weight")
  102. else:
  103. pointer = getattr(pointer, scope_names[0])
  104. if len(scope_names) >= 2:
  105. num = int(scope_names[1])
  106. pointer = pointer[num]
  107. if len(name) > 1 and name[1] == "attn" or name[-1] == "wtet" or name[-1] == "sos" or name[-1] == "wte":
  108. pass # array is used to initialize only part of the pointer so sizes won't match
  109. else:
  110. try:
  111. assert pointer.shape == array.shape
  112. except AssertionError as e:
  113. e.args += (pointer.shape, array.shape)
  114. raise
  115. logger.info(f"Initialize PyTorch weight {name}")
  116. if name[-1] == "q_proj":
  117. pointer.data[:, : config.n_embd] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T
  118. elif name[-1] == "k_proj":
  119. pointer.data[:, config.n_embd : 2 * config.n_embd] = torch.from_numpy(
  120. array.reshape(config.n_embd, config.n_embd)
  121. ).T
  122. elif name[-1] == "v_proj":
  123. pointer.data[:, 2 * config.n_embd :] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T
  124. elif len(name) == 3 and name[1] == "attn" and name[2] == "c_proj":
  125. pointer.data = torch.from_numpy(array.reshape(config.n_embd, config.n_embd))
  126. elif name[-1] == "wtet":
  127. pointer.data = torch.from_numpy(array)
  128. elif name[-1] == "wte":
  129. pointer.data[: config.vocab_size - 1, :] = torch.from_numpy(array)
  130. elif name[-1] == "sos":
  131. pointer.data[-1] = torch.from_numpy(array)
  132. else:
  133. pointer.data = torch.from_numpy(array)
  134. return model
  135. class ImageGPTLayerNorm(nn.Module):
  136. def __init__(self, hidden_size: tuple[int], eps: float = 1e-5):
  137. super().__init__()
  138. self.eps = eps
  139. self.weight = nn.Parameter(torch.Tensor(hidden_size))
  140. def forward(self, tensor: torch.Tensor) -> torch.Tensor:
  141. # input is not mean centered
  142. tensor = tensor / torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
  143. tensor = tensor * self.weight
  144. return tensor
  145. class ImageGPTAttention(nn.Module):
  146. def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx: Optional[int] = None):
  147. super().__init__()
  148. max_positions = config.max_position_embeddings
  149. self.register_buffer(
  150. "bias",
  151. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
  152. 1, 1, max_positions, max_positions
  153. ),
  154. persistent=False,
  155. )
  156. self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
  157. self.embed_dim = config.hidden_size
  158. self.num_heads = config.num_attention_heads
  159. self.head_dim = self.embed_dim // self.num_heads
  160. self.split_size = self.embed_dim
  161. if self.head_dim * self.num_heads != self.embed_dim:
  162. raise ValueError(
  163. f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  164. f" {self.num_heads})."
  165. )
  166. self.scale_attn_weights = config.scale_attn_weights
  167. self.is_cross_attention = is_cross_attention
  168. # Layer-wise attention scaling, reordering, and upcasting
  169. self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
  170. self.layer_idx = layer_idx
  171. self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
  172. if self.is_cross_attention:
  173. self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
  174. self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
  175. else:
  176. self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
  177. self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
  178. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  179. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  180. self.pruned_heads = set()
  181. def prune_heads(self, heads):
  182. if len(heads) == 0:
  183. return
  184. heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
  185. index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
  186. # Prune conv1d layers
  187. self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
  188. self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
  189. # Update hyper params
  190. self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
  191. self.num_heads = self.num_heads - len(heads)
  192. self.pruned_heads = self.pruned_heads.union(heads)
  193. def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  194. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  195. if self.scale_attn_weights:
  196. attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
  197. # Layer-wise attention scaling
  198. if self.scale_attn_by_inverse_layer_idx:
  199. attn_weights = attn_weights / float(self.layer_idx + 1)
  200. if not self.is_cross_attention:
  201. # if only "normal" attention layer implements causal mask
  202. query_length, key_length = query.size(-2), key.size(-2)
  203. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  204. mask_value = torch.finfo(attn_weights.dtype).min
  205. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  206. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  207. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
  208. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  209. if attention_mask is not None:
  210. # Apply the attention mask
  211. attn_weights = attn_weights + attention_mask
  212. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  213. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
  214. attn_weights = attn_weights.type(value.dtype)
  215. attn_weights = self.attn_dropout(attn_weights)
  216. # Mask heads if we want to
  217. if head_mask is not None:
  218. attn_weights = attn_weights * head_mask
  219. attn_output = torch.matmul(attn_weights, value)
  220. return attn_output, attn_weights
  221. def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
  222. # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
  223. bsz, num_heads, q_seq_len, dk = query.size()
  224. _, _, k_seq_len, _ = key.size()
  225. # Preallocate attn_weights for `baddbmm`
  226. attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
  227. # Compute Scale Factor
  228. scale_factor = 1.0
  229. if self.scale_attn_weights:
  230. scale_factor /= float(value.size(-1)) ** 0.5
  231. if self.scale_attn_by_inverse_layer_idx:
  232. scale_factor /= float(self.layer_idx + 1)
  233. # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
  234. with torch.autocast(query.device.type, enabled=False):
  235. q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
  236. attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
  237. attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
  238. if not self.is_cross_attention:
  239. # if only "normal" attention layer implements causal mask
  240. query_length, key_length = query.size(-2), key.size(-2)
  241. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  242. mask_value = torch.finfo(attn_weights.dtype).min
  243. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  244. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  245. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
  246. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  247. if attention_mask is not None:
  248. # Apply the attention mask
  249. attn_weights = attn_weights + attention_mask
  250. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  251. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
  252. if attn_weights.dtype != torch.float32:
  253. raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
  254. attn_weights = attn_weights.type(value.dtype)
  255. attn_weights = self.attn_dropout(attn_weights)
  256. # Mask heads if we want to
  257. if head_mask is not None:
  258. attn_weights = attn_weights * head_mask
  259. attn_output = torch.matmul(attn_weights, value)
  260. return attn_output, attn_weights
  261. def _split_heads(self, tensor, num_heads, attn_head_size):
  262. """
  263. Splits hidden_size dim into attn_head_size and num_heads
  264. """
  265. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  266. tensor = tensor.view(*new_shape)
  267. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  268. def _merge_heads(self, tensor, num_heads, attn_head_size):
  269. """
  270. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  271. """
  272. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  273. new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
  274. return tensor.view(new_shape)
  275. def forward(
  276. self,
  277. hidden_states: torch.Tensor,
  278. layer_past: Optional[Cache] = None,
  279. attention_mask: Optional[torch.Tensor] = None,
  280. head_mask: Optional[torch.Tensor] = None,
  281. encoder_hidden_states: Optional[torch.Tensor] = None,
  282. encoder_attention_mask: Optional[torch.Tensor] = None,
  283. use_cache: Optional[bool] = False,
  284. output_attentions: Optional[bool] = False,
  285. cache_position: Optional[torch.Tensor] = None,
  286. ) -> tuple:
  287. is_cross_attention = encoder_hidden_states is not None
  288. bsz, seq_len, _ = hidden_states.shape
  289. if layer_past is not None:
  290. if isinstance(layer_past, EncoderDecoderCache):
  291. is_updated = layer_past.is_updated.get(self.layer_idx)
  292. if is_cross_attention:
  293. # after the first generated id, we can subsequently re-use all key/value_states from cache
  294. curr_past_key_value = layer_past.cross_attention_cache
  295. else:
  296. curr_past_key_value = layer_past.self_attention_cache
  297. else:
  298. curr_past_key_value = layer_past
  299. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  300. if is_cross_attention:
  301. if not hasattr(self, "q_attn"):
  302. raise ValueError(
  303. "If class is used as cross attention, the weights `q_attn` have to be defined. "
  304. "Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`."
  305. )
  306. if layer_past is not None and is_updated:
  307. # reuse k,v, cross_attentions, and compute only q
  308. query = self.q_attn(hidden_states)
  309. key = curr_past_key_value.layers[self.layer_idx].keys
  310. value = curr_past_key_value.layers[self.layer_idx].values
  311. else:
  312. query = self.q_attn(hidden_states)
  313. key, value = self.c_attn(current_states).split(self.split_size, dim=2)
  314. key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  315. value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  316. else:
  317. query, key, value = self.c_attn(current_states).split(self.split_size, dim=2)
  318. key = key.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  319. value = value.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  320. if layer_past is not None:
  321. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  322. cache_position = cache_position if not is_cross_attention else None
  323. key, value = curr_past_key_value.update(key, value, self.layer_idx, {"cache_position": cache_position})
  324. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  325. if is_cross_attention:
  326. layer_past.is_updated[self.layer_idx] = True
  327. query = query.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  328. if self.reorder_and_upcast_attn:
  329. attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
  330. else:
  331. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  332. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  333. attn_output = self.c_proj(attn_output)
  334. attn_output = self.resid_dropout(attn_output)
  335. return attn_output, attn_weights
  336. class ImageGPTMLP(nn.Module):
  337. def __init__(self, intermediate_size, config):
  338. super().__init__()
  339. embed_dim = config.hidden_size
  340. self.c_fc = Conv1D(intermediate_size, embed_dim)
  341. self.c_proj = Conv1D(embed_dim, intermediate_size)
  342. self.act = ACT2FN[config.activation_function]
  343. self.dropout = nn.Dropout(config.resid_pdrop)
  344. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  345. hidden_states = self.c_fc(hidden_states)
  346. hidden_states = self.act(hidden_states)
  347. hidden_states = self.c_proj(hidden_states)
  348. hidden_states = self.dropout(hidden_states)
  349. return hidden_states
  350. class ImageGPTBlock(GradientCheckpointingLayer):
  351. def __init__(self, config, layer_idx=None):
  352. super().__init__()
  353. hidden_size = config.hidden_size
  354. inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
  355. self.ln_1 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  356. self.attn = ImageGPTAttention(config, layer_idx=layer_idx)
  357. self.ln_2 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  358. if config.add_cross_attention:
  359. self.crossattention = ImageGPTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
  360. self.ln_cross_attn = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  361. self.mlp = ImageGPTMLP(inner_dim, config)
  362. def forward(
  363. self,
  364. hidden_states: torch.Tensor,
  365. layer_past: Optional[Cache] = None,
  366. attention_mask: Optional[torch.Tensor] = None,
  367. head_mask: Optional[torch.Tensor] = None,
  368. encoder_hidden_states: Optional[torch.Tensor] = None,
  369. encoder_attention_mask: Optional[torch.Tensor] = None,
  370. use_cache: Optional[bool] = False,
  371. output_attentions: Optional[bool] = False,
  372. cache_position: Optional[torch.Tensor] = None,
  373. ) -> tuple:
  374. residual = hidden_states
  375. hidden_states = self.ln_1(hidden_states)
  376. attn_outputs = self.attn(
  377. hidden_states,
  378. layer_past=layer_past,
  379. attention_mask=attention_mask,
  380. head_mask=head_mask,
  381. use_cache=use_cache,
  382. output_attentions=output_attentions,
  383. cache_position=cache_position,
  384. )
  385. attn_output = attn_outputs[0]
  386. outputs = attn_outputs[1:]
  387. # residual connection
  388. hidden_states = attn_output + residual
  389. if encoder_hidden_states is not None:
  390. # add one self-attention block for cross-attention
  391. if not hasattr(self, "crossattention"):
  392. raise ValueError(
  393. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
  394. "cross-attention layers by setting `config.add_cross_attention=True`"
  395. )
  396. residual = hidden_states
  397. hidden_states = self.ln_cross_attn(hidden_states)
  398. cross_attn_outputs = self.crossattention(
  399. hidden_states,
  400. layer_past=layer_past,
  401. attention_mask=attention_mask,
  402. head_mask=head_mask,
  403. encoder_hidden_states=encoder_hidden_states,
  404. encoder_attention_mask=encoder_attention_mask,
  405. output_attentions=output_attentions,
  406. cache_position=cache_position,
  407. )
  408. attn_output = cross_attn_outputs[0]
  409. # residual connection
  410. hidden_states = residual + attn_output
  411. outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
  412. residual = hidden_states
  413. hidden_states = self.ln_2(hidden_states)
  414. feed_forward_hidden_states = self.mlp(hidden_states)
  415. # residual connection
  416. hidden_states = residual + feed_forward_hidden_states
  417. return (hidden_states,) + outputs
  418. @auto_docstring
  419. class ImageGPTPreTrainedModel(PreTrainedModel):
  420. config: ImageGPTConfig
  421. load_tf_weights = load_tf_weights_in_imagegpt
  422. base_model_prefix = "transformer"
  423. main_input_name = "input_ids"
  424. supports_gradient_checkpointing = True
  425. _no_split_modules = ["ImageGPTBlock"]
  426. def __init__(self, *inputs, **kwargs):
  427. super().__init__(*inputs, **kwargs)
  428. def _init_weights(self, module):
  429. """Initialize the weights."""
  430. if isinstance(module, (nn.Linear, Conv1D)):
  431. # Slightly different from the TF version which uses truncated_normal for initialization
  432. # cf https://github.com/pytorch/pytorch/pull/5617
  433. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  434. if module.bias is not None:
  435. module.bias.data.zero_()
  436. elif isinstance(module, nn.Embedding):
  437. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  438. if module.padding_idx is not None:
  439. module.weight.data[module.padding_idx].zero_()
  440. elif isinstance(module, ImageGPTLayerNorm):
  441. module.weight.data.fill_(1.0)
  442. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  443. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  444. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  445. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  446. #
  447. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  448. for name, p in module.named_parameters():
  449. if "c_proj" in name and "weight" in name:
  450. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  451. p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
  452. @auto_docstring
  453. class ImageGPTModel(ImageGPTPreTrainedModel):
  454. def __init__(self, config: ImageGPTConfig):
  455. super().__init__(config)
  456. self.embed_dim = config.hidden_size
  457. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  458. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  459. self.drop = nn.Dropout(config.embd_pdrop)
  460. self.h = nn.ModuleList([ImageGPTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  461. self.ln_f = ImageGPTLayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  462. # Model parallel
  463. self.model_parallel = False
  464. self.device_map = None
  465. self.gradient_checkpointing = False
  466. # Initialize weights and apply final processing
  467. self.post_init()
  468. def get_input_embeddings(self):
  469. return self.wte
  470. def set_input_embeddings(self, new_embeddings):
  471. self.wte = new_embeddings
  472. def _prune_heads(self, heads_to_prune):
  473. """
  474. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  475. """
  476. for layer, heads in heads_to_prune.items():
  477. self.h[layer].attn.prune_heads(heads)
  478. @auto_docstring
  479. def forward(
  480. self,
  481. input_ids: Optional[torch.Tensor] = None,
  482. past_key_values: Optional[Cache] = None,
  483. attention_mask: Optional[torch.Tensor] = None,
  484. token_type_ids: Optional[torch.Tensor] = None,
  485. position_ids: Optional[torch.Tensor] = None,
  486. head_mask: Optional[torch.Tensor] = None,
  487. inputs_embeds: Optional[torch.Tensor] = None,
  488. encoder_hidden_states: Optional[torch.Tensor] = None,
  489. encoder_attention_mask: Optional[torch.Tensor] = None,
  490. use_cache: Optional[bool] = None,
  491. output_attentions: Optional[bool] = None,
  492. output_hidden_states: Optional[bool] = None,
  493. return_dict: Optional[bool] = None,
  494. cache_position: Optional[torch.Tensor] = None,
  495. **kwargs: Any,
  496. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  497. r"""
  498. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  499. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  500. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  501. sequence tokens in the vocabulary.
  502. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  503. `input_ids`.
  504. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  505. Examples:
  506. ```python
  507. >>> from transformers import AutoImageProcessor, ImageGPTModel
  508. >>> from PIL import Image
  509. >>> import requests
  510. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  511. >>> image = Image.open(requests.get(url, stream=True).raw)
  512. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  513. >>> model = ImageGPTModel.from_pretrained("openai/imagegpt-small")
  514. >>> inputs = image_processor(images=image, return_tensors="pt")
  515. >>> outputs = model(**inputs)
  516. >>> last_hidden_states = outputs.last_hidden_state
  517. ```"""
  518. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  519. output_hidden_states = (
  520. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  521. )
  522. use_cache = use_cache if use_cache is not None else self.config.use_cache
  523. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  524. if input_ids is not None and inputs_embeds is not None:
  525. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  526. elif input_ids is not None:
  527. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  528. input_shape = input_ids.size()
  529. input_ids = input_ids.view(-1, input_shape[-1])
  530. batch_size = input_ids.shape[0]
  531. elif inputs_embeds is not None:
  532. input_shape = inputs_embeds.size()[:-1]
  533. batch_size = inputs_embeds.shape[0]
  534. else:
  535. raise ValueError("You have to specify either input_ids or inputs_embeds")
  536. device = input_ids.device if input_ids is not None else inputs_embeds.device
  537. if self.gradient_checkpointing and self.training:
  538. if use_cache:
  539. logger.warning_once(
  540. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  541. )
  542. use_cache = False
  543. if use_cache and past_key_values is None:
  544. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  545. if use_cache and isinstance(past_key_values, tuple):
  546. logger.warning_once(
  547. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  548. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  549. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  550. )
  551. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  552. past_length = past_key_values.get_seq_length() if past_key_values is not None else past_key_values
  553. if token_type_ids is not None:
  554. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  555. if position_ids is None:
  556. position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
  557. position_ids = position_ids.unsqueeze(0)
  558. # ImageGPTAttention mask.
  559. if attention_mask is not None:
  560. if batch_size <= 0:
  561. raise ValueError("batch_size has to be defined and > 0")
  562. attention_mask = attention_mask.view(batch_size, -1)
  563. # We create a 3D attention mask from a 2D tensor mask.
  564. # Sizes are [batch_size, 1, 1, to_seq_length]
  565. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  566. # this attention mask is more simple than the triangular masking of causal attention
  567. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  568. attention_mask = attention_mask[:, None, None, :]
  569. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  570. # masked positions, this operation will create a tensor which is 0.0 for
  571. # positions we want to attend and the dtype's smallest value for masked positions.
  572. # Since we are adding it to the raw scores before the softmax, this is
  573. # effectively the same as removing these entirely.
  574. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
  575. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  576. # If a 2D or 3D attention mask is provided for the cross-attention
  577. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  578. if self.config.add_cross_attention and encoder_hidden_states is not None:
  579. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  580. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  581. if encoder_attention_mask is None:
  582. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  583. encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  584. else:
  585. encoder_attention_mask = None
  586. # Prepare head mask if needed
  587. # 1.0 in head_mask indicate we keep the head
  588. # attention_probs has shape bsz x n_heads x N x N
  589. # head_mask has shape n_layer x batch x n_heads x N x N
  590. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  591. if inputs_embeds is None:
  592. inputs_embeds = self.wte(input_ids)
  593. position_embeds = self.wpe(position_ids)
  594. hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
  595. if token_type_ids is not None:
  596. token_type_embeds = self.wte(token_type_ids)
  597. hidden_states = hidden_states + token_type_embeds
  598. hidden_states = self.drop(hidden_states)
  599. output_shape = input_shape + (hidden_states.size(-1),)
  600. all_self_attentions = () if output_attentions else None
  601. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  602. all_hidden_states = () if output_hidden_states else None
  603. for i, block in enumerate(self.h):
  604. # Model parallel
  605. if self.model_parallel:
  606. torch.cuda.set_device(hidden_states.device)
  607. # Ensure that attention_mask is always on the same device as hidden_states
  608. if attention_mask is not None:
  609. attention_mask = attention_mask.to(hidden_states.device)
  610. if isinstance(head_mask, torch.Tensor):
  611. head_mask = head_mask.to(hidden_states.device)
  612. if output_hidden_states:
  613. all_hidden_states = all_hidden_states + (hidden_states,)
  614. outputs = block(
  615. hidden_states,
  616. past_key_values,
  617. attention_mask,
  618. head_mask[i],
  619. encoder_hidden_states, # as a positional argument for gradient checkpointing
  620. encoder_attention_mask=encoder_attention_mask,
  621. use_cache=use_cache,
  622. output_attentions=output_attentions,
  623. cache_position=cache_position,
  624. )
  625. hidden_states = outputs[0]
  626. if output_attentions:
  627. all_self_attentions = all_self_attentions + (outputs[1],)
  628. if self.config.add_cross_attention:
  629. all_cross_attentions = all_cross_attentions + (outputs[2],)
  630. # Model Parallel: If it's the last layer for that device, put things on the next device
  631. if self.model_parallel:
  632. for k, v in self.device_map.items():
  633. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  634. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  635. hidden_states = self.ln_f(hidden_states)
  636. hidden_states = hidden_states.view(*output_shape)
  637. # Add last hidden state
  638. if output_hidden_states:
  639. all_hidden_states = all_hidden_states + (hidden_states,)
  640. if not return_dict:
  641. return tuple(
  642. v
  643. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
  644. if v is not None
  645. )
  646. return BaseModelOutputWithPastAndCrossAttentions(
  647. last_hidden_state=hidden_states,
  648. past_key_values=past_key_values,
  649. hidden_states=all_hidden_states,
  650. attentions=all_self_attentions,
  651. cross_attentions=all_cross_attentions,
  652. )
  653. @auto_docstring(
  654. custom_intro="""
  655. The ImageGPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  656. embeddings).
  657. """
  658. )
  659. class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin):
  660. _tied_weights_keys = ["lm_head.weight"]
  661. def __init__(self, config: ImageGPTConfig):
  662. super().__init__(config)
  663. self.transformer = ImageGPTModel(config)
  664. self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)
  665. # Model parallel
  666. self.model_parallel = False
  667. self.device_map = None
  668. # Initialize weights and apply final processing
  669. self.post_init()
  670. @auto_docstring
  671. def forward(
  672. self,
  673. input_ids: Optional[torch.Tensor] = None,
  674. past_key_values: Optional[Cache] = None,
  675. attention_mask: Optional[torch.Tensor] = None,
  676. token_type_ids: Optional[torch.Tensor] = None,
  677. position_ids: Optional[torch.Tensor] = None,
  678. head_mask: Optional[torch.Tensor] = None,
  679. inputs_embeds: Optional[torch.Tensor] = None,
  680. encoder_hidden_states: Optional[torch.Tensor] = None,
  681. encoder_attention_mask: Optional[torch.Tensor] = None,
  682. labels: Optional[torch.Tensor] = None,
  683. use_cache: Optional[bool] = None,
  684. output_attentions: Optional[bool] = None,
  685. output_hidden_states: Optional[bool] = None,
  686. return_dict: Optional[bool] = None,
  687. cache_position: Optional[torch.Tensor] = None,
  688. **kwargs: Any,
  689. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  690. r"""
  691. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  692. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  693. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  694. sequence tokens in the vocabulary.
  695. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  696. `input_ids`.
  697. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  698. labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
  699. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  700. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  701. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  702. Examples:
  703. ```python
  704. >>> from transformers import AutoImageProcessor, ImageGPTForCausalImageModeling
  705. >>> import torch
  706. >>> import matplotlib.pyplot as plt
  707. >>> import numpy as np
  708. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  709. >>> model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small")
  710. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  711. >>> model.to(device) # doctest: +IGNORE_RESULT
  712. >>> # unconditional generation of 8 images
  713. >>> batch_size = 4
  714. >>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token
  715. >>> context = context.to(device)
  716. >>> output = model.generate(
  717. ... input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
  718. ... )
  719. >>> clusters = image_processor.clusters
  720. >>> height = image_processor.size["height"]
  721. >>> width = image_processor.size["width"]
  722. >>> samples = output[:, 1:].detach().cpu().numpy()
  723. >>> samples_img = [
  724. ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
  725. ... ] # convert color cluster tokens back to pixels
  726. >>> f, axes = plt.subplots(1, batch_size, dpi=300)
  727. >>> for img, ax in zip(samples_img, axes): # doctest: +IGNORE_RESULT
  728. ... ax.axis("off")
  729. ... ax.imshow(img)
  730. ```"""
  731. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  732. transformer_outputs = self.transformer(
  733. input_ids,
  734. past_key_values=past_key_values,
  735. attention_mask=attention_mask,
  736. token_type_ids=token_type_ids,
  737. position_ids=position_ids,
  738. head_mask=head_mask,
  739. inputs_embeds=inputs_embeds,
  740. encoder_hidden_states=encoder_hidden_states,
  741. encoder_attention_mask=encoder_attention_mask,
  742. use_cache=use_cache,
  743. output_attentions=output_attentions,
  744. output_hidden_states=output_hidden_states,
  745. return_dict=return_dict,
  746. cache_position=cache_position,
  747. )
  748. hidden_states = transformer_outputs[0]
  749. lm_logits = self.lm_head(hidden_states)
  750. loss = None
  751. if labels is not None:
  752. # Shift so that tokens < n predict n
  753. shift_logits = lm_logits[..., :-1, :].contiguous()
  754. shift_labels = labels[..., 1:].contiguous()
  755. # Flatten the tokens
  756. loss_fct = CrossEntropyLoss()
  757. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  758. if not return_dict:
  759. output = (lm_logits,) + transformer_outputs[1:]
  760. return ((loss,) + output) if loss is not None else output
  761. return CausalLMOutputWithCrossAttentions(
  762. loss=loss,
  763. logits=lm_logits,
  764. past_key_values=transformer_outputs.past_key_values,
  765. hidden_states=transformer_outputs.hidden_states,
  766. attentions=transformer_outputs.attentions,
  767. cross_attentions=transformer_outputs.cross_attentions,
  768. )
  769. @auto_docstring(
  770. custom_intro="""
  771. The ImageGPT Model transformer with an image classification head on top (linear layer).
  772. [`ImageGPTForImageClassification`] average-pools the hidden states in order to do the classification.
  773. """
  774. )
  775. class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
  776. def __init__(self, config: ImageGPTConfig):
  777. super().__init__(config)
  778. self.num_labels = config.num_labels
  779. self.transformer = ImageGPTModel(config)
  780. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  781. # Initialize weights and apply final processing
  782. self.post_init()
  783. @auto_docstring
  784. def forward(
  785. self,
  786. input_ids: Optional[torch.Tensor] = None,
  787. past_key_values: Optional[Cache] = None,
  788. attention_mask: Optional[torch.Tensor] = None,
  789. token_type_ids: Optional[torch.Tensor] = None,
  790. position_ids: Optional[torch.Tensor] = None,
  791. head_mask: Optional[torch.Tensor] = None,
  792. inputs_embeds: Optional[torch.Tensor] = None,
  793. labels: Optional[torch.Tensor] = None,
  794. use_cache: Optional[bool] = None,
  795. output_attentions: Optional[bool] = None,
  796. output_hidden_states: Optional[bool] = None,
  797. return_dict: Optional[bool] = None,
  798. **kwargs: Any,
  799. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  800. r"""
  801. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  802. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  803. `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
  804. sequence tokens in the vocabulary.
  805. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  806. `input_ids`.
  807. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  808. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  809. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  810. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  811. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  812. Examples:
  813. ```python
  814. >>> from transformers import AutoImageProcessor, ImageGPTForImageClassification
  815. >>> from PIL import Image
  816. >>> import requests
  817. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  818. >>> image = Image.open(requests.get(url, stream=True).raw)
  819. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  820. >>> model = ImageGPTForImageClassification.from_pretrained("openai/imagegpt-small")
  821. >>> inputs = image_processor(images=image, return_tensors="pt")
  822. >>> outputs = model(**inputs)
  823. >>> logits = outputs.logits
  824. ```"""
  825. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  826. transformer_outputs = self.transformer(
  827. input_ids,
  828. past_key_values=past_key_values,
  829. attention_mask=attention_mask,
  830. token_type_ids=token_type_ids,
  831. position_ids=position_ids,
  832. head_mask=head_mask,
  833. inputs_embeds=inputs_embeds,
  834. use_cache=use_cache,
  835. output_attentions=output_attentions,
  836. output_hidden_states=output_hidden_states,
  837. return_dict=return_dict,
  838. )
  839. hidden_states = transformer_outputs[0]
  840. # average-pool the hidden states along the sequence dimension
  841. pooled_hidden_states = hidden_states.mean(dim=1)
  842. # project from (batch_size, hidden_size) to (batch_size, num_labels)
  843. logits = self.score(pooled_hidden_states)
  844. loss = None
  845. if labels is not None:
  846. loss = self.loss_function(labels, logits, self.config)
  847. if not return_dict:
  848. output = (logits,) + transformer_outputs[1:]
  849. return ((loss,) + output) if loss is not None else output
  850. return SequenceClassifierOutputWithPast(
  851. loss=loss,
  852. logits=logits,
  853. past_key_values=transformer_outputs.past_key_values,
  854. hidden_states=transformer_outputs.hidden_states,
  855. attentions=transformer_outputs.attentions,
  856. )
  857. __all__ = [
  858. "ImageGPTForCausalImageModeling",
  859. "ImageGPTForImageClassification",
  860. "ImageGPTModel",
  861. "ImageGPTPreTrainedModel",
  862. "load_tf_weights_in_imagegpt",
  863. ]