gaspin_transformer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. from paddle import nn, ParamAttr
  20. from paddle.nn import functional as F
  21. import numpy as np
  22. import functools
  23. from .tps import GridGenerator
  24. """This code is refer from:
  25. https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
  26. """
  27. class SP_TransformerNetwork(nn.Layer):
  28. """
  29. Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1]
  30. Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
  31. """
  32. def __init__(self, nc=1, default_type=5):
  33. """Based on SPIN
  34. Args:
  35. nc (int): number of input channels (usually in 1 or 3)
  36. default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
  37. """
  38. super(SP_TransformerNetwork, self).__init__()
  39. self.power_list = self.cal_K(default_type)
  40. self.sigmoid = nn.Sigmoid()
  41. self.bn = nn.InstanceNorm2D(nc)
  42. def cal_K(self, k=5):
  43. """
  44. Args:
  45. k (int): the complexity of transformation intensities (by default set to 6 as the paper)
  46. Returns:
  47. List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)]
  48. """
  49. from math import log
  50. x = []
  51. if k != 0:
  52. for i in range(1, k + 1):
  53. lower = round(
  54. log(1 - (0.5 / (k + 1)) * i) / log((0.5 / (k + 1)) * i), 2
  55. )
  56. upper = round(1 / lower, 2)
  57. x.append(lower)
  58. x.append(upper)
  59. x.append(1.00)
  60. return x
  61. def forward(self, batch_I, weights, offsets, lambda_color=None):
  62. """
  63. Args:
  64. batch_I (Tensor): batch of input images [batch_size x nc x I_height x I_width]
  65. weights:
  66. offsets: the predicted offset by AIN, a scalar
  67. lambda_color: the learnable update gate \alpha in Equa. (5) as
  68. g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets}
  69. Returns:
  70. Tensor: transformed images by SPN as Equa. (4) in Ref. [1]
  71. [batch_size x I_channel_num x I_r_height x I_r_width]
  72. """
  73. batch_I = (batch_I + 1) * 0.5
  74. if offsets is not None:
  75. batch_I = batch_I * (1 - lambda_color) + offsets * lambda_color
  76. batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
  77. batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
  78. batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1)
  79. batch_weight_sum = self.bn(batch_weight_sum)
  80. batch_weight_sum = self.sigmoid(batch_weight_sum)
  81. batch_weight_sum = batch_weight_sum * 2 - 1
  82. return batch_weight_sum
  83. class GA_SPIN_Transformer(nn.Layer):
  84. """
  85. Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
  86. Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
  87. """
  88. def __init__(
  89. self,
  90. in_channels=1,
  91. I_r_size=(32, 100),
  92. offsets=False,
  93. norm_type="BN",
  94. default_type=6,
  95. loc_lr=1,
  96. stn=True,
  97. ):
  98. """
  99. Args:
  100. in_channels (int): channel of input features,
  101. set it to 1 if the grayscale images and 3 if RGB input
  102. I_r_size (tuple): size of rectified images (used in STN transformations)
  103. offsets (bool): set it to False if use SPN w.o. AIN,
  104. and set it to True if use SPIN (both with SPN and AIN)
  105. norm_type (str): the normalization type of the module,
  106. set it to 'BN' by default, 'IN' optionally
  107. default_type (int): the K chromatic space,
  108. set it to 3/5/6 depend on the complexity of transformation intensities
  109. loc_lr (float): learning rate of location network
  110. stn (bool): whether to use stn.
  111. """
  112. super(GA_SPIN_Transformer, self).__init__()
  113. self.nc = in_channels
  114. self.spt = True
  115. self.offsets = offsets
  116. self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN
  117. self.I_r_size = I_r_size
  118. self.out_channels = in_channels
  119. if norm_type == "BN":
  120. norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
  121. elif norm_type == "IN":
  122. norm_layer = functools.partial(
  123. nn.InstanceNorm2D, weight_attr=False, use_global_stats=False
  124. )
  125. else:
  126. raise NotImplementedError(
  127. "normalization layer [%s] is not found" % norm_type
  128. )
  129. if self.spt:
  130. self.sp_net = SP_TransformerNetwork(in_channels, default_type)
  131. self.spt_convnet = nn.Sequential(
  132. # 32*100
  133. nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
  134. norm_layer(32),
  135. nn.ReLU(),
  136. nn.MaxPool2D(kernel_size=2, stride=2),
  137. # 16*50
  138. nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
  139. norm_layer(64),
  140. nn.ReLU(),
  141. nn.MaxPool2D(kernel_size=2, stride=2),
  142. # 8*25
  143. nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
  144. norm_layer(128),
  145. nn.ReLU(),
  146. nn.MaxPool2D(kernel_size=2, stride=2),
  147. # 4*12
  148. )
  149. self.stucture_fc1 = nn.Sequential(
  150. nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
  151. norm_layer(256),
  152. nn.ReLU(),
  153. nn.MaxPool2D(kernel_size=2, stride=2),
  154. nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
  155. norm_layer(256),
  156. nn.ReLU(), # 2*6
  157. nn.MaxPool2D(kernel_size=2, stride=2),
  158. nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
  159. norm_layer(512),
  160. nn.ReLU(), # 1*3
  161. nn.AdaptiveAvgPool2D(1),
  162. nn.Flatten(1, -1), # batch_size x 512
  163. nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
  164. nn.BatchNorm1D(256),
  165. nn.ReLU(),
  166. )
  167. self.out_weight = 2 * default_type + 1
  168. self.spt_length = 2 * default_type + 1
  169. if offsets:
  170. self.out_weight += 1
  171. if self.stn:
  172. self.F = 20
  173. self.out_weight += self.F * 2
  174. self.GridGenerator = GridGenerator(self.F * 2, self.F)
  175. # self.out_weight*=nc
  176. # Init structure_fc2 in LocalizationNetwork
  177. initial_bias = self.init_spin(default_type * 2)
  178. initial_bias = initial_bias.reshape(-1)
  179. param_attr = ParamAttr(
  180. learning_rate=loc_lr,
  181. initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])),
  182. )
  183. bias_attr = ParamAttr(
  184. learning_rate=loc_lr, initializer=nn.initializer.Assign(initial_bias)
  185. )
  186. self.stucture_fc2 = nn.Linear(
  187. 256, self.out_weight, weight_attr=param_attr, bias_attr=bias_attr
  188. )
  189. self.sigmoid = nn.Sigmoid()
  190. if offsets:
  191. self.offset_fc1 = nn.Sequential(
  192. nn.Conv2D(128, 16, 3, 1, 1, bias_attr=False),
  193. norm_layer(16),
  194. nn.ReLU(),
  195. )
  196. self.offset_fc2 = nn.Conv2D(16, in_channels, 3, 1, 1)
  197. self.pool = nn.MaxPool2D(2, 2)
  198. def init_spin(self, nz):
  199. """
  200. Args:
  201. nz (int): number of paired \betas exponents, which means the value of K x 2
  202. """
  203. init_id = [0.00] * nz + [5.00]
  204. if self.offsets:
  205. init_id += [-5.00]
  206. # init_id *=3
  207. init = np.array(init_id)
  208. if self.stn:
  209. F = self.F
  210. ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
  211. ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
  212. ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
  213. ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
  214. ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
  215. initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
  216. initial_bias = initial_bias.reshape(-1)
  217. init = np.concatenate([init, initial_bias], axis=0)
  218. return init
  219. def forward(self, x, return_weight=False):
  220. """
  221. Args:
  222. x (Tensor): input image batch
  223. return_weight (bool): set to False by default,
  224. if set to True return the predicted offsets of AIN, denoted as x_{offsets}
  225. Returns:
  226. Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size
  227. """
  228. if self.spt:
  229. feat = self.spt_convnet(x)
  230. fc1 = self.stucture_fc1(feat)
  231. sp_weight_fusion = self.stucture_fc2(fc1)
  232. sp_weight_fusion = sp_weight_fusion.reshape(
  233. [x.shape[0], self.out_weight, 1]
  234. )
  235. if self.offsets: # SPIN w. AIN
  236. lambda_color = sp_weight_fusion[:, self.spt_length, 0]
  237. lambda_color = (
  238. self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  239. )
  240. sp_weight = sp_weight_fusion[:, : self.spt_length, :]
  241. offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
  242. assert offsets.shape[2] == 2 # 2
  243. assert offsets.shape[3] == 6 # 16
  244. offsets = self.sigmoid(offsets) # v12
  245. if return_weight:
  246. return offsets
  247. offsets = nn.functional.upsample(
  248. offsets, size=(x.shape[2], x.shape[3]), mode="bilinear"
  249. )
  250. if self.stn:
  251. batch_C_prime = sp_weight_fusion[
  252. :, (self.spt_length + 1) :, :
  253. ].reshape([x.shape[0], self.F, 2])
  254. build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
  255. build_P_prime_reshape = build_P_prime.reshape(
  256. [build_P_prime.shape[0], self.I_r_size[0], self.I_r_size[1], 2]
  257. )
  258. else: # SPIN w.o. AIN
  259. sp_weight = sp_weight_fusion[:, : self.spt_length, :]
  260. lambda_color, offsets = None, None
  261. if self.stn:
  262. batch_C_prime = sp_weight_fusion[:, self.spt_length :, :].reshape(
  263. [x.shape[0], self.F, 2]
  264. )
  265. build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
  266. build_P_prime_reshape = build_P_prime.reshape(
  267. [build_P_prime.shape[0], self.I_r_size[0], self.I_r_size[1], 2]
  268. )
  269. x = self.sp_net(x, sp_weight, offsets, lambda_color)
  270. if self.stn:
  271. is_fp16 = False
  272. if build_P_prime_reshape.dtype != paddle.float32:
  273. data_type = build_P_prime_reshape.dtype
  274. x = x.cast(paddle.float32)
  275. build_P_prime_reshape = build_P_prime_reshape.cast(paddle.float32)
  276. is_fp16 = True
  277. x = F.grid_sample(
  278. x=x, grid=build_P_prime_reshape, padding_mode="border"
  279. )
  280. if is_fp16:
  281. x = x.cast(data_type)
  282. return x