rec_parseq_head.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
  15. # reference: https://arxiv.org/abs/2207.06966
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import math
  20. import paddle
  21. from paddle import nn, ParamAttr
  22. from paddle.nn import functional as F
  23. import numpy as np
  24. from .self_attention import WrapEncoderForFeature
  25. from .self_attention import WrapEncoder
  26. from collections import OrderedDict
  27. from typing import Optional
  28. import copy
  29. from itertools import permutations
  30. class DecoderLayer(paddle.nn.Layer):
  31. """A Transformer decoder layer supporting two-stream attention (XLNet)
  32. This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
  33. def __init__(
  34. self,
  35. d_model,
  36. nhead,
  37. dim_feedforward=2048,
  38. dropout=0.1,
  39. activation="gelu",
  40. layer_norm_eps=1e-05,
  41. ):
  42. super().__init__()
  43. self.self_attn = paddle.nn.MultiHeadAttention(
  44. d_model, nhead, dropout=dropout, need_weights=True
  45. ) # paddle.nn.MultiHeadAttention默认为batch_first模式
  46. self.cross_attn = paddle.nn.MultiHeadAttention(
  47. d_model, nhead, dropout=dropout, need_weights=True
  48. )
  49. self.linear1 = paddle.nn.Linear(
  50. in_features=d_model, out_features=dim_feedforward
  51. )
  52. self.dropout = paddle.nn.Dropout(p=dropout)
  53. self.linear2 = paddle.nn.Linear(
  54. in_features=dim_feedforward, out_features=d_model
  55. )
  56. self.norm1 = paddle.nn.LayerNorm(
  57. normalized_shape=d_model, epsilon=layer_norm_eps
  58. )
  59. self.norm2 = paddle.nn.LayerNorm(
  60. normalized_shape=d_model, epsilon=layer_norm_eps
  61. )
  62. self.norm_q = paddle.nn.LayerNorm(
  63. normalized_shape=d_model, epsilon=layer_norm_eps
  64. )
  65. self.norm_c = paddle.nn.LayerNorm(
  66. normalized_shape=d_model, epsilon=layer_norm_eps
  67. )
  68. self.dropout1 = paddle.nn.Dropout(p=dropout)
  69. self.dropout2 = paddle.nn.Dropout(p=dropout)
  70. self.dropout3 = paddle.nn.Dropout(p=dropout)
  71. if activation == "gelu":
  72. self.activation = paddle.nn.GELU()
  73. def __setstate__(self, state):
  74. if "activation" not in state:
  75. state["activation"] = paddle.nn.functional.gelu
  76. super().__setstate__(state)
  77. def forward_stream(
  78. self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask
  79. ):
  80. """Forward pass for a single stream (i.e. content or query)
  81. tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
  82. Both tgt_kv and memory are expected to be LayerNorm'd too.
  83. memory is LayerNorm'd by ViT.
  84. """
  85. if tgt_key_padding_mask is not None:
  86. tgt_mask1 = (tgt_mask != float("-inf"))[None, None, :, :] & (
  87. tgt_key_padding_mask[:, None, None, :] == False
  88. )
  89. tgt2, sa_weights = self.self_attn(
  90. tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1
  91. )
  92. else:
  93. tgt2, sa_weights = self.self_attn(
  94. tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask
  95. )
  96. tgt = tgt + self.dropout1(tgt2)
  97. tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
  98. tgt = tgt + self.dropout2(tgt2)
  99. tgt2 = self.linear2(
  100. self.dropout(self.activation(self.linear1(self.norm2(tgt))))
  101. )
  102. tgt = tgt + self.dropout3(tgt2)
  103. return tgt, sa_weights, ca_weights
  104. def forward(
  105. self,
  106. query,
  107. content,
  108. memory,
  109. query_mask=None,
  110. content_mask=None,
  111. content_key_padding_mask=None,
  112. update_content=True,
  113. ):
  114. query_norm = self.norm_q(query)
  115. content_norm = self.norm_c(content)
  116. query = self.forward_stream(
  117. query,
  118. query_norm,
  119. content_norm,
  120. memory,
  121. query_mask,
  122. content_key_padding_mask,
  123. )[0]
  124. if update_content:
  125. content = self.forward_stream(
  126. content,
  127. content_norm,
  128. content_norm,
  129. memory,
  130. content_mask,
  131. content_key_padding_mask,
  132. )[0]
  133. return query, content
  134. def get_clones(module, N):
  135. return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
  136. class Decoder(paddle.nn.Layer):
  137. __constants__ = ["norm"]
  138. def __init__(self, decoder_layer, num_layers, norm):
  139. super().__init__()
  140. self.layers = get_clones(decoder_layer, num_layers)
  141. self.num_layers = num_layers
  142. self.norm = norm
  143. def forward(
  144. self,
  145. query,
  146. content,
  147. memory,
  148. query_mask: Optional[paddle.Tensor] = None,
  149. content_mask: Optional[paddle.Tensor] = None,
  150. content_key_padding_mask: Optional[paddle.Tensor] = None,
  151. ):
  152. for i, mod in enumerate(self.layers):
  153. last = i == len(self.layers) - 1
  154. query, content = mod(
  155. query,
  156. content,
  157. memory,
  158. query_mask,
  159. content_mask,
  160. content_key_padding_mask,
  161. update_content=not last,
  162. )
  163. query = self.norm(query)
  164. return query
  165. class TokenEmbedding(paddle.nn.Layer):
  166. def __init__(self, charset_size: int, embed_dim: int):
  167. super().__init__()
  168. self.embedding = paddle.nn.Embedding(
  169. num_embeddings=charset_size, embedding_dim=embed_dim
  170. )
  171. self.embed_dim = embed_dim
  172. def forward(self, tokens: paddle.Tensor):
  173. return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
  174. def trunc_normal_init(param, **kwargs):
  175. initializer = nn.initializer.TruncatedNormal(**kwargs)
  176. initializer(param, param.block)
  177. def constant_init(param, **kwargs):
  178. initializer = nn.initializer.Constant(**kwargs)
  179. initializer(param, param.block)
  180. def kaiming_normal_init(param, **kwargs):
  181. initializer = nn.initializer.KaimingNormal(**kwargs)
  182. initializer(param, param.block)
  183. class ParseQHead(nn.Layer):
  184. def __init__(
  185. self,
  186. out_channels,
  187. max_text_length,
  188. embed_dim,
  189. dec_num_heads,
  190. dec_mlp_ratio,
  191. dec_depth,
  192. perm_num,
  193. perm_forward,
  194. perm_mirrored,
  195. decode_ar,
  196. refine_iters,
  197. dropout,
  198. **kwargs,
  199. ):
  200. super().__init__()
  201. self.bos_id = out_channels - 2
  202. self.eos_id = 0
  203. self.pad_id = out_channels - 1
  204. self.max_label_length = max_text_length
  205. self.decode_ar = decode_ar
  206. self.refine_iters = refine_iters
  207. decoder_layer = DecoderLayer(
  208. embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout
  209. )
  210. self.decoder = Decoder(
  211. decoder_layer,
  212. num_layers=dec_depth,
  213. norm=paddle.nn.LayerNorm(normalized_shape=embed_dim),
  214. )
  215. self.rng = np.random.default_rng()
  216. self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
  217. self.perm_forward = perm_forward
  218. self.perm_mirrored = perm_mirrored
  219. self.head = paddle.nn.Linear(
  220. in_features=embed_dim, out_features=out_channels - 2
  221. )
  222. self.text_embed = TokenEmbedding(out_channels, embed_dim)
  223. self.pos_queries = paddle.create_parameter(
  224. shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape,
  225. dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype,
  226. default_initializer=paddle.nn.initializer.Assign(
  227. paddle.empty(shape=[1, max_text_length + 1, embed_dim])
  228. ),
  229. )
  230. self.pos_queries.stop_gradient = not True
  231. self.dropout = paddle.nn.Dropout(p=dropout)
  232. self._device = self.parameters()[0].place
  233. trunc_normal_init(self.pos_queries, std=0.02)
  234. self.apply(self._init_weights)
  235. def _init_weights(self, m):
  236. if isinstance(m, paddle.nn.Linear):
  237. trunc_normal_init(m.weight, std=0.02)
  238. if m.bias is not None:
  239. constant_init(m.bias, value=0.0)
  240. elif isinstance(m, paddle.nn.Embedding):
  241. trunc_normal_init(m.weight, std=0.02)
  242. if m._padding_idx is not None:
  243. m.weight.data[m._padding_idx].zero_()
  244. elif isinstance(m, paddle.nn.Conv2D):
  245. kaiming_normal_init(m.weight, fan_in=None, nonlinearity="relu")
  246. if m.bias is not None:
  247. constant_init(m.bias, value=0.0)
  248. elif isinstance(
  249. m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)
  250. ):
  251. constant_init(m.weight, value=1.0)
  252. constant_init(m.bias, value=0.0)
  253. def no_weight_decay(self):
  254. param_names = {"text_embed.embedding.weight", "pos_queries"}
  255. enc_param_names = {("encoder." + n) for n in self.encoder.no_weight_decay()}
  256. return param_names.union(enc_param_names)
  257. def encode(self, img):
  258. return self.encoder(img)
  259. def decode(
  260. self,
  261. tgt,
  262. memory,
  263. tgt_mask=None,
  264. tgt_padding_mask=None,
  265. tgt_query=None,
  266. tgt_query_mask=None,
  267. ):
  268. N, L = tgt.shape
  269. null_ctx = self.text_embed(tgt[:, :1])
  270. if L != 1:
  271. tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
  272. tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
  273. else:
  274. tgt_emb = self.dropout(null_ctx)
  275. if tgt_query is None:
  276. tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
  277. tgt_query = self.dropout(tgt_query)
  278. return self.decoder(
  279. tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask
  280. )
  281. def forward_test(self, memory, max_length=None):
  282. testing = max_length is None
  283. max_length = (
  284. self.max_label_length
  285. if max_length is None
  286. else min(max_length, self.max_label_length)
  287. )
  288. bs = memory.shape[0]
  289. num_steps = max_length + 1
  290. pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
  291. tgt_mask = query_mask = paddle.triu(
  292. x=paddle.full(shape=(num_steps, num_steps), fill_value=float("-inf")),
  293. diagonal=1,
  294. )
  295. if self.decode_ar:
  296. tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype(
  297. "int64"
  298. )
  299. tgt_in[:, (0)] = self.bos_id
  300. logits = []
  301. for i in range(paddle.to_tensor(num_steps)):
  302. j = i + 1
  303. tgt_out = self.decode(
  304. tgt_in[:, :j],
  305. memory,
  306. tgt_mask[:j, :j],
  307. tgt_query=pos_queries[:, i:j],
  308. tgt_query_mask=query_mask[i:j, :j],
  309. )
  310. p_i = self.head(tgt_out)
  311. logits.append(p_i)
  312. if j < num_steps:
  313. tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
  314. if (
  315. testing
  316. and (tgt_in == self.eos_id)
  317. .astype("bool")
  318. .any(axis=-1)
  319. .astype("bool")
  320. .all()
  321. ):
  322. break
  323. logits = paddle.concat(x=logits, axis=1)
  324. else:
  325. tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
  326. tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
  327. logits = self.head(tgt_out)
  328. if self.refine_iters:
  329. temp = paddle.triu(
  330. x=paddle.ones(shape=[num_steps, num_steps], dtype="bool"), diagonal=2
  331. )
  332. posi = paddle.where(temp == True)
  333. query_mask[posi] = 0
  334. bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
  335. for i in range(self.refine_iters):
  336. tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
  337. tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype="int32")
  338. tgt_padding_mask = tgt_padding_mask.cpu()
  339. tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
  340. tgt_padding_mask = (
  341. tgt_padding_mask.cuda().astype(dtype="float32") == 1.0
  342. )
  343. tgt_out = self.decode(
  344. tgt_in,
  345. memory,
  346. tgt_mask,
  347. tgt_padding_mask,
  348. tgt_query=pos_queries,
  349. tgt_query_mask=query_mask[:, : tgt_in.shape[1]],
  350. )
  351. logits = self.head(tgt_out)
  352. # transfer to probability
  353. logits = F.softmax(logits, axis=-1)
  354. final_output = {"predict": logits}
  355. return final_output
  356. def gen_tgt_perms(self, tgt):
  357. """Generate shared permutations for the whole batch.
  358. This works because the same attention mask can be used for the shorter sequences
  359. because of the padding mask.
  360. """
  361. max_num_chars = tgt.shape[1] - 2
  362. if max_num_chars == 1:
  363. return paddle.arange(end=3).unsqueeze(axis=0)
  364. perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
  365. max_perms = math.factorial(max_num_chars)
  366. if self.perm_mirrored:
  367. max_perms //= 2
  368. num_gen_perms = min(self.max_gen_perms, max_perms)
  369. if max_num_chars < 5:
  370. if max_num_chars == 4 and self.perm_mirrored:
  371. selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
  372. else:
  373. selector = list(range(max_perms))
  374. perm_pool = paddle.to_tensor(
  375. data=list(permutations(range(max_num_chars), max_num_chars)),
  376. place=self._device,
  377. )[selector]
  378. if self.perm_forward:
  379. perm_pool = perm_pool[1:]
  380. perms = paddle.stack(x=perms)
  381. if len(perm_pool):
  382. i = self.rng.choice(
  383. len(perm_pool), size=num_gen_perms - len(perms), replace=False
  384. )
  385. perms = paddle.concat(x=[perms, perm_pool[i]])
  386. else:
  387. perms.extend(
  388. [
  389. paddle.randperm(n=max_num_chars)
  390. for _ in range(num_gen_perms - len(perms))
  391. ]
  392. )
  393. perms = paddle.stack(x=perms)
  394. if self.perm_mirrored:
  395. comp = perms.flip(axis=-1)
  396. x = paddle.stack(x=[perms, comp])
  397. perm_2 = list(range(x.ndim))
  398. perm_2[0] = 1
  399. perm_2[1] = 0
  400. perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
  401. bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
  402. eos_idx = paddle.full(
  403. shape=(len(perms), 1), fill_value=max_num_chars + 1, dtype=perms.dtype
  404. )
  405. perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
  406. if len(perms) > 1:
  407. perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
  408. return perms
  409. def generate_attn_masks(self, perm):
  410. """Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
  411. :param perm: the permutation sequence. i = 0 is always the BOS
  412. :return: lookahead attention masks
  413. """
  414. sz = perm.shape[0]
  415. mask = paddle.zeros(shape=(sz, sz))
  416. for i in range(sz):
  417. query_idx = perm[i].cpu().numpy().tolist()
  418. masked_keys = perm[i + 1 :].cpu().numpy().tolist()
  419. if len(masked_keys) == 0:
  420. break
  421. mask[query_idx, masked_keys] = float("-inf")
  422. content_mask = mask[:-1, :-1].clone()
  423. mask[paddle.eye(num_rows=sz).astype("bool")] = float("-inf")
  424. query_mask = mask[1:, :-1]
  425. return content_mask, query_mask
  426. def forward_train(self, memory, tgt):
  427. tgt_perms = self.gen_tgt_perms(tgt)
  428. tgt_in = tgt[:, :-1]
  429. tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
  430. logits_list = []
  431. final_out = {}
  432. for i, perm in enumerate(tgt_perms):
  433. tgt_mask, query_mask = self.generate_attn_masks(perm)
  434. out = self.decode(
  435. tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask
  436. )
  437. logits = self.head(out)
  438. if i == 0:
  439. final_out["predict"] = logits
  440. logits = logits.flatten(stop_axis=1)
  441. logits_list.append(logits)
  442. final_out["logits_list"] = logits_list
  443. final_out["pad_id"] = self.pad_id
  444. final_out["eos_id"] = self.eos_id
  445. return final_out
  446. def forward(self, feat, targets=None):
  447. # feat : B, N, C
  448. # targets : labels, labels_len
  449. if self.training:
  450. label = targets[0] # label
  451. label_len = targets[1]
  452. max_step = paddle.max(label_len).cpu().numpy()[0] + 2
  453. crop_label = label[:, :max_step]
  454. final_out = self.forward_train(feat, crop_label)
  455. else:
  456. final_out = self.forward_test(feat)
  457. return final_out