| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- This code is refer from:
- https://github.com/lukas-blecher/LaTeX-OCR/blob/main/pix2tex/models/transformer.py
- """
- import math
- import paddle
- from paddle import nn, einsum
- import paddle.nn.functional as F
- from functools import partial
- from inspect import isfunction
- from collections import namedtuple
- from paddle.nn.initializer import (
- TruncatedNormal,
- Constant,
- Normal,
- KaimingUniform,
- XavierUniform,
- )
- zeros_ = Constant(value=0.0)
- ones_ = Constant(value=1.0)
- normal_ = Normal(std=0.02)
- DEFAULT_DIM_HEAD = 64
- Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
- LayerIntermediates = namedtuple("Intermediates", ["hiddens", "attn_intermediates"])
- # helpers
- def exists(val):
- return val is not None
- def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
- class always:
- def __init__(self, val):
- self.val = val
- def __call__(self, *args, **kwargs):
- return self.val
- class not_equals:
- def __init__(self, val):
- self.val = val
- def __call__(self, x, *args, **kwargs):
- return x != self.val
- class equals:
- def __init__(self, val):
- self.val = val
- def __call__(self, x, *args, **kwargs):
- return x == self.val
- def max_neg_value(tensor):
- return -paddle.finfo(tensor.dtype).max
- def pick_and_pop(keys, d):
- values = list(map(lambda key: d.pop(key), keys))
- return dict(zip(keys, values))
- def group_dict_by_key(cond, d):
- return_val = [dict(), dict()]
- for key in d.keys():
- match = bool(cond(key))
- ind = int(not match)
- return_val[ind][key] = d[key]
- return (*return_val,)
- def string_begins_with(prefix, str):
- return str.startswith(prefix)
- def group_by_key_prefix(prefix, d):
- return group_dict_by_key(partial(string_begins_with, prefix), d)
- def groupby_prefix_and_trim(prefix, d):
- kwargs_with_prefix, kwargs = group_dict_by_key(
- partial(string_begins_with, prefix), d
- )
- kwargs_without_prefix = dict(
- map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
- )
- return kwargs_without_prefix, kwargs
- # positional embeddings
- class DepthWiseConv1d(nn.Layer):
- def __init__(
- self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True, groups=False
- ):
- super().__init__()
- groups = default(groups, dim_in)
- self.net = nn.Sequential(
- nn.Conv1D(
- dim_in,
- dim_in,
- kernel_size=kernel_size,
- padding=padding,
- groups=dim_in,
- stride=stride,
- bias_attr=bias,
- ),
- nn.Conv1D(dim_in, dim_out, 1),
- )
- def forward(self, x):
- return self.net(x)
- class AbsolutePositionalEmbedding(nn.Layer):
- def __init__(self, dim, max_seq_len):
- super().__init__()
- self.emb = nn.Embedding(max_seq_len, dim)
- self.init_()
- def init_(self):
- normal_(self.emb.weight)
- def forward(self, x):
- n = paddle.arange(x.shape[1])
- return self.emb(n)[None, :, :]
- class FixedPositionalEmbedding(nn.Layer):
- def __init__(self, dim):
- super().__init__()
- inv_freq = 1.0 / (10000 ** (paddle.arange(0, dim, 2).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
- def forward(self, x, seq_dim=1, offset=0):
- t = (
- paddle.arange(
- x.shape[seq_dim],
- ).type_as(self.inv_freq)
- + offset
- )
- sinusoid_inp = paddle.einsum("i , j -> i j", t, self.inv_freq)
- emb = paddle.concat((sinusoid_inp.sin(), sinusoid_inp.cos()), axis=-1)
- return emb[None, :, :]
- class Scale(nn.Layer):
- def __init__(self, value, fn):
- super().__init__()
- self.value = value
- self.fn = fn
- def forward(self, x, **kwargs):
- x, *rest = self.fn(x, **kwargs)
- return (x * self.value, *rest)
- class Rezero(nn.Layer):
- def __init__(self, fn):
- super().__init__()
- self.fn = fn
- self.g = paddle.create_parameter([1], dtype="float32")
- zeros_(self.g)
- def forward(self, x, **kwargs):
- x, *rest = self.fn(x, **kwargs)
- return (x * self.g, *rest)
- class ScaleNorm(nn.Layer):
- def __init__(self, dim, eps=1e-5):
- super().__init__()
- self.scale = dim**-0.5
- self.eps = eps
- self.g = paddle.create_parameter([1], dtype="float32")
- ones_(self.g)
- def forward(self, x):
- norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale
- return x / norm.clamp(min=self.eps) * self.g
- class RMSNorm(nn.Layer):
- def __init__(self, dim, eps=1e-8):
- super().__init__()
- self.scale = dim**-0.5
- self.eps = eps
- self.g = paddle.create_parameter([dim])
- ones_(self.g)
- def forward(self, x):
- norm = paddle.norm(x, axis=-1, keepdim=True) * self.scale
- return x / norm.clamp(min=self.eps) * self.g
- class Residual(nn.Layer):
- def forward(self, x, residual):
- return x + residual
- class GEGLU(nn.Layer):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, axis=-1)
- return x * F.gelu(gate)
- class FeedForward(nn.Layer):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = (
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
- if not glu
- else GEGLU(dim, inner_dim)
- )
- self.net = nn.Sequential(
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
- )
- def forward(self, x):
- return self.net(x)
- class Attention(nn.Layer):
- def __init__(
- self,
- dim,
- dim_head=DEFAULT_DIM_HEAD,
- heads=8,
- causal=False,
- mask=None,
- talking_heads=False,
- collab_heads=False,
- collab_compression=0.3,
- sparse_topk=None,
- use_entmax15=False,
- num_mem_kv=0,
- dropout=0.0,
- on_attn=False,
- gate_values=False,
- is_export=False,
- ):
- super().__init__()
- self.scale = dim_head**-0.5
- self.heads = heads
- self.causal = causal
- self.mask = mask
- self.is_export = is_export
- qk_dim = v_dim = dim_head * heads
- # collaborative heads
- self.collab_heads = collab_heads
- if self.collab_heads:
- qk_dim = int(collab_compression * qk_dim)
- self.collab_mixing = nn.Parameter(paddle.randn(heads, qk_dim))
- self.to_q = nn.Linear(dim, qk_dim, bias_attr=False)
- self.to_k = nn.Linear(dim, qk_dim, bias_attr=False)
- self.to_v = nn.Linear(dim, v_dim, bias_attr=False)
- self.dropout = nn.Dropout(dropout)
- # add GLU gating for aggregated values, from alphafold2
- self.to_v_gate = None
- if gate_values:
- self.to_v_gate = nn.Linear(dim, v_dim)
- zeros_(self.to_v_gate.weight)
- ones_(self.to_v_gate.bias)
- # talking heads
- self.talking_heads = talking_heads
- if talking_heads:
- self.pre_softmax_proj = nn.Parameter(paddle.randn(heads, heads))
- self.post_softmax_proj = nn.Parameter(paddle.randn(heads, heads))
- # explicit topk sparse attention
- self.sparse_topk = sparse_topk
- self.attn_fn = F.softmax
- # add memory key / values
- self.num_mem_kv = num_mem_kv
- if num_mem_kv > 0:
- self.mem_k = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head))
- self.mem_v = nn.Parameter(paddle.randn(heads, num_mem_kv, dim_head))
- # attention on attention
- self.attn_on_attn = on_attn
- self.to_out = (
- nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU())
- if on_attn
- else nn.Linear(v_dim, dim)
- )
- def forward(
- self,
- x,
- context=None,
- mask=None,
- context_mask=None,
- rel_pos=None,
- sinusoidal_emb=None,
- rotary_pos_emb=None,
- prev_attn=None,
- mem=None,
- seq_len=0,
- ):
- if not self.training:
- self.is_export = True
- b, n, _, h, talking_heads, collab_heads, has_context = (
- *x.shape,
- self.heads,
- self.talking_heads,
- self.collab_heads,
- exists(context),
- )
- kv_input = default(context, x)
- q_input = x
- k_input = kv_input
- v_input = kv_input
- if exists(mem):
- k_input = paddle.concat((mem, k_input), axis=-2)
- v_input = paddle.concat((mem, v_input), axis=-2)
- if exists(sinusoidal_emb):
- # in shortformer, the query would start at a position offset depending on the past cached memory
- offset = k_input.shape[-2] - q_input.shape[-2]
- q_input = q_input + sinusoidal_emb(q_input, offset=offset)
- k_input = k_input + sinusoidal_emb(k_input)
- q = self.to_q(q_input)
- k = self.to_k(k_input)
- v = self.to_v(v_input)
- def rearrange_q_k_v(x, h, is_export):
- if is_export:
- b, n, h_d = paddle.shape(x)
- else:
- b, n, h_d = x.shape
- d = h_d // h
- return x.reshape([b, n, h, d]).transpose([0, 2, 1, 3])
- q, k, v = map(
- lambda t: rearrange_q_k_v(t, h, is_export=self.is_export), (q, k, v)
- )
- input_mask = None
- if any(map(exists, (mask, context_mask))):
- q_mask = default(
- mask,
- lambda: paddle.ones(
- (b, n),
- ).cast(paddle.bool),
- )
- k_mask = q_mask if not exists(context) else context_mask
- k_mask = default(
- k_mask, lambda: paddle.ones((b, k.shape[-2])).cast(paddle.bool)
- )
- q_mask = q_mask.reshape([q_mask.shape[0], 1, q_mask.shape[1], 1])
- k_mask = k_mask.reshape([k_mask.shape[0], 1, 1, k_mask.shape[1]])
- input_mask = q_mask * k_mask
- if collab_heads:
- k = k.expand(-1, h, -1, -1)
- dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
- mask_value = max_neg_value(dots)
- if exists(prev_attn):
- dots = dots + prev_attn
- pre_softmax_attn = dots.clone()
- if talking_heads:
- dots = einsum(
- "b h i j, h k -> b k i j", dots, self.pre_softmax_proj
- ).contiguous()
- if exists(rel_pos):
- dots = rel_pos(dots)
- input_mask = input_mask.cast(paddle.bool)
- if exists(input_mask):
- dots.masked_fill_(~input_mask, mask_value)
- del input_mask
- if self.causal:
- i, j = dots.shape[-2:]
- r = paddle.arange(i)
- r_shape = r.shape[0]
- mask = r.reshape([1, 1, r_shape, 1]) < r.reshape([1, 1, 1, r_shape])
- if self.is_export:
- pad_list = [
- paddle.to_tensor(0, dtype="int32"),
- paddle.to_tensor(0, dtype="int32"),
- paddle.to_tensor(j - i, dtype="int32"),
- paddle.to_tensor(0, dtype="int32"),
- ]
- mask = F.pad(
- mask.cast(paddle.int32),
- paddle.to_tensor(pad_list).cast(paddle.int32),
- value=False,
- ).cast(paddle.bool)
- dots = dots.masked_fill_(mask, mask_value)
- else:
- mask = F.pad(mask.cast(paddle.int32), (0, 0, j - i, 0), value=False)
- dots.masked_fill_(mask, mask_value)
- del mask
- if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
- top, _ = dots.topk(self.sparse_topk, dim=-1)
- vk = top[..., -1].unsqueeze(-1).expand_as(dots)
- mask = dots < vk
- dots.masked_fill_(mask, mask_value)
- del mask
- attn = self.attn_fn(dots, axis=-1)
- post_softmax_attn = attn.clone()
- attn = self.dropout(attn)
- if talking_heads:
- attn = einsum(
- "b h i j, h k -> b k i j", attn, self.post_softmax_proj
- ).contiguous()
- out = einsum("b h i j, b h j d -> b h i d", attn, v)
- b, h, n, d = out.shape
- out = out.transpose([0, 2, 1, 3]).reshape([b, n, h * d])
- if exists(self.to_v_gate):
- gates = self.gate_v(x)
- out = out * gates.sigmoid()
- intermediates = Intermediates(
- pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn
- )
- return self.to_out(out), intermediates
- class AttentionLayers(nn.Layer):
- def __init__(
- self,
- dim,
- depth,
- heads=8,
- causal=False,
- cross_attend=False,
- only_cross=False,
- use_scalenorm=False,
- use_rmsnorm=False,
- use_rezero=False,
- rel_pos_bias=False,
- rel_pos_num_buckets=32,
- rel_pos_max_distance=128,
- position_infused_attn=False,
- rotary_pos_emb=False,
- rotary_emb_dim=None,
- custom_layers=None,
- sandwich_coef=None,
- par_ratio=None,
- residual_attn=False,
- cross_residual_attn=False,
- macaron=False,
- pre_norm=True,
- gate_residual=False,
- is_export=False,
- **kwargs,
- ):
- super().__init__()
- ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
- attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
- dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
- self.dim = dim
- self.depth = depth
- self.layers = nn.LayerList([])
- self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
- self.pia_pos_emb = (
- FixedPositionalEmbedding(dim) if position_infused_attn else None
- )
- assert (
- rel_pos_num_buckets <= rel_pos_max_distance
- ), "number of relative position buckets must be less than the relative position max distance"
- self.pre_norm = pre_norm
- self.residual_attn = residual_attn
- self.cross_residual_attn = cross_residual_attn
- self.cross_attend = cross_attend
- self.rel_pos = None
- norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
- norm_class = RMSNorm if use_rmsnorm else norm_class
- norm_fn = partial(norm_class, dim)
- norm_fn = nn.Identity if use_rezero else norm_fn
- branch_fn = Rezero if use_rezero else None
- if cross_attend and not only_cross:
- default_block = ("a", "c", "f")
- elif cross_attend and only_cross:
- default_block = ("c", "f")
- else:
- default_block = ("a", "f")
- if macaron:
- default_block = ("f",) + default_block
- if exists(custom_layers):
- layer_types = custom_layers
- elif exists(par_ratio):
- par_depth = depth * len(default_block)
- assert 1 < par_ratio <= par_depth, "par ratio out of range"
- default_block = tuple(filter(not_equals("f"), default_block))
- par_attn = par_depth // par_ratio
- depth_cut = (
- par_depth * 2 // 3
- ) # 2 / 3 attention layer cutoff suggested by PAR paper
- par_width = (depth_cut + depth_cut // par_attn) // par_attn
- assert (
- len(default_block) <= par_width
- ), "default block is too large for par_ratio"
- par_block = default_block + ("f",) * (par_width - len(default_block))
- par_head = par_block * par_attn
- layer_types = par_head + ("f",) * (par_depth - len(par_head))
- elif exists(sandwich_coef):
- assert (
- sandwich_coef > 0 and sandwich_coef <= depth
- ), "sandwich coefficient should be less than the depth"
- layer_types = (
- ("a",) * sandwich_coef
- + default_block * (depth - sandwich_coef)
- + ("f",) * sandwich_coef
- )
- else:
- layer_types = default_block * depth
- self.layer_types = layer_types
- self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
- for layer_type in self.layer_types:
- if layer_type == "a":
- layer = Attention(
- dim, heads=heads, causal=causal, is_export=is_export, **attn_kwargs
- )
- elif layer_type == "c":
- layer = Attention(dim, heads=heads, is_export=is_export, **attn_kwargs)
- elif layer_type == "f":
- layer = FeedForward(dim, **ff_kwargs)
- layer = layer if not macaron else Scale(0.5, layer)
- else:
- raise Exception(f"invalid layer type {layer_type}")
- if isinstance(layer, Attention) and exists(branch_fn):
- layer = branch_fn(layer)
- residual_fn = Residual()
- self.layers.append(nn.LayerList([norm_fn(), layer, residual_fn]))
- def forward(
- self,
- x,
- context=None,
- mask=None,
- context_mask=None,
- mems=None,
- seq_len=0,
- return_hiddens=False,
- ):
- assert not (
- self.cross_attend ^ exists(context)
- ), "context must be passed in if cross_attend is set to True"
- hiddens = []
- intermediates = []
- prev_attn = None
- prev_cross_attn = None
- rotary_pos_emb = None
- mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
- for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
- zip(self.layer_types, self.layers)
- ):
- is_last = ind == (len(self.layers) - 1)
- if layer_type == "a":
- hiddens.append(x)
- layer_mem = mems.pop(0)
- residual = x
- if self.pre_norm:
- x = norm(x)
- if layer_type == "a":
- out, inter = block(
- x,
- mask=mask,
- sinusoidal_emb=self.pia_pos_emb,
- rel_pos=self.rel_pos,
- rotary_pos_emb=rotary_pos_emb,
- prev_attn=prev_attn,
- mem=layer_mem,
- )
- elif layer_type == "c":
- out, inter = block(
- x,
- context=context,
- mask=mask,
- context_mask=context_mask,
- prev_attn=prev_cross_attn,
- )
- elif layer_type == "f":
- out = block(x)
- x = residual_fn(out, residual)
- if layer_type in ("a", "c"):
- intermediates.append(inter)
- if layer_type == "a" and self.residual_attn:
- prev_attn = inter.pre_softmax_attn
- elif layer_type == "c" and self.cross_residual_attn:
- prev_cross_attn = inter.pre_softmax_attn
- if not self.pre_norm and not is_last:
- x = norm(x)
- if return_hiddens:
- intermediates = LayerIntermediates(
- hiddens=hiddens, attn_intermediates=intermediates
- )
- return x, intermediates
- return x
- class Encoder(AttentionLayers):
- def __init__(self, **kwargs):
- assert "causal" not in kwargs, "cannot set causality on encoder"
- super().__init__(causal=False, **kwargs)
- class Decoder(AttentionLayers):
- def __init__(self, **kwargs):
- assert "causal" not in kwargs, "cannot set causality on decoder"
- super().__init__(causal=True, **kwargs)
- class CrossAttender(AttentionLayers):
- def __init__(self, **kwargs):
- super().__init__(cross_attend=True, only_cross=True, **kwargs)
- def create_latex_parameter(shape):
- return paddle.create_parameter(
- shape=shape,
- dtype="float32",
- default_initializer=paddle.nn.initializer.Assign(paddle.randn(shape)),
- )
- class TransformerDecoder(nn.Layer):
- def __init__(
- self,
- *,
- num_tokens,
- max_seq_len,
- attn_layers,
- emb_dim=None,
- max_mem_len=0.0,
- emb_dropout=0.0,
- num_memory_tokens=None,
- tie_embedding=False,
- use_pos_emb=True,
- is_export=False,
- ):
- super().__init__()
- assert isinstance(
- attn_layers, AttentionLayers
- ), "attention layers must be one of Encoder or Decoder"
- dim = attn_layers.dim
- emb_dim = default(emb_dim, dim)
- self.max_seq_len = max_seq_len
- self.max_mem_len = max_mem_len
- self.token_emb = nn.Embedding(num_tokens, emb_dim)
- self.pos_emb = (
- AbsolutePositionalEmbedding(emb_dim, max_seq_len)
- if (use_pos_emb and not attn_layers.has_pos_emb)
- else always(0)
- )
- self.emb_dropout = nn.Dropout(emb_dropout)
- self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
- self.attn_layers = attn_layers
- self.norm = nn.LayerNorm(dim)
- self.is_export = is_export
- self.init_()
- self.to_logits = (
- nn.Linear(dim, num_tokens)
- if not tie_embedding
- else lambda t: t @ self.token_emb.weight.t()
- )
- # memory tokens (like [cls]) from Memory Transformers paper
- num_memory_tokens = default(num_memory_tokens, 0)
- self.num_memory_tokens = num_memory_tokens
- if num_memory_tokens > 0:
- self.memory_tokens = create_latex_parameter([num_memory_tokens, dim])
- # let funnel encoder know number of memory tokens, if specified
- # TODO: think of a cleaner solution
- if hasattr(attn_layers, "num_memory_tokens"):
- attn_layers.num_memory_tokens = num_memory_tokens
- def init_(self):
- normal_(self.token_emb.weight)
- def forward(
- self,
- x,
- return_embeddings=False,
- mask=None,
- return_mems=False,
- return_attn=False,
- seq_len=0,
- mems=None,
- **kwargs,
- ):
- b, n, num_mem = *x.shape, self.num_memory_tokens
- x = self.token_emb(x)
- x = x + self.pos_emb(x)
- x = self.emb_dropout(x)
- x = self.project_emb(x)
- x, intermediates = self.attn_layers(
- x, mask=mask, mems=mems, return_hiddens=True, seq_len=seq_len, **kwargs
- )
- x = self.norm(x)
- if paddle.device.get_device().startswith("npu"):
- x = x[:, num_mem:]
- else:
- mem, x = x[:, :num_mem], x[:, num_mem:]
- out = self.to_logits(x) if not return_embeddings else x
- if return_mems:
- hiddens = intermediates.hiddens
- new_mems = (
- list(map(lambda pair: paddle.concat(pair, axis=-2), zip(mems, hiddens)))
- if exists(mems)
- else hiddens
- )
- new_mems = list(
- map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)
- )
- return out, new_mems
- if return_attn:
- attn_maps = list(
- map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)
- )
- return out, attn_maps
- return out
- def top_p(logits, thres=0.9):
- sorted_logits, sorted_indices = paddle.sort(logits, descending=True)
- cum_probs = paddle.cumsum(F.softmax(sorted_logits, axis=-1), axis=-1)
- sorted_indices_to_remove = cum_probs > (1 - thres)
- sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
- sorted_indices_to_remove[:, 0] = 0
- sorted_logits[sorted_indices_to_remove] = float("-inf")
- return sorted_logits.scatter(1, sorted_indices, sorted_logits)
- # topk
- def top_k(logits, thres=0.9):
- k = int((1 - thres) * logits.shape[-1])
- val, ind = paddle.topk(logits, k)
- probs = paddle.full_like(logits, float("-inf"))
- probs = paddle.put_along_axis(probs, ind, val, 1)
- return probs
- class LaTeXOCRHead(nn.Layer):
- """Implementation of LaTeX OCR decoder.
- Args:
- encoded_feat: The encoded features with shape[N, 1, H//16, W//16]
- tgt_seq: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
- xi: The first N-1 LaTeX-OCR sequences in tgt_seq with shape [N, L-1]
- mask: The first N-1 LaTeX-OCR attention mask with shape [N, L-1] , L is the max sequence length
- Returns:
- The predicted LaTeX sequences with shape [N, L-1, C], C is the number of LaTeX classes
- """
- def __init__(
- self,
- net=None,
- in_channels=256,
- out_channels=256,
- pad_value=0,
- decoder_args=None,
- is_export=False,
- ):
- super().__init__()
- decoder = Decoder(
- dim=256, depth=4, heads=8, is_export=is_export, **decoder_args
- )
- transformer_decoder = TransformerDecoder(
- num_tokens=8000,
- max_seq_len=512,
- attn_layers=decoder,
- is_export=is_export,
- )
- self.temperature = 0.333
- self.bos_token = 1
- self.eos_token = 2
- self.max_length = 512
- self.pad_value = pad_value
- self.net = transformer_decoder
- self.max_seq_len = self.net.max_seq_len
- self.is_export = is_export
- @paddle.no_grad()
- def generate(
- self,
- start_tokens,
- seq_len,
- eos_token=None,
- temperature=1.0,
- filter_logits_fn=top_k,
- filter_thres=0.9,
- **kwargs,
- ):
- was_training = self.net.training
- num_dims = len(start_tokens.shape)
- if num_dims == 1:
- start_tokens = start_tokens[None, :]
- b, t = start_tokens.shape
- self.net.eval()
- out = start_tokens
- mask = kwargs.pop("mask", None)
- if mask is None:
- mask = paddle.full_like(out, True, dtype=paddle.bool)
- for _ in range(seq_len):
- x = out[:, -self.max_seq_len :]
- mask = mask[:, -self.max_seq_len :]
- logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
- if filter_logits_fn in {top_k, top_p}:
- filtered_logits = filter_logits_fn(logits, thres=filter_thres)
- probs = F.softmax(filtered_logits / temperature, axis=-1)
- else:
- raise NotImplementedError("The filter_logits_fn is not supported ")
- sample = paddle.multinomial(probs, 1)
- out = paddle.concat((out, sample), axis=-1)
- pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
- mask = paddle.concat((mask, pad_mask), axis=1)
- if (
- eos_token is not None
- and (
- paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
- ).all()
- ):
- break
- out = out[:, t:]
- if num_dims == 1:
- out = out.squeeze(0)
- return out
- @paddle.no_grad()
- def generate_export(
- self,
- start_tokens,
- seq_len,
- eos_token=None,
- context=None,
- temperature=1.0,
- filter_logits_fn=None,
- filter_thres=0.9,
- **kwargs,
- ):
- was_training = self.net.training
- num_dims = len(start_tokens.shape)
- if num_dims == 1:
- start_tokens = start_tokens[None, :]
- b, t = start_tokens.shape
- self.net.eval()
- out = start_tokens
- mask = kwargs.pop("mask", None)
- if mask is None:
- mask = paddle.full_like(out, True, dtype=paddle.bool)
- i_idx = paddle.full([], 0)
- while i_idx < paddle.to_tensor(seq_len):
- x = out[:, -self.max_seq_len :]
- paddle.jit.api.set_dynamic_shape(x, [-1, -1])
- mask = mask[:, -self.max_seq_len :]
- paddle.jit.api.set_dynamic_shape(mask, [-1, -1])
- logits = self.net(x, mask=mask, context=context, seq_len=i_idx, **kwargs)[
- :, -1, :
- ]
- if filter_logits_fn in {top_k, top_p}:
- filtered_logits = filter_logits_fn(logits, thres=filter_thres)
- probs = F.softmax(filtered_logits / temperature, axis=-1)
- sample = paddle.multinomial(probs, 1)
- out = paddle.concat((out, sample), axis=-1)
- pad_mask = paddle.full(shape=[mask.shape[0], 1], fill_value=1, dtype="bool")
- mask = paddle.concat((mask, pad_mask), axis=1)
- if (
- eos_token is not None
- and (
- paddle.cumsum((out == eos_token).cast(paddle.int64), 1)[:, -1] >= 1
- ).all()
- ):
- break
- i_idx += 1
- out = out[:, t:]
- if num_dims == 1:
- out = out.squeeze(0)
- return out
- # forward for export
- def forward(self, inputs, targets=None):
- if not self.training:
- self.is_export = True
- encoded_feat = inputs
- batch_num = encoded_feat.shape[0]
- bos_tensor = paddle.full([batch_num, 1], self.bos_token, dtype=paddle.int64)
- if self.is_export:
- word_pred = self.generate_export(
- bos_tensor,
- self.max_seq_len,
- eos_token=self.eos_token,
- context=encoded_feat,
- temperature=self.temperature,
- filter_logits_fn=top_k,
- )
- else:
- word_pred = self.generate(
- bos_tensor,
- self.max_seq_len,
- eos_token=self.eos_token,
- context=encoded_feat,
- temperature=self.temperature,
- filter_logits_fn=top_k,
- )
- return word_pred
- encoded_feat, tgt_seq, mask = inputs
- kwargs = {"context": encoded_feat, "mask": mask.cast(paddle.bool)}
- x = tgt_seq
- xi = x[:, :-1]
- mask = kwargs.get("mask", None)
- if mask is not None and mask.shape[1] == x.shape[1]:
- mask = mask[:, :-1]
- kwargs["mask"] = mask
- out = self.net(xi, **kwargs)
- return out
|