batch.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. def batch(reader, batch_size, drop_last=False):
  3. """
  4. This operator creates a batched reader which combines the data from the
  5. input reader to batched data.
  6. Args:
  7. reader(generator): the data reader to read from.
  8. batch_size(int): size of each mini-batch.
  9. drop_last(bool, optional): If set to True, the last batch is dropped when
  10. the size of last batch is not equal to batch_size, if set to False,
  11. it will not. Default: False.
  12. Returns:
  13. The batched reader.
  14. Return Type:
  15. generator
  16. Examples:
  17. >>> import paddle.fluid as fluid
  18. >>> def reader():
  19. >>> for i in range(10):
  20. >>> yield i
  21. >>> batch_reader = fluid.io.batch(reader, batch_size=2)
  22. >>> for data in batch_reader():
  23. >>> print(data)
  24. >>> # Output is
  25. >>> # [0, 1]
  26. >>> # [2, 3]
  27. >>> # [4, 5]
  28. >>> # [6, 7]
  29. >>> # [8, 9]
  30. """
  31. def batch_reader():
  32. r = reader()
  33. b = []
  34. for instance in r:
  35. b.append(instance)
  36. if len(b) == batch_size:
  37. yield b
  38. b = []
  39. if drop_last is False and len(b) != 0:
  40. yield b
  41. # Batch size check
  42. batch_size = int(batch_size)
  43. if batch_size <= 0:
  44. raise ValueError('batch_size should be a positive integer value, '
  45. 'but got batch_size={}'.format(batch_size))
  46. return batch_reader