tsrn.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/model/tsrn.py
  17. """
  18. import math
  19. import paddle
  20. import paddle.nn.functional as F
  21. from paddle import nn
  22. from collections import OrderedDict
  23. import sys
  24. import numpy as np
  25. import warnings
  26. import math, copy
  27. import cv2
  28. warnings.filterwarnings("ignore")
  29. from .tps_spatial_transformer import TPSSpatialTransformer
  30. from .stn import STN as STN_model
  31. from ppocr.modeling.heads.sr_rensnet_transformer import Transformer
  32. class TSRN(nn.Layer):
  33. def __init__(
  34. self,
  35. in_channels,
  36. scale_factor=2,
  37. width=128,
  38. height=32,
  39. STN=False,
  40. srb_nums=5,
  41. mask=False,
  42. hidden_units=32,
  43. infer_mode=False,
  44. **kwargs,
  45. ):
  46. super(TSRN, self).__init__()
  47. in_planes = 3
  48. if mask:
  49. in_planes = 4
  50. assert math.log(scale_factor, 2) % 1 == 0
  51. upsample_block_num = int(math.log(scale_factor, 2))
  52. self.block1 = nn.Sequential(
  53. nn.Conv2D(in_planes, 2 * hidden_units, kernel_size=9, padding=4), nn.PReLU()
  54. )
  55. self.srb_nums = srb_nums
  56. for i in range(srb_nums):
  57. setattr(self, "block%d" % (i + 2), RecurrentResidualBlock(2 * hidden_units))
  58. setattr(
  59. self,
  60. "block%d" % (srb_nums + 2),
  61. nn.Sequential(
  62. nn.Conv2D(2 * hidden_units, 2 * hidden_units, kernel_size=3, padding=1),
  63. nn.BatchNorm2D(2 * hidden_units),
  64. ),
  65. )
  66. block_ = [UpsampleBLock(2 * hidden_units, 2) for _ in range(upsample_block_num)]
  67. block_.append(nn.Conv2D(2 * hidden_units, in_planes, kernel_size=9, padding=4))
  68. setattr(self, "block%d" % (srb_nums + 3), nn.Sequential(*block_))
  69. self.tps_inputsize = [height // scale_factor, width // scale_factor]
  70. tps_outputsize = [height // scale_factor, width // scale_factor]
  71. num_control_points = 20
  72. tps_margins = [0.05, 0.05]
  73. self.stn = STN
  74. if self.stn:
  75. self.tps = TPSSpatialTransformer(
  76. output_image_size=tuple(tps_outputsize),
  77. num_control_points=num_control_points,
  78. margins=tuple(tps_margins),
  79. )
  80. self.stn_head = STN_model(
  81. in_channels=in_planes,
  82. num_ctrlpoints=num_control_points,
  83. activation="none",
  84. )
  85. self.out_channels = in_channels
  86. self.r34_transformer = Transformer()
  87. for param in self.r34_transformer.parameters():
  88. param.trainable = False
  89. self.infer_mode = infer_mode
  90. def forward(self, x):
  91. output = {}
  92. if self.infer_mode:
  93. output["lr_img"] = x
  94. y = x
  95. else:
  96. output["lr_img"] = x[0]
  97. output["hr_img"] = x[1]
  98. y = x[0]
  99. if self.stn and self.training:
  100. _, ctrl_points_x = self.stn_head(y)
  101. y, _ = self.tps(y, ctrl_points_x)
  102. block = {"1": self.block1(y)}
  103. for i in range(self.srb_nums + 1):
  104. block[str(i + 2)] = getattr(self, "block%d" % (i + 2))(block[str(i + 1)])
  105. block[str(self.srb_nums + 3)] = getattr(self, "block%d" % (self.srb_nums + 3))(
  106. (block["1"] + block[str(self.srb_nums + 2)])
  107. )
  108. sr_img = paddle.tanh(block[str(self.srb_nums + 3)])
  109. output["sr_img"] = sr_img
  110. if self.training:
  111. hr_img = x[1]
  112. length = x[2]
  113. input_tensor = x[3]
  114. # add transformer
  115. sr_pred, word_attention_map_pred, _ = self.r34_transformer(
  116. sr_img, length, input_tensor
  117. )
  118. hr_pred, word_attention_map_gt, _ = self.r34_transformer(
  119. hr_img, length, input_tensor
  120. )
  121. output["hr_img"] = hr_img
  122. output["hr_pred"] = hr_pred
  123. output["word_attention_map_gt"] = word_attention_map_gt
  124. output["sr_pred"] = sr_pred
  125. output["word_attention_map_pred"] = word_attention_map_pred
  126. return output
  127. class RecurrentResidualBlock(nn.Layer):
  128. def __init__(self, channels):
  129. super(RecurrentResidualBlock, self).__init__()
  130. self.conv1 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
  131. self.bn1 = nn.BatchNorm2D(channels)
  132. self.gru1 = GruBlock(channels, channels)
  133. self.prelu = mish()
  134. self.conv2 = nn.Conv2D(channels, channels, kernel_size=3, padding=1)
  135. self.bn2 = nn.BatchNorm2D(channels)
  136. self.gru2 = GruBlock(channels, channels)
  137. def forward(self, x):
  138. residual = self.conv1(x)
  139. residual = self.bn1(residual)
  140. residual = self.prelu(residual)
  141. residual = self.conv2(residual)
  142. residual = self.bn2(residual)
  143. residual = self.gru1(residual.transpose([0, 1, 3, 2])).transpose([0, 1, 3, 2])
  144. return self.gru2(x + residual)
  145. class UpsampleBLock(nn.Layer):
  146. def __init__(self, in_channels, up_scale):
  147. super(UpsampleBLock, self).__init__()
  148. self.conv = nn.Conv2D(
  149. in_channels, in_channels * up_scale**2, kernel_size=3, padding=1
  150. )
  151. self.pixel_shuffle = nn.PixelShuffle(up_scale)
  152. self.prelu = mish()
  153. def forward(self, x):
  154. x = self.conv(x)
  155. x = self.pixel_shuffle(x)
  156. x = self.prelu(x)
  157. return x
  158. class mish(nn.Layer):
  159. def __init__(
  160. self,
  161. ):
  162. super(mish, self).__init__()
  163. self.activated = True
  164. def forward(self, x):
  165. if self.activated:
  166. x = x * (paddle.tanh(F.softplus(x)))
  167. return x
  168. class GruBlock(nn.Layer):
  169. def __init__(self, in_channels, out_channels):
  170. super(GruBlock, self).__init__()
  171. assert out_channels % 2 == 0
  172. self.conv1 = nn.Conv2D(in_channels, out_channels, kernel_size=1, padding=0)
  173. self.gru = nn.GRU(out_channels, out_channels // 2, direction="bidirectional")
  174. def forward(self, x):
  175. # x: b, c, w, h
  176. x = self.conv1(x)
  177. x = x.transpose([0, 2, 3, 1]) # b, w, h, c
  178. batch_size, w, h, c = x.shape
  179. x = x.reshape([-1, h, c]) # b*w, h, c
  180. x, _ = self.gru(x)
  181. x = x.reshape([-1, w, h, c])
  182. x = x.transpose([0, 3, 1, 2])
  183. return x