collate_fn.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import numbers
  16. import numpy as np
  17. from collections import defaultdict
  18. class DictCollator(object):
  19. """
  20. data batch
  21. """
  22. def __call__(self, batch):
  23. # todo:support batch operators
  24. data_dict = defaultdict(list)
  25. to_tensor_keys = []
  26. for sample in batch:
  27. for k, v in sample.items():
  28. if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
  29. if k not in to_tensor_keys:
  30. to_tensor_keys.append(k)
  31. data_dict[k].append(v)
  32. for k in to_tensor_keys:
  33. data_dict[k] = paddle.to_tensor(data_dict[k])
  34. return data_dict
  35. class ListCollator(object):
  36. """
  37. data batch
  38. """
  39. def __call__(self, batch):
  40. # todo:support batch operators
  41. data_dict = defaultdict(list)
  42. to_tensor_idxs = []
  43. for sample in batch:
  44. for idx, v in enumerate(sample):
  45. if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
  46. if idx not in to_tensor_idxs:
  47. to_tensor_idxs.append(idx)
  48. data_dict[idx].append(v)
  49. for idx in to_tensor_idxs:
  50. data_dict[idx] = paddle.to_tensor(data_dict[idx])
  51. return list(data_dict.values())
  52. class SSLRotateCollate(object):
  53. """
  54. bach: [
  55. [(4*3xH*W), (4,)]
  56. [(4*3xH*W), (4,)]
  57. ...
  58. ]
  59. """
  60. def __call__(self, batch):
  61. output = [np.concatenate(d, axis=0) for d in zip(*batch)]
  62. return output
  63. class DyMaskCollator(object):
  64. """
  65. batch: [
  66. image [batch_size, channel, maxHinbatch, maxWinbatch]
  67. image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
  68. label [batch_size, maxLabelLen]
  69. label_mask [batch_size, maxLabelLen]
  70. ...
  71. ]
  72. """
  73. def __call__(self, batch):
  74. max_width, max_height, max_length = 0, 0, 0
  75. bs, channel = len(batch), batch[0][0].shape[0]
  76. proper_items = []
  77. for item in batch:
  78. if (
  79. item[0].shape[1] * max_width > 1600 * 320
  80. or item[0].shape[2] * max_height > 1600 * 320
  81. ):
  82. continue
  83. max_height = (
  84. item[0].shape[1] if item[0].shape[1] > max_height else max_height
  85. )
  86. max_width = item[0].shape[2] if item[0].shape[2] > max_width else max_width
  87. max_length = len(item[1]) if len(item[1]) > max_length else max_length
  88. proper_items.append(item)
  89. images, image_masks = np.zeros(
  90. (len(proper_items), channel, max_height, max_width), dtype="float32"
  91. ), np.zeros((len(proper_items), 1, max_height, max_width), dtype="float32")
  92. labels, label_masks = np.zeros(
  93. (len(proper_items), max_length), dtype="int64"
  94. ), np.zeros((len(proper_items), max_length), dtype="int64")
  95. for i in range(len(proper_items)):
  96. _, h, w = proper_items[i][0].shape
  97. images[i][:, :h, :w] = proper_items[i][0]
  98. image_masks[i][:, :h, :w] = 1
  99. l = len(proper_items[i][1])
  100. labels[i][:l] = proper_items[i][1]
  101. label_masks[i][:l] = 1
  102. return images, image_masks, labels, label_masks
  103. class LaTeXOCRCollator(object):
  104. """
  105. batch: [
  106. image [batch_size, channel, maxHinbatch, maxWinbatch]
  107. label [batch_size, maxLabelLen]
  108. label_mask [batch_size, maxLabelLen]
  109. ...
  110. ]
  111. """
  112. def __call__(self, batch):
  113. images, labels, attention_mask = batch[0]
  114. return images, labels, attention_mask
  115. class UniMERNetCollator(object):
  116. """
  117. batch: [
  118. image [batch_size, channel, maxHinbatch, maxWinbatch]
  119. image_mask [batch_size, channel, maxHinbatch, maxWinbatch]
  120. label [batch_size, maxLabelLen]
  121. label_mask [batch_size, maxLabelLen]
  122. ...
  123. ]
  124. """
  125. def __call__(self, batch):
  126. max_width, max_height, max_length = 0, 0, 0
  127. bs, channel = len(batch), batch[0][0].shape[0]
  128. proper_items = []
  129. for item in batch:
  130. max_height = (
  131. item[0].shape[1] if item[0].shape[1] > max_height else max_height
  132. )
  133. max_width = item[0].shape[2] if item[0].shape[2] > max_width else max_width
  134. max_length = len(item[1]) if len(item[1]) > max_length else max_length
  135. proper_items.append(item)
  136. images = np.ones(
  137. (len(proper_items), channel, max_height, max_width), dtype="float32"
  138. )
  139. labels, label_masks = np.ones(
  140. (len(proper_items), max_length), dtype="int64"
  141. ), np.zeros((len(proper_items), max_length), dtype="int64")
  142. for i in range(len(proper_items)):
  143. _, h, w = proper_items[i][0].shape
  144. images[i][:, :h, :w] = proper_items[i][0]
  145. l = len(proper_items[i][1])
  146. labels[i][:l] = proper_items[i][1]
  147. label_masks[i][:l] = proper_items[i][2]
  148. return images, labels, label_masks