kie_sdmgr_head.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import math
  19. import paddle
  20. from paddle import nn
  21. import paddle.nn.functional as F
  22. from paddle import ParamAttr
  23. class SDMGRHead(nn.Layer):
  24. def __init__(
  25. self,
  26. in_channels,
  27. num_chars=92,
  28. visual_dim=16,
  29. fusion_dim=1024,
  30. node_input=32,
  31. node_embed=256,
  32. edge_input=5,
  33. edge_embed=256,
  34. num_gnn=2,
  35. num_classes=26,
  36. bidirectional=False,
  37. ):
  38. super().__init__()
  39. self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
  40. self.node_embed = nn.Embedding(num_chars, node_input, 0)
  41. hidden = node_embed // 2 if bidirectional else node_embed
  42. self.rnn = nn.LSTM(input_size=node_input, hidden_size=hidden, num_layers=1)
  43. self.edge_embed = nn.Linear(edge_input, edge_embed)
  44. self.gnn_layers = nn.LayerList(
  45. [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)]
  46. )
  47. self.node_cls = nn.Linear(node_embed, num_classes)
  48. self.edge_cls = nn.Linear(edge_embed, 2)
  49. def forward(self, input, targets):
  50. relations, texts, x = input
  51. node_nums, char_nums = [], []
  52. for text in texts:
  53. node_nums.append(text.shape[0])
  54. char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
  55. max_num = max([char_num.max() for char_num in char_nums])
  56. all_nodes = paddle.concat(
  57. [
  58. paddle.concat(
  59. [text, paddle.zeros((text.shape[0], max_num - text.shape[1]))], -1
  60. )
  61. for text in texts
  62. ]
  63. )
  64. temp = paddle.clip(all_nodes, min=0).astype(int)
  65. embed_nodes = self.node_embed(temp)
  66. rnn_nodes, _ = self.rnn(embed_nodes)
  67. b, h, w = rnn_nodes.shape
  68. nodes = paddle.zeros([b, w])
  69. all_nums = paddle.concat(char_nums)
  70. valid = paddle.nonzero((all_nums > 0).astype(int))
  71. temp_all_nums = (paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
  72. temp_all_nums = paddle.expand(
  73. temp_all_nums,
  74. [temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]],
  75. )
  76. temp_all_nodes = paddle.gather(rnn_nodes, valid)
  77. N, C, A = temp_all_nodes.shape
  78. one_hot = F.one_hot(temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
  79. one_hot = paddle.multiply(temp_all_nodes, one_hot.astype("float32")).sum(
  80. axis=1, keepdim=True
  81. )
  82. t = one_hot.expand([N, 1, A]).squeeze(1)
  83. nodes = paddle.scatter(nodes, valid.squeeze(1), t)
  84. if x is not None:
  85. nodes = self.fusion([x, nodes])
  86. all_edges = paddle.concat(
  87. [rel.reshape([-1, rel.shape[-1]]) for rel in relations]
  88. )
  89. embed_edges = self.edge_embed(all_edges.astype("float32"))
  90. embed_edges = F.normalize(embed_edges)
  91. for gnn_layer in self.gnn_layers:
  92. nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
  93. node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
  94. return node_cls, edge_cls
  95. class GNNLayer(nn.Layer):
  96. def __init__(self, node_dim=256, edge_dim=256):
  97. super().__init__()
  98. self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
  99. self.coef_fc = nn.Linear(node_dim, 1)
  100. self.out_fc = nn.Linear(node_dim, node_dim)
  101. self.relu = nn.ReLU()
  102. def forward(self, nodes, edges, nums):
  103. start, cat_nodes = 0, []
  104. for num in nums:
  105. sample_nodes = nodes[start : start + num]
  106. cat_nodes.append(
  107. paddle.concat(
  108. [
  109. paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
  110. paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1]),
  111. ],
  112. -1,
  113. ).reshape([num**2, -1])
  114. )
  115. start += num
  116. cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
  117. cat_nodes = self.relu(self.in_fc(cat_nodes))
  118. coefs = self.coef_fc(cat_nodes)
  119. start, residuals = 0, []
  120. for num in nums:
  121. residual = F.softmax(
  122. -paddle.eye(num).unsqueeze(-1) * 1e9
  123. + coefs[start : start + num**2].reshape([num, num, -1]),
  124. 1,
  125. )
  126. residuals.append(
  127. (
  128. residual * cat_nodes[start : start + num**2].reshape([num, num, -1])
  129. ).sum(1)
  130. )
  131. start += num**2
  132. nodes += self.relu(self.out_fc(paddle.concat(residuals)))
  133. return [nodes, cat_nodes]
  134. class Block(nn.Layer):
  135. def __init__(
  136. self,
  137. input_dims,
  138. output_dim,
  139. mm_dim=1600,
  140. chunks=20,
  141. rank=15,
  142. shared=False,
  143. dropout_input=0.0,
  144. dropout_pre_lin=0.0,
  145. dropout_output=0.0,
  146. pos_norm="before_cat",
  147. ):
  148. super().__init__()
  149. self.rank = rank
  150. self.dropout_input = dropout_input
  151. self.dropout_pre_lin = dropout_pre_lin
  152. self.dropout_output = dropout_output
  153. assert pos_norm in ["before_cat", "after_cat"]
  154. self.pos_norm = pos_norm
  155. # Modules
  156. self.linear0 = nn.Linear(input_dims[0], mm_dim)
  157. self.linear1 = self.linear0 if shared else nn.Linear(input_dims[1], mm_dim)
  158. self.merge_linears0 = nn.LayerList()
  159. self.merge_linears1 = nn.LayerList()
  160. self.chunks = self.chunk_sizes(mm_dim, chunks)
  161. for size in self.chunks:
  162. ml0 = nn.Linear(size, size * rank)
  163. self.merge_linears0.append(ml0)
  164. ml1 = ml0 if shared else nn.Linear(size, size * rank)
  165. self.merge_linears1.append(ml1)
  166. self.linear_out = nn.Linear(mm_dim, output_dim)
  167. def forward(self, x):
  168. x0 = self.linear0(x[0])
  169. x1 = self.linear1(x[1])
  170. bs = x1.shape[0]
  171. if self.dropout_input > 0:
  172. x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
  173. x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
  174. x0_chunks = paddle.split(x0, self.chunks, -1)
  175. x1_chunks = paddle.split(x1, self.chunks, -1)
  176. zs = []
  177. for x0_c, x1_c, m0, m1 in zip(
  178. x0_chunks, x1_chunks, self.merge_linears0, self.merge_linears1
  179. ):
  180. m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
  181. m = m.reshape([bs, self.rank, -1])
  182. z = paddle.sum(m, 1)
  183. if self.pos_norm == "before_cat":
  184. z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
  185. z = F.normalize(z)
  186. zs.append(z)
  187. z = paddle.concat(zs, 1)
  188. if self.pos_norm == "after_cat":
  189. z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
  190. z = F.normalize(z)
  191. if self.dropout_pre_lin > 0:
  192. z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
  193. z = self.linear_out(z)
  194. if self.dropout_output > 0:
  195. z = F.dropout(z, p=self.dropout_output, training=self.training)
  196. return z
  197. def chunk_sizes(self, dim, chunks):
  198. split_size = (dim + chunks - 1) // chunks
  199. sizes_list = [split_size] * chunks
  200. sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
  201. return sizes_list