rec_ppformulanet_head.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import re
  16. import numpy as np
  17. import inspect
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn import CrossEntropyLoss
  22. from paddle import Tensor
  23. from collections import OrderedDict
  24. from typing import Optional, Tuple, Union, List, Dict, Any
  25. from dataclasses import dataclass, fields, is_dataclass
  26. from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModelOutput
  27. from ppocr.modeling.heads.rec_unimernet_head import (
  28. MBartForCausalLM,
  29. MBartDecoder,
  30. MBartConfig,
  31. ModelOutput,
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. Seq2SeqLMOutput,
  34. zeros_,
  35. ones_,
  36. kaiming_normal_,
  37. trunc_normal_,
  38. xavier_uniform_,
  39. CausalLMOutputWithCrossAttentions,
  40. LogitsProcessorList,
  41. ForcedEOSTokenLogitsProcessor,
  42. UniMERNetHead,
  43. )
  44. @dataclass
  45. class AttentionMaskConverter:
  46. """
  47. A class to convert attention masks based on specific configurations.
  48. This class is designed to handle the conversion of attention masks with options for causal masking
  49. and sliding window attention, which are commonly used in transformer models.
  50. Attributes:
  51. is_causal (bool): Flag indicating whether the attention mask should enforce causal masking,
  52. which ensures each position can only attend to previous positions.
  53. sliding_window (int, optional): Size of the sliding window for local attention. If set,
  54. attention is restricted to a local window of this size.
  55. """
  56. is_causal: bool
  57. sliding_window: int
  58. def __init__(self, is_causal: bool, sliding_window=None):
  59. self.is_causal = is_causal
  60. self.sliding_window = sliding_window
  61. if self.sliding_window is not None and self.sliding_window <= 0:
  62. raise ValueError(
  63. f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
  64. )
  65. @staticmethod
  66. def _make_causal_mask(
  67. input_ids_shape,
  68. dtype,
  69. past_key_values_length=0,
  70. sliding_window=None,
  71. is_export=False,
  72. ):
  73. """
  74. Make causal mask used for bi-directional self-attention.
  75. """
  76. bsz, tgt_len = input_ids_shape
  77. if is_export:
  78. mask = paddle.full(
  79. (tgt_len, tgt_len), paddle.finfo(dtype).min, dtype="float64"
  80. )
  81. mask_cond = paddle.arange(mask.shape[-1])
  82. mask.masked_fill_(
  83. mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
  84. )
  85. else:
  86. mask = paddle.full((tgt_len, tgt_len), paddle.finfo(dtype).min)
  87. mask_cond = paddle.arange(mask.shape[-1])
  88. mask.masked_fill_(
  89. mask_cond < (mask_cond + 1).reshape([mask.shape[-1], 1]), 0
  90. )
  91. mask = mask.cast(dtype)
  92. if past_key_values_length > 0:
  93. mask = paddle.concat(
  94. [paddle.zeros(tgt_len, past_key_values_length, dtype=dtype), mask],
  95. axis=-1,
  96. )
  97. # add lower triangular sliding window mask if necessary
  98. if sliding_window is not None:
  99. diagonal = past_key_values_length - sliding_window - 1
  100. context_mask = paddle.tril(
  101. paddle.ones_like(mask, dtype=paddle.bool), diagonal=diagonal
  102. )
  103. mask.masked_fill_(context_mask, paddle.finfo(dtype).min)
  104. return mask[None, None, :, :].expand(
  105. [bsz, 1, tgt_len, tgt_len + past_key_values_length]
  106. )
  107. @staticmethod
  108. def _make_causal_mask_parallel(
  109. input_ids_shape,
  110. dtype,
  111. past_key_values_length=0,
  112. sliding_window=None,
  113. parallel_step=1,
  114. is_export=False,
  115. ):
  116. """
  117. Make causal mask used for bi-directional self-attention.
  118. """
  119. bsz, tgt_len = input_ids_shape
  120. mask = paddle.full((tgt_len, tgt_len), paddle.finfo(dtype).min)
  121. mask_cond = paddle.arange(mask.shape[-1])
  122. mask_cond_parallel = paddle.arange(mask.shape[-1])
  123. mask_parallel = paddle.arange(0, tgt_len, step=parallel_step).reshape([1, -1])
  124. mask_parallel = paddle.repeat_interleave(mask_parallel, parallel_step, 1)[
  125. :, :tgt_len
  126. ]
  127. mask.masked_fill_(
  128. mask_cond < (mask_parallel + parallel_step).reshape([mask.shape[-1], 1]), 0
  129. )
  130. mask = mask.cast(dtype)
  131. if past_key_values_length > 0:
  132. mask = paddle.concat(
  133. [paddle.zeros([tgt_len, past_key_values_length], dtype=dtype), mask],
  134. axis=-1,
  135. )
  136. # add lower triangular sliding window mask if necessary
  137. if sliding_window is not None:
  138. diagonal = past_key_values_length - sliding_window - 1
  139. context_mask = paddle.tril(
  140. paddle.ones_like(mask, dtype=paddle.bool), diagonal=diagonal
  141. )
  142. mask.masked_fill_(context_mask, paddle.finfo(dtype).min)
  143. return mask[None, None, :, :].expand(
  144. [bsz, 1, tgt_len, tgt_len + past_key_values_length]
  145. )
  146. def to_4d(
  147. self,
  148. attention_mask_2d,
  149. query_length,
  150. dtype,
  151. key_value_length,
  152. use_parallel=False,
  153. parallel_step=3,
  154. is_export=False,
  155. ):
  156. """
  157. Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
  158. key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
  159. causal, a causal mask will be added.
  160. """
  161. input_shape = (attention_mask_2d.shape[0], query_length)
  162. causal_4d_mask = None
  163. if use_parallel:
  164. step = parallel_step
  165. else:
  166. step = 1
  167. if (
  168. input_shape[-1] > step or self.sliding_window is not None
  169. ) and self.is_causal:
  170. if key_value_length is None:
  171. raise ValueError(
  172. "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
  173. )
  174. past_key_values_length = key_value_length - query_length
  175. if use_parallel:
  176. causal_4d_mask = self._make_causal_mask_parallel(
  177. input_shape,
  178. dtype,
  179. past_key_values_length=past_key_values_length,
  180. sliding_window=self.sliding_window,
  181. parallel_step=parallel_step,
  182. is_export=is_export,
  183. )
  184. else:
  185. causal_4d_mask = self._make_causal_mask(
  186. input_shape,
  187. dtype,
  188. past_key_values_length=past_key_values_length,
  189. sliding_window=self.sliding_window,
  190. is_export=is_export,
  191. )
  192. elif self.sliding_window is not None:
  193. raise NotImplementedError(
  194. "Sliding window is currently only implemented for causal masking"
  195. )
  196. expanded_attn_mask = self._expand_mask(
  197. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  198. )
  199. if causal_4d_mask is not None:
  200. expanded_attn_mask = causal_4d_mask.masked_fill_(
  201. expanded_attn_mask.cast(paddle.bool), paddle.finfo(dtype).min
  202. )
  203. expanded_4d_mask = expanded_attn_mask
  204. return expanded_4d_mask
  205. def to_4d_export(
  206. self,
  207. attention_mask_2d,
  208. query_length,
  209. dtype,
  210. key_value_length,
  211. use_parallel=False,
  212. parallel_step=3,
  213. is_export=False,
  214. ):
  215. input_shape = (attention_mask_2d.shape[0], query_length)
  216. expanded_attn_mask = self._expand_mask_export(
  217. attention_mask_2d, dtype, tgt_len=input_shape[-1]
  218. )
  219. expanded_4d_mask = expanded_attn_mask
  220. return expanded_4d_mask
  221. def _expand_mask(self, mask, dtype, tgt_len=None):
  222. """
  223. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  224. """
  225. bsz, src_len = mask.shape
  226. tgt_len = tgt_len if tgt_len is not None else src_len
  227. expanded_mask = (
  228. mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
  229. )
  230. inverted_mask = 1.0 - expanded_mask
  231. return inverted_mask.masked_fill_(
  232. inverted_mask.cast(paddle.bool), paddle.finfo(dtype).min
  233. )
  234. def _expand_mask_export(self, mask, dtype, tgt_len=None):
  235. """
  236. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  237. """
  238. bsz, src_len = paddle.shape(mask)
  239. expanded_mask = (
  240. mask[:, None, None, :].expand([bsz, 1, tgt_len, src_len]).cast(dtype)
  241. )
  242. paddle.jit.api.set_dynamic_shape(expanded_mask, [-1, -1, -1, -1])
  243. inverted_mask = 1.0 - expanded_mask
  244. return inverted_mask.masked_fill_(
  245. inverted_mask.cast(paddle.bool), paddle.finfo(dtype).min
  246. )
  247. def _prepare_4d_attention_mask(mask, dtype, tgt_len=None):
  248. return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
  249. def _prepare_4d_causal_attention_mask(
  250. attention_mask,
  251. input_shape,
  252. inputs_embeds,
  253. past_key_values_length,
  254. sliding_window=None,
  255. use_parallel=False,
  256. parallel_step=3,
  257. is_export=False,
  258. ):
  259. """
  260. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  261. `(batch_size, key_value_length)`
  262. Args:
  263. attention_mask (`paddle.Tensor` or `None`):
  264. A 2D attention mask of shape `(batch_size, key_value_length)`
  265. input_shape (`tuple(int)` or `list(int)` or `paddle.Size`):
  266. The input shape should be a tuple that defines `(batch_size, query_length)`.
  267. inputs_embeds (`paddle.Tensor`):
  268. The embedded inputs as a paddle Tensor.
  269. past_key_values_length (`int`):
  270. The length of the key value cache.
  271. sliding_window (`int`, *optional*):
  272. If the model uses windowed attention, a sliding window should be passed.
  273. """
  274. attn_mask_converter = AttentionMaskConverter(
  275. is_causal=True, sliding_window=sliding_window
  276. )
  277. key_value_length = input_shape[-1] + past_key_values_length
  278. # 4d mask is passed through the layers
  279. if attention_mask is not None and len(attention_mask.shape) == 2:
  280. attention_mask = attn_mask_converter.to_4d(
  281. attention_mask,
  282. input_shape[-1],
  283. key_value_length=key_value_length,
  284. dtype=inputs_embeds.dtype,
  285. use_parallel=use_parallel,
  286. parallel_step=parallel_step,
  287. is_export=is_export,
  288. )
  289. elif attention_mask is not None and len(attention_mask.shape) == 4:
  290. expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
  291. if tuple(attention_mask.shape) != expected_shape:
  292. raise ValueError(
  293. f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
  294. )
  295. else:
  296. # if the 4D mask has correct shape - invert it and fill with negative infinity
  297. inverted_mask = 1.0 - attention_mask
  298. attention_mask = inverted_mask.masked_fill_(
  299. inverted_mask.to(paddle.bool), paddle.finfo(inputs_embeds.dtype).min
  300. )
  301. else:
  302. attention_mask = attn_mask_converter.to_causal_4d(
  303. input_shape[0],
  304. input_shape[-1],
  305. key_value_length,
  306. dtype=inputs_embeds.dtype,
  307. )
  308. return attention_mask
  309. def _prepare_4d_causal_attention_mask_export(
  310. attention_mask,
  311. input_shape,
  312. inputs_embeds,
  313. past_key_values_length,
  314. sliding_window=None,
  315. use_parallel=False,
  316. parallel_step=3,
  317. is_export=False,
  318. ):
  319. """
  320. Prepare a 4D causal attention mask for export.
  321. This function prepares a 4-dimensional causal attention mask, which is used to ensure that each position in the
  322. sequence can only attend to previous positions. It is specifically designed to handle scenarios where the model
  323. is being exported, potentially with additional options like sliding window or parallel processing.
  324. Args:
  325. attention_mask: The initial attention mask, typically used to avoid attending to padding tokens.
  326. input_shape: Shape of the input tensor, usually in the form (batch_size, sequence_length).
  327. inputs_embeds: Embeddings of the input sequence, used to derive certain dimensions if needed.
  328. past_key_values_length: Length of past key values, used in contexts like transformer decoders with caching.
  329. sliding_window: Optional parameter. If provided, specifies the size of a sliding window for local attention.
  330. use_parallel: Flag indicating whether to use parallel processing for attention computation.
  331. parallel_step: Number of steps to use in parallel processing, relevant if `use_parallel` is True.
  332. is_export: Flag indicating whether the attention mask is being prepared for model export.
  333. Returns:
  334. A 4D causal attention mask suitable for use in transformer models, ensuring correct causal masking.
  335. """
  336. attn_mask_converter = AttentionMaskConverter(
  337. is_causal=True, sliding_window=sliding_window
  338. )
  339. key_value_length = input_shape[-1] + past_key_values_length
  340. shape = attention_mask.shape
  341. len_shape = len(shape)
  342. attention_mask = attn_mask_converter.to_4d_export(
  343. attention_mask,
  344. input_shape[-1],
  345. key_value_length=key_value_length,
  346. dtype=inputs_embeds.dtype,
  347. use_parallel=use_parallel,
  348. parallel_step=parallel_step,
  349. is_export=is_export,
  350. )
  351. return attention_mask
  352. class CustomMBartDecoder(MBartDecoder):
  353. def __init__(self, config):
  354. super().__init__(config)
  355. hidden_size = config.d_model
  356. self.is_export = config.is_export
  357. self.config_decoder = config
  358. def forward(
  359. self,
  360. input_ids=None,
  361. attention_mask=None,
  362. encoder_hidden_states=None,
  363. encoder_attention_mask=None,
  364. head_mask=None,
  365. cross_attn_head_mask=None,
  366. past_key_values=None,
  367. inputs_embeds=None,
  368. use_cache=None,
  369. output_attentions=None,
  370. output_hidden_states=None,
  371. return_dict=None,
  372. ):
  373. self.is_export = False if self.training else True
  374. output_attentions = (
  375. output_attentions
  376. if output_attentions is not None
  377. else self.config.output_attentions
  378. )
  379. output_hidden_states = (
  380. output_hidden_states
  381. if output_hidden_states is not None
  382. else self.config.output_hidden_states
  383. )
  384. use_cache = use_cache if use_cache is not None else self.config.use_cache
  385. return_dict = (
  386. return_dict if return_dict is not None else self.config.use_return_dict
  387. )
  388. # retrieve input_ids and inputs_embeds
  389. if input_ids is not None and inputs_embeds is not None:
  390. raise ValueError(
  391. "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
  392. )
  393. elif input_ids is not None:
  394. input = input_ids
  395. input_shape = input.shape
  396. input_ids = input_ids.reshape([-1, input_shape[-1]])
  397. elif inputs_embeds is not None:
  398. input_shape = inputs_embeds.shape[:-1]
  399. input = inputs_embeds[:, :, -1]
  400. else:
  401. raise ValueError(
  402. "You have to specify either decoder_input_ids or decoder_inputs_embeds"
  403. )
  404. # past_key_values_length
  405. past_key_values_length = (
  406. past_key_values[0][0].shape[2] if past_key_values is not None else 0
  407. )
  408. if inputs_embeds is None:
  409. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  410. if self._use_flash_attention_2:
  411. # 2d mask is passed through the layers
  412. attention_mask = (
  413. attention_mask
  414. if (attention_mask is not None and 0 in attention_mask)
  415. else None
  416. )
  417. else:
  418. # 4d mask is passed through the layers
  419. if self.is_export:
  420. attention_mask = _prepare_4d_causal_attention_mask_export(
  421. attention_mask,
  422. input_shape,
  423. inputs_embeds,
  424. past_key_values_length,
  425. use_parallel=self.config_decoder.use_parallel,
  426. parallel_step=self.config_decoder.parallel_step,
  427. is_export=self.is_export,
  428. )
  429. else:
  430. attention_mask = _prepare_4d_causal_attention_mask(
  431. attention_mask,
  432. input_shape,
  433. inputs_embeds,
  434. past_key_values_length,
  435. use_parallel=self.config_decoder.use_parallel,
  436. parallel_step=self.config_decoder.parallel_step,
  437. is_export=self.is_export,
  438. )
  439. # expand encoder attention mask
  440. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  441. if self._use_flash_attention_2:
  442. encoder_attention_mask = (
  443. encoder_attention_mask if 0 in encoder_attention_mask else None
  444. )
  445. else:
  446. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  447. encoder_attention_mask = _prepare_4d_attention_mask(
  448. encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
  449. )
  450. # embed positions
  451. positions = self.embed_positions(input, past_key_values_length)
  452. hidden_states = inputs_embeds + positions
  453. hidden_states = self.layernorm_embedding(hidden_states)
  454. hidden_states = nn.functional.dropout(
  455. hidden_states, p=self.dropout, training=self.training
  456. )
  457. if self.gradient_checkpointing and self.training:
  458. if use_cache:
  459. print(
  460. "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
  461. )
  462. use_cache = False
  463. # decoder layers
  464. all_hidden_states = () if output_hidden_states else None
  465. all_self_attns = () if output_attentions else None
  466. all_cross_attentions = (
  467. () if (output_attentions and encoder_hidden_states is not None) else None
  468. )
  469. next_decoder_cache = () if use_cache else None
  470. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  471. for attn_mask, mask_name in zip(
  472. [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
  473. ):
  474. if attn_mask is not None:
  475. if attn_mask.size()[0] != len(self.layers):
  476. raise ValueError(
  477. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  478. f" {attn_mask.size()[0]}."
  479. )
  480. for idx, decoder_layer in enumerate(self.layers):
  481. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  482. if output_hidden_states:
  483. all_hidden_states += (hidden_states,)
  484. if self.training:
  485. dropout_probability = paddle.rand([])
  486. if dropout_probability < self.layerdrop:
  487. continue
  488. past_key_value = (
  489. past_key_values[idx] if past_key_values is not None else None
  490. )
  491. if self.gradient_checkpointing and self.training:
  492. layer_outputs = self._gradient_checkpointing_func(
  493. decoder_layer.__call__,
  494. hidden_states,
  495. attention_mask,
  496. encoder_hidden_states,
  497. encoder_attention_mask,
  498. head_mask[idx] if head_mask is not None else None,
  499. (
  500. cross_attn_head_mask[idx]
  501. if cross_attn_head_mask is not None
  502. else None
  503. ),
  504. None,
  505. output_attentions,
  506. use_cache,
  507. )
  508. else:
  509. layer_outputs = decoder_layer(
  510. hidden_states,
  511. attention_mask=attention_mask,
  512. encoder_hidden_states=encoder_hidden_states,
  513. encoder_attention_mask=encoder_attention_mask,
  514. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  515. cross_attn_layer_head_mask=(
  516. cross_attn_head_mask[idx]
  517. if cross_attn_head_mask is not None
  518. else None
  519. ),
  520. past_key_value=past_key_value,
  521. output_attentions=output_attentions,
  522. use_cache=use_cache,
  523. )
  524. hidden_states = layer_outputs[0]
  525. if self.is_export:
  526. next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
  527. else:
  528. if use_cache:
  529. next_decoder_cache += (
  530. layer_outputs[3 if output_attentions else 1],
  531. )
  532. if output_attentions:
  533. all_self_attns += (layer_outputs[1],)
  534. if encoder_hidden_states is not None:
  535. all_cross_attentions += (layer_outputs[2],)
  536. hidden_states = self.layer_norm(hidden_states)
  537. # add hidden states from the last decoder layer
  538. if output_hidden_states:
  539. all_hidden_states += (hidden_states,)
  540. if self.is_export:
  541. next_cache = next_decoder_cache
  542. else:
  543. next_cache = next_decoder_cache if use_cache else None
  544. if not return_dict:
  545. return tuple(
  546. v
  547. for v in [
  548. hidden_states,
  549. next_cache,
  550. all_hidden_states,
  551. all_self_attns,
  552. all_cross_attentions,
  553. ]
  554. if v is not None
  555. )
  556. return BaseModelOutputWithPastAndCrossAttentions(
  557. last_hidden_state=hidden_states,
  558. past_key_values=next_cache,
  559. hidden_states=all_hidden_states,
  560. attentions=all_self_attns,
  561. cross_attentions=all_cross_attentions,
  562. )
  563. class CustomMBartForCausalLM(MBartForCausalLM):
  564. def __init__(self, config):
  565. super().__init__(config)
  566. # Modify the decoder within MBartDecoderWrapper
  567. self.model.decoder = CustomMBartDecoder(config)
  568. def forward(
  569. self,
  570. input_ids=None,
  571. attention_mask=None,
  572. encoder_hidden_states=None,
  573. encoder_attention_mask=None,
  574. head_mask=None,
  575. cross_attn_head_mask=None,
  576. past_key_values=None,
  577. inputs_embeds=None,
  578. labels=None,
  579. use_cache=None,
  580. output_attentions=None,
  581. output_hidden_states=None,
  582. return_dict=None,
  583. ):
  584. output_attentions = (
  585. output_attentions
  586. if output_attentions is not None
  587. else self.config.output_attentions
  588. )
  589. output_hidden_states = (
  590. output_hidden_states
  591. if output_hidden_states is not None
  592. else self.config.output_hidden_states
  593. )
  594. return_dict = (
  595. return_dict if return_dict is not None else self.config.use_return_dict
  596. )
  597. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  598. outputs = self.model.decoder(
  599. input_ids=input_ids,
  600. attention_mask=attention_mask,
  601. encoder_hidden_states=encoder_hidden_states,
  602. encoder_attention_mask=encoder_attention_mask,
  603. head_mask=head_mask,
  604. cross_attn_head_mask=cross_attn_head_mask,
  605. past_key_values=past_key_values,
  606. inputs_embeds=inputs_embeds,
  607. use_cache=use_cache,
  608. output_attentions=output_attentions,
  609. output_hidden_states=output_hidden_states,
  610. return_dict=return_dict,
  611. )
  612. logits = self.lm_head(outputs[0])
  613. return CausalLMOutputWithCrossAttentions(
  614. logits=logits,
  615. past_key_values=outputs.past_key_values,
  616. hidden_states=outputs.hidden_states,
  617. attentions=outputs.attentions,
  618. cross_attentions=outputs.cross_attentions,
  619. )
  620. class PPFormulaNet_Head(UniMERNetHead):
  621. """
  622. PPFormulaNet_Head
  623. Args:
  624. max_new_tokens (int): Maximum number of new tokens to generate. Default is 1536.
  625. decoder_start_token_id (int): Start token ID for the decoder. Default is 0.
  626. temperature (float): Temperature parameter for controlling randomness in sampling. Default is 0.2.
  627. do_sample (bool): Flag to determine whether to use sampling for generation. Default is False.
  628. top_p (float): Top-p (nucleus) sampling parameter for controlling diversity. Default is 0.95.
  629. in_channels (int): Number of input channels for the model. Default is 1024.
  630. decoder_layers (int): Number of layers in the decoder. Default is 8.
  631. encoder_hidden_size (int): Size of the hidden layer in the encoder. Default is 1024.
  632. decoder_ffn_dim (int): Dimension of the feed-forward network in the decoder. Default is 4096.
  633. decoder_hidden_size (int): Size of the hidden layer in the decoder. Default is 1024.
  634. is_export (bool): Flag indicating whether the model is to be exported. Default is False.
  635. length_aware (bool): Flag to determine if the model should be aware of input sequence length. Default is True.
  636. use_parallel (bool): Flag to enable or disable parallel processing. Default is False.
  637. parallel_step (int): Number of steps to use in parallel processing. Default is 3.
  638. """
  639. def __init__(
  640. self,
  641. max_new_tokens=1536,
  642. decoder_start_token_id=0,
  643. temperature=0.2,
  644. do_sample=False,
  645. top_p=0.95,
  646. in_channels=1024,
  647. decoder_layers=8,
  648. encoder_hidden_size=1024,
  649. decoder_ffn_dim=4096,
  650. decoder_hidden_size=1024,
  651. is_export=False,
  652. length_aware=True,
  653. use_parallel=False,
  654. parallel_step=3,
  655. ):
  656. super().__init__()
  657. mbart_config_dict = {
  658. "activation_dropout": 0.0,
  659. "activation_function": "gelu",
  660. "add_cross_attention": True,
  661. "add_final_layer_norm": True,
  662. "attention_dropout": 0.0,
  663. "bos_token_id": 0,
  664. "classifier_dropout": 0.0,
  665. "d_model": decoder_hidden_size,
  666. "decoder_attention_heads": 16,
  667. "decoder_ffn_dim": decoder_ffn_dim,
  668. "decoder_layerdrop": 0.0,
  669. "decoder_layers": decoder_layers,
  670. "dropout": 0.1,
  671. "encoder_attention_heads": 16,
  672. "encoder_ffn_dim": 4096,
  673. "encoder_layerdrop": 0.0,
  674. "encoder_layers": 12,
  675. "eos_token_id": 2,
  676. "forced_eos_token_id": 2,
  677. "init_std": 0.02,
  678. "is_decoder": True,
  679. "is_encoder_decoder": False,
  680. "output_hidden_states": False,
  681. "max_position_embeddings": (
  682. max_new_tokens + parallel_step if use_parallel else max_new_tokens
  683. ),
  684. "model_type": "mbart",
  685. "num_hidden_layers": 12,
  686. "pad_token_id": 1,
  687. "scale_embedding": True,
  688. "tie_word_embeddings": False,
  689. "transformers_version": "4.40.0",
  690. "use_cache": True,
  691. "use_return_dict": True,
  692. "vocab_size": 50000,
  693. "_attn_implementation": "eager",
  694. "hidden_size": decoder_hidden_size,
  695. "use_parallel": use_parallel,
  696. "parallel_step": int(parallel_step),
  697. "is_export": is_export,
  698. }
  699. self.decoder_start_token_id = decoder_start_token_id
  700. self.temperature = temperature
  701. self.do_sample = do_sample
  702. self.top_p = top_p
  703. self.is_export = is_export
  704. self.max_seq_len = max_new_tokens
  705. self.config_decoder = MBartConfig(**mbart_config_dict)
  706. self.encoder_hidden_size = encoder_hidden_size
  707. self.decoder = CustomMBartForCausalLM(self.config_decoder)
  708. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  709. self.enc_to_dec_proj = nn.Linear(
  710. self.encoder_hidden_size, self.config_decoder.hidden_size
  711. )
  712. generation_config = {
  713. "max_length": 1537,
  714. "forced_eos_token_id": 2,
  715. }
  716. self.eos_token_id = generation_config["forced_eos_token_id"]
  717. self.pad_token_id = self.config_decoder.pad_token_id
  718. self.logits_processor = LogitsProcessorList()
  719. self.logits_processor.append(
  720. ForcedEOSTokenLogitsProcessor(
  721. generation_config["max_length"],
  722. generation_config["forced_eos_token_id"],
  723. )
  724. )
  725. def prepare_inputs_for_generation(
  726. self,
  727. input_ids,
  728. past_key_values=None,
  729. attention_mask=None,
  730. use_cache=None,
  731. encoder_outputs=None,
  732. **kwargs,
  733. ):
  734. decoder_inputs = self.prepare_inputs_for_generation_mbart(
  735. input_ids, past_key_values=past_key_values
  736. )
  737. decoder_attention_mask = (
  738. decoder_inputs["attention_mask"]
  739. if "attention_mask" in decoder_inputs
  740. else None
  741. )
  742. input_dict = {
  743. "attention_mask": attention_mask,
  744. "decoder_attention_mask": decoder_attention_mask,
  745. "decoder_input_ids": decoder_inputs["input_ids"],
  746. "past_key_values": decoder_inputs["past_key_values"],
  747. "use_cache": use_cache,
  748. }
  749. return input_dict
  750. def _extract_past_from_model_output(
  751. self, outputs: ModelOutput, standardize_cache_format: bool = False
  752. ):
  753. past_key_values = None
  754. if "past_key_values" in outputs:
  755. past_key_values = outputs.past_key_values
  756. elif "mems" in outputs:
  757. past_key_values = outputs.mems
  758. elif "past_buckets_states" in outputs:
  759. past_key_values = outputs.past_buckets_states
  760. return past_key_values
  761. def _update_model_kwargs_for_generation(
  762. self,
  763. outputs: ModelOutput,
  764. model_kwargs: Dict[str, Any],
  765. is_encoder_decoder: bool = False,
  766. standardize_cache_format: bool = False,
  767. ) -> Dict[str, Any]:
  768. # update past_key_values
  769. model_kwargs["past_key_values"] = self._extract_past_from_model_output(
  770. outputs, standardize_cache_format=standardize_cache_format
  771. )
  772. if getattr(outputs, "state", None) is not None:
  773. model_kwargs["state"] = outputs.state
  774. # update token_type_ids with last value
  775. if "token_type_ids" in model_kwargs:
  776. token_type_ids = model_kwargs["token_type_ids"]
  777. model_kwargs["token_type_ids"] = paddle.concat(
  778. [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1
  779. )
  780. if not is_encoder_decoder:
  781. # update attention mask
  782. if "attention_mask" in model_kwargs:
  783. attention_mask = model_kwargs["attention_mask"]
  784. model_kwargs["attention_mask"] = paddle.concat(
  785. [
  786. attention_mask,
  787. attention_mask.new_ones((attention_mask.shape[0], 1)),
  788. ],
  789. axis=-1,
  790. )
  791. else:
  792. # update decoder attention mask
  793. if "decoder_attention_mask" in model_kwargs:
  794. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  795. model_kwargs["decoder_attention_mask"] = paddle.concat(
  796. [
  797. decoder_attention_mask,
  798. decoder_attention_mask.new_ones(
  799. (decoder_attention_mask.shape[0], 1)
  800. ),
  801. ],
  802. dim=-1,
  803. )
  804. if (
  805. "cache_position" in model_kwargs
  806. and model_kwargs["cache_position"] is not None
  807. ):
  808. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
  809. return model_kwargs
  810. def stopping_criteria(self, input_ids):
  811. if self.is_export:
  812. return input_ids[:, -1] == paddle.to_tensor([self.eos_token_id])
  813. is_done = paddle.isin(input_ids[:, -1], paddle.to_tensor([self.eos_token_id]))
  814. return is_done
  815. def stopping_criteria_parallel(self, input_ids):
  816. parallel_step = self.config_decoder.parallel_step
  817. if self.is_export:
  818. is_done_list = []
  819. for i in range(parallel_step, 0, -1):
  820. cur_is_done = input_ids[:, -i] == paddle.to_tensor([self.eos_token_id])
  821. is_done_list.append(cur_is_done)
  822. is_done_list = paddle.to_tensor(is_done_list).transpose([1, 0])
  823. return is_done_list
  824. else:
  825. is_done = paddle.isin(
  826. input_ids[:, -parallel_step:],
  827. paddle.to_tensor([self.eos_token_id]).reshape([1, 1]),
  828. )
  829. return paddle.to_tensor(is_done)
  830. def generate_single_iter(
  831. self,
  832. decoder_input_ids=None,
  833. decoder_attention_mask=None,
  834. encoder_outputs=None,
  835. past_key_values=None,
  836. decoder_inputs_embeds=None,
  837. labels=None,
  838. use_cache=None,
  839. output_attentions=None,
  840. output_hidden_states=None,
  841. return_dict=None,
  842. **kwargs,
  843. ):
  844. encoder_hidden_states = encoder_outputs[0]
  845. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  846. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  847. kwargs_decoder = {}
  848. decoder_outputs = self.decoder(
  849. input_ids=decoder_input_ids,
  850. attention_mask=decoder_attention_mask,
  851. encoder_hidden_states=encoder_hidden_states,
  852. encoder_attention_mask=None,
  853. inputs_embeds=None,
  854. output_attentions=False,
  855. output_hidden_states=output_hidden_states,
  856. use_cache=use_cache,
  857. past_key_values=past_key_values,
  858. return_dict=return_dict,
  859. **kwargs_decoder,
  860. )
  861. return Seq2SeqLMOutput(
  862. loss=None,
  863. logits=decoder_outputs.logits,
  864. past_key_values=decoder_outputs.past_key_values,
  865. decoder_hidden_states=decoder_outputs.hidden_states,
  866. decoder_attentions=decoder_outputs.attentions,
  867. cross_attentions=decoder_outputs.cross_attentions,
  868. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  869. encoder_hidden_states=encoder_outputs.hidden_states,
  870. encoder_attentions=encoder_outputs.attentions,
  871. )
  872. def _prepare_decoder_input_ids_for_generation(
  873. self,
  874. batch_size,
  875. model_kwargs,
  876. decoder_start_token_id=None,
  877. bos_token_id=None,
  878. ):
  879. # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
  880. # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
  881. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  882. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  883. elif "input_ids" in model_kwargs:
  884. decoder_input_ids = model_kwargs.pop("input_ids")
  885. else:
  886. decoder_input_ids = None
  887. # 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
  888. decoder_start_token_id = self._get_decoder_start_token_id(
  889. decoder_start_token_id, bos_token_id
  890. )
  891. if isinstance(decoder_start_token_id, list):
  892. if len(decoder_start_token_id) != batch_size:
  893. raise ValueError(
  894. f"`decoder_start_token_id` expected to have length {batch_size} but got {len(decoder_start_token_id)}"
  895. )
  896. decoder_input_ids_start = paddle.to_tensor(
  897. decoder_start_token_id,
  898. dtype=paddle.int64,
  899. )
  900. decoder_input_ids_start = decoder_input_ids_start.view(-1, 1)
  901. else:
  902. use_parallel = self.config_decoder.use_parallel
  903. parallel_step = self.config_decoder.parallel_step
  904. if use_parallel:
  905. decoder_input_ids_start = (
  906. paddle.ones(
  907. (batch_size, parallel_step),
  908. dtype=paddle.int64,
  909. )
  910. * decoder_start_token_id
  911. )
  912. else:
  913. decoder_input_ids_start = (
  914. paddle.ones(
  915. (batch_size, 1),
  916. dtype=paddle.int64,
  917. )
  918. * decoder_start_token_id
  919. )
  920. # no user input -> use decoder_start_token_id as decoder_input_ids
  921. if decoder_input_ids is None:
  922. decoder_input_ids = decoder_input_ids_start
  923. # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
  924. elif (
  925. self.config.model_type == "vision-encoder-decoder"
  926. and "donut" in self.name_or_path.lower()
  927. ):
  928. pass
  929. elif self.config.model_type in ["whisper"]:
  930. pass
  931. # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
  932. # decoder_attention_mask if provided)
  933. elif (
  934. isinstance(decoder_start_token_id, int)
  935. and (decoder_input_ids[:, 0] != decoder_start_token_id).all().item()
  936. ) or (
  937. isinstance(decoder_start_token_id, paddle.Tensor)
  938. and (decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]).all().item()
  939. ):
  940. decoder_input_ids = paddle.concat(
  941. [decoder_input_ids_start, decoder_input_ids], axis=-1
  942. )
  943. if "decoder_attention_mask" in model_kwargs:
  944. decoder_attention_mask = model_kwargs["decoder_attention_mask"]
  945. decoder_attention_mask = paddle.cat(
  946. (
  947. paddle.ones_like(decoder_attention_mask)[:, :1],
  948. decoder_attention_mask,
  949. ),
  950. dim=-1,
  951. )
  952. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  953. return decoder_input_ids, model_kwargs
  954. @paddle.no_grad()
  955. def generate_export(
  956. self,
  957. encoder_outputs,
  958. model_kwargs,
  959. ):
  960. use_parallel = self.config_decoder.use_parallel
  961. parallel_step = self.config_decoder.parallel_step
  962. batch_size = encoder_outputs["last_hidden_state"].shape[0]
  963. generation_config = {
  964. "decoder_start_token_id": 0,
  965. "bos_token_id": 0,
  966. }
  967. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  968. batch_size=batch_size,
  969. model_kwargs=model_kwargs,
  970. decoder_start_token_id=generation_config["decoder_start_token_id"],
  971. bos_token_id=generation_config["bos_token_id"],
  972. )
  973. if not use_parallel:
  974. input_ids = input_ids.reshape([-1, 1])
  975. decoder_input_ids = input_ids
  976. model_kwargs["key use_cache"] = True
  977. batch_size, cur_len = input_ids.shape
  978. if "inputs_embeds" in model_kwargs:
  979. cur_len = model_kwargs["inputs_embeds"].shape[1]
  980. cache_position = paddle.arange(cur_len)
  981. pad_token_id = self.pad_token_id
  982. eos_token_id = [self.eos_token_id]
  983. eos_token = self.eos_token_id
  984. if use_parallel:
  985. unfinished_sequences = paddle.ones(
  986. [batch_size, parallel_step], dtype=paddle.int64
  987. )
  988. parallel_length = math.ceil(self.max_seq_len // parallel_step)
  989. else:
  990. unfinished_sequences = paddle.ones(batch_size, dtype=paddle.int64)
  991. parallel_length = self.max_seq_len
  992. i_idx = paddle.full([], 0)
  993. past_key_values = []
  994. decoder_attention_heads = self.config_decoder.decoder_attention_heads
  995. decoder_attention_heads_dim = int(
  996. self.config_decoder.d_model / decoder_attention_heads
  997. )
  998. for i in range(self.config_decoder.decoder_layers):
  999. init_arr = paddle.zeros(
  1000. [batch_size, decoder_attention_heads, 0, decoder_attention_heads_dim]
  1001. )
  1002. paddle.jit.api.set_dynamic_shape(init_arr, [-1, -1, -1, -1])
  1003. cache = (init_arr, init_arr, init_arr, init_arr)
  1004. past_key_values.append(cache)
  1005. while i_idx < paddle.to_tensor(parallel_length):
  1006. model_inputs = self.prepare_inputs_for_generation_export(
  1007. past_key_values=past_key_values, **model_kwargs
  1008. )
  1009. decoder_attention_mask = paddle.ones(paddle.shape(input_ids))
  1010. paddle.jit.api.set_dynamic_shape(decoder_input_ids, [-1, -1])
  1011. paddle.jit.api.set_dynamic_shape(decoder_attention_mask, [-1, -1])
  1012. outputs = self.generate_single_iter(
  1013. decoder_input_ids=decoder_input_ids,
  1014. decoder_attention_mask=decoder_attention_mask,
  1015. encoder_outputs=encoder_outputs,
  1016. past_key_values=past_key_values,
  1017. return_dict=True,
  1018. output_attentions=False,
  1019. output_hidden_states=False,
  1020. )
  1021. if use_parallel:
  1022. next_token_logits = outputs.logits[:, -parallel_step:, :]
  1023. else:
  1024. next_token_logits = outputs.logits[:, -1, :]
  1025. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  1026. next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
  1027. if eos_token_id is not None:
  1028. # False
  1029. if pad_token_id is None:
  1030. raise ValueError(
  1031. "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
  1032. )
  1033. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  1034. 1 - unfinished_sequences
  1035. )
  1036. if use_parallel:
  1037. input_ids = paddle.concat([input_ids, next_tokens], axis=-1)
  1038. decoder_input_ids = next_tokens
  1039. else:
  1040. input_ids = paddle.concat(
  1041. [input_ids, next_tokens.unsqueeze(1)], axis=-1
  1042. )
  1043. decoder_input_ids = next_tokens.unsqueeze(1)
  1044. past_length = past_key_values[0][0].shape[2]
  1045. past_key_values = outputs.past_key_values
  1046. cache_position = cache_position[-1:] + 1
  1047. if use_parallel:
  1048. unfinished_sequences = (
  1049. unfinished_sequences
  1050. & ~self.stopping_criteria_parallel(input_ids).cast(paddle.int64)
  1051. )
  1052. else:
  1053. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  1054. input_ids
  1055. ).cast(paddle.int64)
  1056. if (
  1057. eos_token is not None
  1058. and (
  1059. paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
  1060. >= 1
  1061. ).all()
  1062. ):
  1063. break
  1064. i_idx += 1
  1065. # break
  1066. return input_ids
  1067. @paddle.no_grad()
  1068. def generate(
  1069. self,
  1070. encoder_outputs,
  1071. model_kwargs,
  1072. ):
  1073. """
  1074. Generate sequences from the model without computing gradients.
  1075. This method is used to generate sequences from the model based on the given encoder outputs.
  1076. It does not compute gradients, making it suitable for inference.
  1077. Args:
  1078. encoder_outputs: The outputs from the encoder, typically including hidden states necessary for generation.
  1079. model_kwargs: Additional keyword arguments that may include parameters such as maximum length,
  1080. temperature, top-k/top-p sampling parameters, and other generation-specific settings.
  1081. Returns:
  1082. Generated sequences based on the encoder outputs and specified generation parameters.
  1083. """
  1084. use_parallel = self.config_decoder.use_parallel
  1085. parallel_step = self.config_decoder.parallel_step
  1086. batch_size = encoder_outputs["last_hidden_state"].shape[0]
  1087. generation_config = {
  1088. "decoder_start_token_id": 0,
  1089. "bos_token_id": 0,
  1090. }
  1091. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  1092. batch_size=batch_size,
  1093. model_kwargs=model_kwargs,
  1094. decoder_start_token_id=generation_config["decoder_start_token_id"],
  1095. bos_token_id=generation_config["bos_token_id"],
  1096. )
  1097. decoder_input_ids = input_ids
  1098. model_kwargs["key use_cache"] = True
  1099. batch_size, cur_len = input_ids.shape
  1100. if "inputs_embeds" in model_kwargs:
  1101. cur_len = model_kwargs["inputs_embeds"].shape[1]
  1102. model_kwargs["cache_position"] = paddle.arange(cur_len)
  1103. pad_token_id = self.pad_token_id
  1104. eos_token_id = [self.eos_token_id]
  1105. eos_token = self.eos_token_id
  1106. if use_parallel:
  1107. unfinished_sequences = paddle.ones(
  1108. [batch_size, parallel_step], dtype=paddle.int64
  1109. )
  1110. parallel_length = math.ceil(self.max_seq_len // parallel_step)
  1111. else:
  1112. unfinished_sequences = paddle.ones(batch_size, dtype=paddle.int64)
  1113. parallel_length = self.max_seq_len
  1114. past_key_values = []
  1115. for idx in range(parallel_length):
  1116. model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  1117. outputs = self.generate_single_iter(
  1118. **model_inputs,
  1119. encoder_outputs=encoder_outputs,
  1120. return_dict=True,
  1121. output_attentions=False,
  1122. output_hidden_states=False,
  1123. )
  1124. if use_parallel:
  1125. next_token_logits = outputs.logits[:, :, :]
  1126. else:
  1127. next_token_logits = outputs.logits[:, -1, :]
  1128. next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
  1129. next_tokens = paddle.argmax(next_tokens_scores, axis=-1)
  1130. if eos_token_id is not None:
  1131. # False
  1132. if pad_token_id is None:
  1133. raise ValueError(
  1134. "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
  1135. )
  1136. next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
  1137. 1 - unfinished_sequences
  1138. )
  1139. if use_parallel:
  1140. input_ids = paddle.concat([input_ids, next_tokens], axis=-1)
  1141. else:
  1142. input_ids = paddle.concat([input_ids, next_tokens[:, None]], axis=-1)
  1143. model_kwargs = self._update_model_kwargs_for_generation(
  1144. outputs,
  1145. model_kwargs,
  1146. is_encoder_decoder=self.config_decoder.is_encoder_decoder,
  1147. )
  1148. if use_parallel:
  1149. unfinished_sequences = (
  1150. unfinished_sequences
  1151. & ~self.stopping_criteria_parallel(input_ids).cast(paddle.int64)
  1152. )
  1153. else:
  1154. unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
  1155. input_ids
  1156. ).cast(paddle.int64)
  1157. if (
  1158. eos_token is not None
  1159. and (
  1160. paddle.cumsum((input_ids == eos_token).cast(paddle.int64), 1)[:, -1]
  1161. >= 1
  1162. ).all()
  1163. ):
  1164. break
  1165. return input_ids
  1166. def forwad_train(
  1167. self,
  1168. encoder_outputs,
  1169. decoder_input_ids,
  1170. decoder_attention_mask,
  1171. past_key_values=None,
  1172. decoder_inputs_embeds=None,
  1173. labels=None,
  1174. use_cache=None,
  1175. output_attentions=None,
  1176. output_hidden_states=None,
  1177. return_dict=None,
  1178. **kwargs,
  1179. ):
  1180. """
  1181. Forward pass for training the model.
  1182. Args:
  1183. encoder_outputs: The outputs from the encoder, typically including hidden states.
  1184. decoder_input_ids: Input IDs for the decoder.
  1185. decoder_attention_mask: Attention mask for the decoder inputs to avoid attending to padding tokens.
  1186. past_key_values: Previously computed key and value states for the decoder, used for fast generation.
  1187. decoder_inputs_embeds: Optional embeddings for decoder inputs, used instead of decoder_input_ids if provided.
  1188. labels: Labels for computing the training loss.
  1189. use_cache: Whether to use a cache of past key values for faster generation.
  1190. output_attentions: Whether to output attention weights.
  1191. output_hidden_states: Whether to output hidden states of all layers.
  1192. return_dict: Whether to return the output as a dictionary.
  1193. **kwargs: Additional keyword arguments.
  1194. Returns:
  1195. Depending on the `return_dict` flag, returns either a dictionary of model outputs or a tuple.
  1196. """
  1197. if self.config_decoder.use_parallel:
  1198. batch = decoder_input_ids.shape[0]
  1199. add_sos_token = self.config_decoder.parallel_step - 1
  1200. start_token = paddle.zeros([batch, add_sos_token]).cast(paddle.int64)
  1201. start_mask = paddle.ones([batch, add_sos_token]).cast(paddle.int64)
  1202. decoder_input_ids = paddle.concat([start_token, decoder_input_ids], axis=1)
  1203. decoder_attention_mask = paddle.concat(
  1204. [start_mask, decoder_attention_mask], axis=1
  1205. )
  1206. labels = decoder_input_ids * 1
  1207. labels = labels.masked_fill_(labels == self.pad_token_id, -100)
  1208. if self.config_decoder.use_parallel:
  1209. input_decoder_input_ids = decoder_input_ids[
  1210. :, : -self.config_decoder.parallel_step
  1211. ]
  1212. input_decoder_attention_mask = decoder_attention_mask[
  1213. :, : -self.config_decoder.parallel_step
  1214. ]
  1215. else:
  1216. input_decoder_input_ids = decoder_input_ids[:, :-1]
  1217. input_decoder_attention_mask = decoder_attention_mask[:, :-1]
  1218. encoder_hidden_states = encoder_outputs[0]
  1219. kwargs_decoder = {}
  1220. if self.config_decoder.hidden_size != self.encoder_hidden_size:
  1221. encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
  1222. decoder_outputs = self.decoder(
  1223. input_ids=input_decoder_input_ids,
  1224. attention_mask=input_decoder_attention_mask,
  1225. encoder_hidden_states=encoder_hidden_states,
  1226. encoder_attention_mask=None,
  1227. inputs_embeds=None,
  1228. output_attentions=False,
  1229. output_hidden_states=output_hidden_states,
  1230. use_cache=use_cache,
  1231. past_key_values=past_key_values,
  1232. return_dict=return_dict,
  1233. **kwargs_decoder,
  1234. )
  1235. logits = decoder_outputs.logits
  1236. return logits, labels
  1237. # forward for export
  1238. def forward(self, inputs, targets=None):
  1239. self.is_export = False if self.training else True
  1240. if not self.training:
  1241. encoder_outputs = inputs
  1242. model_kwargs = {
  1243. "output_attentions": False,
  1244. "output_hidden_states": False,
  1245. "use_cache": True,
  1246. }
  1247. if self.is_export:
  1248. word_pred = self.generate_export(encoder_outputs, model_kwargs)
  1249. else:
  1250. word_pred = self.generate(encoder_outputs, model_kwargs)
  1251. return word_pred
  1252. encoder_outputs, tgt_seq, mask = inputs
  1253. logits, masked_labels = self.forwad_train(encoder_outputs, tgt_seq, mask)
  1254. return logits, masked_labels