fetcher.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. class _DatasetFetcher:
  15. def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
  16. self.dataset = dataset
  17. self.auto_collate_batch = auto_collate_batch
  18. self.collate_fn = collate_fn
  19. self.drop_last = drop_last
  20. # NOTE: fetch function here perform the whole pipeline of dataset
  21. # reading and data transforms of a batch in each calling, this
  22. # may take a long time inside, if DataLoader is exit outside,
  23. # fetch need to perceive exit situation, so we pass done_event
  24. # here for fetch to check exit status
  25. # NOTE: if DataLoader exit by `break`, performing GPU tensor operations,
  26. # e.g. to_tensor may cause SIGSEGV in thread, so we pass the
  27. # done_event argument to check DataLoader exit status between
  28. # each sample processing in the batch
  29. def fetch(self, batch_indices, done_event=None):
  30. raise NotImplementedError(
  31. f"'fetch' not implement for class {self.__class__.__name__}"
  32. )
  33. class _IterableDatasetFetcher(_DatasetFetcher):
  34. def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
  35. super().__init__(dataset, auto_collate_batch, collate_fn, drop_last)
  36. self.dataset_iter = iter(dataset)
  37. def fetch(self, batch_indices, done_event=None):
  38. if self.auto_collate_batch:
  39. data = []
  40. for _ in batch_indices:
  41. if done_event is None or not done_event.is_set():
  42. try:
  43. data.append(next(self.dataset_iter))
  44. except StopIteration:
  45. break
  46. else:
  47. return None
  48. if len(data) == 0 or (
  49. self.drop_last and len(data) < len(batch_indices)
  50. ):
  51. raise StopIteration
  52. else:
  53. data = next(self.dataset_iter)
  54. if self.collate_fn:
  55. data = self.collate_fn(data)
  56. return data
  57. class _MapDatasetFetcher(_DatasetFetcher):
  58. def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
  59. super().__init__(dataset, auto_collate_batch, collate_fn, drop_last)
  60. def fetch(self, batch_indices, done_event=None):
  61. if self.auto_collate_batch:
  62. data = []
  63. for idx in batch_indices:
  64. if done_event is None or not done_event.is_set():
  65. data.append(self.dataset[idx])
  66. else:
  67. return None
  68. else:
  69. data = self.dataset[batch_indices]
  70. if self.collate_fn:
  71. data = self.collate_fn(data)
  72. return data