rec_sar_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
  17. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import math
  23. import paddle
  24. from paddle import ParamAttr
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. class SAREncoder(nn.Layer):
  28. """
  29. Args:
  30. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
  31. enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
  32. enc_gru (bool): If True, use GRU, else LSTM in encoder.
  33. d_model (int): Dim of channels from backbone.
  34. d_enc (int): Dim of encoder RNN layer.
  35. mask (bool): If True, mask padding in RNN sequence.
  36. """
  37. def __init__(
  38. self,
  39. enc_bi_rnn=False,
  40. enc_drop_rnn=0.1,
  41. enc_gru=False,
  42. d_model=512,
  43. d_enc=512,
  44. mask=True,
  45. **kwargs,
  46. ):
  47. super().__init__()
  48. assert isinstance(enc_bi_rnn, bool)
  49. assert isinstance(enc_drop_rnn, (int, float))
  50. assert 0 <= enc_drop_rnn < 1.0
  51. assert isinstance(enc_gru, bool)
  52. assert isinstance(d_model, int)
  53. assert isinstance(d_enc, int)
  54. assert isinstance(mask, bool)
  55. self.enc_bi_rnn = enc_bi_rnn
  56. self.enc_drop_rnn = enc_drop_rnn
  57. self.mask = mask
  58. # LSTM Encoder
  59. if enc_bi_rnn:
  60. direction = "bidirectional"
  61. else:
  62. direction = "forward"
  63. kwargs = dict(
  64. input_size=d_model,
  65. hidden_size=d_enc,
  66. num_layers=2,
  67. time_major=False,
  68. dropout=enc_drop_rnn,
  69. direction=direction,
  70. )
  71. if enc_gru:
  72. self.rnn_encoder = nn.GRU(**kwargs)
  73. else:
  74. self.rnn_encoder = nn.LSTM(**kwargs)
  75. # global feature transformation
  76. encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
  77. self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
  78. def forward(self, feat, img_metas=None):
  79. if img_metas is not None:
  80. assert len(img_metas[0]) == feat.shape[0]
  81. valid_ratios = None
  82. if img_metas is not None and self.mask:
  83. valid_ratios = img_metas[-1]
  84. h_feat = feat.shape[2] # bsz c h w
  85. feat_v = F.max_pool2d(feat, kernel_size=(h_feat, 1), stride=1, padding=0)
  86. feat_v = feat_v.squeeze(2) # bsz * C * W
  87. feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
  88. holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
  89. if valid_ratios is not None:
  90. valid_hf = []
  91. T = paddle.shape(holistic_feat)[1]
  92. for i in range(valid_ratios.shape[0]):
  93. valid_step = (
  94. paddle.minimum(T, paddle.ceil(valid_ratios[i] * T).astype(T.dtype))
  95. - 1
  96. )
  97. valid_hf.append(holistic_feat[i, valid_step, :])
  98. valid_hf = paddle.stack(valid_hf, axis=0)
  99. else:
  100. valid_hf = holistic_feat[:, -1, :] # bsz * C
  101. holistic_feat = self.linear(valid_hf) # bsz * C
  102. return holistic_feat
  103. class BaseDecoder(nn.Layer):
  104. def __init__(self, **kwargs):
  105. super().__init__()
  106. def forward_train(self, feat, out_enc, targets, img_metas):
  107. raise NotImplementedError
  108. def forward_test(self, feat, out_enc, img_metas):
  109. raise NotImplementedError
  110. def forward(self, feat, out_enc, label=None, img_metas=None, train_mode=True):
  111. self.train_mode = train_mode
  112. if train_mode:
  113. return self.forward_train(feat, out_enc, label, img_metas)
  114. return self.forward_test(feat, out_enc, img_metas)
  115. class ParallelSARDecoder(BaseDecoder):
  116. """
  117. Args:
  118. out_channels (int): Output class number.
  119. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
  120. dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
  121. dec_drop_rnn (float): Dropout of RNN layer in decoder.
  122. dec_gru (bool): If True, use GRU, else LSTM in decoder.
  123. d_model (int): Dim of channels from backbone.
  124. d_enc (int): Dim of encoder RNN layer.
  125. d_k (int): Dim of channels of attention module.
  126. pred_dropout (float): Dropout probability of prediction layer.
  127. max_seq_len (int): Maximum sequence length for decoding.
  128. mask (bool): If True, mask padding in feature map.
  129. start_idx (int): Index of start token.
  130. padding_idx (int): Index of padding token.
  131. pred_concat (bool): If True, concat glimpse feature from
  132. attention with holistic feature and hidden state.
  133. """
  134. def __init__(
  135. self,
  136. out_channels, # 90 + unknown + start + padding
  137. enc_bi_rnn=False,
  138. dec_bi_rnn=False,
  139. dec_drop_rnn=0.0,
  140. dec_gru=False,
  141. d_model=512,
  142. d_enc=512,
  143. d_k=64,
  144. pred_dropout=0.1,
  145. max_text_length=30,
  146. mask=True,
  147. pred_concat=True,
  148. **kwargs,
  149. ):
  150. super().__init__()
  151. self.num_classes = out_channels
  152. self.enc_bi_rnn = enc_bi_rnn
  153. self.d_k = d_k
  154. self.start_idx = out_channels - 2
  155. self.padding_idx = out_channels - 1
  156. self.max_seq_len = max_text_length
  157. self.mask = mask
  158. self.pred_concat = pred_concat
  159. encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
  160. decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
  161. # 2D attention layer
  162. self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
  163. self.conv3x3_1 = nn.Conv2D(d_model, d_k, kernel_size=3, stride=1, padding=1)
  164. self.conv1x1_2 = nn.Linear(d_k, 1)
  165. # Decoder RNN layer
  166. if dec_bi_rnn:
  167. direction = "bidirectional"
  168. else:
  169. direction = "forward"
  170. kwargs = dict(
  171. input_size=encoder_rnn_out_size,
  172. hidden_size=encoder_rnn_out_size,
  173. num_layers=2,
  174. time_major=False,
  175. dropout=dec_drop_rnn,
  176. direction=direction,
  177. )
  178. if dec_gru:
  179. self.rnn_decoder = nn.GRU(**kwargs)
  180. else:
  181. self.rnn_decoder = nn.LSTM(**kwargs)
  182. # Decoder input embedding
  183. self.embedding = nn.Embedding(
  184. self.num_classes, encoder_rnn_out_size, padding_idx=self.padding_idx
  185. )
  186. # Prediction layer
  187. self.pred_dropout = nn.Dropout(pred_dropout)
  188. pred_num_classes = self.num_classes - 1
  189. if pred_concat:
  190. fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
  191. else:
  192. fc_in_channel = d_model
  193. self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
  194. def _2d_attention(self, decoder_input, feat, holistic_feat, valid_ratios=None):
  195. y = self.rnn_decoder(decoder_input)[0]
  196. # y: bsz * (seq_len + 1) * hidden_size
  197. attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
  198. bsz, seq_len, attn_size = attn_query.shape
  199. attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
  200. # (bsz, seq_len + 1, attn_size, 1, 1)
  201. attn_key = self.conv3x3_1(feat)
  202. # bsz * attn_size * h * w
  203. attn_key = attn_key.unsqueeze(1)
  204. # bsz * 1 * attn_size * h * w
  205. attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
  206. # bsz * (seq_len + 1) * attn_size * h * w
  207. attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
  208. # bsz * (seq_len + 1) * h * w * attn_size
  209. attn_weight = self.conv1x1_2(attn_weight)
  210. # bsz * (seq_len + 1) * h * w * 1
  211. bsz, T, h, w, c = paddle.shape(attn_weight)
  212. assert c == 1
  213. if valid_ratios is not None:
  214. # cal mask of attention weight
  215. for i in range(valid_ratios.shape[0]):
  216. valid_width = paddle.minimum(
  217. w.astype("int64"), paddle.ceil(valid_ratios[i] * w).astype("int64")
  218. )
  219. if valid_width < w:
  220. attn_weight[i, :, :, valid_width:, :] = float("-inf")
  221. attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
  222. attn_weight = F.softmax(attn_weight, axis=-1)
  223. attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
  224. attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
  225. # attn_weight: bsz * T * c * h * w
  226. # feat: bsz * c * h * w
  227. attn_feat = paddle.sum(
  228. paddle.multiply(feat.unsqueeze(1), attn_weight), (3, 4), keepdim=False
  229. )
  230. # bsz * (seq_len + 1) * C
  231. # Linear transformation
  232. if self.pred_concat:
  233. hf_c = holistic_feat.shape[-1]
  234. holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c])
  235. y = self.prediction(
  236. paddle.concat(
  237. (y, attn_feat.astype(y.dtype), holistic_feat.astype(y.dtype)), 2
  238. )
  239. )
  240. else:
  241. y = self.prediction(attn_feat)
  242. # bsz * (seq_len + 1) * num_classes
  243. if self.train_mode:
  244. y = self.pred_dropout(y)
  245. return y
  246. def forward_train(self, feat, out_enc, label, img_metas):
  247. """
  248. img_metas: [label, valid_ratio]
  249. """
  250. if img_metas is not None:
  251. assert img_metas[0].shape[0] == feat.shape[0]
  252. valid_ratios = None
  253. if img_metas is not None and self.mask:
  254. valid_ratios = img_metas[-1]
  255. lab_embedding = self.embedding(label)
  256. # bsz * seq_len * emb_dim
  257. out_enc = out_enc.unsqueeze(1).astype(lab_embedding.dtype)
  258. # bsz * 1 * emb_dim
  259. in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
  260. # bsz * (seq_len + 1) * C
  261. out_dec = self._2d_attention(in_dec, feat, out_enc, valid_ratios=valid_ratios)
  262. return out_dec[:, 1:, :] # bsz * seq_len * num_classes
  263. def forward_test(self, feat, out_enc, img_metas):
  264. if img_metas is not None:
  265. assert len(img_metas[0]) == feat.shape[0]
  266. valid_ratios = None
  267. if img_metas is not None and self.mask:
  268. valid_ratios = img_metas[-1]
  269. seq_len = self.max_seq_len
  270. bsz = feat.shape[0]
  271. start_token = paddle.full((bsz,), fill_value=self.start_idx, dtype="int64")
  272. # bsz
  273. start_token = self.embedding(start_token)
  274. # bsz * emb_dim
  275. emb_dim = start_token.shape[1]
  276. start_token = start_token.unsqueeze(1)
  277. start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
  278. # bsz * seq_len * emb_dim
  279. out_enc = out_enc.unsqueeze(1)
  280. # bsz * 1 * emb_dim
  281. decoder_input = paddle.concat((out_enc, start_token), axis=1)
  282. # bsz * (seq_len + 1) * emb_dim
  283. outputs = []
  284. for i in range(1, seq_len + 1):
  285. decoder_output = self._2d_attention(
  286. decoder_input, feat, out_enc, valid_ratios=valid_ratios
  287. )
  288. char_output = decoder_output[:, i, :] # bsz * num_classes
  289. char_output = F.softmax(char_output, -1)
  290. outputs.append(char_output)
  291. max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
  292. char_embedding = self.embedding(max_idx) # bsz * emb_dim
  293. if i < seq_len:
  294. decoder_input[:, i + 1, :] = char_embedding
  295. outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
  296. return outputs
  297. class SARHead(nn.Layer):
  298. def __init__(
  299. self,
  300. in_channels,
  301. out_channels,
  302. enc_dim=512,
  303. max_text_length=30,
  304. enc_bi_rnn=False,
  305. enc_drop_rnn=0.1,
  306. enc_gru=False,
  307. dec_bi_rnn=False,
  308. dec_drop_rnn=0.0,
  309. dec_gru=False,
  310. d_k=512,
  311. pred_dropout=0.1,
  312. pred_concat=True,
  313. **kwargs,
  314. ):
  315. super(SARHead, self).__init__()
  316. # encoder module
  317. self.encoder = SAREncoder(
  318. enc_bi_rnn=enc_bi_rnn,
  319. enc_drop_rnn=enc_drop_rnn,
  320. enc_gru=enc_gru,
  321. d_model=in_channels,
  322. d_enc=enc_dim,
  323. )
  324. # decoder module
  325. self.decoder = ParallelSARDecoder(
  326. out_channels=out_channels,
  327. enc_bi_rnn=enc_bi_rnn,
  328. dec_bi_rnn=dec_bi_rnn,
  329. dec_drop_rnn=dec_drop_rnn,
  330. dec_gru=dec_gru,
  331. d_model=in_channels,
  332. d_enc=enc_dim,
  333. d_k=d_k,
  334. pred_dropout=pred_dropout,
  335. max_text_length=max_text_length,
  336. pred_concat=pred_concat,
  337. )
  338. def forward(self, feat, targets=None):
  339. """
  340. img_metas: [label, valid_ratio]
  341. """
  342. holistic_feat = self.encoder(feat, targets) # bsz c
  343. if self.training:
  344. label = targets[0] # label
  345. final_out = self.decoder(feat, holistic_feat, label, img_metas=targets)
  346. else:
  347. final_out = self.decoder(
  348. feat, holistic_feat, label=None, img_metas=targets, train_mode=False
  349. )
  350. # (bsz, seq_len, num_classes)
  351. return final_out