modular_biogpt.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BioGPT model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. import torch.nn as nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import (
  25. AttentionMaskConverter,
  26. )
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. CausalLMOutputWithCrossAttentions,
  30. SequenceClassifierOutputWithPast,
  31. TokenClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import (
  36. TransformersKwargs,
  37. auto_docstring,
  38. is_torch_flex_attn_available,
  39. logger,
  40. )
  41. from ...utils.deprecation import deprecate_kwarg
  42. from ..bart.modeling_bart import (
  43. BartAttention,
  44. BartDecoderLayer,
  45. BartScaledWordEmbedding,
  46. )
  47. from ..opt.modeling_opt import OPTLearnedPositionalEmbedding
  48. from .configuration_biogpt import BioGptConfig
  49. if is_torch_flex_attn_available():
  50. from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
  51. class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
  52. def forward(
  53. self,
  54. attention_mask: torch.LongTensor,
  55. past_key_values_length: int = 0,
  56. position_ids: Optional[torch.LongTensor] = None,
  57. ):
  58. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  59. super().forward(attention_mask, past_key_values_length, position_ids)
  60. class BioGptScaledWordEmbedding(BartScaledWordEmbedding):
  61. pass
  62. class BioGptAttention(BartAttention):
  63. pass
  64. class BioGptDecoderLayer(BartDecoderLayer):
  65. def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
  66. super().__init__(config)
  67. self.embed_dim = config.hidden_size
  68. self.self_attn = BioGptAttention(
  69. embed_dim=self.embed_dim,
  70. num_heads=config.num_attention_heads,
  71. dropout=config.attention_probs_dropout_prob,
  72. is_decoder=True,
  73. is_causal=True,
  74. config=config,
  75. layer_idx=layer_idx,
  76. )
  77. self.dropout = config.hidden_dropout_prob
  78. self.activation_fn = ACT2FN[config.hidden_act]
  79. self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
  80. self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
  81. del self.encoder_attn
  82. del self.encoder_attn_layer_norm
  83. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  84. def forward(
  85. self,
  86. hidden_states: torch.Tensor,
  87. attention_mask: Optional[torch.Tensor] = None,
  88. layer_head_mask: Optional[torch.Tensor] = None,
  89. past_key_values: Optional[Cache] = None,
  90. output_attentions: Optional[bool] = False,
  91. use_cache: Optional[bool] = True,
  92. position_ids: Optional[torch.LongTensor] = None,
  93. cache_position: Optional[torch.Tensor] = None,
  94. **kwargs: Unpack[TransformersKwargs],
  95. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  96. """
  97. Args:
  98. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  99. attention_mask (`torch.FloatTensor`): attention mask of size
  100. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  101. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  102. `(encoder_attention_heads,)`.
  103. past_key_values (`Cache`): cached past key and value projection states
  104. output_attentions (`bool`, *optional*):
  105. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  106. returned tensors for more detail.
  107. use_cache (`bool`, *optional*):
  108. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  109. (see `past_key_values`).
  110. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  111. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  112. cache in the correct position and to infer the complete sequence length.
  113. """
  114. residual = hidden_states
  115. hidden_states = self.self_attn_layer_norm(hidden_states)
  116. # Self Attention
  117. hidden_states, self_attn_weights = self.self_attn(
  118. hidden_states=hidden_states,
  119. past_key_values=past_key_values,
  120. attention_mask=attention_mask,
  121. layer_head_mask=layer_head_mask,
  122. output_attentions=output_attentions,
  123. position_ids=position_ids,
  124. cache_position=cache_position,
  125. **kwargs,
  126. )
  127. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  128. hidden_states = residual + hidden_states
  129. # Fully Connected
  130. residual = hidden_states
  131. hidden_states = self.final_layer_norm(hidden_states)
  132. hidden_states = self.fc1(hidden_states)
  133. hidden_states = self.activation_fn(hidden_states)
  134. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  135. hidden_states = self.fc2(hidden_states)
  136. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  137. hidden_states = residual + hidden_states
  138. outputs = (hidden_states,)
  139. if output_attentions:
  140. outputs += (self_attn_weights,)
  141. return outputs
  142. @auto_docstring
  143. class BioGptPreTrainedModel(PreTrainedModel):
  144. config: BioGptConfig
  145. base_model_prefix = "biogpt"
  146. supports_gradient_checkpointing = True
  147. _supports_flash_attn = True
  148. _supports_sdpa = True
  149. _supports_flex_attn = True
  150. _can_compile_fullgraph = True
  151. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
  152. def _update_causal_mask(
  153. self,
  154. attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
  155. input_tensor: torch.Tensor,
  156. cache_position: torch.Tensor,
  157. past_key_values: Cache,
  158. ):
  159. if self.config._attn_implementation == "flex_attention":
  160. if isinstance(attention_mask, torch.Tensor):
  161. attention_mask = make_flex_block_causal_mask(attention_mask)
  162. # Other attention flavors support in-built causal (when `mask is None`)
  163. # while we need to create our specific block mask regardless
  164. elif attention_mask is None:
  165. attention_mask = make_flex_block_causal_mask(
  166. torch.ones(
  167. size=(input_tensor.shape[0], input_tensor.shape[1]),
  168. device=attention_mask.device,
  169. )
  170. )
  171. return attention_mask
  172. if self.config._attn_implementation == "flash_attention_2":
  173. if attention_mask is not None and (attention_mask == 0.0).any():
  174. return attention_mask
  175. return None
  176. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  177. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  178. # to infer the attention mask.
  179. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  180. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  181. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  182. if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
  183. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  184. attention_mask,
  185. inputs_embeds=input_tensor,
  186. past_key_values_length=past_seen_tokens,
  187. is_training=self.training,
  188. ):
  189. return None
  190. dtype = input_tensor.dtype
  191. sequence_length = input_tensor.shape[1]
  192. if using_compilable_cache:
  193. target_length = past_key_values.get_max_cache_shape()
  194. else:
  195. target_length = (
  196. attention_mask.shape[-1]
  197. if isinstance(attention_mask, torch.Tensor)
  198. else past_seen_tokens + sequence_length + 1
  199. )
  200. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  201. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  202. attention_mask,
  203. sequence_length=sequence_length,
  204. target_length=target_length,
  205. dtype=dtype,
  206. cache_position=cache_position,
  207. batch_size=input_tensor.shape[0],
  208. )
  209. if (
  210. self.config._attn_implementation == "sdpa"
  211. and attention_mask is not None
  212. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  213. ):
  214. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  215. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  216. # Details: https://github.com/pytorch/pytorch/issues/110213
  217. min_dtype = torch.finfo(dtype).min
  218. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  219. return causal_mask
  220. @staticmethod
  221. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  222. def _prepare_4d_causal_attention_mask_with_cache_position(
  223. attention_mask: torch.Tensor,
  224. sequence_length: int,
  225. target_length: int,
  226. dtype: torch.dtype,
  227. cache_position: torch.Tensor,
  228. batch_size: int,
  229. **kwargs,
  230. ):
  231. """
  232. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  233. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  234. Args:
  235. attention_mask (`torch.Tensor`):
  236. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  237. `(batch_size, 1, query_length, key_value_length)`.
  238. sequence_length (`int`):
  239. The sequence length being processed.
  240. target_length (`int`):
  241. The target length: when generating with static cache, the mask should be as long as the static cache,
  242. to account for the 0 padding, the part of the cache that is not filled yet.
  243. dtype (`torch.dtype`):
  244. The dtype to use for the 4D attention mask.
  245. cache_position (`torch.Tensor`):
  246. Indices depicting the position of the input sequence tokens in the sequence.
  247. batch_size (`torch.Tensor`):
  248. Batch size.
  249. """
  250. if attention_mask is not None and attention_mask.dim() == 4:
  251. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  252. causal_mask = attention_mask
  253. else:
  254. min_dtype = torch.finfo(dtype).min
  255. causal_mask = torch.full(
  256. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  257. )
  258. if sequence_length != 1:
  259. causal_mask = torch.triu(causal_mask, diagonal=1)
  260. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  261. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  262. if attention_mask is not None:
  263. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  264. mask_length = attention_mask.shape[-1]
  265. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  266. causal_mask.device
  267. )
  268. padding_mask = padding_mask == 0
  269. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  270. padding_mask, min_dtype
  271. )
  272. return causal_mask
  273. @auto_docstring
  274. class BioGptModel(BioGptPreTrainedModel):
  275. def __init__(self, config: BioGptConfig):
  276. super().__init__(config)
  277. self.config = config
  278. self.layerdrop = config.layerdrop
  279. self.dropout = config.hidden_dropout_prob
  280. self.embed_dim = config.hidden_size
  281. self.padding_idx = config.pad_token_id
  282. embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  283. self.embed_tokens = BioGptScaledWordEmbedding(
  284. config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
  285. )
  286. self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
  287. self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  288. self.layer_norm = nn.LayerNorm(self.embed_dim)
  289. self.gradient_checkpointing = False
  290. # Initialize weights and apply final processing
  291. self.post_init()
  292. @auto_docstring
  293. def forward(
  294. self,
  295. input_ids: Optional[torch.LongTensor] = None,
  296. attention_mask: Optional[torch.FloatTensor] = None,
  297. head_mask: Optional[torch.FloatTensor] = None,
  298. inputs_embeds: Optional[torch.FloatTensor] = None,
  299. past_key_values: Optional[Cache] = None,
  300. use_cache: Optional[bool] = None,
  301. position_ids: Optional[torch.LongTensor] = None,
  302. output_attentions: Optional[bool] = None,
  303. output_hidden_states: Optional[bool] = None,
  304. return_dict: Optional[bool] = None,
  305. cache_position: Optional[torch.Tensor] = None,
  306. **kwargs: Unpack[TransformersKwargs],
  307. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  308. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  309. output_hidden_states = (
  310. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  311. )
  312. use_cache = use_cache if use_cache is not None else self.config.use_cache
  313. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  314. # retrieve input_ids and inputs_embeds
  315. if (input_ids is None) ^ (inputs_embeds is not None):
  316. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  317. elif input_ids is not None:
  318. input = input_ids
  319. input_shape = input.shape
  320. input_ids = input_ids.view(-1, input_shape[-1])
  321. elif inputs_embeds is not None:
  322. input_shape = inputs_embeds.size()[:-1]
  323. input = inputs_embeds[:, :, -1]
  324. else:
  325. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  326. if inputs_embeds is None:
  327. inputs_embeds = self.embed_tokens(input)
  328. if self.gradient_checkpointing and self.training:
  329. if use_cache:
  330. logger.warning_once(
  331. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  332. )
  333. use_cache = False
  334. # initialize past_key_values
  335. if use_cache and past_key_values is None:
  336. past_key_values = DynamicCache(config=self.config)
  337. if use_cache and isinstance(past_key_values, tuple):
  338. logger.warning_once(
  339. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  340. "You should pass an instance of `DynamicCache` instead, e.g. "
  341. "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
  342. )
  343. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  344. batch_size, seq_length = inputs_embeds.size()[:-1]
  345. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  346. if cache_position is None:
  347. cache_position = torch.arange(
  348. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  349. )
  350. if attention_mask is None:
  351. # required mask seq length can be calculated via length of past cache
  352. mask_seq_length = past_key_values_length + seq_length
  353. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  354. self_attn_cache = past_key_values
  355. causal_mask = self._update_causal_mask(
  356. attention_mask,
  357. inputs_embeds,
  358. cache_position,
  359. self_attn_cache,
  360. )
  361. # embed positions
  362. if position_ids is None:
  363. # position_ids = cache_position.unsqueeze(0)
  364. position_ids = torch.cumsum(attention_mask, dim=1)
  365. position_ids = (position_ids * attention_mask - 1).long()
  366. # cut positions if `past_seen_tokens` is > 0
  367. position_ids = position_ids[:, past_key_values_length:]
  368. positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
  369. hidden_states = inputs_embeds + positions
  370. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  371. if self.gradient_checkpointing and self.training:
  372. if use_cache:
  373. logger.warning_once(
  374. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  375. )
  376. use_cache = False
  377. all_hidden_states = () if output_hidden_states else None
  378. all_self_attns = () if output_attentions else None
  379. all_cross_attentions = None
  380. for idx, decoder_layer in enumerate(self.layers):
  381. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  382. if output_hidden_states:
  383. all_hidden_states += (hidden_states,)
  384. if self.training:
  385. dropout_probability = torch.rand([])
  386. if dropout_probability < self.layerdrop:
  387. continue
  388. layer_outputs = decoder_layer(
  389. hidden_states,
  390. attention_mask=causal_mask,
  391. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  392. past_key_values=past_key_values,
  393. output_attentions=output_attentions,
  394. use_cache=use_cache,
  395. position_ids=position_ids,
  396. cache_position=cache_position,
  397. **kwargs,
  398. )
  399. hidden_states = layer_outputs[0]
  400. if output_attentions:
  401. all_self_attns += (layer_outputs[1],)
  402. # add hidden states from the last decoder layer
  403. if output_hidden_states:
  404. all_hidden_states += (hidden_states,)
  405. hidden_states = self.layer_norm(hidden_states)
  406. if not return_dict:
  407. return tuple(
  408. v
  409. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  410. if v is not None
  411. )
  412. return BaseModelOutputWithPastAndCrossAttentions(
  413. last_hidden_state=hidden_states,
  414. past_key_values=past_key_values,
  415. hidden_states=all_hidden_states,
  416. attentions=all_self_attns,
  417. cross_attentions=all_cross_attentions,
  418. )
  419. @auto_docstring(
  420. custom_intro="""
  421. BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
  422. """
  423. )
  424. class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
  425. _tied_weights_keys = ["output_projection.weight"]
  426. def __init__(self, config):
  427. super().__init__(config)
  428. self.biogpt = BioGptModel(config)
  429. self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  430. # Initialize weights and apply final processing
  431. self.post_init()
  432. def get_output_embeddings(self):
  433. return self.output_projection
  434. def set_output_embeddings(self, new_embeddings):
  435. self.output_projection = new_embeddings
  436. @auto_docstring
  437. def forward(
  438. self,
  439. input_ids: Optional[torch.LongTensor] = None,
  440. attention_mask: Optional[torch.FloatTensor] = None,
  441. head_mask: Optional[torch.FloatTensor] = None,
  442. inputs_embeds: Optional[torch.FloatTensor] = None,
  443. past_key_values: Optional[Cache] = None,
  444. labels: Optional[torch.LongTensor] = None,
  445. use_cache: Optional[bool] = None,
  446. position_ids: Optional[torch.LongTensor] = None,
  447. output_attentions: Optional[bool] = None,
  448. output_hidden_states: Optional[bool] = None,
  449. return_dict: Optional[bool] = None,
  450. cache_position: Optional[torch.Tensor] = None,
  451. **kwargs: Unpack[TransformersKwargs],
  452. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  453. r"""
  454. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  455. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  456. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  457. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  458. """
  459. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  460. outputs = self.biogpt(
  461. input_ids,
  462. attention_mask=attention_mask,
  463. head_mask=head_mask,
  464. inputs_embeds=inputs_embeds,
  465. past_key_values=past_key_values,
  466. use_cache=use_cache,
  467. position_ids=position_ids,
  468. output_attentions=output_attentions,
  469. output_hidden_states=output_hidden_states,
  470. return_dict=return_dict,
  471. cache_position=cache_position,
  472. **kwargs,
  473. )
  474. sequence_output = outputs[0]
  475. prediction_scores = self.output_projection(sequence_output)
  476. lm_loss = None
  477. if labels is not None:
  478. lm_loss = self.loss_function(
  479. prediction_scores,
  480. labels,
  481. vocab_size=self.config.vocab_size,
  482. **kwargs,
  483. )
  484. if not return_dict:
  485. output = (prediction_scores,) + outputs[1:]
  486. return ((lm_loss,) + output) if lm_loss is not None else output
  487. return CausalLMOutputWithCrossAttentions(
  488. loss=lm_loss,
  489. logits=prediction_scores,
  490. past_key_values=outputs.past_key_values,
  491. hidden_states=outputs.hidden_states,
  492. attentions=outputs.attentions,
  493. cross_attentions=outputs.cross_attentions,
  494. )
  495. @auto_docstring
  496. class BioGptForTokenClassification(BioGptPreTrainedModel):
  497. def __init__(self, config):
  498. super().__init__(config)
  499. self.num_labels = config.num_labels
  500. self.biogpt = BioGptModel(config)
  501. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  502. classifier_dropout = config.classifier_dropout
  503. else:
  504. classifier_dropout = config.hidden_dropout_prob
  505. self.dropout = nn.Dropout(classifier_dropout)
  506. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  507. self.post_init()
  508. @auto_docstring
  509. def forward(
  510. self,
  511. input_ids: Optional[torch.LongTensor] = None,
  512. token_type_ids: Optional[torch.LongTensor] = None,
  513. attention_mask: Optional[torch.FloatTensor] = None,
  514. head_mask: Optional[torch.FloatTensor] = None,
  515. past_key_values: Optional[Cache] = None,
  516. inputs_embeds: Optional[torch.FloatTensor] = None,
  517. labels: Optional[torch.LongTensor] = None,
  518. use_cache: Optional[bool] = None,
  519. position_ids: Optional[torch.LongTensor] = None,
  520. output_attentions: Optional[bool] = None,
  521. output_hidden_states: Optional[bool] = None,
  522. return_dict: Optional[bool] = None,
  523. cache_position: Optional[torch.Tensor] = None,
  524. ) -> Union[tuple, TokenClassifierOutput]:
  525. r"""
  526. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  527. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  528. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  529. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  530. """
  531. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  532. transformer_outputs = self.biogpt(
  533. input_ids,
  534. past_key_values=past_key_values,
  535. attention_mask=attention_mask,
  536. head_mask=head_mask,
  537. inputs_embeds=inputs_embeds,
  538. use_cache=use_cache,
  539. position_ids=position_ids,
  540. output_attentions=output_attentions,
  541. output_hidden_states=output_hidden_states,
  542. return_dict=return_dict,
  543. cache_position=cache_position,
  544. )
  545. hidden_states = transformer_outputs[0]
  546. hidden_states = self.dropout(hidden_states)
  547. logits = self.classifier(hidden_states)
  548. loss = None
  549. if labels is not None:
  550. loss_fct = CrossEntropyLoss()
  551. # Only keep active parts of the loss
  552. if attention_mask is not None:
  553. active_loss = attention_mask.view(-1) == 1
  554. active_logits = logits.view(-1, self.num_labels)
  555. active_labels = torch.where(
  556. active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
  557. )
  558. loss = loss_fct(active_logits, active_labels)
  559. else:
  560. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  561. if not return_dict:
  562. output = (logits,) + transformer_outputs[2:]
  563. return ((loss,) + output) if loss is not None else output
  564. return TokenClassifierOutput(
  565. loss=loss,
  566. logits=logits,
  567. hidden_states=transformer_outputs.hidden_states,
  568. attentions=transformer_outputs.attentions,
  569. )
  570. @auto_docstring(
  571. custom_intro="""
  572. The BioGpt Model transformer with a sequence classification head on top (linear layer).
  573. [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  574. (e.g. GPT-2) do.
  575. Since it does classification on the last token, it is required to know the position of the last token. If a
  576. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  577. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  578. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  579. each row of the batch).
  580. """
  581. )
  582. class BioGptForSequenceClassification(BioGptPreTrainedModel):
  583. def __init__(self, config: BioGptConfig):
  584. super().__init__(config)
  585. self.num_labels = config.num_labels
  586. self.biogpt = BioGptModel(config)
  587. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  588. # Initialize weights and apply final processing
  589. self.post_init()
  590. @auto_docstring
  591. def forward(
  592. self,
  593. input_ids: Optional[torch.LongTensor] = None,
  594. attention_mask: Optional[torch.FloatTensor] = None,
  595. head_mask: Optional[torch.FloatTensor] = None,
  596. past_key_values: Optional[Cache] = None,
  597. inputs_embeds: Optional[torch.FloatTensor] = None,
  598. labels: Optional[torch.LongTensor] = None,
  599. use_cache: Optional[bool] = None,
  600. position_ids: Optional[torch.LongTensor] = None,
  601. output_attentions: Optional[bool] = None,
  602. output_hidden_states: Optional[bool] = None,
  603. return_dict: Optional[bool] = None,
  604. cache_position: Optional[torch.Tensor] = None,
  605. logits_to_keep: Union[int, torch.Tensor] = 0,
  606. ) -> Union[tuple, SequenceClassifierOutputWithPast]:
  607. r"""
  608. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  609. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  610. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  611. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  612. """
  613. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  614. transformer_outputs = self.biogpt(
  615. input_ids,
  616. past_key_values=past_key_values,
  617. attention_mask=attention_mask,
  618. head_mask=head_mask,
  619. inputs_embeds=inputs_embeds,
  620. use_cache=use_cache,
  621. position_ids=position_ids,
  622. output_attentions=output_attentions,
  623. output_hidden_states=output_hidden_states,
  624. return_dict=return_dict,
  625. cache_position=cache_position,
  626. )
  627. hidden_states = transformer_outputs[0]
  628. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  629. logits = self.score(hidden_states[:, slice_indices, :])
  630. if input_ids is not None:
  631. batch_size, sequence_length = input_ids.shape[:2]
  632. else:
  633. batch_size, sequence_length = inputs_embeds.shape[:2]
  634. if self.config.pad_token_id is None:
  635. sequence_length = -1
  636. else:
  637. if input_ids is not None:
  638. sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
  639. else:
  640. sequence_length = -1
  641. logger.warning_once(
  642. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  643. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  644. )
  645. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
  646. loss = None
  647. if labels is not None:
  648. if self.config.problem_type is None:
  649. if self.num_labels == 1:
  650. self.config.problem_type = "regression"
  651. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  652. self.config.problem_type = "single_label_classification"
  653. else:
  654. self.config.problem_type = "multi_label_classification"
  655. if self.config.problem_type == "regression":
  656. loss_fct = MSELoss()
  657. if self.num_labels == 1:
  658. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  659. else:
  660. loss = loss_fct(pooled_logits, labels)
  661. elif self.config.problem_type == "single_label_classification":
  662. loss_fct = CrossEntropyLoss()
  663. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  664. elif self.config.problem_type == "multi_label_classification":
  665. loss_fct = BCEWithLogitsLoss()
  666. loss = loss_fct(pooled_logits, labels)
  667. if not return_dict:
  668. output = (pooled_logits,) + transformer_outputs[1:]
  669. return ((loss,) + output) if loss is not None else output
  670. return SequenceClassifierOutputWithPast(
  671. loss=loss,
  672. logits=pooled_logits,
  673. past_key_values=transformer_outputs.past_key_values,
  674. hidden_states=transformer_outputs.hidden_states,
  675. attentions=transformer_outputs.attentions,
  676. )
  677. def get_input_embeddings(self):
  678. return self.biogpt.embed_tokens
  679. def set_input_embeddings(self, value):
  680. self.biogpt.embed_tokens = value
  681. __all__ = [
  682. "BioGptForCausalLM",
  683. "BioGptForTokenClassification",
  684. "BioGptForSequenceClassification",
  685. "BioGptModel",
  686. "BioGptPreTrainedModel",
  687. ]