modeling_mvp.py 80 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786
  1. # coding=utf-8
  2. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MVP model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import (
  25. _prepare_4d_attention_mask,
  26. _prepare_4d_causal_attention_mask,
  27. )
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. CausalLMOutputWithCrossAttentions,
  33. Seq2SeqLMOutput,
  34. Seq2SeqModelOutput,
  35. Seq2SeqQuestionAnsweringModelOutput,
  36. Seq2SeqSequenceClassifierOutput,
  37. )
  38. from ...modeling_utils import PreTrainedModel
  39. from ...utils import auto_docstring, logging
  40. from ...utils.deprecation import deprecate_kwarg
  41. from .configuration_mvp import MvpConfig
  42. logger = logging.get_logger(__name__)
  43. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  44. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  45. """
  46. Shift input ids one token to the right.
  47. """
  48. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  49. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  50. shifted_input_ids[:, 0] = decoder_start_token_id
  51. if pad_token_id is None:
  52. raise ValueError("self.model.config.pad_token_id has to be defined.")
  53. # replace possible -100 values in labels by `pad_token_id`
  54. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  55. return shifted_input_ids
  56. # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->Mvp
  57. class MvpLearnedPositionalEmbedding(nn.Embedding):
  58. """
  59. This module learns positional embeddings up to a fixed maximum size.
  60. """
  61. def __init__(self, num_embeddings: int, embedding_dim: int):
  62. # Mvp is set up so that if padding_idx is specified then offset the embedding ids by 2
  63. # and adjust num_embeddings appropriately. Other models don't have this hack
  64. self.offset = 2
  65. super().__init__(num_embeddings + self.offset, embedding_dim)
  66. def forward(
  67. self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None
  68. ):
  69. """`input_ids' shape is expected to be [bsz x seqlen]."""
  70. if position_ids is None:
  71. bsz, seq_len = input_ids.shape[:2]
  72. position_ids = torch.arange(
  73. past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
  74. ).expand(bsz, -1)
  75. else:
  76. position_ids = position_ids.unsqueeze(0)
  77. return super().forward(position_ids + self.offset)
  78. class MvpAttention(nn.Module):
  79. """Multi-headed attention from 'Attention Is All You Need' paper"""
  80. def __init__(
  81. self,
  82. embed_dim: int,
  83. num_heads: int,
  84. dropout: Optional[float] = 0.0,
  85. is_decoder: Optional[bool] = False,
  86. bias: Optional[bool] = True,
  87. layer_idx: Optional[bool] = None,
  88. ):
  89. super().__init__()
  90. self.embed_dim = embed_dim
  91. self.num_heads = num_heads
  92. self.dropout = dropout
  93. self.head_dim = embed_dim // num_heads
  94. if (self.head_dim * num_heads) != self.embed_dim:
  95. raise ValueError(
  96. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  97. f" and `num_heads`: {num_heads})."
  98. )
  99. self.scaling = self.head_dim**-0.5
  100. self.is_decoder = is_decoder
  101. self.layer_idx = layer_idx
  102. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  103. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  104. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  105. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  106. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  107. def forward(
  108. self,
  109. hidden_states: torch.Tensor,
  110. key_value_states: Optional[torch.Tensor] = None,
  111. past_key_values: Optional[Cache] = None,
  112. attention_mask: Optional[torch.Tensor] = None,
  113. layer_head_mask: Optional[torch.Tensor] = None,
  114. attn_prompt: Optional[torch.Tensor] = None,
  115. output_attentions: bool = False,
  116. cache_position: Optional[torch.Tensor] = None,
  117. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  118. """Input shape: Batch x Time x Channel"""
  119. # if key_value_states are provided this layer is used as a cross-attention layer
  120. # for the decoder
  121. is_cross_attention = key_value_states is not None
  122. bsz, tgt_len, _ = hidden_states.size()
  123. # get query proj
  124. query_states = self.q_proj(hidden_states) * self.scaling
  125. is_updated = False
  126. if past_key_values is not None:
  127. if isinstance(past_key_values, EncoderDecoderCache):
  128. is_updated = past_key_values.is_updated.get(self.layer_idx)
  129. if is_cross_attention:
  130. # after the first generated id, we can subsequently re-use all key/value_states from cache
  131. curr_past_key_value = past_key_values.cross_attention_cache
  132. else:
  133. curr_past_key_value = past_key_values.self_attention_cache
  134. else:
  135. curr_past_key_value = past_key_values
  136. current_states = key_value_states if is_cross_attention else hidden_states
  137. if is_cross_attention and past_key_values is not None and is_updated:
  138. # reuse k,v, cross_attentions
  139. key_states = curr_past_key_value.layers[self.layer_idx].keys
  140. value_states = curr_past_key_value.layers[self.layer_idx].values
  141. else:
  142. key_states = self.k_proj(current_states)
  143. value_states = self.v_proj(current_states)
  144. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  145. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  146. if past_key_values is not None:
  147. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  148. cache_position = cache_position if not is_cross_attention else None
  149. key_states, value_states = curr_past_key_value.update(
  150. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  151. )
  152. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  153. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  154. past_key_values.is_updated[self.layer_idx] = True
  155. if attn_prompt is not None:
  156. key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2)
  157. value_states = torch.cat([attn_prompt[1].expand(bsz, -1, -1, -1), value_states], dim=2)
  158. if attention_mask is not None:
  159. prompt_mask = torch.zeros(bsz, 1, tgt_len, attn_prompt[0].size(1)).to(attention_mask.device)
  160. attention_mask = torch.cat([prompt_mask, attention_mask], dim=(-1))
  161. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  162. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  163. query_states = query_states.reshape(*proj_shape)
  164. key_states = key_states.reshape(*proj_shape)
  165. value_states = value_states.reshape(*proj_shape)
  166. src_len = key_states.size(1)
  167. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  168. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  169. raise ValueError(
  170. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  171. f" {attn_weights.size()}"
  172. )
  173. if attention_mask is not None:
  174. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  175. raise ValueError(
  176. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  177. )
  178. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  179. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  180. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  181. if layer_head_mask is not None:
  182. if layer_head_mask.size() != (self.num_heads,):
  183. raise ValueError(
  184. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  185. f" {layer_head_mask.size()}"
  186. )
  187. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  188. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  189. if output_attentions:
  190. # this operation is a bit awkward, but it's required to
  191. # make sure that attn_weights keeps its gradient.
  192. # In order to do so, attn_weights have to be reshaped
  193. # twice and have to be reused in the following
  194. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  195. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  196. else:
  197. attn_weights_reshaped = None
  198. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  199. attn_output = torch.bmm(attn_probs, value_states)
  200. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  201. raise ValueError(
  202. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  203. f" {attn_output.size()}"
  204. )
  205. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  206. attn_output = attn_output.transpose(1, 2)
  207. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  208. # partitioned across GPUs when using tensor-parallelism.
  209. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  210. attn_output = self.out_proj(attn_output)
  211. return attn_output, attn_weights_reshaped
  212. class MvpEncoderLayer(GradientCheckpointingLayer):
  213. def __init__(self, config: MvpConfig):
  214. super().__init__()
  215. self.embed_dim = config.d_model
  216. self.self_attn = MvpAttention(
  217. embed_dim=self.embed_dim,
  218. num_heads=config.encoder_attention_heads,
  219. dropout=config.attention_dropout,
  220. )
  221. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  222. self.dropout = config.dropout
  223. self.activation_fn = ACT2FN[config.activation_function]
  224. self.activation_dropout = config.activation_dropout
  225. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  226. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  227. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  228. def forward(
  229. self,
  230. hidden_states: torch.FloatTensor,
  231. attention_mask: torch.FloatTensor,
  232. layer_head_mask: torch.FloatTensor,
  233. self_attn_prompt: torch.FloatTensor,
  234. output_attentions: Optional[bool] = False,
  235. ) -> tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
  236. """
  237. Args:
  238. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  239. attention_mask (`torch.FloatTensor`): attention mask of size
  240. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  241. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  242. `(encoder_attention_heads,)`.
  243. self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
  244. `(2, encoder_attention_heads, pro_len, head_dim)`.
  245. output_attentions (`bool`, *optional*):
  246. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  247. returned tensors for more detail.
  248. """
  249. residual = hidden_states
  250. hidden_states, attn_weights = self.self_attn(
  251. hidden_states=hidden_states,
  252. attention_mask=attention_mask,
  253. layer_head_mask=layer_head_mask,
  254. attn_prompt=self_attn_prompt,
  255. output_attentions=output_attentions,
  256. )
  257. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  258. hidden_states = residual + hidden_states
  259. hidden_states = self.self_attn_layer_norm(hidden_states)
  260. residual = hidden_states
  261. hidden_states = self.activation_fn(self.fc1(hidden_states))
  262. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  263. hidden_states = self.fc2(hidden_states)
  264. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  265. hidden_states = residual + hidden_states
  266. hidden_states = self.final_layer_norm(hidden_states)
  267. if hidden_states.dtype == torch.float16 and (
  268. torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
  269. ):
  270. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  271. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  272. return hidden_states, attn_weights
  273. class MvpDecoderLayer(GradientCheckpointingLayer):
  274. def __init__(self, config: MvpConfig, layer_idx=None):
  275. super().__init__()
  276. self.embed_dim = config.d_model
  277. self.self_attn = MvpAttention(
  278. embed_dim=self.embed_dim,
  279. num_heads=config.decoder_attention_heads,
  280. dropout=config.attention_dropout,
  281. is_decoder=True,
  282. layer_idx=layer_idx,
  283. )
  284. self.dropout = config.dropout
  285. self.activation_fn = ACT2FN[config.activation_function]
  286. self.activation_dropout = config.activation_dropout
  287. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  288. self.encoder_attn = MvpAttention(
  289. self.embed_dim,
  290. config.decoder_attention_heads,
  291. dropout=config.attention_dropout,
  292. is_decoder=True,
  293. layer_idx=layer_idx,
  294. )
  295. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  296. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  297. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  298. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  299. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  300. def forward(
  301. self,
  302. hidden_states: torch.Tensor,
  303. attention_mask: Optional[torch.Tensor] = None,
  304. encoder_hidden_states: Optional[torch.Tensor] = None,
  305. encoder_attention_mask: Optional[torch.Tensor] = None,
  306. layer_head_mask: Optional[torch.Tensor] = None,
  307. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  308. self_attn_prompt: Optional[torch.Tensor] = None,
  309. cross_attn_prompt: Optional[torch.Tensor] = None,
  310. past_key_values: Optional[Cache] = None,
  311. output_attentions: Optional[bool] = False,
  312. use_cache: Optional[bool] = True,
  313. cache_position: Optional[torch.Tensor] = None,
  314. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  315. """
  316. Args:
  317. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  318. attention_mask (`torch.FloatTensor`): attention mask of size
  319. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  320. encoder_hidden_states (`torch.FloatTensor`):
  321. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  322. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  323. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  324. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  325. `(encoder_attention_heads,)`.
  326. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  327. size `(decoder_attention_heads,)`.
  328. self_attn_prompt (`torch.FloatTensor`): prompt of self attention of shape
  329. `(2, decoder_attention_heads, pro_len, head_dim)`.
  330. cross_attn_prompt (`torch.FloatTensor`): prompt of cross attention of shape
  331. `(2, decoder_attention_heads, pro_len, head_dim)`.
  332. past_key_values (`Cache`): cached past key and value projection states
  333. output_attentions (`bool`, *optional*):
  334. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  335. returned tensors for more detail.
  336. """
  337. residual = hidden_states
  338. # Self Attention
  339. hidden_states, self_attn_weights = self.self_attn(
  340. hidden_states=hidden_states,
  341. past_key_values=past_key_values,
  342. attention_mask=attention_mask,
  343. layer_head_mask=layer_head_mask,
  344. attn_prompt=self_attn_prompt,
  345. output_attentions=output_attentions,
  346. cache_position=cache_position,
  347. )
  348. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  349. hidden_states = residual + hidden_states
  350. hidden_states = self.self_attn_layer_norm(hidden_states)
  351. # Cross-Attention Block
  352. cross_attn_weights = None
  353. if encoder_hidden_states is not None:
  354. residual = hidden_states
  355. hidden_states, cross_attn_weights = self.encoder_attn(
  356. hidden_states=hidden_states,
  357. key_value_states=encoder_hidden_states,
  358. attention_mask=encoder_attention_mask,
  359. layer_head_mask=cross_attn_layer_head_mask,
  360. attn_prompt=cross_attn_prompt,
  361. past_key_values=past_key_values,
  362. output_attentions=output_attentions,
  363. )
  364. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  365. hidden_states = residual + hidden_states
  366. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  367. # Fully Connected
  368. residual = hidden_states
  369. hidden_states = self.activation_fn(self.fc1(hidden_states))
  370. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  371. hidden_states = self.fc2(hidden_states)
  372. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  373. hidden_states = residual + hidden_states
  374. hidden_states = self.final_layer_norm(hidden_states)
  375. outputs = (hidden_states,)
  376. if output_attentions:
  377. outputs += (self_attn_weights, cross_attn_weights)
  378. return outputs
  379. # Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MVP
  380. class MvpClassificationHead(nn.Module):
  381. """Head for sentence-level classification tasks."""
  382. def __init__(
  383. self,
  384. input_dim: int,
  385. inner_dim: int,
  386. num_classes: int,
  387. pooler_dropout: float,
  388. ):
  389. super().__init__()
  390. self.dense = nn.Linear(input_dim, inner_dim)
  391. self.dropout = nn.Dropout(p=pooler_dropout)
  392. self.out_proj = nn.Linear(inner_dim, num_classes)
  393. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  394. hidden_states = self.dropout(hidden_states)
  395. hidden_states = self.dense(hidden_states)
  396. hidden_states = torch.tanh(hidden_states)
  397. hidden_states = self.dropout(hidden_states)
  398. hidden_states = self.out_proj(hidden_states)
  399. return hidden_states
  400. class MvpPrompt(nn.Module):
  401. """Layer-wise prompt for encoder or decoder."""
  402. def __init__(self, config, num_layers, num_heads):
  403. super().__init__()
  404. self.prompt_length = config.prompt_length
  405. self.num_layers = num_layers
  406. self.num_heads = num_heads
  407. self.head_dim = config.d_model // num_heads
  408. self.dropout = nn.Dropout(p=config.dropout)
  409. self.prompt_embedding = nn.Embedding(config.prompt_length, config.d_model)
  410. self.prompt_trans = nn.Sequential(
  411. nn.Linear(config.d_model, config.prompt_mid_dim),
  412. nn.GELU(),
  413. nn.Linear(config.prompt_mid_dim, num_layers * 2 * config.d_model),
  414. )
  415. def forward(self, prompt_ids: torch.Tensor) -> tuple[torch.Tensor]:
  416. prompt = self.prompt_trans(self.prompt_embedding(prompt_ids))
  417. prompt = prompt.view(self.prompt_length, self.num_layers * 2, self.num_heads, self.head_dim)
  418. prompt = self.dropout(prompt)
  419. prompt = prompt.permute([1, 2, 0, 3]).split(2)
  420. return prompt
  421. @auto_docstring
  422. class MvpPreTrainedModel(PreTrainedModel):
  423. config: MvpConfig
  424. base_model_prefix = "model"
  425. supports_gradient_checkpointing = True
  426. def _init_weights(self, module):
  427. std = self.config.init_std
  428. if isinstance(module, nn.Linear):
  429. module.weight.data.normal_(mean=0.0, std=std)
  430. if module.bias is not None:
  431. module.bias.data.zero_()
  432. elif isinstance(module, nn.Embedding):
  433. module.weight.data.normal_(mean=0.0, std=std)
  434. if module.padding_idx is not None:
  435. module.weight.data[module.padding_idx].zero_()
  436. @property
  437. def dummy_inputs(self):
  438. pad_token = self.config.pad_token_id
  439. input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
  440. dummy_inputs = {
  441. "attention_mask": input_ids.ne(pad_token),
  442. "input_ids": input_ids,
  443. }
  444. return dummy_inputs
  445. class MvpEncoder(MvpPreTrainedModel):
  446. """
  447. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  448. [`MvpEncoderLayer`].
  449. Args:
  450. config: MvpConfig
  451. embed_tokens (nn.Embedding): output embedding
  452. use_prompt (bool): whether to use prompt
  453. """
  454. def __init__(
  455. self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False
  456. ):
  457. super().__init__(config)
  458. self.dropout = config.dropout
  459. self.layerdrop = config.encoder_layerdrop
  460. embed_dim = config.d_model
  461. self.padding_idx = config.pad_token_id
  462. self.max_source_positions = config.max_position_embeddings
  463. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  464. if embed_tokens is not None:
  465. self.embed_tokens = embed_tokens
  466. else:
  467. self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
  468. self.embed_positions = MvpLearnedPositionalEmbedding(
  469. config.max_position_embeddings,
  470. embed_dim,
  471. )
  472. self.layers = nn.ModuleList([MvpEncoderLayer(config) for _ in range(config.encoder_layers)])
  473. self.layernorm_embedding = nn.LayerNorm(embed_dim)
  474. self.use_prompt = use_prompt
  475. if use_prompt:
  476. self.prompt_length = config.prompt_length
  477. self.self_attn_prompt = MvpPrompt(
  478. config,
  479. config.encoder_layers,
  480. config.encoder_attention_heads,
  481. )
  482. self.gradient_checkpointing = False
  483. # Initialize weights and apply final processing
  484. self.post_init()
  485. def forward(
  486. self,
  487. input_ids: Optional[torch.LongTensor] = None,
  488. attention_mask: Optional[torch.Tensor] = None,
  489. head_mask: Optional[torch.Tensor] = None,
  490. inputs_embeds: Optional[torch.FloatTensor] = None,
  491. output_attentions: Optional[bool] = None,
  492. output_hidden_states: Optional[bool] = None,
  493. return_dict: Optional[bool] = None,
  494. ) -> Union[tuple, BaseModelOutput]:
  495. r"""
  496. Args:
  497. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  498. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  499. provide it.
  500. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  501. [`PreTrainedTokenizer.__call__`] for details.
  502. [What are input IDs?](../glossary#input-ids)
  503. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  504. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  505. - 1 for tokens that are **not masked**,
  506. - 0 for tokens that are **masked**.
  507. [What are attention masks?](../glossary#attention-mask)
  508. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  509. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  510. - 1 indicates the head is **not masked**,
  511. - 0 indicates the head is **masked**.
  512. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  513. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  514. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  515. than the model's internal embedding lookup matrix.
  516. output_attentions (`bool`, *optional*):
  517. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  518. returned tensors for more detail.
  519. output_hidden_states (`bool`, *optional*):
  520. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  521. for more detail.
  522. return_dict (`bool`, *optional*):
  523. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  524. """
  525. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  526. output_hidden_states = (
  527. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  528. )
  529. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  530. # retrieve input_ids and inputs_embeds
  531. if input_ids is not None and inputs_embeds is not None:
  532. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  533. elif input_ids is not None:
  534. input = input_ids
  535. input_shape = input.shape
  536. input_ids = input_ids.view(-1, input_shape[-1])
  537. elif inputs_embeds is not None:
  538. input_shape = inputs_embeds.size()[:-1]
  539. input = inputs_embeds[:, :, -1]
  540. else:
  541. raise ValueError("You have to specify either input_ids or inputs_embeds")
  542. if inputs_embeds is None:
  543. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  544. embed_pos = self.embed_positions(input)
  545. hidden_states = inputs_embeds + embed_pos
  546. hidden_states = self.layernorm_embedding(hidden_states)
  547. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  548. # layer-wise prompt
  549. if self.use_prompt:
  550. prompt_ids = torch.arange(self.prompt_length).to(self.device)
  551. self_attn_prompt = self.self_attn_prompt(prompt_ids)
  552. # expand attention_mask
  553. if attention_mask is not None:
  554. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  555. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  556. encoder_states = () if output_hidden_states else None
  557. all_attentions = () if output_attentions else None
  558. # check if head_mask has a correct number of layers specified if desired
  559. if head_mask is not None:
  560. if head_mask.size()[0] != (len(self.layers)):
  561. raise ValueError(
  562. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  563. f" {head_mask.size()[0]}."
  564. )
  565. for idx, encoder_layer in enumerate(self.layers):
  566. if output_hidden_states:
  567. encoder_states = encoder_states + (hidden_states,)
  568. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  569. to_drop = False
  570. if self.training:
  571. dropout_probability = torch.rand([])
  572. if dropout_probability < self.layerdrop: # skip the layer
  573. to_drop = True
  574. if to_drop:
  575. layer_outputs = (None, None)
  576. else:
  577. layer_outputs = encoder_layer(
  578. hidden_states,
  579. attention_mask,
  580. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  581. self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None),
  582. output_attentions=output_attentions,
  583. )
  584. hidden_states = layer_outputs[0]
  585. if output_attentions:
  586. all_attentions = all_attentions + (layer_outputs[1],)
  587. if output_hidden_states:
  588. encoder_states = encoder_states + (hidden_states,)
  589. if not return_dict:
  590. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  591. return BaseModelOutput(
  592. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  593. )
  594. class MvpDecoder(MvpPreTrainedModel):
  595. """
  596. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MvpDecoderLayer`]
  597. Args:
  598. config: MvpConfig
  599. embed_tokens (nn.Embedding): output embedding
  600. use_prompt (bool): whether to use prompt
  601. """
  602. def __init__(
  603. self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False
  604. ):
  605. super().__init__(config)
  606. self.dropout = config.dropout
  607. self.layerdrop = config.decoder_layerdrop
  608. self.padding_idx = config.pad_token_id
  609. self.max_target_positions = config.max_position_embeddings
  610. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  611. if embed_tokens is not None:
  612. self.embed_tokens = embed_tokens
  613. else:
  614. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  615. self.embed_positions = MvpLearnedPositionalEmbedding(
  616. config.max_position_embeddings,
  617. config.d_model,
  618. )
  619. self.layers = nn.ModuleList([MvpDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  620. self.layernorm_embedding = nn.LayerNorm(config.d_model)
  621. self.use_prompt = use_prompt
  622. if use_prompt:
  623. self.prompt_length = config.prompt_length
  624. self.self_attn_prompt = MvpPrompt(
  625. config,
  626. config.decoder_layers,
  627. config.decoder_attention_heads,
  628. )
  629. self.cross_attn_prompt = MvpPrompt(
  630. config,
  631. config.decoder_layers,
  632. config.decoder_attention_heads,
  633. )
  634. self.gradient_checkpointing = False
  635. # Initialize weights and apply final processing
  636. self.post_init()
  637. def forward(
  638. self,
  639. input_ids: Optional[torch.LongTensor] = None,
  640. attention_mask: Optional[torch.Tensor] = None,
  641. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  642. encoder_attention_mask: Optional[torch.LongTensor] = None,
  643. head_mask: Optional[torch.Tensor] = None,
  644. cross_attn_head_mask: Optional[torch.Tensor] = None,
  645. past_key_values: Optional[Cache] = None,
  646. inputs_embeds: Optional[torch.FloatTensor] = None,
  647. use_cache: Optional[bool] = None,
  648. output_attentions: Optional[bool] = None,
  649. output_hidden_states: Optional[bool] = None,
  650. return_dict: Optional[bool] = None,
  651. cache_position: Optional[torch.Tensor] = None,
  652. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  653. r"""
  654. Args:
  655. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  656. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  657. provide it.
  658. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  659. [`PreTrainedTokenizer.__call__`] for details.
  660. [What are input IDs?](../glossary#input-ids)
  661. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  662. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  663. - 1 for tokens that are **not masked**,
  664. - 0 for tokens that are **masked**.
  665. [What are attention masks?](../glossary#attention-mask)
  666. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  667. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  668. of the decoder.
  669. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  670. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  671. selected in `[0, 1]`:
  672. - 1 for tokens that are **not masked**,
  673. - 0 for tokens that are **masked**.
  674. [What are attention masks?](../glossary#attention-mask)
  675. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  676. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  677. - 1 indicates the head is **not masked**,
  678. - 0 indicates the head is **masked**.
  679. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  680. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  681. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  682. - 1 indicates the head is **not masked**,
  683. - 0 indicates the head is **masked**.
  684. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  685. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  686. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  687. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  688. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  689. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  690. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  691. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  692. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  693. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  694. than the model's internal embedding lookup matrix.
  695. output_attentions (`bool`, *optional*):
  696. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  697. returned tensors for more detail.
  698. output_hidden_states (`bool`, *optional*):
  699. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  700. for more detail.
  701. return_dict (`bool`, *optional*):
  702. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  703. """
  704. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  705. output_hidden_states = (
  706. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  707. )
  708. use_cache = use_cache if use_cache is not None else self.config.use_cache
  709. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  710. # retrieve input_ids and inputs_embeds
  711. if input_ids is not None and inputs_embeds is not None:
  712. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  713. elif input_ids is not None:
  714. input = input_ids
  715. input_shape = input_ids.shape
  716. input_ids = input_ids.view(-1, input_shape[-1])
  717. elif inputs_embeds is not None:
  718. input_shape = inputs_embeds.size()[:-1]
  719. input = inputs_embeds[:, :, -1]
  720. else:
  721. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  722. if inputs_embeds is None:
  723. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  724. if self.gradient_checkpointing and self.training:
  725. if use_cache:
  726. logger.warning_once(
  727. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  728. )
  729. use_cache = False
  730. if use_cache and past_key_values is None:
  731. past_key_values = (
  732. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  733. if encoder_hidden_states is not None
  734. else DynamicCache(config=self.config)
  735. )
  736. if use_cache and isinstance(past_key_values, tuple):
  737. logger.warning_once(
  738. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  739. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  740. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  741. )
  742. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  743. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  744. attention_mask = _prepare_4d_causal_attention_mask(
  745. attention_mask, input_shape, inputs_embeds, past_key_values_length
  746. )
  747. # expand encoder attention mask
  748. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  749. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  750. encoder_attention_mask = _prepare_4d_attention_mask(
  751. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  752. )
  753. # embed positions
  754. positions = self.embed_positions(input, past_key_values_length)
  755. hidden_states = inputs_embeds + positions
  756. hidden_states = self.layernorm_embedding(hidden_states)
  757. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  758. # layer-wise prompt
  759. if self.use_prompt:
  760. prompt_ids = torch.arange(self.prompt_length).to(self.device)
  761. self_attn_prompt = self.self_attn_prompt(prompt_ids)
  762. cross_attn_prompt = self.cross_attn_prompt(prompt_ids)
  763. # decoder layers
  764. all_hidden_states = () if output_hidden_states else None
  765. all_self_attns = () if output_attentions else None
  766. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  767. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  768. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  769. if attn_mask is not None:
  770. if attn_mask.size()[0] != (len(self.layers)):
  771. raise ValueError(
  772. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  773. f" {head_mask.size()[0]}."
  774. )
  775. for idx, decoder_layer in enumerate(self.layers):
  776. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  777. if output_hidden_states:
  778. all_hidden_states += (hidden_states,)
  779. if self.training:
  780. dropout_probability = torch.rand([])
  781. if dropout_probability < self.layerdrop:
  782. continue
  783. layer_outputs = decoder_layer(
  784. hidden_states,
  785. attention_mask,
  786. encoder_hidden_states, # as positional argument for gradient checkpointing
  787. encoder_attention_mask=encoder_attention_mask,
  788. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  789. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  790. self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None),
  791. cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None),
  792. past_key_values=past_key_values,
  793. output_attentions=output_attentions,
  794. use_cache=use_cache,
  795. cache_position=cache_position,
  796. )
  797. hidden_states = layer_outputs[0]
  798. if output_attentions:
  799. all_self_attns += (layer_outputs[1],)
  800. if encoder_hidden_states is not None:
  801. all_cross_attentions += (layer_outputs[2],)
  802. # add hidden states from the last decoder layer
  803. if output_hidden_states:
  804. all_hidden_states += (hidden_states,)
  805. if not return_dict:
  806. return tuple(
  807. v
  808. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  809. if v is not None
  810. )
  811. return BaseModelOutputWithPastAndCrossAttentions(
  812. last_hidden_state=hidden_states,
  813. past_key_values=past_key_values,
  814. hidden_states=all_hidden_states,
  815. attentions=all_self_attns,
  816. cross_attentions=all_cross_attentions,
  817. )
  818. @auto_docstring
  819. class MvpModel(MvpPreTrainedModel):
  820. _keys_to_ignore_on_load_unexpected = ["final_logits_bias"]
  821. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  822. def __init__(self, config: MvpConfig):
  823. super().__init__(config)
  824. padding_idx, vocab_size = config.pad_token_id, config.vocab_size
  825. self.use_prompt = config.use_prompt
  826. self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  827. self.encoder = MvpEncoder(config, self.shared, config.use_prompt)
  828. self.decoder = MvpDecoder(config, self.shared, config.use_prompt)
  829. # Initialize weights and apply final processing
  830. self.post_init()
  831. def get_input_embeddings(self):
  832. return self.shared
  833. def set_input_embeddings(self, value):
  834. self.shared = value
  835. self.encoder.embed_tokens = self.shared
  836. self.decoder.embed_tokens = self.shared
  837. def get_encoder(self):
  838. return self.encoder
  839. def set_lightweight_tuning(self):
  840. assert self.use_prompt, "If you want to use lightweight tuning, make sure that `use_prompt=True`."
  841. self.requires_grad_(False)
  842. self.encoder.self_attn_prompt.requires_grad_(True)
  843. self.decoder.self_attn_prompt.requires_grad_(True)
  844. self.decoder.cross_attn_prompt.requires_grad_(True)
  845. @auto_docstring
  846. def forward(
  847. self,
  848. input_ids: Optional[torch.LongTensor] = None,
  849. attention_mask: Optional[torch.Tensor] = None,
  850. decoder_input_ids: Optional[torch.LongTensor] = None,
  851. decoder_attention_mask: Optional[torch.LongTensor] = None,
  852. head_mask: Optional[torch.Tensor] = None,
  853. decoder_head_mask: Optional[torch.Tensor] = None,
  854. cross_attn_head_mask: Optional[torch.Tensor] = None,
  855. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  856. past_key_values: Optional[Cache] = None,
  857. inputs_embeds: Optional[torch.FloatTensor] = None,
  858. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  859. use_cache: Optional[bool] = None,
  860. output_attentions: Optional[bool] = None,
  861. output_hidden_states: Optional[bool] = None,
  862. return_dict: Optional[bool] = None,
  863. cache_position: Optional[torch.Tensor] = None,
  864. ) -> Union[tuple, Seq2SeqModelOutput]:
  865. r"""
  866. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  867. Indices of decoder input sequence tokens in the vocabulary.
  868. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  869. [`PreTrainedTokenizer.__call__`] for details.
  870. [What are decoder input IDs?](../glossary#decoder-input-ids)
  871. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  872. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  873. For translation and summarization training, `decoder_input_ids` should be provided. If no
  874. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  875. for denoising pre-training following the paper.
  876. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  877. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  878. be used by default.
  879. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  880. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  881. information on the default strategy.
  882. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  883. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  884. 1]`:
  885. - 1 indicates the head is **not masked**,
  886. - 0 indicates the head is **masked**.
  887. """
  888. # different to other models, Mvp automatically creates decoder_input_ids from
  889. # input_ids if no decoder_input_ids are provided
  890. if decoder_input_ids is None and decoder_inputs_embeds is None:
  891. if input_ids is None:
  892. raise ValueError(
  893. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  894. "passed, `input_ids` cannot be `None`. Please pass either "
  895. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  896. )
  897. decoder_input_ids = shift_tokens_right(
  898. input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
  899. )
  900. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  901. output_hidden_states = (
  902. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  903. )
  904. use_cache = use_cache if use_cache is not None else self.config.use_cache
  905. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  906. if encoder_outputs is None:
  907. encoder_outputs = self.encoder(
  908. input_ids=input_ids,
  909. attention_mask=attention_mask,
  910. head_mask=head_mask,
  911. inputs_embeds=inputs_embeds,
  912. output_attentions=output_attentions,
  913. output_hidden_states=output_hidden_states,
  914. return_dict=return_dict,
  915. )
  916. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  917. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  918. encoder_outputs = BaseModelOutput(
  919. last_hidden_state=encoder_outputs[0],
  920. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  921. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  922. )
  923. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  924. decoder_outputs = self.decoder(
  925. input_ids=decoder_input_ids,
  926. attention_mask=decoder_attention_mask,
  927. encoder_hidden_states=encoder_outputs[0],
  928. encoder_attention_mask=attention_mask,
  929. head_mask=decoder_head_mask,
  930. cross_attn_head_mask=cross_attn_head_mask,
  931. past_key_values=past_key_values,
  932. inputs_embeds=decoder_inputs_embeds,
  933. use_cache=use_cache,
  934. output_attentions=output_attentions,
  935. output_hidden_states=output_hidden_states,
  936. return_dict=return_dict,
  937. cache_position=cache_position,
  938. )
  939. if not return_dict:
  940. return decoder_outputs + encoder_outputs
  941. return Seq2SeqModelOutput(
  942. last_hidden_state=decoder_outputs.last_hidden_state,
  943. past_key_values=decoder_outputs.past_key_values,
  944. decoder_hidden_states=decoder_outputs.hidden_states,
  945. decoder_attentions=decoder_outputs.attentions,
  946. cross_attentions=decoder_outputs.cross_attentions,
  947. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  948. encoder_hidden_states=encoder_outputs.hidden_states,
  949. encoder_attentions=encoder_outputs.attentions,
  950. )
  951. @auto_docstring(
  952. custom_intro="""
  953. The MVP Model with a language modeling head. Can be used for various text generation tasks.
  954. """
  955. )
  956. class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin):
  957. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  958. def __init__(self, config: MvpConfig):
  959. super().__init__(config)
  960. self.model = MvpModel(config)
  961. self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
  962. self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
  963. # Initialize weights and apply final processing
  964. self.post_init()
  965. def get_encoder(self):
  966. return self.model.get_encoder()
  967. def get_decoder(self):
  968. return self.model.get_decoder()
  969. def resize_token_embeddings(
  970. self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
  971. ) -> nn.Embedding:
  972. new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  973. self._resize_final_logits_bias(new_num_tokens)
  974. return new_embeddings
  975. def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
  976. old_num_tokens = self.final_logits_bias.shape[-1]
  977. if new_num_tokens <= old_num_tokens:
  978. new_bias = self.final_logits_bias[:, :new_num_tokens]
  979. else:
  980. extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
  981. new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
  982. self.register_buffer("final_logits_bias", new_bias)
  983. def set_lightweight_tuning(self):
  984. self.model.set_lightweight_tuning()
  985. self.lm_head.requires_grad_(False)
  986. @auto_docstring
  987. def forward(
  988. self,
  989. input_ids: Optional[torch.LongTensor] = None,
  990. attention_mask: Optional[torch.Tensor] = None,
  991. decoder_input_ids: Optional[torch.LongTensor] = None,
  992. decoder_attention_mask: Optional[torch.LongTensor] = None,
  993. head_mask: Optional[torch.Tensor] = None,
  994. decoder_head_mask: Optional[torch.Tensor] = None,
  995. cross_attn_head_mask: Optional[torch.Tensor] = None,
  996. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  997. past_key_values: Optional[Cache] = None,
  998. inputs_embeds: Optional[torch.FloatTensor] = None,
  999. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1000. labels: Optional[torch.LongTensor] = None,
  1001. use_cache: Optional[bool] = None,
  1002. output_attentions: Optional[bool] = None,
  1003. output_hidden_states: Optional[bool] = None,
  1004. return_dict: Optional[bool] = None,
  1005. cache_position: Optional[torch.Tensor] = None,
  1006. ) -> Union[tuple, Seq2SeqLMOutput]:
  1007. r"""
  1008. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1009. Indices of decoder input sequence tokens in the vocabulary.
  1010. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1011. [`PreTrainedTokenizer.__call__`] for details.
  1012. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1013. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1014. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1015. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1016. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1017. for denoising pre-training following the paper.
  1018. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1019. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1020. be used by default.
  1021. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  1022. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1023. information on the default strategy.
  1024. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1025. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1026. 1]`:
  1027. - 1 indicates the head is **not masked**,
  1028. - 0 indicates the head is **masked**.
  1029. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1030. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1031. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1032. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1033. Example of summarization:
  1034. Fine-tuning a model
  1035. ```python
  1036. >>> import torch
  1037. >>> from transformers import AutoTokenizer, MvpForConditionalGeneration
  1038. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1039. >>> model = MvpForConditionalGeneration.from_pretrained("RUCAIBox/mvp")
  1040. >>> inputs = tokenizer(
  1041. ... "Summarize: You may want to stick it to your boss and leave your job, but don't do it if these are your reasons.",
  1042. ... return_tensors="pt",
  1043. ... )
  1044. >>> labels = tokenizer("Bad Reasons To Quit Your Job", return_tensors="pt")["input_ids"]
  1045. >>> loss = model(**inputs, labels=labels).loss
  1046. >>> loss.backward()
  1047. ```
  1048. Inference after the model fine-tuned
  1049. ```python
  1050. >>> with torch.no_grad():
  1051. ... generated_ids = model.generate(**inputs)
  1052. >>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  1053. ```
  1054. """
  1055. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1056. if labels is not None:
  1057. if use_cache:
  1058. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1059. use_cache = False
  1060. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1061. decoder_input_ids = shift_tokens_right(
  1062. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1063. )
  1064. outputs = self.model(
  1065. input_ids,
  1066. attention_mask=attention_mask,
  1067. decoder_input_ids=decoder_input_ids,
  1068. encoder_outputs=encoder_outputs,
  1069. decoder_attention_mask=decoder_attention_mask,
  1070. head_mask=head_mask,
  1071. decoder_head_mask=decoder_head_mask,
  1072. cross_attn_head_mask=cross_attn_head_mask,
  1073. past_key_values=past_key_values,
  1074. inputs_embeds=inputs_embeds,
  1075. decoder_inputs_embeds=decoder_inputs_embeds,
  1076. use_cache=use_cache,
  1077. output_attentions=output_attentions,
  1078. output_hidden_states=output_hidden_states,
  1079. return_dict=return_dict,
  1080. cache_position=cache_position,
  1081. )
  1082. lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
  1083. masked_lm_loss = None
  1084. if labels is not None:
  1085. loss_fct = CrossEntropyLoss()
  1086. masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1087. if not return_dict:
  1088. output = (lm_logits,) + outputs[1:]
  1089. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1090. return Seq2SeqLMOutput(
  1091. loss=masked_lm_loss,
  1092. logits=lm_logits,
  1093. past_key_values=outputs.past_key_values,
  1094. decoder_hidden_states=outputs.decoder_hidden_states,
  1095. decoder_attentions=outputs.decoder_attentions,
  1096. cross_attentions=outputs.cross_attentions,
  1097. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1098. encoder_hidden_states=outputs.encoder_hidden_states,
  1099. encoder_attentions=outputs.encoder_attentions,
  1100. )
  1101. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1102. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  1103. @auto_docstring(
  1104. custom_intro="""
  1105. Mvp model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1106. tasks.
  1107. """
  1108. )
  1109. class MvpForSequenceClassification(MvpPreTrainedModel):
  1110. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1111. def __init__(self, config: MvpConfig, **kwargs):
  1112. super().__init__(config, **kwargs)
  1113. self.model = MvpModel(config)
  1114. self.classification_head = MvpClassificationHead(
  1115. config.d_model,
  1116. config.d_model,
  1117. config.num_labels,
  1118. config.classifier_dropout,
  1119. )
  1120. # Initialize weights and apply final processing
  1121. self.post_init()
  1122. def set_lightweight_tuning(self):
  1123. self.model.set_lightweight_tuning()
  1124. self.classification_head.requires_grad_(False)
  1125. @auto_docstring
  1126. def forward(
  1127. self,
  1128. input_ids: Optional[torch.LongTensor] = None,
  1129. attention_mask: Optional[torch.Tensor] = None,
  1130. decoder_input_ids: Optional[torch.LongTensor] = None,
  1131. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1132. head_mask: Optional[torch.Tensor] = None,
  1133. decoder_head_mask: Optional[torch.Tensor] = None,
  1134. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1135. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1136. inputs_embeds: Optional[torch.FloatTensor] = None,
  1137. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1138. labels: Optional[torch.LongTensor] = None,
  1139. use_cache: Optional[bool] = None,
  1140. output_attentions: Optional[bool] = None,
  1141. output_hidden_states: Optional[bool] = None,
  1142. return_dict: Optional[bool] = None,
  1143. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1144. r"""
  1145. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1146. Indices of decoder input sequence tokens in the vocabulary.
  1147. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1148. [`PreTrainedTokenizer.__call__`] for details.
  1149. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1150. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1151. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1152. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1153. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1154. for denoising pre-training following the paper.
  1155. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1156. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1157. be used by default.
  1158. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  1159. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1160. information on the default strategy.
  1161. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1162. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1163. 1]`:
  1164. - 1 indicates the head is **not masked**,
  1165. - 0 indicates the head is **masked**.
  1166. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1167. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1168. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1169. Example of single-label classification:
  1170. Fine-tuning a model on `num_labels` classes
  1171. ```python
  1172. >>> import torch
  1173. >>> from transformers import AutoTokenizer, MvpForSequenceClassification
  1174. >>> num_labels = 2 # for example, this is a binary classification task
  1175. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1176. >>> model = MvpForSequenceClassification.from_pretrained("RUCAIBox/mvp", num_labels=num_labels)
  1177. >>> inputs = tokenizer("Classify: Hello, my dog is cute", return_tensors="pt")
  1178. >>> labels = torch.tensor(1) # the real label for inputs
  1179. >>> loss = model(**inputs, labels=labels).loss
  1180. >>> loss.backward()
  1181. ```
  1182. Inference after the model fine-tuned
  1183. ```python
  1184. >>> with torch.no_grad():
  1185. ... logits = model(**inputs).logits
  1186. >>> predicted_class_id = logits.argmax()
  1187. ```
  1188. """
  1189. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1190. if labels is not None:
  1191. use_cache = False
  1192. if input_ids is None and inputs_embeds is not None:
  1193. raise NotImplementedError(
  1194. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1195. )
  1196. outputs = self.model(
  1197. input_ids,
  1198. attention_mask=attention_mask,
  1199. decoder_input_ids=decoder_input_ids,
  1200. decoder_attention_mask=decoder_attention_mask,
  1201. head_mask=head_mask,
  1202. decoder_head_mask=decoder_head_mask,
  1203. cross_attn_head_mask=cross_attn_head_mask,
  1204. encoder_outputs=encoder_outputs,
  1205. inputs_embeds=inputs_embeds,
  1206. decoder_inputs_embeds=decoder_inputs_embeds,
  1207. use_cache=use_cache,
  1208. output_attentions=output_attentions,
  1209. output_hidden_states=output_hidden_states,
  1210. return_dict=return_dict,
  1211. )
  1212. hidden_states = outputs[0] # last hidden state
  1213. eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
  1214. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1215. raise ValueError("All examples must have the same number of <eos> tokens.")
  1216. sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
  1217. :, -1, :
  1218. ]
  1219. logits = self.classification_head(sentence_representation)
  1220. loss = None
  1221. if labels is not None:
  1222. if self.config.problem_type is None:
  1223. if self.config.num_labels == 1:
  1224. self.config.problem_type = "regression"
  1225. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1226. self.config.problem_type = "single_label_classification"
  1227. else:
  1228. self.config.problem_type = "multi_label_classification"
  1229. if self.config.problem_type == "regression":
  1230. loss_fct = MSELoss()
  1231. if self.config.num_labels == 1:
  1232. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1233. else:
  1234. loss = loss_fct(logits, labels)
  1235. elif self.config.problem_type == "single_label_classification":
  1236. loss_fct = CrossEntropyLoss()
  1237. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1238. elif self.config.problem_type == "multi_label_classification":
  1239. loss_fct = BCEWithLogitsLoss()
  1240. loss = loss_fct(logits, labels)
  1241. if not return_dict:
  1242. output = (logits,) + outputs[1:]
  1243. return ((loss,) + output) if loss is not None else output
  1244. return Seq2SeqSequenceClassifierOutput(
  1245. loss=loss,
  1246. logits=logits,
  1247. past_key_values=outputs.past_key_values,
  1248. decoder_hidden_states=outputs.decoder_hidden_states,
  1249. decoder_attentions=outputs.decoder_attentions,
  1250. cross_attentions=outputs.cross_attentions,
  1251. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1252. encoder_hidden_states=outputs.encoder_hidden_states,
  1253. encoder_attentions=outputs.encoder_attentions,
  1254. )
  1255. @auto_docstring
  1256. class MvpForQuestionAnswering(MvpPreTrainedModel):
  1257. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1258. def __init__(self, config):
  1259. super().__init__(config)
  1260. config.num_labels = 2
  1261. self.num_labels = config.num_labels
  1262. self.model = MvpModel(config)
  1263. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1264. # Initialize weights and apply final processing
  1265. self.post_init()
  1266. def set_lightweight_tuning(self):
  1267. self.model.set_lightweight_tuning()
  1268. self.qa_outputs.requires_grad_(False)
  1269. @auto_docstring
  1270. def forward(
  1271. self,
  1272. input_ids: Optional[torch.Tensor] = None,
  1273. attention_mask: Optional[torch.Tensor] = None,
  1274. decoder_input_ids: Optional[torch.LongTensor] = None,
  1275. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1276. head_mask: Optional[torch.Tensor] = None,
  1277. decoder_head_mask: Optional[torch.Tensor] = None,
  1278. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1279. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1280. start_positions: Optional[torch.LongTensor] = None,
  1281. end_positions: Optional[torch.LongTensor] = None,
  1282. inputs_embeds: Optional[torch.FloatTensor] = None,
  1283. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1284. use_cache: Optional[bool] = None,
  1285. output_attentions: Optional[bool] = None,
  1286. output_hidden_states: Optional[bool] = None,
  1287. return_dict: Optional[bool] = None,
  1288. ) -> Union[tuple, Seq2SeqQuestionAnsweringModelOutput]:
  1289. r"""
  1290. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1291. Indices of decoder input sequence tokens in the vocabulary.
  1292. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1293. [`PreTrainedTokenizer.__call__`] for details.
  1294. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1295. Mvp uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1296. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1297. For translation and summarization training, `decoder_input_ids` should be provided. If no
  1298. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  1299. for denoising pre-training following the paper.
  1300. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1301. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1302. be used by default.
  1303. If you want to change padding behavior, you should read [`modeling_mvp._prepare_decoder_attention_mask`]
  1304. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1305. information on the default strategy.
  1306. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1307. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
  1308. 1]`:
  1309. - 1 indicates the head is **not masked**,
  1310. - 0 indicates the head is **masked**.
  1311. Example:
  1312. Fine-tuning a model for extrative question answering, and our model also supports generative question answering
  1313. using `BartForConditionalGeneration`
  1314. ```python
  1315. >>> import torch
  1316. >>> from transformers import AutoTokenizer, MvpForQuestionAnswering
  1317. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1318. >>> model = MvpForQuestionAnswering.from_pretrained("RUCAIBox/mvp")
  1319. >>> inputs = tokenizer(
  1320. ... "Answer the following question: Who was Jim Henson? [SEP] Jim Henson was a nice puppet",
  1321. ... return_tensors="pt",
  1322. ... )
  1323. >>> target_start_index = torch.tensor([18])
  1324. >>> target_end_index = torch.tensor([19])
  1325. >>> loss = model(**inputs, start_positions=target_start_index, end_positions=target_end_index).loss
  1326. >>> loss.backward()
  1327. ```
  1328. Inference after the model fine-tuned
  1329. ```python
  1330. >>> with torch.no_grad():
  1331. ... outputs = model(**inputs)
  1332. >>> answer_start_index = outputs.start_logits.argmax()
  1333. >>> answer_end_index = outputs.end_logits.argmax()
  1334. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  1335. >>> predict_answer = tokenizer.decode(predict_answer_tokens)
  1336. ```
  1337. """
  1338. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1339. if start_positions is not None and end_positions is not None:
  1340. use_cache = False
  1341. outputs = self.model(
  1342. input_ids,
  1343. attention_mask=attention_mask,
  1344. decoder_input_ids=decoder_input_ids,
  1345. decoder_attention_mask=decoder_attention_mask,
  1346. head_mask=head_mask,
  1347. decoder_head_mask=decoder_head_mask,
  1348. cross_attn_head_mask=cross_attn_head_mask,
  1349. encoder_outputs=encoder_outputs,
  1350. inputs_embeds=inputs_embeds,
  1351. decoder_inputs_embeds=decoder_inputs_embeds,
  1352. use_cache=use_cache,
  1353. output_attentions=output_attentions,
  1354. output_hidden_states=output_hidden_states,
  1355. return_dict=return_dict,
  1356. )
  1357. sequence_output = outputs[0]
  1358. logits = self.qa_outputs(sequence_output)
  1359. start_logits, end_logits = logits.split(1, dim=-1)
  1360. start_logits = start_logits.squeeze(-1).contiguous()
  1361. end_logits = end_logits.squeeze(-1).contiguous()
  1362. total_loss = None
  1363. if start_positions is not None and end_positions is not None:
  1364. # If we are on multi-GPU, split add a dimension
  1365. if len(start_positions.size()) > 1:
  1366. start_positions = start_positions.squeeze(-1)
  1367. if len(end_positions.size()) > 1:
  1368. end_positions = end_positions.squeeze(-1)
  1369. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1370. ignored_index = start_logits.size(1)
  1371. start_positions = start_positions.clamp(0, ignored_index)
  1372. end_positions = end_positions.clamp(0, ignored_index)
  1373. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1374. start_loss = loss_fct(start_logits, start_positions)
  1375. end_loss = loss_fct(end_logits, end_positions)
  1376. total_loss = (start_loss + end_loss) / 2
  1377. if not return_dict:
  1378. output = (
  1379. start_logits,
  1380. end_logits,
  1381. ) + outputs[1:]
  1382. return ((total_loss,) + output) if total_loss is not None else output
  1383. return Seq2SeqQuestionAnsweringModelOutput(
  1384. loss=total_loss,
  1385. start_logits=start_logits,
  1386. end_logits=end_logits,
  1387. past_key_values=outputs.past_key_values,
  1388. decoder_hidden_states=outputs.decoder_hidden_states,
  1389. decoder_attentions=outputs.decoder_attentions,
  1390. cross_attentions=outputs.cross_attentions,
  1391. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1392. encoder_hidden_states=outputs.encoder_hidden_states,
  1393. encoder_attentions=outputs.encoder_attentions,
  1394. )
  1395. # Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Mvp
  1396. class MvpDecoderWrapper(MvpPreTrainedModel):
  1397. """
  1398. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1399. used in combination with the [`EncoderDecoderModel`] framework.
  1400. """
  1401. def __init__(self, config):
  1402. super().__init__(config)
  1403. self.decoder = MvpDecoder(config)
  1404. def forward(self, *args, **kwargs):
  1405. return self.decoder(*args, **kwargs)
  1406. class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin):
  1407. _tied_weights_keys = ["lm_head.weight"]
  1408. def __init__(self, config):
  1409. config.is_decoder = True
  1410. config.is_encoder_decoder = False
  1411. super().__init__(config)
  1412. self.model = MvpDecoderWrapper(config)
  1413. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1414. # Initialize weights and apply final processing
  1415. self.post_init()
  1416. def get_input_embeddings(self):
  1417. return self.model.decoder.embed_tokens
  1418. def set_input_embeddings(self, value):
  1419. self.model.decoder.embed_tokens = value
  1420. def set_decoder(self, decoder):
  1421. self.model.decoder = decoder
  1422. def get_decoder(self):
  1423. return self.model.decoder
  1424. def set_lightweight_tuning(self):
  1425. self.model.set_lightweight_tuning()
  1426. self.lm_head.requires_grad_(False)
  1427. @auto_docstring
  1428. def forward(
  1429. self,
  1430. input_ids: Optional[torch.LongTensor] = None,
  1431. attention_mask: Optional[torch.Tensor] = None,
  1432. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1433. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1434. head_mask: Optional[torch.Tensor] = None,
  1435. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1436. past_key_values: Optional[Cache] = None,
  1437. inputs_embeds: Optional[torch.FloatTensor] = None,
  1438. labels: Optional[torch.LongTensor] = None,
  1439. use_cache: Optional[bool] = None,
  1440. output_attentions: Optional[bool] = None,
  1441. output_hidden_states: Optional[bool] = None,
  1442. return_dict: Optional[bool] = None,
  1443. cache_position: Optional[torch.Tensor] = None,
  1444. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1445. r"""
  1446. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1447. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1448. - 1 indicates the head is **not masked**,
  1449. - 0 indicates the head is **masked**.
  1450. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1451. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1452. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1453. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1454. Example:
  1455. ```python
  1456. >>> from transformers import AutoTokenizer, MvpForCausalLM
  1457. >>> tokenizer = AutoTokenizer.from_pretrained("RUCAIBox/mvp")
  1458. >>> model = MvpForCausalLM.from_pretrained("RUCAIBox/mvp", add_cross_attention=False)
  1459. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1460. >>> outputs = model(**inputs)
  1461. >>> logits = outputs.logits
  1462. >>> list(logits.shape)
  1463. [1, 8, 50267]
  1464. ```"""
  1465. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1466. output_hidden_states = (
  1467. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1468. )
  1469. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1470. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1471. outputs = self.model.decoder(
  1472. input_ids=input_ids,
  1473. attention_mask=attention_mask,
  1474. encoder_hidden_states=encoder_hidden_states,
  1475. encoder_attention_mask=encoder_attention_mask,
  1476. head_mask=head_mask,
  1477. cross_attn_head_mask=cross_attn_head_mask,
  1478. past_key_values=past_key_values,
  1479. inputs_embeds=inputs_embeds,
  1480. use_cache=use_cache,
  1481. output_attentions=output_attentions,
  1482. output_hidden_states=output_hidden_states,
  1483. return_dict=return_dict,
  1484. )
  1485. logits = self.lm_head(outputs[0])
  1486. loss = None
  1487. if labels is not None:
  1488. loss_fct = CrossEntropyLoss()
  1489. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1490. if not return_dict:
  1491. output = (logits,) + outputs[1:]
  1492. return (loss,) + output if loss is not None else output
  1493. return CausalLMOutputWithCrossAttentions(
  1494. loss=loss,
  1495. logits=logits,
  1496. past_key_values=outputs.past_key_values,
  1497. hidden_states=outputs.hidden_states,
  1498. attentions=outputs.attentions,
  1499. cross_attentions=outputs.cross_attentions,
  1500. )
  1501. __all__ = [
  1502. "MvpForCausalLM",
  1503. "MvpForConditionalGeneration",
  1504. "MvpForQuestionAnswering",
  1505. "MvpForSequenceClassification",
  1506. "MvpModel",
  1507. "MvpPreTrainedModel",
  1508. ]