| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
- # reference: https://arxiv.org/abs/2207.06966
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import math
- import paddle
- from paddle import nn, ParamAttr
- from paddle.nn import functional as F
- import numpy as np
- from .self_attention import WrapEncoderForFeature
- from .self_attention import WrapEncoder
- from collections import OrderedDict
- from typing import Optional
- import copy
- from itertools import permutations
- class DecoderLayer(paddle.nn.Layer):
- """A Transformer decoder layer supporting two-stream attention (XLNet)
- This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
- def __init__(
- self,
- d_model,
- nhead,
- dim_feedforward=2048,
- dropout=0.1,
- activation="gelu",
- layer_norm_eps=1e-05,
- ):
- super().__init__()
- self.self_attn = paddle.nn.MultiHeadAttention(
- d_model, nhead, dropout=dropout, need_weights=True
- ) # paddle.nn.MultiHeadAttention默认为batch_first模式
- self.cross_attn = paddle.nn.MultiHeadAttention(
- d_model, nhead, dropout=dropout, need_weights=True
- )
- self.linear1 = paddle.nn.Linear(
- in_features=d_model, out_features=dim_feedforward
- )
- self.dropout = paddle.nn.Dropout(p=dropout)
- self.linear2 = paddle.nn.Linear(
- in_features=dim_feedforward, out_features=d_model
- )
- self.norm1 = paddle.nn.LayerNorm(
- normalized_shape=d_model, epsilon=layer_norm_eps
- )
- self.norm2 = paddle.nn.LayerNorm(
- normalized_shape=d_model, epsilon=layer_norm_eps
- )
- self.norm_q = paddle.nn.LayerNorm(
- normalized_shape=d_model, epsilon=layer_norm_eps
- )
- self.norm_c = paddle.nn.LayerNorm(
- normalized_shape=d_model, epsilon=layer_norm_eps
- )
- self.dropout1 = paddle.nn.Dropout(p=dropout)
- self.dropout2 = paddle.nn.Dropout(p=dropout)
- self.dropout3 = paddle.nn.Dropout(p=dropout)
- if activation == "gelu":
- self.activation = paddle.nn.GELU()
- def __setstate__(self, state):
- if "activation" not in state:
- state["activation"] = paddle.nn.functional.gelu
- super().__setstate__(state)
- def forward_stream(
- self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask
- ):
- """Forward pass for a single stream (i.e. content or query)
- tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
- Both tgt_kv and memory are expected to be LayerNorm'd too.
- memory is LayerNorm'd by ViT.
- """
- if tgt_key_padding_mask is not None:
- tgt_mask1 = (tgt_mask != float("-inf"))[None, None, :, :] & (
- tgt_key_padding_mask[:, None, None, :] == False
- )
- tgt2, sa_weights = self.self_attn(
- tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1
- )
- else:
- tgt2, sa_weights = self.self_attn(
- tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask
- )
- tgt = tgt + self.dropout1(tgt2)
- tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
- tgt = tgt + self.dropout2(tgt2)
- tgt2 = self.linear2(
- self.dropout(self.activation(self.linear1(self.norm2(tgt))))
- )
- tgt = tgt + self.dropout3(tgt2)
- return tgt, sa_weights, ca_weights
- def forward(
- self,
- query,
- content,
- memory,
- query_mask=None,
- content_mask=None,
- content_key_padding_mask=None,
- update_content=True,
- ):
- query_norm = self.norm_q(query)
- content_norm = self.norm_c(content)
- query = self.forward_stream(
- query,
- query_norm,
- content_norm,
- memory,
- query_mask,
- content_key_padding_mask,
- )[0]
- if update_content:
- content = self.forward_stream(
- content,
- content_norm,
- content_norm,
- memory,
- content_mask,
- content_key_padding_mask,
- )[0]
- return query, content
- def get_clones(module, N):
- return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
- class Decoder(paddle.nn.Layer):
- __constants__ = ["norm"]
- def __init__(self, decoder_layer, num_layers, norm):
- super().__init__()
- self.layers = get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
- def forward(
- self,
- query,
- content,
- memory,
- query_mask: Optional[paddle.Tensor] = None,
- content_mask: Optional[paddle.Tensor] = None,
- content_key_padding_mask: Optional[paddle.Tensor] = None,
- ):
- for i, mod in enumerate(self.layers):
- last = i == len(self.layers) - 1
- query, content = mod(
- query,
- content,
- memory,
- query_mask,
- content_mask,
- content_key_padding_mask,
- update_content=not last,
- )
- query = self.norm(query)
- return query
- class TokenEmbedding(paddle.nn.Layer):
- def __init__(self, charset_size: int, embed_dim: int):
- super().__init__()
- self.embedding = paddle.nn.Embedding(
- num_embeddings=charset_size, embedding_dim=embed_dim
- )
- self.embed_dim = embed_dim
- def forward(self, tokens: paddle.Tensor):
- return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
- def trunc_normal_init(param, **kwargs):
- initializer = nn.initializer.TruncatedNormal(**kwargs)
- initializer(param, param.block)
- def constant_init(param, **kwargs):
- initializer = nn.initializer.Constant(**kwargs)
- initializer(param, param.block)
- def kaiming_normal_init(param, **kwargs):
- initializer = nn.initializer.KaimingNormal(**kwargs)
- initializer(param, param.block)
- class ParseQHead(nn.Layer):
- def __init__(
- self,
- out_channels,
- max_text_length,
- embed_dim,
- dec_num_heads,
- dec_mlp_ratio,
- dec_depth,
- perm_num,
- perm_forward,
- perm_mirrored,
- decode_ar,
- refine_iters,
- dropout,
- **kwargs,
- ):
- super().__init__()
- self.bos_id = out_channels - 2
- self.eos_id = 0
- self.pad_id = out_channels - 1
- self.max_label_length = max_text_length
- self.decode_ar = decode_ar
- self.refine_iters = refine_iters
- decoder_layer = DecoderLayer(
- embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout
- )
- self.decoder = Decoder(
- decoder_layer,
- num_layers=dec_depth,
- norm=paddle.nn.LayerNorm(normalized_shape=embed_dim),
- )
- self.rng = np.random.default_rng()
- self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
- self.perm_forward = perm_forward
- self.perm_mirrored = perm_mirrored
- self.head = paddle.nn.Linear(
- in_features=embed_dim, out_features=out_channels - 2
- )
- self.text_embed = TokenEmbedding(out_channels, embed_dim)
- self.pos_queries = paddle.create_parameter(
- shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape,
- dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype,
- default_initializer=paddle.nn.initializer.Assign(
- paddle.empty(shape=[1, max_text_length + 1, embed_dim])
- ),
- )
- self.pos_queries.stop_gradient = not True
- self.dropout = paddle.nn.Dropout(p=dropout)
- self._device = self.parameters()[0].place
- trunc_normal_init(self.pos_queries, std=0.02)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, paddle.nn.Linear):
- trunc_normal_init(m.weight, std=0.02)
- if m.bias is not None:
- constant_init(m.bias, value=0.0)
- elif isinstance(m, paddle.nn.Embedding):
- trunc_normal_init(m.weight, std=0.02)
- if m._padding_idx is not None:
- m.weight.data[m._padding_idx].zero_()
- elif isinstance(m, paddle.nn.Conv2D):
- kaiming_normal_init(m.weight, fan_in=None, nonlinearity="relu")
- if m.bias is not None:
- constant_init(m.bias, value=0.0)
- elif isinstance(
- m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)
- ):
- constant_init(m.weight, value=1.0)
- constant_init(m.bias, value=0.0)
- def no_weight_decay(self):
- param_names = {"text_embed.embedding.weight", "pos_queries"}
- enc_param_names = {("encoder." + n) for n in self.encoder.no_weight_decay()}
- return param_names.union(enc_param_names)
- def encode(self, img):
- return self.encoder(img)
- def decode(
- self,
- tgt,
- memory,
- tgt_mask=None,
- tgt_padding_mask=None,
- tgt_query=None,
- tgt_query_mask=None,
- ):
- N, L = tgt.shape
- null_ctx = self.text_embed(tgt[:, :1])
- if L != 1:
- tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
- tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
- else:
- tgt_emb = self.dropout(null_ctx)
- if tgt_query is None:
- tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
- tgt_query = self.dropout(tgt_query)
- return self.decoder(
- tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask
- )
- def forward_test(self, memory, max_length=None):
- testing = max_length is None
- max_length = (
- self.max_label_length
- if max_length is None
- else min(max_length, self.max_label_length)
- )
- bs = memory.shape[0]
- num_steps = max_length + 1
- pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
- tgt_mask = query_mask = paddle.triu(
- x=paddle.full(shape=(num_steps, num_steps), fill_value=float("-inf")),
- diagonal=1,
- )
- if self.decode_ar:
- tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype(
- "int64"
- )
- tgt_in[:, (0)] = self.bos_id
- logits = []
- for i in range(paddle.to_tensor(num_steps)):
- j = i + 1
- tgt_out = self.decode(
- tgt_in[:, :j],
- memory,
- tgt_mask[:j, :j],
- tgt_query=pos_queries[:, i:j],
- tgt_query_mask=query_mask[i:j, :j],
- )
- p_i = self.head(tgt_out)
- logits.append(p_i)
- if j < num_steps:
- tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
- if (
- testing
- and (tgt_in == self.eos_id)
- .astype("bool")
- .any(axis=-1)
- .astype("bool")
- .all()
- ):
- break
- logits = paddle.concat(x=logits, axis=1)
- else:
- tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
- tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
- logits = self.head(tgt_out)
- if self.refine_iters:
- temp = paddle.triu(
- x=paddle.ones(shape=[num_steps, num_steps], dtype="bool"), diagonal=2
- )
- posi = paddle.where(temp == True)
- query_mask[posi] = 0
- bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
- for i in range(self.refine_iters):
- tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
- tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype="int32")
- tgt_padding_mask = tgt_padding_mask.cpu()
- tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
- tgt_padding_mask = (
- tgt_padding_mask.cuda().astype(dtype="float32") == 1.0
- )
- tgt_out = self.decode(
- tgt_in,
- memory,
- tgt_mask,
- tgt_padding_mask,
- tgt_query=pos_queries,
- tgt_query_mask=query_mask[:, : tgt_in.shape[1]],
- )
- logits = self.head(tgt_out)
- # transfer to probability
- logits = F.softmax(logits, axis=-1)
- final_output = {"predict": logits}
- return final_output
- def gen_tgt_perms(self, tgt):
- """Generate shared permutations for the whole batch.
- This works because the same attention mask can be used for the shorter sequences
- because of the padding mask.
- """
- max_num_chars = tgt.shape[1] - 2
- if max_num_chars == 1:
- return paddle.arange(end=3).unsqueeze(axis=0)
- perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
- max_perms = math.factorial(max_num_chars)
- if self.perm_mirrored:
- max_perms //= 2
- num_gen_perms = min(self.max_gen_perms, max_perms)
- if max_num_chars < 5:
- if max_num_chars == 4 and self.perm_mirrored:
- selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
- else:
- selector = list(range(max_perms))
- perm_pool = paddle.to_tensor(
- data=list(permutations(range(max_num_chars), max_num_chars)),
- place=self._device,
- )[selector]
- if self.perm_forward:
- perm_pool = perm_pool[1:]
- perms = paddle.stack(x=perms)
- if len(perm_pool):
- i = self.rng.choice(
- len(perm_pool), size=num_gen_perms - len(perms), replace=False
- )
- perms = paddle.concat(x=[perms, perm_pool[i]])
- else:
- perms.extend(
- [
- paddle.randperm(n=max_num_chars)
- for _ in range(num_gen_perms - len(perms))
- ]
- )
- perms = paddle.stack(x=perms)
- if self.perm_mirrored:
- comp = perms.flip(axis=-1)
- x = paddle.stack(x=[perms, comp])
- perm_2 = list(range(x.ndim))
- perm_2[0] = 1
- perm_2[1] = 0
- perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
- bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
- eos_idx = paddle.full(
- shape=(len(perms), 1), fill_value=max_num_chars + 1, dtype=perms.dtype
- )
- perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
- if len(perms) > 1:
- perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
- return perms
- def generate_attn_masks(self, perm):
- """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
- :param perm: the permutation sequence. i = 0 is always the BOS
- :return: lookahead attention masks
- """
- sz = perm.shape[0]
- mask = paddle.zeros(shape=(sz, sz))
- for i in range(sz):
- query_idx = perm[i].cpu().numpy().tolist()
- masked_keys = perm[i + 1 :].cpu().numpy().tolist()
- if len(masked_keys) == 0:
- break
- mask[query_idx, masked_keys] = float("-inf")
- content_mask = mask[:-1, :-1].clone()
- mask[paddle.eye(num_rows=sz).astype("bool")] = float("-inf")
- query_mask = mask[1:, :-1]
- return content_mask, query_mask
- def forward_train(self, memory, tgt):
- tgt_perms = self.gen_tgt_perms(tgt)
- tgt_in = tgt[:, :-1]
- tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
- logits_list = []
- final_out = {}
- for i, perm in enumerate(tgt_perms):
- tgt_mask, query_mask = self.generate_attn_masks(perm)
- out = self.decode(
- tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask
- )
- logits = self.head(out)
- if i == 0:
- final_out["predict"] = logits
- logits = logits.flatten(stop_axis=1)
- logits_list.append(logits)
- final_out["logits_list"] = logits_list
- final_out["pad_id"] = self.pad_id
- final_out["eos_id"] = self.eos_id
- return final_out
- def forward(self, feat, targets=None):
- # feat : B, N, C
- # targets : labels, labels_len
- if self.training:
- label = targets[0] # label
- label_len = targets[1]
- max_step = paddle.max(label_len).cpu().numpy()[0] + 2
- crop_label = label[:, :max_step]
- final_out = self.forward_train(feat, crop_label)
- else:
- final_out = self.forward_test(feat)
- return final_out
|