modeling_biogpt.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_biogpt.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from typing import Callable, Optional, Union
  23. import torch
  24. import torch.nn as nn
  25. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  28. from ...generation import GenerationMixin
  29. from ...modeling_attn_mask_utils import AttentionMaskConverter
  30. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import (
  33. BaseModelOutputWithPastAndCrossAttentions,
  34. CausalLMOutputWithCrossAttentions,
  35. SequenceClassifierOutputWithPast,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging
  41. from ...utils.deprecation import deprecate_kwarg
  42. from .configuration_biogpt import BioGptConfig
  43. if is_torch_flex_attn_available():
  44. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  45. logger = logging.get_logger(__name__)
  46. class BioGptLearnedPositionalEmbedding(nn.Embedding):
  47. """
  48. This module learns positional embeddings up to a fixed maximum size.
  49. """
  50. def __init__(self, num_embeddings: int, embedding_dim: int):
  51. # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  52. # and adjust num_embeddings appropriately. Other models don't have this hack
  53. self.offset = 2
  54. super().__init__(num_embeddings + self.offset, embedding_dim)
  55. def forward(
  56. self,
  57. attention_mask: torch.LongTensor,
  58. past_key_values_length: int = 0,
  59. position_ids: Optional[torch.LongTensor] = None,
  60. ):
  61. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  62. if position_ids is None:
  63. position_ids = torch.cumsum(attention_mask, dim=1)
  64. position_ids = (position_ids * attention_mask - 1).long()
  65. # cut positions if `past_key_values_length` is > 0
  66. position_ids = position_ids[:, past_key_values_length:]
  67. return super().forward(position_ids + self.offset)
  68. class BioGptScaledWordEmbedding(nn.Embedding):
  69. """
  70. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  71. """
  72. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
  73. super().__init__(num_embeddings, embedding_dim, padding_idx)
  74. self.embed_scale = embed_scale
  75. def forward(self, input_ids: torch.Tensor):
  76. return super().forward(input_ids) * self.embed_scale
  77. def eager_attention_forward(
  78. module: nn.Module,
  79. query: torch.Tensor,
  80. key: torch.Tensor,
  81. value: torch.Tensor,
  82. attention_mask: Optional[torch.Tensor],
  83. scaling: Optional[float] = None,
  84. dropout: float = 0.0,
  85. head_mask: Optional[torch.Tensor] = None,
  86. **kwargs,
  87. ):
  88. if scaling is None:
  89. scaling = query.size(-1) ** -0.5
  90. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  91. if attention_mask is not None:
  92. attn_weights = attn_weights + attention_mask
  93. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  94. if head_mask is not None:
  95. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  96. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  97. attn_output = torch.matmul(attn_weights, value)
  98. attn_output = attn_output.transpose(1, 2).contiguous()
  99. return attn_output, attn_weights
  100. class BioGptAttention(nn.Module):
  101. """Multi-headed attention from 'Attention Is All You Need' paper"""
  102. def __init__(
  103. self,
  104. embed_dim: int,
  105. num_heads: int,
  106. dropout: float = 0.0,
  107. is_decoder: bool = False,
  108. bias: bool = True,
  109. is_causal: bool = False,
  110. config: Optional[BioGptConfig] = None,
  111. layer_idx: Optional[int] = None,
  112. ):
  113. super().__init__()
  114. self.embed_dim = embed_dim
  115. self.num_heads = num_heads
  116. self.dropout = dropout
  117. self.head_dim = embed_dim // num_heads
  118. self.config = config
  119. if (self.head_dim * num_heads) != self.embed_dim:
  120. raise ValueError(
  121. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  122. f" and `num_heads`: {num_heads})."
  123. )
  124. self.scaling = self.head_dim**-0.5
  125. self.is_decoder = is_decoder
  126. self.is_causal = is_causal
  127. self.layer_idx = layer_idx
  128. if layer_idx is None and self.is_decoder:
  129. logger.warning_once(
  130. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  131. "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  132. "when creating this class."
  133. )
  134. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  135. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  136. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  137. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  138. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  139. def forward(
  140. self,
  141. hidden_states: torch.Tensor,
  142. key_value_states: Optional[torch.Tensor] = None,
  143. past_key_values: Optional[Cache] = None,
  144. attention_mask: Optional[torch.Tensor] = None,
  145. layer_head_mask: Optional[torch.Tensor] = None,
  146. output_attentions: bool = False,
  147. cache_position: Optional[torch.Tensor] = None,
  148. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  149. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  150. **kwargs: Unpack[FlashAttentionKwargs],
  151. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  152. """Input shape: Batch x Time x Channel"""
  153. # if key_value_states are provided this layer is used as a cross-attention layer
  154. # for the decoder
  155. is_cross_attention = key_value_states is not None
  156. # determine input shapes
  157. bsz, tgt_len = hidden_states.shape[:-1]
  158. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  159. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  160. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  161. # get query proj
  162. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  163. is_updated = False
  164. if past_key_values is not None:
  165. if isinstance(past_key_values, EncoderDecoderCache):
  166. is_updated = past_key_values.is_updated.get(self.layer_idx)
  167. if is_cross_attention:
  168. # after the first generated id, we can subsequently re-use all key/value_states from cache
  169. curr_past_key_value = past_key_values.cross_attention_cache
  170. else:
  171. curr_past_key_value = past_key_values.self_attention_cache
  172. else:
  173. curr_past_key_value = past_key_values
  174. current_states = key_value_states if is_cross_attention else hidden_states
  175. if is_cross_attention and past_key_values is not None and is_updated:
  176. # reuse k,v, cross_attentions
  177. key_states = curr_past_key_value.layers[self.layer_idx].keys
  178. value_states = curr_past_key_value.layers[self.layer_idx].values
  179. else:
  180. key_states = self.k_proj(current_states)
  181. value_states = self.v_proj(current_states)
  182. key_states = key_states.view(*kv_input_shape).transpose(1, 2)
  183. value_states = value_states.view(*kv_input_shape).transpose(1, 2)
  184. if past_key_values is not None:
  185. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  186. cache_position = cache_position if not is_cross_attention else None
  187. key_states, value_states = curr_past_key_value.update(
  188. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  189. )
  190. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  191. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  192. past_key_values.is_updated[self.layer_idx] = True
  193. attention_interface: Callable = eager_attention_forward
  194. if self.config._attn_implementation != "eager":
  195. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  196. attn_output, attn_weights = attention_interface(
  197. self,
  198. query_states,
  199. key_states,
  200. value_states,
  201. attention_mask,
  202. dropout=0.0 if not self.training else self.dropout,
  203. scaling=self.scaling,
  204. output_attentions=output_attentions,
  205. head_mask=layer_head_mask,
  206. **kwargs,
  207. )
  208. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  209. attn_output = self.out_proj(attn_output)
  210. return attn_output, attn_weights
  211. class BioGptDecoderLayer(GradientCheckpointingLayer):
  212. def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
  213. super().__init__()
  214. self.embed_dim = config.hidden_size
  215. self.self_attn = BioGptAttention(
  216. embed_dim=self.embed_dim,
  217. num_heads=config.num_attention_heads,
  218. dropout=config.attention_probs_dropout_prob,
  219. is_decoder=True,
  220. is_causal=True,
  221. config=config,
  222. layer_idx=layer_idx,
  223. )
  224. self.dropout = config.hidden_dropout_prob
  225. self.activation_fn = ACT2FN[config.hidden_act]
  226. self.activation_dropout = config.activation_dropout
  227. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  228. self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
  229. self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
  230. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  231. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  232. def forward(
  233. self,
  234. hidden_states: torch.Tensor,
  235. attention_mask: Optional[torch.Tensor] = None,
  236. layer_head_mask: Optional[torch.Tensor] = None,
  237. past_key_values: Optional[Cache] = None,
  238. output_attentions: Optional[bool] = False,
  239. use_cache: Optional[bool] = True,
  240. position_ids: Optional[torch.LongTensor] = None,
  241. cache_position: Optional[torch.Tensor] = None,
  242. **kwargs: Unpack[TransformersKwargs],
  243. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  244. """
  245. Args:
  246. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  247. attention_mask (`torch.FloatTensor`): attention mask of size
  248. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  249. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  250. `(encoder_attention_heads,)`.
  251. past_key_values (`Cache`): cached past key and value projection states
  252. output_attentions (`bool`, *optional*):
  253. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  254. returned tensors for more detail.
  255. use_cache (`bool`, *optional*):
  256. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  257. (see `past_key_values`).
  258. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  259. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  260. cache in the correct position and to infer the complete sequence length.
  261. """
  262. residual = hidden_states
  263. hidden_states = self.self_attn_layer_norm(hidden_states)
  264. # Self Attention
  265. hidden_states, self_attn_weights = self.self_attn(
  266. hidden_states=hidden_states,
  267. past_key_values=past_key_values,
  268. attention_mask=attention_mask,
  269. layer_head_mask=layer_head_mask,
  270. output_attentions=output_attentions,
  271. position_ids=position_ids,
  272. cache_position=cache_position,
  273. **kwargs,
  274. )
  275. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  276. hidden_states = residual + hidden_states
  277. # Fully Connected
  278. residual = hidden_states
  279. hidden_states = self.final_layer_norm(hidden_states)
  280. hidden_states = self.fc1(hidden_states)
  281. hidden_states = self.activation_fn(hidden_states)
  282. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  283. hidden_states = self.fc2(hidden_states)
  284. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  285. hidden_states = residual + hidden_states
  286. outputs = (hidden_states,)
  287. if output_attentions:
  288. outputs += (self_attn_weights,)
  289. return outputs
  290. @auto_docstring
  291. class BioGptPreTrainedModel(PreTrainedModel):
  292. config: BioGptConfig
  293. base_model_prefix = "biogpt"
  294. supports_gradient_checkpointing = True
  295. _supports_flash_attn = True
  296. _supports_sdpa = True
  297. _supports_flex_attn = True
  298. _can_compile_fullgraph = True
  299. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  300. def _update_causal_mask(
  301. self,
  302. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  303. input_tensor: torch.Tensor,
  304. cache_position: torch.Tensor,
  305. past_key_values: Cache,
  306. ):
  307. if self.config._attn_implementation == "flex_attention":
  308. if isinstance(attention_mask, torch.Tensor):
  309. attention_mask = make_flex_block_causal_mask(attention_mask)
  310. # Other attention flavors support in-built causal (when `mask is None`)
  311. # while we need to create our specific block mask regardless
  312. elif attention_mask is None:
  313. attention_mask = make_flex_block_causal_mask(
  314. torch.ones(
  315. size=(input_tensor.shape[0], input_tensor.shape[1]),
  316. device=attention_mask.device,
  317. )
  318. )
  319. return attention_mask
  320. if self.config._attn_implementation == "flash_attention_2":
  321. if attention_mask is not None and (attention_mask == 0.0).any():
  322. return attention_mask
  323. return None
  324. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  325. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  326. # to infer the attention mask.
  327. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  328. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  329. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  330. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  331. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  332. attention_mask,
  333. inputs_embeds=input_tensor,
  334. past_key_values_length=past_seen_tokens,
  335. is_training=self.training,
  336. ):
  337. return None
  338. dtype = input_tensor.dtype
  339. sequence_length = input_tensor.shape[1]
  340. if using_compilable_cache:
  341. target_length = past_key_values.get_max_cache_shape()
  342. else:
  343. target_length = (
  344. attention_mask.shape[-1]
  345. if isinstance(attention_mask, torch.Tensor)
  346. else past_seen_tokens + sequence_length + 1
  347. )
  348. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  349. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  350. attention_mask,
  351. sequence_length=sequence_length,
  352. target_length=target_length,
  353. dtype=dtype,
  354. cache_position=cache_position,
  355. batch_size=input_tensor.shape[0],
  356. )
  357. if (
  358. self.config._attn_implementation == "sdpa"
  359. and attention_mask is not None
  360. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  361. ):
  362. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  363. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  364. # Details: https://github.com/pytorch/pytorch/issues/110213
  365. min_dtype = torch.finfo(dtype).min
  366. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  367. return causal_mask
  368. @staticmethod
  369. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  370. def _prepare_4d_causal_attention_mask_with_cache_position(
  371. attention_mask: torch.Tensor,
  372. sequence_length: int,
  373. target_length: int,
  374. dtype: torch.dtype,
  375. cache_position: torch.Tensor,
  376. batch_size: int,
  377. **kwargs,
  378. ):
  379. """
  380. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  381. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  382. Args:
  383. attention_mask (`torch.Tensor`):
  384. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  385. `(batch_size, 1, query_length, key_value_length)`.
  386. sequence_length (`int`):
  387. The sequence length being processed.
  388. target_length (`int`):
  389. The target length: when generating with static cache, the mask should be as long as the static cache,
  390. to account for the 0 padding, the part of the cache that is not filled yet.
  391. dtype (`torch.dtype`):
  392. The dtype to use for the 4D attention mask.
  393. cache_position (`torch.Tensor`):
  394. Indices depicting the position of the input sequence tokens in the sequence.
  395. batch_size (`torch.Tensor`):
  396. Batch size.
  397. """
  398. if attention_mask is not None and attention_mask.dim() == 4:
  399. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  400. causal_mask = attention_mask
  401. else:
  402. min_dtype = torch.finfo(dtype).min
  403. causal_mask = torch.full(
  404. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  405. )
  406. if sequence_length != 1:
  407. causal_mask = torch.triu(causal_mask, diagonal=1)
  408. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  409. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  410. if attention_mask is not None:
  411. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  412. mask_length = attention_mask.shape[-1]
  413. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  414. causal_mask.device
  415. )
  416. padding_mask = padding_mask == 0
  417. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  418. padding_mask, min_dtype
  419. )
  420. return causal_mask
  421. @auto_docstring
  422. class BioGptModel(BioGptPreTrainedModel):
  423. def __init__(self, config: BioGptConfig):
  424. super().__init__(config)
  425. self.config = config
  426. self.layerdrop = config.layerdrop
  427. self.dropout = config.hidden_dropout_prob
  428. self.embed_dim = config.hidden_size
  429. self.padding_idx = config.pad_token_id
  430. embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  431. self.embed_tokens = BioGptScaledWordEmbedding(
  432. config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
  433. )
  434. self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
  435. self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  436. self.layer_norm = nn.LayerNorm(self.embed_dim)
  437. self.gradient_checkpointing = False
  438. # Initialize weights and apply final processing
  439. self.post_init()
  440. @auto_docstring
  441. def forward(
  442. self,
  443. input_ids: Optional[torch.LongTensor] = None,
  444. attention_mask: Optional[torch.FloatTensor] = None,
  445. head_mask: Optional[torch.FloatTensor] = None,
  446. inputs_embeds: Optional[torch.FloatTensor] = None,
  447. past_key_values: Optional[Cache] = None,
  448. use_cache: Optional[bool] = None,
  449. position_ids: Optional[torch.LongTensor] = None,
  450. output_attentions: Optional[bool] = None,
  451. output_hidden_states: Optional[bool] = None,
  452. return_dict: Optional[bool] = None,
  453. cache_position: Optional[torch.Tensor] = None,
  454. **kwargs: Unpack[TransformersKwargs],
  455. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  456. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  457. output_hidden_states = (
  458. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  459. )
  460. use_cache = use_cache if use_cache is not None else self.config.use_cache
  461. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  462. # retrieve input_ids and inputs_embeds
  463. if (input_ids is None) ^ (inputs_embeds is not None):
  464. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  465. elif input_ids is not None:
  466. input = input_ids
  467. input_shape = input.shape
  468. input_ids = input_ids.view(-1, input_shape[-1])
  469. elif inputs_embeds is not None:
  470. input_shape = inputs_embeds.size()[:-1]
  471. input = inputs_embeds[:, :, -1]
  472. else:
  473. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  474. if inputs_embeds is None:
  475. inputs_embeds = self.embed_tokens(input)
  476. if self.gradient_checkpointing and self.training:
  477. if use_cache:
  478. logger.warning_once(
  479. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  480. )
  481. use_cache = False
  482. # initialize past_key_values
  483. if use_cache and past_key_values is None:
  484. past_key_values = DynamicCache(config=self.config)
  485. if use_cache and isinstance(past_key_values, tuple):
  486. logger.warning_once(
  487. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  488. "You should pass an instance of `DynamicCache` instead, e.g. "
  489. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  490. )
  491. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  492. batch_size, seq_length = inputs_embeds.size()[:-1]
  493. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  494. if cache_position is None:
  495. cache_position = torch.arange(
  496. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  497. )
  498. if attention_mask is None:
  499. # required mask seq length can be calculated via length of past cache
  500. mask_seq_length = past_key_values_length + seq_length
  501. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  502. self_attn_cache = past_key_values
  503. causal_mask = self._update_causal_mask(
  504. attention_mask,
  505. inputs_embeds,
  506. cache_position,
  507. self_attn_cache,
  508. )
  509. # embed positions
  510. if position_ids is None:
  511. # position_ids = cache_position.unsqueeze(0)
  512. position_ids = torch.cumsum(attention_mask, dim=1)
  513. position_ids = (position_ids * attention_mask - 1).long()
  514. # cut positions if `past_seen_tokens` is > 0
  515. position_ids = position_ids[:, past_key_values_length:]
  516. positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
  517. hidden_states = inputs_embeds + positions
  518. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  519. if self.gradient_checkpointing and self.training:
  520. if use_cache:
  521. logger.warning_once(
  522. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  523. )
  524. use_cache = False
  525. all_hidden_states = () if output_hidden_states else None
  526. all_self_attns = () if output_attentions else None
  527. all_cross_attentions = None
  528. for idx, decoder_layer in enumerate(self.layers):
  529. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  530. if output_hidden_states:
  531. all_hidden_states += (hidden_states,)
  532. if self.training:
  533. dropout_probability = torch.rand([])
  534. if dropout_probability < self.layerdrop:
  535. continue
  536. layer_outputs = decoder_layer(
  537. hidden_states,
  538. attention_mask=causal_mask,
  539. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  540. past_key_values=past_key_values,
  541. output_attentions=output_attentions,
  542. use_cache=use_cache,
  543. position_ids=position_ids,
  544. cache_position=cache_position,
  545. **kwargs,
  546. )
  547. hidden_states = layer_outputs[0]
  548. if output_attentions:
  549. all_self_attns += (layer_outputs[1],)
  550. # add hidden states from the last decoder layer
  551. if output_hidden_states:
  552. all_hidden_states += (hidden_states,)
  553. hidden_states = self.layer_norm(hidden_states)
  554. if not return_dict:
  555. return tuple(
  556. v
  557. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  558. if v is not None
  559. )
  560. return BaseModelOutputWithPastAndCrossAttentions(
  561. last_hidden_state=hidden_states,
  562. past_key_values=past_key_values,
  563. hidden_states=all_hidden_states,
  564. attentions=all_self_attns,
  565. cross_attentions=all_cross_attentions,
  566. )
  567. @auto_docstring(
  568. custom_intro="""
  569. BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
  570. """
  571. )
  572. class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
  573. _tied_weights_keys = ["output_projection.weight"]
  574. def __init__(self, config):
  575. super().__init__(config)
  576. self.biogpt = BioGptModel(config)
  577. self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  578. # Initialize weights and apply final processing
  579. self.post_init()
  580. def get_output_embeddings(self):
  581. return self.output_projection
  582. def set_output_embeddings(self, new_embeddings):
  583. self.output_projection = new_embeddings
  584. @auto_docstring
  585. def forward(
  586. self,
  587. input_ids: Optional[torch.LongTensor] = None,
  588. attention_mask: Optional[torch.FloatTensor] = None,
  589. head_mask: Optional[torch.FloatTensor] = None,
  590. inputs_embeds: Optional[torch.FloatTensor] = None,
  591. past_key_values: Optional[Cache] = None,
  592. labels: Optional[torch.LongTensor] = None,
  593. use_cache: Optional[bool] = None,
  594. position_ids: Optional[torch.LongTensor] = None,
  595. output_attentions: Optional[bool] = None,
  596. output_hidden_states: Optional[bool] = None,
  597. return_dict: Optional[bool] = None,
  598. cache_position: Optional[torch.Tensor] = None,
  599. **kwargs: Unpack[TransformersKwargs],
  600. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  601. r"""
  602. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  603. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  604. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  605. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  606. """
  607. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  608. outputs = self.biogpt(
  609. input_ids,
  610. attention_mask=attention_mask,
  611. head_mask=head_mask,
  612. inputs_embeds=inputs_embeds,
  613. past_key_values=past_key_values,
  614. use_cache=use_cache,
  615. position_ids=position_ids,
  616. output_attentions=output_attentions,
  617. output_hidden_states=output_hidden_states,
  618. return_dict=return_dict,
  619. cache_position=cache_position,
  620. **kwargs,
  621. )
  622. sequence_output = outputs[0]
  623. prediction_scores = self.output_projection(sequence_output)
  624. lm_loss = None
  625. if labels is not None:
  626. lm_loss = self.loss_function(
  627. prediction_scores,
  628. labels,
  629. vocab_size=self.config.vocab_size,
  630. **kwargs,
  631. )
  632. if not return_dict:
  633. output = (prediction_scores,) + outputs[1:]
  634. return ((lm_loss,) + output) if lm_loss is not None else output
  635. return CausalLMOutputWithCrossAttentions(
  636. loss=lm_loss,
  637. logits=prediction_scores,
  638. past_key_values=outputs.past_key_values,
  639. hidden_states=outputs.hidden_states,
  640. attentions=outputs.attentions,
  641. cross_attentions=outputs.cross_attentions,
  642. )
  643. @auto_docstring
  644. class BioGptForTokenClassification(BioGptPreTrainedModel):
  645. def __init__(self, config):
  646. super().__init__(config)
  647. self.num_labels = config.num_labels
  648. self.biogpt = BioGptModel(config)
  649. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  650. classifier_dropout = config.classifier_dropout
  651. else:
  652. classifier_dropout = config.hidden_dropout_prob
  653. self.dropout = nn.Dropout(classifier_dropout)
  654. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  655. self.post_init()
  656. @auto_docstring
  657. def forward(
  658. self,
  659. input_ids: Optional[torch.LongTensor] = None,
  660. token_type_ids: Optional[torch.LongTensor] = None,
  661. attention_mask: Optional[torch.FloatTensor] = None,
  662. head_mask: Optional[torch.FloatTensor] = None,
  663. past_key_values: Optional[Cache] = None,
  664. inputs_embeds: Optional[torch.FloatTensor] = None,
  665. labels: Optional[torch.LongTensor] = None,
  666. use_cache: Optional[bool] = None,
  667. position_ids: Optional[torch.LongTensor] = None,
  668. output_attentions: Optional[bool] = None,
  669. output_hidden_states: Optional[bool] = None,
  670. return_dict: Optional[bool] = None,
  671. cache_position: Optional[torch.Tensor] = None,
  672. ) -> Union[tuple, TokenClassifierOutput]:
  673. r"""
  674. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  675. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  676. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  677. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  678. """
  679. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  680. transformer_outputs = self.biogpt(
  681. input_ids,
  682. past_key_values=past_key_values,
  683. attention_mask=attention_mask,
  684. head_mask=head_mask,
  685. inputs_embeds=inputs_embeds,
  686. use_cache=use_cache,
  687. position_ids=position_ids,
  688. output_attentions=output_attentions,
  689. output_hidden_states=output_hidden_states,
  690. return_dict=return_dict,
  691. cache_position=cache_position,
  692. )
  693. hidden_states = transformer_outputs[0]
  694. hidden_states = self.dropout(hidden_states)
  695. logits = self.classifier(hidden_states)
  696. loss = None
  697. if labels is not None:
  698. loss_fct = CrossEntropyLoss()
  699. # Only keep active parts of the loss
  700. if attention_mask is not None:
  701. active_loss = attention_mask.view(-1) == 1
  702. active_logits = logits.view(-1, self.num_labels)
  703. active_labels = torch.where(
  704. active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
  705. )
  706. loss = loss_fct(active_logits, active_labels)
  707. else:
  708. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  709. if not return_dict:
  710. output = (logits,) + transformer_outputs[2:]
  711. return ((loss,) + output) if loss is not None else output
  712. return TokenClassifierOutput(
  713. loss=loss,
  714. logits=logits,
  715. hidden_states=transformer_outputs.hidden_states,
  716. attentions=transformer_outputs.attentions,
  717. )
  718. @auto_docstring(
  719. custom_intro="""
  720. The BioGpt Model transformer with a sequence classification head on top (linear layer).
  721. [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  722. (e.g. GPT-2) do.
  723. Since it does classification on the last token, it is required to know the position of the last token. If a
  724. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  725. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  726. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  727. each row of the batch).
  728. """
  729. )
  730. class BioGptForSequenceClassification(BioGptPreTrainedModel):
  731. def __init__(self, config: BioGptConfig):
  732. super().__init__(config)
  733. self.num_labels = config.num_labels
  734. self.biogpt = BioGptModel(config)
  735. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  736. # Initialize weights and apply final processing
  737. self.post_init()
  738. @auto_docstring
  739. def forward(
  740. self,
  741. input_ids: Optional[torch.LongTensor] = None,
  742. attention_mask: Optional[torch.FloatTensor] = None,
  743. head_mask: Optional[torch.FloatTensor] = None,
  744. past_key_values: Optional[Cache] = None,
  745. inputs_embeds: Optional[torch.FloatTensor] = None,
  746. labels: Optional[torch.LongTensor] = None,
  747. use_cache: Optional[bool] = None,
  748. position_ids: Optional[torch.LongTensor] = None,
  749. output_attentions: Optional[bool] = None,
  750. output_hidden_states: Optional[bool] = None,
  751. return_dict: Optional[bool] = None,
  752. cache_position: Optional[torch.Tensor] = None,
  753. logits_to_keep: Union[int, torch.Tensor] = 0,
  754. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  755. r"""
  756. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  757. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  758. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  759. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  760. """
  761. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  762. transformer_outputs = self.biogpt(
  763. input_ids,
  764. past_key_values=past_key_values,
  765. attention_mask=attention_mask,
  766. head_mask=head_mask,
  767. inputs_embeds=inputs_embeds,
  768. use_cache=use_cache,
  769. position_ids=position_ids,
  770. output_attentions=output_attentions,
  771. output_hidden_states=output_hidden_states,
  772. return_dict=return_dict,
  773. cache_position=cache_position,
  774. )
  775. hidden_states = transformer_outputs[0]
  776. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  777. logits = self.score(hidden_states[:, slice_indices, :])
  778. if input_ids is not None:
  779. batch_size, sequence_length = input_ids.shape[:2]
  780. else:
  781. batch_size, sequence_length = inputs_embeds.shape[:2]
  782. if self.config.pad_token_id is None:
  783. sequence_length = -1
  784. else:
  785. if input_ids is not None:
  786. sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
  787. else:
  788. sequence_length = -1
  789. logger.warning_once(
  790. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  791. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  792. )
  793. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
  794. loss = None
  795. if labels is not None:
  796. if self.config.problem_type is None:
  797. if self.num_labels == 1:
  798. self.config.problem_type = "regression"
  799. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  800. self.config.problem_type = "single_label_classification"
  801. else:
  802. self.config.problem_type = "multi_label_classification"
  803. if self.config.problem_type == "regression":
  804. loss_fct = MSELoss()
  805. if self.num_labels == 1:
  806. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  807. else:
  808. loss = loss_fct(pooled_logits, labels)
  809. elif self.config.problem_type == "single_label_classification":
  810. loss_fct = CrossEntropyLoss()
  811. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  812. elif self.config.problem_type == "multi_label_classification":
  813. loss_fct = BCEWithLogitsLoss()
  814. loss = loss_fct(pooled_logits, labels)
  815. if not return_dict:
  816. output = (pooled_logits,) + transformer_outputs[1:]
  817. return ((loss,) + output) if loss is not None else output
  818. return SequenceClassifierOutputWithPast(
  819. loss=loss,
  820. logits=pooled_logits,
  821. past_key_values=transformer_outputs.past_key_values,
  822. hidden_states=transformer_outputs.hidden_states,
  823. attentions=transformer_outputs.attentions,
  824. )
  825. def get_input_embeddings(self):
  826. return self.biogpt.embed_tokens
  827. def set_input_embeddings(self, value):
  828. self.biogpt.embed_tokens = value
  829. __all__ = [
  830. "BioGptForCausalLM",
  831. "BioGptForTokenClassification",
  832. "BioGptForSequenceClassification",
  833. "BioGptModel",
  834. "BioGptPreTrainedModel",
  835. ]