modeling_mpnet.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967
  1. # coding=utf-8
  2. # Copyright 2018 The HuggingFace Inc. team, Microsoft Corporation.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch MPNet model."""
  17. import math
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN, gelu
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. MaskedLMOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  34. from ...utils import auto_docstring, logging
  35. from .configuration_mpnet import MPNetConfig
  36. logger = logging.get_logger(__name__)
  37. @auto_docstring
  38. class MPNetPreTrainedModel(PreTrainedModel):
  39. config: MPNetConfig
  40. base_model_prefix = "mpnet"
  41. def _init_weights(self, module):
  42. """Initialize the weights"""
  43. if isinstance(module, nn.Linear):
  44. # Slightly different from the TF version which uses truncated_normal for initialization
  45. # cf https://github.com/pytorch/pytorch/pull/5617
  46. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  47. if module.bias is not None:
  48. module.bias.data.zero_()
  49. elif isinstance(module, nn.Embedding):
  50. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  51. if module.padding_idx is not None:
  52. module.weight.data[module.padding_idx].zero_()
  53. elif isinstance(module, nn.LayerNorm):
  54. module.bias.data.zero_()
  55. module.weight.data.fill_(1.0)
  56. elif isinstance(module, MPNetLMHead):
  57. module.bias.data.zero_()
  58. class MPNetEmbeddings(nn.Module):
  59. def __init__(self, config):
  60. super().__init__()
  61. self.padding_idx = 1
  62. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
  63. self.position_embeddings = nn.Embedding(
  64. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  65. )
  66. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  67. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  68. self.register_buffer(
  69. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  70. )
  71. def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, **kwargs):
  72. if position_ids is None:
  73. if input_ids is not None:
  74. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
  75. else:
  76. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  77. if input_ids is not None:
  78. input_shape = input_ids.size()
  79. else:
  80. input_shape = inputs_embeds.size()[:-1]
  81. seq_length = input_shape[1]
  82. if position_ids is None:
  83. position_ids = self.position_ids[:, :seq_length]
  84. if inputs_embeds is None:
  85. inputs_embeds = self.word_embeddings(input_ids)
  86. position_embeddings = self.position_embeddings(position_ids)
  87. embeddings = inputs_embeds + position_embeddings
  88. embeddings = self.LayerNorm(embeddings)
  89. embeddings = self.dropout(embeddings)
  90. return embeddings
  91. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  92. """
  93. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  94. Args:
  95. inputs_embeds: torch.Tensor
  96. Returns: torch.Tensor
  97. """
  98. input_shape = inputs_embeds.size()[:-1]
  99. sequence_length = input_shape[1]
  100. position_ids = torch.arange(
  101. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  102. )
  103. return position_ids.unsqueeze(0).expand(input_shape)
  104. class MPNetSelfAttention(nn.Module):
  105. def __init__(self, config):
  106. super().__init__()
  107. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  108. raise ValueError(
  109. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  110. f"heads ({config.num_attention_heads})"
  111. )
  112. self.num_attention_heads = config.num_attention_heads
  113. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  114. self.all_head_size = self.num_attention_heads * self.attention_head_size
  115. self.q = nn.Linear(config.hidden_size, self.all_head_size)
  116. self.k = nn.Linear(config.hidden_size, self.all_head_size)
  117. self.v = nn.Linear(config.hidden_size, self.all_head_size)
  118. self.o = nn.Linear(config.hidden_size, config.hidden_size)
  119. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  120. def forward(
  121. self,
  122. hidden_states,
  123. attention_mask=None,
  124. head_mask=None,
  125. position_bias=None,
  126. output_attentions=False,
  127. **kwargs,
  128. ):
  129. batch_size, seq_length, _ = hidden_states.shape
  130. q = (
  131. self.q(hidden_states)
  132. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  133. .transpose(1, 2)
  134. )
  135. k = (
  136. self.k(hidden_states)
  137. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  138. .transpose(1, 2)
  139. )
  140. v = (
  141. self.v(hidden_states)
  142. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  143. .transpose(1, 2)
  144. )
  145. # Take the dot product between "query" and "key" to get the raw attention scores.
  146. attention_scores = torch.matmul(q, k.transpose(-1, -2))
  147. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  148. # Apply relative position embedding (precomputed in MPNetEncoder) if provided.
  149. if position_bias is not None:
  150. attention_scores += position_bias
  151. if attention_mask is not None:
  152. attention_scores = attention_scores + attention_mask
  153. # Normalize the attention scores to probabilities.
  154. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  155. attention_probs = self.dropout(attention_probs)
  156. if head_mask is not None:
  157. attention_probs = attention_probs * head_mask
  158. c = torch.matmul(attention_probs, v)
  159. c = c.permute(0, 2, 1, 3).contiguous()
  160. new_c_shape = c.size()[:-2] + (self.all_head_size,)
  161. c = c.view(*new_c_shape)
  162. o = self.o(c)
  163. outputs = (o, attention_probs) if output_attentions else (o,)
  164. return outputs
  165. class MPNetAttention(nn.Module):
  166. def __init__(self, config):
  167. super().__init__()
  168. self.attn = MPNetSelfAttention(config)
  169. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  170. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  171. self.pruned_heads = set()
  172. def prune_heads(self, heads):
  173. if len(heads) == 0:
  174. return
  175. heads, index = find_pruneable_heads_and_indices(
  176. heads, self.attn.num_attention_heads, self.attn.attention_head_size, self.pruned_heads
  177. )
  178. self.attn.q = prune_linear_layer(self.attn.q, index)
  179. self.attn.k = prune_linear_layer(self.attn.k, index)
  180. self.attn.v = prune_linear_layer(self.attn.v, index)
  181. self.attn.o = prune_linear_layer(self.attn.o, index, dim=1)
  182. self.attn.num_attention_heads = self.attn.num_attention_heads - len(heads)
  183. self.attn.all_head_size = self.attn.attention_head_size * self.attn.num_attention_heads
  184. self.pruned_heads = self.pruned_heads.union(heads)
  185. def forward(
  186. self,
  187. hidden_states,
  188. attention_mask=None,
  189. head_mask=None,
  190. position_bias=None,
  191. output_attentions=False,
  192. **kwargs,
  193. ):
  194. self_outputs = self.attn(
  195. hidden_states,
  196. attention_mask,
  197. head_mask,
  198. position_bias,
  199. output_attentions=output_attentions,
  200. )
  201. attention_output = self.LayerNorm(self.dropout(self_outputs[0]) + hidden_states)
  202. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  203. return outputs
  204. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  205. class MPNetIntermediate(nn.Module):
  206. def __init__(self, config):
  207. super().__init__()
  208. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  209. if isinstance(config.hidden_act, str):
  210. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  211. else:
  212. self.intermediate_act_fn = config.hidden_act
  213. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  214. hidden_states = self.dense(hidden_states)
  215. hidden_states = self.intermediate_act_fn(hidden_states)
  216. return hidden_states
  217. # Copied from transformers.models.bert.modeling_bert.BertOutput
  218. class MPNetOutput(nn.Module):
  219. def __init__(self, config):
  220. super().__init__()
  221. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  222. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  223. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  224. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  225. hidden_states = self.dense(hidden_states)
  226. hidden_states = self.dropout(hidden_states)
  227. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  228. return hidden_states
  229. class MPNetLayer(nn.Module):
  230. def __init__(self, config):
  231. super().__init__()
  232. self.attention = MPNetAttention(config)
  233. self.intermediate = MPNetIntermediate(config)
  234. self.output = MPNetOutput(config)
  235. def forward(
  236. self,
  237. hidden_states,
  238. attention_mask=None,
  239. head_mask=None,
  240. position_bias=None,
  241. output_attentions=False,
  242. **kwargs,
  243. ):
  244. self_attention_outputs = self.attention(
  245. hidden_states,
  246. attention_mask,
  247. head_mask,
  248. position_bias=position_bias,
  249. output_attentions=output_attentions,
  250. )
  251. attention_output = self_attention_outputs[0]
  252. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  253. intermediate_output = self.intermediate(attention_output)
  254. layer_output = self.output(intermediate_output, attention_output)
  255. outputs = (layer_output,) + outputs
  256. return outputs
  257. class MPNetEncoder(nn.Module):
  258. def __init__(self, config):
  259. super().__init__()
  260. self.config = config
  261. self.n_heads = config.num_attention_heads
  262. self.layer = nn.ModuleList([MPNetLayer(config) for _ in range(config.num_hidden_layers)])
  263. self.relative_attention_bias = nn.Embedding(config.relative_attention_num_buckets, self.n_heads)
  264. def forward(
  265. self,
  266. hidden_states: torch.Tensor,
  267. attention_mask: Optional[torch.Tensor] = None,
  268. head_mask: Optional[torch.Tensor] = None,
  269. output_attentions: bool = False,
  270. output_hidden_states: bool = False,
  271. return_dict: bool = False,
  272. **kwargs,
  273. ):
  274. position_bias = self.compute_position_bias(hidden_states)
  275. all_hidden_states = () if output_hidden_states else None
  276. all_attentions = () if output_attentions else None
  277. for i, layer_module in enumerate(self.layer):
  278. if output_hidden_states:
  279. all_hidden_states = all_hidden_states + (hidden_states,)
  280. layer_outputs = layer_module(
  281. hidden_states,
  282. attention_mask,
  283. head_mask[i],
  284. position_bias,
  285. output_attentions=output_attentions,
  286. **kwargs,
  287. )
  288. hidden_states = layer_outputs[0]
  289. if output_attentions:
  290. all_attentions = all_attentions + (layer_outputs[1],)
  291. # Add last layer
  292. if output_hidden_states:
  293. all_hidden_states = all_hidden_states + (hidden_states,)
  294. if not return_dict:
  295. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  296. return BaseModelOutput(
  297. last_hidden_state=hidden_states,
  298. hidden_states=all_hidden_states,
  299. attentions=all_attentions,
  300. )
  301. def compute_position_bias(self, x, position_ids=None, num_buckets=32):
  302. bsz, qlen, klen = x.size(0), x.size(1), x.size(1)
  303. if position_ids is not None:
  304. context_position = position_ids[:, :, None]
  305. memory_position = position_ids[:, None, :]
  306. else:
  307. context_position = torch.arange(qlen, dtype=torch.long)[:, None]
  308. memory_position = torch.arange(klen, dtype=torch.long)[None, :]
  309. relative_position = memory_position - context_position
  310. rp_bucket = self.relative_position_bucket(relative_position, num_buckets=num_buckets)
  311. rp_bucket = rp_bucket.to(x.device)
  312. values = self.relative_attention_bias(rp_bucket)
  313. values = values.permute([2, 0, 1]).unsqueeze(0)
  314. values = values.expand((bsz, -1, qlen, klen)).contiguous()
  315. return values
  316. @staticmethod
  317. def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
  318. ret = 0
  319. n = -relative_position
  320. num_buckets //= 2
  321. ret += (n < 0).to(torch.long) * num_buckets
  322. n = torch.abs(n)
  323. max_exact = num_buckets // 2
  324. is_small = n < max_exact
  325. val_if_large = max_exact + (
  326. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  327. ).to(torch.long)
  328. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  329. ret += torch.where(is_small, n, val_if_large)
  330. return ret
  331. # Copied from transformers.models.bert.modeling_bert.BertPooler
  332. class MPNetPooler(nn.Module):
  333. def __init__(self, config):
  334. super().__init__()
  335. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  336. self.activation = nn.Tanh()
  337. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  338. # We "pool" the model by simply taking the hidden state corresponding
  339. # to the first token.
  340. first_token_tensor = hidden_states[:, 0]
  341. pooled_output = self.dense(first_token_tensor)
  342. pooled_output = self.activation(pooled_output)
  343. return pooled_output
  344. @auto_docstring
  345. class MPNetModel(MPNetPreTrainedModel):
  346. def __init__(self, config, add_pooling_layer=True):
  347. r"""
  348. add_pooling_layer (bool, *optional*, defaults to `True`):
  349. Whether to add a pooling layer
  350. """
  351. super().__init__(config)
  352. self.config = config
  353. self.embeddings = MPNetEmbeddings(config)
  354. self.encoder = MPNetEncoder(config)
  355. self.pooler = MPNetPooler(config) if add_pooling_layer else None
  356. # Initialize weights and apply final processing
  357. self.post_init()
  358. def get_input_embeddings(self):
  359. return self.embeddings.word_embeddings
  360. def set_input_embeddings(self, value):
  361. self.embeddings.word_embeddings = value
  362. def _prune_heads(self, heads_to_prune):
  363. """
  364. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  365. class PreTrainedModel
  366. """
  367. for layer, heads in heads_to_prune.items():
  368. self.encoder.layer[layer].attention.prune_heads(heads)
  369. @auto_docstring
  370. def forward(
  371. self,
  372. input_ids: Optional[torch.LongTensor] = None,
  373. attention_mask: Optional[torch.FloatTensor] = None,
  374. position_ids: Optional[torch.LongTensor] = None,
  375. head_mask: Optional[torch.FloatTensor] = None,
  376. inputs_embeds: Optional[torch.FloatTensor] = None,
  377. output_attentions: Optional[bool] = None,
  378. output_hidden_states: Optional[bool] = None,
  379. return_dict: Optional[bool] = None,
  380. **kwargs,
  381. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPooling]:
  382. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  383. output_hidden_states = (
  384. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  385. )
  386. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  387. if input_ids is not None and inputs_embeds is not None:
  388. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  389. elif input_ids is not None:
  390. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  391. input_shape = input_ids.size()
  392. elif inputs_embeds is not None:
  393. input_shape = inputs_embeds.size()[:-1]
  394. else:
  395. raise ValueError("You have to specify either input_ids or inputs_embeds")
  396. device = input_ids.device if input_ids is not None else inputs_embeds.device
  397. if attention_mask is None:
  398. attention_mask = torch.ones(input_shape, device=device)
  399. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  400. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  401. embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
  402. encoder_outputs = self.encoder(
  403. embedding_output,
  404. attention_mask=extended_attention_mask,
  405. head_mask=head_mask,
  406. output_attentions=output_attentions,
  407. output_hidden_states=output_hidden_states,
  408. return_dict=return_dict,
  409. )
  410. sequence_output = encoder_outputs[0]
  411. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  412. if not return_dict:
  413. return (sequence_output, pooled_output) + encoder_outputs[1:]
  414. return BaseModelOutputWithPooling(
  415. last_hidden_state=sequence_output,
  416. pooler_output=pooled_output,
  417. hidden_states=encoder_outputs.hidden_states,
  418. attentions=encoder_outputs.attentions,
  419. )
  420. class MPNetForMaskedLM(MPNetPreTrainedModel):
  421. _tied_weights_keys = ["lm_head.decoder"]
  422. def __init__(self, config):
  423. super().__init__(config)
  424. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  425. self.lm_head = MPNetLMHead(config)
  426. # Initialize weights and apply final processing
  427. self.post_init()
  428. def get_output_embeddings(self):
  429. return self.lm_head.decoder
  430. def set_output_embeddings(self, new_embeddings):
  431. self.lm_head.decoder = new_embeddings
  432. self.lm_head.bias = new_embeddings.bias
  433. @auto_docstring
  434. def forward(
  435. self,
  436. input_ids: Optional[torch.LongTensor] = None,
  437. attention_mask: Optional[torch.FloatTensor] = None,
  438. position_ids: Optional[torch.LongTensor] = None,
  439. head_mask: Optional[torch.FloatTensor] = None,
  440. inputs_embeds: Optional[torch.FloatTensor] = None,
  441. labels: Optional[torch.LongTensor] = None,
  442. output_attentions: Optional[bool] = None,
  443. output_hidden_states: Optional[bool] = None,
  444. return_dict: Optional[bool] = None,
  445. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  446. r"""
  447. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  448. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  449. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  450. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  451. """
  452. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  453. outputs = self.mpnet(
  454. input_ids,
  455. attention_mask=attention_mask,
  456. position_ids=position_ids,
  457. head_mask=head_mask,
  458. inputs_embeds=inputs_embeds,
  459. output_attentions=output_attentions,
  460. output_hidden_states=output_hidden_states,
  461. return_dict=return_dict,
  462. )
  463. sequence_output = outputs[0]
  464. prediction_scores = self.lm_head(sequence_output)
  465. masked_lm_loss = None
  466. if labels is not None:
  467. loss_fct = CrossEntropyLoss()
  468. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  469. if not return_dict:
  470. output = (prediction_scores,) + outputs[2:]
  471. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  472. return MaskedLMOutput(
  473. loss=masked_lm_loss,
  474. logits=prediction_scores,
  475. hidden_states=outputs.hidden_states,
  476. attentions=outputs.attentions,
  477. )
  478. class MPNetLMHead(nn.Module):
  479. """MPNet Head for masked and permuted language modeling."""
  480. def __init__(self, config):
  481. super().__init__()
  482. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  483. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  484. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  485. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  486. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  487. self.decoder.bias = self.bias
  488. def _tie_weights(self):
  489. self.decoder.bias = self.bias
  490. def forward(self, features, **kwargs):
  491. x = self.dense(features)
  492. x = gelu(x)
  493. x = self.layer_norm(x)
  494. # project back to size of vocabulary with bias
  495. x = self.decoder(x)
  496. return x
  497. @auto_docstring(
  498. custom_intro="""
  499. MPNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  500. output) e.g. for GLUE tasks.
  501. """
  502. )
  503. class MPNetForSequenceClassification(MPNetPreTrainedModel):
  504. def __init__(self, config):
  505. super().__init__(config)
  506. self.num_labels = config.num_labels
  507. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  508. self.classifier = MPNetClassificationHead(config)
  509. # Initialize weights and apply final processing
  510. self.post_init()
  511. @auto_docstring
  512. def forward(
  513. self,
  514. input_ids: Optional[torch.LongTensor] = None,
  515. attention_mask: Optional[torch.FloatTensor] = None,
  516. position_ids: Optional[torch.LongTensor] = None,
  517. head_mask: Optional[torch.FloatTensor] = None,
  518. inputs_embeds: Optional[torch.FloatTensor] = None,
  519. labels: Optional[torch.LongTensor] = None,
  520. output_attentions: Optional[bool] = None,
  521. output_hidden_states: Optional[bool] = None,
  522. return_dict: Optional[bool] = None,
  523. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  524. r"""
  525. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  526. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  527. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  528. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  529. """
  530. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  531. outputs = self.mpnet(
  532. input_ids,
  533. attention_mask=attention_mask,
  534. position_ids=position_ids,
  535. head_mask=head_mask,
  536. inputs_embeds=inputs_embeds,
  537. output_attentions=output_attentions,
  538. output_hidden_states=output_hidden_states,
  539. return_dict=return_dict,
  540. )
  541. sequence_output = outputs[0]
  542. logits = self.classifier(sequence_output)
  543. loss = None
  544. if labels is not None:
  545. if self.config.problem_type is None:
  546. if self.num_labels == 1:
  547. self.config.problem_type = "regression"
  548. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  549. self.config.problem_type = "single_label_classification"
  550. else:
  551. self.config.problem_type = "multi_label_classification"
  552. if self.config.problem_type == "regression":
  553. loss_fct = MSELoss()
  554. if self.num_labels == 1:
  555. loss = loss_fct(logits.squeeze(), labels.squeeze())
  556. else:
  557. loss = loss_fct(logits, labels)
  558. elif self.config.problem_type == "single_label_classification":
  559. loss_fct = CrossEntropyLoss()
  560. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  561. elif self.config.problem_type == "multi_label_classification":
  562. loss_fct = BCEWithLogitsLoss()
  563. loss = loss_fct(logits, labels)
  564. if not return_dict:
  565. output = (logits,) + outputs[2:]
  566. return ((loss,) + output) if loss is not None else output
  567. return SequenceClassifierOutput(
  568. loss=loss,
  569. logits=logits,
  570. hidden_states=outputs.hidden_states,
  571. attentions=outputs.attentions,
  572. )
  573. @auto_docstring
  574. class MPNetForMultipleChoice(MPNetPreTrainedModel):
  575. def __init__(self, config):
  576. super().__init__(config)
  577. self.mpnet = MPNetModel(config)
  578. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  579. self.classifier = nn.Linear(config.hidden_size, 1)
  580. # Initialize weights and apply final processing
  581. self.post_init()
  582. @auto_docstring
  583. def forward(
  584. self,
  585. input_ids: Optional[torch.LongTensor] = None,
  586. attention_mask: Optional[torch.FloatTensor] = None,
  587. position_ids: Optional[torch.LongTensor] = None,
  588. head_mask: Optional[torch.FloatTensor] = None,
  589. inputs_embeds: Optional[torch.FloatTensor] = None,
  590. labels: Optional[torch.LongTensor] = None,
  591. output_attentions: Optional[bool] = None,
  592. output_hidden_states: Optional[bool] = None,
  593. return_dict: Optional[bool] = None,
  594. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  595. r"""
  596. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  597. Indices of input sequence tokens in the vocabulary.
  598. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  599. [`PreTrainedTokenizer.__call__`] for details.
  600. [What are input IDs?](../glossary#input-ids)
  601. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  602. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  603. config.max_position_embeddings - 1]`.
  604. [What are position IDs?](../glossary#position-ids)
  605. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  606. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  607. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  608. model's internal embedding lookup matrix.
  609. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  610. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  611. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  612. `input_ids` above)
  613. """
  614. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  615. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  616. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  617. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  618. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  619. flat_inputs_embeds = (
  620. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  621. if inputs_embeds is not None
  622. else None
  623. )
  624. outputs = self.mpnet(
  625. flat_input_ids,
  626. position_ids=flat_position_ids,
  627. attention_mask=flat_attention_mask,
  628. head_mask=head_mask,
  629. inputs_embeds=flat_inputs_embeds,
  630. output_attentions=output_attentions,
  631. output_hidden_states=output_hidden_states,
  632. return_dict=return_dict,
  633. )
  634. pooled_output = outputs[1]
  635. pooled_output = self.dropout(pooled_output)
  636. logits = self.classifier(pooled_output)
  637. reshaped_logits = logits.view(-1, num_choices)
  638. loss = None
  639. if labels is not None:
  640. loss_fct = CrossEntropyLoss()
  641. loss = loss_fct(reshaped_logits, labels)
  642. if not return_dict:
  643. output = (reshaped_logits,) + outputs[2:]
  644. return ((loss,) + output) if loss is not None else output
  645. return MultipleChoiceModelOutput(
  646. loss=loss,
  647. logits=reshaped_logits,
  648. hidden_states=outputs.hidden_states,
  649. attentions=outputs.attentions,
  650. )
  651. @auto_docstring
  652. class MPNetForTokenClassification(MPNetPreTrainedModel):
  653. def __init__(self, config):
  654. super().__init__(config)
  655. self.num_labels = config.num_labels
  656. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  657. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  658. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  659. # Initialize weights and apply final processing
  660. self.post_init()
  661. @auto_docstring
  662. def forward(
  663. self,
  664. input_ids: Optional[torch.LongTensor] = None,
  665. attention_mask: Optional[torch.FloatTensor] = None,
  666. position_ids: Optional[torch.LongTensor] = None,
  667. head_mask: Optional[torch.FloatTensor] = None,
  668. inputs_embeds: Optional[torch.FloatTensor] = None,
  669. labels: Optional[torch.LongTensor] = None,
  670. output_attentions: Optional[bool] = None,
  671. output_hidden_states: Optional[bool] = None,
  672. return_dict: Optional[bool] = None,
  673. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  674. r"""
  675. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  676. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  677. """
  678. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  679. outputs = self.mpnet(
  680. input_ids,
  681. attention_mask=attention_mask,
  682. position_ids=position_ids,
  683. head_mask=head_mask,
  684. inputs_embeds=inputs_embeds,
  685. output_attentions=output_attentions,
  686. output_hidden_states=output_hidden_states,
  687. return_dict=return_dict,
  688. )
  689. sequence_output = outputs[0]
  690. sequence_output = self.dropout(sequence_output)
  691. logits = self.classifier(sequence_output)
  692. loss = None
  693. if labels is not None:
  694. loss_fct = CrossEntropyLoss()
  695. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  696. if not return_dict:
  697. output = (logits,) + outputs[2:]
  698. return ((loss,) + output) if loss is not None else output
  699. return TokenClassifierOutput(
  700. loss=loss,
  701. logits=logits,
  702. hidden_states=outputs.hidden_states,
  703. attentions=outputs.attentions,
  704. )
  705. class MPNetClassificationHead(nn.Module):
  706. """Head for sentence-level classification tasks."""
  707. def __init__(self, config):
  708. super().__init__()
  709. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  710. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  711. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  712. def forward(self, features, **kwargs):
  713. x = features[:, 0, :] # take <s> token (equiv. to BERT's [CLS] token)
  714. x = self.dropout(x)
  715. x = self.dense(x)
  716. x = torch.tanh(x)
  717. x = self.dropout(x)
  718. x = self.out_proj(x)
  719. return x
  720. @auto_docstring
  721. class MPNetForQuestionAnswering(MPNetPreTrainedModel):
  722. def __init__(self, config):
  723. super().__init__(config)
  724. self.num_labels = config.num_labels
  725. self.mpnet = MPNetModel(config, add_pooling_layer=False)
  726. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  727. # Initialize weights and apply final processing
  728. self.post_init()
  729. @auto_docstring
  730. def forward(
  731. self,
  732. input_ids: Optional[torch.LongTensor] = None,
  733. attention_mask: Optional[torch.FloatTensor] = None,
  734. position_ids: Optional[torch.LongTensor] = None,
  735. head_mask: Optional[torch.FloatTensor] = None,
  736. inputs_embeds: Optional[torch.FloatTensor] = None,
  737. start_positions: Optional[torch.LongTensor] = None,
  738. end_positions: Optional[torch.LongTensor] = None,
  739. output_attentions: Optional[bool] = None,
  740. output_hidden_states: Optional[bool] = None,
  741. return_dict: Optional[bool] = None,
  742. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  743. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  744. outputs = self.mpnet(
  745. input_ids,
  746. attention_mask=attention_mask,
  747. position_ids=position_ids,
  748. head_mask=head_mask,
  749. inputs_embeds=inputs_embeds,
  750. output_attentions=output_attentions,
  751. output_hidden_states=output_hidden_states,
  752. return_dict=return_dict,
  753. )
  754. sequence_output = outputs[0]
  755. logits = self.qa_outputs(sequence_output)
  756. start_logits, end_logits = logits.split(1, dim=-1)
  757. start_logits = start_logits.squeeze(-1).contiguous()
  758. end_logits = end_logits.squeeze(-1).contiguous()
  759. total_loss = None
  760. if start_positions is not None and end_positions is not None:
  761. # If we are on multi-GPU, split add a dimension
  762. if len(start_positions.size()) > 1:
  763. start_positions = start_positions.squeeze(-1)
  764. if len(end_positions.size()) > 1:
  765. end_positions = end_positions.squeeze(-1)
  766. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  767. ignored_index = start_logits.size(1)
  768. start_positions = start_positions.clamp(0, ignored_index)
  769. end_positions = end_positions.clamp(0, ignored_index)
  770. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  771. start_loss = loss_fct(start_logits, start_positions)
  772. end_loss = loss_fct(end_logits, end_positions)
  773. total_loss = (start_loss + end_loss) / 2
  774. if not return_dict:
  775. output = (start_logits, end_logits) + outputs[2:]
  776. return ((total_loss,) + output) if total_loss is not None else output
  777. return QuestionAnsweringModelOutput(
  778. loss=total_loss,
  779. start_logits=start_logits,
  780. end_logits=end_logits,
  781. hidden_states=outputs.hidden_states,
  782. attentions=outputs.attentions,
  783. )
  784. def create_position_ids_from_input_ids(input_ids, padding_idx):
  785. """
  786. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  787. are ignored. This is modified from fairseq's `utils.make_positions`. :param torch.Tensor x: :return torch.Tensor:
  788. """
  789. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  790. mask = input_ids.ne(padding_idx).int()
  791. incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
  792. return incremental_indices.long() + padding_idx
  793. __all__ = [
  794. "MPNetForMaskedLM",
  795. "MPNetForMultipleChoice",
  796. "MPNetForQuestionAnswering",
  797. "MPNetForSequenceClassification",
  798. "MPNetForTokenClassification",
  799. "MPNetLayer",
  800. "MPNetModel",
  801. "MPNetPreTrainedModel",
  802. ]