sampler.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. class Sampler(object):
  4. def __init__(self):
  5. return
  6. def __len__(self):
  7. raise NotImplementedError
  8. def __iter__(self):
  9. raise NotImplementedError
  10. class SequentialSampler(Sampler):
  11. def __init__(self, dataset):
  12. self.dataset = dataset
  13. return
  14. def __len__(self):
  15. return len(self.dataset)
  16. def __iter__(self):
  17. return iter(range(len(self)))
  18. class RandomSampler(Sampler):
  19. def __init__(self, dataset):
  20. self.dataset = dataset
  21. self.epoch = 0
  22. return
  23. def __len__(self):
  24. return len(self.dataset)
  25. def __iter__(self):
  26. np.random.seed(self.epoch)
  27. self.epoch += 1
  28. return iter(np.random.permutation(len(self)))
  29. class SortedSampler(Sampler):
  30. """ Sorted Sampler.
  31. Sort each block of examples by key.
  32. """
  33. def __init__(self, sampler, sort_pool_size, key='src'):
  34. self.sampler = sampler
  35. self.sort_pool_size = sort_pool_size
  36. self.key = lambda idx: len(self.sampler.dataset[idx][key])
  37. return
  38. def __len__(self):
  39. return len(self.sampler)
  40. def __iter__(self):
  41. pool = []
  42. for idx in self.sampler:
  43. pool.append(idx)
  44. if len(pool) == self.sort_pool_size:
  45. pool = sorted(pool, key=self.key)
  46. for i in pool:
  47. yield i
  48. pool = []
  49. if len(pool) > 0:
  50. pool = sorted(pool, key=self.key)
  51. for i in pool:
  52. yield i