table_master_head.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/decoders/master_decoder.py
  17. """
  18. import copy
  19. import math
  20. import paddle
  21. from paddle import nn
  22. from paddle.nn import functional as F
  23. class TableMasterHead(nn.Layer):
  24. """
  25. Split to two transformer header at the last layer.
  26. Cls_layer is used to structure token classification.
  27. Bbox_layer is used to regress bbox coord.
  28. """
  29. def __init__(
  30. self,
  31. in_channels,
  32. out_channels=30,
  33. headers=8,
  34. d_ff=2048,
  35. dropout=0,
  36. max_text_length=500,
  37. loc_reg_num=4,
  38. **kwargs,
  39. ):
  40. super(TableMasterHead, self).__init__()
  41. hidden_size = in_channels[-1]
  42. self.layers = clones(DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
  43. self.cls_layer = clones(DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
  44. self.bbox_layer = clones(DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
  45. self.cls_fc = nn.Linear(hidden_size, out_channels)
  46. self.bbox_fc = nn.Sequential(
  47. # nn.Linear(hidden_size, hidden_size),
  48. nn.Linear(hidden_size, loc_reg_num),
  49. nn.Sigmoid(),
  50. )
  51. self.norm = nn.LayerNorm(hidden_size)
  52. self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
  53. self.positional_encoding = PositionalEncoding(d_model=hidden_size)
  54. self.SOS = out_channels - 3
  55. self.PAD = out_channels - 1
  56. self.out_channels = out_channels
  57. self.loc_reg_num = loc_reg_num
  58. self.max_text_length = max_text_length
  59. def make_mask(self, tgt):
  60. """
  61. Make mask for self attention.
  62. :param src: [b, c, h, l_src]
  63. :param tgt: [b, l_tgt]
  64. :return:
  65. """
  66. trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
  67. tgt_len = tgt.shape[1]
  68. trg_sub_mask = paddle.tril(
  69. paddle.ones(([tgt_len, tgt_len]), dtype=paddle.float32)
  70. )
  71. tgt_mask = paddle.logical_and(trg_pad_mask.astype(paddle.float32), trg_sub_mask)
  72. return tgt_mask.astype(paddle.float32)
  73. def decode(self, input, feature, src_mask, tgt_mask):
  74. # main process of transformer decoder.
  75. x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
  76. x = self.positional_encoding(x)
  77. # origin transformer layers
  78. for i, layer in enumerate(self.layers):
  79. x = layer(x, feature, src_mask, tgt_mask)
  80. # cls head
  81. cls_x = x
  82. for layer in self.cls_layer:
  83. cls_x = layer(x, feature, src_mask, tgt_mask)
  84. cls_x = self.norm(cls_x)
  85. # bbox head
  86. bbox_x = x
  87. for layer in self.bbox_layer:
  88. bbox_x = layer(x, feature, src_mask, tgt_mask)
  89. bbox_x = self.norm(bbox_x)
  90. return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
  91. def greedy_forward(self, SOS, feature):
  92. input = SOS
  93. output = paddle.zeros(
  94. [input.shape[0], self.max_text_length + 1, self.out_channels]
  95. )
  96. bbox_output = paddle.zeros(
  97. [input.shape[0], self.max_text_length + 1, self.loc_reg_num]
  98. )
  99. max_text_length = paddle.to_tensor(self.max_text_length)
  100. for i in range(max_text_length + 1):
  101. target_mask = self.make_mask(input)
  102. out_step, bbox_output_step = self.decode(input, feature, None, target_mask)
  103. prob = F.softmax(out_step, axis=-1)
  104. next_word = prob.argmax(axis=2, dtype="int64")
  105. input = paddle.concat([input, next_word[:, -1].unsqueeze(-1)], axis=1)
  106. if i == self.max_text_length:
  107. output = out_step
  108. bbox_output = bbox_output_step
  109. return output, bbox_output
  110. def forward_train(self, out_enc, targets):
  111. # x is token of label
  112. # feat is feature after backbone before pe.
  113. # out_enc is feature after pe.
  114. padded_targets = targets[0]
  115. src_mask = None
  116. tgt_mask = self.make_mask(padded_targets[:, :-1])
  117. output, bbox_output = self.decode(
  118. padded_targets[:, :-1], out_enc, src_mask, tgt_mask
  119. )
  120. return {"structure_probs": output, "loc_preds": bbox_output}
  121. def forward_test(self, out_enc):
  122. batch_size = out_enc.shape[0]
  123. SOS = paddle.zeros([batch_size, 1], dtype="int64") + self.SOS
  124. output, bbox_output = self.greedy_forward(SOS, out_enc)
  125. output = F.softmax(output)
  126. return {"structure_probs": output, "loc_preds": bbox_output}
  127. def forward(self, feat, targets=None):
  128. feat = feat[-1]
  129. b, c, h, w = feat.shape
  130. feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
  131. feat = feat.transpose((0, 2, 1))
  132. out_enc = self.positional_encoding(feat)
  133. if self.training:
  134. return self.forward_train(out_enc, targets)
  135. return self.forward_test(out_enc)
  136. class DecoderLayer(nn.Layer):
  137. """
  138. Decoder is made of self attention, source attention and feed forward.
  139. """
  140. def __init__(self, headers, d_model, dropout, d_ff):
  141. super(DecoderLayer, self).__init__()
  142. self.self_attn = MultiHeadAttention(headers, d_model, dropout)
  143. self.src_attn = MultiHeadAttention(headers, d_model, dropout)
  144. self.feed_forward = FeedForward(d_model, d_ff, dropout)
  145. self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
  146. def forward(self, x, feature, src_mask, tgt_mask):
  147. x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
  148. x = self.sublayer[1](x, lambda x: self.src_attn(x, feature, feature, src_mask))
  149. return self.sublayer[2](x, self.feed_forward)
  150. class MultiHeadAttention(nn.Layer):
  151. def __init__(self, headers, d_model, dropout):
  152. super(MultiHeadAttention, self).__init__()
  153. assert d_model % headers == 0
  154. self.d_k = int(d_model / headers)
  155. self.headers = headers
  156. self.linears = clones(nn.Linear(d_model, d_model), 4)
  157. self.attn = None
  158. self.dropout = nn.Dropout(dropout)
  159. def forward(self, query, key, value, mask=None):
  160. B = query.shape[0]
  161. # 1) Do all the linear projections in batch from d_model => h x d_k
  162. query, key, value = [
  163. l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
  164. for l, x in zip(self.linears, (query, key, value))
  165. ]
  166. # 2) Apply attention on all the projected vectors in batch
  167. x, self.attn = self_attention(
  168. query, key, value, mask=mask, dropout=self.dropout
  169. )
  170. x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
  171. return self.linears[-1](x)
  172. class FeedForward(nn.Layer):
  173. def __init__(self, d_model, d_ff, dropout):
  174. super(FeedForward, self).__init__()
  175. self.w_1 = nn.Linear(d_model, d_ff)
  176. self.w_2 = nn.Linear(d_ff, d_model)
  177. self.dropout = nn.Dropout(dropout)
  178. def forward(self, x):
  179. return self.w_2(self.dropout(F.relu(self.w_1(x))))
  180. class SubLayerConnection(nn.Layer):
  181. """
  182. A residual connection followed by a layer norm.
  183. Note for code simplicity the norm is first as opposed to last.
  184. """
  185. def __init__(self, size, dropout):
  186. super(SubLayerConnection, self).__init__()
  187. self.norm = nn.LayerNorm(size)
  188. self.dropout = nn.Dropout(dropout)
  189. def forward(self, x, sublayer):
  190. return x + self.dropout(sublayer(self.norm(x)))
  191. def masked_fill(x, mask, value):
  192. mask = mask.astype(x.dtype)
  193. return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
  194. def self_attention(query, key, value, mask=None, dropout=None):
  195. """
  196. Compute 'Scale Dot Product Attention'
  197. """
  198. d_k = value.shape[-1]
  199. score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
  200. if mask is not None:
  201. # score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
  202. score = masked_fill(score, mask == 0, -6.55e4) # for fp16
  203. p_attn = F.softmax(score, axis=-1)
  204. if dropout is not None:
  205. p_attn = dropout(p_attn)
  206. return paddle.matmul(p_attn, value), p_attn
  207. def clones(module, N):
  208. """Produce N identical layers"""
  209. return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
  210. class Embeddings(nn.Layer):
  211. def __init__(self, d_model, vocab):
  212. super(Embeddings, self).__init__()
  213. self.lut = nn.Embedding(vocab, d_model)
  214. self.d_model = d_model
  215. def forward(self, *input):
  216. x = input[0]
  217. return self.lut(x) * math.sqrt(self.d_model)
  218. class PositionalEncoding(nn.Layer):
  219. """Implement the PE function."""
  220. def __init__(self, d_model, dropout=0.0, max_len=5000):
  221. super(PositionalEncoding, self).__init__()
  222. self.dropout = nn.Dropout(p=dropout)
  223. # Compute the positional encodings once in log space.
  224. pe = paddle.zeros([max_len, d_model])
  225. position = paddle.arange(0, max_len).unsqueeze(1).astype("float32")
  226. div_term = paddle.exp(
  227. paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model
  228. )
  229. pe[:, 0::2] = paddle.sin(position * div_term)
  230. pe[:, 1::2] = paddle.cos(position * div_term)
  231. pe = pe.unsqueeze(0)
  232. self.register_buffer("pe", pe)
  233. def forward(self, feat, **kwargs):
  234. feat = feat + self.pe[:, : feat.shape[1]] # pe 1*5000*512
  235. return self.dropout(feat)