flat.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numbers
  15. from collections.abc import Mapping, Sequence
  16. import numpy as np
  17. import paddle
  18. FIELD_PREFIX = "_paddle_field_"
  19. def _flatten_batch(batch):
  20. """
  21. For lod_blocking_queue only receive tensor array, flatten batch
  22. data, extract numpy.array data out as a list of numpy.array to
  23. send to lod_blocking_queue, and save the batch data structure
  24. such as fields in other types (str, int, etc) or key-value map
  25. of dictionaries
  26. """
  27. def _flatten(batch, flat_batch, structure, field_idx):
  28. if isinstance(batch, Sequence):
  29. for field in batch:
  30. if isinstance(
  31. field,
  32. (np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor),
  33. ):
  34. structure.append(f'{FIELD_PREFIX}{field_idx}')
  35. flat_batch.append(field)
  36. field_idx += 1
  37. elif isinstance(field, (str, bytes, numbers.Number)):
  38. structure.append(field)
  39. elif isinstance(field, Sequence):
  40. field_struct, field_idx = _flatten(
  41. field, flat_batch, [], field_idx
  42. )
  43. structure.append(field_struct)
  44. elif isinstance(field, Mapping):
  45. field_struct, field_idx = _flatten(
  46. field, flat_batch, {}, field_idx
  47. )
  48. structure.append(field_struct)
  49. else:
  50. structure.append(field)
  51. elif isinstance(batch, Mapping):
  52. for k, field in batch.items():
  53. if isinstance(
  54. field,
  55. (np.ndarray, paddle.Tensor, paddle.base.core.eager.Tensor),
  56. ):
  57. structure[k] = f'{FIELD_PREFIX}{field_idx}'
  58. flat_batch.append(field)
  59. field_idx += 1
  60. elif isinstance(field, (str, bytes, numbers.Number)):
  61. structure[k] = field
  62. elif isinstance(field, Sequence):
  63. field_struct, field_idx = _flatten(
  64. field, flat_batch, [], field_idx
  65. )
  66. structure[k] = field_struct
  67. elif isinstance(field, Mapping):
  68. field_struct, field_idx = _flatten(
  69. field, flat_batch, {}, field_idx
  70. )
  71. structure[k] = field_struct
  72. else:
  73. structure[k] = field
  74. else:
  75. raise TypeError(f"wrong flat data type: {type(batch)}")
  76. return structure, field_idx
  77. # sample only contains single fields
  78. if not isinstance(batch, Sequence):
  79. flat_batch = []
  80. structure, _ = _flatten([batch], flat_batch, [], 0)
  81. return flat_batch, structure[0]
  82. flat_batch = []
  83. structure, _ = _flatten(batch, flat_batch, [], 0)
  84. return flat_batch, structure
  85. def _restore_batch(flat_batch, structure):
  86. """
  87. After reading list of Tensor data from lod_blocking_queue outputs,
  88. use this function to restore the batch data structure, replace
  89. :attr:`_paddle_field_x` with data from flat_batch
  90. """
  91. def _restore(structure, field_idx):
  92. if isinstance(structure, Sequence):
  93. for i, field in enumerate(structure):
  94. if isinstance(field, str) and field.startswith(FIELD_PREFIX):
  95. cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
  96. field_idx = max(field_idx, cur_field_idx)
  97. assert (
  98. flat_batch[cur_field_idx] is not None
  99. ), "flat_batch[{}] parsed repeatly"
  100. structure[i] = flat_batch[cur_field_idx]
  101. flat_batch[cur_field_idx] = None
  102. elif isinstance(field, (str, bytes, numbers.Number)):
  103. continue
  104. elif isinstance(field, (Sequence, Mapping)):
  105. field_idx = _restore(structure[i], field_idx)
  106. elif isinstance(structure, Mapping):
  107. for k, field in structure.items():
  108. if isinstance(field, str) and field.startswith(FIELD_PREFIX):
  109. cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
  110. field_idx = max(field_idx, cur_field_idx)
  111. assert (
  112. flat_batch[cur_field_idx] is not None
  113. ), "flat_batch[{}] parsed repeatly"
  114. structure[k] = flat_batch[cur_field_idx]
  115. flat_batch[cur_field_idx] = None
  116. elif isinstance(field, (str, bytes, numbers.Number)):
  117. continue
  118. elif isinstance(field, (Sequence, Mapping)):
  119. field_idx = _restore(structure[k], field_idx)
  120. else:
  121. raise TypeError(f"wrong flat data type: {type(structure)}")
  122. return field_idx
  123. assert isinstance(flat_batch, Sequence), "flat_batch is not a list or tuple"
  124. # no np.array in dataset, no output tensor from blocking queue
  125. # simply return structure
  126. if len(flat_batch) == 0:
  127. return structure
  128. # sample only contains single fields
  129. if isinstance(structure, (str, bytes)):
  130. assert (
  131. structure == f'{FIELD_PREFIX}{0}'
  132. ), f"invalid structure: {structure}"
  133. return flat_batch[0]
  134. field_idx = _restore(structure, 0)
  135. assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete"
  136. return structure