modeling_gptj.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231
  1. # coding=utf-8
  2. # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch GPT-J model."""
  16. import warnings
  17. from typing import Optional, Union
  18. import torch
  19. import torch.fx
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import AttentionMaskConverter
  26. from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutputWithPast,
  30. CausalLMOutputWithPast,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutputWithPast,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...utils import (
  36. add_start_docstrings,
  37. auto_docstring,
  38. is_torch_flex_attn_available,
  39. is_torch_fx_proxy,
  40. logging,
  41. )
  42. from ...utils.model_parallel_utils import assert_device_map, get_device_map
  43. from .configuration_gptj import GPTJConfig
  44. if is_torch_flex_attn_available():
  45. from torch.nn.attention.flex_attention import BlockMask
  46. from ...integrations.flex_attention import make_flex_block_causal_mask
  47. if is_flash_attn_available():
  48. from ...modeling_flash_attention_utils import _flash_attention_forward
  49. logger = logging.get_logger(__name__)
  50. def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
  51. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
  52. sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
  53. return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
  54. @torch.fx.wrap
  55. def get_embed_positions(embed_positions, position_ids):
  56. return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1)
  57. def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
  58. x1 = x[:, :, :, ::2]
  59. x2 = x[:, :, :, 1::2]
  60. x = torch.stack((-x2, x1), dim=-1)
  61. return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
  62. def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
  63. sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
  64. cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
  65. return (tensor * cos) + (rotate_every_two(tensor) * sin)
  66. class GPTJAttention(nn.Module):
  67. def __init__(self, config, layer_idx=None):
  68. super().__init__()
  69. self.config = config
  70. max_positions = config.max_position_embeddings
  71. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  72. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  73. self.is_causal = True
  74. self.layer_idx = layer_idx
  75. if layer_idx is None:
  76. logger.warning_once(
  77. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  78. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  79. "when creating this class."
  80. )
  81. self.embed_dim = config.hidden_size
  82. self.num_attention_heads = config.num_attention_heads
  83. self.head_dim = self.embed_dim // self.num_attention_heads
  84. if self.head_dim * self.num_attention_heads != self.embed_dim:
  85. raise ValueError(
  86. f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
  87. f" `num_attention_heads`: {self.num_attention_heads})."
  88. )
  89. self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
  90. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  91. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  92. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  93. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  94. self.rotary_dim = config.rotary_dim
  95. pos_embd_dim = self.rotary_dim or self.embed_dim
  96. self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
  97. def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary):
  98. """
  99. Splits hidden dim into attn_head_size and num_attention_heads
  100. """
  101. new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
  102. tensor = tensor.view(new_shape)
  103. if rotary:
  104. return tensor
  105. if len(tensor.shape) == 5:
  106. return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
  107. elif len(tensor.shape) == 4:
  108. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  109. else:
  110. raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
  111. def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
  112. """
  113. Merges attn_head_size dim and num_attn_heads dim into hidden dim
  114. """
  115. if len(tensor.shape) == 5:
  116. tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
  117. elif len(tensor.shape) == 4:
  118. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  119. else:
  120. raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
  121. new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
  122. return tensor.view(new_shape)
  123. def _attn(
  124. self,
  125. query,
  126. key,
  127. value,
  128. attention_mask=None,
  129. head_mask=None,
  130. ):
  131. # Keep the attention weights computation in fp32 to avoid overflow issues
  132. query = query.to(torch.float32)
  133. key = key.to(torch.float32)
  134. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  135. attn_weights = attn_weights / self.scale_attn
  136. if attention_mask is not None: # no matter the length, we just slice it
  137. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  138. attn_weights = attn_weights + causal_mask
  139. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  140. attn_weights = attn_weights.to(value.dtype)
  141. attn_weights = self.attn_dropout(attn_weights)
  142. # Mask heads if we want to
  143. if head_mask is not None:
  144. attn_weights = attn_weights * head_mask
  145. attn_output = torch.matmul(attn_weights, value)
  146. return attn_output, attn_weights
  147. def _get_embed_positions(self, position_ids):
  148. embed_positions = self.embed_positions
  149. if embed_positions.device != position_ids.device:
  150. embed_positions = embed_positions.to(position_ids.device)
  151. self.embed_positions = embed_positions
  152. return embed_positions.repeat(position_ids.shape[0], 1, 1)
  153. def forward(
  154. self,
  155. hidden_states: torch.FloatTensor,
  156. layer_past: Optional[Cache] = None,
  157. attention_mask: Optional[torch.FloatTensor] = None,
  158. position_ids: Optional[torch.LongTensor] = None,
  159. head_mask: Optional[torch.FloatTensor] = None,
  160. use_cache: Optional[bool] = False,
  161. output_attentions: Optional[bool] = False,
  162. cache_position: Optional[torch.LongTensor] = None,
  163. ) -> Union[
  164. tuple[torch.Tensor, tuple[torch.Tensor]],
  165. Optional[tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]],
  166. ]:
  167. query = self.q_proj(hidden_states)
  168. key = self.k_proj(hidden_states)
  169. value = self.v_proj(hidden_states)
  170. query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  171. key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
  172. value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
  173. if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
  174. # The logic to conditionally copy to GPU could not be traced, so we do this
  175. # every time in the torch.fx case
  176. embed_positions = get_embed_positions(self.embed_positions, position_ids)
  177. else:
  178. embed_positions = self._get_embed_positions(position_ids)
  179. repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
  180. sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype)
  181. sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
  182. if self.rotary_dim is not None:
  183. k_rot = key[:, :, :, : self.rotary_dim]
  184. k_pass = key[:, :, :, self.rotary_dim :]
  185. q_rot = query[:, :, :, : self.rotary_dim]
  186. q_pass = query[:, :, :, self.rotary_dim :]
  187. k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
  188. q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
  189. key = torch.cat([k_rot, k_pass], dim=-1)
  190. query = torch.cat([q_rot, q_pass], dim=-1)
  191. else:
  192. key = apply_rotary_pos_emb(key, sin, cos)
  193. query = apply_rotary_pos_emb(query, sin, cos)
  194. key = key.permute(0, 2, 1, 3)
  195. query = query.permute(0, 2, 1, 3)
  196. if layer_past is not None:
  197. cache_kwargs = {
  198. "sin": sin,
  199. "cos": cos,
  200. "partial_rotation_size": self.rotary_dim,
  201. "cache_position": cache_position,
  202. }
  203. key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
  204. # compute self-attention: V x Softmax(QK^T)
  205. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  206. attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
  207. attn_output = self.out_proj(attn_output)
  208. attn_output = self.resid_dropout(attn_output)
  209. return attn_output, attn_weights
  210. class GPTJFlashAttention2(GPTJAttention):
  211. """
  212. GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
  213. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  214. flash attention and deal with padding tokens in case the input contains any of them.
  215. """
  216. def __init__(self, *args, **kwargs):
  217. super().__init__(*args, **kwargs)
  218. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  219. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  220. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  221. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  222. def forward(
  223. self,
  224. hidden_states: torch.FloatTensor,
  225. layer_past: Optional[Cache] = None,
  226. attention_mask: Optional[torch.FloatTensor] = None,
  227. position_ids: Optional[torch.LongTensor] = None,
  228. head_mask: Optional[torch.FloatTensor] = None,
  229. use_cache: Optional[bool] = False,
  230. output_attentions: Optional[bool] = False,
  231. cache_position: Optional[torch.LongTensor] = None,
  232. ) -> Union[
  233. tuple[torch.Tensor, tuple[torch.Tensor]],
  234. Optional[tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]],
  235. ]:
  236. query = self.q_proj(hidden_states)
  237. key = self.k_proj(hidden_states)
  238. value = self.v_proj(hidden_states)
  239. query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
  240. key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
  241. value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
  242. if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
  243. # The logic to conditionally copy to GPU could not be traced, so we do this
  244. # every time in the torch.fx case
  245. embed_positions = get_embed_positions(self.embed_positions, position_ids)
  246. else:
  247. embed_positions = self._get_embed_positions(position_ids)
  248. repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
  249. sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype)
  250. sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
  251. if self.rotary_dim is not None:
  252. k_rot = key[:, :, :, : self.rotary_dim]
  253. k_pass = key[:, :, :, self.rotary_dim :]
  254. q_rot = query[:, :, :, : self.rotary_dim]
  255. q_pass = query[:, :, :, self.rotary_dim :]
  256. k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
  257. q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
  258. key = torch.cat([k_rot, k_pass], dim=-1)
  259. query = torch.cat([q_rot, q_pass], dim=-1)
  260. else:
  261. key = apply_rotary_pos_emb(key, sin, cos)
  262. query = apply_rotary_pos_emb(query, sin, cos)
  263. # tanspose to have the desired shape
  264. # before transpose: batch_size x seq_length x num_attention_heads x head_dim
  265. # after transpose: batch_size x num_attention_heads x seq_length x head_dim
  266. key = key.permute(0, 2, 1, 3)
  267. query = query.permute(0, 2, 1, 3)
  268. # value: batch_size x num_attention_heads x seq_length x head_dim
  269. if layer_past is not None:
  270. cache_kwargs = {
  271. "sin": sin,
  272. "cos": cos,
  273. "partial_rotation_size": self.rotary_dim,
  274. "cache_position": cache_position,
  275. }
  276. key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
  277. # The Flash attention requires the input to have the shape
  278. # batch_size x seq_length x head_dim x hidden_dim
  279. # therefore we need to keep the original shape for query and key, and reshape value
  280. # to have the correct shape.
  281. key = key.permute(0, 2, 1, 3).contiguous()
  282. query = query.permute(0, 2, 1, 3).contiguous()
  283. value = value.permute(0, 2, 1, 3).contiguous()
  284. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  285. # therefore the input hidden states gets silently casted in float32. Hence, we need
  286. # cast them back in the correct dtype just to be sure everything works as expected.
  287. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  288. # in fp32. (LlamaRMSNorm handles it correctly)
  289. input_dtype = query.dtype
  290. device_type = query.device.type if query.device.type != "mps" else "cpu"
  291. if input_dtype == torch.float32:
  292. if torch.is_autocast_enabled():
  293. target_dtype = (
  294. torch.get_autocast_dtype(device_type)
  295. if hasattr(torch, "get_autocast_dtype")
  296. else torch.get_autocast_gpu_dtype()
  297. )
  298. # Handle the case where the model is quantized
  299. elif hasattr(self.config, "_pre_quantization_dtype"):
  300. target_dtype = self.config._pre_quantization_dtype
  301. else:
  302. target_dtype = self.q_proj.weight.dtype
  303. logger.warning_once(
  304. f"The input hidden states seems to be silently casted in float32, this might be related to"
  305. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  306. f" {target_dtype}."
  307. )
  308. query = query.to(target_dtype)
  309. key = key.to(target_dtype)
  310. value = value.to(target_dtype)
  311. attention_dropout = self.config.attn_pdrop if self.training else 0.0 # attn_pdrop in gptj
  312. query_length = query.shape[1]
  313. # Compute attention
  314. attn_weights = _flash_attention_forward(
  315. query,
  316. key,
  317. value,
  318. attention_mask,
  319. query_length,
  320. dropout=attention_dropout,
  321. is_causal=self.is_causal,
  322. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  323. )
  324. # Reshape outputs
  325. attn_output = attn_weights.reshape(
  326. attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3]
  327. )
  328. attn_output = self.out_proj(attn_output)
  329. attn_output = self.resid_dropout(attn_output)
  330. return attn_output, attn_weights
  331. GPTJ_ATTENTION_CLASSES = {
  332. "eager": GPTJAttention,
  333. "flash_attention_2": GPTJFlashAttention2,
  334. }
  335. class GPTJMLP(nn.Module):
  336. def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
  337. super().__init__()
  338. embed_dim = config.n_embd
  339. self.fc_in = nn.Linear(embed_dim, intermediate_size)
  340. self.fc_out = nn.Linear(intermediate_size, embed_dim)
  341. self.act = ACT2FN[config.activation_function]
  342. self.dropout = nn.Dropout(config.resid_pdrop)
  343. def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
  344. hidden_states = self.fc_in(hidden_states)
  345. hidden_states = self.act(hidden_states)
  346. hidden_states = self.fc_out(hidden_states)
  347. hidden_states = self.dropout(hidden_states)
  348. return hidden_states
  349. class GPTJBlock(GradientCheckpointingLayer):
  350. def __init__(self, config, layer_idx=None):
  351. super().__init__()
  352. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
  353. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  354. self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  355. self.mlp = GPTJMLP(inner_dim, config)
  356. def forward(
  357. self,
  358. hidden_states: Optional[torch.FloatTensor],
  359. layer_past: Optional[Cache] = None,
  360. attention_mask: Optional[torch.FloatTensor] = None,
  361. position_ids: Optional[torch.LongTensor] = None,
  362. head_mask: Optional[torch.FloatTensor] = None,
  363. use_cache: Optional[bool] = False,
  364. output_attentions: Optional[bool] = False,
  365. cache_position: Optional[torch.LongTensor] = None,
  366. ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
  367. residual = hidden_states
  368. hidden_states = self.ln_1(hidden_states)
  369. attn_outputs, attn_weights = self.attn(
  370. hidden_states=hidden_states,
  371. layer_past=layer_past,
  372. attention_mask=attention_mask,
  373. position_ids=position_ids,
  374. head_mask=head_mask,
  375. use_cache=use_cache,
  376. output_attentions=output_attentions,
  377. cache_position=cache_position,
  378. )
  379. feed_forward_hidden_states = self.mlp(hidden_states)
  380. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  381. return hidden_states, attn_weights
  382. @auto_docstring
  383. class GPTJPreTrainedModel(PreTrainedModel):
  384. config: GPTJConfig
  385. base_model_prefix = "transformer"
  386. is_parallelizable = True
  387. supports_gradient_checkpointing = True
  388. _no_split_modules = ["GPTJBlock"]
  389. _skip_keys_device_placement = "past_key_values"
  390. _supports_flash_attn = True
  391. _can_compile_fullgraph = True
  392. _supports_param_buffer_assignment = False
  393. def __init__(self, *inputs, **kwargs):
  394. super().__init__(*inputs, **kwargs)
  395. def _init_weights(self, module):
  396. """Initialize the weights."""
  397. if isinstance(module, (nn.Linear,)):
  398. # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
  399. # cf https://github.com/pytorch/pytorch/pull/5617
  400. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  401. if module.bias is not None:
  402. module.bias.data.zero_()
  403. elif isinstance(module, nn.Embedding):
  404. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  405. if module.padding_idx is not None:
  406. module.weight.data[module.padding_idx].zero_()
  407. elif isinstance(module, nn.LayerNorm):
  408. module.bias.data.zero_()
  409. module.weight.data.fill_(1.0)
  410. PARALLELIZE_DOCSTRING = r"""
  411. This is an experimental feature and is a subject to change at a moment's notice. Uses a device map to distribute
  412. attention modules of the model across several devices. If no device map is given, it will evenly distribute blocks
  413. across all devices.
  414. Args:
  415. device_map (`dict[int, list]`, *optional*):
  416. A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
  417. automatically mapped to the first device (for esoteric reasons). That means that the first device should
  418. have fewer attention modules mapped to it than other devices. For reference, the GPT-J models have the
  419. following number of attention modules:
  420. - gpt-j-6B: 28
  421. Example:
  422. ```python
  423. # Here is an example of a device map on a machine with 4 GPUs using gpt-j-6B, which has a total of 28 attention modules:
  424. model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
  425. device_map = {
  426. 0: [0, 1, 2, 3, 4, 5, 6],
  427. 1: [7, 8, 9, 10, 11, 12, 13],
  428. 2: [14, 15, 16, 17, 18, 19, 20],
  429. 3: [21, 22, 23, 24, 25, 26, 27],
  430. }
  431. model.parallelize(device_map)
  432. ```
  433. """
  434. DEPARALLELIZE_DOCSTRING = r"""
  435. Moves the model to CPU from a model parallel state.
  436. Example:
  437. ```python
  438. # On a 4 GPU machine with gpt-j-6B:
  439. model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
  440. device_map = {
  441. 0: [0, 1, 2, 3, 4, 5, 6],
  442. 1: [7, 8, 9, 10, 11, 12, 13],
  443. 2: [14, 15, 16, 17, 18, 19, 20],
  444. 3: [21, 22, 23, 24, 25, 26, 27],
  445. }
  446. model.parallelize(device_map) # Splits the model across several devices
  447. model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
  448. ```
  449. """
  450. @auto_docstring
  451. class GPTJModel(GPTJPreTrainedModel):
  452. def __init__(self, config):
  453. super().__init__(config)
  454. self.embed_dim = config.n_embd
  455. self.vocab_size = config.vocab_size
  456. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  457. self.drop = nn.Dropout(config.embd_pdrop)
  458. self.h = nn.ModuleList([GPTJBlock(config, layer_idx=i) for i in range(config.n_layer)])
  459. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  460. # Model parallel
  461. self.model_parallel = False
  462. self.device_map = None
  463. self.gradient_checkpointing = False
  464. # Initialize weights and apply final processing
  465. self.post_init()
  466. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  467. def parallelize(self, device_map=None):
  468. warnings.warn(
  469. "`GPTJModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
  470. " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  471. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
  472. " ...}",
  473. FutureWarning,
  474. )
  475. # Check validity of device_map
  476. self.device_map = (
  477. get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
  478. )
  479. assert_device_map(self.device_map, len(self.h))
  480. self.model_parallel = True
  481. self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys()))
  482. self.last_device = "cuda:" + str(max(self.device_map.keys()))
  483. self.wte = self.wte.to(self.first_device)
  484. # Load onto devices
  485. for k, v in self.device_map.items():
  486. for block in v:
  487. cuda_device = "cuda:" + str(k)
  488. self.h[block] = self.h[block].to(cuda_device)
  489. # ln_f to last
  490. self.ln_f = self.ln_f.to(self.last_device)
  491. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  492. def deparallelize(self):
  493. warnings.warn(
  494. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  495. FutureWarning,
  496. )
  497. self.model_parallel = False
  498. self.device_map = None
  499. self.first_device = "cpu"
  500. self.last_device = "cpu"
  501. self.wte = self.wte.to("cpu")
  502. for index in range(len(self.h)):
  503. self.h[index] = self.h[index].to("cpu")
  504. self.ln_f = self.ln_f.to("cpu")
  505. torch.cuda.empty_cache()
  506. def get_input_embeddings(self):
  507. return self.wte
  508. def set_input_embeddings(self, new_embeddings):
  509. self.wte = new_embeddings
  510. @auto_docstring
  511. def forward(
  512. self,
  513. input_ids: Optional[torch.LongTensor] = None,
  514. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
  515. attention_mask: Optional[torch.FloatTensor] = None,
  516. token_type_ids: Optional[torch.LongTensor] = None,
  517. position_ids: Optional[torch.LongTensor] = None,
  518. head_mask: Optional[torch.FloatTensor] = None,
  519. inputs_embeds: Optional[torch.FloatTensor] = None,
  520. use_cache: Optional[bool] = None,
  521. output_attentions: Optional[bool] = None,
  522. output_hidden_states: Optional[bool] = None,
  523. return_dict: Optional[bool] = None,
  524. cache_position: Optional[torch.LongTensor] = None,
  525. ) -> Union[tuple, BaseModelOutputWithPast]:
  526. r"""
  527. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  528. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  529. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  530. model's internal embedding lookup matrix.
  531. """
  532. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  533. output_hidden_states = (
  534. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  535. )
  536. use_cache = use_cache if use_cache is not None else self.config.use_cache
  537. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  538. if (input_ids is None) ^ (inputs_embeds is not None):
  539. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  540. if self.gradient_checkpointing and self.training:
  541. if use_cache:
  542. logger.warning_once(
  543. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  544. )
  545. use_cache = False
  546. if inputs_embeds is None:
  547. inputs_embeds = self.wte(input_ids)
  548. if use_cache and past_key_values is None:
  549. past_key_values = DynamicCache(config=self.config)
  550. seq_length = inputs_embeds.shape[1]
  551. if cache_position is None:
  552. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  553. cache_position = torch.arange(
  554. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  555. )
  556. if position_ids is None:
  557. position_ids = cache_position.unsqueeze(0)
  558. causal_mask = self._update_causal_mask(
  559. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  560. )
  561. # Prepare head mask if needed
  562. # 1.0 in head_mask indicate we keep the head
  563. # attention_probs has shape bsz x num_attention_heads x N x N
  564. # head_mask has shape n_layer x batch x num_attention_heads x N x N
  565. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  566. hidden_states = inputs_embeds
  567. if token_type_ids is not None:
  568. token_type_ids = token_type_ids.view(-1, seq_length)
  569. token_type_embeds = self.wte(token_type_ids)
  570. hidden_states = hidden_states + token_type_embeds
  571. hidden_states = self.drop(hidden_states)
  572. output_shape = (-1, seq_length, hidden_states.size(-1))
  573. all_self_attentions = () if output_attentions else None
  574. all_hidden_states = () if output_hidden_states else None
  575. for i, block in enumerate(self.h):
  576. # Model parallel
  577. if self.model_parallel:
  578. torch.cuda.set_device(hidden_states.device)
  579. # Ensure layer_past is on same device as hidden_states (might not be correct)
  580. if past_key_values is not None:
  581. for layer in past_key_values.layers:
  582. layer.keys = layer.keys.to(hidden_states.device)
  583. layer.values = layer.values.to(hidden_states.device)
  584. # Ensure that attention_mask is always on the same device as hidden_states
  585. if causal_mask is not None:
  586. causal_mask = causal_mask.to(hidden_states.device)
  587. if isinstance(head_mask, torch.Tensor):
  588. head_mask = head_mask.to(hidden_states.device)
  589. if output_hidden_states:
  590. all_hidden_states = all_hidden_states + (hidden_states,)
  591. outputs = block(
  592. hidden_states,
  593. layer_past=past_key_values,
  594. attention_mask=causal_mask,
  595. position_ids=position_ids,
  596. head_mask=head_mask[i],
  597. use_cache=use_cache,
  598. output_attentions=output_attentions,
  599. cache_position=cache_position,
  600. )
  601. hidden_states = outputs[0]
  602. if output_attentions:
  603. all_self_attentions = all_self_attentions + (outputs[1],)
  604. # Model Parallel: If it's the last layer for that device, put things on the next device
  605. if self.model_parallel:
  606. for k, v in self.device_map.items():
  607. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  608. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  609. hidden_states = self.ln_f(hidden_states)
  610. hidden_states = hidden_states.view(output_shape)
  611. # Add last hidden state
  612. if output_hidden_states:
  613. all_hidden_states = all_hidden_states + (hidden_states,)
  614. if not return_dict:
  615. return tuple(
  616. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  617. )
  618. return BaseModelOutputWithPast(
  619. last_hidden_state=hidden_states,
  620. past_key_values=past_key_values,
  621. hidden_states=all_hidden_states,
  622. attentions=all_self_attentions,
  623. )
  624. def _update_causal_mask(
  625. self,
  626. attention_mask: Union[torch.Tensor, "BlockMask"],
  627. input_tensor: torch.Tensor,
  628. cache_position: torch.Tensor,
  629. past_key_values: Cache,
  630. output_attentions: bool = False,
  631. ):
  632. if self.config._attn_implementation == "flash_attention_2":
  633. if attention_mask is not None and (attention_mask == 0.0).any():
  634. return attention_mask
  635. return None
  636. if self.config._attn_implementation == "flex_attention":
  637. if isinstance(attention_mask, torch.Tensor):
  638. attention_mask = make_flex_block_causal_mask(attention_mask)
  639. return attention_mask
  640. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  641. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  642. # to infer the attention mask.
  643. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  644. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  645. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  646. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  647. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  648. attention_mask,
  649. inputs_embeds=input_tensor,
  650. past_key_values_length=past_seen_tokens,
  651. is_training=self.training,
  652. ):
  653. return None
  654. dtype = input_tensor.dtype
  655. sequence_length = input_tensor.shape[1]
  656. if using_compilable_cache:
  657. target_length = past_key_values.get_max_cache_shape()
  658. else:
  659. target_length = (
  660. attention_mask.shape[-1]
  661. if isinstance(attention_mask, torch.Tensor)
  662. else past_seen_tokens + sequence_length + 1
  663. )
  664. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  665. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  666. attention_mask,
  667. sequence_length=sequence_length,
  668. target_length=target_length,
  669. dtype=dtype,
  670. cache_position=cache_position,
  671. batch_size=input_tensor.shape[0],
  672. )
  673. if (
  674. self.config._attn_implementation == "sdpa"
  675. and attention_mask is not None
  676. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  677. and not output_attentions
  678. ):
  679. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  680. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  681. # Details: https://github.com/pytorch/pytorch/issues/110213
  682. min_dtype = torch.finfo(dtype).min
  683. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  684. return causal_mask
  685. @staticmethod
  686. def _prepare_4d_causal_attention_mask_with_cache_position(
  687. attention_mask: torch.Tensor,
  688. sequence_length: int,
  689. target_length: int,
  690. dtype: torch.dtype,
  691. cache_position: torch.Tensor,
  692. batch_size: int,
  693. **kwargs,
  694. ):
  695. """
  696. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  697. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  698. Args:
  699. attention_mask (`torch.Tensor`):
  700. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  701. `(batch_size, 1, query_length, key_value_length)`.
  702. sequence_length (`int`):
  703. The sequence length being processed.
  704. target_length (`int`):
  705. The target length: when generating with static cache, the mask should be as long as the static cache,
  706. to account for the 0 padding, the part of the cache that is not filled yet.
  707. dtype (`torch.dtype`):
  708. The dtype to use for the 4D attention mask.
  709. cache_position (`torch.Tensor`):
  710. Indices depicting the position of the input sequence tokens in the sequence.
  711. batch_size (`torch.Tensor`):
  712. Batch size.
  713. """
  714. if attention_mask is not None and attention_mask.dim() == 4:
  715. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  716. causal_mask = attention_mask
  717. else:
  718. min_dtype = torch.finfo(dtype).min
  719. causal_mask = torch.full(
  720. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  721. )
  722. if sequence_length != 1:
  723. causal_mask = torch.triu(causal_mask, diagonal=1)
  724. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  725. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  726. if attention_mask is not None:
  727. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  728. mask_length = attention_mask.shape[-1]
  729. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  730. causal_mask.device
  731. )
  732. padding_mask = padding_mask == 0
  733. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  734. padding_mask, min_dtype
  735. )
  736. return causal_mask
  737. @auto_docstring(
  738. custom_intro="""
  739. The GPT-J Model transformer with a language modeling head on top.
  740. """
  741. )
  742. class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin):
  743. _tied_weights_keys = ["lm_head.weight"]
  744. def __init__(self, config):
  745. super().__init__(config)
  746. self.transformer = GPTJModel(config)
  747. self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
  748. # Model parallel
  749. self.model_parallel = False
  750. self.device_map = None
  751. # Initialize weights and apply final processing
  752. self.post_init()
  753. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  754. def parallelize(self, device_map=None):
  755. warnings.warn(
  756. "`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
  757. " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  758. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
  759. " 0, 'transformer.h.1': 1, ...}",
  760. FutureWarning,
  761. )
  762. self.device_map = (
  763. get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
  764. if device_map is None
  765. else device_map
  766. )
  767. assert_device_map(self.device_map, len(self.transformer.h))
  768. self.transformer.parallelize(self.device_map)
  769. self.lm_head = self.lm_head.to(self.transformer.first_device)
  770. self.model_parallel = True
  771. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  772. def deparallelize(self):
  773. warnings.warn(
  774. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  775. FutureWarning,
  776. )
  777. self.transformer.deparallelize()
  778. self.transformer = self.transformer.to("cpu")
  779. self.lm_head = self.lm_head.to("cpu")
  780. self.model_parallel = False
  781. torch.cuda.empty_cache()
  782. @auto_docstring
  783. def forward(
  784. self,
  785. input_ids: Optional[torch.LongTensor] = None,
  786. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
  787. attention_mask: Optional[torch.FloatTensor] = None,
  788. token_type_ids: Optional[torch.LongTensor] = None,
  789. position_ids: Optional[torch.LongTensor] = None,
  790. head_mask: Optional[torch.FloatTensor] = None,
  791. inputs_embeds: Optional[torch.FloatTensor] = None,
  792. labels: Optional[torch.LongTensor] = None,
  793. use_cache: Optional[bool] = None,
  794. output_attentions: Optional[bool] = None,
  795. output_hidden_states: Optional[bool] = None,
  796. return_dict: Optional[bool] = None,
  797. cache_position: Optional[torch.LongTensor] = None,
  798. **kwargs,
  799. ) -> Union[tuple, CausalLMOutputWithPast]:
  800. r"""
  801. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  802. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  803. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  804. model's internal embedding lookup matrix.
  805. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  806. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  807. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  808. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  809. """
  810. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  811. transformer_outputs = self.transformer(
  812. input_ids,
  813. past_key_values=past_key_values,
  814. attention_mask=attention_mask,
  815. token_type_ids=token_type_ids,
  816. position_ids=position_ids,
  817. head_mask=head_mask,
  818. inputs_embeds=inputs_embeds,
  819. use_cache=use_cache,
  820. output_attentions=output_attentions,
  821. output_hidden_states=output_hidden_states,
  822. return_dict=return_dict,
  823. cache_position=cache_position,
  824. )
  825. hidden_states = transformer_outputs[0]
  826. # Set device for model parallelism
  827. if self.model_parallel:
  828. torch.cuda.set_device(self.transformer.first_device)
  829. hidden_states = hidden_states.to(self.lm_head.weight.device)
  830. # make sure sampling in fp16 works correctly and
  831. # compute loss in fp32 to match with mesh-tf version
  832. # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
  833. lm_logits = self.lm_head(hidden_states).to(torch.float32)
  834. loss = None
  835. if labels is not None:
  836. # move labels to correct device to enable model parallelism
  837. labels = labels.to(lm_logits.device)
  838. # Flatten the tokens
  839. loss = self.loss_function(
  840. lm_logits,
  841. labels,
  842. vocab_size=self.config.vocab_size,
  843. **kwargs,
  844. )
  845. loss = loss.to(hidden_states.dtype)
  846. if not return_dict:
  847. output = (lm_logits,) + transformer_outputs[1:]
  848. return ((loss,) + output) if loss is not None else output
  849. return CausalLMOutputWithPast(
  850. loss=loss,
  851. logits=lm_logits,
  852. past_key_values=transformer_outputs.past_key_values,
  853. hidden_states=transformer_outputs.hidden_states,
  854. attentions=transformer_outputs.attentions,
  855. )
  856. @auto_docstring(
  857. custom_intro="""
  858. The GPT-J Model transformer with a sequence classification head on top (linear layer).
  859. [`GPTJForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  860. (e.g. GPT, GPT-2, GPT-Neo) do.
  861. Since it does classification on the last token, it requires to know the position of the last token. If a
  862. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  863. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  864. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  865. each row of the batch).
  866. """
  867. )
  868. class GPTJForSequenceClassification(GPTJPreTrainedModel):
  869. def __init__(self, config):
  870. super().__init__(config)
  871. self.num_labels = config.num_labels
  872. self.transformer = GPTJModel(config)
  873. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  874. # Model parallel
  875. self.model_parallel = False
  876. self.device_map = None
  877. # Initialize weights and apply final processing
  878. self.post_init()
  879. @auto_docstring
  880. def forward(
  881. self,
  882. input_ids: Optional[torch.LongTensor] = None,
  883. past_key_values: Optional[Cache] = None,
  884. attention_mask: Optional[torch.FloatTensor] = None,
  885. token_type_ids: Optional[torch.LongTensor] = None,
  886. position_ids: Optional[torch.LongTensor] = None,
  887. head_mask: Optional[torch.FloatTensor] = None,
  888. inputs_embeds: Optional[torch.FloatTensor] = None,
  889. labels: Optional[torch.LongTensor] = None,
  890. use_cache: Optional[bool] = None,
  891. output_attentions: Optional[bool] = None,
  892. output_hidden_states: Optional[bool] = None,
  893. return_dict: Optional[bool] = None,
  894. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  895. r"""
  896. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  897. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  898. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  899. model's internal embedding lookup matrix.
  900. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  901. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  902. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  903. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  904. """
  905. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  906. transformer_outputs = self.transformer(
  907. input_ids,
  908. past_key_values=past_key_values,
  909. attention_mask=attention_mask,
  910. token_type_ids=token_type_ids,
  911. position_ids=position_ids,
  912. head_mask=head_mask,
  913. inputs_embeds=inputs_embeds,
  914. use_cache=use_cache,
  915. output_attentions=output_attentions,
  916. output_hidden_states=output_hidden_states,
  917. return_dict=return_dict,
  918. )
  919. hidden_states = transformer_outputs[0]
  920. logits = self.score(hidden_states)
  921. if input_ids is not None:
  922. batch_size = input_ids.shape[0]
  923. else:
  924. batch_size = inputs_embeds.shape[0]
  925. if self.config.pad_token_id is None and batch_size != 1:
  926. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  927. if self.config.pad_token_id is None:
  928. last_non_pad_token = -1
  929. elif input_ids is not None:
  930. # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
  931. non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
  932. token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
  933. last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
  934. else:
  935. last_non_pad_token = -1
  936. logger.warning_once(
  937. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  938. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  939. )
  940. pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
  941. loss = None
  942. if labels is not None:
  943. labels = labels.to(pooled_logits.device)
  944. if self.config.problem_type is None:
  945. if self.num_labels == 1:
  946. self.config.problem_type = "regression"
  947. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  948. self.config.problem_type = "single_label_classification"
  949. else:
  950. self.config.problem_type = "multi_label_classification"
  951. if self.config.problem_type == "regression":
  952. loss_fct = MSELoss()
  953. if self.num_labels == 1:
  954. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  955. else:
  956. loss = loss_fct(pooled_logits, labels)
  957. elif self.config.problem_type == "single_label_classification":
  958. loss_fct = CrossEntropyLoss()
  959. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  960. elif self.config.problem_type == "multi_label_classification":
  961. loss_fct = BCEWithLogitsLoss()
  962. loss = loss_fct(pooled_logits, labels)
  963. if not return_dict:
  964. output = (pooled_logits,) + transformer_outputs[1:]
  965. return ((loss,) + output) if loss is not None else output
  966. return SequenceClassifierOutputWithPast(
  967. loss=loss,
  968. logits=pooled_logits,
  969. past_key_values=transformer_outputs.past_key_values,
  970. hidden_states=transformer_outputs.hidden_states,
  971. attentions=transformer_outputs.attentions,
  972. )
  973. @auto_docstring
  974. class GPTJForQuestionAnswering(GPTJPreTrainedModel):
  975. def __init__(self, config):
  976. super().__init__(config)
  977. self.num_labels = config.num_labels
  978. self.transformer = GPTJModel(config)
  979. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  980. # Model parallel
  981. self.model_parallel = False
  982. self.device_map = None
  983. # Initialize weights and apply final processing
  984. self.post_init()
  985. @auto_docstring
  986. def forward(
  987. self,
  988. input_ids: Optional[torch.LongTensor] = None,
  989. attention_mask: Optional[torch.FloatTensor] = None,
  990. token_type_ids: Optional[torch.LongTensor] = None,
  991. position_ids: Optional[torch.LongTensor] = None,
  992. head_mask: Optional[torch.FloatTensor] = None,
  993. inputs_embeds: Optional[torch.FloatTensor] = None,
  994. start_positions: Optional[torch.LongTensor] = None,
  995. end_positions: Optional[torch.LongTensor] = None,
  996. output_attentions: Optional[bool] = None,
  997. output_hidden_states: Optional[bool] = None,
  998. return_dict: Optional[bool] = None,
  999. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  1000. r"""
  1001. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  1002. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1003. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  1004. model's internal embedding lookup matrix.
  1005. """
  1006. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1007. outputs = self.transformer(
  1008. input_ids,
  1009. attention_mask=attention_mask,
  1010. token_type_ids=token_type_ids,
  1011. position_ids=position_ids,
  1012. head_mask=head_mask,
  1013. inputs_embeds=inputs_embeds,
  1014. output_attentions=output_attentions,
  1015. output_hidden_states=output_hidden_states,
  1016. return_dict=return_dict,
  1017. )
  1018. sequence_output = outputs[0]
  1019. logits = self.qa_outputs(sequence_output)
  1020. start_logits, end_logits = logits.split(1, dim=-1)
  1021. start_logits = start_logits.squeeze(-1).contiguous()
  1022. end_logits = end_logits.squeeze(-1).contiguous()
  1023. total_loss = None
  1024. if start_positions is not None and end_positions is not None:
  1025. # If we are on multi-GPU, split add a dimension
  1026. if len(start_positions.size()) > 1:
  1027. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  1028. if len(end_positions.size()) > 1:
  1029. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  1030. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1031. ignored_index = start_logits.size(1)
  1032. start_positions = start_positions.clamp(0, ignored_index)
  1033. end_positions = end_positions.clamp(0, ignored_index)
  1034. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1035. start_loss = loss_fct(start_logits, start_positions)
  1036. end_loss = loss_fct(end_logits, end_positions)
  1037. total_loss = (start_loss + end_loss) / 2
  1038. if not return_dict:
  1039. output = (start_logits, end_logits) + outputs[2:]
  1040. return ((total_loss,) + output) if total_loss is not None else output
  1041. return QuestionAnsweringModelOutput(
  1042. loss=total_loss,
  1043. start_logits=start_logits,
  1044. end_logits=end_logits,
  1045. hidden_states=outputs.hidden_states,
  1046. attentions=outputs.attentions,
  1047. )
  1048. __all__ = [
  1049. "GPTJForCausalLM",
  1050. "GPTJForQuestionAnswering",
  1051. "GPTJForSequenceClassification",
  1052. "GPTJModel",
  1053. "GPTJPreTrainedModel",
  1054. ]