multicore.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917
  1. """Classes and functions dealing with augmentation on multiple CPU cores."""
  2. from __future__ import print_function, division, absolute_import
  3. import sys
  4. import multiprocessing
  5. import threading
  6. import traceback
  7. import time
  8. import random
  9. import platform
  10. import numpy as np
  11. import cv2
  12. import imgaug.imgaug as ia
  13. import imgaug.random as iarandom
  14. from imgaug.augmentables.batches import Batch, UnnormalizedBatch
  15. if sys.version_info[0] == 2:
  16. # pylint: disable=redefined-builtin, import-error
  17. import cPickle as pickle
  18. from Queue import Empty as QueueEmpty, Full as QueueFull
  19. import socket
  20. BrokenPipeError = socket.error
  21. elif sys.version_info[0] == 3:
  22. import pickle
  23. from queue import Empty as QueueEmpty, Full as QueueFull
  24. _CONTEXT = None
  25. # Added in 0.4.0.
  26. def _get_context_method():
  27. vinfo = sys.version_info
  28. # get_context() is only supported in 3.5 and later (same for
  29. # set_start_method)
  30. get_context_unsupported = (
  31. vinfo[0] == 2
  32. or (vinfo[0] == 3 and vinfo[1] <= 3))
  33. method = None
  34. # Fix random hanging code in NixOS by switching to spawn method,
  35. # see issue #414
  36. # TODO This is only a workaround and doesn't really fix the underlying
  37. # issue. The cause of the underlying issue is currently unknown.
  38. # Its possible that #535 fixes the issue, though earlier tests
  39. # indicated that the cause was something else.
  40. # TODO this might break the semaphore used to prevent out of memory
  41. # errors
  42. if "NixOS" in platform.version():
  43. method = "spawn"
  44. if get_context_unsupported:
  45. ia.warn("Detected usage of imgaug.multicore in python <=3.4 "
  46. "and NixOS. This is known to sometimes cause endlessly "
  47. "hanging programs when also making use of multicore "
  48. "augmentation (aka background augmentation). Use "
  49. "python 3.5 or later to prevent this.")
  50. if get_context_unsupported:
  51. return False
  52. return method
  53. # Added in 0.4.0.
  54. def _set_context(method):
  55. # method=False indicates that multiprocessing module (i.e. no context)
  56. # should be used, e.g. because get_context() is not supported
  57. globals()["_CONTEXT"] = (
  58. multiprocessing if method is False
  59. else multiprocessing.get_context(method))
  60. # Added in 0.4.0.
  61. def _reset_context():
  62. globals()["_CONTEXT"] = None
  63. # Added in 0.4.0.
  64. def _autoset_context():
  65. _set_context(_get_context_method())
  66. # Added in 0.4.0.
  67. def _get_context():
  68. if _CONTEXT is None:
  69. _autoset_context()
  70. return _CONTEXT
  71. class Pool(object):
  72. """
  73. Wrapper around ``multiprocessing.Pool`` for multicore augmentation.
  74. Parameters
  75. ----------
  76. augseq : imgaug.augmenters.meta.Augmenter
  77. The augmentation sequence to apply to batches.
  78. processes : None or int, optional
  79. The number of background workers, similar to the same parameter in
  80. multiprocessing.Pool. If ``None``, the number of the machine's CPU
  81. cores will be used (this counts hyperthreads as CPU cores). If this is
  82. set to a negative value ``p``, then ``P - abs(p)`` will be used,
  83. where ``P`` is the number of CPU cores. E.g. ``-1`` would use all
  84. cores except one (this is useful to e.g. reserve one core to feed
  85. batches to the GPU).
  86. maxtasksperchild : None or int, optional
  87. The number of tasks done per worker process before the process is
  88. killed and restarted, similar to the same parameter in
  89. multiprocessing.Pool. If ``None``, worker processes will not be
  90. automatically restarted.
  91. seed : None or int, optional
  92. The seed to use for child processes. If ``None``, a random seed will
  93. be used.
  94. """
  95. # This attribute saves the augmentation sequence for background workers so
  96. # that it does not have to be resend with every batch. The attribute is set
  97. # once per worker in the worker's initializer. As each worker has its own
  98. # process, it is a different variable per worker (though usually should be
  99. # of equal content).
  100. _WORKER_AUGSEQ = None
  101. # This attribute saves the initial seed for background workers so that for
  102. # any future batch the batch's specific seed can be derived, roughly via
  103. # SEED_START+SEED_BATCH. As each worker has its own process, this seed can
  104. # be unique per worker even though all seemingly use the same constant
  105. # attribute.
  106. _WORKER_SEED_START = None
  107. def __init__(self, augseq, processes=None, maxtasksperchild=None,
  108. seed=None):
  109. # make sure that don't call pool again in a child process
  110. assert Pool._WORKER_AUGSEQ is None, (
  111. "_WORKER_AUGSEQ was already set when calling Pool.__init__(). "
  112. "Did you try to instantiate a Pool within a Pool?")
  113. assert processes is None or processes != 0, (
  114. "Expected `processes` to be `None` (\"use as many cores as "
  115. "available\") or a negative integer (\"use as many as available "
  116. "MINUS this number\") or an integer>1 (\"use exactly that many "
  117. "processes\"). Got type %s, value %s instead." % (
  118. type(processes), str(processes))
  119. )
  120. self.augseq = augseq
  121. self.processes = processes
  122. self.maxtasksperchild = maxtasksperchild
  123. if seed is not None:
  124. assert iarandom.SEED_MIN_VALUE <= seed <= iarandom.SEED_MAX_VALUE, (
  125. "Expected `seed` to be either `None` or a value between "
  126. "%d and %d. Got type %s, value %s instead." % (
  127. iarandom.SEED_MIN_VALUE,
  128. iarandom.SEED_MAX_VALUE,
  129. type(seed),
  130. str(seed)
  131. )
  132. )
  133. self.seed = seed
  134. # multiprocessing.Pool instance
  135. self._pool = None
  136. # Running counter of the number of augmented batches. This will be
  137. # used to send indexes for each batch to the workers so that they can
  138. # augment using SEED_BASE+SEED_BATCH and ensure consistency of applied
  139. # augmentation order between script runs.
  140. self._batch_idx = 0
  141. @property
  142. def pool(self):
  143. """Return or create the ``multiprocessing.Pool`` instance.
  144. This creates a new instance upon the first call and afterwards
  145. returns that instance (until the property ``_pool`` is set to
  146. ``None`` again).
  147. Returns
  148. -------
  149. multiprocessing.Pool
  150. The ``multiprocessing.Pool`` used internally by this
  151. ``imgaug.multicore.Pool``.
  152. """
  153. if self._pool is None:
  154. processes = self.processes
  155. if processes is not None and processes < 0:
  156. # cpu count returns the number of logical cpu cores, i.e.
  157. # including hyperthreads could also use
  158. # os.sched_getaffinity(0) here, which seems to not exist on
  159. # BSD though.
  160. # In python 3.4+, there is also os.cpu_count(), which
  161. # multiprocessing.cpu_count() then redirects to.
  162. # At least one guy on stackoverflow.com/questions/1006289
  163. # reported that only os.* existed, not the multiprocessing
  164. # method.
  165. # TODO make this also check if os.cpu_count exists as a
  166. # fallback
  167. try:
  168. processes = _get_context().cpu_count() - abs(processes)
  169. processes = max(processes, 1)
  170. except (ImportError, NotImplementedError):
  171. ia.warn(
  172. "Could not find method multiprocessing.cpu_count(). "
  173. "This will likely lead to more CPU cores being used "
  174. "for the background augmentation than originally "
  175. "intended.")
  176. processes = None
  177. self._pool = _get_context().Pool(
  178. processes,
  179. initializer=_Pool_initialize_worker,
  180. initargs=(self.augseq, self.seed),
  181. maxtasksperchild=self.maxtasksperchild)
  182. return self._pool
  183. def map_batches(self, batches, chunksize=None):
  184. """
  185. Augment a list of batches.
  186. Parameters
  187. ----------
  188. batches : list of imgaug.augmentables.batches.Batch
  189. The batches to augment.
  190. chunksize : None or int, optional
  191. Rough indicator of how many tasks should be sent to each worker.
  192. Increasing this number can improve performance.
  193. Returns
  194. -------
  195. list of imgaug.augmentables.batches.Batch
  196. Augmented batches.
  197. """
  198. self._assert_batches_is_list(batches)
  199. return self.pool.map(
  200. _Pool_starworker,
  201. self._handle_batch_ids(batches),
  202. chunksize=chunksize)
  203. def map_batches_async(self, batches, chunksize=None, callback=None,
  204. error_callback=None):
  205. """
  206. Augment batches asynchonously.
  207. Parameters
  208. ----------
  209. batches : list of imgaug.augmentables.batches.Batch
  210. The batches to augment.
  211. chunksize : None or int, optional
  212. Rough indicator of how many tasks should be sent to each worker.
  213. Increasing this number can improve performance.
  214. callback : None or callable, optional
  215. Function to call upon finish. See ``multiprocessing.Pool``.
  216. error_callback : None or callable, optional
  217. Function to call upon errors. See ``multiprocessing.Pool``.
  218. Returns
  219. -------
  220. multiprocessing.MapResult
  221. Asynchonous result. See ``multiprocessing.Pool``.
  222. """
  223. self._assert_batches_is_list(batches)
  224. return self.pool.map_async(
  225. _Pool_starworker,
  226. self._handle_batch_ids(batches),
  227. chunksize=chunksize,
  228. callback=callback,
  229. error_callback=error_callback)
  230. @classmethod
  231. def _assert_batches_is_list(cls, batches):
  232. assert isinstance(batches, list), (
  233. "Expected `batches` to be a list, got type %s. Call "
  234. "imap_batches() if you use generators.") % (type(batches),)
  235. def imap_batches(self, batches, chunksize=1, output_buffer_size=None):
  236. """
  237. Augment batches from a generator.
  238. Pattern for output buffer constraint is from
  239. https://stackoverflow.com/a/47058399.
  240. Parameters
  241. ----------
  242. batches : generator of imgaug.augmentables.batches.Batch
  243. The batches to augment, provided as a generator. Each call to the
  244. generator should yield exactly one batch.
  245. chunksize : None or int, optional
  246. Rough indicator of how many tasks should be sent to each worker.
  247. Increasing this number can improve performance.
  248. output_buffer_size : None or int, optional
  249. Max number of batches to handle *at the same time* in the *whole*
  250. pipeline (including already augmented batches that are waiting to
  251. be requested). If the buffer size is reached, no new batches will
  252. be loaded from `batches` until a produced (i.e. augmented) batch is
  253. consumed (i.e. requested from this method).
  254. The buffer is unlimited if this is set to ``None``. For large
  255. datasets, this should be set to an integer value to avoid filling
  256. the whole RAM if loading+augmentation happens faster than training.
  257. *New in version 0.3.0.*
  258. Yields
  259. ------
  260. imgaug.augmentables.batches.Batch
  261. Augmented batch.
  262. """
  263. self._assert_batches_is_generator(batches)
  264. # buffer is either None or a Semaphore
  265. output_buffer_left = _create_output_buffer_left(output_buffer_size)
  266. # TODO change this to 'yield from' once switched to 3.3+
  267. gen = self.pool.imap(
  268. _Pool_starworker,
  269. self._ibuffer_batch_loading(
  270. self._handle_batch_ids_gen(batches),
  271. output_buffer_left
  272. ),
  273. chunksize=chunksize)
  274. for batch in gen:
  275. yield batch
  276. if output_buffer_left is not None:
  277. output_buffer_left.release()
  278. def imap_batches_unordered(self, batches, chunksize=1,
  279. output_buffer_size=None):
  280. """Augment batches from a generator (without preservation of order).
  281. Pattern for output buffer constraint is from
  282. https://stackoverflow.com/a/47058399.
  283. Parameters
  284. ----------
  285. batches : generator of imgaug.augmentables.batches.Batch
  286. The batches to augment, provided as a generator. Each call to the
  287. generator should yield exactly one batch.
  288. chunksize : None or int, optional
  289. Rough indicator of how many tasks should be sent to each worker.
  290. Increasing this number can improve performance.
  291. output_buffer_size : None or int, optional
  292. Max number of batches to handle *at the same time* in the *whole*
  293. pipeline (including already augmented batches that are waiting to
  294. be requested). If the buffer size is reached, no new batches will
  295. be loaded from `batches` until a produced (i.e. augmented) batch is
  296. consumed (i.e. requested from this method).
  297. The buffer is unlimited if this is set to ``None``. For large
  298. datasets, this should be set to an integer value to avoid filling
  299. the whole RAM if loading+augmentation happens faster than training.
  300. *New in version 0.3.0.*
  301. Yields
  302. ------
  303. imgaug.augmentables.batches.Batch
  304. Augmented batch.
  305. """
  306. self._assert_batches_is_generator(batches)
  307. # buffer is either None or a Semaphore
  308. output_buffer_left = _create_output_buffer_left(output_buffer_size)
  309. gen = self.pool.imap_unordered(
  310. _Pool_starworker,
  311. self._ibuffer_batch_loading(
  312. self._handle_batch_ids_gen(batches),
  313. output_buffer_left
  314. ),
  315. chunksize=chunksize
  316. )
  317. for batch in gen:
  318. yield batch
  319. if output_buffer_left is not None:
  320. output_buffer_left.release()
  321. @classmethod
  322. def _assert_batches_is_generator(cls, batches):
  323. assert ia.is_generator(batches), (
  324. "Expected `batches` to be generator, got type %s. Call "
  325. "map_batches() if you use lists.") % (type(batches),)
  326. def __enter__(self):
  327. assert self._pool is None, (
  328. "Tried to __enter__ a pool that has already been initialized.")
  329. _ = self.pool # initialize internal multiprocessing pool instance
  330. return self
  331. def __exit__(self, exc_type, exc_val, exc_tb):
  332. self.close()
  333. def close(self):
  334. """Close the pool gracefully."""
  335. if self._pool is not None:
  336. self._pool.close()
  337. self._pool.join()
  338. self._pool = None
  339. def terminate(self):
  340. """Terminate the pool immediately."""
  341. if self._pool is not None:
  342. self._pool.terminate()
  343. self._pool.join()
  344. self._pool = None
  345. # TODO why does this function exist if it may only be called after
  346. # close/terminate and both of these two already call join() themselves
  347. def join(self):
  348. """
  349. Wait for the workers to exit.
  350. This may only be called after first calling
  351. :func:`~imgaug.multicore.Pool.close` or
  352. :func:`~imgaug.multicore.Pool.terminate`.
  353. """
  354. if self._pool is not None:
  355. self._pool.join()
  356. def _handle_batch_ids(self, batches):
  357. ids = np.arange(self._batch_idx, self._batch_idx + len(batches))
  358. inputs = list(zip(ids, batches))
  359. self._batch_idx += len(batches)
  360. return inputs
  361. def _handle_batch_ids_gen(self, batches):
  362. for batch in batches:
  363. batch_idx = self._batch_idx
  364. yield batch_idx, batch
  365. self._batch_idx += 1
  366. @classmethod
  367. def _ibuffer_batch_loading(cls, batches, output_buffer_left):
  368. for batch in batches:
  369. if output_buffer_left is not None:
  370. output_buffer_left.acquire()
  371. yield batch
  372. def _create_output_buffer_left(output_buffer_size):
  373. output_buffer_left = None
  374. if output_buffer_size:
  375. assert output_buffer_size > 0, (
  376. "Expected buffer size to be greater than zero, but got size %d "
  377. "instead." % (output_buffer_size,))
  378. output_buffer_left = _get_context().Semaphore(output_buffer_size)
  379. return output_buffer_left
  380. # This could be a classmethod or staticmethod of Pool in 3.x, but in 2.7 that
  381. # leads to pickle errors.
  382. def _Pool_initialize_worker(augseq, seed_start):
  383. # pylint: disable=invalid-name, protected-access
  384. # Not using this seems to have caused infinite hanging in the case
  385. # of gaussian blur on at least MacOSX.
  386. # It is also in most cases probably not sensible to use multiple
  387. # threads while already running augmentation in multiple processes.
  388. cv2.setNumThreads(0)
  389. if seed_start is None:
  390. # pylint falsely thinks in older versions that
  391. # multiprocessing.current_process() was not callable, see
  392. # https://github.com/PyCQA/pylint/issues/1699
  393. # pylint: disable=not-callable
  394. process_name = _get_context().current_process().name
  395. # pylint: enable=not-callable
  396. # time_ns() exists only in 3.7+
  397. if sys.version_info[0] == 3 and sys.version_info[1] >= 7:
  398. seed_offset = time.time_ns()
  399. else:
  400. seed_offset = int(time.time() * 10**6) % 10**6
  401. seed = hash(process_name) + seed_offset
  402. _reseed_global_local(seed, augseq)
  403. Pool._WORKER_SEED_START = seed_start
  404. Pool._WORKER_AUGSEQ = augseq
  405. # not sure if really necessary, but shouldn't hurt either
  406. Pool._WORKER_AUGSEQ.localize_random_state_()
  407. # This could be a classmethod or staticmethod of Pool in 3.x, but in 2.7 that
  408. # leads to pickle errors.
  409. def _Pool_worker(batch_idx, batch):
  410. # pylint: disable=invalid-name, protected-access
  411. assert ia.is_single_integer(batch_idx), (
  412. "Expected `batch_idx` to be an integer. Got type %s instead." % (
  413. type(batch_idx)
  414. ))
  415. assert isinstance(batch, (UnnormalizedBatch, Batch)), (
  416. "Expected `batch` to be either an instance of "
  417. "`imgaug.augmentables.batches.UnnormalizedBatch` or "
  418. "`imgaug.augmentables.batches.Batch`. Got type %s instead." % (
  419. type(batch)
  420. ))
  421. assert Pool._WORKER_AUGSEQ is not None, (
  422. "Expected `Pool._WORKER_AUGSEQ` to NOT be `None`. Did you manually "
  423. "call _Pool_worker()?")
  424. augseq = Pool._WORKER_AUGSEQ
  425. # TODO why is this if here? _WORKER_SEED_START should always be set?
  426. if Pool._WORKER_SEED_START is not None:
  427. seed = Pool._WORKER_SEED_START + batch_idx
  428. _reseed_global_local(seed, augseq)
  429. result = augseq.augment_batch_(batch)
  430. return result
  431. # could be a classmethod or staticmethod of Pool in 3.x, but in 2.7 that leads
  432. # to pickle errors starworker is here necessary, because starmap does not exist
  433. # in 2.7
  434. def _Pool_starworker(inputs):
  435. # pylint: disable=invalid-name
  436. return _Pool_worker(*inputs)
  437. def _reseed_global_local(base_seed, augseq):
  438. seed_global = _derive_seed(base_seed, -10**9)
  439. seed_local = _derive_seed(base_seed)
  440. iarandom.seed(seed_global)
  441. augseq.seed_(seed_local)
  442. def _derive_seed(base_seed, offset=0):
  443. return (
  444. iarandom.SEED_MIN_VALUE
  445. + (base_seed + offset)
  446. % (iarandom.SEED_MAX_VALUE - iarandom.SEED_MIN_VALUE)
  447. )
  448. class BatchLoader(object):
  449. """**Deprecated**. Load batches in the background.
  450. Deprecated. Use ``imgaug.multicore.Pool`` instead.
  451. Loaded batches can be accesses using :attr:`imgaug.BatchLoader.queue`.
  452. Parameters
  453. ----------
  454. load_batch_func : callable or generator
  455. Generator or generator function (i.e. function that yields Batch
  456. objects) or a function that returns a list of Batch objects.
  457. Background loading automatically stops when the last batch was yielded
  458. or the last batch in the list was reached.
  459. queue_size : int, optional
  460. Maximum number of batches to store in the queue. May be set higher
  461. for small images and/or small batches.
  462. nb_workers : int, optional
  463. Number of workers to run in the background.
  464. threaded : bool, optional
  465. Whether to run the background processes using threads (True) or full
  466. processes (False).
  467. """
  468. @ia.deprecated(alt_func="imgaug.multicore.Pool")
  469. def __init__(self, load_batch_func, queue_size=50, nb_workers=1,
  470. threaded=True):
  471. assert queue_size >= 2, (
  472. "Queue size for BatchLoader must be at least 2, "
  473. "got %d." % (queue_size,))
  474. assert nb_workers >= 1, (
  475. "Number of workers for BatchLoader must be at least 1, "
  476. "got %d" % (nb_workers,))
  477. self._queue_internal = multiprocessing.Queue(queue_size//2)
  478. self.queue = multiprocessing.Queue(queue_size//2)
  479. self.join_signal = multiprocessing.Event()
  480. self.workers = []
  481. self.threaded = threaded
  482. seeds = iarandom.get_global_rng().generate_seeds_(nb_workers)
  483. for i in range(nb_workers):
  484. if threaded:
  485. worker = threading.Thread(
  486. target=self._load_batches,
  487. args=(load_batch_func, self._queue_internal,
  488. self.join_signal, None)
  489. )
  490. else:
  491. worker = multiprocessing.Process(
  492. target=self._load_batches,
  493. args=(load_batch_func, self._queue_internal,
  494. self.join_signal, seeds[i])
  495. )
  496. worker.daemon = True
  497. worker.start()
  498. self.workers.append(worker)
  499. self.main_worker_thread = threading.Thread(
  500. target=self._main_worker,
  501. args=()
  502. )
  503. self.main_worker_thread.daemon = True
  504. self.main_worker_thread.start()
  505. def count_workers_alive(self):
  506. return sum([int(worker.is_alive()) for worker in self.workers])
  507. def all_finished(self):
  508. """
  509. Determine whether the workers have finished the loading process.
  510. Returns
  511. -------
  512. out : bool
  513. True if all workers have finished. Else False.
  514. """
  515. return self.count_workers_alive() == 0
  516. def _main_worker(self):
  517. workers_running = self.count_workers_alive()
  518. while workers_running > 0 and not self.join_signal.is_set():
  519. # wait for a new batch in the source queue and load it
  520. try:
  521. batch_str = self._queue_internal.get(timeout=0.1)
  522. if batch_str == "":
  523. workers_running -= 1
  524. else:
  525. self.queue.put(batch_str)
  526. except QueueEmpty:
  527. time.sleep(0.01)
  528. except (EOFError, BrokenPipeError):
  529. break
  530. workers_running = self.count_workers_alive()
  531. # All workers have finished, move the remaining entries from internal
  532. # to external queue
  533. while True:
  534. try:
  535. batch_str = self._queue_internal.get(timeout=0.005)
  536. if batch_str != "":
  537. self.queue.put(batch_str)
  538. except QueueEmpty:
  539. break
  540. except (EOFError, BrokenPipeError):
  541. break
  542. self.queue.put(pickle.dumps(None, protocol=-1))
  543. time.sleep(0.01)
  544. @classmethod
  545. def _load_batches(cls, load_batch_func, queue_internal, join_signal,
  546. seedval):
  547. # pylint: disable=broad-except
  548. if seedval is not None:
  549. random.seed(seedval)
  550. np.random.seed(seedval)
  551. iarandom.seed(seedval)
  552. try:
  553. gen = (
  554. load_batch_func()
  555. if not ia.is_generator(load_batch_func)
  556. else load_batch_func
  557. )
  558. for batch in gen:
  559. assert isinstance(batch, Batch), (
  560. "Expected batch returned by load_batch_func to "
  561. "be of class imgaug.Batch, got %s." % (
  562. type(batch),))
  563. batch_pickled = pickle.dumps(batch, protocol=-1)
  564. while not join_signal.is_set():
  565. try:
  566. queue_internal.put(batch_pickled, timeout=0.005)
  567. break
  568. except QueueFull:
  569. pass
  570. if join_signal.is_set():
  571. break
  572. except Exception:
  573. traceback.print_exc()
  574. finally:
  575. queue_internal.put("")
  576. time.sleep(0.01)
  577. def terminate(self):
  578. """Stop all workers."""
  579. # pylint: disable=protected-access
  580. if not self.join_signal.is_set():
  581. self.join_signal.set()
  582. # give minimal time to put generated batches in queue and gracefully
  583. # shut down
  584. time.sleep(0.01)
  585. if self.main_worker_thread.is_alive():
  586. self.main_worker_thread.join()
  587. if self.threaded:
  588. for worker in self.workers:
  589. if worker.is_alive():
  590. worker.join()
  591. else:
  592. for worker in self.workers:
  593. if worker.is_alive():
  594. worker.terminate()
  595. worker.join()
  596. # wait until all workers are fully terminated
  597. while not self.all_finished():
  598. time.sleep(0.001)
  599. # empty queue until at least one element can be added and place None
  600. # as signal that BL finished
  601. if self.queue.full():
  602. self.queue.get()
  603. self.queue.put(pickle.dumps(None, protocol=-1))
  604. time.sleep(0.01)
  605. # clean the queue, this reportedly prevents hanging threads
  606. while True:
  607. try:
  608. self._queue_internal.get(timeout=0.005)
  609. except QueueEmpty:
  610. break
  611. if not self._queue_internal._closed:
  612. self._queue_internal.close()
  613. if not self.queue._closed:
  614. self.queue.close()
  615. self._queue_internal.join_thread()
  616. self.queue.join_thread()
  617. time.sleep(0.025)
  618. def __del__(self):
  619. if not self.join_signal.is_set():
  620. self.join_signal.set()
  621. class BackgroundAugmenter(object):
  622. """
  623. **Deprecated**. Augment batches in the background processes.
  624. Deprecated. Use ``imgaug.multicore.Pool`` instead.
  625. This is a wrapper around the multiprocessing module.
  626. Parameters
  627. ----------
  628. batch_loader : BatchLoader or multiprocessing.Queue
  629. BatchLoader object that loads the data fed into the
  630. BackgroundAugmenter, or alternatively a Queue. If a Queue, then it
  631. must be made sure that a final ``None`` in the Queue signals that the
  632. loading is finished and no more batches will follow. Otherwise the
  633. BackgroundAugmenter will wait forever for the next batch.
  634. augseq : Augmenter
  635. An augmenter to apply to all loaded images.
  636. This may be e.g. a Sequential to apply multiple augmenters.
  637. queue_size : int
  638. Size of the queue that is used to temporarily save the augmentation
  639. results. Larger values offer the background processes more room
  640. to save results when the main process doesn't load much, i.e. they
  641. can lead to smoother and faster training. For large images, high
  642. values can block a lot of RAM though.
  643. nb_workers : 'auto' or int
  644. Number of background workers to spawn.
  645. If ``auto``, it will be set to ``C-1``, where ``C`` is the number of
  646. CPU cores.
  647. """
  648. @ia.deprecated(alt_func="imgaug.multicore.Pool")
  649. def __init__(self, batch_loader, augseq, queue_size=50, nb_workers="auto"):
  650. assert queue_size > 0, (
  651. "Expected 'queue_size' to be at least 1, got %d." % (queue_size,))
  652. self.augseq = augseq
  653. self.queue_source = (
  654. batch_loader
  655. if isinstance(batch_loader, multiprocessing.queues.Queue)
  656. else batch_loader.queue
  657. )
  658. self.queue_result = multiprocessing.Queue(queue_size)
  659. if nb_workers == "auto":
  660. try:
  661. nb_workers = multiprocessing.cpu_count()
  662. except (ImportError, NotImplementedError):
  663. nb_workers = 1
  664. # try to reserve at least one core for the main process
  665. nb_workers = max(1, nb_workers - 1)
  666. else:
  667. assert nb_workers >= 1, (
  668. "Expected 'nb_workers' to be \"auto\" or at least 1, "
  669. "got %d instead." % (nb_workers,))
  670. self.nb_workers = nb_workers
  671. self.workers = []
  672. self.nb_workers_finished = 0
  673. seeds = iarandom.get_global_rng().generate_seeds_(nb_workers)
  674. for i in range(nb_workers):
  675. worker = multiprocessing.Process(
  676. target=self._augment_images_worker,
  677. args=(augseq, self.queue_source, self.queue_result, seeds[i])
  678. )
  679. worker.daemon = True
  680. worker.start()
  681. self.workers.append(worker)
  682. def all_finished(self):
  683. return self.nb_workers_finished == self.nb_workers
  684. def get_batch(self):
  685. """
  686. Returns a batch from the queue of augmented batches.
  687. If workers are still running and there are no batches in the queue,
  688. it will automatically wait for the next batch.
  689. Returns
  690. -------
  691. out : None or imgaug.Batch
  692. One batch or None if all workers have finished.
  693. """
  694. if self.all_finished():
  695. return None
  696. batch_str = self.queue_result.get()
  697. batch = pickle.loads(batch_str)
  698. if batch is not None:
  699. return batch
  700. self.nb_workers_finished += 1
  701. if self.nb_workers_finished >= self.nb_workers:
  702. try:
  703. # remove `None` from the source queue
  704. self.queue_source.get(timeout=0.001)
  705. except QueueEmpty:
  706. pass
  707. return None
  708. return self.get_batch()
  709. @classmethod
  710. def _augment_images_worker(cls, augseq, queue_source, queue_result,
  711. seedval):
  712. """
  713. Augment endlessly images in the source queue.
  714. This is a worker function for that endlessly queries the source queue
  715. (input batches), augments batches in it and sends the result to the
  716. output queue.
  717. """
  718. np.random.seed(seedval)
  719. random.seed(seedval)
  720. augseq.seed_(seedval)
  721. iarandom.seed(seedval)
  722. loader_finished = False
  723. while not loader_finished:
  724. # wait for a new batch in the source queue and load it
  725. try:
  726. batch_str = queue_source.get(timeout=0.1)
  727. batch = pickle.loads(batch_str)
  728. if batch is None:
  729. loader_finished = True
  730. # put it back in so that other workers know that the
  731. # loading queue is finished
  732. queue_source.put(pickle.dumps(None, protocol=-1))
  733. else:
  734. batch_aug = augseq.augment_batch_(batch)
  735. # send augmented batch to output queue
  736. batch_str = pickle.dumps(batch_aug, protocol=-1)
  737. queue_result.put(batch_str)
  738. except QueueEmpty:
  739. time.sleep(0.01)
  740. queue_result.put(pickle.dumps(None, protocol=-1))
  741. time.sleep(0.01)
  742. def terminate(self):
  743. """
  744. Terminates all background processes immediately.
  745. This will also free their RAM.
  746. """
  747. # pylint: disable=protected-access
  748. for worker in self.workers:
  749. if worker.is_alive():
  750. worker.terminate()
  751. self.nb_workers_finished = len(self.workers)
  752. if not self.queue_result._closed:
  753. self.queue_result.close()
  754. time.sleep(0.01)
  755. def __del__(self):
  756. time.sleep(0.1)
  757. self.terminate()