modeling_prophetnet.py 95 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056
  1. # coding=utf-8
  2. # Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ProphetNet model, ported from ProphetNet repo(fairsequery_states version)."""
  16. import copy
  17. import math
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. from torch import Tensor, nn
  23. from torch.nn import LayerNorm
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  26. from ...generation import GenerationMixin
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutput
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import ModelOutput, auto_docstring, logging
  31. from ...utils.deprecation import deprecate_kwarg
  32. from .configuration_prophetnet import ProphetNetConfig
  33. logger = logging.get_logger(__name__)
  34. def softmax(hidden_state, dim, onnx_trace=False):
  35. if onnx_trace:
  36. return nn.functional.softmax(hidden_state.float(), dim=dim)
  37. else:
  38. return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32)
  39. def ngram_attention_bias(sequence_length, ngram, device, dtype):
  40. """
  41. This function computes the bias for the predict stream
  42. """
  43. left_block = (
  44. torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
  45. )
  46. right_block = left_block.detach().clone()
  47. # create bias
  48. for stream_idx in range(ngram):
  49. right_block[stream_idx].fill_diagonal_(0, wrap=False)
  50. left_block[stream_idx].triu_(-stream_idx + 1)
  51. left_block[:, :, 0] = 0
  52. return torch.cat([left_block, right_block], dim=2)
  53. def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
  54. """
  55. This function computes individual parts of the relative position buckets. For more detail, see paper.
  56. """
  57. inv_relative_positions = -relative_positions
  58. rel_positions_bucket = 0
  59. if is_bidirectional:
  60. num_buckets = num_buckets // 2
  61. rel_positions_bucket = (
  62. rel_positions_bucket
  63. + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets
  64. )
  65. inv_relative_positions = torch.abs(inv_relative_positions)
  66. else:
  67. inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions))
  68. max_exact = num_buckets // 2
  69. is_small = torch.lt(inv_relative_positions, max_exact)
  70. val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log(
  71. max_distance / max_exact
  72. ) * (num_buckets - max_exact)
  73. val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int()
  74. rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large)
  75. return rel_positions_bucket
  76. def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids):
  77. """
  78. This function computes both main and predict relative position buckets. For more detail, see paper.
  79. """
  80. # main stream
  81. main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1)
  82. main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1)
  83. # predicting stream
  84. predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1)
  85. predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1)
  86. predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1)
  87. # get both position buckets
  88. main_relative_position_buckets = compute_relative_buckets(
  89. num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False
  90. )
  91. predict_relative_position_buckets = compute_relative_buckets(
  92. num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False
  93. )
  94. return main_relative_position_buckets, predict_relative_position_buckets
  95. @dataclass
  96. @auto_docstring(
  97. custom_intro="""
  98. Base class for sequence-to-sequence language models outputs.
  99. """
  100. )
  101. class ProphetNetSeq2SeqLMOutput(ModelOutput):
  102. r"""
  103. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  104. Language modeling loss.
  105. logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
  106. Prediction scores of the main stream language modeling head (scores for each vocabulary token before
  107. SoftMax).
  108. logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  109. Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
  110. SoftMax).
  111. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  112. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  113. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  114. used (see `past_key_values` input) to speed up sequential decoding.
  115. decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  116. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  117. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  118. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  119. outputs.
  120. decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  121. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  122. decoder_sequence_length, decoder_sequence_length)`.
  123. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  124. weighted average in the self-attention heads.
  125. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  126. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  127. """
  128. loss: Optional[torch.FloatTensor] = None
  129. logits: Optional[torch.FloatTensor] = None
  130. logits_ngram: Optional[torch.FloatTensor] = None
  131. past_key_values: Optional[Cache] = None
  132. decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  133. decoder_ngram_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  134. decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  135. decoder_ngram_attentions: Optional[tuple[torch.FloatTensor]] = None
  136. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  137. encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  138. encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  139. encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  140. @property
  141. def decoder_cross_attentions(self):
  142. warnings.warn(
  143. "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
  144. " instead.",
  145. FutureWarning,
  146. )
  147. return self.cross_attentions
  148. @dataclass
  149. @auto_docstring(
  150. custom_intro="""
  151. Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
  152. decoding.
  153. """
  154. )
  155. class ProphetNetSeq2SeqModelOutput(ModelOutput):
  156. r"""
  157. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
  158. Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
  159. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  160. hidden_size)` is output.
  161. last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
  162. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
  163. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  164. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  165. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  166. used (see `past_key_values` input) to speed up sequential decoding.
  167. decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  168. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  169. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  170. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  171. outputs.
  172. decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  173. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  174. decoder_sequence_length, decoder_sequence_length)`.
  175. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  176. weighted average in the
  177. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  178. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  179. """
  180. last_hidden_state: torch.FloatTensor
  181. last_hidden_state_ngram: Optional[torch.FloatTensor] = None
  182. past_key_values: Optional[Cache] = None
  183. decoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  184. decoder_ngram_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  185. decoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  186. decoder_ngram_attentions: Optional[tuple[torch.FloatTensor]] = None
  187. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  188. encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  189. encoder_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  190. encoder_attentions: Optional[tuple[torch.FloatTensor]] = None
  191. @property
  192. def decoder_cross_attentions(self):
  193. warnings.warn(
  194. "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
  195. " instead.",
  196. FutureWarning,
  197. )
  198. return self.cross_attentions
  199. @dataclass
  200. @auto_docstring(
  201. custom_intro="""
  202. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  203. """
  204. )
  205. class ProphetNetDecoderModelOutput(ModelOutput):
  206. r"""
  207. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
  208. Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
  209. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  210. hidden_size)` is output.
  211. last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  212. Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
  213. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  214. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  215. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  216. used (see `past_key_values` input) to speed up sequential decoding.
  217. hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  218. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  219. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  220. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  221. outputs.
  222. ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  223. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  224. decoder_sequence_length, decoder_sequence_length)`.
  225. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  226. weighted average in the
  227. """
  228. last_hidden_state: torch.FloatTensor
  229. last_hidden_state_ngram: Optional[torch.FloatTensor] = None
  230. past_key_values: Optional[Cache] = None
  231. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  232. hidden_states_ngram: Optional[tuple[torch.FloatTensor]] = None
  233. attentions: Optional[tuple[torch.FloatTensor]] = None
  234. ngram_attentions: Optional[tuple[torch.FloatTensor]] = None
  235. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  236. @dataclass
  237. @auto_docstring(
  238. custom_intro="""
  239. Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
  240. """
  241. )
  242. class ProphetNetDecoderLMOutput(ModelOutput):
  243. r"""
  244. ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  245. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  246. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  247. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  248. outputs.
  249. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  250. Language modeling loss.
  251. logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
  252. Prediction scores of the main stream language modeling head (scores for each vocabulary token before
  253. SoftMax).
  254. logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
  255. Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
  256. SoftMax).
  257. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  258. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  259. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  260. used (see `past_key_values` input) to speed up sequential decoding.
  261. hidden_states_ngram (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  262. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  263. shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
  264. Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
  265. outputs.
  266. ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  267. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
  268. decoder_sequence_length, decoder_sequence_length)`.
  269. Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
  270. weighted average in the
  271. """
  272. loss: Optional[torch.FloatTensor] = None
  273. logits: Optional[torch.FloatTensor] = None
  274. logits_ngram: Optional[torch.FloatTensor] = None
  275. past_key_values: Optional[Cache] = None
  276. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  277. hidden_states_ngram: Optional[tuple[torch.FloatTensor]] = None
  278. attentions: Optional[tuple[torch.FloatTensor]] = None
  279. ngram_attentions: Optional[tuple[torch.FloatTensor]] = None
  280. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  281. @auto_docstring
  282. class ProphetNetPreTrainedModel(PreTrainedModel):
  283. config: ProphetNetConfig
  284. base_model_prefix = "prophetnet"
  285. supports_gradient_checkpointing = True
  286. def _init_weights(self, module):
  287. if isinstance(module, nn.Linear):
  288. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  289. if module.bias is not None:
  290. module.bias.data.zero_()
  291. elif isinstance(module, nn.Embedding):
  292. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  293. if module.padding_idx is not None:
  294. module.weight.data[module.padding_idx].zero_()
  295. def _shift_right(self, input_ids):
  296. decoder_start_token_id = self.config.decoder_start_token_id
  297. pad_token_id = self.config.pad_token_id
  298. assert decoder_start_token_id is not None, (
  299. "self.model.config.decoder_start_token_id has to be defined. In ProphetNet it is usually set to the"
  300. " pad_token_id. See ProphetNet docs for more information"
  301. )
  302. # shift inputs to the right
  303. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  304. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  305. shifted_input_ids[..., 0] = decoder_start_token_id
  306. assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
  307. # replace possible -100 values in labels by `pad_token_id`
  308. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  309. assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
  310. return shifted_input_ids
  311. class ProphetNetPositionalEmbeddings(nn.Embedding):
  312. """
  313. This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
  314. based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
  315. the forward function.
  316. """
  317. def __init__(self, config: ProphetNetConfig) -> None:
  318. self.max_length = config.max_position_embeddings
  319. super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
  320. def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
  321. assert (position_ids is None) or (self.padding_idx is None), (
  322. "If position_ids is pre-computed then padding_idx should not be set."
  323. )
  324. if position_ids is None:
  325. if past_key_values is not None and past_key_values.get_seq_length() != 0:
  326. # position_ids is the same for every token when decoding a single step
  327. # Without the int() cast, it doesn't work in some cases when exporting to ONNX
  328. prev_num_input_ids = past_key_values.get_seq_length()
  329. num_input_ids = inputs_shape[1] + prev_num_input_ids
  330. position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
  331. int(self.padding_idx + num_input_ids)
  332. )
  333. else:
  334. if attention_mask is None:
  335. attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device)
  336. # retrieve position_ids from input_ids / attention_mask
  337. position_ids = (
  338. torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
  339. ).long() + self.padding_idx
  340. # make sure position_ids are not bigger then max_length
  341. position_ids = position_ids.clamp(0, self.max_length - 1)
  342. return super().forward(position_ids), position_ids
  343. def _forward(self, position_ids):
  344. return super().forward(position_ids)
  345. class ProphetNetAttention(nn.Module):
  346. """Multi-headed attention from 'Attention Is All You Need' paper"""
  347. def __init__(self, config: ProphetNetConfig, num_attn_heads: int, layer_idx: Optional[int] = None):
  348. super().__init__()
  349. hidden_size = config.hidden_size
  350. self.attention_dropout = config.attention_dropout
  351. self.dropout = config.dropout
  352. self.num_attn_heads = num_attn_heads
  353. self.head_dim = hidden_size // num_attn_heads
  354. self.layer_idx = layer_idx
  355. assert self.head_dim * num_attn_heads == hidden_size, (
  356. "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
  357. " `config.num_decoder_attention_heads`"
  358. )
  359. self.key_proj = nn.Linear(hidden_size, hidden_size)
  360. self.value_proj = nn.Linear(hidden_size, hidden_size)
  361. self.query_proj = nn.Linear(hidden_size, hidden_size)
  362. self.out_proj = nn.Linear(hidden_size, hidden_size)
  363. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  364. def forward(
  365. self,
  366. hidden_states,
  367. key_value_states: Optional[Tensor] = None,
  368. attention_mask: Optional[Tensor] = None,
  369. layer_head_mask: Optional[Tensor] = None,
  370. past_key_values: Optional[Cache] = None,
  371. output_attentions: Optional[bool] = False,
  372. cache_position: Optional[torch.Tensor] = None,
  373. ) -> tuple[Tensor, Optional[Tensor]]:
  374. batch_size, tgt_len, hidden_size = hidden_states.size()
  375. # if key_value_states are provided this layer is used as a cross-attention layer
  376. # for the decoder
  377. is_cross_attention = key_value_states is not None
  378. assert list(hidden_states.size()) == [
  379. batch_size,
  380. tgt_len,
  381. hidden_size,
  382. ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
  383. # previous time steps are cached - no need to recompute key and value if they are static
  384. query_states = self.query_proj(hidden_states) / (self.head_dim**0.5)
  385. is_updated = False
  386. if past_key_values is not None:
  387. if isinstance(past_key_values, EncoderDecoderCache):
  388. is_updated = past_key_values.is_updated.get(self.layer_idx)
  389. if is_cross_attention:
  390. # after the first generated id, we can subsequently re-use all key/value_states from cache
  391. curr_past_key_value = past_key_values.cross_attention_cache
  392. else:
  393. curr_past_key_value = past_key_values.self_attention_cache
  394. else:
  395. curr_past_key_value = past_key_values
  396. current_states = key_value_states if is_cross_attention else hidden_states
  397. if is_cross_attention and past_key_values is not None and is_updated:
  398. # reuse k,v, cross_attentions
  399. key_states = curr_past_key_value.layers[self.layer_idx].keys
  400. value_states = curr_past_key_value.layers[self.layer_idx].values
  401. else:
  402. key_states = self.key_proj(current_states)
  403. value_states = self.value_proj(current_states)
  404. key_states = key_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2)
  405. value_states = value_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2)
  406. if past_key_values is not None:
  407. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  408. cache_position = cache_position if not is_cross_attention else None
  409. key_states, value_states = curr_past_key_value.update(
  410. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  411. )
  412. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  413. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  414. past_key_values.is_updated[self.layer_idx] = True
  415. query_states = query_states.view(batch_size, tgt_len, self.num_attn_heads, self.head_dim).transpose(1, 2)
  416. src_len = key_states.size(2)
  417. attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
  418. expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
  419. if attn_weights.size() != expected_shape:
  420. raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
  421. # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
  422. if attention_mask is not None and attention_mask.dim() == 0:
  423. attention_mask = None
  424. expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
  425. if attention_mask is not None and attention_mask.size() != expected_shape:
  426. raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
  427. if attention_mask is not None: # don't attend to padding symbols
  428. attn_weights = attn_weights + attention_mask
  429. if output_attentions:
  430. attn_weights_reshaped = attn_weights
  431. else:
  432. attn_weights_reshaped = None
  433. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  434. if layer_head_mask is not None:
  435. assert layer_head_mask.size() == (self.num_attn_heads,), (
  436. f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
  437. f" {layer_head_mask.size()}"
  438. )
  439. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
  440. batch_size, self.num_attn_heads, tgt_len, src_len
  441. )
  442. # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
  443. attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
  444. attn_probs = nn.functional.dropout(
  445. attn_weights,
  446. p=self.attention_dropout,
  447. training=self.training,
  448. )
  449. attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
  450. expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
  451. if attn_output.size() != expected_shape:
  452. raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
  453. attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
  454. attn_output = self.out_proj(attn_output)
  455. attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
  456. return attn_output, attn_weights_reshaped
  457. class ProphetNetFeedForward(nn.Module):
  458. """
  459. This is the residual two feed-forward layer block based on the original Transformer implementation.
  460. """
  461. def __init__(self, config: ProphetNetConfig, ffn_dim: int):
  462. super().__init__()
  463. self.activation_fn = ACT2FN[config.activation_function]
  464. self.intermediate = nn.Linear(config.hidden_size, ffn_dim)
  465. self.output = nn.Linear(ffn_dim, config.hidden_size)
  466. self.activation_dropout = config.activation_dropout
  467. self.dropout = config.dropout
  468. def forward(self, hidden_states):
  469. hidden_states = self.intermediate(hidden_states)
  470. hidden_states = self.activation_fn(hidden_states)
  471. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  472. hidden_states = self.output(hidden_states)
  473. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  474. return hidden_states
  475. class ProphetNetNgramSelfAttention(nn.Module):
  476. def __init__(self, config: ProphetNetConfig, layer_idx=None):
  477. super().__init__()
  478. self.hidden_size = config.hidden_size
  479. self.num_buckets = config.num_buckets
  480. self.relative_max_distance = config.relative_max_distance
  481. self.num_attn_heads = config.num_decoder_attention_heads
  482. self.dropout = config.dropout
  483. self.attention_dropout = config.attention_dropout
  484. self.head_dim = config.hidden_size // self.num_attn_heads
  485. self.ngram = config.ngram
  486. self.layer_idx = layer_idx
  487. assert self.head_dim * self.num_attn_heads == config.hidden_size, (
  488. "config.hidden_size must be divisible by num_attn_heads"
  489. )
  490. # key, value, query projection
  491. self.key_proj = nn.Linear(config.hidden_size, config.hidden_size)
  492. self.value_proj = nn.Linear(config.hidden_size, config.hidden_size)
  493. self.query_proj = nn.Linear(config.hidden_size, config.hidden_size)
  494. # out projection
  495. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
  496. # rel position embeddings
  497. self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads)
  498. # for onnx runtime
  499. self.onnx_trace = False
  500. def _shape(self, tensor, seq_len, batch_size):
  501. return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
  502. def prepare_for_onnx_export_(self):
  503. self.onnx_trace = True
  504. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  505. def forward(
  506. self,
  507. hidden_states,
  508. past_key_values: Optional[Cache] = None,
  509. attention_mask=None,
  510. layer_head_mask=None,
  511. extended_predict_attention_mask=None,
  512. main_relative_position_buckets=None,
  513. predict_relative_position_buckets=None,
  514. position_ids=None,
  515. cache_position=None,
  516. ):
  517. batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
  518. assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
  519. f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
  520. f" {hidden_states.shape}"
  521. )
  522. # project
  523. query_states = self.query_proj(hidden_states)
  524. key_states = self.key_proj(hidden_states)
  525. value_states = self.value_proj(hidden_states)
  526. # normalize
  527. query_states = query_states / (self.head_dim**0.5)
  528. # reshape
  529. query_states = self._shape(query_states, ngram_sequence_length, batch_size)
  530. key_states = self._shape(key_states, -1, batch_size)
  531. value_states = self._shape(value_states, -1, batch_size)
  532. proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
  533. query_states = query_states.reshape(*proj_shape)
  534. key_states = key_states.reshape(*proj_shape)
  535. value_states = value_states.reshape(*proj_shape)
  536. # chunk into main stream and predict stream
  537. hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
  538. query_states_list = query_states.chunk(1 + self.ngram, dim=2)
  539. key_states_list = key_states.chunk(1 + self.ngram, dim=2)
  540. value_states_list = value_states.chunk(1 + self.ngram, dim=2)
  541. main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
  542. main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
  543. main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:]
  544. main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
  545. # ProphetNet has two separate attention layers, one for self and one for cross attention
  546. # We need to obtain the self attention only for this module, if `EncoderDecoderCache`
  547. if past_key_values is not None:
  548. if isinstance(past_key_values, EncoderDecoderCache):
  549. curr_past_key_value = past_key_values.self_attention_cache
  550. else:
  551. curr_past_key_value = past_key_values
  552. main_key_states, main_value_states = curr_past_key_value.update(
  553. main_key_states, main_value_states, self.layer_idx, {"cache_position": cache_position}
  554. )
  555. # get seq_length of main stream only
  556. sequence_length = ngram_sequence_length // (1 + self.ngram)
  557. # MAIN-STREAM
  558. # main attn weights
  559. # [batch_size, number_heads, sequence_length, head_dimesion]
  560. # x [batch_size, number_heads, head_dimesion, sequence_length]
  561. # -> [batch_size, number_heads, sequence_length, sequence_length]
  562. main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
  563. # retrieve relative position embeddings for each layer -> see paper for more details
  564. main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
  565. main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
  566. )
  567. main_attn_weights = main_attn_weights + main_relative_pos_embeddings
  568. if attention_mask is not None:
  569. main_attn_weights = main_attn_weights + attention_mask
  570. main_attn_probs = softmax(
  571. main_attn_weights,
  572. dim=-1,
  573. onnx_trace=self.onnx_trace,
  574. ).type_as(main_attn_weights)
  575. if layer_head_mask is not None:
  576. assert layer_head_mask.size() == (self.num_attn_heads,), (
  577. f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
  578. f" {layer_head_mask.size()}"
  579. )
  580. main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
  581. batch_size, self.num_attn_heads, -1, sequence_length
  582. )
  583. main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
  584. # project to attn_output
  585. # [batch_size, number_heads, sequence_length, sequence_length]
  586. # x [batch_size, number_heads, sequence_length, head_dimesion]
  587. # -> [batch_size, number_heads, sequence_length, head_dimesion]
  588. main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
  589. # reshape so that num_heads dim is merged into last `head_dim` axis
  590. main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
  591. main_attn_output = self.out_proj(main_attn_output)
  592. # PREDICT-STREAM
  593. # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  594. predict_query_states = torch.stack(predict_query_states_list, 1).view(
  595. batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
  596. )
  597. # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  598. predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
  599. # [batch_size, sequence_length, ngram, hidden_size]
  600. predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
  601. # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
  602. predict_value_states = torch.cat(
  603. [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
  604. )
  605. # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  606. # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  607. # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  608. predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
  609. # retrieve relative position embeddings for each layer -> see paper for more details
  610. # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
  611. predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
  612. predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
  613. )
  614. # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  615. predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
  616. if extended_predict_attention_mask is not None:
  617. # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  618. extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
  619. extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
  620. predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
  621. predict_attn_probs = softmax(
  622. predict_attn_weights,
  623. dim=-1,
  624. onnx_trace=self.onnx_trace,
  625. ).type_as(predict_attn_weights)
  626. if layer_head_mask is not None:
  627. assert layer_head_mask.size() == (self.num_attn_heads,), (
  628. f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
  629. f" {layer_head_mask.size()}"
  630. )
  631. predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
  632. predict_attn_probs = nn.functional.dropout(
  633. predict_attn_probs, p=self.attention_dropout, training=self.training
  634. )
  635. # project to attention output
  636. # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
  637. # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
  638. # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
  639. predict_attn_output = torch.einsum(
  640. "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
  641. )
  642. # reshape so that num_heads dim is merged into last `head_dim` axis
  643. # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
  644. predict_attn_output = predict_attn_output.transpose(2, 3)
  645. predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
  646. predict_attn_output = self.out_proj(predict_attn_output)
  647. # concat to single attn output
  648. # [batch_size, (1+ngram)*sequence_length, hidden_size]
  649. attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
  650. # reshape into better form for `config.output_attentions`
  651. main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
  652. attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
  653. return attn_output, main_attn_probs, predict_attn_probs
  654. def get_main_relative_pos_embeddings(
  655. self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
  656. ):
  657. # input hidden_states [batch_size, sequence_length, hidden_size]
  658. # input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
  659. # input position_ids [batch_size, sequence_length] or [1,1]
  660. batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
  661. attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
  662. if main_relative_position_buckets is None:
  663. batch_size, sequence_length = hidden_states.shape[:2]
  664. relative_positions = (
  665. torch.arange(1, attn_weights.shape[-1] + 1)
  666. .unsqueeze(0)
  667. .unsqueeze(0)
  668. .repeat(batch_size, sequence_length, 1)
  669. .to(position_ids.device)
  670. )
  671. # [batch_size, sequence_length, sequence_length+1]
  672. relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
  673. main_relative_position_buckets = compute_relative_buckets(
  674. self.num_buckets, self.relative_max_distance, relative_positions, False
  675. )
  676. # [batch_size, sequence_length, num_buckets * num_heads]
  677. rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
  678. rel_pos_embeddings = rel_pos_embeddings.view(
  679. rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
  680. )
  681. rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
  682. # [batch_size, num_heads, sequence_length, num_buckets]
  683. rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
  684. main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
  685. # [batch_size * num_heads * sequence_length, sequence_length]
  686. main_relative_position_buckets = main_relative_position_buckets.view(
  687. -1, main_relative_position_buckets.shape[-1]
  688. )
  689. main_relative_position_buckets = main_relative_position_buckets.long()
  690. # [batch_size * num_heads * sequence_length, sequence_length]
  691. rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
  692. main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
  693. main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
  694. return main_relative_pos_embeddings
  695. def get_predict_relative_pos_embeddings(
  696. self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
  697. ):
  698. # input hidden_states [batch_size, sequence_length, ngram, hidden_size]
  699. # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
  700. # input position_ids [batch_size, sequence_length] or [1,1]
  701. # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
  702. batch_size, sequence_length = hidden_states.shape[0:2]
  703. if predict_relative_position_buckets is None:
  704. key_sequence_length = attn_weights.shape[-1]
  705. assert position_ids[0][0] == key_sequence_length - 1, (
  706. "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)"
  707. )
  708. relative_positions = (
  709. torch.arange(0, key_sequence_length)
  710. .unsqueeze(0)
  711. .unsqueeze(0)
  712. .repeat(batch_size, sequence_length, 1)
  713. .to(position_ids.device)
  714. )
  715. relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
  716. predict_relative_position_buckets = compute_relative_buckets(
  717. self.num_buckets, self.relative_max_distance, relative_positions, False
  718. )
  719. # [batch_size, ngram, sequence_length, hidden_size]
  720. hidden_states = hidden_states.transpose(1, 2)
  721. rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
  722. # [batch_size, ngram, sequence_length, num_buckets, num_heads]
  723. rel_pos_embeddings = rel_pos_embeddings.view(
  724. hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
  725. )
  726. rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
  727. # [batch_size * ngram * sequence_length * num_heads, num_buckets]
  728. rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
  729. # [ngram, batch_size, num_heads * sequence_length, -1]
  730. predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
  731. predict_relative_position_buckets = predict_relative_position_buckets.repeat(
  732. self.ngram, 1, self.num_attn_heads, 1
  733. )
  734. # [ngram * batch_size * num_heads * sequence_length, -1]
  735. predict_relative_position_buckets = predict_relative_position_buckets.view(
  736. -1, predict_relative_position_buckets.size(-1)
  737. ).long()
  738. predict_relative_pos_embeddings = torch.gather(
  739. rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
  740. )
  741. # [batch_size, gram, num_heads, sequence_length, -1]
  742. predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
  743. batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
  744. )
  745. return predict_relative_pos_embeddings
  746. class ProphetNetEncoderLayer(GradientCheckpointingLayer):
  747. """
  748. Encoder block for Prophetnet
  749. """
  750. def __init__(self, config: ProphetNetConfig):
  751. super().__init__()
  752. # 1st residual block
  753. self.self_attn = ProphetNetAttention(config, config.num_encoder_attention_heads)
  754. self.self_attn_layer_norm = LayerNorm(config.hidden_size)
  755. # 2nd residual block
  756. self.feed_forward = ProphetNetFeedForward(config, config.encoder_ffn_dim)
  757. self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
  758. def forward(
  759. self,
  760. hidden_states,
  761. attention_mask,
  762. layer_head_mask,
  763. output_attentions: bool = False,
  764. ):
  765. # 1st residual block
  766. attention_output, attn_weights = self.self_attn(
  767. hidden_states=hidden_states,
  768. attention_mask=attention_mask,
  769. layer_head_mask=layer_head_mask,
  770. output_attentions=output_attentions,
  771. )
  772. hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
  773. # 2nd residual block
  774. feed_forward_output = self.feed_forward(hidden_states)
  775. hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
  776. outputs = (hidden_states,)
  777. if output_attentions:
  778. outputs += (attn_weights,)
  779. return outputs
  780. class ProphetNetDecoderLayer(GradientCheckpointingLayer):
  781. """
  782. Decoder block for Prophetnet
  783. """
  784. def __init__(self, config: ProphetNetConfig, layer_idx=None):
  785. super().__init__()
  786. # 1st residual block
  787. self.self_attn = ProphetNetNgramSelfAttention(config, layer_idx=layer_idx)
  788. self.self_attn_layer_norm = LayerNorm(config.hidden_size)
  789. # 2nd residual block
  790. if config.add_cross_attention:
  791. self.cross_attn = ProphetNetAttention(config, config.num_decoder_attention_heads, layer_idx=layer_idx)
  792. self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
  793. # 3rd residual block
  794. self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim)
  795. self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
  796. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  797. def forward(
  798. self,
  799. hidden_states,
  800. attention_mask=None,
  801. encoder_hidden_states=None,
  802. encoder_attn_mask=None,
  803. layer_head_mask=None,
  804. cross_attn_layer_head_mask=None,
  805. extended_predict_attention_mask=None,
  806. main_relative_position_buckets=None,
  807. predict_relative_position_buckets=None,
  808. position_ids=None,
  809. past_key_values=None,
  810. use_cache: Optional[bool] = True,
  811. output_attentions: Optional[bool] = False,
  812. cache_position: Optional[torch.Tensor] = None,
  813. ):
  814. # 1st residual block
  815. ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn(
  816. hidden_states=hidden_states,
  817. past_key_values=past_key_values,
  818. attention_mask=attention_mask,
  819. layer_head_mask=layer_head_mask,
  820. extended_predict_attention_mask=extended_predict_attention_mask,
  821. main_relative_position_buckets=main_relative_position_buckets,
  822. predict_relative_position_buckets=predict_relative_position_buckets,
  823. position_ids=position_ids,
  824. )
  825. hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
  826. cross_attn_weights = None
  827. if encoder_hidden_states is not None:
  828. # 2nd residual block
  829. attention_output, cross_attn_weights = self.cross_attn(
  830. hidden_states=hidden_states,
  831. key_value_states=encoder_hidden_states,
  832. attention_mask=encoder_attn_mask,
  833. layer_head_mask=cross_attn_layer_head_mask,
  834. past_key_values=past_key_values,
  835. output_attentions=output_attentions,
  836. )
  837. hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
  838. # 3rd residual block
  839. feed_forward_output = self.feed_forward(hidden_states)
  840. hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
  841. outputs = (hidden_states,)
  842. if output_attentions:
  843. outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
  844. return outputs
  845. @auto_docstring(
  846. custom_intro="""
  847. The standalone encoder part of the ProphetNetModel.
  848. """
  849. )
  850. class ProphetNetEncoder(ProphetNetPreTrainedModel):
  851. def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
  852. r"""
  853. word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
  854. The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word
  855. embeddings instead of randomly initialized word embeddings.
  856. """
  857. super().__init__(config)
  858. self.word_embeddings = (
  859. word_embeddings
  860. if word_embeddings is not None
  861. else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  862. )
  863. self.position_embeddings = ProphetNetPositionalEmbeddings(config)
  864. self.embeddings_layer_norm = LayerNorm(config.hidden_size)
  865. self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
  866. self.gradient_checkpointing = False
  867. # Initialize weights and apply final processing
  868. self.post_init()
  869. def get_input_embeddings(self):
  870. return self.word_embeddings
  871. def set_input_embeddings(self, value):
  872. self.word_embeddings = value
  873. @auto_docstring
  874. def forward(
  875. self,
  876. input_ids: Optional[torch.Tensor] = None,
  877. attention_mask: Optional[torch.Tensor] = None,
  878. head_mask: Optional[torch.Tensor] = None,
  879. inputs_embeds: Optional[torch.Tensor] = None,
  880. output_attentions: Optional[bool] = None,
  881. output_hidden_states: Optional[bool] = None,
  882. return_dict: Optional[bool] = None,
  883. ) -> Union[tuple, BaseModelOutput]:
  884. r"""
  885. Example:
  886. ```python
  887. >>> from transformers import AutoTokenizer, ProphetNetEncoder
  888. >>> import torch
  889. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  890. >>> model = ProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone")
  891. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  892. >>> outputs = model(**inputs)
  893. >>> last_hidden_states = outputs.last_hidden_state
  894. ```"""
  895. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  896. output_hidden_states = (
  897. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  898. )
  899. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  900. if input_ids is None and inputs_embeds is None:
  901. raise ValueError("Either input_ids or inputs_embeds has to be passed.")
  902. elif input_ids is not None and inputs_embeds is not None:
  903. raise ValueError("Make sure to only pass input_ids or inputs_embeds.")
  904. elif input_ids is not None and inputs_embeds is None:
  905. inputs_embeds = self.word_embeddings(input_ids)
  906. # prepare attention mask
  907. if attention_mask is not None:
  908. extended_attention_mask = (
  909. 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
  910. ) * torch.finfo(self.dtype).min
  911. extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
  912. else:
  913. extended_attention_mask = None
  914. position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device)
  915. hidden_states = inputs_embeds + position_embeddings
  916. hidden_states = self.embeddings_layer_norm(hidden_states)
  917. hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training)
  918. encoder_hidden_states = () if output_hidden_states else None
  919. all_attentions = () if output_attentions else None
  920. # check if head_mask has a correct number of layers specified if desired
  921. if head_mask is not None:
  922. assert head_mask.size()[0] == (len(self.layers)), (
  923. f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
  924. )
  925. for idx, encoder_layer in enumerate(self.layers):
  926. if output_hidden_states:
  927. encoder_hidden_states = encoder_hidden_states + (hidden_states,)
  928. layer_outputs = encoder_layer(
  929. hidden_states,
  930. attention_mask=extended_attention_mask,
  931. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  932. output_attentions=output_attentions,
  933. )
  934. hidden_states = layer_outputs[0]
  935. if output_attentions:
  936. all_attentions = all_attentions + (layer_outputs[1],)
  937. if output_hidden_states:
  938. encoder_hidden_states = encoder_hidden_states + (hidden_states,)
  939. if not return_dict:
  940. return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None)
  941. return BaseModelOutput(
  942. last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions
  943. )
  944. @auto_docstring(
  945. custom_intro="""
  946. The standalone decoder part of the ProphetNetModel.
  947. """
  948. )
  949. class ProphetNetDecoder(ProphetNetPreTrainedModel):
  950. def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
  951. r"""
  952. word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
  953. The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word
  954. embeddings instead of randomly initialized word embeddings.
  955. """
  956. super().__init__(config)
  957. self.ngram = config.ngram
  958. self.num_buckets = config.num_buckets
  959. self.relative_max_distance = config.relative_max_distance
  960. self.dropout = config.dropout
  961. self.max_target_positions = config.max_position_embeddings
  962. self.word_embeddings = (
  963. word_embeddings
  964. if word_embeddings is not None
  965. else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  966. )
  967. self.position_embeddings = ProphetNetPositionalEmbeddings(config)
  968. self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
  969. self.layers = nn.ModuleList(
  970. [ProphetNetDecoderLayer(config, layer_idx=i) for i in range(config.num_decoder_layers)]
  971. )
  972. self.embeddings_layer_norm = LayerNorm(config.hidden_size)
  973. self.gradient_checkpointing = False
  974. # Initialize weights and apply final processing
  975. self.post_init()
  976. def get_input_embeddings(self):
  977. return self.word_embeddings
  978. def set_input_embeddings(self, value):
  979. self.word_embeddings = value
  980. @auto_docstring
  981. def forward(
  982. self,
  983. input_ids: Optional[torch.Tensor] = None,
  984. attention_mask: Optional[torch.Tensor] = None,
  985. encoder_hidden_states: Optional[torch.Tensor] = None,
  986. encoder_attention_mask: Optional[torch.Tensor] = None,
  987. head_mask: Optional[torch.Tensor] = None,
  988. cross_attn_head_mask: Optional[torch.Tensor] = None,
  989. past_key_values: Optional[Cache] = None,
  990. inputs_embeds: Optional[torch.Tensor] = None,
  991. use_cache: Optional[bool] = None,
  992. output_attentions: Optional[bool] = None,
  993. output_hidden_states: Optional[bool] = None,
  994. return_dict: Optional[bool] = None,
  995. cache_position: Optional[torch.Tensor] = None,
  996. ) -> Union[tuple, ProphetNetDecoderModelOutput]:
  997. r"""
  998. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  999. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1000. - 1 indicates the head is **not masked**,
  1001. - 0 indicates the head is **masked**.
  1002. Example:
  1003. ```python
  1004. >>> from transformers import AutoTokenizer, ProphetNetDecoder
  1005. >>> import torch
  1006. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1007. >>> model = ProphetNetDecoder.from_pretrained("microsoft/prophetnet-large-uncased", add_cross_attention=False)
  1008. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1009. >>> outputs = model(**inputs)
  1010. >>> last_hidden_states = outputs.last_hidden_state
  1011. ```"""
  1012. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1013. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1014. output_hidden_states = (
  1015. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1016. )
  1017. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1018. if input_ids is None and inputs_embeds is None:
  1019. raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.")
  1020. elif input_ids is not None and inputs_embeds is not None:
  1021. raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.")
  1022. elif input_ids is not None and inputs_embeds is None:
  1023. inputs_embeds = self.word_embeddings(input_ids)
  1024. batch_size, sequence_length = inputs_embeds.shape[:2]
  1025. if self.gradient_checkpointing and self.training:
  1026. if use_cache:
  1027. logger.warning_once(
  1028. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1029. )
  1030. use_cache = False
  1031. if use_cache and past_key_values is None:
  1032. past_key_values = (
  1033. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  1034. if encoder_hidden_states is not None
  1035. else DynamicCache(config=self.config)
  1036. )
  1037. if use_cache and isinstance(past_key_values, tuple):
  1038. logger.warning_once(
  1039. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  1040. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  1041. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  1042. )
  1043. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  1044. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1045. main_stream_pos_embed, position_ids = self.position_embeddings(
  1046. (batch_size, sequence_length),
  1047. device=inputs_embeds.device,
  1048. past_key_values=past_key_values,
  1049. )
  1050. if past_key_values_length != 0:
  1051. main_relative_position_buckets, predict_relative_position_buckets = None, None
  1052. else:
  1053. (
  1054. main_relative_position_buckets,
  1055. predict_relative_position_buckets,
  1056. ) = self.compute_buffered_relative_buckets(position_ids)
  1057. predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1)
  1058. # add position embeddings
  1059. hidden_states = inputs_embeds + main_stream_pos_embed
  1060. ngram_embeddings = self.ngram_embeddings.weight
  1061. # prepare attention mask
  1062. if past_key_values_length != 0:
  1063. assert hidden_states.size(1) == 1, (
  1064. "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
  1065. )
  1066. ngram_hidden_states = [
  1067. (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
  1068. for ngram in range(self.ngram)
  1069. ]
  1070. extended_attention_mask = None
  1071. extended_predict_attention_mask = None
  1072. else:
  1073. ngram_hidden_states = [
  1074. (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
  1075. ]
  1076. extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
  1077. extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
  1078. # prepare encoder attention mask
  1079. if encoder_attention_mask is not None:
  1080. extended_encoder_attention_mask = (
  1081. 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
  1082. ) * torch.finfo(self.dtype).min
  1083. extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
  1084. else:
  1085. extended_encoder_attention_mask = None
  1086. hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
  1087. if self.embeddings_layer_norm:
  1088. hidden_states = self.embeddings_layer_norm(hidden_states)
  1089. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  1090. # init attentions, hidden_states and cache with empty tuples
  1091. all_main_stream_hidden_states = () if output_hidden_states else None
  1092. all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
  1093. all_main_stream_attns = () if output_attentions else None
  1094. all_ngram_stream_attns = () if output_attentions else None
  1095. all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
  1096. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1097. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  1098. if attn_mask is not None:
  1099. assert attn_mask.size()[0] == (len(self.layers)), (
  1100. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1101. f" {head_mask.size()[0]}."
  1102. )
  1103. for idx, decoder_layer in enumerate(self.layers):
  1104. if output_hidden_states:
  1105. # grad cannot be kept because tensor is sliced
  1106. all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
  1107. if self.config.ngram > 0:
  1108. all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
  1109. layer_outputs = decoder_layer(
  1110. hidden_states,
  1111. extended_attention_mask,
  1112. encoder_hidden_states, # as a positional argument for gradient checkpointing
  1113. encoder_attn_mask=extended_encoder_attention_mask,
  1114. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1115. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  1116. extended_predict_attention_mask=extended_predict_attention_mask,
  1117. main_relative_position_buckets=main_relative_position_buckets,
  1118. predict_relative_position_buckets=predict_relative_position_buckets,
  1119. position_ids=position_ids,
  1120. past_key_values=past_key_values,
  1121. use_cache=use_cache,
  1122. output_attentions=output_attentions,
  1123. cache_position=cache_position,
  1124. )
  1125. hidden_states = layer_outputs[0]
  1126. if output_attentions:
  1127. all_main_stream_attns += (layer_outputs[1],)
  1128. all_ngram_stream_attns += (layer_outputs[2],)
  1129. if self.config.add_cross_attention:
  1130. all_cross_attns += (layer_outputs[3],)
  1131. if output_hidden_states:
  1132. all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
  1133. if self.config.ngram > 0:
  1134. all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
  1135. # split last_hidden_state for return
  1136. last_hidden_state = hidden_states[:, :sequence_length]
  1137. last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
  1138. if not return_dict:
  1139. return tuple(
  1140. v
  1141. for v in [
  1142. last_hidden_state,
  1143. last_hidden_state_ngram,
  1144. past_key_values,
  1145. all_main_stream_hidden_states,
  1146. all_ngram_stream_hidden_states,
  1147. all_main_stream_attns,
  1148. all_ngram_stream_attns,
  1149. all_cross_attns,
  1150. ]
  1151. if v is not None
  1152. )
  1153. return ProphetNetDecoderModelOutput(
  1154. last_hidden_state=last_hidden_state,
  1155. last_hidden_state_ngram=last_hidden_state_ngram,
  1156. past_key_values=past_key_values,
  1157. hidden_states=all_main_stream_hidden_states,
  1158. hidden_states_ngram=all_ngram_stream_hidden_states,
  1159. attentions=all_main_stream_attns,
  1160. ngram_attentions=all_ngram_stream_attns,
  1161. cross_attentions=all_cross_attns,
  1162. )
  1163. def compute_buffered_relative_buckets(self, position_ids):
  1164. batch_size, sequence_length = position_ids.shape
  1165. position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1)
  1166. main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets(
  1167. self.num_buckets, self.relative_max_distance, position_ids
  1168. )
  1169. # buffer relative buckets
  1170. main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1)
  1171. predict_relative_buckets = torch.cat(
  1172. [
  1173. predict_relative_buckets[:, :sequence_length, :sequence_length],
  1174. predict_relative_buckets[
  1175. :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length
  1176. ],
  1177. ],
  1178. 2,
  1179. ).repeat(batch_size, 1, 1)
  1180. return main_relative_buckets, predict_relative_buckets
  1181. def prepare_attention_mask(self, hidden_states, attention_mask):
  1182. batch_size, seq_length = hidden_states.shape[:2]
  1183. # get causal mask
  1184. causal_mask = torch.full(
  1185. (seq_length, seq_length),
  1186. torch.finfo(hidden_states.dtype).min,
  1187. dtype=hidden_states.dtype,
  1188. device=hidden_states.device,
  1189. )
  1190. causal_mask = torch.triu(causal_mask, 1)
  1191. extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
  1192. (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
  1193. )
  1194. # add usual attention mask
  1195. if attention_mask is not None:
  1196. extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
  1197. extended_attention_mask = extended_causal_mask + extended_attention_mask
  1198. else:
  1199. extended_attention_mask = extended_causal_mask
  1200. return extended_attention_mask.to(hidden_states.dtype)
  1201. def prepare_predict_attention_mask(self, hidden_states, attention_mask):
  1202. batch_size, seq_length = hidden_states.shape[:2]
  1203. # get causal mask
  1204. predict_causal_mask = ngram_attention_bias(
  1205. self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype
  1206. )
  1207. predict_causal_mask = torch.cat(
  1208. [
  1209. predict_causal_mask[:, :seq_length, :seq_length],
  1210. predict_causal_mask[
  1211. :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length
  1212. ],
  1213. ],
  1214. dim=-1,
  1215. )
  1216. extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
  1217. (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
  1218. )
  1219. # add usual attention mask
  1220. if attention_mask is not None:
  1221. extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
  1222. extended_attention_mask = extended_attention_mask.expand(
  1223. (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
  1224. )
  1225. # predicted stream attention_mask should always be 0
  1226. extended_attention_mask = torch.cat(
  1227. [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
  1228. )
  1229. extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
  1230. else:
  1231. extended_predict_attention_mask = extended_predict_causal_mask
  1232. return extended_predict_attention_mask.to(hidden_states.dtype)
  1233. @auto_docstring
  1234. class ProphetNetModel(ProphetNetPreTrainedModel):
  1235. _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
  1236. def __init__(self, config: ProphetNetConfig):
  1237. super().__init__(config)
  1238. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1239. encoder_config = copy.deepcopy(config)
  1240. encoder_config.use_cache = False
  1241. encoder_config.tie_encoder_decoder = False
  1242. self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings)
  1243. decoder_config = copy.deepcopy(config)
  1244. decoder_config.is_decoder = True
  1245. decoder_config.tie_encoder_decoder = False
  1246. self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings)
  1247. # Initialize weights and apply final processing
  1248. self.post_init()
  1249. def get_input_embeddings(self):
  1250. return self.word_embeddings
  1251. def set_input_embeddings(self, value):
  1252. self.word_embeddings = value
  1253. self.encoder.word_embeddings = self.word_embeddings
  1254. self.decoder.word_embeddings = self.word_embeddings
  1255. def _tie_weights(self):
  1256. if self.config.tie_word_embeddings:
  1257. self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings)
  1258. self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings)
  1259. def get_encoder(self):
  1260. return self.encoder
  1261. @auto_docstring
  1262. def forward(
  1263. self,
  1264. input_ids: Optional[torch.Tensor] = None,
  1265. attention_mask: Optional[torch.Tensor] = None,
  1266. decoder_input_ids: Optional[torch.Tensor] = None,
  1267. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1268. head_mask: Optional[torch.Tensor] = None,
  1269. decoder_head_mask: Optional[torch.Tensor] = None,
  1270. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1271. encoder_outputs: Optional[tuple] = None,
  1272. past_key_values: Optional[Cache] = None,
  1273. inputs_embeds: Optional[torch.Tensor] = None,
  1274. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1275. use_cache: Optional[bool] = None,
  1276. output_attentions: Optional[bool] = None,
  1277. output_hidden_states: Optional[bool] = None,
  1278. return_dict: Optional[bool] = None,
  1279. cache_position: Optional[torch.Tensor] = None,
  1280. ) -> Union[tuple, ProphetNetSeq2SeqModelOutput]:
  1281. r"""
  1282. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1283. Indices of decoder input sequence tokens in the vocabulary.
  1284. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1285. [`PreTrainedTokenizer.__call__`] for details.
  1286. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1287. ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  1288. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1289. `past_key_values`).
  1290. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1291. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1292. be used by default.
  1293. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1294. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1295. - 1 indicates the head is **not masked**,
  1296. - 0 indicates the head is **masked**.
  1297. Example:
  1298. ```python
  1299. >>> from transformers import AutoTokenizer, ProphetNetModel
  1300. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1301. >>> model = ProphetNetModel.from_pretrained("microsoft/prophetnet-large-uncased")
  1302. >>> input_ids = tokenizer(
  1303. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1304. ... ).input_ids # Batch size 1
  1305. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1306. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1307. >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
  1308. >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
  1309. ```"""
  1310. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1311. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1312. output_hidden_states = (
  1313. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1314. )
  1315. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1316. if encoder_outputs is None:
  1317. encoder_outputs = self.encoder(
  1318. input_ids=input_ids,
  1319. attention_mask=attention_mask,
  1320. head_mask=head_mask,
  1321. inputs_embeds=inputs_embeds,
  1322. output_attentions=output_attentions,
  1323. output_hidden_states=output_hidden_states,
  1324. return_dict=return_dict,
  1325. )
  1326. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1327. decoder_outputs = self.decoder(
  1328. input_ids=decoder_input_ids,
  1329. attention_mask=decoder_attention_mask,
  1330. encoder_hidden_states=encoder_outputs[0],
  1331. encoder_attention_mask=attention_mask,
  1332. head_mask=decoder_head_mask,
  1333. cross_attn_head_mask=cross_attn_head_mask,
  1334. past_key_values=past_key_values,
  1335. inputs_embeds=decoder_inputs_embeds,
  1336. output_attentions=output_attentions,
  1337. output_hidden_states=output_hidden_states,
  1338. use_cache=use_cache,
  1339. return_dict=return_dict,
  1340. cache_position=cache_position,
  1341. )
  1342. if not return_dict:
  1343. return decoder_outputs + encoder_outputs
  1344. return ProphetNetSeq2SeqModelOutput(
  1345. last_hidden_state=decoder_outputs.last_hidden_state,
  1346. last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram,
  1347. past_key_values=decoder_outputs.past_key_values,
  1348. decoder_hidden_states=decoder_outputs.hidden_states,
  1349. decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram,
  1350. decoder_attentions=decoder_outputs.attentions,
  1351. decoder_ngram_attentions=decoder_outputs.ngram_attentions,
  1352. cross_attentions=decoder_outputs.cross_attentions,
  1353. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1354. encoder_hidden_states=encoder_outputs.hidden_states,
  1355. encoder_attentions=encoder_outputs.attentions,
  1356. )
  1357. @auto_docstring(
  1358. custom_intro="""
  1359. The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.
  1360. """
  1361. )
  1362. class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin):
  1363. _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
  1364. def __init__(self, config: ProphetNetConfig):
  1365. super().__init__(config)
  1366. self.prophetnet = ProphetNetModel(config)
  1367. self.padding_idx = config.pad_token_id
  1368. self.disable_ngram_loss = config.disable_ngram_loss
  1369. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1370. # Initialize weights and apply final processing
  1371. self.post_init()
  1372. def _tie_weights(self):
  1373. if self.config.tie_word_embeddings:
  1374. self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head)
  1375. def get_input_embeddings(self):
  1376. return self.prophetnet.word_embeddings
  1377. @auto_docstring
  1378. def forward(
  1379. self,
  1380. input_ids: Optional[torch.Tensor] = None,
  1381. attention_mask: Optional[torch.Tensor] = None,
  1382. decoder_input_ids: Optional[torch.Tensor] = None,
  1383. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1384. head_mask: Optional[torch.Tensor] = None,
  1385. decoder_head_mask: Optional[torch.Tensor] = None,
  1386. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1387. encoder_outputs: Optional[torch.Tensor] = None,
  1388. past_key_values: Optional[Cache] = None,
  1389. inputs_embeds: Optional[torch.Tensor] = None,
  1390. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1391. labels: Optional[torch.Tensor] = None,
  1392. use_cache: Optional[bool] = None,
  1393. output_attentions: Optional[bool] = None,
  1394. output_hidden_states: Optional[bool] = None,
  1395. return_dict: Optional[bool] = None,
  1396. cache_position: Optional[torch.Tensor] = None,
  1397. ) -> Union[tuple, ProphetNetSeq2SeqLMOutput]:
  1398. r"""
  1399. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1400. Indices of decoder input sequence tokens in the vocabulary.
  1401. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1402. [`PreTrainedTokenizer.__call__`] for details.
  1403. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1404. ProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  1405. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1406. `past_key_values`).
  1407. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1408. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1409. be used by default.
  1410. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1411. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1412. - 1 indicates the head is **not masked**,
  1413. - 0 indicates the head is **masked**.
  1414. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1415. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1416. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1417. labels in `[0, ..., config.vocab_size]`
  1418. Example:
  1419. ```python
  1420. >>> from transformers import AutoTokenizer, ProphetNetForConditionalGeneration
  1421. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1422. >>> model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")
  1423. >>> input_ids = tokenizer(
  1424. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1425. ... ).input_ids # Batch size 1
  1426. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1427. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1428. >>> logits_next_token = outputs.logits # logits to predict next token as usual
  1429. >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens
  1430. ```"""
  1431. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1432. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1433. # get decoder inputs from shifting lm labels to the right
  1434. decoder_input_ids = self._shift_right(labels)
  1435. outputs = self.prophetnet(
  1436. input_ids=input_ids,
  1437. attention_mask=attention_mask,
  1438. decoder_input_ids=decoder_input_ids,
  1439. decoder_attention_mask=decoder_attention_mask,
  1440. head_mask=head_mask,
  1441. decoder_head_mask=decoder_head_mask,
  1442. cross_attn_head_mask=cross_attn_head_mask,
  1443. encoder_outputs=encoder_outputs,
  1444. past_key_values=past_key_values,
  1445. inputs_embeds=inputs_embeds,
  1446. decoder_inputs_embeds=decoder_inputs_embeds,
  1447. use_cache=use_cache,
  1448. output_attentions=output_attentions,
  1449. output_hidden_states=output_hidden_states,
  1450. return_dict=return_dict,
  1451. cache_position=cache_position,
  1452. )
  1453. batch_size, sequence_length = (
  1454. decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2]
  1455. )
  1456. predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
  1457. predict_logits = self.lm_head(predicting_streams)
  1458. logits = predict_logits[:, 0]
  1459. logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
  1460. # To use .view in loss computation, make sure that logits is contiguous.
  1461. if not logits.is_contiguous():
  1462. logits = logits.contiguous()
  1463. loss = None
  1464. if labels is not None:
  1465. loss = self._compute_loss(predict_logits, labels)
  1466. if not return_dict:
  1467. all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
  1468. return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
  1469. else:
  1470. return ProphetNetSeq2SeqLMOutput(
  1471. loss=loss,
  1472. logits=logits,
  1473. logits_ngram=logits_ngram,
  1474. past_key_values=outputs.past_key_values,
  1475. decoder_hidden_states=outputs.decoder_hidden_states,
  1476. decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states,
  1477. decoder_attentions=outputs.decoder_attentions,
  1478. decoder_ngram_attentions=outputs.decoder_ngram_attentions,
  1479. cross_attentions=outputs.cross_attentions,
  1480. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1481. encoder_hidden_states=outputs.encoder_hidden_states,
  1482. encoder_attentions=outputs.encoder_attentions,
  1483. )
  1484. def _compute_loss(self, logits, labels, ignore_index=-100):
  1485. expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
  1486. for i in range(self.config.ngram):
  1487. if i > 0 and self.disable_ngram_loss:
  1488. break
  1489. expend_targets[i, :, :] = labels
  1490. logits = logits.transpose(0, 1).contiguous()
  1491. lprobs = nn.functional.log_softmax(
  1492. logits.view(-1, logits.size(-1)),
  1493. dim=-1,
  1494. dtype=torch.float32,
  1495. )
  1496. loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
  1497. if self.config.eps > 0.0:
  1498. smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
  1499. non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
  1500. smooth_loss = smooth_loss[non_masked_tokens]
  1501. smooth_loss = smooth_loss.mean()
  1502. eps_i = self.config.eps / lprobs.size(-1)
  1503. loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
  1504. return loss
  1505. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1506. return self._shift_right(labels)
  1507. def get_encoder(self):
  1508. return self.prophetnet.encoder
  1509. def get_decoder(self):
  1510. return self.prophetnet.decoder
  1511. @auto_docstring(
  1512. custom_intro="""
  1513. The standalone decoder part of the ProphetNetModel with a lm head on top. The model can be used for causal
  1514. """
  1515. )
  1516. class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin):
  1517. _tied_weights_keys = [
  1518. "prophetnet.word_embeddings.weight",
  1519. "prophetnet.decoder.word_embeddings.weight",
  1520. "lm_head.weight",
  1521. ]
  1522. def __init__(self, config: ProphetNetConfig):
  1523. # set config for CLM
  1524. config = copy.deepcopy(config)
  1525. config.is_decoder = True
  1526. config.is_encoder_decoder = False
  1527. super().__init__(config)
  1528. self.prophetnet = ProphetNetDecoderWrapper(config)
  1529. self.padding_idx = config.pad_token_id
  1530. self.disable_ngram_loss = config.disable_ngram_loss
  1531. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1532. # Initialize weights and apply final processing
  1533. self.post_init()
  1534. def get_input_embeddings(self):
  1535. return self.prophetnet.decoder.word_embeddings
  1536. def set_input_embeddings(self, value):
  1537. self.prophetnet.decoder.word_embeddings = value
  1538. def _tie_weights(self):
  1539. if self.config.tie_word_embeddings:
  1540. self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head)
  1541. def set_decoder(self, decoder):
  1542. self.prophetnet.decoder = decoder
  1543. def get_decoder(self):
  1544. return self.prophetnet.decoder
  1545. @auto_docstring
  1546. def forward(
  1547. self,
  1548. input_ids: Optional[torch.Tensor] = None,
  1549. attention_mask: Optional[torch.Tensor] = None,
  1550. encoder_hidden_states: Optional[torch.Tensor] = None,
  1551. encoder_attention_mask: Optional[torch.Tensor] = None,
  1552. head_mask: Optional[torch.Tensor] = None,
  1553. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1554. past_key_values: Optional[Cache] = None,
  1555. inputs_embeds: Optional[torch.Tensor] = None,
  1556. labels: Optional[torch.Tensor] = None,
  1557. use_cache: Optional[bool] = None,
  1558. output_attentions: Optional[bool] = None,
  1559. output_hidden_states: Optional[bool] = None,
  1560. return_dict: Optional[bool] = None,
  1561. ) -> Union[tuple, ProphetNetDecoderLMOutput]:
  1562. r"""
  1563. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1564. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1565. - 1 indicates the head is **not masked**,
  1566. - 0 indicates the head is **masked**.
  1567. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1568. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1569. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1570. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1571. Example:
  1572. ```python
  1573. >>> from transformers import AutoTokenizer, ProphetNetForCausalLM
  1574. >>> import torch
  1575. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1576. >>> model = ProphetNetForCausalLM.from_pretrained("microsoft/prophetnet-large-uncased")
  1577. >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
  1578. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  1579. >>> outputs = model(**inputs)
  1580. >>> logits = outputs.logits
  1581. >>> # Model can also be used with EncoderDecoder framework
  1582. >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer
  1583. >>> import torch
  1584. >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
  1585. >>> tokenizer_dec = AutoTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")
  1586. >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
  1587. ... "google-bert/bert-large-uncased", "microsoft/prophetnet-large-uncased"
  1588. ... )
  1589. >>> ARTICLE = (
  1590. ... "the us state department said wednesday it had received no "
  1591. ... "formal word from bolivia that it was expelling the us ambassador there "
  1592. ... "but said the charges made against him are `` baseless ."
  1593. ... )
  1594. >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
  1595. >>> labels = tokenizer_dec(
  1596. ... "us rejects charges against its ambassador in bolivia", return_tensors="pt"
  1597. ... ).input_ids
  1598. >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])
  1599. >>> loss = outputs.loss
  1600. ```"""
  1601. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1602. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  1603. outputs = self.prophetnet.decoder(
  1604. input_ids=input_ids,
  1605. attention_mask=attention_mask,
  1606. encoder_hidden_states=encoder_hidden_states,
  1607. encoder_attention_mask=encoder_attention_mask,
  1608. head_mask=head_mask,
  1609. cross_attn_head_mask=cross_attn_head_mask,
  1610. past_key_values=past_key_values,
  1611. inputs_embeds=inputs_embeds,
  1612. use_cache=use_cache,
  1613. output_attentions=output_attentions,
  1614. output_hidden_states=output_hidden_states,
  1615. return_dict=return_dict,
  1616. )
  1617. batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
  1618. predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
  1619. predict_logits = self.lm_head(predicting_streams)
  1620. logits = predict_logits[:, 0]
  1621. logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
  1622. loss = None
  1623. if labels is not None:
  1624. loss = self._compute_loss(predict_logits, labels)
  1625. if not return_dict:
  1626. all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
  1627. return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
  1628. else:
  1629. return ProphetNetDecoderLMOutput(
  1630. loss=loss,
  1631. logits=logits,
  1632. logits_ngram=logits_ngram,
  1633. past_key_values=outputs.past_key_values,
  1634. hidden_states=outputs.hidden_states,
  1635. hidden_states_ngram=outputs.hidden_states_ngram,
  1636. attentions=outputs.attentions,
  1637. ngram_attentions=outputs.ngram_attentions,
  1638. cross_attentions=outputs.cross_attentions,
  1639. )
  1640. def _compute_loss(self, logits, labels, ignore_index=-100):
  1641. expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
  1642. for i in range(self.config.ngram):
  1643. if i > 0 and self.disable_ngram_loss:
  1644. break
  1645. expend_targets[i, :, :] = labels
  1646. logits = logits.transpose(0, 1).contiguous()
  1647. lprobs = nn.functional.log_softmax(
  1648. logits.view(-1, logits.size(-1)),
  1649. dim=-1,
  1650. dtype=torch.float32,
  1651. )
  1652. loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
  1653. if self.config.eps > 0.0:
  1654. smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
  1655. non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
  1656. smooth_loss = smooth_loss[non_masked_tokens]
  1657. smooth_loss = smooth_loss.mean()
  1658. eps_i = self.config.eps / lprobs.size(-1)
  1659. loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
  1660. return loss
  1661. def prepare_inputs_for_generation(
  1662. self,
  1663. input_ids,
  1664. past_key_values=None,
  1665. attention_mask=None,
  1666. head_mask=None,
  1667. use_cache=None,
  1668. **kwargs,
  1669. ):
  1670. # Overwritten -- our tests complain if we use GenerationMixin.prepare_inputs_for_generation
  1671. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  1672. if attention_mask is None:
  1673. attention_mask = input_ids.new_ones(input_ids.shape)
  1674. if past_key_values is not None and past_key_values.get_seq_length() > 0:
  1675. input_ids = input_ids[:, -1:]
  1676. # first step, decoder_cached_states are empty
  1677. model_inputs = {
  1678. "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
  1679. "attention_mask": attention_mask,
  1680. "head_mask": head_mask,
  1681. "past_key_values": past_key_values,
  1682. "use_cache": use_cache,
  1683. }
  1684. # Prophetnet does not support cache_position
  1685. kwargs.pop("cache_position", None)
  1686. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1687. for key, value in kwargs.items():
  1688. if key not in model_inputs:
  1689. model_inputs[key] = value
  1690. return model_inputs
  1691. class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
  1692. """
  1693. This is a wrapper class, so that [`ProphetNetForCausalLM`] can correctly be loaded from pretrained prophetnet
  1694. classes.
  1695. """
  1696. def __init__(self, config: ProphetNetConfig):
  1697. super().__init__(config)
  1698. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  1699. self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings)
  1700. # Initialize weights and apply final processing
  1701. self.post_init()
  1702. def _tie_weights(self):
  1703. self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings())
  1704. def forward(self, *args, **kwargs):
  1705. return self.decoder(*args, **kwargs)
  1706. __all__ = [
  1707. "ProphetNetDecoder",
  1708. "ProphetNetEncoder",
  1709. "ProphetNetForCausalLM",
  1710. "ProphetNetForConditionalGeneration",
  1711. "ProphetNetModel",
  1712. "ProphetNetPreTrainedModel",
  1713. ]