modeling_codegen.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. # coding=utf-8
  2. # Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CodeGen model."""
  16. from typing import Optional, Union
  17. import torch
  18. from torch import nn
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...generation import GenerationMixin
  22. from ...modeling_attn_mask_utils import AttentionMaskConverter
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import (
  27. auto_docstring,
  28. is_torch_flex_attn_available,
  29. logging,
  30. )
  31. from .configuration_codegen import CodeGenConfig
  32. if is_torch_flex_attn_available():
  33. from torch.nn.attention.flex_attention import BlockMask
  34. from ...integrations.flex_attention import make_flex_block_causal_mask
  35. logger = logging.get_logger(__name__)
  36. # Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
  37. def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
  38. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
  39. sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
  40. return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
  41. # Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
  42. def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
  43. x1 = x[:, :, :, ::2]
  44. x2 = x[:, :, :, 1::2]
  45. x = torch.stack((-x2, x1), dim=-1)
  46. return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
  47. # Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
  48. def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
  49. sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
  50. cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
  51. return (tensor * cos) + (rotate_every_two(tensor) * sin)
  52. class CodeGenAttention(nn.Module):
  53. def __init__(self, config, layer_idx=None):
  54. super().__init__()
  55. max_positions = config.max_position_embeddings
  56. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  57. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  58. self.layer_idx = layer_idx
  59. if layer_idx is None:
  60. logger.warning_once(
  61. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  62. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  63. "when creating this class."
  64. )
  65. self.embed_dim = config.hidden_size
  66. self.num_attention_heads = config.num_attention_heads
  67. self.head_dim = self.embed_dim // self.num_attention_heads
  68. if self.head_dim * self.num_attention_heads != self.embed_dim:
  69. raise ValueError(
  70. f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
  71. f" `num_attention_heads`: {self.num_attention_heads})."
  72. )
  73. self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
  74. self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
  75. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
  76. self.rotary_dim = config.rotary_dim
  77. pos_embd_dim = self.rotary_dim or self.embed_dim
  78. self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
  79. def _split_heads(self, x, n_head, dim_head, mp_num):
  80. reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
  81. reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
  82. return reshaped
  83. def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
  84. """
  85. Merges attn_head_size dim and num_attn_heads dim into n_ctx
  86. """
  87. if len(tensor.shape) == 5:
  88. tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
  89. elif len(tensor.shape) == 4:
  90. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  91. else:
  92. raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
  93. new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
  94. return tensor.view(new_shape)
  95. def _attn(
  96. self,
  97. query,
  98. key,
  99. value,
  100. attention_mask=None,
  101. head_mask=None,
  102. ):
  103. # Keep the attention weights computation in fp32 to avoid overflow issues
  104. query = query.to(torch.float32)
  105. key = key.to(torch.float32)
  106. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  107. if attention_mask is not None:
  108. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  109. attn_weights += causal_mask
  110. attn_weights = attn_weights / self.scale_attn
  111. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  112. attn_weights = attn_weights.to(value.dtype)
  113. attn_weights = self.attn_dropout(attn_weights)
  114. # Mask heads if we want to
  115. if head_mask is not None:
  116. attn_weights = attn_weights * head_mask
  117. attn_output = torch.matmul(attn_weights, value)
  118. return attn_output, attn_weights
  119. def forward(
  120. self,
  121. hidden_states: Optional[torch.FloatTensor],
  122. layer_past: Optional[Cache] = None,
  123. attention_mask: Optional[torch.FloatTensor] = None,
  124. position_ids: Optional[torch.LongTensor] = None,
  125. head_mask: Optional[torch.FloatTensor] = None,
  126. use_cache: Optional[bool] = False,
  127. output_attentions: Optional[bool] = False,
  128. cache_position: Optional[torch.LongTensor] = None,
  129. ) -> Union[
  130. tuple[torch.Tensor, tuple[torch.Tensor]],
  131. Optional[tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]]],
  132. ]:
  133. qkv = self.qkv_proj(hidden_states)
  134. # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
  135. mp_num = 4
  136. qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
  137. local_dim = self.head_dim * self.num_attention_heads // mp_num
  138. query, value, key = torch.split(qkv_split, local_dim, dim=-1)
  139. query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
  140. key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
  141. value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
  142. value = value.permute(0, 2, 1, 3)
  143. embed_positions = self.embed_positions
  144. if embed_positions.device != position_ids.device:
  145. embed_positions = embed_positions.to(position_ids.device)
  146. self.embed_positions = embed_positions
  147. sincos = embed_positions[position_ids]
  148. sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
  149. if self.rotary_dim is not None:
  150. k_rot = key[:, :, :, : self.rotary_dim]
  151. k_pass = key[:, :, :, self.rotary_dim :]
  152. q_rot = query[:, :, :, : self.rotary_dim]
  153. q_pass = query[:, :, :, self.rotary_dim :]
  154. k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
  155. q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
  156. key = torch.cat([k_rot, k_pass], dim=-1)
  157. query = torch.cat([q_rot, q_pass], dim=-1)
  158. else:
  159. key = apply_rotary_pos_emb(key, sin, cos)
  160. query = apply_rotary_pos_emb(query, sin, cos)
  161. key = key.permute(0, 2, 1, 3)
  162. query = query.permute(0, 2, 1, 3)
  163. # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
  164. # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
  165. if layer_past is not None:
  166. cache_kwargs = {
  167. "sin": sin,
  168. "cos": cos,
  169. "partial_rotation_size": self.rotary_dim,
  170. "cache_position": cache_position,
  171. }
  172. key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
  173. # compute self-attention: V x Softmax(QK^T)
  174. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  175. attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
  176. attn_output = self.out_proj(attn_output)
  177. attn_output = self.resid_dropout(attn_output)
  178. return attn_output, attn_weights
  179. # Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen
  180. class CodeGenMLP(nn.Module):
  181. def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
  182. super().__init__()
  183. embed_dim = config.n_embd
  184. self.fc_in = nn.Linear(embed_dim, intermediate_size)
  185. self.fc_out = nn.Linear(intermediate_size, embed_dim)
  186. self.act = ACT2FN[config.activation_function]
  187. self.dropout = nn.Dropout(config.resid_pdrop)
  188. def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
  189. hidden_states = self.fc_in(hidden_states)
  190. hidden_states = self.act(hidden_states)
  191. hidden_states = self.fc_out(hidden_states)
  192. hidden_states = self.dropout(hidden_states)
  193. return hidden_states
  194. # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
  195. class CodeGenBlock(GradientCheckpointingLayer):
  196. # Ignore copy
  197. def __init__(self, config, layer_idx=None):
  198. super().__init__()
  199. inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
  200. self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
  201. self.attn = CodeGenAttention(config, layer_idx)
  202. self.mlp = CodeGenMLP(inner_dim, config)
  203. def forward(
  204. self,
  205. hidden_states: Optional[torch.FloatTensor],
  206. layer_past: Optional[Cache] = None,
  207. attention_mask: Optional[torch.FloatTensor] = None,
  208. position_ids: Optional[torch.LongTensor] = None,
  209. head_mask: Optional[torch.FloatTensor] = None,
  210. use_cache: Optional[bool] = False,
  211. output_attentions: Optional[bool] = False,
  212. cache_position: Optional[torch.LongTensor] = None,
  213. ) -> Union[tuple[torch.Tensor], Optional[tuple[torch.Tensor, tuple[torch.FloatTensor, ...]]]]:
  214. residual = hidden_states
  215. hidden_states = self.ln_1(hidden_states)
  216. attn_outputs, attn_weights = self.attn(
  217. hidden_states=hidden_states,
  218. layer_past=layer_past,
  219. attention_mask=attention_mask,
  220. position_ids=position_ids,
  221. head_mask=head_mask,
  222. use_cache=use_cache,
  223. output_attentions=output_attentions,
  224. cache_position=cache_position,
  225. )
  226. feed_forward_hidden_states = self.mlp(hidden_states)
  227. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  228. return hidden_states, attn_weights
  229. @auto_docstring
  230. class CodeGenPreTrainedModel(PreTrainedModel):
  231. config: CodeGenConfig
  232. base_model_prefix = "transformer"
  233. supports_gradient_checkpointing = True
  234. _no_split_modules = ["CodeGenBlock"]
  235. _skip_keys_device_placement = "past_key_values"
  236. _can_compile_fullgraph = True
  237. def __init__(self, *inputs, **kwargs):
  238. super().__init__(*inputs, **kwargs)
  239. def _init_weights(self, module):
  240. """Initialize the weights."""
  241. if isinstance(module, (nn.Linear,)):
  242. # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
  243. # cf https://github.com/pytorch/pytorch/pull/5617
  244. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  245. if module.bias is not None:
  246. module.bias.data.zero_()
  247. elif isinstance(module, nn.Embedding):
  248. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  249. if module.padding_idx is not None:
  250. module.weight.data[module.padding_idx].zero_()
  251. elif isinstance(module, nn.LayerNorm):
  252. module.bias.data.zero_()
  253. module.weight.data.fill_(1.0)
  254. @auto_docstring
  255. class CodeGenModel(CodeGenPreTrainedModel):
  256. def __init__(self, config):
  257. super().__init__(config)
  258. self.embed_dim = config.n_embd
  259. self.vocab_size = config.vocab_size
  260. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  261. self.drop = nn.Dropout(config.embd_pdrop)
  262. self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
  263. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  264. self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
  265. self.gradient_checkpointing = False
  266. # Initialize weights and apply final processing
  267. self.post_init()
  268. def get_input_embeddings(self):
  269. return self.wte
  270. def set_input_embeddings(self, new_embeddings):
  271. self.wte = new_embeddings
  272. @auto_docstring
  273. def forward(
  274. self,
  275. input_ids: Optional[torch.LongTensor] = None,
  276. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
  277. attention_mask: Optional[torch.FloatTensor] = None,
  278. token_type_ids: Optional[torch.LongTensor] = None,
  279. position_ids: Optional[torch.LongTensor] = None,
  280. head_mask: Optional[torch.FloatTensor] = None,
  281. inputs_embeds: Optional[torch.FloatTensor] = None,
  282. use_cache: Optional[bool] = None,
  283. output_attentions: Optional[bool] = None,
  284. output_hidden_states: Optional[bool] = None,
  285. return_dict: Optional[bool] = None,
  286. cache_position: Optional[torch.LongTensor] = None,
  287. **kwargs, # NOOP kwargs, for now
  288. ) -> Union[tuple, BaseModelOutputWithPast]:
  289. r"""
  290. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  291. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  292. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  293. model's internal embedding lookup matrix.
  294. """
  295. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  296. output_hidden_states = (
  297. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  298. )
  299. use_cache = use_cache if use_cache is not None else self.config.use_cache
  300. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  301. if (input_ids is None) ^ (inputs_embeds is not None):
  302. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  303. if self.gradient_checkpointing and self.training:
  304. if use_cache:
  305. logger.warning_once(
  306. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  307. )
  308. use_cache = False
  309. if inputs_embeds is None:
  310. inputs_embeds = self.wte(input_ids)
  311. if use_cache and past_key_values is None:
  312. past_key_values = DynamicCache(config=self.config)
  313. seq_length = inputs_embeds.shape[1]
  314. if cache_position is None:
  315. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  316. cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
  317. if position_ids is None:
  318. position_ids = cache_position.unsqueeze(0)
  319. causal_mask = self._update_causal_mask(
  320. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  321. )
  322. # Prepare head mask if needed
  323. # 1.0 in head_mask indicate we keep the head
  324. # attention_probs has shape bsz x num_attention_heads x N x N
  325. # head_mask has shape n_layer x batch x num_attention_heads x N x N
  326. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  327. hidden_states = inputs_embeds
  328. if token_type_ids is not None:
  329. token_type_ids = token_type_ids.view(-1, seq_length)
  330. token_type_embeds = self.wte(token_type_ids)
  331. hidden_states = hidden_states + token_type_embeds
  332. hidden_states = self.drop(hidden_states)
  333. output_shape = (-1, seq_length, hidden_states.size(-1))
  334. all_self_attentions = () if output_attentions else None
  335. all_hidden_states = () if output_hidden_states else None
  336. for i, block in enumerate(self.h):
  337. if output_hidden_states:
  338. all_hidden_states = all_hidden_states + (hidden_states,)
  339. outputs = block(
  340. hidden_states,
  341. layer_past=past_key_values,
  342. attention_mask=causal_mask,
  343. position_ids=position_ids,
  344. head_mask=head_mask[i],
  345. use_cache=use_cache,
  346. output_attentions=output_attentions,
  347. cache_position=cache_position,
  348. )
  349. hidden_states = outputs[0]
  350. if output_attentions:
  351. all_self_attentions = all_self_attentions + (outputs[1],)
  352. hidden_states = self.ln_f(hidden_states)
  353. hidden_states = hidden_states.view(output_shape)
  354. # Add last hidden state
  355. if output_hidden_states:
  356. all_hidden_states = all_hidden_states + (hidden_states,)
  357. if not return_dict:
  358. return tuple(
  359. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None
  360. )
  361. return BaseModelOutputWithPast(
  362. last_hidden_state=hidden_states,
  363. past_key_values=past_key_values,
  364. hidden_states=all_hidden_states,
  365. attentions=all_self_attentions,
  366. )
  367. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  368. def _update_causal_mask(
  369. self,
  370. attention_mask: Union[torch.Tensor, "BlockMask"],
  371. input_tensor: torch.Tensor,
  372. cache_position: torch.Tensor,
  373. past_key_values: Cache,
  374. output_attentions: bool = False,
  375. ):
  376. if self.config._attn_implementation == "flash_attention_2":
  377. if attention_mask is not None and (attention_mask == 0.0).any():
  378. return attention_mask
  379. return None
  380. if self.config._attn_implementation == "flex_attention":
  381. if isinstance(attention_mask, torch.Tensor):
  382. attention_mask = make_flex_block_causal_mask(attention_mask)
  383. return attention_mask
  384. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  385. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  386. # to infer the attention mask.
  387. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  388. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  389. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  390. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  391. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  392. attention_mask,
  393. inputs_embeds=input_tensor,
  394. past_key_values_length=past_seen_tokens,
  395. is_training=self.training,
  396. ):
  397. return None
  398. dtype = input_tensor.dtype
  399. sequence_length = input_tensor.shape[1]
  400. if using_compilable_cache:
  401. target_length = past_key_values.get_max_cache_shape()
  402. else:
  403. target_length = (
  404. attention_mask.shape[-1]
  405. if isinstance(attention_mask, torch.Tensor)
  406. else past_seen_tokens + sequence_length + 1
  407. )
  408. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  409. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  410. attention_mask,
  411. sequence_length=sequence_length,
  412. target_length=target_length,
  413. dtype=dtype,
  414. cache_position=cache_position,
  415. batch_size=input_tensor.shape[0],
  416. )
  417. if (
  418. self.config._attn_implementation == "sdpa"
  419. and attention_mask is not None
  420. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  421. and not output_attentions
  422. ):
  423. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  424. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  425. # Details: https://github.com/pytorch/pytorch/issues/110213
  426. min_dtype = torch.finfo(dtype).min
  427. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  428. return causal_mask
  429. @staticmethod
  430. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  431. def _prepare_4d_causal_attention_mask_with_cache_position(
  432. attention_mask: torch.Tensor,
  433. sequence_length: int,
  434. target_length: int,
  435. dtype: torch.dtype,
  436. cache_position: torch.Tensor,
  437. batch_size: int,
  438. **kwargs,
  439. ):
  440. """
  441. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  442. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  443. Args:
  444. attention_mask (`torch.Tensor`):
  445. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  446. `(batch_size, 1, query_length, key_value_length)`.
  447. sequence_length (`int`):
  448. The sequence length being processed.
  449. target_length (`int`):
  450. The target length: when generating with static cache, the mask should be as long as the static cache,
  451. to account for the 0 padding, the part of the cache that is not filled yet.
  452. dtype (`torch.dtype`):
  453. The dtype to use for the 4D attention mask.
  454. cache_position (`torch.Tensor`):
  455. Indices depicting the position of the input sequence tokens in the sequence.
  456. batch_size (`torch.Tensor`):
  457. Batch size.
  458. """
  459. if attention_mask is not None and attention_mask.dim() == 4:
  460. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  461. causal_mask = attention_mask
  462. else:
  463. min_dtype = torch.finfo(dtype).min
  464. causal_mask = torch.full(
  465. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  466. )
  467. if sequence_length != 1:
  468. causal_mask = torch.triu(causal_mask, diagonal=1)
  469. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  470. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  471. if attention_mask is not None:
  472. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  473. mask_length = attention_mask.shape[-1]
  474. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  475. causal_mask.device
  476. )
  477. padding_mask = padding_mask == 0
  478. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  479. padding_mask, min_dtype
  480. )
  481. return causal_mask
  482. @auto_docstring(
  483. custom_intro="""
  484. The CodeGen Model transformer with a language modeling head on top.
  485. """
  486. )
  487. class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
  488. _tied_weights_keys = ["lm_head.weight"]
  489. def __init__(self, config):
  490. super().__init__(config)
  491. self.transformer = CodeGenModel(config)
  492. self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
  493. # Initialize weights and apply final processing
  494. self.post_init()
  495. @auto_docstring
  496. def forward(
  497. self,
  498. input_ids: Optional[torch.LongTensor] = None,
  499. past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None,
  500. attention_mask: Optional[torch.FloatTensor] = None,
  501. token_type_ids: Optional[torch.LongTensor] = None,
  502. position_ids: Optional[torch.LongTensor] = None,
  503. head_mask: Optional[torch.FloatTensor] = None,
  504. inputs_embeds: Optional[torch.FloatTensor] = None,
  505. labels: Optional[torch.LongTensor] = None,
  506. use_cache: Optional[bool] = None,
  507. output_attentions: Optional[bool] = None,
  508. output_hidden_states: Optional[bool] = None,
  509. return_dict: Optional[bool] = None,
  510. cache_position: Optional[torch.LongTensor] = None,
  511. **kwargs,
  512. ) -> Union[tuple, CausalLMOutputWithPast]:
  513. r"""
  514. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*):
  515. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  516. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  517. model's internal embedding lookup matrix.
  518. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  519. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  520. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  521. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  522. """
  523. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  524. transformer_outputs = self.transformer(
  525. input_ids,
  526. past_key_values=past_key_values,
  527. attention_mask=attention_mask,
  528. token_type_ids=token_type_ids,
  529. position_ids=position_ids,
  530. head_mask=head_mask,
  531. inputs_embeds=inputs_embeds,
  532. use_cache=use_cache,
  533. output_attentions=output_attentions,
  534. output_hidden_states=output_hidden_states,
  535. return_dict=return_dict,
  536. cache_position=cache_position,
  537. )
  538. hidden_states = transformer_outputs[0]
  539. # make sure sampling in fp16 works correctly and
  540. # compute loss in fp32 to match with mesh-tf version
  541. # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
  542. lm_logits = self.lm_head(hidden_states).to(torch.float32)
  543. loss = None
  544. if labels is not None:
  545. # move labels to correct device to enable model parallelism
  546. labels = labels.to(lm_logits.device)
  547. # Flatten the tokens
  548. loss = self.loss_function(
  549. lm_logits,
  550. labels,
  551. vocab_size=self.config.vocab_size,
  552. **kwargs,
  553. )
  554. loss = loss.to(hidden_states.dtype)
  555. if not return_dict:
  556. output = (lm_logits,) + transformer_outputs[1:]
  557. return ((loss,) + output) if loss is not None else output
  558. return CausalLMOutputWithPast(
  559. loss=loss,
  560. logits=lm_logits,
  561. past_key_values=transformer_outputs.past_key_values,
  562. hidden_states=transformer_outputs.hidden_states,
  563. attentions=transformer_outputs.attentions,
  564. )
  565. __all__ = ["CodeGenForCausalLM", "CodeGenModel", "CodeGenPreTrainedModel"]