reader.py 64 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import multiprocessing
  16. import queue
  17. import sys
  18. import threading
  19. import warnings
  20. import numpy as np
  21. import paddle
  22. from paddle.base.framework import _set_expected_place
  23. from paddle.pir.core import datatype_to_vartype
  24. from . import core
  25. from .data_feeder import BatchedTensorProvider, DataFeeder
  26. from .executor import global_scope
  27. from .framework import (
  28. Program,
  29. _current_expected_place,
  30. _get_paddle_place,
  31. _get_paddle_place_list,
  32. default_main_program,
  33. default_startup_program,
  34. in_dygraph_mode,
  35. in_pir_mode,
  36. program_guard,
  37. )
  38. from .layers.io import (
  39. __create_unshared_decorated_reader__,
  40. _copy_reader_var_,
  41. monkey_patch_reader_methods,
  42. )
  43. from .multiprocess_utils import ( # noqa: F401
  44. CleanupFuncRegistrar,
  45. _cleanup,
  46. _cleanup_mmap,
  47. _set_SIGCHLD_handler,
  48. multiprocess_queue_set,
  49. )
  50. from .unique_name import UniqueNameGenerator
  51. # NOTE: [ avoid hanging & failed quickly ] These value is used in getting data from another process
  52. QUEUE_GET_TIMEOUT = 60
  53. __all__ = []
  54. data_loader_unique_name_generator = UniqueNameGenerator()
  55. KEEP_DATA_LOADER_ORDER = True
  56. USE_PINNED_MEMORY = None
  57. def keep_data_loader_order(*args):
  58. global KEEP_DATA_LOADER_ORDER
  59. if len(args) == 0:
  60. return KEEP_DATA_LOADER_ORDER
  61. else:
  62. assert len(args) == 1 and isinstance(args[0], bool)
  63. KEEP_DATA_LOADER_ORDER = args[0]
  64. def use_pinned_memory(*args):
  65. global USE_PINNED_MEMORY
  66. if len(args) == 0:
  67. return USE_PINNED_MEMORY
  68. else:
  69. assert len(args) == 1 and isinstance(args[0], bool)
  70. USE_PINNED_MEMORY = args[0]
  71. def _convert_places(places):
  72. if not isinstance(places, (list, tuple)):
  73. places = [places]
  74. ret = []
  75. for p in places:
  76. if not isinstance(p, core.Place):
  77. tmp = core.Place()
  78. tmp.set_place(p)
  79. p = tmp
  80. ret.append(p)
  81. return ret
  82. # NOTE(chenweihang): _reader_process_loop must be top level method to be pickled
  83. def _reader_process_loop(
  84. batch_reader, data_queue, dataloader_use_file_descriptor=True
  85. ):
  86. try:
  87. # set signal handler
  88. core._set_process_signal_handler()
  89. if not dataloader_use_file_descriptor:
  90. # set dataloader_use_file_descriptor to false to avoid use descriptor.
  91. paddle.base.core.globals()[
  92. "FLAGS_dataloader_use_file_descriptor"
  93. ] = False
  94. # NOTE: [ mmap files clear ] When the child process exits unexpectedly,
  95. # some shared memory objects may have been applied for but have not yet
  96. # been put into the inter-process Queue. This part of the object needs
  97. # to be cleaned up when the process ends.
  98. CleanupFuncRegistrar.register(_cleanup_mmap)
  99. for batch in batch_reader():
  100. tensor_list = core._convert_to_tensor_list(batch)
  101. data_queue.put(tensor_list)
  102. core._remove_tensor_list_mmap_fds(tensor_list)
  103. data_queue.put(None)
  104. except KeyboardInterrupt:
  105. # NOTE: Main process will raise KeyboardInterrupt anyways, ignore it in child process
  106. pass
  107. except:
  108. raise
  109. class DataLoaderBase:
  110. def __init__(self):
  111. self._places = None
  112. def __call__(self):
  113. return self
  114. def __iter__(self):
  115. raise NotImplementedError()
  116. def __next__(self):
  117. raise NotImplementedError()
  118. @classmethod
  119. def _check_input_array(cls, item):
  120. arr = np.asarray(item)
  121. if arr.dtype == np.object_:
  122. raise TypeError(
  123. "\n\tFailed to convert input data to a regular ndarray :\n\t* Usually "
  124. "this means the input data contains nested lists with different lengths. "
  125. "\n\t* Check the reader function passed to 'decorate_batch_generator'"
  126. " to locate the data causes this issue.\n\t* Please consider using "
  127. "'base.create_lod_tensor' to convert it to a LoD-Tensor."
  128. )
  129. return arr
  130. class DataLoader:
  131. @staticmethod
  132. def from_generator(
  133. feed_list=None,
  134. capacity=None,
  135. use_double_buffer=True,
  136. iterable=True,
  137. return_list=False,
  138. use_multiprocess=False,
  139. drop_last=True,
  140. ):
  141. """
  142. .. warning::
  143. This API will be deprecated in the future, it is recommended to use
  144. :code:`paddle.io.DataLoader` which supports multi-processes acceleration.
  145. .. note::
  146. **The framework ensures that the data loading order of DataLoader is exactly the same as the user-defined data source.**
  147. Create a DataLoader object for loading data from Python generator.
  148. Data would be prefetched using Python thread and be pushed
  149. into a queue asynchronously.
  150. The created DataLoader object provides 3 methods to set the data source
  151. :code:`set_sample_generator` , :code:`set_sample_list_generator` and
  152. :code:`set_batch_generator` . Please see the following example codes
  153. to know their usages.
  154. If iterable = True, the created DataLoader object is a Python generator
  155. object, which is iterable using for-range loop.
  156. If iterable = False, the created DataLoader object provides
  157. :code:`start()` and :code:`reset()` method to control the data reading
  158. process.
  159. Args:
  160. feed_list (list(Tensor)|tuple(Tensor)): feed Tensor list.
  161. The Tensors should be created by :code:`paddle.static.data()`.
  162. capacity (int): capacity of the queue maintained in DataLoader.
  163. The unit is batch number. Set larger capacity if your reader
  164. is fast.
  165. use_double_buffer (bool, optional): whether to use double_buffer_reader.
  166. If use_double_buffer=True, the DataLoader would prefetch next
  167. batch data asynchronously, so it would speed up data feeding
  168. and occupies a little more CPU or GPU memory, i.e., the memory
  169. of one batch input data.
  170. iterable (bool, optional): whether the created DataLoader is iterable.
  171. return_list (bool, optional): whether the return value on each device is
  172. presented as a list. It is only valid when iterable=True.
  173. If return_list=False, the return value on each device would
  174. be a dict of str -> LoDTensor, where the key of the dict is
  175. the name of each fed Tensors. If return_list=True, the
  176. return value on each device would be a list(LoDTensor). It is
  177. recommended to use return_list=False in static graph mode and
  178. use return_list=True in dygraph mode.
  179. use_multiprocess (bool, optional): whether to use multi-process to
  180. speed up the data loading process in dygraph. Note: this parameter
  181. only can be used in the dygraph mode. In the static graph mode,
  182. whether this parameter is set or not has no effect.
  183. The Default value is False.
  184. drop_last (bool, optional): whether to drop the last batches whose
  185. number is less than the CPU core/GPU card number. The default
  186. value is True. In training phase, users should not set drop_last=False,
  187. because all CPU cores/GPU cards must read data from DataLoader.
  188. In inference phase, users can set drop_last=False, so that the
  189. last batches whose number is less than the CPU core/GPU card
  190. number can be tested.
  191. Returns:
  192. loader (DataLoader): the created DataLoader object.
  193. Examples:
  194. .. code-block:: python
  195. :name: example_1
  196. >>> # Example in static graph mode
  197. >>> import numpy as np
  198. >>> import paddle
  199. >>> import paddle.static as static
  200. >>> import paddle.nn.functional as F
  201. >>> BATCH_NUM = 10
  202. >>> BATCH_SIZE = 16
  203. >>> EPOCH_NUM = 4
  204. >>> CLASS_NUM = 10
  205. >>> ITERABLE = True # whether the created DataLoader object is iterable
  206. >>> USE_GPU = False # whether to use GPU
  207. >>> DATA_FORMAT = 'batch_generator' # data format of data source user provides
  208. >>> paddle.enable_static()
  209. >>> def simple_net(image, label):
  210. ... fc_tmp = static.nn.fc(image, size=CLASS_NUM)
  211. ... cross_entropy = F.softmax_with_cross_entropy(image, label)
  212. ... loss = paddle.mean(cross_entropy)
  213. ... sgd = paddle.optimizer.SGD(learning_rate=1e-3)
  214. ... sgd.minimize(loss)
  215. ... return loss
  216. ...
  217. >>> def get_random_images_and_labels(image_shape, label_shape):
  218. ... image = np.random.random(size=image_shape).astype('float32')
  219. ... label = np.random.random(size=label_shape).astype('int64')
  220. ... return image, label
  221. ...
  222. >>> # If the data generator yields one sample each time,
  223. >>> # use DataLoader.set_sample_generator to set the data source.
  224. >>> def sample_generator_creator():
  225. ... def __reader__():
  226. ... for _ in range(BATCH_NUM * BATCH_SIZE):
  227. ... image, label = get_random_images_and_labels([784], [1])
  228. ... yield image, label
  229. ...
  230. ... return __reader__
  231. ...
  232. >>> # If the data generator yield list of samples each time,
  233. >>> # use DataLoader.set_sample_list_generator to set the data source.
  234. >>> def sample_list_generator_creator():
  235. ... def __reader__():
  236. ... for _ in range(BATCH_NUM):
  237. ... sample_list = []
  238. ... for _ in range(BATCH_SIZE):
  239. ... image, label = get_random_images_and_labels([784], [1])
  240. ... sample_list.append([image, label])
  241. ...
  242. ... yield sample_list
  243. ...
  244. ... return __reader__
  245. ...
  246. >>> # If the data generator yields a batch each time,
  247. >>> # use DataLoader.set_batch_generator to set the data source.
  248. >>> def batch_generator_creator():
  249. ... def __reader__():
  250. ... for _ in range(BATCH_NUM):
  251. ... batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])
  252. ... yield batch_image, batch_label
  253. ...
  254. ... return __reader__
  255. ...
  256. >>> # If DataLoader is iterable, use for loop to train the network
  257. >>> def train_iterable(exe, prog, loss, loader):
  258. ... for _ in range(EPOCH_NUM):
  259. ... for data in loader():
  260. ... exe.run(prog, feed=data, fetch_list=[loss])
  261. ...
  262. >>> # If DataLoader is not iterable, use start() and reset() method to control the process
  263. >>> def train_non_iterable(exe, prog, loss, loader):
  264. ... for _ in range(EPOCH_NUM):
  265. ... loader.start() # call DataLoader.start() before each epoch starts
  266. ... try:
  267. ... while True:
  268. ... exe.run(prog, fetch_list=[loss])
  269. ... except paddle.core.EOFException:
  270. ... loader.reset() # call DataLoader.reset() after catching EOFException
  271. ...
  272. >>> def set_data_source(loader, places):
  273. ... if DATA_FORMAT == 'sample_generator':
  274. ... loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)
  275. ... elif DATA_FORMAT == 'sample_list_generator':
  276. ... loader.set_sample_list_generator(sample_list_generator_creator(), places=places)
  277. ... elif DATA_FORMAT == 'batch_generator':
  278. ... loader.set_batch_generator(batch_generator_creator(), places=places)
  279. ... else:
  280. ... raise ValueError('Unsupported data format')
  281. ...
  282. >>> image = static.data(name='image', shape=[None, 784], dtype='float32')
  283. >>> label = static.data(name='label', shape=[None, 1], dtype='int64')
  284. >>> # Define DataLoader
  285. >>> loader = paddle.base.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)
  286. >>> # Define network
  287. >>> loss = simple_net(image, label)
  288. >>> places = static.cuda_places() if USE_GPU else static.cpu_places()
  289. >>> set_data_source(loader, places)
  290. >>> exe = static.Executor(places[0])
  291. >>> exe.run(static.default_startup_program())
  292. >>> prog = static.CompiledProgram(static.default_main_program())
  293. >>> if loader.iterable:
  294. ... train_iterable(exe, prog, loss, loader)
  295. >>> else:
  296. ... train_non_iterable(exe, prog, loss, loader)
  297. .. code-block:: python
  298. :name: example_2
  299. >>> # Example in dynamic graph mode.
  300. >>> import numpy as np
  301. >>> import paddle
  302. >>> import paddle.nn as nn
  303. >>> import paddle.optimizer as opt
  304. >>> import paddle.distributed as dist
  305. >>> BATCH_SIZE = 16
  306. >>> BATCH_NUM = 4
  307. >>> EPOCH_NUM = 4
  308. >>> IMAGE_SIZE = 784
  309. >>> CLASS_NUM = 10
  310. >>> USE_GPU = False # whether to use GPU
  311. >>> def _get_random_images_and_labels(image_shape):
  312. ... image = np.random.random(size=image_shape).astype('float32')
  313. ... label = np.random.randint(0, CLASS_NUM, size=BATCH_SIZE).astype('int64')
  314. ... return image, label
  315. ...
  316. >>> def __reader__():
  317. ... for _ in range(BATCH_NUM):
  318. ... batch_image, batch_label = _get_random_images_and_labels(
  319. ... [BATCH_SIZE, IMAGE_SIZE])
  320. ... yield batch_image, batch_label
  321. ...
  322. >>> def random_batch_reader():
  323. ... return __reader__
  324. ...
  325. >>> class LinearNet(nn.Layer):
  326. ... def __init__(self):
  327. ... super().__init__()
  328. ... self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
  329. ...
  330. ... @paddle.jit.to_static
  331. ... def forward(self, x):
  332. ... return self._linear(x)
  333. ...
  334. >>> # set device
  335. >>> paddle.set_device('gpu' if USE_GPU else 'cpu')
  336. >>> # doctest: +SKIP('`paddle.jit.to_static` can not run in xdoctest')
  337. >>> # create network
  338. >>> layer = LinearNet()
  339. >>> dp_layer = paddle.DataParallel(layer)
  340. >>> loss_fn = nn.CrossEntropyLoss()
  341. >>> adam = opt.Adam(learning_rate=0.001, parameters=dp_layer.parameters())
  342. >>> # create data loader
  343. >>> loader = paddle.base.io.DataLoader.from_generator(capacity=5)
  344. >>> loader.set_batch_generator(random_batch_reader())
  345. >>> for epoch_id in range(EPOCH_NUM):
  346. ... for batch_id, (image, label) in enumerate(loader()):
  347. ... out = layer(image)
  348. ... loss = loss_fn(out, label)
  349. ...
  350. ... loss.backward()
  351. ...
  352. ... adam.step()
  353. ... adam.clear_grad()
  354. ... print("Epoch {} batch {}: loss = {}".format(
  355. ... epoch_id, batch_id, np.mean(loss.numpy())))
  356. ...
  357. >>> # doctest: -SKIP
  358. """
  359. if in_dygraph_mode():
  360. return DygraphGeneratorLoader(
  361. feed_list,
  362. capacity,
  363. use_double_buffer,
  364. iterable,
  365. return_list,
  366. use_multiprocess,
  367. )
  368. else:
  369. return GeneratorLoader(
  370. feed_list,
  371. capacity,
  372. use_double_buffer,
  373. iterable,
  374. return_list,
  375. drop_last,
  376. )
  377. @staticmethod
  378. def from_dataset(dataset, places, drop_last=True):
  379. """
  380. .. warning::
  381. This API will be deprecated in the future, it is recommended to use
  382. :code:`paddle.io.DataLoader` which supports multi-processes acceleration.
  383. Create an iterable DataLoader object for loading data from Dataset.
  384. Dataset is only supported in Linux system currently.
  385. Args:
  386. dataset (InMemoryDataset|QueueDataset): the dataset object.
  387. places (list(CUDAPlace)|list(CPUPlace)|list(str)): places where the result
  388. data should be converted. If places is list of string, the string in the list
  389. can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where x is the index of the GPUs.
  390. drop_last (bool, optional): whether to drop the last batch whose
  391. sample number is less than batch size. If drop_last = True,
  392. they would be dropped. If drop_last = False, they would be kept.
  393. Returns:
  394. loader (DataLoader): the created DataLoader object, which can be
  395. treated as a Python generator.
  396. Examples:
  397. .. code-block:: python
  398. >>> import paddle
  399. >>> import paddle.static as static
  400. >>> paddle.enable_static()
  401. >>> image = static.data(name='image', shape=[None, 784], dtype='float32')
  402. >>> label = static.data(name='label', shape=[None, 1], dtype='int64')
  403. >>> dataset = paddle.distributed.QueueDataset()
  404. >>> dataset.init(
  405. ... batch_size=32,
  406. ... pipe_command='cat',
  407. ... use_var=[image, label])
  408. >>> dataset.set_filelist(['a.txt', 'b.txt', 'c.txt'])
  409. >>> loader = paddle.base.io.DataLoader.from_dataset(dataset, static.cpu_places())
  410. """
  411. return DatasetLoader(dataset, places, drop_last)
  412. class DygraphGeneratorLoader(DataLoaderBase):
  413. """
  414. The GeneratorLoader of dygraph
  415. The multiprocess dygraph GeneratorLoader's most functions are different from
  416. static graph GeneratorLoader, Separate implementation to keep code readable.
  417. """
  418. def __init__(
  419. self,
  420. feed_list=None,
  421. capacity=None,
  422. use_double_buffer=True,
  423. iterable=True,
  424. return_list=True,
  425. use_multiprocess=False,
  426. ):
  427. self._batch_reader = None
  428. self._places = None
  429. self._feed_list = feed_list
  430. self._timeout = QUEUE_GET_TIMEOUT
  431. if not capacity:
  432. raise ValueError("Please give value to capacity.")
  433. self._capacity = capacity
  434. self._use_double_buffer = use_double_buffer
  435. if not iterable:
  436. warnings.warn(
  437. "Please NOTE: DygraphGeneratorLoader supports iterable mode only. Change to iterable mode."
  438. )
  439. self._iterable = True
  440. if not return_list:
  441. warnings.warn(
  442. "Please NOTE: DygraphGeneratorLoader supports returning as list only. Change to return as list."
  443. )
  444. self._return_list = True
  445. # NOTE: the multiprocessing in different platform is incompatible, we will solve it later
  446. self._use_multiprocess = use_multiprocess
  447. if self._use_multiprocess and (
  448. sys.platform == 'darwin' or sys.platform == 'win32'
  449. ):
  450. warnings.warn(
  451. "NOTE: DygraphGeneratorLoader with multiprocess mode is not currently supported on MacOs and Windows."
  452. )
  453. self._use_multiprocess = False
  454. if self._use_multiprocess:
  455. # NOTE: the multiprocessing.Queue used to save loading data in self._process
  456. self._data_queue = None
  457. # NOTE: this process is used to load data asynchronously from self._batch_reader
  458. self._process = None
  459. # NOTE: the C++ LoDTensorBlockingQueue instance
  460. self._blocking_queue = None
  461. # NOTE: 1. In multiprocess mode, this thread is used to get next batch data from
  462. # self._data_queue, then push it into self._blocking_queue; 2. In single process
  463. # mode, this thread is used to get next batch data from self._batch_reader, then
  464. # push it into self._blocking_queue
  465. self._thread = None
  466. self._pin_memory = (
  467. True if use_pinned_memory() is None else use_pinned_memory()
  468. )
  469. @property
  470. def queue(self):
  471. return self._blocking_queue
  472. @property
  473. def iterable(self):
  474. return self._iterable
  475. def _clear_and_remove_data_queue(self):
  476. if self._data_queue is not None:
  477. while True:
  478. try:
  479. self._data_queue.get_nowait()
  480. except queue.Empty:
  481. break
  482. global multiprocess_queue_set
  483. multiprocess_queue_set.remove(self._data_queue)
  484. def _wait_thread_ends(self):
  485. thread = self._thread
  486. if thread is not None:
  487. self._blocking_queue.close()
  488. thread.join()
  489. def _wait_process_ends(self):
  490. process = self._process
  491. if process is not None:
  492. process.join()
  493. # erase process id
  494. core._erase_process_pids(id(self))
  495. def _init_iterable(self):
  496. self._wait_thread_ends()
  497. if self._use_multiprocess:
  498. self._wait_process_ends()
  499. self._var_names = []
  500. self._shapes = []
  501. self._dtypes = []
  502. self._need_check_feed = []
  503. self._blocking_queue = core.init_lod_tensor_blocking_queue(
  504. core.Variable(), self._capacity, False
  505. )
  506. self._reader = None
  507. self._reader = core.create_py_reader(
  508. self.queue,
  509. self._var_names,
  510. self._shapes,
  511. self._dtypes,
  512. self._need_check_feed,
  513. self._places,
  514. self._use_double_buffer,
  515. True,
  516. self._pin_memory,
  517. )
  518. def _start(self):
  519. if self._use_multiprocess:
  520. # clear old _data_queue and remove it from multiprocess_queue_set
  521. self._clear_and_remove_data_queue()
  522. # set data_queue and process
  523. self._data_queue = multiprocessing.Queue(self._capacity)
  524. # add _data_queue into global queue set
  525. global multiprocess_queue_set
  526. multiprocess_queue_set.add(self._data_queue)
  527. self._process = multiprocessing.Process(
  528. target=_reader_process_loop,
  529. args=(self._batch_reader, self._data_queue, False),
  530. )
  531. self._process.daemon = True
  532. self._process.start()
  533. # Set child process signal handler
  534. # NOTE: [ avoiding hang ] 1. if the child process dies due to bus error/segfault
  535. # or just hang, the main process will hang waiting for data, so here need to deal
  536. # with SIGSEGV and SIGBUS of child process; 2. if the main process end before child
  537. # process, it shuts the all its daemonic children down with a SIGTERM (instead of
  538. # joining them without a timeout), so here need to deal with SIGTERM.
  539. core._set_process_pids(id(self), [self._process.pid])
  540. _set_SIGCHLD_handler()
  541. # Set reader_thread
  542. self._thread_done_event = threading.Event()
  543. self._thread = threading.Thread(
  544. target=self._reader_thread_loop_for_multiprocess,
  545. args=(_current_expected_place(),),
  546. )
  547. self._thread.daemon = True
  548. self._thread.start()
  549. else:
  550. self._thread = threading.Thread(
  551. target=self._reader_thread_loop_for_singleprocess,
  552. args=(_current_expected_place(),),
  553. )
  554. self._thread.daemon = True
  555. self._thread.start()
  556. def _reset(self):
  557. self._reader.reset()
  558. self._wait_thread_ends()
  559. if self._use_multiprocess:
  560. self._wait_process_ends()
  561. def __iter__(self):
  562. assert self.iterable, "DataLoader is not iterable"
  563. assert (
  564. self._batch_reader is not None
  565. ), "Data source of DataLoader has not set yet"
  566. self._init_iterable()
  567. self._start()
  568. return self
  569. def __next__(self):
  570. try:
  571. return core.eager.read_next_tensor_list(
  572. self._reader.read_next_list()[0]
  573. )
  574. except StopIteration:
  575. self._reset()
  576. raise
  577. def _exit_thread_expectedly(self):
  578. self._thread_done_event.set()
  579. self._blocking_queue.close()
  580. def _exit_thread_unexpectedly(self):
  581. self._thread_done_event.set()
  582. self._blocking_queue.kill()
  583. logging.error("DataLoader reader thread raised an exception!")
  584. def _reader_thread_loop_for_multiprocess(self, legacy_expected_place):
  585. # See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
  586. core.set_current_thread_name("Dataloader_" + str(id(self)))
  587. _set_expected_place(legacy_expected_place)
  588. while not self._thread_done_event.is_set():
  589. try:
  590. # NOTE: [ avoid hanging ] Even with carefully designed data dependencies
  591. # (i.e., a put() always corresponding to a get()), hanging on get() can
  592. # still happen when data in queue is corrupted (e.g., due to
  593. # Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
  594. # we try to get data from `data_queue`
  595. # NOTE: [ avoid failed quickly ] Here, the time setting of QUEUE_GET_TIMEOUT
  596. # is relatively long, currently it is 60 seconds, because in some models,
  597. # if the reader child process starts with a heavy burden, the child process
  598. # has no enough time to put the data in the queue when the main process
  599. # start trying to get data from queue. At this time, the child thread needs
  600. # to wait slightly longer
  601. tensor_list = self._data_queue.get(timeout=self._timeout)
  602. except Exception as e:
  603. # NOTE [ avoid handing ] After adding the shared memory mechanism, not only
  604. # the queue.Empty exception will occur here, but other exceptions will also
  605. # occur, such as mmap failure. If it is not handled here, it will hang.
  606. self._exit_thread_unexpectedly()
  607. logging.error(
  608. "DataLoader reader thread failed to read data from the multiprocessing.Queue."
  609. )
  610. raise e
  611. if not self._thread_done_event.is_set():
  612. if tensor_list is not None:
  613. try:
  614. array = core.LoDTensorArray()
  615. for tensor in tensor_list:
  616. array.append(tensor)
  617. if not self._blocking_queue.push(array):
  618. self._blocking_queue.close()
  619. except Exception as e:
  620. self._exit_thread_unexpectedly()
  621. raise e
  622. else:
  623. self._exit_thread_expectedly()
  624. def _reader_thread_loop_for_singleprocess(self, legacy_expected_place):
  625. try:
  626. # See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
  627. core.set_current_thread_name("Dataloader_" + str(id(self)))
  628. _set_expected_place(legacy_expected_place)
  629. for sample in self._batch_reader():
  630. array = core.LoDTensorArray()
  631. for item in sample:
  632. if not isinstance(item, core.LoDTensor):
  633. item = self._check_input_array(item)
  634. tmp = core.LoDTensor()
  635. tmp.set(item, core.CPUPlace())
  636. item = tmp
  637. array.append(item)
  638. if not self._blocking_queue.push(array):
  639. break
  640. self._blocking_queue.close()
  641. self._thread = None
  642. except Exception as e:
  643. self._blocking_queue.kill()
  644. self._thread = None
  645. logging.warning(
  646. "DygraphDataLoader reader thread raised an exception."
  647. )
  648. raise e
  649. def set_sample_generator(
  650. self, reader, batch_size, drop_last=True, places=None
  651. ):
  652. assert batch_size > 0, "batch_size must be larger than 0"
  653. if isinstance(places, (list, tuple)):
  654. places = _get_paddle_place_list(places)
  655. else:
  656. places = _get_paddle_place(places)
  657. self.set_sample_list_generator(
  658. paddle.batch(reader, batch_size=batch_size, drop_last=drop_last),
  659. places=places,
  660. )
  661. return self
  662. def set_sample_list_generator(self, reader, places=None):
  663. if isinstance(places, (list, tuple)):
  664. places = _get_paddle_place_list(places)
  665. else:
  666. places = _get_paddle_place(places)
  667. def __batch_reader_impl__():
  668. for batch in reader():
  669. slots = []
  670. for items in batch:
  671. for i, item in enumerate(items):
  672. if len(slots) < len(items):
  673. slots.append([item])
  674. else:
  675. slots[i].append(item)
  676. yield slots
  677. self.set_batch_generator(__batch_reader_impl__, places)
  678. return self
  679. def set_batch_generator(self, reader, places=None):
  680. if isinstance(places, (list, tuple)):
  681. places = _get_paddle_place_list(places)
  682. else:
  683. places = _get_paddle_place(places)
  684. self._batch_reader = reader
  685. if places is None:
  686. places = _current_expected_place()
  687. self._places = _convert_places(places)
  688. assert (
  689. len(self._places) == 1
  690. ), "Number of places must be 1 in imperative mode"
  691. return self
  692. class GeneratorLoader(DataLoaderBase):
  693. def __init__(
  694. self,
  695. feed_list=None,
  696. capacity=None,
  697. use_double_buffer=True,
  698. iterable=True,
  699. return_list=False,
  700. drop_last=True,
  701. ):
  702. self._tensor_reader = None
  703. self._places = None
  704. self._thread = None
  705. self._queue = None
  706. self._feed_list = feed_list
  707. self._exited = False
  708. self._drop_last = drop_last
  709. self._keep_order = keep_data_loader_order()
  710. if not capacity:
  711. raise ValueError("Please give value to capacity.")
  712. self._iterable = iterable
  713. self._return_list = return_list
  714. if not self._feed_list:
  715. raise Exception("Feed list must be given under static graph mode.")
  716. self._use_double_buffer = use_double_buffer
  717. self._capacity = capacity
  718. if not self._iterable:
  719. self._init_non_iterable()
  720. def _wait_thread_ends(self):
  721. # Get self._thread first to prevent data race, because __thread_main__
  722. # would set self._thread be None at the end
  723. thread = self._thread
  724. if thread is not None and self._iterable:
  725. self._queue.close()
  726. thread.join()
  727. def _init_iterable(self):
  728. self._wait_thread_ends()
  729. self._var_names = [v.name for v in self._feed_list]
  730. self._shapes = [v.shape for v in self._feed_list]
  731. if in_pir_mode():
  732. self._dtypes = [
  733. datatype_to_vartype[v.dtype] for v in self._feed_list
  734. ]
  735. self._need_check_feed = [False for v in self._feed_list]
  736. else:
  737. self._dtypes = [v.dtype for v in self._feed_list]
  738. self._need_check_feed = [
  739. v.desc.need_check_feed() for v in self._feed_list
  740. ]
  741. self._queue = core.init_lod_tensor_blocking_queue(
  742. core.Variable(), self._capacity, self._keep_order
  743. )
  744. self._reader = None
  745. self._reader = core.create_py_reader(
  746. self.queue,
  747. self._var_names,
  748. self._shapes,
  749. self._dtypes,
  750. self._need_check_feed,
  751. self._places,
  752. self._use_double_buffer,
  753. self._drop_last,
  754. False,
  755. )
  756. def _init_non_iterable(self):
  757. lod_levels = []
  758. dtypes = []
  759. shape_concat = []
  760. ranks = []
  761. shapes = []
  762. need_check_feed = []
  763. for feed_data in self._feed_list:
  764. dtypes.append(feed_data.dtype)
  765. shape_concat.extend(feed_data.shape)
  766. ranks.append(len(feed_data.shape))
  767. shapes.append(feed_data.shape)
  768. lod_levels.append(feed_data.lod_level)
  769. if in_pir_mode():
  770. need_check_feed.append(0)
  771. else:
  772. need_check_feed.append(int(feed_data.desc.need_check_feed()))
  773. queue_name = data_loader_unique_name_generator(
  774. 'lod_tensor_blocking_queue'
  775. )
  776. reader_name = data_loader_unique_name_generator('create_py_reader')
  777. double_buffer_name = data_loader_unique_name_generator('double_buffer')
  778. var = global_scope().var(queue_name)
  779. self._queue = core.init_lod_tensor_blocking_queue(
  780. var, self._capacity, self._keep_order
  781. )
  782. if self._keep_order:
  783. block = default_main_program().current_block()
  784. else:
  785. block = default_startup_program().current_block()
  786. reader_var = block.create_var(name=reader_name)
  787. dtype_int = [int(t) for t in dtypes]
  788. block.append_op(
  789. type='create_py_reader',
  790. inputs={'blocking_queue': [queue_name]},
  791. outputs={'Out': [reader_var]},
  792. attrs={
  793. 'shape_concat': shape_concat,
  794. 'lod_levels': lod_levels,
  795. 'dtypes': dtype_int,
  796. 'need_check_feed': need_check_feed,
  797. 'ranks': ranks,
  798. },
  799. )
  800. reader_var.desc.set_dtypes(dtypes)
  801. reader_var.persistable = True
  802. reader_var.stop_gradient = True
  803. if self._keep_order:
  804. main_prog_var = reader_var
  805. reader = main_prog_var
  806. reader.reset = self._queue.reset
  807. else:
  808. main_prog_var = _copy_reader_var_(
  809. default_main_program().current_block(), reader_var
  810. )
  811. main_prog_var.stop_gradient = True
  812. main_prog_var.persistable = True
  813. reader = monkey_patch_reader_methods(main_prog_var)
  814. if self._use_double_buffer:
  815. double_buffer_reader = __create_unshared_decorated_reader__(
  816. 'create_double_buffer_reader',
  817. reader,
  818. {},
  819. name=double_buffer_name,
  820. )
  821. # we return a double buffer reader. However, the reset method comes from
  822. # py_reader.
  823. double_buffer_reader.reset = reader.reset
  824. reader = double_buffer_reader
  825. self._reader = reader
  826. default_main_program().current_block().append_op(
  827. type='read',
  828. inputs={'Reader': [self._reader]},
  829. outputs={'Out': self._feed_list},
  830. attrs={'drop_last': self._drop_last},
  831. )
  832. @property
  833. def queue(self):
  834. return self._queue
  835. @property
  836. def iterable(self):
  837. return self._iterable
  838. def __iter__(self):
  839. assert self.iterable, "DataLoader is not iterable"
  840. assert (
  841. self._tensor_reader is not None
  842. ), "Data source of DataLoader has not set yet"
  843. self._init_iterable()
  844. self._start()
  845. return self
  846. def __next__(self):
  847. try:
  848. if self._return_list:
  849. data = self._reader.read_next_list()
  850. for i in range(len(data)):
  851. data[i] = data[i]._move_to_list()
  852. return data
  853. else:
  854. return self._reader.read_next()
  855. except StopIteration:
  856. self._queue.close()
  857. self._reset()
  858. raise
  859. def start(self):
  860. assert (
  861. not self._iterable
  862. ), "start() cannot be called when DataLoader is iterable"
  863. self._start()
  864. def reset(self):
  865. assert (
  866. not self._iterable
  867. ), "reset() cannot be called when DataLoader is iterable"
  868. self._reset()
  869. def _start(self):
  870. def __thread_main__(legacy_expected_place):
  871. try:
  872. # See _DataLoaderIterSingleProcess._thread_loop() for why set expected place here.
  873. core.set_current_thread_name("Dataloader_" + str(id(self)))
  874. _set_expected_place(legacy_expected_place)
  875. while not self._queue.wait_for_inited(1):
  876. if self._exited:
  877. return
  878. for tensors in self._tensor_reader():
  879. array = core.LoDTensorArray()
  880. for item in tensors:
  881. if not isinstance(item, core.LoDTensor):
  882. item = self._check_input_array(item)
  883. tmp = core.LoDTensor()
  884. tmp.set(item, core.CPUPlace())
  885. item = tmp
  886. array.append(item)
  887. if not self._queue.push(array):
  888. break
  889. self._queue.close()
  890. self._thread = None
  891. except Exception as e:
  892. self._queue.kill()
  893. self._thread = None
  894. logging.warning('Your reader has raised an exception!')
  895. raise e
  896. self._thread = threading.Thread(
  897. target=__thread_main__, args=(_current_expected_place(),)
  898. )
  899. self._thread.daemon = True
  900. self._thread.start()
  901. def _reset(self):
  902. self._queue.close()
  903. self._exited = True
  904. thread = self._thread
  905. if thread is not None:
  906. thread.join()
  907. self._exited = False
  908. self._reader.reset()
  909. def set_sample_generator(
  910. self, reader, batch_size, drop_last=True, places=None
  911. ):
  912. assert batch_size > 0, "batch_size must be larger than 0"
  913. if isinstance(places, (list, tuple)):
  914. places = _get_paddle_place_list(places)
  915. else:
  916. places = _get_paddle_place(places)
  917. has_lod = False
  918. for f in self._feed_list:
  919. if f.lod_level != 0:
  920. has_lod = True
  921. break
  922. if has_lod:
  923. self.set_sample_list_generator(
  924. paddle.batch(
  925. reader, batch_size=batch_size, drop_last=drop_last
  926. ),
  927. places=places,
  928. )
  929. else:
  930. reader = BatchedTensorProvider(
  931. feed_list=self._feed_list,
  932. place=core.CPUPlace(),
  933. batch_size=batch_size,
  934. generator=reader,
  935. drop_last=drop_last,
  936. )
  937. self.set_batch_generator(reader, places=places)
  938. return self
  939. def set_sample_list_generator(self, reader, places=None):
  940. if isinstance(places, (list, tuple)):
  941. places = _get_paddle_place_list(places)
  942. else:
  943. places = _get_paddle_place(places)
  944. with program_guard(Program(), Program()):
  945. feeder = DataFeeder(
  946. feed_list=self._feed_list, place=core.CPUPlace()
  947. )
  948. def decorate_reader():
  949. for item in reader():
  950. yield feeder.feed(item)
  951. paddle_reader = decorate_reader
  952. def __tensor_reader_impl__():
  953. for slots in paddle_reader():
  954. yield [slots[var.name] for var in self._feed_list]
  955. self.set_batch_generator(__tensor_reader_impl__, places)
  956. return self
  957. def set_batch_generator(self, reader, places=None):
  958. if isinstance(places, (list, tuple)):
  959. places = _get_paddle_place_list(places)
  960. else:
  961. places = _get_paddle_place(places)
  962. self._tensor_reader = reader
  963. if self._iterable:
  964. assert (
  965. places is not None
  966. ), "Places cannot be None when DataLoader is iterable"
  967. self._places = _convert_places(places)
  968. else:
  969. if places is not None:
  970. logging.info(
  971. 'places would be omitted when DataLoader is not iterable'
  972. )
  973. return self
  974. class PyReader(DataLoaderBase):
  975. r"""
  976. Create a reader object for data feeding in Python.
  977. Data would be prefetched using Python thread and be pushed
  978. into a queue asynchronously. Data in the queue would be extracted
  979. automatically when `Executor.run(...)` is called.
  980. Args:
  981. feed_list (list(Variable)|tuple(Variable)): feed variable list.
  982. The variables should be created by :code:`paddle.static.data()`.
  983. capacity (int): capacity of the queue maintained in PyReader.
  984. The unit is batch number. Set larger capacity if your reader
  985. is fast.
  986. use_double_buffer (bool): whether to use double_buffer_reader.
  987. If use_double_buffer=True, PyReader would prefetch next
  988. batch data asynchronously, so it would speed up data feeding
  989. and occupies a little more CPU or GPU memory, i.e., the memory
  990. of one batch input data.
  991. iterable (bool): whether the created PyReader is iterable.
  992. return_list (bool): whether the return value on each device is
  993. presented as a list. It is only valid when iterable=True.
  994. If return_list=False, the return value on each device would
  995. be a dict of str -> LoDTensor, where the key of the dict is
  996. the name of each fed variables. If return_list=True, the
  997. return value on each device would be a list(LoDTensor). It is
  998. recommended to use return_list=False in static graph mode and
  999. use return_list=True in dygraph mode.
  1000. Returns:
  1001. the created reader object.
  1002. Return type:
  1003. reader(Reader)
  1004. Examples:
  1005. 1. If iterable = False, the created PyReader object is almost the
  1006. same as :code:`base.layers.py_reader()`. Operators would be
  1007. inserted into the program. User should call :code:`start()`
  1008. before each epoch and catch :code:`base.core.EOFException`
  1009. thrown by :code:`Executor.run()` when epoch ends. Once the
  1010. exception is caught, user should call :code:`reset()` to reset
  1011. the reader manually.
  1012. .. code-block:: python
  1013. :name: example_1
  1014. >>> import paddle
  1015. >>> import paddle.base as base
  1016. >>> import numpy as np
  1017. >>> paddle.enable_static()
  1018. >>> EPOCH_NUM = 3
  1019. >>> ITER_NUM = 5
  1020. >>> BATCH_SIZE = 3
  1021. >>> def network(image, label):
  1022. ... # User-defined network, here is an example of softmax regression.
  1023. ... predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
  1024. ... return paddle.nn.functional.cross_entropy(
  1025. ... input=predict, label=label,
  1026. ... reduction='none', use_softmax=False
  1027. ... )
  1028. >>> def reader_creator_random_image_and_label(height, width):
  1029. ... def reader():
  1030. ... for i in range(ITER_NUM):
  1031. ... fake_image = np.random.uniform(low=0,
  1032. ... high=255,
  1033. ... size=[height, width])
  1034. ... fake_label = np.ones([1])
  1035. ... yield fake_image, fake_label
  1036. ... return reader
  1037. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1038. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  1039. >>> reader = base.io.PyReader(feed_list=[image, label],
  1040. ... capacity=4,
  1041. ... iterable=False)
  1042. >>> user_defined_reader = reader_creator_random_image_and_label(784, 784)
  1043. >>> reader.decorate_sample_list_generator(
  1044. ... paddle.batch(user_defined_reader, batch_size=BATCH_SIZE))
  1045. >>> loss = network(image, label)
  1046. >>> executor = base.Executor(base.CPUPlace())
  1047. >>> executor.run(base.default_startup_program())
  1048. >>> for i in range(EPOCH_NUM):
  1049. ... reader.start()
  1050. ... while True:
  1051. ... try:
  1052. ... executor.run(feed=None)
  1053. ... except base.core.EOFException:
  1054. ... reader.reset()
  1055. ... break
  1056. 2. If iterable=True, the created PyReader object is decoupled with
  1057. the program. No operator would be inserted into the program.
  1058. In this case, the created reader is a Python generator, which
  1059. is iterable. User should feed the data yielded from PyReader
  1060. object into :code:`Executor.run(feed=...)`.
  1061. .. code-block:: python
  1062. :name: example_2
  1063. >>> import paddle
  1064. >>> import paddle.base as base
  1065. >>> import numpy as np
  1066. >>> paddle.enable_static()
  1067. >>> EPOCH_NUM = 3
  1068. >>> ITER_NUM = 5
  1069. >>> BATCH_SIZE = 10
  1070. >>> def network(image, label):
  1071. ... # User-defined network, here is an example of softmax regression.
  1072. ... predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
  1073. ... return paddle.nn.functional.cross_entropy(
  1074. ... input=predict, label=label,
  1075. ... reduction='none', use_softmax=False
  1076. ... )
  1077. >>> def reader_creator_random_image(height, width):
  1078. ... def reader():
  1079. ... for i in range(ITER_NUM):
  1080. ... fake_image = np.random.uniform(low=0, high=255, size=[height, width])
  1081. ... fake_label = np.ones([1])
  1082. ... yield fake_image, fake_label
  1083. ... return reader
  1084. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1085. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  1086. >>> reader = base.io.PyReader(feed_list=[image, label], capacity=4, iterable=True, return_list=False)
  1087. >>> user_defined_reader = reader_creator_random_image(784, 784)
  1088. >>> reader.decorate_sample_list_generator(
  1089. ... paddle.batch(user_defined_reader, batch_size=BATCH_SIZE),
  1090. ... base.core.CPUPlace())
  1091. >>> loss = network(image, label)
  1092. >>> executor = base.Executor(base.CPUPlace())
  1093. >>> executor.run(base.default_startup_program())
  1094. >>> for _ in range(EPOCH_NUM):
  1095. ... for data in reader():
  1096. ... executor.run(feed=data, fetch_list=[loss])
  1097. 3. If return_list=True, the return values would be presented as list instead of dict.
  1098. This is usually used in dygraph mode.
  1099. .. code-block:: python
  1100. :name: example_3
  1101. >>> import paddle
  1102. >>> import paddle.base as base
  1103. >>> import numpy as np
  1104. >>> ITER_NUM = 5
  1105. >>> BATCH_SIZE = 10
  1106. >>> def reader_creator_random_image(height, width):
  1107. ... def reader():
  1108. ... for i in range(ITER_NUM):
  1109. ... yield np.random.uniform(low=0, high=255, size=[height, width]), \
  1110. ... np.random.random_integers(low=0, high=9, size=[1])
  1111. ... return reader
  1112. >>> place = base.CPUPlace()
  1113. >>> with base.dygraph.guard(place):
  1114. ... py_reader = base.io.PyReader(capacity=2, return_list=True)
  1115. ... user_defined_reader = reader_creator_random_image(784, 784)
  1116. ... py_reader.decorate_sample_list_generator(
  1117. ... paddle.batch(user_defined_reader, batch_size=BATCH_SIZE),
  1118. ... place)
  1119. ... for image, label in py_reader():
  1120. ... relu = paddle.nn.functional.relu(image)
  1121. """
  1122. def __init__(
  1123. self,
  1124. feed_list=None,
  1125. capacity=None,
  1126. use_double_buffer=True,
  1127. iterable=True,
  1128. return_list=False,
  1129. ):
  1130. self._loader = DataLoader.from_generator(
  1131. feed_list, capacity, use_double_buffer, iterable, return_list
  1132. )
  1133. @property
  1134. def queue(self):
  1135. return self._loader.queue
  1136. @property
  1137. def iterable(self):
  1138. return self._loader.iterable
  1139. def __iter__(self):
  1140. return self._loader.__iter__()
  1141. def __next__(self):
  1142. return self._loader.__next__()
  1143. def start(self):
  1144. '''
  1145. Start the data feeding thread.
  1146. Can only call when the reader object is not iterable.
  1147. Example:
  1148. .. code-block:: python
  1149. >>> import paddle
  1150. >>> import paddle.base as base
  1151. >>> import numpy as np
  1152. >>> paddle.enable_static()
  1153. >>> BATCH_SIZE = 10
  1154. >>> def generator():
  1155. ... for i in range(5):
  1156. ... yield np.random.uniform(low=0, high=255, size=[784, 784]),
  1157. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1158. >>> reader = base.io.PyReader(feed_list=[image], capacity=4, iterable=False)
  1159. >>> reader.decorate_sample_list_generator(
  1160. ... paddle.batch(generator, batch_size=BATCH_SIZE))
  1161. >>> executor = base.Executor(base.CPUPlace())
  1162. >>> executor.run(base.default_startup_program())
  1163. >>> for i in range(3):
  1164. ... reader.start()
  1165. ... while True:
  1166. ... try:
  1167. ... executor.run(feed=None)
  1168. ... except base.core.EOFException:
  1169. ... reader.reset()
  1170. ... break
  1171. '''
  1172. self._loader.start()
  1173. def reset(self):
  1174. '''
  1175. Reset the reader object when :code:`base.core.EOFException` raises.
  1176. Can only call when the reader object is not iterable.
  1177. Example:
  1178. .. code-block:: python
  1179. >>> import paddle
  1180. >>> import paddle.base as base
  1181. >>> import numpy as np
  1182. >>> paddle.enable_static()
  1183. >>> BATCH_SIZE = 10
  1184. >>> def generator():
  1185. ... for i in range(5):
  1186. ... yield np.random.uniform(low=0, high=255, size=[784, 784]),
  1187. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1188. >>> reader = base.io.PyReader(feed_list=[image], capacity=4, iterable=False)
  1189. >>> reader.decorate_sample_list_generator(
  1190. ... paddle.batch(generator, batch_size=BATCH_SIZE))
  1191. >>> executor = base.Executor(base.CPUPlace())
  1192. >>> executor.run(base.default_startup_program())
  1193. >>> for i in range(3):
  1194. ... reader.start()
  1195. ... while True:
  1196. ... try:
  1197. ... executor.run(feed=None)
  1198. ... except base.core.EOFException:
  1199. ... reader.reset()
  1200. ... break
  1201. '''
  1202. self._loader.reset()
  1203. def decorate_sample_generator(
  1204. self, sample_generator, batch_size, drop_last=True, places=None
  1205. ):
  1206. '''
  1207. Set the data source of the PyReader object.
  1208. The provided :code:`sample_generator` should be a Python generator,
  1209. which yields list(numpy.ndarray)-typed data of each sample.
  1210. :code:`places` must be set when the PyReader object is iterable.
  1211. If all inputs have no lods, this method is faster than
  1212. :code:`decorate_sample_list_generator(paddle.batch(sample_generator, ...))` .
  1213. Args:
  1214. sample_generator (generator): Python generator that yields
  1215. list(numpy.ndarray)-typed sample data.
  1216. batch_size (int): batch size. Must be larger than 0.
  1217. drop_last (bool): Whether to drop the last batch when sample number
  1218. is less than batch_size.
  1219. places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must
  1220. be provided when PyReader is iterable.
  1221. Example:
  1222. .. code-block:: python
  1223. >>> import paddle
  1224. >>> import paddle.base as base
  1225. >>> import numpy as np
  1226. >>> paddle.enable_static()
  1227. >>> EPOCH_NUM = 3
  1228. >>> ITER_NUM = 15
  1229. >>> BATCH_SIZE = 3
  1230. >>> def network(image, label):
  1231. ... # User-defined network, here is an example of softmax regression.
  1232. ... predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
  1233. ... return paddle.nn.functional.cross_entropy(
  1234. ... input=predict, label=label,
  1235. ... reduction='none', use_softmax=False
  1236. ... )
  1237. >>> def random_image_and_label_generator(height, width):
  1238. ... def generator():
  1239. ... for i in range(ITER_NUM):
  1240. ... fake_image = np.random.uniform(low=0,
  1241. ... high=255,
  1242. ... size=[height, width])
  1243. ... fake_label = np.array([1])
  1244. ... yield fake_image, fake_label
  1245. ... return generator
  1246. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1247. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  1248. >>> reader = base.io.PyReader(feed_list=[image, label], capacity=4, iterable=True)
  1249. >>> user_defined_generator = random_image_and_label_generator(784, 784)
  1250. >>> reader.decorate_sample_generator(user_defined_generator,
  1251. ... batch_size=BATCH_SIZE,
  1252. ... places=[base.CPUPlace()])
  1253. >>> loss = network(image, label)
  1254. >>> executor = base.Executor(base.CPUPlace())
  1255. >>> executor.run(base.default_startup_program())
  1256. >>> for _ in range(EPOCH_NUM):
  1257. ... for data in reader():
  1258. ... executor.run(feed=data, fetch_list=[loss])
  1259. '''
  1260. self._loader.set_sample_generator(
  1261. sample_generator, batch_size, drop_last, places
  1262. )
  1263. def decorate_sample_list_generator(self, reader, places=None):
  1264. '''
  1265. Set the data source of the PyReader object.
  1266. The provided :code:`reader` should be a Python generator,
  1267. which yields list(numpy.ndarray) typed batched data.
  1268. :code:`places` must be set when the PyReader object is iterable.
  1269. Args:
  1270. reader (generator): Python generator that yields
  1271. list(numpy.ndarray)-typed batched data.
  1272. places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must
  1273. be provided when PyReader is iterable.
  1274. Example:
  1275. .. code-block:: python
  1276. >>> import paddle
  1277. >>> import paddle.base as base
  1278. >>> import numpy as np
  1279. >>> paddle.enable_static()
  1280. >>> EPOCH_NUM = 3
  1281. >>> ITER_NUM = 15
  1282. >>> BATCH_SIZE = 3
  1283. >>> def network(image, label):
  1284. ... # User-defined network, here is an example of softmax regression.
  1285. ... predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
  1286. ... return paddle.nn.functional.cross_entropy(
  1287. ... input=predict, label=label,
  1288. ... reduction='none', use_softmax=False
  1289. ... )
  1290. >>> def random_image_and_label_generator(height, width):
  1291. ... def generator():
  1292. ... for i in range(ITER_NUM):
  1293. ... fake_image = np.random.uniform(low=0,
  1294. ... high=255,
  1295. ... size=[height, width])
  1296. ... fake_label = np.ones([1])
  1297. ... yield fake_image, fake_label
  1298. ... return generator
  1299. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1300. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  1301. >>> reader = base.io.PyReader(feed_list=[image, label], capacity=4, iterable=True)
  1302. >>> user_defined_generator = random_image_and_label_generator(784, 784)
  1303. >>> reader.decorate_sample_list_generator(
  1304. ... paddle.batch(user_defined_generator, batch_size=BATCH_SIZE),
  1305. ... base.core.CPUPlace())
  1306. >>> loss = network(image, label)
  1307. >>> executor = base.Executor(base.core.CPUPlace())
  1308. >>> executor.run(base.default_startup_program())
  1309. >>> for _ in range(EPOCH_NUM):
  1310. ... for data in reader():
  1311. ... executor.run(feed=data, fetch_list=[loss])
  1312. '''
  1313. self._loader.set_sample_list_generator(reader, places)
  1314. def decorate_batch_generator(self, reader, places=None):
  1315. '''
  1316. Set the data source of the PyReader object.
  1317. The provided :code:`reader` should be a Python generator,
  1318. which yields numpy.ndarray-typed or LoDTensor-typed batched data.
  1319. :code:`places` must be set when the PyReader object is iterable.
  1320. Args:
  1321. reader (generator): Python generator that yields LoDTensor-typed
  1322. batched data.
  1323. places (None|list(CUDAPlace)|list(CPUPlace)): place list. Must
  1324. be provided when PyReader is iterable.
  1325. Example:
  1326. .. code-block:: python
  1327. >>> import paddle
  1328. >>> import paddle.base as base
  1329. >>> import numpy as np
  1330. >>> paddle.enable_static()
  1331. >>> EPOCH_NUM = 3
  1332. >>> ITER_NUM = 15
  1333. >>> BATCH_SIZE = 3
  1334. >>> def network(image, label):
  1335. ... # User-defined network, here is an example of softmax regression.
  1336. ... predict = paddle.static.nn.fc(x=image, size=10, activation='softmax')
  1337. ... return paddle.nn.functional.cross_entropy(
  1338. ... input=predict, label=label,
  1339. ... reduction='none', use_softmax=False
  1340. ... )
  1341. >>> def random_image_and_label_generator(height, width):
  1342. ... def generator():
  1343. ... for i in range(ITER_NUM):
  1344. ... batch_image = np.random.uniform(low=0,
  1345. ... high=255,
  1346. ... size=[BATCH_SIZE, height, width])
  1347. ... batch_label = np.ones([BATCH_SIZE, 1])
  1348. ... batch_image = batch_image.astype('float32')
  1349. ... batch_label = batch_label.astype('int64')
  1350. ... yield batch_image, batch_label
  1351. ... return generator
  1352. >>> image = paddle.static.data(name='image', shape=[None, 784, 784], dtype='float32')
  1353. >>> label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
  1354. >>> reader = base.io.PyReader(feed_list=[image, label], capacity=4, iterable=True)
  1355. >>> user_defined_generator = random_image_and_label_generator(784, 784)
  1356. >>> reader.decorate_batch_generator(user_defined_generator, base.CPUPlace())
  1357. >>> loss = network(image, label)
  1358. >>> executor = base.Executor(base.CPUPlace())
  1359. >>> executor.run(base.default_startup_program())
  1360. >>> for _ in range(EPOCH_NUM):
  1361. ... for data in reader():
  1362. ... executor.run(feed=data, fetch_list=[loss])
  1363. '''
  1364. self._loader.set_batch_generator(reader, places)
  1365. class DatasetLoader(DataLoaderBase):
  1366. def __init__(self, dataset, places, drop_last):
  1367. assert isinstance(
  1368. dataset, paddle.distributed.fleet.dataset.DatasetBase
  1369. ), "dataset must be type of DatasetBase"
  1370. assert (
  1371. not in_dygraph_mode()
  1372. ), "DatasetLoader is not supported in dygraph mode yet"
  1373. if isinstance(places, (list, tuple)):
  1374. places = _get_paddle_place_list(places)
  1375. else:
  1376. places = _get_paddle_place(places)
  1377. thread_num = len(places)
  1378. assert (
  1379. len(dataset.filelist) >= thread_num
  1380. ), f"Filelist number of dataset {len(dataset.filelist)} must be not less than place number {thread_num}"
  1381. if dataset.thread_num != 0 and dataset.thread_num != thread_num:
  1382. logging.warn(
  1383. f'thread_num {dataset.thread_num} which is set in Dataset is ignored'
  1384. )
  1385. dataset._set_thread(thread_num)
  1386. if (
  1387. isinstance(
  1388. dataset, paddle.distributed.fleet.dataset.InMemoryDataset
  1389. )
  1390. and dataset.queue_num > thread_num
  1391. ):
  1392. logging.warn(
  1393. f"queue_num {dataset.queue_num} which is set in Dataset is ignored"
  1394. )
  1395. dataset._set_queue_num(thread_num)
  1396. self._dataset = dataset
  1397. use_slots = [
  1398. slot.name
  1399. for slot in dataset.proto_desc.multi_slot_desc.slots
  1400. if slot.is_used
  1401. ]
  1402. self._iterable_dataset = core.IterableDatasetWrapper(
  1403. dataset.dataset,
  1404. use_slots,
  1405. _convert_places(places),
  1406. dataset.proto_desc.batch_size,
  1407. drop_last,
  1408. )
  1409. def __iter__(self):
  1410. self._dataset._finish_to_run()
  1411. self._dataset._prepare_to_run()
  1412. self._iterable_dataset._start()
  1413. return self
  1414. def __next__(self):
  1415. return self._iterable_dataset._next()