rec_unimernet_head.py 93 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/opendatalab/UniMERNet/blob/main/unimernet/models/unimernet/configuration_unimernet_decoder.py
  17. """
  18. import copy
  19. import math
  20. import re
  21. import numpy as np
  22. import inspect
  23. import warnings
  24. from collections import OrderedDict
  25. from typing import Optional, Tuple, Union, List, Dict, Any
  26. from dataclasses import dataclass, fields, is_dataclass
  27. import paddle
  28. import paddle.nn as nn
  29. from paddle import Tensor
  30. import paddle.nn.functional as F
  31. from paddle.nn import CrossEntropyLoss
  32. from paddle.nn.initializer import (
  33. TruncatedNormal,
  34. Constant,
  35. Normal,
  36. KaimingUniform,
  37. XavierUniform,
  38. XavierNormal,
  39. )
  40. zeros_ = Constant(value=0.0)
  41. ones_ = Constant(value=1.0)
  42. kaiming_normal_ = KaimingUniform(nonlinearity="relu")
  43. trunc_normal_ = TruncatedNormal(std=0.02)
  44. xavier_uniform_ = XavierUniform()
  45. xavier_normal_ = XavierNormal()
  46. class ModelOutput(OrderedDict):
  47. def __init__(self, *args, **kwargs):
  48. super().__init__(*args, **kwargs)
  49. def __post_init__(self):
  50. class_fields = fields(self)
  51. if not len(class_fields):
  52. raise ValueError(f"{self.__class__.__name__} has no fields.")
  53. if not all(field.default is None for field in class_fields[1:]):
  54. raise ValueError(
  55. f"{self.__class__.__name__} should not have more than one required field."
  56. )
  57. first_field = getattr(self, class_fields[0].name)
  58. other_fields_are_none = all(
  59. getattr(self, field.name) is None for field in class_fields[1:]
  60. )
  61. if other_fields_are_none:
  62. if isinstance(first_field, dict):
  63. iterator = first_field.items()
  64. first_field_iterator = True
  65. else:
  66. try:
  67. iterator = iter(first_field)
  68. first_field_iterator = True
  69. except TypeError:
  70. first_field_iterator = False
  71. if first_field_iterator:
  72. for idx, element in enumerate(iterator):
  73. if (
  74. not isinstance(element, (list, tuple))
  75. or not len(element) == 2
  76. or not isinstance(element[0], str)
  77. ):
  78. if idx == 0:
  79. self[class_fields[0].name] = first_field
  80. else:
  81. raise ValueError(
  82. f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
  83. )
  84. break
  85. setattr(self, element[0], element[1])
  86. if element[1] is not None:
  87. self[element[0]] = element[1]
  88. elif first_field is not None:
  89. self[class_fields[0].name] = first_field
  90. else:
  91. for field in class_fields:
  92. v = getattr(self, field.name)
  93. if v is not None:
  94. self[field.name] = v
  95. def __delitem__(self, *args, **kwargs):
  96. raise Exception(
  97. f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
  98. )
  99. def setdefault(self, *args, **kwargs):
  100. raise Exception(
  101. f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
  102. )
  103. def pop(self, *args, **kwargs):
  104. raise Exception(
  105. f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
  106. )
  107. def update(self, *args, **kwargs):
  108. raise Exception(
  109. f"You cannot use ``update`` on a {self.__class__.__name__} instance."
  110. )
  111. def __getitem__(self, k):
  112. if isinstance(k, str):
  113. inner_dict = dict(self.items())
  114. return inner_dict[k]
  115. else:
  116. return self.to_tuple()[k]
  117. def __setattr__(self, name, value):
  118. if name in self.keys() and value is not None:
  119. super().__setitem__(name, value)
  120. super().__setattr__(name, value)
  121. def __setitem__(self, key, value):
  122. super().__setitem__(key, value)
  123. super().__setattr__(key, value)
  124. def __reduce__(self):
  125. if not is_dataclass(self):
  126. return super().__reduce__()
  127. callable, _args, *remaining = super().__reduce__()
  128. args = tuple(getattr(self, field.name) for field in fields(self))
  129. return callable, args, *remaining
  130. def to_tuple(self):
  131. return tuple(self[k] for k in self.keys())
  132. @dataclass
  133. class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
  134. last_hidden_state = None
  135. past_key_values = None
  136. hidden_states = None
  137. attentions = None
  138. cross_attentions = None
  139. def __init__(self, *args, **kwargs):
  140. super().__init__(*args, **kwargs)
  141. @dataclass
  142. class Seq2SeqLMOutput(ModelOutput):
  143. loss = None
  144. logits = None
  145. past_key_values = None
  146. decoder_hidden_states = None
  147. decoder_attentions = None
  148. cross_attentions = None
  149. encoder_last_hidden_state = None
  150. encoder_hidden_states = None
  151. encoder_attentions = None
  152. def __init__(self, *args, **kwargs):
  153. super().__init__(*args, **kwargs)
  154. class MBartConfig(object):
  155. model_type = "mbart"
  156. keys_to_ignore_at_inference = ["past_key_values"]
  157. attribute_map = {
  158. "num_attention_heads": "encoder_attention_heads",
  159. "hidden_size": "d_model",
  160. }
  161. def __init__(
  162. self,
  163. vocab_size=50265,
  164. max_position_embeddings=1024,
  165. encoder_layers=12,
  166. encoder_ffn_dim=4096,
  167. encoder_attention_heads=16,
  168. decoder_layers=12,
  169. decoder_ffn_dim=4096,
  170. decoder_attention_heads=16,
  171. encoder_layerdrop=0.0,
  172. decoder_layerdrop=0.0,
  173. use_cache=True,
  174. is_encoder_decoder=True,
  175. activation_function="gelu",
  176. d_model=1024,
  177. dropout=0.1,
  178. output_hidden_states=False,
  179. use_return_dict=True,
  180. attention_dropout=0.0,
  181. activation_dropout=0.0,
  182. init_std=0.02,
  183. classifier_dropout=0.0,
  184. scale_embedding=False,
  185. pad_token_id=1,
  186. bos_token_id=0,
  187. eos_token_id=2,
  188. forced_eos_token_id=2,
  189. _attn_implementation="eager",
  190. hidden_size=1024,
  191. use_parallel=False,
  192. parallel_step=2,
  193. is_export=False,
  194. **kwargs,
  195. ):
  196. self.vocab_size = vocab_size
  197. self.hidden_size = hidden_size
  198. self.max_position_embeddings = max_position_embeddings
  199. self.d_model = d_model
  200. self.encoder_ffn_dim = encoder_ffn_dim
  201. self.encoder_layers = encoder_layers
  202. self.encoder_attention_heads = encoder_attention_heads
  203. self.decoder_ffn_dim = decoder_ffn_dim
  204. self.decoder_layers = decoder_layers
  205. self.decoder_attention_heads = decoder_attention_heads
  206. self.dropout = dropout
  207. self.output_hidden_states = output_hidden_states
  208. self.use_return_dict = use_return_dict
  209. self.attention_dropout = attention_dropout
  210. self.activation_dropout = activation_dropout
  211. self.activation_function = activation_function
  212. self.init_std = init_std
  213. self.encoder_layerdrop = encoder_layerdrop
  214. self.decoder_layerdrop = decoder_layerdrop
  215. self.classifier_dropout = classifier_dropout
  216. self.use_cache = use_cache
  217. self.num_hidden_layers = encoder_layers
  218. self.scale_embedding = (
  219. scale_embedding # scale factor will be sqrt(d_model) if True
  220. )
  221. self.pad_token_id = pad_token_id
  222. self.bos_token_id = bos_token_id
  223. self.eos_token_id = eos_token_id
  224. self.is_encoder_decoder = is_encoder_decoder
  225. self.forced_eos_token_id = forced_eos_token_id
  226. self._attn_implementation = _attn_implementation
  227. self.use_parallel = use_parallel
  228. self.parallel_step = parallel_step
  229. self.is_export = is_export
  230. super().__init__()
  231. @dataclass
  232. class AttentionMaskConverter:
  233. """
  234. A utility class for converting attention masks used in transformer models.
  235. This class handles the conversion of attention masks based on whether the
  236. attention mechanism is causal (i.e., preventing information flow from future
  237. tokens to past tokens) and whether a sliding window approach is used.
  238. Attributes:
  239. is_causal (bool): Indicates if the attention mechanism is causal.
  240. sliding_window (Optional[int]): Specifies the size of the sliding window
  241. for local attention, if applicable.
  242. Args:
  243. is_causal (bool): Determines if the attention mask should enforce causality.
  244. sliding_window (Optional[int], optional): The size of the sliding window
  245. for local attention. Default is None.
  246. """
  247. is_causal: bool
  248. sliding_window: int
  249. def __init__(self, is_causal: bool, sliding_window=None):
  250. self.is_causal = is_causal
  251. self.sliding_window = sliding_window
  252. if self.sliding_window is not None and self.sliding_window <= 0:
  253. raise ValueError(
  254. f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
  255. )
  256. @staticmethod
  257. def _make_causal_mask(
  258. input_ids_shape,
  259. dtype,
  260. past_key_values_length=0,
  261. sliding_window=None,
  262. is_export=False,
  263. ):
  264. bsz, tgt_len = input_ids_shape
  265. if is_export:
  266. mask = paddle.full(
  267. (tgt_len, tgt_len), paddle.finfo(dtype).min, dtype="float64"
  268. )
  269. else:
  270. mask = paddle.full((tgt_len, tgt_len), paddle.finfo(dtype).min)
  271. mask_cond = paddle.arange(mask.shape[-1])
  272. mask = mask.masked_fill_(
  273. mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
  274. )
  275. return mask[None, None, :, :].expand(
  276. [bsz, 1, tgt_len, tgt_len + past_key_values_length]
  277. )
  278. def to_4d_export(
  279. self,
  280. attention_mask_2d,
  281. query_length,
  282. dtype,
  283. key_value_length,
  284. is_export=False,
  285. ):
  286. input_shape = (attention_mask_2d.shape[0], query_length)
  287. expanded_attn_mask = self._expand_mask(
  288. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  289. )
  290. expanded_4d_mask = expanded_attn_mask
  291. return expanded_4d_mask
  292. def to_4d(
  293. self,
  294. attention_mask_2d,
  295. query_length,
  296. dtype,
  297. key_value_length,
  298. is_export=False,
  299. ):
  300. input_shape = (attention_mask_2d.shape[0], query_length)
  301. causal_4d_mask = None
  302. if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  303. if key_value_length is None:
  304. raise ValueError(
  305. "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
  306. )
  307. past_key_values_length = key_value_length - query_length
  308. causal_4d_mask = self._make_causal_mask(
  309. input_shape,
  310. dtype,
  311. past_key_values_length=past_key_values_length,
  312. sliding_window=self.sliding_window,
  313. is_export=is_export,
  314. )
  315. elif self.sliding_window is not None:
  316. raise NotImplementedError(
  317. "Sliding window is currently only implemented for causal masking"
  318. )
  319. expanded_attn_mask = self._expand_mask(
  320. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  321. )
  322. if causal_4d_mask is not None:
  323. if is_export:
  324. expanded_attn_mask = causal_4d_mask
  325. return expanded_attn_mask
  326. else:
  327. expanded_attn_mask = causal_4d_mask.masked_fill_(
  328. expanded_attn_mask.cast(paddle.bool), paddle.finfo(dtype).min
  329. )
  330. expanded_4d_mask = expanded_attn_mask
  331. return expanded_4d_mask
  332. def _expand_mask(self, mask, dtype, tgt_len=None):
  333. bsz, src_len = mask.shape
  334. tgt_len = tgt_len if tgt_len is not None else src_len
  335. expanded_mask = (
  336. mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
  337. )
  338. inverted_mask = 1.0 - expanded_mask
  339. return inverted_mask.masked_fill_(
  340. inverted_mask.cast(paddle.bool), paddle.finfo(dtype).min
  341. )
  342. def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
  343. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  344. def _prepare_4d_causal_attention_mask_export(
  345. attention_mask,
  346. input_shape,
  347. inputs_embeds,
  348. past_key_values_length,
  349. sliding_window=None,
  350. is_export=False,
  351. ):
  352. attn_mask_converter = AttentionMaskConverter(
  353. is_causal=True, sliding_window=sliding_window
  354. )
  355. key_value_length = input_shape[-1] + past_key_values_length
  356. shape = attention_mask.shape
  357. len_shape = len(shape)
  358. attention_mask = attn_mask_converter.to_4d_export(
  359. attention_mask,
  360. input_shape[-1],
  361. key_value_length=key_value_length,
  362. dtype=inputs_embeds.dtype,
  363. is_export=is_export,
  364. )
  365. return attention_mask
  366. def _prepare_4d_causal_attention_mask(
  367. attention_mask,
  368. input_shape,
  369. inputs_embeds,
  370. past_key_values_length,
  371. sliding_window=None,
  372. is_export=False,
  373. ):
  374. attn_mask_converter = AttentionMaskConverter(
  375. is_causal=True, sliding_window=sliding_window
  376. )
  377. key_value_length = input_shape[-1] + past_key_values_length
  378. shape = attention_mask.shape
  379. len_shape = len(shape)
  380. if (attention_mask is not None) and (len_shape == 2):
  381. attention_mask = attn_mask_converter.to_4d(
  382. attention_mask,
  383. input_shape[-1],
  384. key_value_length=key_value_length,
  385. dtype=inputs_embeds.dtype,
  386. is_export=is_export,
  387. )
  388. return attention_mask
  389. elif attention_mask is not None and len(attention_mask.shape) == 4:
  390. expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
  391. if tuple(attention_mask.shape) != expected_shape:
  392. raise ValueError(
  393. f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
  394. )
  395. else:
  396. inverted_mask = 1.0 - attention_mask
  397. attention_mask = inverted_mask.masked_fill_(
  398. inverted_mask.to(paddle.bool), paddle.finfo(inputs_embeds.dtype).min
  399. )
  400. else:
  401. attention_mask = attn_mask_converter.to_causal_4d(
  402. input_shape[0],
  403. input_shape[-1],
  404. key_value_length,
  405. dtype=inputs_embeds.dtype,
  406. )
  407. return attention_mask
  408. class MBartLearnedPositionalEmbedding(nn.Embedding):
  409. """
  410. This module learns positional embeddings up to a fixed maximum size.
  411. """
  412. def __init__(self, num_embeddings, embedding_dim):
  413. self.offset = 2
  414. super().__init__(num_embeddings + self.offset, embedding_dim)
  415. def forward(self, input_ids, past_key_values_length=0):
  416. """`input_ids' shape is expected to be [bsz x seqlen]."""
  417. bsz, seq_len = input_ids.shape[:2]
  418. positions = paddle.arange(
  419. past_key_values_length, past_key_values_length + seq_len, dtype=paddle.int64
  420. ).expand([bsz, -1])
  421. return nn.Embedding.forward(self, positions + self.offset)
  422. class MBartPreTrainedModel(nn.Layer):
  423. base_model_prefix = "model"
  424. supports_gradient_checkpointing = True
  425. _no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
  426. _supports_flash_attn_2 = True
  427. def __init__(self, config):
  428. super().__init__()
  429. self.config = config
  430. def _initialize_weights(self, module):
  431. """
  432. Initialize the weights if they are not already initialized.
  433. """
  434. if getattr(module, "_is_hf_initialized", False):
  435. return
  436. self._init_weights(module)
  437. def post_init(self):
  438. self.apply(self._initialize_weights)
  439. def _init_weights(self, module):
  440. std = self.config.init_std
  441. normal_ = Normal(mean=0.0, std=std)
  442. if isinstance(module, nn.Linear):
  443. normal_(module.weight)
  444. if module.bias is not None:
  445. zeros_(module.bias)
  446. elif isinstance(module, nn.Embedding):
  447. normal_(module.weight)
  448. if module._padding_idx is not None:
  449. zeros_(module.weight[module._padding_idx])
  450. @property
  451. def dummy_inputs(self):
  452. pad_token = self.config.pad_token_id
  453. input_ids = paddle.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]])
  454. dummy_inputs = {
  455. "attention_mask": input_ids.ne(pad_token),
  456. "input_ids": input_ids,
  457. }
  458. return dummy_inputs
  459. class MBartAttention(nn.Layer):
  460. """Multi-headed attention from 'Attention Is All You Need' paper"""
  461. def __init__(
  462. self,
  463. embed_dim,
  464. num_heads,
  465. dropout: float = 0.0,
  466. is_decoder: bool = False,
  467. bias: bool = True,
  468. is_causal: bool = False,
  469. config=None,
  470. ):
  471. super().__init__()
  472. self.embed_dim = embed_dim
  473. self.num_heads = num_heads
  474. self.dropout = dropout
  475. self.head_dim = embed_dim // num_heads
  476. self.config = config
  477. if (self.head_dim * num_heads) != self.embed_dim:
  478. raise ValueError(
  479. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  480. f" and `num_heads`: {num_heads})."
  481. )
  482. self.scaling = self.head_dim**-0.5
  483. self.is_decoder = is_decoder
  484. self.is_causal = is_causal
  485. self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
  486. self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
  487. self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
  488. self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
  489. def _shape(self, tensor, seq_len, bsz):
  490. return tensor.reshape([bsz, seq_len, self.num_heads, self.head_dim]).transpose(
  491. [0, 2, 1, 3]
  492. )
  493. def forward(
  494. self,
  495. hidden_states,
  496. key_value_states=None,
  497. past_key_value=None,
  498. attention_mask=None,
  499. layer_head_mask=None,
  500. output_attentions=False,
  501. ):
  502. is_cross_attention = key_value_states is not None
  503. bsz, tgt_len, _ = paddle.shape(hidden_states)
  504. query_states = self.q_proj(hidden_states) * self.scaling
  505. if (
  506. is_cross_attention
  507. and past_key_value is not None
  508. and past_key_value[0].shape[2] == key_value_states.shape[1]
  509. ):
  510. key_states = past_key_value[0]
  511. value_states = past_key_value[1]
  512. elif is_cross_attention:
  513. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  514. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  515. elif past_key_value is not None:
  516. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  517. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  518. key_states = paddle.concat([past_key_value[0], key_states], axis=2)
  519. value_states = paddle.concat([past_key_value[1], value_states], axis=2)
  520. else:
  521. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  522. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  523. if self.is_decoder:
  524. past_key_value = (key_states, value_states)
  525. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  526. query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
  527. key_states = key_states.reshape(proj_shape)
  528. value_states = value_states.reshape(proj_shape)
  529. src_len = key_states.shape[1]
  530. attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1]))
  531. if attention_mask is not None:
  532. attn_weights = (
  533. attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
  534. + attention_mask
  535. )
  536. attn_weights = attn_weights.reshape(
  537. [bsz * self.num_heads, tgt_len, src_len]
  538. )
  539. attn_weights = nn.functional.softmax(attn_weights, axis=-1)
  540. if layer_head_mask is not None:
  541. if tuple(layer_head_mask.shape) != (self.num_heads,):
  542. raise ValueError(
  543. f"Head mask for a single layer should be of shape {(self.num_heads,)}, but is"
  544. f" {layer_head_mask.shape}"
  545. )
  546. attn_weights = layer_head_mask.reshape(
  547. [1, -1, 1, 1]
  548. ) * attn_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
  549. attn_weights = attn_weights.reshape(
  550. [bsz * self.num_heads, tgt_len, src_len]
  551. )
  552. if output_attentions:
  553. attn_weights_reshaped = attn_weights.reshape(
  554. [bsz, self.num_heads, tgt_len, src_len]
  555. )
  556. attn_weights = attn_weights_reshaped.reshape(
  557. [bsz * self.num_heads, tgt_len, src_len]
  558. )
  559. else:
  560. attn_weights_reshaped = None
  561. attn_probs = nn.functional.dropout(
  562. attn_weights, p=self.dropout, training=self.training
  563. )
  564. attn_output = paddle.bmm(attn_probs, value_states)
  565. attn_output = attn_output.reshape([bsz, self.num_heads, tgt_len, self.head_dim])
  566. attn_output = attn_output.transpose([0, 2, 1, 3])
  567. attn_output = attn_output.reshape([bsz, tgt_len, self.embed_dim])
  568. attn_output = self.out_proj(attn_output)
  569. return attn_output, attn_weights_reshaped, past_key_value
  570. MBART_ATTENTION_CLASSES = {
  571. "eager": MBartAttention,
  572. }
  573. class MBartDecoderLayer(nn.Layer):
  574. def __init__(self, config):
  575. super().__init__()
  576. self.embed_dim = config.d_model
  577. self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
  578. embed_dim=self.embed_dim,
  579. num_heads=config.decoder_attention_heads,
  580. dropout=config.attention_dropout,
  581. is_decoder=True,
  582. is_causal=True,
  583. config=config,
  584. )
  585. self.is_export = config.is_export
  586. self.dropout = config.dropout
  587. self.activation_fn = F.gelu
  588. self.activation_dropout = config.activation_dropout
  589. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  590. self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
  591. self.embed_dim,
  592. config.decoder_attention_heads,
  593. dropout=config.attention_dropout,
  594. is_decoder=True,
  595. config=config,
  596. )
  597. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  598. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  599. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  600. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  601. def forward(
  602. self,
  603. hidden_states,
  604. attention_mask=None,
  605. encoder_hidden_states=None,
  606. encoder_attention_mask=None,
  607. layer_head_mask=None,
  608. cross_attn_layer_head_mask=None,
  609. past_key_value: Optional[Tuple[paddle.Tensor]] = None,
  610. output_attentions: Optional[bool] = False,
  611. use_cache: Optional[bool] = True,
  612. ) -> paddle.Tensor:
  613. residual = hidden_states
  614. hidden_states = self.self_attn_layer_norm(hidden_states)
  615. self_attn_past_key_value = (
  616. past_key_value[:2] if past_key_value is not None else None
  617. )
  618. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  619. hidden_states=hidden_states,
  620. past_key_value=self_attn_past_key_value,
  621. attention_mask=attention_mask,
  622. layer_head_mask=layer_head_mask,
  623. output_attentions=output_attentions,
  624. )
  625. hidden_states = nn.functional.dropout(
  626. hidden_states, p=self.dropout, training=self.training
  627. )
  628. hidden_states = residual + hidden_states
  629. cross_attn_present_key_value = None
  630. cross_attn_weights = None
  631. if encoder_hidden_states is not None:
  632. residual = hidden_states
  633. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  634. cross_attn_past_key_value = (
  635. past_key_value[-2:] if past_key_value is not None else None
  636. )
  637. (
  638. hidden_states,
  639. cross_attn_weights,
  640. cross_attn_present_key_value,
  641. ) = self.encoder_attn(
  642. hidden_states=hidden_states,
  643. key_value_states=encoder_hidden_states,
  644. attention_mask=encoder_attention_mask,
  645. layer_head_mask=cross_attn_layer_head_mask,
  646. past_key_value=cross_attn_past_key_value,
  647. output_attentions=output_attentions,
  648. )
  649. hidden_states = nn.functional.dropout(
  650. hidden_states, p=self.dropout, training=self.training
  651. )
  652. hidden_states = residual + hidden_states
  653. present_key_value = present_key_value + cross_attn_present_key_value
  654. residual = hidden_states
  655. hidden_states = self.final_layer_norm(hidden_states)
  656. hidden_states = self.activation_fn(self.fc1(hidden_states))
  657. hidden_states = nn.functional.dropout(
  658. hidden_states, p=self.activation_dropout, training=self.training
  659. )
  660. hidden_states = self.fc2(hidden_states)
  661. hidden_states = nn.functional.dropout(
  662. hidden_states, p=self.dropout, training=self.training
  663. )
  664. hidden_states = residual + hidden_states
  665. outputs = (hidden_states,)
  666. if output_attentions:
  667. outputs += (self_attn_weights, cross_attn_weights)
  668. if self.is_export:
  669. outputs += (present_key_value,)
  670. else:
  671. if use_cache:
  672. outputs += (present_key_value,)
  673. return outputs
  674. class MBartForCausalLM(MBartPreTrainedModel):
  675. _tied_weights_keys = ["lm_head.weight"]
  676. def __init__(self, config):
  677. config = copy.deepcopy(config)
  678. config.is_decoder = True
  679. config.is_encoder_decoder = False
  680. super().__init__(config)
  681. self.model = MBartDecoderWrapper(config)
  682. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False)
  683. self.post_init()
  684. def get_input_embeddings(self):
  685. return self.model.decoder.embed_tokens
  686. def set_input_embeddings(self, value):
  687. self.model.decoder.embed_tokens = value
  688. def get_output_embeddings(self):
  689. return self.lm_head
  690. def set_output_embeddings(self, new_embeddings):
  691. self.lm_head = new_embeddings
  692. def set_decoder(self, decoder):
  693. self.model.decoder = decoder
  694. def get_decoder(self):
  695. return self.model.decoder
  696. def forward(
  697. self,
  698. input_ids=None,
  699. attention_mask=None,
  700. encoder_hidden_states=None,
  701. encoder_attention_mask=None,
  702. head_mask=None,
  703. cross_attn_head_mask=None,
  704. past_key_values=None,
  705. inputs_embeds=None,
  706. labels=None,
  707. use_cache=None,
  708. output_attentions=None,
  709. output_hidden_states=None,
  710. return_dict=None,
  711. ):
  712. output_attentions = (
  713. output_attentions
  714. if output_attentions is not None
  715. else self.config.output_attentions
  716. )
  717. output_hidden_states = (
  718. output_hidden_states
  719. if output_hidden_states is not None
  720. else self.config.output_hidden_states
  721. )
  722. return_dict = (
  723. return_dict if return_dict is not None else self.config.use_return_dict
  724. )
  725. outputs = self.model.decoder(
  726. input_ids=input_ids,
  727. attention_mask=attention_mask,
  728. encoder_hidden_states=encoder_hidden_states,
  729. encoder_attention_mask=encoder_attention_mask,
  730. head_mask=head_mask,
  731. cross_attn_head_mask=cross_attn_head_mask,
  732. past_key_values=past_key_values,
  733. inputs_embeds=inputs_embeds,
  734. use_cache=use_cache,
  735. output_attentions=output_attentions,
  736. output_hidden_states=output_hidden_states,
  737. return_dict=return_dict,
  738. )
  739. logits = self.lm_head(outputs[0])
  740. loss = None
  741. if labels is not None:
  742. labels = labels
  743. loss_fct = CrossEntropyLoss()
  744. loss = loss_fct(
  745. logits.reshape([-1, self.config.vocab_size]), labels.reshape([-1])
  746. )
  747. if not return_dict:
  748. output = (logits,) + outputs[1:]
  749. return (loss,) + output if loss is not None else output
  750. return CausalLMOutputWithCrossAttentions(
  751. loss=loss,
  752. logits=logits,
  753. past_key_values=outputs.past_key_values,
  754. hidden_states=outputs.hidden_states,
  755. attentions=outputs.attentions,
  756. cross_attentions=outputs.cross_attentions,
  757. )
  758. def prepare_inputs_for_generation(
  759. self,
  760. input_ids,
  761. past_key_values=None,
  762. attention_mask=None,
  763. use_cache=None,
  764. **kwargs,
  765. ):
  766. if attention_mask is None:
  767. attention_mask = input_ids.new_ones(input_ids.shape)
  768. if past_key_values:
  769. past_length = past_key_values[0][0].shape[2]
  770. if input_ids.shape[1] > past_length:
  771. remove_prefix_length = past_length
  772. else:
  773. remove_prefix_length = input_ids.shape[1] - 1
  774. input_ids = input_ids[:, remove_prefix_length:]
  775. return {
  776. "input_ids": input_ids,
  777. "attention_mask": attention_mask,
  778. "past_key_values": past_key_values,
  779. "use_cache": use_cache,
  780. }
  781. @staticmethod
  782. def _reorder_cache(past_key_values, beam_idx):
  783. reordered_past = ()
  784. for layer_past in past_key_values:
  785. reordered_past += (
  786. tuple(
  787. past_state.index_select(0, beam_idx) for past_state in layer_past
  788. ),
  789. )
  790. return reordered_past
  791. class myLayerNorm(nn.LayerNorm):
  792. """
  793. Custom implementation of Layer Normalization, with additional options.
  794. This class extends the standard LayerNorm to include optional features,
  795. such as drop block regularization, which might be used for improving
  796. model generalization.
  797. Args:
  798. num_channels (int): The number of features or channels in the input.
  799. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
  800. affine (bool, optional): If True, this module has learnable affine parameters (gamma and beta). Default is True.
  801. drop_block (optional): Additional regularization technique that might be applied. Default is None.
  802. """
  803. def __init__(
  804. self,
  805. num_channels,
  806. eps=1e-5,
  807. affine=True,
  808. drop_block=None,
  809. ):
  810. super(nn.LayerNorm, self).__init__()
  811. self._epsilon = eps
  812. self.num_channels = num_channels
  813. if affine:
  814. self.weight = paddle.create_parameter([num_channels], dtype="float32")
  815. self.bias = paddle.create_parameter([num_channels], dtype="float32")
  816. ones_(self.weight)
  817. zeros_(self.bias)
  818. def forward(self, x):
  819. x = F.layer_norm(
  820. x,
  821. self.num_channels,
  822. weight=self.weight,
  823. bias=self.bias,
  824. epsilon=self._epsilon,
  825. )
  826. return x
  827. class MBartDecoder(MBartPreTrainedModel):
  828. """
  829. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
  830. Args:
  831. config
  832. embed_tokens (nn.Embedding): output embedding
  833. """
  834. def __init__(self, config, embed_tokens=None):
  835. super().__init__(config)
  836. self.dropout = config.dropout
  837. self.layerdrop = config.decoder_layerdrop
  838. self.padding_idx = config.pad_token_id
  839. self.max_target_positions = config.max_position_embeddings
  840. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  841. self.embed_tokens = nn.Embedding(
  842. config.vocab_size, config.d_model, self.padding_idx
  843. )
  844. if embed_tokens is not None:
  845. self.embed_tokens.weight = embed_tokens.weight
  846. self.embed_positions = MBartLearnedPositionalEmbedding(
  847. config.max_position_embeddings,
  848. config.d_model,
  849. )
  850. self.layers = nn.LayerList(
  851. [MBartDecoderLayer(config) for _ in range(config.decoder_layers)]
  852. )
  853. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  854. self.layernorm_embedding = myLayerNorm(config.d_model, affine=True)
  855. self.layer_norm = nn.LayerNorm(config.d_model)
  856. self.gradient_checkpointing = False
  857. # Initialize weights and apply final processing
  858. self.post_init()
  859. self.is_export = config.is_export
  860. def get_input_embeddings(self):
  861. return self.embed_tokens
  862. def set_input_embeddings(self, value):
  863. self.embed_tokens = value
  864. def forward(
  865. self,
  866. input_ids=None,
  867. attention_mask=None,
  868. encoder_hidden_states=None,
  869. encoder_attention_mask=None,
  870. head_mask=None,
  871. cross_attn_head_mask=None,
  872. past_key_values=None,
  873. inputs_embeds=None,
  874. use_cache=None,
  875. output_attentions=None,
  876. output_hidden_states=None,
  877. return_dict=None,
  878. ):
  879. output_attentions = (
  880. output_attentions
  881. if output_attentions is not None
  882. else self.config.output_attentions
  883. )
  884. output_hidden_states = (
  885. output_hidden_states
  886. if output_hidden_states is not None
  887. else self.config.output_hidden_states
  888. )
  889. use_cache = use_cache if use_cache is not None else self.config.use_cache
  890. return_dict = (
  891. return_dict if return_dict is not None else self.config.use_return_dict
  892. )
  893. if input_ids is not None and inputs_embeds is not None:
  894. raise ValueError(
  895. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  896. )
  897. elif input_ids is not None:
  898. input = input_ids
  899. input_shape = input.shape
  900. input_ids = input_ids.reshape([-1, input_shape[-1]])
  901. elif inputs_embeds is not None:
  902. input_shape = inputs_embeds.shape[:-1]
  903. input = inputs_embeds[:, :, -1]
  904. else:
  905. raise ValueError(
  906. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  907. )
  908. past_key_values_length = (
  909. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  910. )
  911. if inputs_embeds is None:
  912. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  913. if self._use_flash_attention_2:
  914. attention_mask = (
  915. attention_mask
  916. if (attention_mask is not None and 0 in attention_mask)
  917. else None
  918. )
  919. else:
  920. attention_mask = _prepare_4d_causal_attention_mask(
  921. attention_mask,
  922. input_shape,
  923. inputs_embeds,
  924. past_key_values_length,
  925. is_export=self.is_export,
  926. )
  927. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  928. if self._use_flash_attention_2:
  929. encoder_attention_mask = (
  930. encoder_attention_mask if 0 in encoder_attention_mask else None
  931. )
  932. else:
  933. encoder_attention_mask = _prepare_4d_attention_mask(
  934. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  935. )
  936. # embed positions
  937. positions = self.embed_positions(input, past_key_values_length)
  938. hidden_states = inputs_embeds + positions
  939. hidden_states = self.layernorm_embedding(hidden_states)
  940. hidden_states = nn.functional.dropout(
  941. hidden_states, p=self.dropout, training=self.training
  942. )
  943. if self.gradient_checkpointing and self.training:
  944. if use_cache:
  945. print(
  946. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  947. )
  948. use_cache = False
  949. all_hidden_states = () if output_hidden_states else None
  950. all_self_attns = () if output_attentions else None
  951. all_cross_attentions = (
  952. () if (output_attentions and encoder_hidden_states is not None) else None
  953. )
  954. next_decoder_cache = () if use_cache else None
  955. for attn_mask, mask_name in zip(
  956. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  957. ):
  958. if attn_mask is not None:
  959. if attn_mask.shape[0] != len(self.layers):
  960. raise ValueError(
  961. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  962. f" {attn_mask.shape[0]}."
  963. )
  964. for idx, decoder_layer in enumerate(self.layers):
  965. if output_hidden_states:
  966. all_hidden_states += (hidden_states,)
  967. if self.training:
  968. dropout_probability = paddle.rand([])
  969. if dropout_probability < self.layerdrop:
  970. continue
  971. past_key_value = (
  972. past_key_values[idx] if past_key_values is not None else None
  973. )
  974. if self.gradient_checkpointing and self.training:
  975. layer_outputs = self._gradient_checkpointing_func(
  976. decoder_layer.__call__,
  977. hidden_states,
  978. attention_mask,
  979. encoder_hidden_states,
  980. encoder_attention_mask,
  981. head_mask[idx] if head_mask is not None else None,
  982. (
  983. cross_attn_head_mask[idx]
  984. if cross_attn_head_mask is not None
  985. else None
  986. ),
  987. None,
  988. output_attentions,
  989. use_cache,
  990. )
  991. else:
  992. layer_outputs = decoder_layer(
  993. hidden_states,
  994. attention_mask=attention_mask,
  995. encoder_hidden_states=encoder_hidden_states,
  996. encoder_attention_mask=encoder_attention_mask,
  997. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  998. cross_attn_layer_head_mask=(
  999. cross_attn_head_mask[idx]
  1000. if cross_attn_head_mask is not None
  1001. else None
  1002. ),
  1003. past_key_value=past_key_value,
  1004. output_attentions=output_attentions,
  1005. use_cache=use_cache,
  1006. )
  1007. hidden_states = layer_outputs[0]
  1008. if use_cache:
  1009. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  1010. if output_attentions:
  1011. all_self_attns += (layer_outputs[1],)
  1012. if encoder_hidden_states is not None:
  1013. all_cross_attentions += (layer_outputs[2],)
  1014. hidden_states = self.layer_norm(hidden_states)
  1015. if output_hidden_states:
  1016. all_hidden_states += (hidden_states,)
  1017. next_cache = next_decoder_cache if use_cache else None
  1018. if not return_dict:
  1019. return tuple(
  1020. v
  1021. for v in [
  1022. hidden_states,
  1023. next_cache,
  1024. all_hidden_states,
  1025. all_self_attns,
  1026. all_cross_attentions,
  1027. ]
  1028. if v is not None
  1029. )
  1030. return BaseModelOutputWithPastAndCrossAttentions(
  1031. last_hidden_state=hidden_states,
  1032. past_key_values=next_cache,
  1033. hidden_states=all_hidden_states,
  1034. attentions=all_self_attns,
  1035. cross_attentions=all_cross_attentions,
  1036. )
  1037. class MBartDecoderWrapper(MBartPreTrainedModel):
  1038. """
  1039. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1040. used in combination with the [`EncoderDecoderModel`] framework.
  1041. """
  1042. def __init__(self, config):
  1043. super().__init__(config)
  1044. self.decoder = MBartDecoder(config)
  1045. def forward(self, *args, **kwargs):
  1046. return self.decoder(*args, **kwargs)
  1047. def _in_projection(
  1048. q: paddle.Tensor,
  1049. k: paddle.Tensor,
  1050. v: paddle.Tensor,
  1051. w_q: paddle.Tensor,
  1052. w_k: paddle.Tensor,
  1053. w_v: paddle.Tensor,
  1054. b_q: Optional[paddle.Tensor] = None,
  1055. b_k: Optional[paddle.Tensor] = None,
  1056. b_v: Optional[paddle.Tensor] = None,
  1057. ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
  1058. Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
  1059. assert w_q.shape == (
  1060. Eq,
  1061. Eq,
  1062. ), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
  1063. assert w_k.shape == (
  1064. Eq,
  1065. Ek,
  1066. ), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
  1067. assert w_v.shape == (
  1068. Eq,
  1069. Ev,
  1070. ), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
  1071. assert b_q is None or b_q.shape == (
  1072. Eq,
  1073. ), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
  1074. assert b_k is None or b_k.shape == (
  1075. Eq,
  1076. ), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
  1077. assert b_v is None or b_v.shape == (
  1078. Eq,
  1079. ), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
  1080. return linear(q, w_q.T, b_q), linear(k, w_k.T, b_k), linear(v, w_v.T, b_v)
  1081. def _scaled_dot_product_attention(
  1082. q: paddle.Tensor,
  1083. k: paddle.Tensor,
  1084. v: paddle.Tensor,
  1085. attn_mask: Optional[paddle.Tensor] = None,
  1086. dropout_p: float = 0.0,
  1087. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  1088. B, Nt, E = q.shape
  1089. q = q / math.sqrt(E)
  1090. attn = paddle.bmm(q, k.transpose([0, 2, 1]))
  1091. if attn_mask is not None:
  1092. attn += attn_mask
  1093. attn = F.softmax(attn, axis=-1)
  1094. if dropout_p > 0.0:
  1095. attn = F.dropout(attn, p=dropout_p)
  1096. output = paddle.bmm(attn, v)
  1097. return output, attn
  1098. def linear(x, w, b, is_transpose):
  1099. if b is not None:
  1100. return paddle.matmul(x, w, transpose_y=is_transpose) + b
  1101. else:
  1102. return paddle.matmul(x, w, transpose_y=is_transpose)
  1103. def _in_projection_packed(
  1104. q: Tensor,
  1105. k: Tensor,
  1106. v: Tensor,
  1107. w: Tensor,
  1108. b: Optional[Tensor] = None,
  1109. is_export=False,
  1110. ) -> List[Tensor]:
  1111. E = paddle.shape(q)[-1]
  1112. if k is v:
  1113. if q is k:
  1114. proj = linear(q, w, b, is_transpose=True)
  1115. if is_export:
  1116. B, D, L = paddle.shape(proj)
  1117. proj = proj.reshape([B, D, 3, E])
  1118. proj = (
  1119. proj.unsqueeze(0)
  1120. .transpose([3, 1, 2, 0, 4])
  1121. .squeeze(-2)
  1122. .contiguous()
  1123. )
  1124. else:
  1125. proj = (
  1126. proj.unflatten(-1, (3, E))
  1127. .unsqueeze(0)
  1128. .transpose([3, 1, 2, 0, 4])
  1129. .squeeze(-2)
  1130. .contiguous()
  1131. )
  1132. return proj[0], proj[1], proj[2]
  1133. else:
  1134. w_q, w_k, w_v = w.chunk(3)
  1135. if b is None:
  1136. b_q = b_k = b_v = None
  1137. else:
  1138. b_q, b_k, b_v = b.chunk(3)
  1139. return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
  1140. def multi_head_attention_forward(
  1141. query: paddle.Tensor,
  1142. key: paddle.Tensor,
  1143. value: paddle.Tensor,
  1144. embed_dim_to_check: int,
  1145. num_heads: int,
  1146. in_proj_weight: paddle.Tensor,
  1147. in_proj_bias: Optional[paddle.Tensor],
  1148. bias_k: Optional[paddle.Tensor],
  1149. bias_v: Optional[paddle.Tensor],
  1150. add_zero_attn: bool,
  1151. dropout_p: float,
  1152. out_proj_weight: paddle.Tensor,
  1153. out_proj_bias: Optional[paddle.Tensor],
  1154. training: bool = True,
  1155. key_padding_mask: Optional[paddle.Tensor] = None,
  1156. need_weights: bool = True,
  1157. attn_mask: Optional[paddle.Tensor] = None,
  1158. use_separate_proj_weight: bool = False,
  1159. q_proj_weight: Optional[paddle.Tensor] = None,
  1160. k_proj_weight: Optional[paddle.Tensor] = None,
  1161. v_proj_weight: Optional[paddle.Tensor] = None,
  1162. static_k: Optional[paddle.Tensor] = None,
  1163. static_v: Optional[paddle.Tensor] = None,
  1164. is_export=False,
  1165. ):
  1166. tgt_len, bsz, embed_dim = query.shape
  1167. src_len, _, _ = key.shape
  1168. if isinstance(embed_dim, paddle.Tensor):
  1169. head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
  1170. else:
  1171. head_dim = embed_dim // num_heads
  1172. q, k, v = _in_projection_packed(
  1173. query, key, value, in_proj_weight, in_proj_bias, is_export
  1174. )
  1175. if key_padding_mask is not None and key_padding_mask.dtype == paddle.uint8:
  1176. warnings.warn(
  1177. "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
  1178. )
  1179. key_padding_mask = key_padding_mask.to(paddle.bool)
  1180. if bias_k is not None and bias_v is not None: # False
  1181. assert static_k is None, "bias cannot be added to static key."
  1182. assert static_v is None, "bias cannot be added to static value."
  1183. k = paddle.concat([k, bias_k.repeat(1, bsz, 1)])
  1184. v = paddle.concat([v, bias_v.repeat(1, bsz, 1)])
  1185. else:
  1186. assert bias_k is None
  1187. assert bias_v is None
  1188. q = q.reshape([tgt_len, bsz * num_heads, head_dim]).transpose([1, 0, 2])
  1189. if static_k is None: # True
  1190. k = k.reshape([k.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
  1191. else:
  1192. assert (
  1193. static_k.shape[0] == bsz * num_heads
  1194. ), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.shape[0]}"
  1195. assert (
  1196. static_k.shape[2] == head_dim
  1197. ), f"expecting static_k.size(2) of {head_dim}, but got {static_k.shape[2]}"
  1198. k = static_k
  1199. if static_v is None: # True
  1200. v = v.reshape([v.shape[0], bsz * num_heads, head_dim]).transpose([1, 0, 2])
  1201. else:
  1202. assert (
  1203. static_v.shape[0] == bsz * num_heads
  1204. ), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.shape[0]}"
  1205. assert (
  1206. static_v.shape[2] == head_dim
  1207. ), f"expecting static_v.size(2) of {head_dim}, but got {static_v.shape[2]}"
  1208. v = static_v
  1209. src_len = k.shape[1]
  1210. if not training:
  1211. dropout_p = 0.0
  1212. attn_output, attn_output_weights = _scaled_dot_product_attention(
  1213. q, k, v, attn_mask, dropout_p
  1214. )
  1215. attn_output = attn_output.transpose([1, 0, 2]).reshape([tgt_len, bsz, embed_dim])
  1216. attn_output = linear(
  1217. attn_output, out_proj_weight, out_proj_bias, is_transpose=False
  1218. )
  1219. if need_weights:
  1220. attn_output_weights = attn_output_weights.reshape(
  1221. [bsz, num_heads, tgt_len, src_len]
  1222. )
  1223. return attn_output, attn_output_weights.sum(axis=1) / num_heads
  1224. else:
  1225. return attn_output, None
  1226. class MyMultiheadAttention(nn.Layer):
  1227. """
  1228. Custom implementation of a multi-head attention layer.
  1229. Attributes:
  1230. __constants__ (list): List of constant attributes.
  1231. bias_k (Optional[paddle.Tensor]): Optional tensor for key bias.
  1232. bias_v (Optional[paddle.Tensor]): Optional tensor for value bias.
  1233. Args:
  1234. embed_dim (int): Total dimension of the model. This is the size of the input feature vectors.
  1235. num_heads (int): Number of parallel attention heads. The input dimension must be divisible by the number of heads.
  1236. dropout (float, optional): Dropout probability on the attention weights. Default is 0.0.
  1237. bias (bool, optional): If True, adds a learnable bias to the output. Default is True.
  1238. add_bias_kv (bool, optional): If True, adds bias to the key and value sequences. Default is False.
  1239. add_zero_attn (bool, optional): If True, adds a zero attention head. Default is False.
  1240. kdim (int, optional): Total number of features for keys. If None, defaults to embed_dim.
  1241. vdim (int, optional): Total number of features for values. If None, defaults to embed_dim.
  1242. batch_first (bool, optional): If True, the input and output tensors are provided as (batch, seq, feature). Default is False.
  1243. device (optional): The device on which the layer's parameters should be initialized. Default is None.
  1244. dtype (optional): The data type for the parameters. Default is None.
  1245. is_export (bool, optional): If True, the layer is set up for export, potentially changing behavior for compatibility. Default is False.
  1246. """
  1247. __constants__ = ["batch_first"]
  1248. bias_k: Optional[paddle.Tensor]
  1249. bias_v: Optional[paddle.Tensor]
  1250. def __init__(
  1251. self,
  1252. embed_dim,
  1253. num_heads,
  1254. dropout=0.0,
  1255. bias=True,
  1256. add_bias_kv=False,
  1257. add_zero_attn=False,
  1258. kdim=None,
  1259. vdim=None,
  1260. batch_first=False,
  1261. device=None,
  1262. dtype=None,
  1263. is_export=False,
  1264. ) -> None:
  1265. super(MyMultiheadAttention, self).__init__()
  1266. self.embed_dim = embed_dim
  1267. self.kdim = kdim if kdim is not None else embed_dim
  1268. self.vdim = vdim if vdim is not None else embed_dim
  1269. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1270. self.num_heads = num_heads
  1271. self.dropout = dropout
  1272. self.batch_first = batch_first
  1273. self.head_dim = embed_dim // num_heads
  1274. self.is_export = is_export
  1275. assert (
  1276. self.head_dim * num_heads == self.embed_dim
  1277. ), "embed_dim must be divisible by num_heads"
  1278. if self._qkv_same_embed_dim is False:
  1279. pass
  1280. else:
  1281. if dtype is None:
  1282. dtype = paddle.float32
  1283. self.in_proj_weight = paddle.create_parameter(
  1284. (3 * embed_dim, embed_dim), dtype
  1285. )
  1286. self.q_proj_weight = None
  1287. self.k_proj_weight = None
  1288. self.v_proj_weight = None
  1289. if bias:
  1290. self.in_proj_bias = paddle.create_parameter((3 * embed_dim,), dtype)
  1291. zeros_(self.in_proj_bias)
  1292. else:
  1293. self.in_proj_bias = None
  1294. self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
  1295. if add_bias_kv:
  1296. pass
  1297. else:
  1298. self.bias_k = self.bias_v = None
  1299. self.add_zero_attn = add_zero_attn
  1300. self._reset_parameters()
  1301. def _reset_parameters(self):
  1302. if self._qkv_same_embed_dim:
  1303. xavier_uniform_(self.in_proj_weight)
  1304. else:
  1305. xavier_uniform_(self.q_proj_weight)
  1306. xavier_uniform_(self.k_proj_weight)
  1307. xavier_uniform_(self.v_proj_weight)
  1308. if self.in_proj_bias is not None:
  1309. zeros_(self.in_proj_bias)
  1310. zeros_(self.out_proj.bias)
  1311. if self.bias_k is not None:
  1312. xavier_normal_(self.bias_k)
  1313. if self.bias_v is not None:
  1314. xavier_normal_(self.bias_v)
  1315. def forward(
  1316. self,
  1317. query: paddle.Tensor,
  1318. key: paddle.Tensor,
  1319. value: paddle.Tensor,
  1320. key_padding_mask: Optional[paddle.Tensor] = None,
  1321. need_weights: bool = True,
  1322. attn_mask: Optional[paddle.Tensor] = None,
  1323. ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]:
  1324. attn_output, attn_output_weights = multi_head_attention_forward(
  1325. query,
  1326. key,
  1327. value,
  1328. self.embed_dim,
  1329. self.num_heads,
  1330. self.in_proj_weight,
  1331. self.in_proj_bias,
  1332. self.bias_k,
  1333. self.bias_v,
  1334. self.add_zero_attn,
  1335. self.dropout,
  1336. self.out_proj.weight,
  1337. self.out_proj.bias,
  1338. training=self.training,
  1339. key_padding_mask=key_padding_mask,
  1340. need_weights=need_weights,
  1341. attn_mask=attn_mask,
  1342. is_export=self.is_export,
  1343. )
  1344. return attn_output, attn_output_weights
  1345. class LogitsProcessorList(list):
  1346. """
  1347. A list of logits processors that can be applied sequentially.
  1348. Methods:
  1349. __call__(input_ids, scores, **kwargs): Apply all processors to the given inputs.
  1350. """
  1351. def __call__(self, input_ids, scores, **kwargs):
  1352. for processor in self:
  1353. function_args = inspect.signature(processor.__call__).parameters
  1354. if len(function_args) > 2:
  1355. if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
  1356. raise ValueError(
  1357. f"Make sure that all the required parameters: {list(function_args.keys())} for "
  1358. f"{processor.__class__} are passed to the logits processor."
  1359. )
  1360. scores = processor(input_ids, scores, **kwargs)
  1361. else:
  1362. scores = processor(input_ids, scores)
  1363. return scores
  1364. class ForcedEOSTokenLogitsProcessor(object):
  1365. """
  1366. A processor that forces the generation of an end-of-sequence (EOS) token
  1367. at a specified position in the sequence.
  1368. This is typically used in language generation tasks to ensure that the
  1369. generated sequence ends properly when it reaches a certain length.
  1370. Args:
  1371. max_length (int): The maximum length of the sequence. Forces EOS when this length is reached.
  1372. eos_token_id (Union[int, List[int]]): The ID(s) of the EOS token(s) to be forced in the sequence.
  1373. """
  1374. def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]):
  1375. self.max_length = max_length
  1376. if isinstance(eos_token_id, int):
  1377. eos_token_id = [eos_token_id]
  1378. self.eos_token_id = eos_token_id
  1379. def __call__(self, input_ids, scores):
  1380. cur_len = input_ids.shape[-1]
  1381. scores_processed = scores
  1382. if cur_len == self.max_length - 1:
  1383. scores_processed = paddle.full_like(scores, -math.inf)
  1384. scores_processed[:, self.eos_token_id] = 0
  1385. return scores_processed
  1386. @dataclass
  1387. class CausalLMOutputWithCrossAttentions(ModelOutput):
  1388. loss = None
  1389. logits = None
  1390. past_key_values = None
  1391. hidden_states = None
  1392. attentions = None
  1393. cross_attentions = None
  1394. def __init__(self, *args, **kwargs):
  1395. super().__init__(*args, **kwargs)
  1396. @dataclass
  1397. class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
  1398. """
  1399. Base class for causal language model (or autoregressive) outputs.
  1400. """
  1401. logits = None
  1402. counting = None
  1403. past_key_values = None
  1404. hidden_states = None
  1405. attentions = None
  1406. cross_attentions = None
  1407. def __init__(self, *args, **kwargs):
  1408. super().__init__(*args, **kwargs)
  1409. class CustomMBartDecoder(MBartDecoder):
  1410. """
  1411. A custom MBartDecoder that includes additional processing layers.
  1412. This class extends the MBartDecoder by adding a customizable neural network
  1413. component called `counting_context_weight`, which applies a series of linear
  1414. transformations followed by ReLU activations. This can be used to modify or
  1415. enhance the decoder's behavior for specific tasks.
  1416. Args:
  1417. config: The configuration object containing model parameters.
  1418. """
  1419. def __init__(self, config):
  1420. super().__init__(config)
  1421. hidden_size = config.d_model
  1422. self.is_export = config.is_export
  1423. self.counting_context_weight = nn.Sequential(
  1424. nn.Linear(config.vocab_size, hidden_size),
  1425. nn.ReLU(),
  1426. nn.Linear(hidden_size, hidden_size),
  1427. nn.ReLU(),
  1428. nn.Linear(hidden_size, config.d_model),
  1429. )
  1430. def forward(
  1431. self,
  1432. input_ids=None,
  1433. attention_mask=None,
  1434. count_pred=None,
  1435. encoder_hidden_states=None,
  1436. encoder_attention_mask=None,
  1437. head_mask=None,
  1438. cross_attn_head_mask=None,
  1439. past_key_values=None,
  1440. inputs_embeds=None,
  1441. use_cache=None,
  1442. output_attentions=None,
  1443. output_hidden_states=None,
  1444. return_dict=None,
  1445. ):
  1446. self.is_export = False if self.training else True
  1447. output_attentions = (
  1448. output_attentions
  1449. if output_attentions is not None
  1450. else self.config.output_attentions
  1451. )
  1452. output_hidden_states = (
  1453. output_hidden_states
  1454. if output_hidden_states is not None
  1455. else self.config.output_hidden_states
  1456. )
  1457. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1458. return_dict = (
  1459. return_dict if return_dict is not None else self.config.use_return_dict
  1460. )
  1461. if input_ids is not None and inputs_embeds is not None:
  1462. raise ValueError(
  1463. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  1464. )
  1465. elif input_ids is not None:
  1466. input = input_ids
  1467. input_shape = input.shape
  1468. input_ids = input_ids.reshape([-1, input_shape[-1]])
  1469. elif inputs_embeds is not None:
  1470. input_shape = inputs_embeds.shape[:-1]
  1471. input = inputs_embeds[:, :, -1]
  1472. else:
  1473. raise ValueError(
  1474. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  1475. )
  1476. past_key_values_length = (
  1477. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  1478. )
  1479. if inputs_embeds is None:
  1480. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  1481. if self._use_flash_attention_2:
  1482. attention_mask = (
  1483. attention_mask
  1484. if (attention_mask is not None and 0 in attention_mask)
  1485. else None
  1486. )
  1487. else:
  1488. if self.is_export:
  1489. attention_mask = _prepare_4d_causal_attention_mask_export(
  1490. attention_mask,
  1491. input_shape,
  1492. inputs_embeds,
  1493. past_key_values_length,
  1494. is_export=self.is_export,
  1495. ).cast(paddle.float32)
  1496. else:
  1497. attention_mask = _prepare_4d_causal_attention_mask(
  1498. attention_mask,
  1499. input_shape,
  1500. inputs_embeds,
  1501. past_key_values_length,
  1502. is_export=self.is_export,
  1503. )
  1504. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1505. if self._use_flash_attention_2:
  1506. encoder_attention_mask = (
  1507. encoder_attention_mask if 0 in encoder_attention_mask else None
  1508. )
  1509. else:
  1510. encoder_attention_mask = _prepare_4d_attention_mask(
  1511. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  1512. )
  1513. # embed positions
  1514. positions = self.embed_positions(input, past_key_values_length)
  1515. hidden_states = inputs_embeds + positions
  1516. # TODO: add counting context weight to hidden_states
  1517. if count_pred is not None:
  1518. count_context_weight = self.counting_context_weight(count_pred)
  1519. hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
  1520. hidden_states = self.layernorm_embedding(hidden_states)
  1521. hidden_states = nn.functional.dropout(
  1522. hidden_states, p=self.dropout, training=self.training
  1523. )
  1524. if self.gradient_checkpointing and self.training:
  1525. if use_cache:
  1526. print(
  1527. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  1528. )
  1529. use_cache = False
  1530. # decoder layers
  1531. all_hidden_states = () if output_hidden_states else None
  1532. all_self_attns = () if output_attentions else None
  1533. all_cross_attentions = (
  1534. () if (output_attentions and encoder_hidden_states is not None) else None
  1535. )
  1536. next_decoder_cache = () if use_cache else None
  1537. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1538. for attn_mask, mask_name in zip(
  1539. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  1540. ):
  1541. if attn_mask is not None:
  1542. if attn_mask.size()[0] != len(self.layers):
  1543. raise ValueError(
  1544. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1545. f" {attn_mask.size()[0]}."
  1546. )
  1547. for idx, decoder_layer in enumerate(self.layers):
  1548. if output_hidden_states:
  1549. all_hidden_states += (hidden_states,)
  1550. if self.training:
  1551. dropout_probability = paddle.rand([])
  1552. if dropout_probability < self.layerdrop:
  1553. continue
  1554. past_key_value = (
  1555. past_key_values[idx] if past_key_values is not None else None
  1556. )
  1557. if self.gradient_checkpointing and self.training:
  1558. layer_outputs = self._gradient_checkpointing_func(
  1559. decoder_layer.__call__,
  1560. hidden_states,
  1561. attention_mask,
  1562. encoder_hidden_states,
  1563. encoder_attention_mask,
  1564. head_mask[idx] if head_mask is not None else None,
  1565. (
  1566. cross_attn_head_mask[idx]
  1567. if cross_attn_head_mask is not None
  1568. else None
  1569. ),
  1570. None,
  1571. output_attentions,
  1572. use_cache,
  1573. )
  1574. else:
  1575. layer_outputs = decoder_layer(
  1576. hidden_states,
  1577. attention_mask=attention_mask,
  1578. encoder_hidden_states=encoder_hidden_states,
  1579. encoder_attention_mask=encoder_attention_mask,
  1580. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1581. cross_attn_layer_head_mask=(
  1582. cross_attn_head_mask[idx]
  1583. if cross_attn_head_mask is not None
  1584. else None
  1585. ),
  1586. past_key_value=past_key_value,
  1587. output_attentions=output_attentions,
  1588. use_cache=use_cache,
  1589. )
  1590. hidden_states = layer_outputs[0]
  1591. if self.is_export:
  1592. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  1593. else:
  1594. if use_cache:
  1595. next_decoder_cache += (
  1596. layer_outputs[3 if output_attentions else 1],
  1597. )
  1598. if output_attentions:
  1599. all_self_attns += (layer_outputs[1],)
  1600. if encoder_hidden_states is not None:
  1601. all_cross_attentions += (layer_outputs[2],)
  1602. hidden_states = self.layer_norm(hidden_states)
  1603. if output_hidden_states:
  1604. all_hidden_states += (hidden_states,)
  1605. if self.is_export:
  1606. next_cache = next_decoder_cache
  1607. else:
  1608. next_cache = next_decoder_cache if use_cache else None
  1609. if not self.is_export:
  1610. if not return_dict:
  1611. return tuple(
  1612. v
  1613. for v in [
  1614. hidden_states,
  1615. next_cache,
  1616. all_hidden_states,
  1617. all_self_attns,
  1618. all_cross_attentions,
  1619. ]
  1620. if v is not None
  1621. )
  1622. return BaseModelOutputWithPastAndCrossAttentions(
  1623. last_hidden_state=hidden_states,
  1624. past_key_values=next_cache,
  1625. hidden_states=all_hidden_states,
  1626. attentions=all_self_attns,
  1627. cross_attentions=all_cross_attentions,
  1628. )
  1629. class SelfAttentionBlock(nn.Layer):
  1630. """
  1631. A self-attention block that implements multi-head self-attention
  1632. followed by a feed-forward network, typically used in transformer architectures.
  1633. Args:
  1634. embed_size (int): The size of the embedding vector.
  1635. num_heads (int): The number of attention heads.
  1636. is_export (bool): Flag indicating whether to configure the layer for export.
  1637. """
  1638. def __init__(self, embed_size, num_heads, is_export):
  1639. super(SelfAttentionBlock, self).__init__()
  1640. self.self_attention = MyMultiheadAttention(
  1641. embed_dim=embed_size, num_heads=num_heads, is_export=is_export
  1642. )
  1643. self.norm = nn.LayerNorm(embed_size)
  1644. def forward(self, x):
  1645. attn_output, _ = self.self_attention(x, x, x)
  1646. x = self.norm(attn_output + x)
  1647. return x
  1648. class SeqCountingDecoder(nn.Layer):
  1649. """
  1650. A custom sequence counting decoder that incorporates multi-head attention layers
  1651. and feed-forward networks to process sequences, potentially for latex code counting .
  1652. Args:
  1653. in_features (int): The number of input features.
  1654. out_features (int): The number of output features.
  1655. num_heads (int): The number of attention heads. Defaults to 8.
  1656. num_layers (int): The number of attention layers. Defaults to 4.
  1657. is_export (bool): Flag indicating whether to configure the layer for export.
  1658. """
  1659. def __init__(
  1660. self, in_features, out_features, num_heads=8, num_layers=4, is_export=False
  1661. ):
  1662. super(SeqCountingDecoder, self).__init__()
  1663. self.attention_blocks = nn.LayerList(
  1664. [
  1665. SelfAttentionBlock(
  1666. embed_size=in_features, num_heads=num_heads, is_export=is_export
  1667. )
  1668. for i in range(num_layers)
  1669. ]
  1670. )
  1671. self.fc1 = nn.Linear(in_features, in_features // 2)
  1672. self.relu = nn.ReLU()
  1673. self.global_avg_pool = nn.AdaptiveAvgPool1D(1)
  1674. self.fc2 = nn.Linear(in_features // 2, out_features)
  1675. def forward(self, x):
  1676. for block in self.attention_blocks:
  1677. x = block(x)
  1678. x = self.fc1(x)
  1679. x = self.relu(x)
  1680. x = x.transpose([0, 2, 1])
  1681. x = self.global_avg_pool(x)
  1682. x = x.squeeze(-1)
  1683. x = self.fc2(x)
  1684. return x
  1685. class CustomMBartForCausalLM(MBartForCausalLM):
  1686. """
  1687. Custom MBart model for causal language modeling with a custom decoder.
  1688. This class extends the MBartForCausalLM by replacing its decoder with a
  1689. custom decoder, allowing for additional flexibility and features in the
  1690. decoding process.
  1691. Args:
  1692. config: The configuration object containing model parameters.
  1693. length_aware (bool): A flag to enable or configure length-aware mechanisms.
  1694. """
  1695. def __init__(self, config, length_aware=True):
  1696. super().__init__(config)
  1697. self.model.decoder = CustomMBartDecoder(config)
  1698. self.counting_decoder = SeqCountingDecoder(
  1699. config.d_model, config.vocab_size, is_export=config.is_export
  1700. )
  1701. self.length_aware = length_aware
  1702. def forward(
  1703. self,
  1704. input_ids=None,
  1705. attention_mask=None,
  1706. encoder_hidden_states=None,
  1707. encoder_attention_mask=None,
  1708. head_mask=None,
  1709. cross_attn_head_mask=None,
  1710. past_key_values=None,
  1711. inputs_embeds=None,
  1712. labels=None,
  1713. use_cache=None,
  1714. output_attentions=None,
  1715. output_hidden_states=None,
  1716. return_dict=None,
  1717. count_gt=None,
  1718. ):
  1719. output_attentions = (
  1720. output_attentions
  1721. if output_attentions is not None
  1722. else self.config.output_attentions
  1723. )
  1724. output_hidden_states = (
  1725. output_hidden_states
  1726. if output_hidden_states is not None
  1727. else self.config.output_hidden_states
  1728. )
  1729. return_dict = (
  1730. return_dict if return_dict is not None else self.config.use_return_dict
  1731. )
  1732. if self.length_aware:
  1733. count_pred = self.counting_decoder(encoder_hidden_states)
  1734. else:
  1735. count_pred = None
  1736. outputs = self.model.decoder(
  1737. input_ids=input_ids,
  1738. attention_mask=attention_mask,
  1739. count_pred=count_pred,
  1740. encoder_hidden_states=encoder_hidden_states,
  1741. encoder_attention_mask=encoder_attention_mask,
  1742. head_mask=head_mask,
  1743. cross_attn_head_mask=cross_attn_head_mask,
  1744. past_key_values=past_key_values,
  1745. inputs_embeds=inputs_embeds,
  1746. use_cache=use_cache,
  1747. output_attentions=output_attentions,
  1748. output_hidden_states=output_hidden_states,
  1749. return_dict=return_dict,
  1750. )
  1751. logits = self.lm_head(outputs[0])
  1752. return CausalLMOutputWithCrossAttentionsAndCounting(
  1753. logits=logits,
  1754. counting=count_pred,
  1755. past_key_values=outputs.past_key_values,
  1756. hidden_states=outputs.hidden_states,
  1757. attentions=outputs.attentions,
  1758. cross_attentions=outputs.cross_attentions,
  1759. )
  1760. class UniMERNetHead(nn.Layer):
  1761. """Implementation of UniMERNetHead decoder.
  1762. Args:
  1763. max_new_tokens (int): Maximum number of new tokens to generate.
  1764. decoder_start_token_id (int): ID of the token that starts the decoding.
  1765. temperature (float): Sampling temperature for generation.
  1766. do_sample (bool): Whether to use sampling; if False, uses greedy decoding.
  1767. top_p (float): Top-p (nucleus) sampling parameter.
  1768. in_channels (int): Number of input channels/features.
  1769. encoder_hidden_size (int): Hidden size of the encoder.
  1770. decoder_hidden_size (int): Hidden size of the decoder.
  1771. decoder_ffn_dim (int): Dimension of the decoder's feed-forward network.
  1772. decoder_layers (int): Number of layers in the decoder.
  1773. is_export (bool): Flag indicating if the model is being prepared for export.
  1774. length_aware (bool): Flag to enable length-aware mechanisms.
  1775. """
  1776. def __init__(
  1777. self,
  1778. max_new_tokens=1536,
  1779. decoder_start_token_id=0,
  1780. temperature=0.2,
  1781. do_sample=False,
  1782. top_p=0.95,
  1783. in_channels=1024,
  1784. encoder_hidden_size=1024,
  1785. decoder_hidden_size=1024,
  1786. decoder_ffn_dim=4096,
  1787. decoder_layers=8,
  1788. is_export=False,
  1789. length_aware=True,
  1790. ):
  1791. super().__init__()
  1792. mbart_config_dict = {
  1793. "activation_dropout": 0.0,
  1794. "activation_function": "gelu",
  1795. "add_cross_attention": True,
  1796. "add_final_layer_norm": True,
  1797. "attention_dropout": 0.0,
  1798. "bos_token_id": 0,
  1799. "classifier_dropout": 0.0,
  1800. "d_model": decoder_hidden_size,
  1801. "decoder_attention_heads": 16,
  1802. "decoder_ffn_dim": decoder_ffn_dim,
  1803. "decoder_layerdrop": 0.0,
  1804. "decoder_layers": decoder_layers,
  1805. "dropout": 0.1,
  1806. "encoder_attention_heads": 16,
  1807. "encoder_ffn_dim": 4096,
  1808. "encoder_layerdrop": 0.0,
  1809. "encoder_layers": 12,
  1810. "eos_token_id": 2,
  1811. "forced_eos_token_id": 2,
  1812. "init_std": 0.02,
  1813. "is_decoder": True,
  1814. "is_encoder_decoder": False,
  1815. "output_hidden_states": False,
  1816. "max_position_embeddings": max_new_tokens,
  1817. "model_type": "mbart",
  1818. "num_hidden_layers": 12,
  1819. "pad_token_id": 1,
  1820. "scale_embedding": True,
  1821. "tie_word_embeddings": False,
  1822. "transformers_version": "4.40.0",
  1823. "use_cache": True,
  1824. "use_return_dict": True,
  1825. "vocab_size": 50000,
  1826. "_attn_implementation": "eager",
  1827. "hidden_size": decoder_hidden_size,
  1828. "is_export": is_export,
  1829. }
  1830. self.max_new_tokens = max_new_tokens
  1831. self.decoder_start_token_id = decoder_start_token_id
  1832. self.temperature = temperature
  1833. self.do_sample = do_sample
  1834. self.top_p = top_p
  1835. self.max_seq_len = max_new_tokens
  1836. self.config_decoder = MBartConfig(**mbart_config_dict)
  1837. self.encoder_hidden_size = encoder_hidden_size
  1838. self.is_export = self.config_decoder.is_export
  1839. self.decoder = CustomMBartForCausalLM(
  1840. self.config_decoder, length_aware=length_aware
  1841. )
  1842. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  1843. self.enc_to_dec_proj = nn.Linear(
  1844. self.encoder_hidden_size, self.config_decoder.hidden_size
  1845. )
  1846. generation_config = {
  1847. "max_length": 1537,
  1848. "forced_eos_token_id": 2,
  1849. }
  1850. self.eos_token_id = generation_config["forced_eos_token_id"]
  1851. self.pad_token_id = self.config_decoder.pad_token_id
  1852. self.logits_processor = LogitsProcessorList()
  1853. self.logits_processor.append(
  1854. ForcedEOSTokenLogitsProcessor(
  1855. generation_config["max_length"],
  1856. generation_config["forced_eos_token_id"],
  1857. )
  1858. )
  1859. def _get_decoder_start_token_id(
  1860. self, decoder_start_token_id=None, bos_token_id=None
  1861. ) -> int:
  1862. decoder_start_token_id = (
  1863. decoder_start_token_id
  1864. if decoder_start_token_id is not None
  1865. else self.generation_config.decoder_start_token_id
  1866. )
  1867. bos_token_id = (
  1868. bos_token_id
  1869. if bos_token_id is not None
  1870. else self.generation_config.bos_token_id
  1871. )
  1872. if decoder_start_token_id is not None:
  1873. return decoder_start_token_id
  1874. elif bos_token_id is not None:
  1875. return bos_token_id
  1876. raise ValueError(
  1877. "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
  1878. )
  1879. def _prepare_decoder_input_ids_for_generation(
  1880. self,
  1881. batch_size,
  1882. model_kwargs,
  1883. decoder_start_token_id=None,
  1884. bos_token_id=None,
  1885. ):
  1886. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  1887. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  1888. elif "input_ids" in model_kwargs:
  1889. decoder_input_ids = model_kwargs.pop("input_ids")
  1890. else:
  1891. decoder_input_ids = None
  1892. decoder_start_token_id = self._get_decoder_start_token_id(
  1893. decoder_start_token_id, bos_token_id
  1894. )
  1895. if isinstance(decoder_start_token_id, list):
  1896. if len(decoder_start_token_id) != batch_size:
  1897. raise ValueError(
  1898. f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
  1899. )
  1900. decoder_input_ids_start = paddle.to_tensor(
  1901. decoder_start_token_id,
  1902. dtype=paddle.int64,
  1903. )
  1904. decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
  1905. else:
  1906. decoder_input_ids_start = (
  1907. paddle.ones(
  1908. (batch_size, 1),
  1909. dtype=paddle.int64,
  1910. )
  1911. * decoder_start_token_id
  1912. )
  1913. if decoder_input_ids is None:
  1914. decoder_input_ids = decoder_input_ids_start
  1915. elif (
  1916. self.config.model_type == "vision-encoder-decoder"
  1917. and "donut" in self.name_or_path.lower()
  1918. ):
  1919. pass
  1920. elif self.config.model_type in ["whisper"]:
  1921. pass
  1922. elif (
  1923. isinstance(decoder_start_token_id, int)
  1924. and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
  1925. ) or (
  1926. isinstance(decoder_start_token_id, paddle.Tensor)
  1927. and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
  1928. ):
  1929. decoder_input_ids = paddle.concat(
  1930. [decoder_input_ids_start, decoder_input_ids], axis=-1
  1931. )
  1932. if "decoder_attention_mask" in model_kwargs:
  1933. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  1934. decoder_attention_mask = paddle.cat(
  1935. (
  1936. paddle.ones_like(decoder_attention_mask)[:, :1],
  1937. decoder_attention_mask,
  1938. ),
  1939. dim=-1,
  1940. )
  1941. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  1942. return decoder_input_ids, model_kwargs
  1943. def prepare_inputs_for_generation_mbart(
  1944. self,
  1945. input_ids,
  1946. past_key_values=None,
  1947. attention_mask=None,
  1948. use_cache=None,
  1949. **kwargs,
  1950. ):
  1951. if attention_mask is None:
  1952. attention_mask = paddle.ones(input_ids.shape)
  1953. if past_key_values:
  1954. past_length = past_key_values[0][0].shape[2]
  1955. if input_ids.shape[1] > past_length:
  1956. remove_prefix_length = past_length
  1957. else:
  1958. remove_prefix_length = input_ids.shape[1] - 1
  1959. input_ids = input_ids[:, remove_prefix_length:]
  1960. return {
  1961. "input_ids": input_ids,
  1962. "attention_mask": attention_mask,
  1963. "past_key_values": past_key_values,
  1964. "use_cache": use_cache,
  1965. }
  1966. def prepare_inputs_for_generation(
  1967. self,
  1968. input_ids,
  1969. past_key_values=None,
  1970. attention_mask=None,
  1971. use_cache=None,
  1972. encoder_outputs=None,
  1973. **kwargs,
  1974. ):
  1975. decoder_inputs = self.prepare_inputs_for_generation_mbart(
  1976. input_ids, past_key_values=past_key_values
  1977. )
  1978. decoder_attention_mask = (
  1979. decoder_inputs["attention_mask"]
  1980. if "attention_mask" in decoder_inputs
  1981. else None
  1982. )
  1983. input_dict = {
  1984. "attention_mask": attention_mask,
  1985. "decoder_attention_mask": decoder_attention_mask,
  1986. "decoder_input_ids": decoder_inputs["input_ids"],
  1987. "encoder_outputs": encoder_outputs,
  1988. "past_key_values": decoder_inputs["past_key_values"],
  1989. "use_cache": use_cache,
  1990. }
  1991. return input_dict
  1992. def prepare_inputs_for_generation_export(
  1993. self,
  1994. past_key_values=None,
  1995. attention_mask=None,
  1996. use_cache=None,
  1997. encoder_outputs=None,
  1998. **kwargs,
  1999. ):
  2000. input_dict = {
  2001. "decoder_attention_mask": None,
  2002. "use_cache": use_cache,
  2003. }
  2004. return input_dict
  2005. def _extract_past_from_model_output(
  2006. self, outputs: ModelOutput, standardize_cache_format: bool = False
  2007. ):
  2008. past_key_values = None
  2009. if "past_key_values" in outputs:
  2010. past_key_values = outputs.past_key_values
  2011. elif "mems" in outputs:
  2012. past_key_values = outputs.mems
  2013. elif "past_buckets_states" in outputs:
  2014. past_key_values = outputs.past_buckets_states
  2015. return past_key_values
  2016. def _update_model_kwargs_for_generation(
  2017. self,
  2018. outputs: ModelOutput,
  2019. model_kwargs: Dict[str, Any],
  2020. is_encoder_decoder: bool = False,
  2021. standardize_cache_format: bool = False,
  2022. ) -> Dict[str, Any]:
  2023. model_kwargs["past_key_values"] = self._extract_past_from_model_output(
  2024. outputs, standardize_cache_format=standardize_cache_format
  2025. )
  2026. if getattr(outputs, "state", None) is not None:
  2027. model_kwargs["state"] = outputs.state
  2028. if "token_type_ids" in model_kwargs:
  2029. token_type_ids = model_kwargs["token_type_ids"]
  2030. model_kwargs["token_type_ids"] = paddle.concat(
  2031. [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1
  2032. )
  2033. if not is_encoder_decoder:
  2034. if "attention_mask" in model_kwargs:
  2035. attention_mask = model_kwargs["attention_mask"]
  2036. model_kwargs["attention_mask"] = paddle.concat(
  2037. [
  2038. attention_mask,
  2039. attention_mask.new_ones((attention_mask.shape[0], 1)),
  2040. ],
  2041. axis=-1,
  2042. )
  2043. else:
  2044. if "decoder_attention_mask" in model_kwargs:
  2045. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  2046. model_kwargs["decoder_attention_mask"] = paddle.concat(
  2047. [
  2048. decoder_attention_mask,
  2049. decoder_attention_mask.new_ones(
  2050. (decoder_attention_mask.shape[0], 1)
  2051. ),
  2052. ],
  2053. axis=-1,
  2054. )
  2055. if (
  2056. "cache_position" in model_kwargs
  2057. and model_kwargs["cache_position"] is not None
  2058. ):
  2059. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
  2060. return model_kwargs
  2061. def stopping_criteria(self, input_ids):
  2062. if self.is_export:
  2063. return input_ids[:, -1] == paddle.to_tensor([self.eos_token_id])
  2064. is_done = paddle.isin(input_ids[:, -1], paddle.to_tensor([self.eos_token_id]))
  2065. return is_done
  2066. def generate_single_iter(
  2067. self,
  2068. decoder_input_ids=None,
  2069. decoder_attention_mask=None,
  2070. encoder_outputs=None,
  2071. past_key_values=None,
  2072. decoder_inputs_embeds=None,
  2073. labels=None,
  2074. use_cache=None,
  2075. output_attentions=None,
  2076. output_hidden_states=None,
  2077. return_dict=None,
  2078. **kwargs,
  2079. ):
  2080. encoder_hidden_states = encoder_outputs[0]
  2081. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  2082. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  2083. kwargs_decoder = {}
  2084. decoder_outputs = self.decoder(
  2085. input_ids=decoder_input_ids,
  2086. attention_mask=decoder_attention_mask,
  2087. encoder_hidden_states=encoder_hidden_states,
  2088. encoder_attention_mask=None,
  2089. inputs_embeds=None,
  2090. output_attentions=False,
  2091. output_hidden_states=output_hidden_states,
  2092. use_cache=use_cache,
  2093. past_key_values=past_key_values,
  2094. return_dict=return_dict,
  2095. **kwargs_decoder,
  2096. )
  2097. return Seq2SeqLMOutput(
  2098. loss=None,
  2099. logits=decoder_outputs.logits,
  2100. past_key_values=decoder_outputs.past_key_values,
  2101. decoder_hidden_states=decoder_outputs.hidden_states,
  2102. decoder_attentions=decoder_outputs.attentions,
  2103. cross_attentions=decoder_outputs.cross_attentions,
  2104. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  2105. encoder_hidden_states=encoder_outputs.hidden_states,
  2106. encoder_attentions=encoder_outputs.attentions,
  2107. )
  2108. @paddle.no_grad()
  2109. def generate(
  2110. self,
  2111. model_kwargs,
  2112. ):
  2113. """
  2114. Generate sequences using the UniMERNetHead for inference tasks.
  2115. Args:
  2116. model_kwargs (dict): A dictionary of model configurations and inputs, which typically include:
  2117. - encoder_outputs: Outputs from the encoder.
  2118. - use_cache: Boolean flag to indicate if caching should be used.
  2119. - output_attentions: Boolean flag for outputting attention scores.
  2120. - output_hidden_states: Boolean flag for outputting hidden states.
  2121. Returns:
  2122. A tensor containing the generated sequences.
  2123. """
  2124. batch_size = model_kwargs["encoder_outputs"]["last_hidden_state"].shape[0]
  2125. generation_config = {
  2126. "decoder_start_token_id": 0,
  2127. "bos_token_id": 0,
  2128. }
  2129. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2130. batch_size=batch_size,
  2131. model_kwargs=model_kwargs,
  2132. decoder_start_token_id=generation_config["decoder_start_token_id"],
  2133. bos_token_id=generation_config["bos_token_id"],
  2134. )
  2135. model_kwargs["key use_cache"] = True
  2136. batch_size, cur_len = input_ids.shape
  2137. if "inputs_embeds" in model_kwargs:
  2138. cur_len = model_kwargs["inputs_embeds"].shape[1]
  2139. model_kwargs["cache_position"] = paddle.arange(cur_len)
  2140. pad_token_id = self.pad_token_id
  2141. eos_token_id = [self.eos_token_id]
  2142. eos_token = self.eos_token_id
  2143. unfinished_sequences = paddle.ones(batch_size, dtype=paddle.int64)
  2144. for idx in range(self.max_seq_len):
  2145. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  2146. outputs = self.generate_single_iter(
  2147. **model_inputs,
  2148. return_dict=True,
  2149. output_attentions=False,
  2150. output_hidden_states=False,
  2151. )
  2152. next_token_logits = outputs.logits[:, -1, :]
  2153. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  2154. next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
  2155. if eos_token_id is not None:
  2156. if pad_token_id is None:
  2157. raise ValueError(
  2158. "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
  2159. )
  2160. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  2161. 1 - unfinished_sequences
  2162. )
  2163. input_ids = paddle.concat([input_ids, next_tokens[:, None]], axis=-1)
  2164. model_kwargs = self._update_model_kwargs_for_generation(
  2165. outputs,
  2166. model_kwargs,
  2167. is_encoder_decoder=self.config_decoder.is_encoder_decoder,
  2168. )
  2169. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  2170. input_ids
  2171. ).cast(paddle.int64)
  2172. if (
  2173. eos_token is not None
  2174. and (
  2175. paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
  2176. >= 1
  2177. ).all()
  2178. ):
  2179. break
  2180. return input_ids
  2181. @paddle.no_grad()
  2182. def generate_export(
  2183. self,
  2184. encoder_outputs,
  2185. model_kwargs,
  2186. ):
  2187. batch_size = encoder_outputs["last_hidden_state"].shape[0]
  2188. generation_config = {
  2189. "decoder_start_token_id": 0,
  2190. "bos_token_id": 0,
  2191. }
  2192. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  2193. batch_size=batch_size,
  2194. model_kwargs=model_kwargs,
  2195. decoder_start_token_id=generation_config["decoder_start_token_id"],
  2196. bos_token_id=generation_config["bos_token_id"],
  2197. )
  2198. input_ids = input_ids.reshape([-1, 1])
  2199. decoder_input_ids = input_ids
  2200. model_kwargs["key use_cache"] = True
  2201. batch_size, cur_len = input_ids.shape
  2202. if "inputs_embeds" in model_kwargs:
  2203. cur_len = model_kwargs["inputs_embeds"].shape[1]
  2204. cache_position = paddle.arange(cur_len)
  2205. pad_token_id = self.pad_token_id
  2206. eos_token_id = [self.eos_token_id]
  2207. eos_token = self.eos_token_id
  2208. unfinished_sequences = paddle.ones([batch_size], dtype=paddle.int64)
  2209. i_idx = paddle.full([], 0)
  2210. past_key_values = []
  2211. for i in range(8):
  2212. init_arr = paddle.zeros([batch_size, 16, 0, 64])
  2213. paddle.jit.api.set_dynamic_shape(init_arr, [-1, -1, -1, -1])
  2214. cache = (init_arr, init_arr, init_arr, init_arr)
  2215. past_key_values.append(cache)
  2216. idx = 0
  2217. while i_idx < paddle.to_tensor(self.max_seq_len):
  2218. model_inputs = self.prepare_inputs_for_generation_export(
  2219. past_key_values=past_key_values, **model_kwargs
  2220. )
  2221. decoder_attention_mask = model_inputs["decoder_attention_mask"]
  2222. decoder_attention_mask = paddle.ones(input_ids.shape)
  2223. paddle.jit.api.set_dynamic_shape(decoder_input_ids, [-1, -1])
  2224. paddle.jit.api.set_dynamic_shape(decoder_attention_mask, [-1, -1])
  2225. outputs = self.generate_single_iter(
  2226. decoder_input_ids=decoder_input_ids,
  2227. decoder_attention_mask=decoder_attention_mask,
  2228. encoder_outputs=encoder_outputs,
  2229. past_key_values=past_key_values,
  2230. return_dict=True,
  2231. output_attentions=False,
  2232. output_hidden_states=False,
  2233. )
  2234. next_token_logits = outputs.logits[:, -1, :]
  2235. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  2236. next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
  2237. if eos_token_id is not None:
  2238. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  2239. 1 - unfinished_sequences
  2240. )
  2241. input_ids = paddle.concat([input_ids, next_tokens.unsqueeze(1)], axis=-1)
  2242. past_length = past_key_values[0][0].shape[2]
  2243. decoder_input_ids = next_tokens.unsqueeze(1)
  2244. past_key_values = outputs.past_key_values
  2245. cache_position = cache_position[-1:] + 1
  2246. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  2247. input_ids
  2248. ).cast(paddle.int64)
  2249. if (
  2250. eos_token is not None
  2251. and (
  2252. paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
  2253. >= 1
  2254. ).all()
  2255. ):
  2256. break
  2257. i_idx += 1
  2258. return input_ids
  2259. def forwad_train(
  2260. self,
  2261. encoder_outputs,
  2262. decoder_input_ids,
  2263. decoder_attention_mask,
  2264. past_key_values=None,
  2265. decoder_inputs_embeds=None,
  2266. labels=None,
  2267. use_cache=None,
  2268. output_attentions=None,
  2269. output_hidden_states=None,
  2270. return_dict=None,
  2271. **kwargs,
  2272. ):
  2273. """
  2274. Training for the UniMERNetHead.
  2275. Args:
  2276. encoder_outputs: Outputs from the encoder, used as input to the decoder.
  2277. decoder_input_ids: Input IDs for the decoder.
  2278. decoder_attention_mask: Attention mask for the decoder inputs.
  2279. past_key_values: Cached key/values for faster decoding.
  2280. decoder_inputs_embeds: Optional embeddings for the decoder inputs.
  2281. labels: Target labels for calculating loss.
  2282. use_cache: Whether to use cache during decoding.
  2283. output_attentions: Whether to return attention scores.
  2284. output_hidden_states: Whether to return hidden states.
  2285. return_dict: Whether to return a dictionary of outputs.
  2286. **kwargs: Additional keyword arguments.
  2287. Returns:
  2288. logits: The raw, unnormalized predictions from the model.
  2289. count_pred: Optional prediction related to sequence length or other counts.
  2290. masked_labels: The labels used during training, possibly masked.
  2291. """
  2292. labels = decoder_input_ids * 1
  2293. labels = labels.masked_fill_(labels == self.pad_token_id, -100)
  2294. input_decoder_input_ids = decoder_input_ids[:, :-1]
  2295. input_decoder_attention_mask = decoder_attention_mask[:, :-1]
  2296. encoder_hidden_states = encoder_outputs[0]
  2297. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  2298. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  2299. kwargs_decoder = {}
  2300. decoder_outputs = self.decoder(
  2301. input_ids=input_decoder_input_ids,
  2302. attention_mask=input_decoder_attention_mask,
  2303. encoder_hidden_states=encoder_hidden_states,
  2304. encoder_attention_mask=None,
  2305. inputs_embeds=None,
  2306. output_attentions=False,
  2307. output_hidden_states=output_hidden_states,
  2308. use_cache=use_cache,
  2309. past_key_values=past_key_values,
  2310. return_dict=return_dict,
  2311. **kwargs_decoder,
  2312. )
  2313. logits = decoder_outputs.logits
  2314. count_pred = decoder_outputs.counting
  2315. return logits, count_pred, labels
  2316. def forward(self, inputs, targets=None):
  2317. """
  2318. Forward pass for the UniMERNetHead, handling both training and inference.
  2319. Args:
  2320. inputs: The input data, which can vary based on training or inference.
  2321. targets: The target labels, used only during training.
  2322. Returns:
  2323. During inference: Returns predicted latex code.
  2324. During training: Returns logits, predicted counts, and masked labels.
  2325. """
  2326. self.is_export = False if self.training else True
  2327. if not self.training:
  2328. encoder_outputs = inputs
  2329. if self.is_export:
  2330. model_kwargs = {
  2331. "output_attentions": False,
  2332. "output_hidden_states": False,
  2333. "use_cache": True,
  2334. }
  2335. word_pred = self.generate_export(encoder_outputs, model_kwargs)
  2336. else:
  2337. model_kwargs = {
  2338. "output_attentions": False,
  2339. "output_hidden_states": False,
  2340. "use_cache": True,
  2341. "encoder_outputs": encoder_outputs,
  2342. }
  2343. word_pred = self.generate(model_kwargs)
  2344. return word_pred
  2345. encoder_outputs, tgt_seq, mask = inputs
  2346. logits, count_pred, masked_labels = self.forwad_train(
  2347. encoder_outputs, tgt_seq, mask
  2348. )
  2349. return logits, count_pred, masked_labels