rec_latexocr_head.py 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/models/transformer.py
  17. """
  18. import math
  19. import paddle
  20. from paddle import nn, einsum
  21. import paddle.nn.functional as F
  22. from functools import partial
  23. from inspect import isfunction
  24. from collections import namedtuple
  25. from paddle.nn.initializer import (
  26. TruncatedNormal,
  27. Constant,
  28. Normal,
  29. KaimingUniform,
  30. XavierUniform,
  31. )
  32. zeros_ = Constant(value=0.0)
  33. ones_ = Constant(value=1.0)
  34. normal_ = Normal(std=0.02)
  35. DEFAULT_DIM_HEAD = 64
  36. Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
  37. LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
  38. # helpers
  39. def exists(val):
  40. return val is not None
  41. def default(val, d):
  42. if exists(val):
  43. return val
  44. return d() if isfunction(d) else d
  45. class always:
  46. def __init__(self, val):
  47. self.val = val
  48. def __call__(self, *args, **kwargs):
  49. return self.val
  50. class not_equals:
  51. def __init__(self, val):
  52. self.val = val
  53. def __call__(self, x, *args, **kwargs):
  54. return x != self.val
  55. class equals:
  56. def __init__(self, val):
  57. self.val = val
  58. def __call__(self, x, *args, **kwargs):
  59. return x == self.val
  60. def max_neg_value(tensor):
  61. return -paddle.finfo(tensor.dtype).max
  62. def pick_and_pop(keys, d):
  63. values = list(map(lambda key: d.pop(key), keys))
  64. return dict(zip(keys, values))
  65. def group_dict_by_key(cond, d):
  66. return_val = [dict(), dict()]
  67. for key in d.keys():
  68. match = bool(cond(key))
  69. ind = int(not match)
  70. return_val[ind][key] = d[key]
  71. return (*return_val,)
  72. def string_begins_with(prefix, str):
  73. return str.startswith(prefix)
  74. def group_by_key_prefix(prefix, d):
  75. return group_dict_by_key(partial(string_begins_with, prefix), d)
  76. def groupby_prefix_and_trim(prefix, d):
  77. kwargs_with_prefix, kwargs = group_dict_by_key(
  78. partial(string_begins_with, prefix), d
  79. )
  80. kwargs_without_prefix = dict(
  81. map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
  82. )
  83. return kwargs_without_prefix, kwargs
  84. # positional embeddings
  85. class DepthWiseConv1d(nn.Layer):
  86. def __init__(
  87. self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True, groups=False
  88. ):
  89. super().__init__()
  90. groups = default(groups, dim_in)
  91. self.net = nn.Sequential(
  92. nn.Conv1D(
  93. dim_in,
  94. dim_in,
  95. kernel_size=kernel_size,
  96. padding=padding,
  97. groups=dim_in,
  98. stride=stride,
  99. bias_attr=bias,
  100. ),
  101. nn.Conv1D(dim_in, dim_out, 1),
  102. )
  103. def forward(self, x):
  104. return self.net(x)
  105. class AbsolutePositionalEmbedding(nn.Layer):
  106. def __init__(self, dim, max_seq_len):
  107. super().__init__()
  108. self.emb = nn.Embedding(max_seq_len, dim)
  109. self.init_()
  110. def init_(self):
  111. normal_(self.emb.weight)
  112. def forward(self, x):
  113. n = paddle.arange(x.shape[1])
  114. return self.emb(n)[None, :, :]
  115. class FixedPositionalEmbedding(nn.Layer):
  116. def __init__(self, dim):
  117. super().__init__()
  118. inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2).float() / dim))
  119. self.register_buffer("inv_freq", inv_freq)
  120. def forward(self, x, seq_dim=1, offset=0):
  121. t = (
  122. paddle.arange(
  123. x.shape[seq_dim],
  124. ).type_as(self.inv_freq)
  125. + offset
  126. )
  127. sinusoid_inp = paddle.einsum("i , j -> i j", t, self.inv_freq)
  128. emb = paddle.concat((sinusoid_inp.sin(), sinusoid_inp.cos()), axis=-1)
  129. return emb[None, :, :]
  130. class Scale(nn.Layer):
  131. def __init__(self, value, fn):
  132. super().__init__()
  133. self.value = value
  134. self.fn = fn
  135. def forward(self, x, **kwargs):
  136. x, *rest = self.fn(x, **kwargs)
  137. return (x * self.value, *rest)
  138. class Rezero(nn.Layer):
  139. def __init__(self, fn):
  140. super().__init__()
  141. self.fn = fn
  142. self.g = paddle.create_parameter([1], dtype="float32")
  143. zeros_(self.g)
  144. def forward(self, x, **kwargs):
  145. x, *rest = self.fn(x, **kwargs)
  146. return (x * self.g, *rest)
  147. class ScaleNorm(nn.Layer):
  148. def __init__(self, dim, eps=1e-5):
  149. super().__init__()
  150. self.scale = dim**-0.5
  151. self.eps = eps
  152. self.g = paddle.create_parameter([1], dtype="float32")
  153. ones_(self.g)
  154. def forward(self, x):
  155. norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale
  156. return x / norm.clamp(min=self.eps) * self.g
  157. class RMSNorm(nn.Layer):
  158. def __init__(self, dim, eps=1e-8):
  159. super().__init__()
  160. self.scale = dim**-0.5
  161. self.eps = eps
  162. self.g = paddle.create_parameter([dim])
  163. ones_(self.g)
  164. def forward(self, x):
  165. norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale
  166. return x / norm.clamp(min=self.eps) * self.g
  167. class Residual(nn.Layer):
  168. def forward(self, x, residual):
  169. return x + residual
  170. class GEGLU(nn.Layer):
  171. def __init__(self, dim_in, dim_out):
  172. super().__init__()
  173. self.proj = nn.Linear(dim_in, dim_out * 2)
  174. def forward(self, x):
  175. x, gate = self.proj(x).chunk(2, axis=-1)
  176. return x * F.gelu(gate)
  177. class FeedForward(nn.Layer):
  178. def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
  179. super().__init__()
  180. inner_dim = int(dim * mult)
  181. dim_out = default(dim_out, dim)
  182. project_in = (
  183. nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
  184. if not glu
  185. else GEGLU(dim, inner_dim)
  186. )
  187. self.net = nn.Sequential(
  188. project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
  189. )
  190. def forward(self, x):
  191. return self.net(x)
  192. class Attention(nn.Layer):
  193. def __init__(
  194. self,
  195. dim,
  196. dim_head=DEFAULT_DIM_HEAD,
  197. heads=8,
  198. causal=False,
  199. mask=None,
  200. talking_heads=False,
  201. collab_heads=False,
  202. collab_compression=0.3,
  203. sparse_topk=None,
  204. use_entmax15=False,
  205. num_mem_kv=0,
  206. dropout=0.0,
  207. on_attn=False,
  208. gate_values=False,
  209. is_export=False,
  210. ):
  211. super().__init__()
  212. self.scale = dim_head**-0.5
  213. self.heads = heads
  214. self.causal = causal
  215. self.mask = mask
  216. self.is_export = is_export
  217. qk_dim = v_dim = dim_head * heads
  218. # collaborative heads
  219. self.collab_heads = collab_heads
  220. if self.collab_heads:
  221. qk_dim = int(collab_compression * qk_dim)
  222. self.collab_mixing = nn.Parameter(paddle.randn(heads, qk_dim))
  223. self.to_q = nn.Linear(dim, qk_dim, bias_attr=False)
  224. self.to_k = nn.Linear(dim, qk_dim, bias_attr=False)
  225. self.to_v = nn.Linear(dim, v_dim, bias_attr=False)
  226. self.dropout = nn.Dropout(dropout)
  227. # add GLU gating for aggregated values, from alphafold2
  228. self.to_v_gate = None
  229. if gate_values:
  230. self.to_v_gate = nn.Linear(dim, v_dim)
  231. zeros_(self.to_v_gate.weight)
  232. ones_(self.to_v_gate.bias)
  233. # talking heads
  234. self.talking_heads = talking_heads
  235. if talking_heads:
  236. self.pre_softmax_proj = nn.Parameter(paddle.randn(heads, heads))
  237. self.post_softmax_proj = nn.Parameter(paddle.randn(heads, heads))
  238. # explicit topk sparse attention
  239. self.sparse_topk = sparse_topk
  240. self.attn_fn = F.softmax
  241. # add memory key / values
  242. self.num_mem_kv = num_mem_kv
  243. if num_mem_kv > 0:
  244. self.mem_k = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head))
  245. self.mem_v = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head))
  246. # attention on attention
  247. self.attn_on_attn = on_attn
  248. self.to_out = (
  249. nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU())
  250. if on_attn
  251. else nn.Linear(v_dim, dim)
  252. )
  253. def forward(
  254. self,
  255. x,
  256. context=None,
  257. mask=None,
  258. context_mask=None,
  259. rel_pos=None,
  260. sinusoidal_emb=None,
  261. rotary_pos_emb=None,
  262. prev_attn=None,
  263. mem=None,
  264. seq_len=0,
  265. ):
  266. if not self.training:
  267. self.is_export = True
  268. b, n, _, h, talking_heads, collab_heads, has_context = (
  269. *x.shape,
  270. self.heads,
  271. self.talking_heads,
  272. self.collab_heads,
  273. exists(context),
  274. )
  275. kv_input = default(context, x)
  276. q_input = x
  277. k_input = kv_input
  278. v_input = kv_input
  279. if exists(mem):
  280. k_input = paddle.concat((mem, k_input), axis=-2)
  281. v_input = paddle.concat((mem, v_input), axis=-2)
  282. if exists(sinusoidal_emb):
  283. # in shortformer, the query would start at a position offset depending on the past cached memory
  284. offset = k_input.shape[-2] - q_input.shape[-2]
  285. q_input = q_input + sinusoidal_emb(q_input, offset=offset)
  286. k_input = k_input + sinusoidal_emb(k_input)
  287. q = self.to_q(q_input)
  288. k = self.to_k(k_input)
  289. v = self.to_v(v_input)
  290. def rearrange_q_k_v(x, h, is_export):
  291. if is_export:
  292. b, n, h_d = paddle.shape(x)
  293. else:
  294. b, n, h_d = x.shape
  295. d = h_d // h
  296. return x.reshape([b, n, h, d]).transpose([0, 2, 1, 3])
  297. q, k, v = map(
  298. lambda t: rearrange_q_k_v(t, h, is_export=self.is_export), (q, k, v)
  299. )
  300. input_mask = None
  301. if any(map(exists, (mask, context_mask))):
  302. q_mask = default(
  303. mask,
  304. lambda: paddle.ones(
  305. (b, n),
  306. ).cast(paddle.bool),
  307. )
  308. k_mask = q_mask if not exists(context) else context_mask
  309. k_mask = default(
  310. k_mask, lambda: paddle.ones((b, k.shape[-2])).cast(paddle.bool)
  311. )
  312. q_mask = q_mask.reshape([q_mask.shape[0], 1, q_mask.shape[1], 1])
  313. k_mask = k_mask.reshape([k_mask.shape[0], 1, 1, k_mask.shape[1]])
  314. input_mask = q_mask * k_mask
  315. if collab_heads:
  316. k = k.expand(-1, h, -1, -1)
  317. dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
  318. mask_value = max_neg_value(dots)
  319. if exists(prev_attn):
  320. dots = dots + prev_attn
  321. pre_softmax_attn = dots.clone()
  322. if talking_heads:
  323. dots = einsum(
  324. "b h i j, h k -> b k i j", dots, self.pre_softmax_proj
  325. ).contiguous()
  326. if exists(rel_pos):
  327. dots = rel_pos(dots)
  328. input_mask = input_mask.cast(paddle.bool)
  329. if exists(input_mask):
  330. dots.masked_fill_(~input_mask, mask_value)
  331. del input_mask
  332. if self.causal:
  333. i, j = dots.shape[-2:]
  334. r = paddle.arange(i)
  335. r_shape = r.shape[0]
  336. mask = r.reshape([1, 1, r_shape, 1]) < r.reshape([1, 1, 1, r_shape])
  337. if self.is_export:
  338. pad_list = [
  339. paddle.to_tensor(0, dtype="int32"),
  340. paddle.to_tensor(0, dtype="int32"),
  341. paddle.to_tensor(j - i, dtype="int32"),
  342. paddle.to_tensor(0, dtype="int32"),
  343. ]
  344. mask = F.pad(
  345. mask.cast(paddle.int32),
  346. paddle.to_tensor(pad_list).cast(paddle.int32),
  347. value=False,
  348. ).cast(paddle.bool)
  349. dots = dots.masked_fill_(mask, mask_value)
  350. else:
  351. mask = F.pad(mask.cast(paddle.int32), (0, 0, j - i, 0), value=False)
  352. dots.masked_fill_(mask, mask_value)
  353. del mask
  354. if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
  355. top, _ = dots.topk(self.sparse_topk, dim=-1)
  356. vk = top[..., -1].unsqueeze(-1).expand_as(dots)
  357. mask = dots < vk
  358. dots.masked_fill_(mask, mask_value)
  359. del mask
  360. attn = self.attn_fn(dots, axis=-1)
  361. post_softmax_attn = attn.clone()
  362. attn = self.dropout(attn)
  363. if talking_heads:
  364. attn = einsum(
  365. "b h i j, h k -> b k i j", attn, self.post_softmax_proj
  366. ).contiguous()
  367. out = einsum("b h i j, b h j d -> b h i d", attn, v)
  368. b, h, n, d = out.shape
  369. out = out.transpose([0, 2, 1, 3]).reshape([b, n, h * d])
  370. if exists(self.to_v_gate):
  371. gates = self.gate_v(x)
  372. out = out * gates.sigmoid()
  373. intermediates = Intermediates(
  374. pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
  375. )
  376. return self.to_out(out), intermediates
  377. class AttentionLayers(nn.Layer):
  378. def __init__(
  379. self,
  380. dim,
  381. depth,
  382. heads=8,
  383. causal=False,
  384. cross_attend=False,
  385. only_cross=False,
  386. use_scalenorm=False,
  387. use_rmsnorm=False,
  388. use_rezero=False,
  389. rel_pos_bias=False,
  390. rel_pos_num_buckets=32,
  391. rel_pos_max_distance=128,
  392. position_infused_attn=False,
  393. rotary_pos_emb=False,
  394. rotary_emb_dim=None,
  395. custom_layers=None,
  396. sandwich_coef=None,
  397. par_ratio=None,
  398. residual_attn=False,
  399. cross_residual_attn=False,
  400. macaron=False,
  401. pre_norm=True,
  402. gate_residual=False,
  403. is_export=False,
  404. **kwargs,
  405. ):
  406. super().__init__()
  407. ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
  408. attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
  409. dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
  410. self.dim = dim
  411. self.depth = depth
  412. self.layers = nn.LayerList([])
  413. self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
  414. self.pia_pos_emb = (
  415. FixedPositionalEmbedding(dim) if position_infused_attn else None
  416. )
  417. assert (
  418. rel_pos_num_buckets <= rel_pos_max_distance
  419. ), "number of relative position buckets must be less than the relative position max distance"
  420. self.pre_norm = pre_norm
  421. self.residual_attn = residual_attn
  422. self.cross_residual_attn = cross_residual_attn
  423. self.cross_attend = cross_attend
  424. self.rel_pos = None
  425. norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
  426. norm_class = RMSNorm if use_rmsnorm else norm_class
  427. norm_fn = partial(norm_class, dim)
  428. norm_fn = nn.Identity if use_rezero else norm_fn
  429. branch_fn = Rezero if use_rezero else None
  430. if cross_attend and not only_cross:
  431. default_block = ("a", "c", "f")
  432. elif cross_attend and only_cross:
  433. default_block = ("c", "f")
  434. else:
  435. default_block = ("a", "f")
  436. if macaron:
  437. default_block = ("f",) + default_block
  438. if exists(custom_layers):
  439. layer_types = custom_layers
  440. elif exists(par_ratio):
  441. par_depth = depth * len(default_block)
  442. assert 1 < par_ratio <= par_depth, "par ratio out of range"
  443. default_block = tuple(filter(not_equals("f"), default_block))
  444. par_attn = par_depth // par_ratio
  445. depth_cut = (
  446. par_depth * 2 // 3
  447. ) # 2 / 3 attention layer cutoff suggested by PAR paper
  448. par_width = (depth_cut + depth_cut // par_attn) // par_attn
  449. assert (
  450. len(default_block) <= par_width
  451. ), "default block is too large for par_ratio"
  452. par_block = default_block + ("f",) * (par_width - len(default_block))
  453. par_head = par_block * par_attn
  454. layer_types = par_head + ("f",) * (par_depth - len(par_head))
  455. elif exists(sandwich_coef):
  456. assert (
  457. sandwich_coef > 0 and sandwich_coef <= depth
  458. ), "sandwich coefficient should be less than the depth"
  459. layer_types = (
  460. ("a",) * sandwich_coef
  461. + default_block * (depth - sandwich_coef)
  462. + ("f",) * sandwich_coef
  463. )
  464. else:
  465. layer_types = default_block * depth
  466. self.layer_types = layer_types
  467. self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
  468. for layer_type in self.layer_types:
  469. if layer_type == "a":
  470. layer = Attention(
  471. dim, heads=heads, causal=causal, is_export=is_export, **attn_kwargs
  472. )
  473. elif layer_type == "c":
  474. layer = Attention(dim, heads=heads, is_export=is_export, **attn_kwargs)
  475. elif layer_type == "f":
  476. layer = FeedForward(dim, **ff_kwargs)
  477. layer = layer if not macaron else Scale(0.5, layer)
  478. else:
  479. raise Exception(f"invalid layer type {layer_type}")
  480. if isinstance(layer, Attention) and exists(branch_fn):
  481. layer = branch_fn(layer)
  482. residual_fn = Residual()
  483. self.layers.append(nn.LayerList([norm_fn(), layer, residual_fn]))
  484. def forward(
  485. self,
  486. x,
  487. context=None,
  488. mask=None,
  489. context_mask=None,
  490. mems=None,
  491. seq_len=0,
  492. return_hiddens=False,
  493. ):
  494. assert not (
  495. self.cross_attend ^ exists(context)
  496. ), "context must be passed in if cross_attend is set to True"
  497. hiddens = []
  498. intermediates = []
  499. prev_attn = None
  500. prev_cross_attn = None
  501. rotary_pos_emb = None
  502. mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
  503. for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
  504. zip(self.layer_types, self.layers)
  505. ):
  506. is_last = ind == (len(self.layers) - 1)
  507. if layer_type == "a":
  508. hiddens.append(x)
  509. layer_mem = mems.pop(0)
  510. residual = x
  511. if self.pre_norm:
  512. x = norm(x)
  513. if layer_type == "a":
  514. out, inter = block(
  515. x,
  516. mask=mask,
  517. sinusoidal_emb=self.pia_pos_emb,
  518. rel_pos=self.rel_pos,
  519. rotary_pos_emb=rotary_pos_emb,
  520. prev_attn=prev_attn,
  521. mem=layer_mem,
  522. )
  523. elif layer_type == "c":
  524. out, inter = block(
  525. x,
  526. context=context,
  527. mask=mask,
  528. context_mask=context_mask,
  529. prev_attn=prev_cross_attn,
  530. )
  531. elif layer_type == "f":
  532. out = block(x)
  533. x = residual_fn(out, residual)
  534. if layer_type in ("a", "c"):
  535. intermediates.append(inter)
  536. if layer_type == "a" and self.residual_attn:
  537. prev_attn = inter.pre_softmax_attn
  538. elif layer_type == "c" and self.cross_residual_attn:
  539. prev_cross_attn = inter.pre_softmax_attn
  540. if not self.pre_norm and not is_last:
  541. x = norm(x)
  542. if return_hiddens:
  543. intermediates = LayerIntermediates(
  544. hiddens=hiddens, attn_intermediates=intermediates
  545. )
  546. return x, intermediates
  547. return x
  548. class Encoder(AttentionLayers):
  549. def __init__(self, **kwargs):
  550. assert "causal" not in kwargs, "cannot set causality on encoder"
  551. super().__init__(causal=False, **kwargs)
  552. class Decoder(AttentionLayers):
  553. def __init__(self, **kwargs):
  554. assert "causal" not in kwargs, "cannot set causality on decoder"
  555. super().__init__(causal=True, **kwargs)
  556. class CrossAttender(AttentionLayers):
  557. def __init__(self, **kwargs):
  558. super().__init__(cross_attend=True, only_cross=True, **kwargs)
  559. def create_latex_parameter(shape):
  560. return paddle.create_parameter(
  561. shape=shape,
  562. dtype="float32",
  563. default_initializer=paddle.nn.initializer.Assign(paddle.randn(shape)),
  564. )
  565. class TransformerDecoder(nn.Layer):
  566. def __init__(
  567. self,
  568. *,
  569. num_tokens,
  570. max_seq_len,
  571. attn_layers,
  572. emb_dim=None,
  573. max_mem_len=0.0,
  574. emb_dropout=0.0,
  575. num_memory_tokens=None,
  576. tie_embedding=False,
  577. use_pos_emb=True,
  578. is_export=False,
  579. ):
  580. super().__init__()
  581. assert isinstance(
  582. attn_layers, AttentionLayers
  583. ), "attention layers must be one of Encoder or Decoder"
  584. dim = attn_layers.dim
  585. emb_dim = default(emb_dim, dim)
  586. self.max_seq_len = max_seq_len
  587. self.max_mem_len = max_mem_len
  588. self.token_emb = nn.Embedding(num_tokens, emb_dim)
  589. self.pos_emb = (
  590. AbsolutePositionalEmbedding(emb_dim, max_seq_len)
  591. if (use_pos_emb and not attn_layers.has_pos_emb)
  592. else always(0)
  593. )
  594. self.emb_dropout = nn.Dropout(emb_dropout)
  595. self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
  596. self.attn_layers = attn_layers
  597. self.norm = nn.LayerNorm(dim)
  598. self.is_export = is_export
  599. self.init_()
  600. self.to_logits = (
  601. nn.Linear(dim, num_tokens)
  602. if not tie_embedding
  603. else lambda t: t @ self.token_emb.weight.t()
  604. )
  605. # memory tokens (like [cls]) from Memory Transformers paper
  606. num_memory_tokens = default(num_memory_tokens, 0)
  607. self.num_memory_tokens = num_memory_tokens
  608. if num_memory_tokens > 0:
  609. self.memory_tokens = create_latex_parameter([num_memory_tokens, dim])
  610. # let funnel encoder know number of memory tokens, if specified
  611. # TODO: think of a cleaner solution
  612. if hasattr(attn_layers, "num_memory_tokens"):
  613. attn_layers.num_memory_tokens = num_memory_tokens
  614. def init_(self):
  615. normal_(self.token_emb.weight)
  616. def forward(
  617. self,
  618. x,
  619. return_embeddings=False,
  620. mask=None,
  621. return_mems=False,
  622. return_attn=False,
  623. seq_len=0,
  624. mems=None,
  625. **kwargs,
  626. ):
  627. b, n, num_mem = *x.shape, self.num_memory_tokens
  628. x = self.token_emb(x)
  629. x = x + self.pos_emb(x)
  630. x = self.emb_dropout(x)
  631. x = self.project_emb(x)
  632. x, intermediates = self.attn_layers(
  633. x, mask=mask, mems=mems, return_hiddens=True, seq_len=seq_len, **kwargs
  634. )
  635. x = self.norm(x)
  636. if paddle.device.get_device().startswith("npu"):
  637. x = x[:, num_mem:]
  638. else:
  639. mem, x = x[:, :num_mem], x[:, num_mem:]
  640. out = self.to_logits(x) if not return_embeddings else x
  641. if return_mems:
  642. hiddens = intermediates.hiddens
  643. new_mems = (
  644. list(map(lambda pair: paddle.concat(pair, axis=-2), zip(mems, hiddens)))
  645. if exists(mems)
  646. else hiddens
  647. )
  648. new_mems = list(
  649. map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
  650. )
  651. return out, new_mems
  652. if return_attn:
  653. attn_maps = list(
  654. map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
  655. )
  656. return out, attn_maps
  657. return out
  658. def top_p(logits, thres=0.9):
  659. sorted_logits, sorted_indices = paddle.sort(logits, descending=True)
  660. cum_probs = paddle.cumsum(F.softmax(sorted_logits, axis=-1), axis=-1)
  661. sorted_indices_to_remove = cum_probs > (1 - thres)
  662. sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
  663. sorted_indices_to_remove[:, 0] = 0
  664. sorted_logits[sorted_indices_to_remove] = float("-inf")
  665. return sorted_logits.scatter(1, sorted_indices, sorted_logits)
  666. # topk
  667. def top_k(logits, thres=0.9):
  668. k = int((1 - thres) * logits.shape[-1])
  669. val, ind = paddle.topk(logits, k)
  670. probs = paddle.full_like(logits, float("-inf"))
  671. probs = paddle.put_along_axis(probs, ind, val, 1)
  672. return probs
  673. class LaTeXOCRHead(nn.Layer):
  674. """Implementation of LaTeX OCR decoder.
  675. Args:
  676. encoded_feat: The encoded features with shape[N, 1, H//16, W//16]
  677. tgt_seq: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
  678. xi: The first N-1 LaTeX-OCR sequences in tgt_seq with shape [N, L-1]
  679. mask: The first N-1 LaTeX-OCR attention mask with shape [N, L-1] , L is the max sequence length
  680. Returns:
  681. The predicted LaTeX sequences with shape [N, L-1, C], C is the number of LaTeX classes
  682. """
  683. def __init__(
  684. self,
  685. net=None,
  686. in_channels=256,
  687. out_channels=256,
  688. pad_value=0,
  689. decoder_args=None,
  690. is_export=False,
  691. ):
  692. super().__init__()
  693. decoder = Decoder(
  694. dim=256, depth=4, heads=8, is_export=is_export, **decoder_args
  695. )
  696. transformer_decoder = TransformerDecoder(
  697. num_tokens=8000,
  698. max_seq_len=512,
  699. attn_layers=decoder,
  700. is_export=is_export,
  701. )
  702. self.temperature = 0.333
  703. self.bos_token = 1
  704. self.eos_token = 2
  705. self.max_length = 512
  706. self.pad_value = pad_value
  707. self.net = transformer_decoder
  708. self.max_seq_len = self.net.max_seq_len
  709. self.is_export = is_export
  710. @paddle.no_grad()
  711. def generate(
  712. self,
  713. start_tokens,
  714. seq_len,
  715. eos_token=None,
  716. temperature=1.0,
  717. filter_logits_fn=top_k,
  718. filter_thres=0.9,
  719. **kwargs,
  720. ):
  721. was_training = self.net.training
  722. num_dims = len(start_tokens.shape)
  723. if num_dims == 1:
  724. start_tokens = start_tokens[None, :]
  725. b, t = start_tokens.shape
  726. self.net.eval()
  727. out = start_tokens
  728. mask = kwargs.pop("mask", None)
  729. if mask is None:
  730. mask = paddle.full_like(out, True, dtype=paddle.bool)
  731. for _ in range(seq_len):
  732. x = out[:, -self.max_seq_len :]
  733. mask = mask[:, -self.max_seq_len :]
  734. logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
  735. if filter_logits_fn in {top_k, top_p}:
  736. filtered_logits = filter_logits_fn(logits, thres=filter_thres)
  737. probs = F.softmax(filtered_logits / temperature, axis=-1)
  738. else:
  739. raise NotImplementedError("The filter_logits_fn is not supported ")
  740. sample = paddle.multinomial(probs, 1)
  741. out = paddle.concat((out, sample), axis=-1)
  742. pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
  743. mask = paddle.concat((mask, pad_mask), axis=1)
  744. if (
  745. eos_token is not None
  746. and (
  747. paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
  748. ).all()
  749. ):
  750. break
  751. out = out[:, t:]
  752. if num_dims == 1:
  753. out = out.squeeze(0)
  754. return out
  755. @paddle.no_grad()
  756. def generate_export(
  757. self,
  758. start_tokens,
  759. seq_len,
  760. eos_token=None,
  761. context=None,
  762. temperature=1.0,
  763. filter_logits_fn=None,
  764. filter_thres=0.9,
  765. **kwargs,
  766. ):
  767. was_training = self.net.training
  768. num_dims = len(start_tokens.shape)
  769. if num_dims == 1:
  770. start_tokens = start_tokens[None, :]
  771. b, t = start_tokens.shape
  772. self.net.eval()
  773. out = start_tokens
  774. mask = kwargs.pop("mask", None)
  775. if mask is None:
  776. mask = paddle.full_like(out, True, dtype=paddle.bool)
  777. i_idx = paddle.full([], 0)
  778. while i_idx < paddle.to_tensor(seq_len):
  779. x = out[:, -self.max_seq_len :]
  780. paddle.jit.api.set_dynamic_shape(x, [-1, -1])
  781. mask = mask[:, -self.max_seq_len :]
  782. paddle.jit.api.set_dynamic_shape(mask, [-1, -1])
  783. logits = self.net(x, mask=mask, context=context, seq_len=i_idx, **kwargs)[
  784. :, -1, :
  785. ]
  786. if filter_logits_fn in {top_k, top_p}:
  787. filtered_logits = filter_logits_fn(logits, thres=filter_thres)
  788. probs = F.softmax(filtered_logits / temperature, axis=-1)
  789. sample = paddle.multinomial(probs, 1)
  790. out = paddle.concat((out, sample), axis=-1)
  791. pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
  792. mask = paddle.concat((mask, pad_mask), axis=1)
  793. if (
  794. eos_token is not None
  795. and (
  796. paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
  797. ).all()
  798. ):
  799. break
  800. i_idx += 1
  801. out = out[:, t:]
  802. if num_dims == 1:
  803. out = out.squeeze(0)
  804. return out
  805. # forward for export
  806. def forward(self, inputs, targets=None):
  807. if not self.training:
  808. self.is_export = True
  809. encoded_feat = inputs
  810. batch_num = encoded_feat.shape[0]
  811. bos_tensor = paddle.full([batch_num, 1], self.bos_token, dtype=paddle.int64)
  812. if self.is_export:
  813. word_pred = self.generate_export(
  814. bos_tensor,
  815. self.max_seq_len,
  816. eos_token=self.eos_token,
  817. context=encoded_feat,
  818. temperature=self.temperature,
  819. filter_logits_fn=top_k,
  820. )
  821. else:
  822. word_pred = self.generate(
  823. bos_tensor,
  824. self.max_seq_len,
  825. eos_token=self.eos_token,
  826. context=encoded_feat,
  827. temperature=self.temperature,
  828. filter_logits_fn=top_k,
  829. )
  830. return word_pred
  831. encoded_feat, tgt_seq, mask = inputs
  832. kwargs = {"context": encoded_feat, "mask": mask.cast(paddle.bool)}
  833. x = tgt_seq
  834. xi = x[:, :-1]
  835. mask = kwargs.get("mask", None)
  836. if mask is not None and mask.shape[1] == x.shape[1]:
  837. mask = mask[:, :-1]
  838. kwargs["mask"] = mask
  839. out = self.net(xi, **kwargs)
  840. return out