batch.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. __all__ = []
  15. def batch(reader, batch_size, drop_last=False):
  16. """
  17. This operator creates a batched reader which combines the data from the
  18. input reader to batched data.
  19. Args:
  20. reader(generator): the data reader to read from.
  21. batch_size(int): size of each mini-batch.
  22. drop_last(bool, optional): If set to True, the last batch is dropped when
  23. the size of last batch is not equal to batch_size, if set to False,
  24. it will not. Default: False.
  25. Returns:
  26. The batched reader.
  27. Return Type:
  28. generator
  29. Examples:
  30. .. code-block:: python
  31. >>> import paddle
  32. >>> def reader():
  33. ... for i in range(10):
  34. ... yield i
  35. >>> batch_reader = paddle.batch(reader, batch_size=2)
  36. >>> for data in batch_reader():
  37. ... print(data)
  38. ...
  39. [0, 1]
  40. [2, 3]
  41. [4, 5]
  42. [6, 7]
  43. [8, 9]
  44. """
  45. def batch_reader():
  46. r = reader()
  47. b = []
  48. for instance in r:
  49. b.append(instance)
  50. if len(b) == batch_size:
  51. yield b
  52. b = []
  53. if drop_last is False and len(b) != 0:
  54. yield b
  55. # Batch size check
  56. batch_size = int(batch_size)
  57. if batch_size <= 0:
  58. raise ValueError(
  59. "batch_size should be a positive integer value, "
  60. f"but got batch_size={batch_size}"
  61. )
  62. return batch_reader