rnn.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
  20. from ppocr.modeling.backbones.rec_svtrnet import (
  21. Block,
  22. ConvBNLayer,
  23. trunc_normal_,
  24. zeros_,
  25. ones_,
  26. )
  27. class Im2Seq(nn.Layer):
  28. def __init__(self, in_channels, **kwargs):
  29. super().__init__()
  30. self.out_channels = in_channels
  31. def forward(self, x):
  32. B, C, H, W = x.shape
  33. assert H == 1
  34. x = x.squeeze(axis=2)
  35. x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  36. return x
  37. class EncoderWithRNN(nn.Layer):
  38. def __init__(self, in_channels, hidden_size):
  39. super(EncoderWithRNN, self).__init__()
  40. self.out_channels = hidden_size * 2
  41. self.lstm = nn.LSTM(
  42. in_channels, hidden_size, direction="bidirectional", num_layers=2
  43. )
  44. def forward(self, x):
  45. x, _ = self.lstm(x)
  46. return x
  47. class BidirectionalLSTM(nn.Layer):
  48. def __init__(
  49. self,
  50. input_size,
  51. hidden_size,
  52. output_size=None,
  53. num_layers=1,
  54. dropout=0,
  55. direction=False,
  56. time_major=False,
  57. with_linear=False,
  58. ):
  59. super(BidirectionalLSTM, self).__init__()
  60. self.with_linear = with_linear
  61. self.rnn = nn.LSTM(
  62. input_size,
  63. hidden_size,
  64. num_layers=num_layers,
  65. dropout=dropout,
  66. direction=direction,
  67. time_major=time_major,
  68. )
  69. # text recognition the specified structure LSTM with linear
  70. if self.with_linear:
  71. self.linear = nn.Linear(hidden_size * 2, output_size)
  72. def forward(self, input_feature):
  73. recurrent, _ = self.rnn(
  74. input_feature
  75. ) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
  76. if self.with_linear:
  77. output = self.linear(recurrent) # batch_size x T x output_size
  78. return output
  79. return recurrent
  80. class EncoderWithCascadeRNN(nn.Layer):
  81. def __init__(
  82. self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False
  83. ):
  84. super(EncoderWithCascadeRNN, self).__init__()
  85. self.out_channels = out_channels[-1]
  86. self.encoder = nn.LayerList(
  87. [
  88. BidirectionalLSTM(
  89. in_channels if i == 0 else out_channels[i - 1],
  90. hidden_size,
  91. output_size=out_channels[i],
  92. num_layers=1,
  93. direction="bidirectional",
  94. with_linear=with_linear,
  95. )
  96. for i in range(num_layers)
  97. ]
  98. )
  99. def forward(self, x):
  100. for i, l in enumerate(self.encoder):
  101. x = l(x)
  102. return x
  103. class EncoderWithFC(nn.Layer):
  104. def __init__(self, in_channels, hidden_size):
  105. super(EncoderWithFC, self).__init__()
  106. self.out_channels = hidden_size
  107. weight_attr, bias_attr = get_para_bias_attr(l2_decay=0.00001, k=in_channels)
  108. self.fc = nn.Linear(
  109. in_channels,
  110. hidden_size,
  111. weight_attr=weight_attr,
  112. bias_attr=bias_attr,
  113. name="reduce_encoder_fea",
  114. )
  115. def forward(self, x):
  116. x = self.fc(x)
  117. return x
  118. class EncoderWithSVTR(nn.Layer):
  119. def __init__(
  120. self,
  121. in_channels,
  122. dims=64, # XS
  123. depth=2,
  124. hidden_dims=120,
  125. use_guide=False,
  126. num_heads=8,
  127. qkv_bias=True,
  128. mlp_ratio=2.0,
  129. drop_rate=0.1,
  130. attn_drop_rate=0.1,
  131. drop_path=0.0,
  132. kernel_size=[3, 3],
  133. qk_scale=None,
  134. ):
  135. super(EncoderWithSVTR, self).__init__()
  136. self.depth = depth
  137. self.use_guide = use_guide
  138. self.conv1 = ConvBNLayer(
  139. in_channels,
  140. in_channels // 8,
  141. kernel_size=kernel_size,
  142. padding=[kernel_size[0] // 2, kernel_size[1] // 2],
  143. act=nn.Swish,
  144. )
  145. self.conv2 = ConvBNLayer(
  146. in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish
  147. )
  148. self.svtr_block = nn.LayerList(
  149. [
  150. Block(
  151. dim=hidden_dims,
  152. num_heads=num_heads,
  153. mixer="Global",
  154. HW=None,
  155. mlp_ratio=mlp_ratio,
  156. qkv_bias=qkv_bias,
  157. qk_scale=qk_scale,
  158. drop=drop_rate,
  159. act_layer=nn.Swish,
  160. attn_drop=attn_drop_rate,
  161. drop_path=drop_path,
  162. norm_layer="nn.LayerNorm",
  163. epsilon=1e-05,
  164. prenorm=False,
  165. )
  166. for i in range(depth)
  167. ]
  168. )
  169. self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6)
  170. self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
  171. # last conv-nxn, the input is concat of input tensor and conv3 output tensor
  172. self.conv4 = ConvBNLayer(
  173. 2 * in_channels,
  174. in_channels // 8,
  175. kernel_size=kernel_size,
  176. padding=[kernel_size[0] // 2, kernel_size[1] // 2],
  177. act=nn.Swish,
  178. )
  179. self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act=nn.Swish)
  180. self.out_channels = dims
  181. self.apply(self._init_weights)
  182. def _init_weights(self, m):
  183. if isinstance(m, nn.Linear):
  184. trunc_normal_(m.weight)
  185. if isinstance(m, nn.Linear) and m.bias is not None:
  186. zeros_(m.bias)
  187. elif isinstance(m, nn.LayerNorm):
  188. zeros_(m.bias)
  189. ones_(m.weight)
  190. def forward(self, x):
  191. # for use guide
  192. if self.use_guide:
  193. z = x.clone()
  194. z.stop_gradient = True
  195. else:
  196. z = x
  197. # for short cut
  198. h = z
  199. # reduce dim
  200. z = self.conv1(z)
  201. z = self.conv2(z)
  202. # SVTR global block
  203. B, C, H, W = z.shape
  204. z = z.flatten(2).transpose([0, 2, 1])
  205. for blk in self.svtr_block:
  206. z = blk(z)
  207. z = self.norm(z)
  208. # last stage
  209. z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
  210. z = self.conv3(z)
  211. z = paddle.concat((h, z), axis=1)
  212. z = self.conv1x1(self.conv4(z))
  213. return z
  214. class SequenceEncoder(nn.Layer):
  215. def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
  216. super(SequenceEncoder, self).__init__()
  217. self.encoder_reshape = Im2Seq(in_channels)
  218. self.out_channels = self.encoder_reshape.out_channels
  219. self.encoder_type = encoder_type
  220. if encoder_type == "reshape":
  221. self.only_reshape = True
  222. else:
  223. support_encoder_dict = {
  224. "reshape": Im2Seq,
  225. "fc": EncoderWithFC,
  226. "rnn": EncoderWithRNN,
  227. "svtr": EncoderWithSVTR,
  228. "cascadernn": EncoderWithCascadeRNN,
  229. }
  230. assert encoder_type in support_encoder_dict, "{} must in {}".format(
  231. encoder_type, support_encoder_dict.keys()
  232. )
  233. if encoder_type == "svtr":
  234. self.encoder = support_encoder_dict[encoder_type](
  235. self.encoder_reshape.out_channels, **kwargs
  236. )
  237. elif encoder_type == "cascadernn":
  238. self.encoder = support_encoder_dict[encoder_type](
  239. self.encoder_reshape.out_channels, hidden_size, **kwargs
  240. )
  241. else:
  242. self.encoder = support_encoder_dict[encoder_type](
  243. self.encoder_reshape.out_channels, hidden_size
  244. )
  245. self.out_channels = self.encoder.out_channels
  246. self.only_reshape = False
  247. def forward(self, x):
  248. if self.encoder_type != "svtr":
  249. x = self.encoder_reshape(x)
  250. if not self.only_reshape:
  251. x = self.encoder(x)
  252. return x
  253. else:
  254. x = self.encoder(x)
  255. x = self.encoder_reshape(x)
  256. return x