sr_rensnet_transformer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
  17. """
  18. import copy
  19. import math
  20. import paddle
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. def subsequent_mask(size):
  24. """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
  25. Unmasked positions are filled with float(0.0).
  26. """
  27. mask = paddle.ones([1, size, size], dtype="float32")
  28. mask_inf = paddle.triu(
  29. paddle.full(shape=[1, size, size], dtype="float32", fill_value="-inf"),
  30. diagonal=1,
  31. )
  32. mask = mask + mask_inf
  33. padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
  34. return padding_mask
  35. def clones(module, N):
  36. return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
  37. def masked_fill(x, mask, value):
  38. y = paddle.full(x.shape, value, x.dtype)
  39. return paddle.where(mask, y, x)
  40. def attention(query, key, value, mask=None, dropout=None, attention_map=None):
  41. d_k = query.shape[-1]
  42. scores = paddle.matmul(query, paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
  43. if mask is not None:
  44. scores = masked_fill(scores, mask == 0, float("-inf"))
  45. else:
  46. pass
  47. p_attn = F.softmax(scores, axis=-1)
  48. if dropout is not None:
  49. p_attn = dropout(p_attn)
  50. return paddle.matmul(p_attn, value), p_attn
  51. class MultiHeadedAttention(nn.Layer):
  52. def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
  53. super(MultiHeadedAttention, self).__init__()
  54. assert d_model % h == 0
  55. self.d_k = d_model // h
  56. self.h = h
  57. self.linears = clones(nn.Linear(d_model, d_model), 4)
  58. self.attn = None
  59. self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
  60. self.compress_attention = compress_attention
  61. self.compress_attention_linear = nn.Linear(h, 1)
  62. def forward(self, query, key, value, mask=None, attention_map=None):
  63. if mask is not None:
  64. mask = mask.unsqueeze(1)
  65. nbatches = query.shape[0]
  66. query, key, value = [
  67. paddle.transpose(
  68. l(x).reshape([nbatches, -1, self.h, self.d_k]), [0, 2, 1, 3]
  69. )
  70. for l, x in zip(self.linears, (query, key, value))
  71. ]
  72. x, attention_map = attention(
  73. query,
  74. key,
  75. value,
  76. mask=mask,
  77. dropout=self.dropout,
  78. attention_map=attention_map,
  79. )
  80. x = paddle.reshape(
  81. paddle.transpose(x, [0, 2, 1, 3]), [nbatches, -1, self.h * self.d_k]
  82. )
  83. return self.linears[-1](x), attention_map
  84. class ResNet(nn.Layer):
  85. def __init__(self, num_in, block, layers):
  86. super(ResNet, self).__init__()
  87. self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
  88. self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
  89. self.relu1 = nn.ReLU()
  90. self.pool = nn.MaxPool2D((2, 2), (2, 2))
  91. self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
  92. self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
  93. self.relu2 = nn.ReLU()
  94. self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
  95. self.layer1 = self._make_layer(block, 128, 256, layers[0])
  96. self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
  97. self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
  98. self.layer1_relu = nn.ReLU()
  99. self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
  100. self.layer2 = self._make_layer(block, 256, 256, layers[1])
  101. self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
  102. self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
  103. self.layer2_relu = nn.ReLU()
  104. self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
  105. self.layer3 = self._make_layer(block, 256, 512, layers[2])
  106. self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
  107. self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
  108. self.layer3_relu = nn.ReLU()
  109. self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
  110. self.layer4 = self._make_layer(block, 512, 512, layers[3])
  111. self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
  112. self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
  113. self.layer4_conv2_relu = nn.ReLU()
  114. def _make_layer(self, block, inplanes, planes, blocks):
  115. if inplanes != planes:
  116. downsample = nn.Sequential(
  117. nn.Conv2D(inplanes, planes, 3, 1, 1),
  118. nn.BatchNorm2D(planes, use_global_stats=True),
  119. )
  120. else:
  121. downsample = None
  122. layers = []
  123. layers.append(block(inplanes, planes, downsample))
  124. for i in range(1, blocks):
  125. layers.append(block(planes, planes, downsample=None))
  126. return nn.Sequential(*layers)
  127. def forward(self, x):
  128. x = self.conv1(x)
  129. x = self.bn1(x)
  130. x = self.relu1(x)
  131. x = self.pool(x)
  132. x = self.conv2(x)
  133. x = self.bn2(x)
  134. x = self.relu2(x)
  135. x = self.layer1_pool(x)
  136. x = self.layer1(x)
  137. x = self.layer1_conv(x)
  138. x = self.layer1_bn(x)
  139. x = self.layer1_relu(x)
  140. x = self.layer2(x)
  141. x = self.layer2_conv(x)
  142. x = self.layer2_bn(x)
  143. x = self.layer2_relu(x)
  144. x = self.layer3(x)
  145. x = self.layer3_conv(x)
  146. x = self.layer3_bn(x)
  147. x = self.layer3_relu(x)
  148. x = self.layer4(x)
  149. x = self.layer4_conv2(x)
  150. x = self.layer4_conv2_bn(x)
  151. x = self.layer4_conv2_relu(x)
  152. return x
  153. class Bottleneck(nn.Layer):
  154. def __init__(self, input_dim):
  155. super(Bottleneck, self).__init__()
  156. self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
  157. self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
  158. self.relu = nn.ReLU()
  159. self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
  160. self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
  161. def forward(self, x):
  162. residual = x
  163. out = self.conv1(x)
  164. out = self.bn1(out)
  165. out = self.relu(out)
  166. out = self.conv2(out)
  167. out = self.bn2(out)
  168. out += residual
  169. out = self.relu(out)
  170. return out
  171. class PositionalEncoding(nn.Layer):
  172. "Implement the PE function."
  173. def __init__(self, dropout, dim, max_len=5000):
  174. super(PositionalEncoding, self).__init__()
  175. self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
  176. pe = paddle.zeros([max_len, dim])
  177. position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
  178. div_term = paddle.exp(
  179. paddle.arange(0, dim, 2).astype("float32") * (-math.log(10000.0) / dim)
  180. )
  181. pe[:, 0::2] = paddle.sin(position * div_term)
  182. pe[:, 1::2] = paddle.cos(position * div_term)
  183. pe = paddle.unsqueeze(pe, 0)
  184. self.register_buffer("pe", pe)
  185. def forward(self, x):
  186. x = x + self.pe[:, : x.shape[1]]
  187. return self.dropout(x)
  188. class PositionwiseFeedForward(nn.Layer):
  189. "Implements FFN equation."
  190. def __init__(self, d_model, d_ff, dropout=0.1):
  191. super(PositionwiseFeedForward, self).__init__()
  192. self.w_1 = nn.Linear(d_model, d_ff)
  193. self.w_2 = nn.Linear(d_ff, d_model)
  194. self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
  195. def forward(self, x):
  196. return self.w_2(self.dropout(F.relu(self.w_1(x))))
  197. class Generator(nn.Layer):
  198. "Define standard linear + softmax generation step."
  199. def __init__(self, d_model, vocab):
  200. super(Generator, self).__init__()
  201. self.proj = nn.Linear(d_model, vocab)
  202. self.relu = nn.ReLU()
  203. def forward(self, x):
  204. out = self.proj(x)
  205. return out
  206. class Embeddings(nn.Layer):
  207. def __init__(self, d_model, vocab):
  208. super(Embeddings, self).__init__()
  209. self.lut = nn.Embedding(vocab, d_model)
  210. self.d_model = d_model
  211. def forward(self, x):
  212. embed = self.lut(x) * math.sqrt(self.d_model)
  213. return embed
  214. class LayerNorm(nn.Layer):
  215. "Construct a layernorm module (See citation for details)."
  216. def __init__(self, features, eps=1e-6):
  217. super(LayerNorm, self).__init__()
  218. self.a_2 = self.create_parameter(
  219. shape=[features], default_initializer=paddle.nn.initializer.Constant(1.0)
  220. )
  221. self.b_2 = self.create_parameter(
  222. shape=[features], default_initializer=paddle.nn.initializer.Constant(0.0)
  223. )
  224. self.eps = eps
  225. def forward(self, x):
  226. mean = x.mean(-1, keepdim=True)
  227. std = x.std(-1, keepdim=True)
  228. return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
  229. class Decoder(nn.Layer):
  230. def __init__(self):
  231. super(Decoder, self).__init__()
  232. self.mask_multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
  233. self.mul_layernorm1 = LayerNorm(1024)
  234. self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
  235. self.mul_layernorm2 = LayerNorm(1024)
  236. self.pff = PositionwiseFeedForward(1024, 2048)
  237. self.mul_layernorm3 = LayerNorm(1024)
  238. def forward(self, text, conv_feature, attention_map=None):
  239. text_max_length = text.shape[1]
  240. mask = subsequent_mask(text_max_length)
  241. result = text
  242. result = self.mul_layernorm1(
  243. result + self.mask_multihead(text, text, text, mask=mask)[0]
  244. )
  245. b, c, h, w = conv_feature.shape
  246. conv_feature = paddle.transpose(conv_feature.reshape([b, c, h * w]), [0, 2, 1])
  247. word_image_align, attention_map = self.multihead(
  248. result, conv_feature, conv_feature, mask=None, attention_map=attention_map
  249. )
  250. result = self.mul_layernorm2(result + word_image_align)
  251. result = self.mul_layernorm3(result + self.pff(result))
  252. return result, attention_map
  253. class BasicBlock(nn.Layer):
  254. def __init__(self, inplanes, planes, downsample):
  255. super(BasicBlock, self).__init__()
  256. self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=3, stride=1, padding=1)
  257. self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
  258. self.relu = nn.ReLU()
  259. self.conv2 = nn.Conv2D(planes, planes, kernel_size=3, stride=1, padding=1)
  260. self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
  261. self.downsample = downsample
  262. def forward(self, x):
  263. residual = x
  264. out = self.conv1(x)
  265. out = self.bn1(out)
  266. out = self.relu(out)
  267. out = self.conv2(out)
  268. out = self.bn2(out)
  269. if self.downsample != None:
  270. residual = self.downsample(residual)
  271. out += residual
  272. out = self.relu(out)
  273. return out
  274. class Encoder(nn.Layer):
  275. def __init__(self):
  276. super(Encoder, self).__init__()
  277. self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
  278. def forward(self, input):
  279. conv_result = self.cnn(input)
  280. return conv_result
  281. class Transformer(nn.Layer):
  282. def __init__(self, in_channels=1, alphabet="0123456789"):
  283. super(Transformer, self).__init__()
  284. self.alphabet = alphabet
  285. word_n_class = self.get_alphabet_len()
  286. self.embedding_word_with_upperword = Embeddings(512, word_n_class)
  287. self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
  288. self.encoder = Encoder()
  289. self.decoder = Decoder()
  290. self.generator_word_with_upperword = Generator(1024, word_n_class)
  291. for p in self.parameters():
  292. if p.dim() > 1:
  293. nn.initializer.XavierNormal(p)
  294. def get_alphabet_len(self):
  295. return len(self.alphabet)
  296. def forward(self, image, text_length, text_input, attention_map=None):
  297. if image.shape[1] == 3:
  298. R = image[:, 0:1, :, :]
  299. G = image[:, 1:2, :, :]
  300. B = image[:, 2:3, :, :]
  301. image = 0.299 * R + 0.587 * G + 0.114 * B
  302. conv_feature = self.encoder(image) # batch, 1024, 8, 32
  303. max_length = max(text_length)
  304. text_input = text_input[:, :max_length]
  305. text_embedding = self.embedding_word_with_upperword(
  306. text_input
  307. ) # batch, text_max_length, 512
  308. postion_embedding = self.pe(
  309. paddle.zeros(text_embedding.shape)
  310. ) # batch, text_max_length, 512
  311. text_input_with_pe = paddle.concat(
  312. [text_embedding, postion_embedding], 2
  313. ) # batch, text_max_length, 1024
  314. batch, seq_len, _ = text_input_with_pe.shape
  315. text_input_with_pe, word_attention_map = self.decoder(
  316. text_input_with_pe, conv_feature
  317. )
  318. word_decoder_result = self.generator_word_with_upperword(text_input_with_pe)
  319. if self.training:
  320. total_length = paddle.sum(text_length)
  321. probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
  322. start = 0
  323. for index, length in enumerate(text_length):
  324. length = int(length.numpy())
  325. probs_res[start : start + length, :] = word_decoder_result[
  326. index, 0 : 0 + length, :
  327. ]
  328. start = start + length
  329. return probs_res, word_attention_map, None
  330. else:
  331. return word_decoder_result