batches.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088
  1. """Classes representing batches of normalized or unnormalized data."""
  2. from __future__ import print_function, division, absolute_import
  3. import collections
  4. import numpy as np
  5. from .. import imgaug as ia
  6. from . import normalization as nlib
  7. from . import utils
  8. DEFAULT = "DEFAULT"
  9. _AUGMENTABLE_NAMES = [
  10. "images", "heatmaps", "segmentation_maps", "keypoints",
  11. "bounding_boxes", "polygons", "line_strings"]
  12. _AugmentableColumn = collections.namedtuple(
  13. "_AugmentableColumn",
  14. ["name", "value", "attr_name"])
  15. def _get_column_names(batch, postfix):
  16. return [column.name
  17. for column
  18. in _get_columns(batch, postfix)]
  19. def _get_columns(batch, postfix):
  20. result = []
  21. for name in _AUGMENTABLE_NAMES:
  22. attr_name = name + postfix
  23. value = getattr(batch, name + postfix)
  24. # Every data item is either an array or a list. If there are no
  25. # items in the array/list, there are also no shapes to change
  26. # as shape-changes are imagewise. Hence, we can afford to check
  27. # len() here.
  28. if value is not None and len(value) > 0:
  29. result.append(_AugmentableColumn(name, value, attr_name))
  30. return result
  31. # TODO also support (H,W,C) for heatmaps of len(images) == 1
  32. # TODO also support (H,W) for segmaps of len(images) == 1
  33. class UnnormalizedBatch(object):
  34. """
  35. Class for batches of unnormalized data before and after augmentation.
  36. Parameters
  37. ----------
  38. images : None or (N,H,W,C) ndarray or (N,H,W) ndarray or iterable of (H,W,C) ndarray or iterable of (H,W) ndarray
  39. The images to augment.
  40. heatmaps : None or (N,H,W,C) ndarray or imgaug.augmentables.heatmaps.HeatmapsOnImage or iterable of (H,W,C) ndarray or iterable of imgaug.augmentables.heatmaps.HeatmapsOnImage
  41. The heatmaps to augment.
  42. If anything else than ``HeatmapsOnImage``, then the number of heatmaps
  43. must match the number of images provided via parameter `images`.
  44. The number is contained either in ``N`` or the first iterable's size.
  45. segmentation_maps : None or (N,H,W) ndarray or imgaug.augmentables.segmaps.SegmentationMapsOnImage or iterable of (H,W) ndarray or iterable of imgaug.augmentables.segmaps.SegmentationMapsOnImage
  46. The segmentation maps to augment.
  47. If anything else than ``SegmentationMapsOnImage``, then the number of
  48. segmaps must match the number of images provided via parameter
  49. `images`. The number is contained either in ``N`` or the first
  50. iterable's size.
  51. keypoints : None or list of (N,K,2) ndarray or tuple of number or imgaug.augmentables.kps.Keypoint or iterable of (K,2) ndarray or iterable of tuple of number or iterable of imgaug.augmentables.kps.Keypoint or iterable of imgaug.augmentables.kps.KeypointOnImage or iterable of iterable of tuple of number or iterable of iterable of imgaug.augmentables.kps.Keypoint
  52. The keypoints to augment.
  53. If a tuple (or iterable(s) of tuple), then iterpreted as (x,y)
  54. coordinates and must hence contain two numbers.
  55. A single tuple represents a single coordinate on one image, an
  56. iterable of tuples the coordinates on one image and an iterable of
  57. iterable of tuples the coordinates on several images. Analogous if
  58. ``Keypoint`` objects are used instead of tuples.
  59. If an ndarray, then ``N`` denotes the number of images and ``K`` the
  60. number of keypoints on each image.
  61. If anything else than ``KeypointsOnImage`` is provided, then the
  62. number of keypoint groups must match the number of images provided
  63. via parameter `images`. The number is contained e.g. in ``N`` or
  64. in case of "iterable of iterable of tuples" in the first iterable's
  65. size.
  66. bounding_boxes : None or (N,B,4) ndarray or tuple of number or imgaug.augmentables.bbs.BoundingBox or imgaug.augmentables.bbs.BoundingBoxesOnImage or iterable of (B,4) ndarray or iterable of tuple of number or iterable of imgaug.augmentables.bbs.BoundingBox or iterable of imgaug.augmentables.bbs.BoundingBoxesOnImage or iterable of iterable of tuple of number or iterable of iterable imgaug.augmentables.bbs.BoundingBox
  67. The bounding boxes to augment.
  68. This is analogous to the `keypoints` parameter. However, each
  69. tuple -- and also the last index in case of arrays -- has size 4,
  70. denoting the bounding box coordinates ``x1``, ``y1``, ``x2`` and ``y2``.
  71. polygons : None or (N,#polys,#points,2) ndarray or imgaug.augmentables.polys.Polygon or imgaug.augmentables.polys.PolygonsOnImage or iterable of (#polys,#points,2) ndarray or iterable of tuple of number or iterable of imgaug.augmentables.kps.Keypoint or iterable of imgaug.augmentables.polys.Polygon or iterable of imgaug.augmentables.polys.PolygonsOnImage or iterable of iterable of (#points,2) ndarray or iterable of iterable of tuple of number or iterable of iterable of imgaug.augmentables.kps.Keypoint or iterable of iterable of imgaug.augmentables.polys.Polygon or iterable of iterable of iterable of tuple of number or iterable of iterable of iterable of tuple of imgaug.augmentables.kps.Keypoint
  72. The polygons to augment.
  73. This is similar to the `keypoints` parameter. However, each polygon
  74. may be made up of several ``(x,y)`` coordinates (three or more are
  75. required for valid polygons).
  76. The following datatypes will be interpreted as a single polygon on a
  77. single image:
  78. * ``imgaug.augmentables.polys.Polygon``
  79. * ``iterable of tuple of number``
  80. * ``iterable of imgaug.augmentables.kps.Keypoint``
  81. The following datatypes will be interpreted as multiple polygons on a
  82. single image:
  83. * ``imgaug.augmentables.polys.PolygonsOnImage``
  84. * ``iterable of imgaug.augmentables.polys.Polygon``
  85. * ``iterable of iterable of tuple of number``
  86. * ``iterable of iterable of imgaug.augmentables.kps.Keypoint``
  87. * ``iterable of iterable of imgaug.augmentables.polys.Polygon``
  88. The following datatypes will be interpreted as multiple polygons on
  89. multiple images:
  90. * ``(N,#polys,#points,2) ndarray``
  91. * ``iterable of (#polys,#points,2) ndarray``
  92. * ``iterable of iterable of (#points,2) ndarray``
  93. * ``iterable of iterable of iterable of tuple of number``
  94. * ``iterable of iterable of iterable of tuple of imgaug.augmentables.kps.Keypoint``
  95. line_strings : None or (N,#lines,#points,2) ndarray or imgaug.augmentables.lines.LineString or imgaug.augmentables.lines.LineStringOnImage or iterable of (#lines,#points,2) ndarray or iterable of tuple of number or iterable of imgaug.augmentables.kps.Keypoint or iterable of imgaug.augmentables.lines.LineString or iterable of imgaug.augmentables.lines.LineStringOnImage or iterable of iterable of (#points,2) ndarray or iterable of iterable of tuple of number or iterable of iterable of imgaug.augmentables.kps.Keypoint or iterable of iterable of imgaug.augmentables.polys.LineString or iterable of iterable of iterable of tuple of number or iterable of iterable of iterable of tuple of imgaug.augmentables.kps.Keypoint
  96. The line strings to augment.
  97. See `polygons` for more details as polygons follow a similar
  98. structure to line strings.
  99. data
  100. Additional data that is saved in the batch and may be read out
  101. after augmentation. This could e.g. contain filepaths to each image
  102. in `images`. As this object is usually used for background
  103. augmentation with multiple processes, the augmented Batch objects might
  104. not be returned in the original order, making this information useful.
  105. """
  106. def __init__(self, images=None, heatmaps=None, segmentation_maps=None,
  107. keypoints=None, bounding_boxes=None, polygons=None,
  108. line_strings=None, data=None):
  109. """Construct a new :class:`UnnormalizedBatch` instance."""
  110. self.images_unaug = images
  111. self.images_aug = None
  112. self.heatmaps_unaug = heatmaps
  113. self.heatmaps_aug = None
  114. self.segmentation_maps_unaug = segmentation_maps
  115. self.segmentation_maps_aug = None
  116. self.keypoints_unaug = keypoints
  117. self.keypoints_aug = None
  118. self.bounding_boxes_unaug = bounding_boxes
  119. self.bounding_boxes_aug = None
  120. self.polygons_unaug = polygons
  121. self.polygons_aug = None
  122. self.line_strings_unaug = line_strings
  123. self.line_strings_aug = None
  124. self.data = data
  125. def get_column_names(self):
  126. """Get the names of types of augmentables that contain data.
  127. This method is intended for situations where one wants to know which
  128. data is contained in the batch that has to be augmented, visualized
  129. or something similar.
  130. Added in 0.4.0.
  131. Returns
  132. -------
  133. list of str
  134. Names of types of augmentables. E.g. ``["images", "polygons"]``.
  135. """
  136. return _get_column_names(self, "_unaug")
  137. def to_normalized_batch(self):
  138. """Convert this unnormalized batch to an instance of Batch.
  139. As this method is intended to be called before augmentation, it
  140. assumes that none of the ``*_aug`` attributes is yet set.
  141. It will produce an AssertionError otherwise.
  142. The newly created Batch's ``*_unaug`` attributes will match the ones
  143. in this batch, just in normalized form.
  144. Returns
  145. -------
  146. imgaug.augmentables.batches.Batch
  147. The batch, with ``*_unaug`` attributes being normalized.
  148. """
  149. contains_no_augmented_data_yet = all([
  150. attr is None
  151. for attr_name, attr
  152. in self.__dict__.items()
  153. if attr_name.endswith("_aug")])
  154. assert contains_no_augmented_data_yet, (
  155. "Expected UnnormalizedBatch to not contain any augmented data "
  156. "before normalization, but at least one '*_aug' attribute was "
  157. "already set.")
  158. images_unaug = nlib.normalize_images(self.images_unaug)
  159. shapes = None
  160. if images_unaug is not None:
  161. shapes = [image.shape for image in images_unaug]
  162. return Batch(
  163. images=images_unaug,
  164. heatmaps=nlib.normalize_heatmaps(
  165. self.heatmaps_unaug, shapes),
  166. segmentation_maps=nlib.normalize_segmentation_maps(
  167. self.segmentation_maps_unaug, shapes),
  168. keypoints=nlib.normalize_keypoints(
  169. self.keypoints_unaug, shapes),
  170. bounding_boxes=nlib.normalize_bounding_boxes(
  171. self.bounding_boxes_unaug, shapes),
  172. polygons=nlib.normalize_polygons(
  173. self.polygons_unaug, shapes),
  174. line_strings=nlib.normalize_line_strings(
  175. self.line_strings_unaug, shapes),
  176. data=self.data
  177. )
  178. def fill_from_augmented_normalized_batch_(self, batch_aug_norm):
  179. """
  180. Fill this batch with (normalized) augmentation results in-place.
  181. This method receives a (normalized) Batch instance, takes all
  182. ``*_aug`` attributes out if it and assigns them to this
  183. batch *in unnormalized form*. Hence, the datatypes of all ``*_aug``
  184. attributes will match the datatypes of the ``*_unaug`` attributes.
  185. Added in 0.4.0.
  186. Parameters
  187. ----------
  188. batch_aug_norm: imgaug.augmentables.batches.Batch
  189. Batch after normalization and augmentation.
  190. Returns
  191. -------
  192. imgaug.augmentables.batches.UnnormalizedBatch
  193. This instance itself.
  194. All ``*_unaug`` attributes are unchanged.
  195. All ``*_aug`` attributes are taken from `batch_normalized`,
  196. converted to unnormalized form.
  197. """
  198. self.images_aug = nlib.invert_normalize_images(
  199. batch_aug_norm.images_aug, self.images_unaug)
  200. self.heatmaps_aug = nlib.invert_normalize_heatmaps(
  201. batch_aug_norm.heatmaps_aug, self.heatmaps_unaug)
  202. self.segmentation_maps_aug = nlib.invert_normalize_segmentation_maps(
  203. batch_aug_norm.segmentation_maps_aug, self.segmentation_maps_unaug)
  204. self.keypoints_aug = nlib.invert_normalize_keypoints(
  205. batch_aug_norm.keypoints_aug, self.keypoints_unaug)
  206. self.bounding_boxes_aug = nlib.invert_normalize_bounding_boxes(
  207. batch_aug_norm.bounding_boxes_aug, self.bounding_boxes_unaug)
  208. self.polygons_aug = nlib.invert_normalize_polygons(
  209. batch_aug_norm.polygons_aug, self.polygons_unaug)
  210. self.line_strings_aug = nlib.invert_normalize_line_strings(
  211. batch_aug_norm.line_strings_aug, self.line_strings_unaug)
  212. return self
  213. def fill_from_augmented_normalized_batch(self, batch_aug_norm):
  214. """
  215. Fill this batch with (normalized) augmentation results.
  216. This method receives a (normalized) Batch instance, takes all
  217. ``*_aug`` attributes out if it and assigns them to this
  218. batch *in unnormalized form*. Hence, the datatypes of all ``*_aug``
  219. attributes will match the datatypes of the ``*_unaug`` attributes.
  220. Parameters
  221. ----------
  222. batch_aug_norm: imgaug.augmentables.batches.Batch
  223. Batch after normalization and augmentation.
  224. Returns
  225. -------
  226. imgaug.augmentables.batches.UnnormalizedBatch
  227. New UnnormalizedBatch instance. All ``*_unaug`` attributes are
  228. taken from the old UnnormalizedBatch (without deepcopying them)
  229. and all ``*_aug`` attributes are taken from `batch_normalized`,
  230. converted to unnormalized form.
  231. """
  232. # we take here the .data from the normalized batch instead of from
  233. # self for the rare case where one has decided to somehow change it
  234. # during augmentation
  235. batch = UnnormalizedBatch(
  236. images=self.images_unaug,
  237. heatmaps=self.heatmaps_unaug,
  238. segmentation_maps=self.segmentation_maps_unaug,
  239. keypoints=self.keypoints_unaug,
  240. bounding_boxes=self.bounding_boxes_unaug,
  241. polygons=self.polygons_unaug,
  242. line_strings=self.line_strings_unaug,
  243. data=batch_aug_norm.data
  244. )
  245. batch.images_aug = nlib.invert_normalize_images(
  246. batch_aug_norm.images_aug, self.images_unaug)
  247. batch.heatmaps_aug = nlib.invert_normalize_heatmaps(
  248. batch_aug_norm.heatmaps_aug, self.heatmaps_unaug)
  249. batch.segmentation_maps_aug = nlib.invert_normalize_segmentation_maps(
  250. batch_aug_norm.segmentation_maps_aug, self.segmentation_maps_unaug)
  251. batch.keypoints_aug = nlib.invert_normalize_keypoints(
  252. batch_aug_norm.keypoints_aug, self.keypoints_unaug)
  253. batch.bounding_boxes_aug = nlib.invert_normalize_bounding_boxes(
  254. batch_aug_norm.bounding_boxes_aug, self.bounding_boxes_unaug)
  255. batch.polygons_aug = nlib.invert_normalize_polygons(
  256. batch_aug_norm.polygons_aug, self.polygons_unaug)
  257. batch.line_strings_aug = nlib.invert_normalize_line_strings(
  258. batch_aug_norm.line_strings_aug, self.line_strings_unaug)
  259. return batch
  260. class Batch(object):
  261. """
  262. Class encapsulating a batch before and after augmentation.
  263. Parameters
  264. ----------
  265. images : None or (N,H,W,C) ndarray or list of (H,W,C) ndarray
  266. The images to augment.
  267. heatmaps : None or list of imgaug.augmentables.heatmaps.HeatmapsOnImage
  268. The heatmaps to augment.
  269. segmentation_maps : None or list of imgaug.augmentables.segmaps.SegmentationMapsOnImage
  270. The segmentation maps to augment.
  271. keypoints : None or list of imgaug.augmentables.kps.KeypointOnImage
  272. The keypoints to augment.
  273. bounding_boxes : None or list of imgaug.augmentables.bbs.BoundingBoxesOnImage
  274. The bounding boxes to augment.
  275. polygons : None or list of imgaug.augmentables.polys.PolygonsOnImage
  276. The polygons to augment.
  277. line_strings : None or list of imgaug.augmentables.lines.LineStringsOnImage
  278. The line strings to augment.
  279. data
  280. Additional data that is saved in the batch and may be read out
  281. after augmentation. This could e.g. contain filepaths to each image
  282. in `images`. As this object is usually used for background
  283. augmentation with multiple processes, the augmented Batch objects might
  284. not be returned in the original order, making this information useful.
  285. """
  286. def __init__(self, images=None, heatmaps=None, segmentation_maps=None,
  287. keypoints=None, bounding_boxes=None, polygons=None,
  288. line_strings=None, data=None):
  289. """Construct a new :class:`Batch` instance."""
  290. self.images_unaug = images
  291. self.images_aug = None
  292. self.heatmaps_unaug = heatmaps
  293. self.heatmaps_aug = None
  294. self.segmentation_maps_unaug = segmentation_maps
  295. self.segmentation_maps_aug = None
  296. self.keypoints_unaug = keypoints
  297. self.keypoints_aug = None
  298. self.bounding_boxes_unaug = bounding_boxes
  299. self.bounding_boxes_aug = None
  300. self.polygons_unaug = polygons
  301. self.polygons_aug = None
  302. self.line_strings_unaug = line_strings
  303. self.line_strings_aug = None
  304. self.data = data
  305. @property
  306. @ia.deprecated("Batch.images_unaug")
  307. def images(self):
  308. """Get unaugmented images."""
  309. return self.images_unaug
  310. @property
  311. @ia.deprecated("Batch.heatmaps_unaug")
  312. def heatmaps(self):
  313. """Get unaugmented heatmaps."""
  314. return self.heatmaps_unaug
  315. @property
  316. @ia.deprecated("Batch.segmentation_maps_unaug")
  317. def segmentation_maps(self):
  318. """Get unaugmented segmentation maps."""
  319. return self.segmentation_maps_unaug
  320. @property
  321. @ia.deprecated("Batch.keypoints_unaug")
  322. def keypoints(self):
  323. """Get unaugmented keypoints."""
  324. return self.keypoints_unaug
  325. @property
  326. @ia.deprecated("Batch.bounding_boxes_unaug")
  327. def bounding_boxes(self):
  328. """Get unaugmented bounding boxes."""
  329. return self.bounding_boxes_unaug
  330. def get_column_names(self):
  331. """Get the names of types of augmentables that contain data.
  332. This method is intended for situations where one wants to know which
  333. data is contained in the batch that has to be augmented, visualized
  334. or something similar.
  335. Added in 0.4.0.
  336. Returns
  337. -------
  338. list of str
  339. Names of types of augmentables. E.g. ``["images", "polygons"]``.
  340. """
  341. return _get_column_names(self, "_unaug")
  342. def to_normalized_batch(self):
  343. """Return this batch.
  344. This method does nothing and only exists to simplify interfaces
  345. that accept both :class:`UnnormalizedBatch` and :class:`Batch`.
  346. Added in 0.4.0.
  347. Returns
  348. -------
  349. imgaug.augmentables.batches.Batch
  350. This batch (not copied).
  351. """
  352. return self
  353. def to_batch_in_augmentation(self):
  354. """Convert this batch to a :class:`_BatchInAugmentation` instance.
  355. Added in 0.4.0.
  356. Returns
  357. -------
  358. imgaug.augmentables.batches._BatchInAugmentation
  359. The converted batch.
  360. """
  361. def _copy(var):
  362. # TODO first check here if _aug is set and if it is then use that?
  363. if var is not None:
  364. return utils.copy_augmentables(var)
  365. return var
  366. return _BatchInAugmentation(
  367. images=_copy(self.images_unaug),
  368. heatmaps=_copy(self.heatmaps_unaug),
  369. segmentation_maps=_copy(self.segmentation_maps_unaug),
  370. keypoints=_copy(self.keypoints_unaug),
  371. bounding_boxes=_copy(self.bounding_boxes_unaug),
  372. polygons=_copy(self.polygons_unaug),
  373. line_strings=_copy(self.line_strings_unaug)
  374. )
  375. def fill_from_batch_in_augmentation_(self, batch_in_augmentation):
  376. """Set the columns in this batch to the column values of another batch.
  377. This method works in-place.
  378. Added in 0.4.0.
  379. Parameters
  380. ----------
  381. batch_in_augmentation : _BatchInAugmentation
  382. Batch of which to use the column values.
  383. The values are *not* copied. Only their references are used.
  384. Returns
  385. -------
  386. Batch
  387. The updated batch. (Modified in-place.)
  388. """
  389. self.images_aug = batch_in_augmentation.images
  390. self.heatmaps_aug = batch_in_augmentation.heatmaps
  391. self.segmentation_maps_aug = batch_in_augmentation.segmentation_maps
  392. self.keypoints_aug = batch_in_augmentation.keypoints
  393. self.bounding_boxes_aug = batch_in_augmentation.bounding_boxes
  394. self.polygons_aug = batch_in_augmentation.polygons
  395. self.line_strings_aug = batch_in_augmentation.line_strings
  396. return self
  397. def deepcopy(self,
  398. images_unaug=DEFAULT,
  399. images_aug=DEFAULT,
  400. heatmaps_unaug=DEFAULT,
  401. heatmaps_aug=DEFAULT,
  402. segmentation_maps_unaug=DEFAULT,
  403. segmentation_maps_aug=DEFAULT,
  404. keypoints_unaug=DEFAULT,
  405. keypoints_aug=DEFAULT,
  406. bounding_boxes_unaug=DEFAULT,
  407. bounding_boxes_aug=DEFAULT,
  408. polygons_unaug=DEFAULT,
  409. polygons_aug=DEFAULT,
  410. line_strings_unaug=DEFAULT,
  411. line_strings_aug=DEFAULT):
  412. """Copy this batch and all of its column values.
  413. Parameters
  414. ----------
  415. images_unaug : imgaug.augmentables.batches.DEFAULT or None or (N,H,W,C) ndarray or list of (H,W,C) ndarray
  416. Copies the current attribute value without changes if set to
  417. ``imgaug.augmentables.batches.DEFAULT``.
  418. Otherwise same as in :func:`Batch.__init__`.
  419. images_aug : imgaug.augmentables.batches.DEFAULT or None or (N,H,W,C) ndarray or list of (H,W,C) ndarray
  420. Copies the current attribute value without changes if set to
  421. ``imgaug.augmentables.batches.DEFAULT``.
  422. Otherwise same as in :func:`Batch.__init__`.
  423. heatmaps_unaug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.heatmaps.HeatmapsOnImage
  424. Copies the current attribute value without changes if set to
  425. ``imgaug.augmentables.batches.DEFAULT``.
  426. Otherwise same as in :func:`Batch.__init__`.
  427. heatmaps_aug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.heatmaps.HeatmapsOnImage
  428. Copies the current attribute value without changes if set to
  429. ``imgaug.augmentables.batches.DEFAULT``.
  430. Otherwise same as in :func:`Batch.__init__`.
  431. segmentation_maps_unaug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.segmaps.SegmentationMapsOnImage
  432. Copies the current attribute value without changes if set to
  433. ``imgaug.augmentables.batches.DEFAULT``.
  434. Otherwise same as in :func:`Batch.__init__`.
  435. segmentation_maps_aug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.segmaps.SegmentationMapsOnImage
  436. Copies the current attribute value without changes if set to
  437. ``imgaug.augmentables.batches.DEFAULT``.
  438. Otherwise same as in :func:`Batch.__init__`.
  439. keypoints_unaug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.kps.KeypointOnImage
  440. Copies the current attribute value without changes if set to
  441. ``imgaug.augmentables.batches.DEFAULT``.
  442. Otherwise same as in :func:`Batch.__init__`.
  443. keypoints_aug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.kps.KeypointOnImage
  444. Copies the current attribute value without changes if set to
  445. ``imgaug.augmentables.batches.DEFAULT``.
  446. Otherwise same as in :func:`Batch.__init__`.
  447. bounding_boxes_unaug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.bbs.BoundingBoxesOnImage
  448. Copies the current attribute value without changes if set to
  449. ``imgaug.augmentables.batches.DEFAULT``.
  450. Otherwise same as in :func:`Batch.__init__`.
  451. bounding_boxes_aug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.bbs.BoundingBoxesOnImage
  452. Copies the current attribute value without changes if set to
  453. ``imgaug.augmentables.batches.DEFAULT``.
  454. Otherwise same as in :func:`Batch.__init__`.
  455. polygons_unaug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.polys.PolygonsOnImage
  456. Copies the current attribute value without changes if set to
  457. ``imgaug.augmentables.batches.DEFAULT``.
  458. Otherwise same as in :func:`Batch.__init__`.
  459. polygons_aug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.polys.PolygonsOnImage
  460. Copies the current attribute value without changes if set to
  461. ``imgaug.augmentables.batches.DEFAULT``.
  462. Otherwise same as in :func:`Batch.__init__`.
  463. line_strings_unaug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.lines.LineStringsOnImage
  464. Copies the current attribute value without changes if set to
  465. ``imgaug.augmentables.batches.DEFAULT``.
  466. Otherwise same as in :func:`Batch.__init__`.
  467. line_strings_aug : imgaug.augmentables.batches.DEFAULT or None or list of imgaug.augmentables.lines.LineStringsOnImage
  468. Copies the current attribute value without changes if set to
  469. ``imgaug.augmentables.batches.DEFAULT``.
  470. Otherwise same as in :func:`Batch.__init__`.
  471. Returns
  472. -------
  473. Batch
  474. Deep copy of the batch, optionally with new attributes.
  475. """
  476. def _copy_optional(self_attr, arg):
  477. return utils.deepcopy_fast(arg if arg is not DEFAULT else self_attr)
  478. batch = Batch(
  479. images=_copy_optional(self.images_unaug, images_unaug),
  480. heatmaps=_copy_optional(self.heatmaps_unaug, heatmaps_unaug),
  481. segmentation_maps=_copy_optional(self.segmentation_maps_unaug,
  482. segmentation_maps_unaug),
  483. keypoints=_copy_optional(self.keypoints_unaug, keypoints_unaug),
  484. bounding_boxes=_copy_optional(self.bounding_boxes_unaug,
  485. bounding_boxes_unaug),
  486. polygons=_copy_optional(self.polygons_unaug, polygons_unaug),
  487. line_strings=_copy_optional(self.line_strings_unaug,
  488. line_strings_unaug),
  489. data=utils.deepcopy_fast(self.data)
  490. )
  491. batch.images_aug = _copy_optional(self.images_aug, images_aug)
  492. batch.heatmaps_aug = _copy_optional(self.heatmaps_aug, heatmaps_aug)
  493. batch.segmentation_maps_aug = _copy_optional(self.segmentation_maps_aug,
  494. segmentation_maps_aug)
  495. batch.keypoints_aug = _copy_optional(self.keypoints_aug, keypoints_aug)
  496. batch.bounding_boxes_aug = _copy_optional(self.bounding_boxes_aug,
  497. bounding_boxes_aug)
  498. batch.polygons_aug = _copy_optional(self.polygons_aug, polygons_aug)
  499. batch.line_strings_aug = _copy_optional(self.line_strings_aug,
  500. line_strings_aug)
  501. return batch
  502. # Added in 0.4.0.
  503. class _BatchInAugmentationPropagationContext(object):
  504. def __init__(self, batch, augmenter, hooks, parents):
  505. self.batch = batch
  506. self.augmenter = augmenter
  507. self.hooks = hooks
  508. self.parents = parents
  509. self.noned_info = None
  510. def __enter__(self):
  511. if self.hooks is not None:
  512. self.noned_info = self.batch.apply_propagation_hooks_(
  513. self.augmenter, self.hooks, self.parents)
  514. return self.batch
  515. def __exit__(self, exc_type, exc_val, exc_tb):
  516. if self.noned_info is not None:
  517. self.batch = \
  518. self.batch.invert_apply_propagation_hooks_(self.noned_info)
  519. class _BatchInAugmentation(object):
  520. """
  521. Class encapsulating a batch during the augmentation process.
  522. Data within the batch is already verified and normalized, similar to
  523. :class:`Batch`. Data within the batch may be changed in-place. No initial
  524. copy is needed.
  525. Added in 0.4.0.
  526. Parameters
  527. ----------
  528. images : None or (N,H,W,C) ndarray or list of (H,W,C) ndarray
  529. The images to augment.
  530. heatmaps : None or list of imgaug.augmentables.heatmaps.HeatmapsOnImage
  531. The heatmaps to augment.
  532. segmentation_maps : None or list of imgaug.augmentables.segmaps.SegmentationMapsOnImage
  533. The segmentation maps to augment.
  534. keypoints : None or list of imgaug.augmentables.kps.KeypointOnImage
  535. The keypoints to augment.
  536. bounding_boxes : None or list of imgaug.augmentables.bbs.BoundingBoxesOnImage
  537. The bounding boxes to augment.
  538. polygons : None or list of imgaug.augmentables.polys.PolygonsOnImage
  539. The polygons to augment.
  540. line_strings : None or list of imgaug.augmentables.lines.LineStringsOnImage
  541. The line strings to augment.
  542. """
  543. # Added in 0.4.0.
  544. def __init__(self, images=None, heatmaps=None, segmentation_maps=None,
  545. keypoints=None, bounding_boxes=None, polygons=None,
  546. line_strings=None, data=None):
  547. """Create a new :class:`_BatchInAugmentation` instance."""
  548. self.images = images
  549. self.heatmaps = heatmaps
  550. self.segmentation_maps = segmentation_maps
  551. self.keypoints = keypoints
  552. self.bounding_boxes = bounding_boxes
  553. self.polygons = polygons
  554. self.line_strings = line_strings
  555. self.data = data
  556. @property
  557. def empty(self):
  558. """Estimate whether this batch is empty, i.e. contains no data.
  559. Added in 0.4.0.
  560. Returns
  561. -------
  562. bool
  563. ``True`` if the batch contains no data to augment.
  564. ``False`` otherwise.
  565. """
  566. return self.nb_rows == 0
  567. @property
  568. def nb_rows(self):
  569. """Get the number of rows (i.e. examples) in this batch.
  570. Note that this method assumes that all columns have the same number
  571. of rows.
  572. Added in 0.4.0.
  573. Returns
  574. -------
  575. int
  576. Number of rows or ``0`` if there is no data in the batch.
  577. """
  578. for augm_name in _AUGMENTABLE_NAMES:
  579. value = getattr(self, augm_name)
  580. if value is not None:
  581. return len(value)
  582. return 0
  583. @property
  584. def columns(self):
  585. """Get the columns of data to augment.
  586. Each column represents one datatype and its corresponding data,
  587. e.g. images or polygons.
  588. Added in 0.4.0.
  589. Returns
  590. -------
  591. list of _AugmentableColumn
  592. The columns to augment within this batch.
  593. """
  594. return _get_columns(self, "")
  595. def get_column_names(self):
  596. """Get the names of types of augmentables that contain data.
  597. This method is intended for situations where one wants to know which
  598. data is contained in the batch that has to be augmented, visualized
  599. or something similar.
  600. Added in 0.4.0.
  601. Returns
  602. -------
  603. list of str
  604. Names of types of augmentables. E.g. ``["images", "polygons"]``.
  605. """
  606. return _get_column_names(self, "")
  607. def get_rowwise_shapes(self):
  608. """Get the shape of each row within this batch.
  609. Each row denotes the data of different types (e.g. image array,
  610. polygons) corresponding to a single example in the batch.
  611. This method assumes that all ``.shape`` attributes contain the same
  612. shape and that it is identical to the image's shape.
  613. It also assumes that there are no columns containing only ``None`` s.
  614. Added in 0.4.0.
  615. Returns
  616. -------
  617. list of tuple of int
  618. The shapes of each row.
  619. """
  620. nb_rows = self.nb_rows
  621. columns = self.columns
  622. shapes = [None] * nb_rows
  623. found = np.zeros((nb_rows,), dtype=bool)
  624. for column in columns:
  625. if column.name == "images" and ia.is_np_array(column.value):
  626. shapes = [column.value.shape[1:]] * nb_rows
  627. else:
  628. for i, item in enumerate(column.value):
  629. if item is not None:
  630. shapes[i] = item.shape
  631. found[i] = True
  632. if np.all(found):
  633. return shapes
  634. return shapes
  635. def subselect_rows_by_indices(self, indices):
  636. """Reduce this batch to a subset of rows based on their row indices.
  637. Added in 0.4.0.
  638. Parameters
  639. ----------
  640. indices : iterable of int
  641. Row indices to select.
  642. Returns
  643. -------
  644. _BatchInAugmentation
  645. Batch containing only a subselection of rows.
  646. """
  647. kwargs = {"data": self.data}
  648. for augm_name in _AUGMENTABLE_NAMES:
  649. rows = getattr(self, augm_name)
  650. if rows is not None:
  651. if augm_name == "images" and ia.is_np_array(rows):
  652. rows = rows[indices] # pylint: disable=unsubscriptable-object
  653. else:
  654. rows = [rows[index] for index in indices]
  655. if len(rows) == 0:
  656. rows = None
  657. kwargs[augm_name] = rows
  658. return _BatchInAugmentation(**kwargs)
  659. def invert_subselect_rows_by_indices_(self, indices, batch_subselected):
  660. """Reverse the subselection of rows in-place.
  661. This is the inverse of
  662. :func:`_BatchInAugmentation.subselect_rows_by_indices`.
  663. This method has to be executed on the batch *before* subselection.
  664. Added in 0.4.0.
  665. Parameters
  666. ----------
  667. indices : iterable of int
  668. Row indices that were selected. (This is the input to
  669. batch_subselected : _BatchInAugmentation
  670. The batch after
  671. :func:`_BatchInAugmentation.subselect_rows_by_indices` was called.
  672. Returns
  673. -------
  674. _BatchInAugmentation
  675. The updated batch. (Modified in-place.)
  676. Examples
  677. --------
  678. >>> import numpy as np
  679. >>> from imgaug.augmentables.batches import _BatchInAugmentation
  680. >>> images = np.zeros((2, 10, 20, 3), dtype=np.uint8)
  681. >>> batch = _BatchInAugmentation(images=images)
  682. >>> batch_sub = batch.subselect_rows_by_indices([0])
  683. >>> batch_sub.images += 1
  684. >>> batch = batch.invert_subselect_rows_by_indices_([0], batch_sub)
  685. """
  686. for augm_name in _AUGMENTABLE_NAMES:
  687. column = getattr(self, augm_name)
  688. if column is not None:
  689. column_sub = getattr(batch_subselected, augm_name)
  690. if column_sub is None:
  691. # list of indices was empty, resulting in the columns
  692. # in the subselected batch being empty and replaced
  693. # by Nones. We can just re-use the columns before
  694. # subselection.
  695. pass
  696. elif augm_name == "images" and ia.is_np_array(column):
  697. # An array does not have to stay an array after
  698. # augmentation. The shapes and/or dtypes of rows may
  699. # change, turning the array into a list.
  700. if ia.is_np_array(column_sub):
  701. shapes = {column.shape[1:], column_sub.shape[1:]}
  702. dtypes = {column.dtype.name, column_sub.dtype.name}
  703. else:
  704. shapes = set(
  705. [column.shape[1:]]
  706. + [image.shape for image in column_sub])
  707. dtypes = set(
  708. [column.dtype.name]
  709. + [image.dtype.name for image in column_sub])
  710. if len(shapes) == 1 and len(dtypes) == 1:
  711. column[indices] = column_sub # pylint: disable=unsupported-assignment-operation
  712. else:
  713. self.images = list(column)
  714. for ith_index, index in enumerate(indices):
  715. self.images[index] = column_sub[ith_index]
  716. else:
  717. for ith_index, index in enumerate(indices):
  718. column[index] = column_sub[ith_index] # pylint: disable=unsupported-assignment-operation
  719. return self
  720. def propagation_hooks_ctx(self, augmenter, hooks, parents):
  721. """Start a context in which propagation hooks are applied.
  722. Added in 0.4.0.
  723. Parameters
  724. ----------
  725. augmenter : imgaug.augmenters.meta.Augmenter
  726. Augmenter to provide to the propagation hook function.
  727. hooks : imgaug.imgaug.HooksImages or imgaug.imgaug.HooksKeypoints
  728. The hooks that might contain a propagation hook function.
  729. parents : list of imgaug.augmenters.meta.Augmenter
  730. The list of parents to provide to the propagation hook function.
  731. Returns
  732. -------
  733. _BatchInAugmentationPropagationContext
  734. The progagation hook context.
  735. """
  736. return _BatchInAugmentationPropagationContext(
  737. self, augmenter=augmenter, hooks=hooks, parents=parents)
  738. def apply_propagation_hooks_(self, augmenter, hooks, parents):
  739. """Set columns in this batch to ``None`` based on a propagation hook.
  740. This method works in-place.
  741. Added in 0.4.0.
  742. Parameters
  743. ----------
  744. augmenter : imgaug.augmenters.meta.Augmenter
  745. Augmenter to provide to the propagation hook function.
  746. hooks : imgaug.imgaug.HooksImages or imgaug.imgaug.HooksKeypoints
  747. The hooks that might contain a propagation hook function.
  748. parents : list of imgaug.augmenters.meta.Augmenter
  749. The list of parents to provide to the propagation hook function.
  750. Returns
  751. -------
  752. list of tuple of str
  753. Information about which columns were set to ``None``.
  754. Each tuple contains
  755. ``(column attribute name, column value before setting it to None)``.
  756. This information is required when calling
  757. :func:`_BatchInAugmentation.invert_apply_propagation_hooks_`.
  758. """
  759. if hooks is None:
  760. return None
  761. noned_info = []
  762. for column in self.columns:
  763. is_prop = hooks.is_propagating(
  764. column.value, augmenter=augmenter, parents=parents,
  765. default=True)
  766. if not is_prop:
  767. setattr(self, column.attr_name, None)
  768. noned_info.append((column.attr_name, column.value))
  769. return noned_info
  770. def invert_apply_propagation_hooks_(self, noned_info):
  771. """Set columns from ``None`` back to their original values.
  772. This is the inverse of
  773. :func:`_BatchInAugmentation.apply_propagation_hooks_`.
  774. This method works in-place.
  775. Added in 0.4.0.
  776. Parameters
  777. ----------
  778. noned_info : list of tuple of str
  779. Information about which columns were set to ``None`` and their
  780. original values. This is the output of
  781. :func:`_BatchInAugmentation.apply_propagation_hooks_`.
  782. Returns
  783. -------
  784. _BatchInAugmentation
  785. The updated batch. (Modified in-place.)
  786. """
  787. for attr_name, value in noned_info:
  788. setattr(self, attr_name, value)
  789. return self
  790. def to_batch_in_augmentation(self):
  791. """Convert this batch to a :class:`_BatchInAugmentation` instance.
  792. This method simply returns the batch itself. It exists for consistency
  793. with the other batch classes.
  794. Added in 0.4.0.
  795. Returns
  796. -------
  797. imgaug.augmentables.batches._BatchInAugmentation
  798. The batch itself. (Not copied.)
  799. """
  800. return self
  801. def fill_from_batch_in_augmentation_(self, batch_in_augmentation):
  802. """Set the columns in this batch to the column values of another batch.
  803. This method works in-place.
  804. Added in 0.4.0.
  805. Parameters
  806. ----------
  807. batch_in_augmentation : _BatchInAugmentation
  808. Batch of which to use the column values.
  809. The values are *not* copied. Only their references are used.
  810. Returns
  811. -------
  812. _BatchInAugmentation
  813. The updated batch. (Modified in-place.)
  814. """
  815. if batch_in_augmentation is self:
  816. return self
  817. self.images = batch_in_augmentation.images
  818. self.heatmaps = batch_in_augmentation.heatmaps
  819. self.segmentation_maps = batch_in_augmentation.segmentation_maps
  820. self.keypoints = batch_in_augmentation.keypoints
  821. self.bounding_boxes = batch_in_augmentation.bounding_boxes
  822. self.polygons = batch_in_augmentation.polygons
  823. self.line_strings = batch_in_augmentation.line_strings
  824. return self
  825. def to_batch(self, batch_before_aug):
  826. """Convert this batch into a :class:`Batch` instance.
  827. Added in 0.4.0.
  828. Parameters
  829. ----------
  830. batch_before_aug : imgaug.augmentables.batches.Batch
  831. The batch before augmentation. It is required to set the input
  832. data of the :class:`Batch` instance, e.g. ``images_unaug``
  833. or ``data``.
  834. Returns
  835. -------
  836. imgaug.augmentables.batches.Batch
  837. Batch, with original unaugmented inputs from `batch_before_aug`
  838. and augmented outputs from this :class:`_BatchInAugmentation`
  839. instance.
  840. """
  841. batch = Batch(
  842. images=batch_before_aug.images_unaug,
  843. heatmaps=batch_before_aug.heatmaps_unaug,
  844. segmentation_maps=batch_before_aug.segmentation_maps_unaug,
  845. keypoints=batch_before_aug.keypoints_unaug,
  846. bounding_boxes=batch_before_aug.bounding_boxes_unaug,
  847. polygons=batch_before_aug.polygons_unaug,
  848. line_strings=batch_before_aug.line_strings_unaug,
  849. data=batch_before_aug.data
  850. )
  851. batch.images_aug = self.images
  852. batch.heatmaps_aug = self.heatmaps
  853. batch.segmentation_maps_aug = self.segmentation_maps
  854. batch.keypoints_aug = self.keypoints
  855. batch.bounding_boxes_aug = self.bounding_boxes
  856. batch.polygons_aug = self.polygons
  857. batch.line_strings_aug = self.line_strings
  858. return batch
  859. def deepcopy(self):
  860. """Copy this batch and all of its column values.
  861. Added in 0.4.0.
  862. Returns
  863. -------
  864. _BatchInAugmentation
  865. Deep copy of this batch.
  866. """
  867. batch = _BatchInAugmentation(data=utils.deepcopy_fast(self.data))
  868. for augm_name in _AUGMENTABLE_NAMES:
  869. value = getattr(self, augm_name)
  870. if value is not None:
  871. setattr(batch, augm_name, utils.copy_augmentables(value))
  872. return batch