kie_unet_sdmgr.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import numpy as np
  20. import cv2
  21. __all__ = ["Kie_backbone"]
  22. class Encoder(nn.Layer):
  23. def __init__(self, num_channels, num_filters):
  24. super(Encoder, self).__init__()
  25. self.conv1 = nn.Conv2D(
  26. num_channels,
  27. num_filters,
  28. kernel_size=3,
  29. stride=1,
  30. padding=1,
  31. bias_attr=False,
  32. )
  33. self.bn1 = nn.BatchNorm(num_filters, act="relu")
  34. self.conv2 = nn.Conv2D(
  35. num_filters,
  36. num_filters,
  37. kernel_size=3,
  38. stride=1,
  39. padding=1,
  40. bias_attr=False,
  41. )
  42. self.bn2 = nn.BatchNorm(num_filters, act="relu")
  43. self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
  44. def forward(self, inputs):
  45. x = self.conv1(inputs)
  46. x = self.bn1(x)
  47. x = self.conv2(x)
  48. x = self.bn2(x)
  49. x_pooled = self.pool(x)
  50. return x, x_pooled
  51. class Decoder(nn.Layer):
  52. def __init__(self, num_channels, num_filters):
  53. super(Decoder, self).__init__()
  54. self.conv1 = nn.Conv2D(
  55. num_channels,
  56. num_filters,
  57. kernel_size=3,
  58. stride=1,
  59. padding=1,
  60. bias_attr=False,
  61. )
  62. self.bn1 = nn.BatchNorm(num_filters, act="relu")
  63. self.conv2 = nn.Conv2D(
  64. num_filters,
  65. num_filters,
  66. kernel_size=3,
  67. stride=1,
  68. padding=1,
  69. bias_attr=False,
  70. )
  71. self.bn2 = nn.BatchNorm(num_filters, act="relu")
  72. self.conv0 = nn.Conv2D(
  73. num_channels,
  74. num_filters,
  75. kernel_size=1,
  76. stride=1,
  77. padding=0,
  78. bias_attr=False,
  79. )
  80. self.bn0 = nn.BatchNorm(num_filters, act="relu")
  81. def forward(self, inputs_prev, inputs):
  82. x = self.conv0(inputs)
  83. x = self.bn0(x)
  84. x = paddle.nn.functional.interpolate(
  85. x, scale_factor=2, mode="bilinear", align_corners=False
  86. )
  87. x = paddle.concat([inputs_prev, x], axis=1)
  88. x = self.conv1(x)
  89. x = self.bn1(x)
  90. x = self.conv2(x)
  91. x = self.bn2(x)
  92. return x
  93. class UNet(nn.Layer):
  94. def __init__(self):
  95. super(UNet, self).__init__()
  96. self.down1 = Encoder(num_channels=3, num_filters=16)
  97. self.down2 = Encoder(num_channels=16, num_filters=32)
  98. self.down3 = Encoder(num_channels=32, num_filters=64)
  99. self.down4 = Encoder(num_channels=64, num_filters=128)
  100. self.down5 = Encoder(num_channels=128, num_filters=256)
  101. self.up1 = Decoder(32, 16)
  102. self.up2 = Decoder(64, 32)
  103. self.up3 = Decoder(128, 64)
  104. self.up4 = Decoder(256, 128)
  105. self.out_channels = 16
  106. def forward(self, inputs):
  107. x1, _ = self.down1(inputs)
  108. _, x2 = self.down2(x1)
  109. _, x3 = self.down3(x2)
  110. _, x4 = self.down4(x3)
  111. _, x5 = self.down5(x4)
  112. x = self.up4(x4, x5)
  113. x = self.up3(x3, x)
  114. x = self.up2(x2, x)
  115. x = self.up1(x1, x)
  116. return x
  117. class Kie_backbone(nn.Layer):
  118. def __init__(self, in_channels, **kwargs):
  119. super(Kie_backbone, self).__init__()
  120. self.out_channels = 16
  121. self.img_feat = UNet()
  122. self.maxpool = nn.MaxPool2D(kernel_size=7)
  123. def bbox2roi(self, bbox_list):
  124. rois_list = []
  125. rois_num = []
  126. for img_id, bboxes in enumerate(bbox_list):
  127. rois_num.append(bboxes.shape[0])
  128. rois_list.append(bboxes)
  129. rois = paddle.concat(rois_list, 0)
  130. rois_num = paddle.to_tensor(rois_num, dtype="int32")
  131. return rois, rois_num
  132. def pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):
  133. img, relations, texts, gt_bboxes, tag, img_size = (
  134. img.numpy(),
  135. relations.numpy(),
  136. texts.numpy(),
  137. gt_bboxes.numpy(),
  138. tag.numpy().tolist(),
  139. img_size.numpy(),
  140. )
  141. temp_relations, temp_texts, temp_gt_bboxes = [], [], []
  142. h, w = int(np.max(img_size[:, 0])), int(np.max(img_size[:, 1]))
  143. img = paddle.to_tensor(img[:, :, :h, :w])
  144. batch = len(tag)
  145. for i in range(batch):
  146. num, recoder_len = tag[i][0], tag[i][1]
  147. temp_relations.append(
  148. paddle.to_tensor(relations[i, :num, :num, :], dtype="float32")
  149. )
  150. temp_texts.append(
  151. paddle.to_tensor(texts[i, :num, :recoder_len], dtype="float32")
  152. )
  153. temp_gt_bboxes.append(
  154. paddle.to_tensor(gt_bboxes[i, :num, ...], dtype="float32")
  155. )
  156. return img, temp_relations, temp_texts, temp_gt_bboxes
  157. def forward(self, inputs):
  158. img = inputs[0]
  159. relations, texts, gt_bboxes, tag, img_size = (
  160. inputs[1],
  161. inputs[2],
  162. inputs[3],
  163. inputs[5],
  164. inputs[-1],
  165. )
  166. img, relations, texts, gt_bboxes = self.pre_process(
  167. img, relations, texts, gt_bboxes, tag, img_size
  168. )
  169. x = self.img_feat(img)
  170. boxes, rois_num = self.bbox2roi(gt_bboxes)
  171. feats = paddle.vision.ops.roi_align(
  172. x, boxes, spatial_scale=1.0, output_size=7, boxes_num=rois_num
  173. )
  174. feats = self.maxpool(feats).squeeze(-1).squeeze(-1)
  175. return [relations, texts, feats]