rec_robustscanner_head.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py
  17. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/robust_scanner_decoder.py
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import math
  23. import paddle
  24. from paddle import ParamAttr
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. class BaseDecoder(nn.Layer):
  28. def __init__(self, **kwargs):
  29. super().__init__()
  30. def forward_train(self, feat, out_enc, targets, img_metas):
  31. raise NotImplementedError
  32. def forward_test(self, feat, out_enc, img_metas):
  33. raise NotImplementedError
  34. def forward(
  35. self,
  36. feat,
  37. out_enc,
  38. label=None,
  39. valid_ratios=None,
  40. word_positions=None,
  41. train_mode=True,
  42. ):
  43. self.train_mode = train_mode
  44. if train_mode:
  45. return self.forward_train(
  46. feat, out_enc, label, valid_ratios, word_positions
  47. )
  48. return self.forward_test(feat, out_enc, valid_ratios, word_positions)
  49. class ChannelReductionEncoder(nn.Layer):
  50. """Change the channel number with a one by one convoluational layer.
  51. Args:
  52. in_channels (int): Number of input channels.
  53. out_channels (int): Number of output channels.
  54. """
  55. def __init__(self, in_channels, out_channels, **kwargs):
  56. super(ChannelReductionEncoder, self).__init__()
  57. self.layer = nn.Conv2D(
  58. in_channels,
  59. out_channels,
  60. kernel_size=1,
  61. stride=1,
  62. padding=0,
  63. weight_attr=nn.initializer.XavierNormal(),
  64. )
  65. def forward(self, feat):
  66. """
  67. Args:
  68. feat (Tensor): Image features with the shape of
  69. :math:`(N, C_{in}, H, W)`.
  70. Returns:
  71. Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
  72. """
  73. return self.layer(feat)
  74. def masked_fill(x, mask, value):
  75. y = paddle.full(x.shape, value, x.dtype)
  76. return paddle.where(mask, y, x)
  77. class DotProductAttentionLayer(nn.Layer):
  78. def __init__(self, dim_model=None):
  79. super().__init__()
  80. self.scale = dim_model**-0.5 if dim_model is not None else 1.0
  81. def forward(self, query, key, value, h, w, valid_ratios=None):
  82. query = paddle.transpose(query, (0, 2, 1))
  83. logits = paddle.matmul(query, key) * self.scale
  84. n, c, t = logits.shape
  85. # reshape to (n, c, h, w)
  86. logits = paddle.reshape(logits, [n, c, h, w])
  87. if valid_ratios is not None:
  88. # cal mask of attention weight
  89. with paddle.base.framework._stride_in_no_check_dy2st_diff():
  90. for i, valid_ratio in enumerate(valid_ratios):
  91. valid_width = min(w, int(w * valid_ratio + 0.5))
  92. if valid_width < w:
  93. logits[i, :, :, valid_width:] = float("-inf")
  94. # reshape to (n, c, h, w)
  95. logits = paddle.reshape(logits, [n, c, t])
  96. weights = F.softmax(logits, axis=2)
  97. value = paddle.transpose(value, (0, 2, 1))
  98. glimpse = paddle.matmul(weights, value)
  99. glimpse = paddle.transpose(glimpse, (0, 2, 1))
  100. return glimpse
  101. class SequenceAttentionDecoder(BaseDecoder):
  102. """Sequence attention decoder for RobustScanner.
  103. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  104. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  105. Args:
  106. num_classes (int): Number of output classes :math:`C`.
  107. rnn_layers (int): Number of RNN layers.
  108. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  109. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  110. same as encoder output vector ``out_enc``.
  111. max_seq_len (int): Maximum output sequence length :math:`T`.
  112. start_idx (int): The index of `<SOS>`.
  113. mask (bool): Whether to mask input features according to
  114. ``img_meta['valid_ratio']``.
  115. padding_idx (int): The index of `<PAD>`.
  116. dropout (float): Dropout rate.
  117. return_feature (bool): Return feature or logits as the result.
  118. encode_value (bool): Whether to use the output of encoder ``out_enc``
  119. as `value` of attention layer. If False, the original feature
  120. ``feat`` will be used.
  121. Warning:
  122. This decoder will not predict the final class which is assumed to be
  123. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  124. is also ignored by loss as specified in
  125. :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
  126. """
  127. def __init__(
  128. self,
  129. num_classes=None,
  130. rnn_layers=2,
  131. dim_input=512,
  132. dim_model=128,
  133. max_seq_len=40,
  134. start_idx=0,
  135. mask=True,
  136. padding_idx=None,
  137. dropout=0,
  138. return_feature=False,
  139. encode_value=False,
  140. ):
  141. super().__init__()
  142. self.num_classes = num_classes
  143. self.dim_input = dim_input
  144. self.dim_model = dim_model
  145. self.return_feature = return_feature
  146. self.encode_value = encode_value
  147. self.max_seq_len = max_seq_len
  148. self.start_idx = start_idx
  149. self.mask = mask
  150. self.embedding = nn.Embedding(
  151. self.num_classes, self.dim_model, padding_idx=padding_idx
  152. )
  153. self.sequence_layer = nn.LSTM(
  154. input_size=dim_model,
  155. hidden_size=dim_model,
  156. num_layers=rnn_layers,
  157. time_major=False,
  158. dropout=dropout,
  159. )
  160. self.attention_layer = DotProductAttentionLayer()
  161. self.prediction = None
  162. if not self.return_feature:
  163. pred_num_classes = num_classes - 1
  164. self.prediction = nn.Linear(
  165. dim_model if encode_value else dim_input, pred_num_classes
  166. )
  167. def forward_train(self, feat, out_enc, targets, valid_ratios):
  168. """
  169. Args:
  170. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  171. out_enc (Tensor): Encoder output of shape
  172. :math:`(N, D_m, H, W)`.
  173. targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
  174. character.
  175. valid_ratios (Tensor): valid length ratio of img.
  176. Returns:
  177. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  178. ``return_feature=False``. Otherwise it would be the hidden feature
  179. before the prediction projection layer, whose shape is
  180. :math:`(N, T, D_m)`.
  181. """
  182. tgt_embedding = self.embedding(targets)
  183. n, c_enc, h, w = out_enc.shape
  184. assert c_enc == self.dim_model
  185. _, c_feat, _, _ = feat.shape
  186. assert c_feat == self.dim_input
  187. _, len_q, c_q = tgt_embedding.shape
  188. assert c_q == self.dim_model
  189. assert len_q <= self.max_seq_len
  190. query, _ = self.sequence_layer(tgt_embedding)
  191. query = paddle.transpose(query, (0, 2, 1))
  192. key = paddle.reshape(out_enc, [n, c_enc, h * w])
  193. if self.encode_value:
  194. value = key
  195. else:
  196. value = paddle.reshape(feat, [n, c_feat, h * w])
  197. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  198. attn_out = paddle.transpose(attn_out, (0, 2, 1))
  199. if self.return_feature:
  200. return attn_out
  201. out = self.prediction(attn_out)
  202. return out
  203. def forward_test(self, feat, out_enc, valid_ratios):
  204. """
  205. Args:
  206. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  207. out_enc (Tensor): Encoder output of shape
  208. :math:`(N, D_m, H, W)`.
  209. valid_ratios (Tensor): valid length ratio of img.
  210. Returns:
  211. Tensor: The output logit sequence tensor of shape
  212. :math:`(N, T, C-1)`.
  213. """
  214. seq_len = self.max_seq_len
  215. batch_size = feat.shape[0]
  216. decode_sequence = (
  217. paddle.ones((batch_size, seq_len), dtype="int64") * self.start_idx
  218. )
  219. outputs = []
  220. for i in range(seq_len):
  221. step_out = self.forward_test_step(
  222. feat, out_enc, decode_sequence, i, valid_ratios
  223. )
  224. outputs.append(step_out)
  225. max_idx = paddle.argmax(step_out, axis=1, keepdim=False)
  226. if i < seq_len - 1:
  227. decode_sequence[:, i + 1] = max_idx
  228. outputs = paddle.stack(outputs, 1)
  229. return outputs
  230. def forward_test_step(
  231. self, feat, out_enc, decode_sequence, current_step, valid_ratios
  232. ):
  233. """
  234. Args:
  235. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  236. out_enc (Tensor): Encoder output of shape
  237. :math:`(N, D_m, H, W)`.
  238. decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
  239. stores history decoding result.
  240. current_step (int): Current decoding step.
  241. valid_ratios (Tensor): valid length ratio of img
  242. Returns:
  243. Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
  244. tokens at current time step.
  245. """
  246. embed = self.embedding(decode_sequence)
  247. n, c_enc, h, w = out_enc.shape
  248. assert c_enc == self.dim_model
  249. _, c_feat, _, _ = feat.shape
  250. assert c_feat == self.dim_input
  251. _, _, c_q = embed.shape
  252. assert c_q == self.dim_model
  253. query, _ = self.sequence_layer(embed)
  254. query = paddle.transpose(query, (0, 2, 1))
  255. key = paddle.reshape(out_enc, [n, c_enc, h * w])
  256. if self.encode_value:
  257. value = key
  258. else:
  259. value = paddle.reshape(feat, [n, c_feat, h * w])
  260. # [n, c, l]
  261. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  262. out = attn_out[:, :, current_step]
  263. if self.return_feature:
  264. return out
  265. out = self.prediction(out)
  266. out = F.softmax(out, dim=-1)
  267. return out
  268. class PositionAwareLayer(nn.Layer):
  269. def __init__(self, dim_model, rnn_layers=2):
  270. super().__init__()
  271. self.dim_model = dim_model
  272. self.rnn = nn.LSTM(
  273. input_size=dim_model,
  274. hidden_size=dim_model,
  275. num_layers=rnn_layers,
  276. time_major=False,
  277. )
  278. self.mixer = nn.Sequential(
  279. nn.Conv2D(dim_model, dim_model, kernel_size=3, stride=1, padding=1),
  280. nn.ReLU(),
  281. nn.Conv2D(dim_model, dim_model, kernel_size=3, stride=1, padding=1),
  282. )
  283. def forward(self, img_feature):
  284. n, c, h, w = img_feature.shape
  285. rnn_input = paddle.transpose(img_feature, (0, 2, 3, 1))
  286. rnn_input = paddle.reshape(rnn_input, (n * h, w, c))
  287. rnn_output, _ = self.rnn(rnn_input)
  288. rnn_output = paddle.reshape(rnn_output, (n, h, w, c))
  289. rnn_output = paddle.transpose(rnn_output, (0, 3, 1, 2))
  290. out = self.mixer(rnn_output)
  291. return out
  292. class PositionAttentionDecoder(BaseDecoder):
  293. """Position attention decoder for RobustScanner.
  294. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  295. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  296. Args:
  297. num_classes (int): Number of output classes :math:`C`.
  298. rnn_layers (int): Number of RNN layers.
  299. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  300. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  301. same as encoder output vector ``out_enc``.
  302. max_seq_len (int): Maximum output sequence length :math:`T`.
  303. mask (bool): Whether to mask input features according to
  304. ``img_meta['valid_ratio']``.
  305. return_feature (bool): Return feature or logits as the result.
  306. encode_value (bool): Whether to use the output of encoder ``out_enc``
  307. as `value` of attention layer. If False, the original feature
  308. ``feat`` will be used.
  309. Warning:
  310. This decoder will not predict the final class which is assumed to be
  311. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  312. is also ignored by loss
  313. """
  314. def __init__(
  315. self,
  316. num_classes=None,
  317. rnn_layers=2,
  318. dim_input=512,
  319. dim_model=128,
  320. max_seq_len=40,
  321. mask=True,
  322. return_feature=False,
  323. encode_value=False,
  324. ):
  325. super().__init__()
  326. self.num_classes = num_classes
  327. self.dim_input = dim_input
  328. self.dim_model = dim_model
  329. self.max_seq_len = max_seq_len
  330. self.return_feature = return_feature
  331. self.encode_value = encode_value
  332. self.mask = mask
  333. self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
  334. self.position_aware_module = PositionAwareLayer(self.dim_model, rnn_layers)
  335. self.attention_layer = DotProductAttentionLayer()
  336. self.prediction = None
  337. if not self.return_feature:
  338. pred_num_classes = num_classes - 1
  339. self.prediction = nn.Linear(
  340. dim_model if encode_value else dim_input, pred_num_classes
  341. )
  342. def _get_position_index(self, length, batch_size):
  343. position_index_list = []
  344. for i in range(batch_size):
  345. position_index = paddle.arange(0, end=length, step=1, dtype="int64")
  346. position_index_list.append(position_index)
  347. batch_position_index = paddle.stack(position_index_list, axis=0)
  348. return batch_position_index
  349. def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
  350. """
  351. Args:
  352. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  353. out_enc (Tensor): Encoder output of shape
  354. :math:`(N, D_m, H, W)`.
  355. targets (dict): A dict with the key ``padded_targets``, a
  356. tensor of shape :math:`(N, T)`. Each element is the index of a
  357. character.
  358. valid_ratios (Tensor): valid length ratio of img.
  359. position_index (Tensor): The position of each word.
  360. Returns:
  361. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  362. ``return_feature=False``. Otherwise it will be the hidden feature
  363. before the prediction projection layer, whose shape is
  364. :math:`(N, T, D_m)`.
  365. """
  366. n, c_enc, h, w = out_enc.shape
  367. assert c_enc == self.dim_model
  368. _, c_feat, _, _ = feat.shape
  369. assert c_feat == self.dim_input
  370. _, len_q = targets.shape
  371. assert len_q <= self.max_seq_len
  372. position_out_enc = self.position_aware_module(out_enc)
  373. query = self.embedding(position_index)
  374. query = paddle.transpose(query, (0, 2, 1))
  375. key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
  376. if self.encode_value:
  377. value = paddle.reshape(out_enc, (n, c_enc, h * w))
  378. else:
  379. value = paddle.reshape(feat, (n, c_feat, h * w))
  380. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  381. attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
  382. if self.return_feature:
  383. return attn_out
  384. return self.prediction(attn_out)
  385. def forward_test(self, feat, out_enc, valid_ratios, position_index):
  386. """
  387. Args:
  388. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  389. out_enc (Tensor): Encoder output of shape
  390. :math:`(N, D_m, H, W)`.
  391. valid_ratios (Tensor): valid length ratio of img
  392. position_index (Tensor): The position of each word.
  393. Returns:
  394. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  395. ``return_feature=False``. Otherwise it would be the hidden feature
  396. before the prediction projection layer, whose shape is
  397. :math:`(N, T, D_m)`.
  398. """
  399. n, c_enc, h, w = out_enc.shape
  400. assert c_enc == self.dim_model
  401. _, c_feat, _, _ = feat.shape
  402. assert c_feat == self.dim_input
  403. position_out_enc = self.position_aware_module(out_enc)
  404. query = self.embedding(position_index)
  405. query = paddle.transpose(query, (0, 2, 1))
  406. key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
  407. if self.encode_value:
  408. value = paddle.reshape(out_enc, (n, c_enc, h * w))
  409. else:
  410. value = paddle.reshape(feat, (n, c_feat, h * w))
  411. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  412. attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
  413. if self.return_feature:
  414. return attn_out
  415. return self.prediction(attn_out)
  416. class RobustScannerFusionLayer(nn.Layer):
  417. def __init__(self, dim_model, dim=-1):
  418. super(RobustScannerFusionLayer, self).__init__()
  419. self.dim_model = dim_model
  420. self.dim = dim
  421. self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
  422. def forward(self, x0, x1):
  423. assert x0.shape == x1.shape
  424. fusion_input = paddle.concat([x0, x1], self.dim)
  425. output = self.linear_layer(fusion_input)
  426. output = F.glu(output, self.dim)
  427. return output
  428. class RobustScannerDecoder(BaseDecoder):
  429. """Decoder for RobustScanner.
  430. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  431. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  432. Args:
  433. num_classes (int): Number of output classes :math:`C`.
  434. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  435. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  436. same as encoder output vector ``out_enc``.
  437. max_seq_len (int): Maximum output sequence length :math:`T`.
  438. start_idx (int): The index of `<SOS>`.
  439. mask (bool): Whether to mask input features according to
  440. ``img_meta['valid_ratio']``.
  441. padding_idx (int): The index of `<PAD>`.
  442. encode_value (bool): Whether to use the output of encoder ``out_enc``
  443. as `value` of attention layer. If False, the original feature
  444. ``feat`` will be used.
  445. Warning:
  446. This decoder will not predict the final class which is assumed to be
  447. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  448. is also ignored by loss as specified in
  449. :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
  450. """
  451. def __init__(
  452. self,
  453. num_classes=None,
  454. dim_input=512,
  455. dim_model=128,
  456. hybrid_decoder_rnn_layers=2,
  457. hybrid_decoder_dropout=0,
  458. position_decoder_rnn_layers=2,
  459. max_seq_len=40,
  460. start_idx=0,
  461. mask=True,
  462. padding_idx=None,
  463. encode_value=False,
  464. ):
  465. super().__init__()
  466. self.num_classes = num_classes
  467. self.dim_input = dim_input
  468. self.dim_model = dim_model
  469. self.max_seq_len = max_seq_len
  470. self.encode_value = encode_value
  471. self.start_idx = start_idx
  472. self.padding_idx = padding_idx
  473. self.mask = mask
  474. # init hybrid decoder
  475. self.hybrid_decoder = SequenceAttentionDecoder(
  476. num_classes=num_classes,
  477. rnn_layers=hybrid_decoder_rnn_layers,
  478. dim_input=dim_input,
  479. dim_model=dim_model,
  480. max_seq_len=max_seq_len,
  481. start_idx=start_idx,
  482. mask=mask,
  483. padding_idx=padding_idx,
  484. dropout=hybrid_decoder_dropout,
  485. encode_value=encode_value,
  486. return_feature=True,
  487. )
  488. # init position decoder
  489. self.position_decoder = PositionAttentionDecoder(
  490. num_classes=num_classes,
  491. rnn_layers=position_decoder_rnn_layers,
  492. dim_input=dim_input,
  493. dim_model=dim_model,
  494. max_seq_len=max_seq_len,
  495. mask=mask,
  496. encode_value=encode_value,
  497. return_feature=True,
  498. )
  499. self.fusion_module = RobustScannerFusionLayer(
  500. self.dim_model if encode_value else dim_input
  501. )
  502. pred_num_classes = num_classes - 1
  503. self.prediction = nn.Linear(
  504. dim_model if encode_value else dim_input, pred_num_classes
  505. )
  506. def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
  507. """
  508. Args:
  509. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  510. out_enc (Tensor): Encoder output of shape
  511. :math:`(N, D_m, H, W)`.
  512. target (dict): A dict with the key ``padded_targets``, a
  513. tensor of shape :math:`(N, T)`. Each element is the index of a
  514. character.
  515. valid_ratios (Tensor):
  516. word_positions (Tensor): The position of each word.
  517. Returns:
  518. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
  519. """
  520. hybrid_glimpse = self.hybrid_decoder.forward_train(
  521. feat, out_enc, target, valid_ratios
  522. )
  523. position_glimpse = self.position_decoder.forward_train(
  524. feat, out_enc, target, valid_ratios, word_positions
  525. )
  526. fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
  527. out = self.prediction(fusion_out)
  528. return out
  529. def forward_test(self, feat, out_enc, valid_ratios, word_positions):
  530. """
  531. Args:
  532. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  533. out_enc (Tensor): Encoder output of shape
  534. :math:`(N, D_m, H, W)`.
  535. valid_ratios (Tensor):
  536. word_positions (Tensor): The position of each word.
  537. Returns:
  538. Tensor: The output logit sequence tensor of shape
  539. :math:`(N, T, C-1)`.
  540. """
  541. seq_len = self.max_seq_len
  542. batch_size = feat.shape[0]
  543. decode_sequence = (
  544. paddle.ones((batch_size, seq_len), dtype="int64") * self.start_idx
  545. )
  546. position_glimpse = self.position_decoder.forward_test(
  547. feat, out_enc, valid_ratios, word_positions
  548. )
  549. outputs = []
  550. for i in range(seq_len):
  551. hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
  552. feat, out_enc, decode_sequence, i, valid_ratios
  553. )
  554. fusion_out = self.fusion_module(
  555. hybrid_glimpse_step, position_glimpse[:, i, :]
  556. )
  557. char_out = self.prediction(fusion_out)
  558. char_out = F.softmax(char_out, -1)
  559. outputs.append(char_out)
  560. max_idx = paddle.argmax(char_out, axis=1, keepdim=False)
  561. if i < seq_len - 1:
  562. decode_sequence[:, i + 1] = max_idx
  563. outputs = paddle.stack(outputs, 1)
  564. return outputs
  565. class RobustScannerHead(nn.Layer):
  566. def __init__(
  567. self,
  568. out_channels, # 90 + unknown + start + padding
  569. in_channels,
  570. enc_outchannles=128,
  571. hybrid_dec_rnn_layers=2,
  572. hybrid_dec_dropout=0,
  573. position_dec_rnn_layers=2,
  574. start_idx=0,
  575. max_text_length=40,
  576. mask=True,
  577. padding_idx=None,
  578. encode_value=False,
  579. **kwargs,
  580. ):
  581. super(RobustScannerHead, self).__init__()
  582. # encoder module
  583. self.encoder = ChannelReductionEncoder(
  584. in_channels=in_channels, out_channels=enc_outchannles
  585. )
  586. # decoder module
  587. self.decoder = RobustScannerDecoder(
  588. num_classes=out_channels,
  589. dim_input=in_channels,
  590. dim_model=enc_outchannles,
  591. hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
  592. hybrid_decoder_dropout=hybrid_dec_dropout,
  593. position_decoder_rnn_layers=position_dec_rnn_layers,
  594. max_seq_len=max_text_length,
  595. start_idx=start_idx,
  596. mask=mask,
  597. padding_idx=padding_idx,
  598. encode_value=encode_value,
  599. )
  600. def forward(self, inputs, targets=None):
  601. """
  602. targets: [label, valid_ratio, word_positions]
  603. """
  604. out_enc = self.encoder(inputs)
  605. valid_ratios = None
  606. word_positions = targets[-1]
  607. if len(targets) > 1:
  608. valid_ratios = targets[-2]
  609. if self.training:
  610. label = targets[0] # label
  611. label = paddle.to_tensor(label, dtype="int64")
  612. final_out = self.decoder(
  613. inputs, out_enc, label, valid_ratios, word_positions
  614. )
  615. if not self.training:
  616. final_out = self.decoder(
  617. inputs,
  618. out_enc,
  619. label=None,
  620. valid_ratios=valid_ratios,
  621. word_positions=word_positions,
  622. train_mode=False,
  623. )
  624. return final_out