rec_satrn_head.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/encoders/satrn_encoder.py
  17. https://github.com/open-mmlab/mmocr/blob/1.x/mmocr/models/textrecog/decoders/nrtr_decoder.py
  18. """
  19. import math
  20. import numpy as np
  21. import paddle
  22. import paddle.nn as nn
  23. import paddle.nn.functional as F
  24. from paddle import ParamAttr, reshape, transpose
  25. from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
  26. from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
  27. from paddle.nn.initializer import KaimingNormal, Uniform, Constant
  28. class ConvBNLayer(nn.Layer):
  29. def __init__(
  30. self, num_channels, filter_size, num_filters, stride, padding, num_groups=1
  31. ):
  32. super(ConvBNLayer, self).__init__()
  33. self.conv = nn.Conv2D(
  34. in_channels=num_channels,
  35. out_channels=num_filters,
  36. kernel_size=filter_size,
  37. stride=stride,
  38. padding=padding,
  39. groups=num_groups,
  40. bias_attr=False,
  41. )
  42. self.bn = nn.BatchNorm2D(
  43. num_filters,
  44. weight_attr=ParamAttr(initializer=Constant(1)),
  45. bias_attr=ParamAttr(initializer=Constant(0)),
  46. )
  47. self.relu = nn.ReLU()
  48. def forward(self, inputs):
  49. y = self.conv(inputs)
  50. y = self.bn(y)
  51. y = self.relu(y)
  52. return y
  53. class SATRNEncoderLayer(nn.Layer):
  54. def __init__(
  55. self,
  56. d_model=512,
  57. d_inner=512,
  58. n_head=8,
  59. d_k=64,
  60. d_v=64,
  61. dropout=0.1,
  62. qkv_bias=False,
  63. ):
  64. super().__init__()
  65. self.norm1 = nn.LayerNorm(d_model)
  66. self.attn = MultiHeadAttention(
  67. n_head, d_model, d_k, d_v, qkv_bias=qkv_bias, dropout=dropout
  68. )
  69. self.norm2 = nn.LayerNorm(d_model)
  70. self.feed_forward = LocalityAwareFeedforward(d_model, d_inner, dropout=dropout)
  71. def forward(self, x, h, w, mask=None):
  72. n, hw, c = x.shape
  73. residual = x
  74. x = self.norm1(x)
  75. x = residual + self.attn(x, x, x, mask)
  76. residual = x
  77. x = self.norm2(x)
  78. x = x.transpose([0, 2, 1]).reshape([n, c, h, w])
  79. x = self.feed_forward(x)
  80. x = x.reshape([n, c, hw]).transpose([0, 2, 1])
  81. x = residual + x
  82. return x
  83. class LocalityAwareFeedforward(nn.Layer):
  84. def __init__(
  85. self,
  86. d_in,
  87. d_hid,
  88. dropout=0.1,
  89. ):
  90. super().__init__()
  91. self.conv1 = ConvBNLayer(d_in, 1, d_hid, stride=1, padding=0)
  92. self.depthwise_conv = ConvBNLayer(
  93. d_hid, 3, d_hid, stride=1, padding=1, num_groups=d_hid
  94. )
  95. self.conv2 = ConvBNLayer(d_hid, 1, d_in, stride=1, padding=0)
  96. def forward(self, x):
  97. x = self.conv1(x)
  98. x = self.depthwise_conv(x)
  99. x = self.conv2(x)
  100. return x
  101. class Adaptive2DPositionalEncoding(nn.Layer):
  102. def __init__(self, d_hid=512, n_height=100, n_width=100, dropout=0.1):
  103. super().__init__()
  104. h_position_encoder = self._get_sinusoid_encoding_table(n_height, d_hid)
  105. h_position_encoder = h_position_encoder.transpose([1, 0])
  106. h_position_encoder = h_position_encoder.reshape([1, d_hid, n_height, 1])
  107. w_position_encoder = self._get_sinusoid_encoding_table(n_width, d_hid)
  108. w_position_encoder = w_position_encoder.transpose([1, 0])
  109. w_position_encoder = w_position_encoder.reshape([1, d_hid, 1, n_width])
  110. self.register_buffer("h_position_encoder", h_position_encoder)
  111. self.register_buffer("w_position_encoder", w_position_encoder)
  112. self.h_scale = self.scale_factor_generate(d_hid)
  113. self.w_scale = self.scale_factor_generate(d_hid)
  114. self.pool = nn.AdaptiveAvgPool2D(1)
  115. self.dropout = nn.Dropout(p=dropout)
  116. def _get_sinusoid_encoding_table(self, n_position, d_hid):
  117. """Sinusoid position encoding table."""
  118. denominator = paddle.to_tensor(
  119. [1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
  120. )
  121. denominator = denominator.reshape([1, -1])
  122. pos_tensor = paddle.cast(paddle.arange(n_position).unsqueeze(-1), "float32")
  123. sinusoid_table = pos_tensor * denominator
  124. sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
  125. sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
  126. return sinusoid_table
  127. def scale_factor_generate(self, d_hid):
  128. scale_factor = nn.Sequential(
  129. nn.Conv2D(d_hid, d_hid, 1),
  130. nn.ReLU(),
  131. nn.Conv2D(d_hid, d_hid, 1),
  132. nn.Sigmoid(),
  133. )
  134. return scale_factor
  135. def forward(self, x):
  136. b, c, h, w = x.shape
  137. avg_pool = self.pool(x)
  138. h_pos_encoding = self.h_scale(avg_pool) * self.h_position_encoder[:, :, :h, :]
  139. w_pos_encoding = self.w_scale(avg_pool) * self.w_position_encoder[:, :, :, :w]
  140. out = x + h_pos_encoding + w_pos_encoding
  141. out = self.dropout(out)
  142. return out
  143. class ScaledDotProductAttention(nn.Layer):
  144. def __init__(self, temperature, attn_dropout=0.1):
  145. super().__init__()
  146. self.temperature = temperature
  147. self.dropout = nn.Dropout(attn_dropout)
  148. def forward(self, q, k, v, mask=None):
  149. def masked_fill(x, mask, value):
  150. y = paddle.full(x.shape, value, x.dtype)
  151. return paddle.where(mask, y, x)
  152. attn = paddle.matmul(q / self.temperature, k.transpose([0, 1, 3, 2]))
  153. if mask is not None:
  154. attn = masked_fill(attn, mask == 0, -1e9)
  155. # attn = attn.masked_fill(mask == 0, float('-inf'))
  156. # attn += mask
  157. attn = self.dropout(F.softmax(attn, axis=-1))
  158. output = paddle.matmul(attn, v)
  159. return output, attn
  160. class MultiHeadAttention(nn.Layer):
  161. def __init__(
  162. self, n_head=8, d_model=512, d_k=64, d_v=64, dropout=0.1, qkv_bias=False
  163. ):
  164. super().__init__()
  165. self.n_head = n_head
  166. self.d_k = d_k
  167. self.d_v = d_v
  168. self.dim_k = n_head * d_k
  169. self.dim_v = n_head * d_v
  170. self.linear_q = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias)
  171. self.linear_k = nn.Linear(self.dim_k, self.dim_k, bias_attr=qkv_bias)
  172. self.linear_v = nn.Linear(self.dim_v, self.dim_v, bias_attr=qkv_bias)
  173. self.attention = ScaledDotProductAttention(d_k**0.5, dropout)
  174. self.fc = nn.Linear(self.dim_v, d_model, bias_attr=qkv_bias)
  175. self.proj_drop = nn.Dropout(dropout)
  176. def forward(self, q, k, v, mask=None):
  177. batch_size, len_q, _ = q.shape
  178. _, len_k, _ = k.shape
  179. q = self.linear_q(q).reshape([batch_size, len_q, self.n_head, self.d_k])
  180. k = self.linear_k(k).reshape([batch_size, len_k, self.n_head, self.d_k])
  181. v = self.linear_v(v).reshape([batch_size, len_k, self.n_head, self.d_v])
  182. q, k, v = (
  183. q.transpose([0, 2, 1, 3]),
  184. k.transpose([0, 2, 1, 3]),
  185. v.transpose([0, 2, 1, 3]),
  186. )
  187. if mask is not None:
  188. if mask.dim() == 3:
  189. mask = mask.unsqueeze(1)
  190. elif mask.dim() == 2:
  191. mask = mask.unsqueeze(1).unsqueeze(1)
  192. attn_out, _ = self.attention(q, k, v, mask=mask)
  193. attn_out = attn_out.transpose([0, 2, 1, 3]).reshape(
  194. [batch_size, len_q, self.dim_v]
  195. )
  196. attn_out = self.fc(attn_out)
  197. attn_out = self.proj_drop(attn_out)
  198. return attn_out
  199. class SATRNEncoder(nn.Layer):
  200. def __init__(
  201. self,
  202. n_layers=12,
  203. n_head=8,
  204. d_k=64,
  205. d_v=64,
  206. d_model=512,
  207. n_position=100,
  208. d_inner=256,
  209. dropout=0.1,
  210. ):
  211. super().__init__()
  212. self.d_model = d_model
  213. self.position_enc = Adaptive2DPositionalEncoding(
  214. d_hid=d_model, n_height=n_position, n_width=n_position, dropout=dropout
  215. )
  216. self.layer_stack = nn.LayerList(
  217. [
  218. SATRNEncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  219. for _ in range(n_layers)
  220. ]
  221. )
  222. self.layer_norm = nn.LayerNorm(d_model)
  223. def forward(self, feat, valid_ratios=None):
  224. """
  225. Args:
  226. feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
  227. img_metas (dict): A dict that contains meta information of input
  228. images. Preferably with the key ``valid_ratio``.
  229. Returns:
  230. Tensor: A tensor of shape :math:`(N, T, D_m)`.
  231. """
  232. if valid_ratios is None:
  233. bs = feat.shape[0]
  234. valid_ratios = paddle.full((bs, 1), 1.0, dtype=paddle.float32)
  235. feat = self.position_enc(feat)
  236. n, c, h, w = feat.shape
  237. mask = paddle.zeros((n, h, w))
  238. for i, valid_ratio in enumerate(valid_ratios):
  239. valid_width = int(min(w, paddle.ceil(w * valid_ratio)))
  240. mask[i, :, :valid_width] = 1
  241. mask = mask.reshape([n, h * w])
  242. feat = feat.reshape([n, c, h * w])
  243. output = feat.transpose([0, 2, 1])
  244. for enc_layer in self.layer_stack:
  245. output = enc_layer(output, h, w, mask)
  246. output = self.layer_norm(output)
  247. return output
  248. class PositionwiseFeedForward(nn.Layer):
  249. def __init__(self, d_in, d_hid, dropout=0.1):
  250. super().__init__()
  251. self.w_1 = nn.Linear(d_in, d_hid)
  252. self.w_2 = nn.Linear(d_hid, d_in)
  253. self.act = nn.GELU()
  254. self.dropout = nn.Dropout(dropout)
  255. def forward(self, x):
  256. x = self.w_1(x)
  257. x = self.act(x)
  258. x = self.w_2(x)
  259. x = self.dropout(x)
  260. return x
  261. class PositionalEncoding(nn.Layer):
  262. def __init__(self, d_hid=512, n_position=200, dropout=0):
  263. super().__init__()
  264. self.dropout = nn.Dropout(p=dropout)
  265. # Not a parameter
  266. # Position table of shape (1, n_position, d_hid)
  267. self.register_buffer(
  268. "position_table", self._get_sinusoid_encoding_table(n_position, d_hid)
  269. )
  270. def _get_sinusoid_encoding_table(self, n_position, d_hid):
  271. """Sinusoid position encoding table."""
  272. denominator = paddle.to_tensor(
  273. [1.0 / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
  274. )
  275. denominator = denominator.reshape([1, -1])
  276. pos_tensor = paddle.cast(paddle.arange(n_position).unsqueeze(-1), "float32")
  277. sinusoid_table = pos_tensor * denominator
  278. sinusoid_table[:, 0::2] = paddle.sin(sinusoid_table[:, 0::2])
  279. sinusoid_table[:, 1::2] = paddle.cos(sinusoid_table[:, 1::2])
  280. return sinusoid_table.unsqueeze(0)
  281. def forward(self, x):
  282. x = x + self.position_table[:, : x.shape[1]].clone().detach()
  283. return self.dropout(x)
  284. class TFDecoderLayer(nn.Layer):
  285. def __init__(
  286. self,
  287. d_model=512,
  288. d_inner=256,
  289. n_head=8,
  290. d_k=64,
  291. d_v=64,
  292. dropout=0.1,
  293. qkv_bias=False,
  294. operation_order=None,
  295. ):
  296. super().__init__()
  297. self.norm1 = nn.LayerNorm(d_model)
  298. self.norm2 = nn.LayerNorm(d_model)
  299. self.norm3 = nn.LayerNorm(d_model)
  300. self.self_attn = MultiHeadAttention(
  301. n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias
  302. )
  303. self.enc_attn = MultiHeadAttention(
  304. n_head, d_model, d_k, d_v, dropout=dropout, qkv_bias=qkv_bias
  305. )
  306. self.mlp = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  307. self.operation_order = operation_order
  308. if self.operation_order is None:
  309. self.operation_order = (
  310. "norm",
  311. "self_attn",
  312. "norm",
  313. "enc_dec_attn",
  314. "norm",
  315. "ffn",
  316. )
  317. assert self.operation_order in [
  318. ("norm", "self_attn", "norm", "enc_dec_attn", "norm", "ffn"),
  319. ("self_attn", "norm", "enc_dec_attn", "norm", "ffn", "norm"),
  320. ]
  321. def forward(
  322. self, dec_input, enc_output, self_attn_mask=None, dec_enc_attn_mask=None
  323. ):
  324. if self.operation_order == (
  325. "self_attn",
  326. "norm",
  327. "enc_dec_attn",
  328. "norm",
  329. "ffn",
  330. "norm",
  331. ):
  332. dec_attn_out = self.self_attn(
  333. dec_input, dec_input, dec_input, self_attn_mask
  334. )
  335. dec_attn_out += dec_input
  336. dec_attn_out = self.norm1(dec_attn_out)
  337. enc_dec_attn_out = self.enc_attn(
  338. dec_attn_out, enc_output, enc_output, dec_enc_attn_mask
  339. )
  340. enc_dec_attn_out += dec_attn_out
  341. enc_dec_attn_out = self.norm2(enc_dec_attn_out)
  342. mlp_out = self.mlp(enc_dec_attn_out)
  343. mlp_out += enc_dec_attn_out
  344. mlp_out = self.norm3(mlp_out)
  345. elif self.operation_order == (
  346. "norm",
  347. "self_attn",
  348. "norm",
  349. "enc_dec_attn",
  350. "norm",
  351. "ffn",
  352. ):
  353. dec_input_norm = self.norm1(dec_input)
  354. dec_attn_out = self.self_attn(
  355. dec_input_norm, dec_input_norm, dec_input_norm, self_attn_mask
  356. )
  357. dec_attn_out += dec_input
  358. enc_dec_attn_in = self.norm2(dec_attn_out)
  359. enc_dec_attn_out = self.enc_attn(
  360. enc_dec_attn_in, enc_output, enc_output, dec_enc_attn_mask
  361. )
  362. enc_dec_attn_out += dec_attn_out
  363. mlp_out = self.mlp(self.norm3(enc_dec_attn_out))
  364. mlp_out += enc_dec_attn_out
  365. return mlp_out
  366. class SATRNDecoder(nn.Layer):
  367. def __init__(
  368. self,
  369. n_layers=6,
  370. d_embedding=512,
  371. n_head=8,
  372. d_k=64,
  373. d_v=64,
  374. d_model=512,
  375. d_inner=256,
  376. n_position=200,
  377. dropout=0.1,
  378. num_classes=93,
  379. max_seq_len=40,
  380. start_idx=1,
  381. padding_idx=92,
  382. ):
  383. super().__init__()
  384. self.padding_idx = padding_idx
  385. self.start_idx = start_idx
  386. self.max_seq_len = max_seq_len
  387. self.trg_word_emb = nn.Embedding(
  388. num_classes, d_embedding, padding_idx=padding_idx
  389. )
  390. self.position_enc = PositionalEncoding(d_embedding, n_position=n_position)
  391. self.dropout = nn.Dropout(p=dropout)
  392. self.layer_stack = nn.LayerList(
  393. [
  394. TFDecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  395. for _ in range(n_layers)
  396. ]
  397. )
  398. self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
  399. pred_num_class = num_classes - 1 # ignore padding_idx
  400. self.classifier = nn.Linear(d_model, pred_num_class)
  401. @staticmethod
  402. def get_pad_mask(seq, pad_idx):
  403. return (seq != pad_idx).unsqueeze(-2)
  404. @staticmethod
  405. def get_subsequent_mask(seq):
  406. """For masking out the subsequent info."""
  407. len_s = seq.shape[1]
  408. subsequent_mask = 1 - paddle.triu(paddle.ones((len_s, len_s)), diagonal=1)
  409. subsequent_mask = paddle.cast(subsequent_mask.unsqueeze(0), "bool")
  410. return subsequent_mask
  411. def _attention(self, trg_seq, src, src_mask=None):
  412. trg_embedding = self.trg_word_emb(trg_seq)
  413. trg_pos_encoded = self.position_enc(trg_embedding)
  414. tgt = self.dropout(trg_pos_encoded)
  415. trg_mask = self.get_pad_mask(
  416. trg_seq, pad_idx=self.padding_idx
  417. ) & self.get_subsequent_mask(trg_seq)
  418. output = tgt
  419. for dec_layer in self.layer_stack:
  420. output = dec_layer(
  421. output, src, self_attn_mask=trg_mask, dec_enc_attn_mask=src_mask
  422. )
  423. output = self.layer_norm(output)
  424. return output
  425. def _get_mask(self, logit, valid_ratios):
  426. N, T, _ = logit.shape
  427. mask = None
  428. if valid_ratios is not None:
  429. mask = paddle.zeros((N, T))
  430. for i, valid_ratio in enumerate(valid_ratios):
  431. valid_width = min(T, math.ceil(T * valid_ratio))
  432. mask[i, :valid_width] = 1
  433. return mask
  434. def forward_train(self, feat, out_enc, targets, valid_ratio):
  435. src_mask = self._get_mask(out_enc, valid_ratio)
  436. attn_output = self._attention(targets, out_enc, src_mask=src_mask)
  437. outputs = self.classifier(attn_output)
  438. return outputs
  439. def forward_test(self, feat, out_enc, valid_ratio):
  440. src_mask = self._get_mask(out_enc, valid_ratio)
  441. N = out_enc.shape[0]
  442. init_target_seq = paddle.full(
  443. (N, self.max_seq_len + 1), self.padding_idx, dtype="int64"
  444. )
  445. # bsz * seq_len
  446. init_target_seq[:, 0] = self.start_idx
  447. outputs = []
  448. for step in range(0, paddle.to_tensor(self.max_seq_len)):
  449. decoder_output = self._attention(
  450. init_target_seq, out_enc, src_mask=src_mask
  451. )
  452. # bsz * seq_len * C
  453. step_result = F.softmax(
  454. self.classifier(decoder_output[:, step, :]), axis=-1
  455. )
  456. # bsz * num_classes
  457. outputs.append(step_result)
  458. step_max_index = paddle.argmax(step_result, axis=-1)
  459. init_target_seq[:, step + 1] = step_max_index
  460. outputs = paddle.stack(outputs, axis=1)
  461. return outputs
  462. def forward(self, feat, out_enc, targets=None, valid_ratio=None):
  463. if self.training:
  464. return self.forward_train(feat, out_enc, targets, valid_ratio)
  465. else:
  466. return self.forward_test(feat, out_enc, valid_ratio)
  467. class SATRNHead(nn.Layer):
  468. def __init__(self, enc_cfg, dec_cfg, **kwargs):
  469. super(SATRNHead, self).__init__()
  470. # encoder module
  471. self.encoder = SATRNEncoder(**enc_cfg)
  472. # decoder module
  473. self.decoder = SATRNDecoder(**dec_cfg)
  474. def forward(self, feat, targets=None):
  475. if targets is not None:
  476. targets, valid_ratio = targets
  477. else:
  478. targets, valid_ratio = None, None
  479. holistic_feat = self.encoder(feat, valid_ratio) # bsz c
  480. final_out = self.decoder(feat, holistic_feat, targets, valid_ratio)
  481. return final_out