segmaps.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. """Classes dealing with segmentation maps.
  2. E.g. masks, semantic or instance segmentation maps.
  3. """
  4. from __future__ import print_function, division, absolute_import
  5. import numpy as np
  6. import six.moves as sm
  7. from .. import imgaug as ia
  8. from ..augmenters import blend as blendlib
  9. from .base import IAugmentable
  10. @ia.deprecated(alt_func="SegmentationMapsOnImage",
  11. comment="(Note the plural 'Maps' instead of old 'Map'.)")
  12. def SegmentationMapOnImage(*args, **kwargs):
  13. """Object representing a segmentation map associated with an image."""
  14. # pylint: disable=invalid-name
  15. return SegmentationMapsOnImage(*args, **kwargs)
  16. class SegmentationMapsOnImage(IAugmentable):
  17. """
  18. Object representing a segmentation map associated with an image.
  19. Attributes
  20. ----------
  21. DEFAULT_SEGMENT_COLORS : list of tuple of int
  22. Standard RGB colors to use during drawing, ordered by class index.
  23. Parameters
  24. ----------
  25. arr : (H,W) ndarray or (H,W,C) ndarray
  26. Array representing the segmentation map(s). May have dtypes bool,
  27. int or uint.
  28. shape : tuple of int
  29. Shape of the image on which the segmentation map(s) is/are placed.
  30. **Not** the shape of the segmentation map(s) array, unless it is
  31. identical to the image shape (note the likely difference between the
  32. arrays in the number of channels).
  33. This is expected to be ``(H, W)`` or ``(H, W, C)`` with ``C`` usually
  34. being ``3``.
  35. If there is no corresponding image, use ``(H_arr, W_arr)`` instead,
  36. where ``H_arr`` is the height of the segmentation map(s) array
  37. (analogous ``W_arr``).
  38. nb_classes : None or int, optional
  39. Deprecated.
  40. """
  41. # TODO replace this by matplotlib colormap
  42. DEFAULT_SEGMENT_COLORS = [
  43. (0, 0, 0), # black
  44. (230, 25, 75), # red
  45. (60, 180, 75), # green
  46. (255, 225, 25), # yellow
  47. (0, 130, 200), # blue
  48. (245, 130, 48), # orange
  49. (145, 30, 180), # purple
  50. (70, 240, 240), # cyan
  51. (240, 50, 230), # magenta
  52. (210, 245, 60), # lime
  53. (250, 190, 190), # pink
  54. (0, 128, 128), # teal
  55. (230, 190, 255), # lavender
  56. (170, 110, 40), # brown
  57. (255, 250, 200), # beige
  58. (128, 0, 0), # maroon
  59. (170, 255, 195), # mint
  60. (128, 128, 0), # olive
  61. (255, 215, 180), # coral
  62. (0, 0, 128), # navy
  63. (128, 128, 128), # grey
  64. (255, 255, 255), # white
  65. # --
  66. (115, 12, 37), # dark red
  67. (30, 90, 37), # dark green
  68. (127, 112, 12), # dark yellow
  69. (0, 65, 100), # dark blue
  70. (122, 65, 24), # dark orange
  71. (72, 15, 90), # dark purple
  72. (35, 120, 120), # dark cyan
  73. (120, 25, 115), # dark magenta
  74. (105, 122, 30), # dark lime
  75. (125, 95, 95), # dark pink
  76. (0, 64, 64), # dark teal
  77. (115, 95, 127), # dark lavender
  78. (85, 55, 20), # dark brown
  79. (127, 125, 100), # dark beige
  80. (64, 0, 0), # dark maroon
  81. (85, 127, 97), # dark mint
  82. (64, 64, 0), # dark olive
  83. (127, 107, 90), # dark coral
  84. (0, 0, 64), # dark navy
  85. (64, 64, 64), # dark grey
  86. ]
  87. def __init__(self, arr, shape, nb_classes=None):
  88. assert ia.is_np_array(arr), (
  89. "Expected to get numpy array, got %s." % (type(arr),))
  90. assert arr.ndim in [2, 3], (
  91. "Expected segmentation map array to be 2- or "
  92. "3-dimensional, got %d dimensions and shape %s." % (
  93. arr.ndim, arr.shape))
  94. assert isinstance(shape, tuple), (
  95. "Expected 'shape' to be a tuple denoting the shape of the image "
  96. "on which the segmentation map is placed. Got type %s instead." % (
  97. type(shape)))
  98. if arr.dtype.kind == "f":
  99. ia.warn_deprecated(
  100. "Got a float array as the segmentation map in "
  101. "SegmentationMapsOnImage. That is deprecated. Please provide "
  102. "instead a (H,W,[C]) array of dtype bool_, int or uint, where "
  103. "C denotes the segmentation map index."
  104. )
  105. if arr.ndim == 2:
  106. arr = (arr > 0.5)
  107. else: # arr.ndim == 3
  108. arr = np.argmax(arr, axis=2).astype(np.int32)
  109. if arr.dtype.name == "bool":
  110. self._input_was = (arr.dtype, arr.ndim)
  111. if arr.ndim == 2:
  112. arr = arr[..., np.newaxis]
  113. elif arr.dtype.kind in ["i", "u"]:
  114. assert np.min(arr.flat[0:100]) >= 0, (
  115. "Expected segmentation map array to only contain values >=0, "
  116. "got a minimum of %d." % (np.min(arr),))
  117. if arr.dtype.kind == "u":
  118. # allow only <=uint16 due to conversion to int32
  119. assert arr.dtype.itemsize <= 2, (
  120. "When using uint arrays as segmentation maps, only uint8 "
  121. "and uint16 are allowed. Got dtype %s." % (arr.dtype.name,)
  122. )
  123. elif arr.dtype.kind == "i":
  124. # allow only <=uint16 due to conversion to int32
  125. assert arr.dtype.itemsize <= 4, (
  126. "When using int arrays as segmentation maps, only int8, "
  127. "int16 and int32 are allowed. Got dtype %s." % (
  128. arr.dtype.name,)
  129. )
  130. self._input_was = (arr.dtype, arr.ndim)
  131. if arr.ndim == 2:
  132. arr = arr[..., np.newaxis]
  133. else:
  134. raise Exception((
  135. "Input was expected to be an array of dtype 'bool', 'int' "
  136. "or 'uint'. Got dtype '%s'.") % (arr.dtype.name,))
  137. if arr.dtype.name != "int32":
  138. arr = arr.astype(np.int32)
  139. self.arr = arr
  140. # don't allow arrays here as an alternative to tuples as input
  141. # as allowing arrays introduces risk to mix up 'arr' and 'shape' args
  142. self.shape = shape
  143. if nb_classes is not None:
  144. ia.warn_deprecated(
  145. "Providing nb_classes to SegmentationMapsOnImage is no longer "
  146. "necessary and hence deprecated. The argument is ignored "
  147. "and can be safely removed.")
  148. def get_arr(self):
  149. """Return the seg.map array, with original dtype and shape ndim.
  150. Here, "original" denotes the dtype and number of shape dimensions that
  151. was used when the :class:`SegmentationMapsOnImage` instance was
  152. created, i.e. upon the call of
  153. :func:`SegmentationMapsOnImage.__init__`.
  154. Internally, this class may use a different dtype and shape to simplify
  155. computations.
  156. .. note::
  157. The height and width may have changed compared to the original
  158. input due to e.g. pooling operations.
  159. Returns
  160. -------
  161. ndarray
  162. Segmentation map array.
  163. Same dtype and number of dimensions as was originally used when
  164. the :class:`SegmentationMapsOnImage` instance was created.
  165. """
  166. input_dtype, input_ndim = self._input_was
  167. # The internally used int32 has a wider value range than any other
  168. # input dtype, hence we can simply convert via astype() here.
  169. arr_input = self.arr.astype(input_dtype)
  170. if input_ndim == 2:
  171. assert arr_input.shape[2] == 1, (
  172. "Originally got a (H,W) segmentation map. Internal array "
  173. "should now have shape (H,W,1), but got %s. This might be "
  174. "an internal error." % (arr_input.shape,))
  175. return arr_input[:, :, 0]
  176. return arr_input
  177. @ia.deprecated(alt_func="SegmentationMapsOnImage.get_arr()")
  178. def get_arr_int(self, *args, **kwargs):
  179. """Return the seg.map array, with original dtype and shape ndim."""
  180. # pylint: disable=unused-argument
  181. return self.get_arr()
  182. def draw(self, size=None, colors=None):
  183. """
  184. Render the segmentation map as an RGB image.
  185. Parameters
  186. ----------
  187. size : None or float or iterable of int or iterable of float, optional
  188. Size of the rendered RGB image as ``(height, width)``.
  189. See :func:`~imgaug.imgaug.imresize_single_image` for details.
  190. If set to ``None``, no resizing is performed and the size of the
  191. segmentation map array is used.
  192. colors : None or list of tuple of int, optional
  193. Colors to use. One for each class to draw.
  194. If ``None``, then default colors will be used.
  195. Returns
  196. -------
  197. list of (H,W,3) ndarray
  198. Rendered segmentation map (dtype is ``uint8``).
  199. One per ``C`` in the original input array ``(H,W,C)``.
  200. """
  201. def _handle_sizeval(sizeval, arr_axis_size):
  202. if sizeval is None:
  203. return arr_axis_size
  204. if ia.is_single_float(sizeval):
  205. return max(int(arr_axis_size * sizeval), 1)
  206. if ia.is_single_integer(sizeval):
  207. return sizeval
  208. raise ValueError("Expected float or int, got %s." % (
  209. type(sizeval),))
  210. if size is None:
  211. size = [size, size]
  212. elif not ia.is_iterable(size):
  213. size = [size, size]
  214. height = _handle_sizeval(size[0], self.arr.shape[0])
  215. width = _handle_sizeval(size[1], self.arr.shape[1])
  216. image = np.zeros((height, width, 3), dtype=np.uint8)
  217. return self.draw_on_image(
  218. image,
  219. alpha=1.0,
  220. resize="segmentation_map",
  221. colors=colors,
  222. draw_background=True
  223. )
  224. def draw_on_image(self, image, alpha=0.75, resize="segmentation_map",
  225. colors=None, draw_background=False,
  226. background_class_id=0, background_threshold=None):
  227. """Draw the segmentation map as an overlay over an image.
  228. Parameters
  229. ----------
  230. image : (H,W,3) ndarray
  231. Image onto which to draw the segmentation map. Expected dtype
  232. is ``uint8``.
  233. alpha : float, optional
  234. Alpha/opacity value to use for the mixing of image and
  235. segmentation map. Larger values mean that the segmentation map
  236. will be more visible and the image less visible.
  237. resize : {'segmentation_map', 'image'}, optional
  238. In case of size differences between the image and segmentation
  239. map, either the image or the segmentation map can be resized.
  240. This parameter controls which of the two will be resized to the
  241. other's size.
  242. colors : None or list of tuple of int, optional
  243. Colors to use. One for each class to draw.
  244. If ``None``, then default colors will be used.
  245. draw_background : bool, optional
  246. If ``True``, the background will be drawn like any other class.
  247. If ``False``, the background will not be drawn, i.e. the respective
  248. background pixels will be identical with the image's RGB color at
  249. the corresponding spatial location and no color overlay will be
  250. applied.
  251. background_class_id : int, optional
  252. Class id to interpret as the background class.
  253. See `draw_background`.
  254. background_threshold : None, optional
  255. Deprecated.
  256. This parameter is ignored.
  257. Returns
  258. -------
  259. list of (H,W,3) ndarray
  260. Rendered overlays as ``uint8`` arrays.
  261. Always a **list** containing one RGB image per segmentation map
  262. array channel.
  263. """
  264. if background_threshold is not None:
  265. ia.warn_deprecated(
  266. "The argument `background_threshold` is deprecated and "
  267. "ignored. Please don't use it anymore.")
  268. assert image.ndim == 3, (
  269. "Expected to draw on 3-dimensional image, got image with %d "
  270. "dimensions." % (image.ndim,))
  271. assert image.shape[2] == 3, (
  272. "Expected to draw on RGB image, got image with %d channels "
  273. "instead." % (image.shape[2],))
  274. assert image.dtype.name == "uint8", (
  275. "Expected to get image with dtype uint8, got dtype %s." % (
  276. image.dtype.name,))
  277. assert 0 - 1e-8 <= alpha <= 1.0 + 1e-8, (
  278. "Expected 'alpha' to be in interval [0.0, 1.0], got %.4f." % (
  279. alpha,))
  280. assert resize in ["segmentation_map", "image"], (
  281. "Expected 'resize' to be \"segmentation_map\" or \"image\", got "
  282. "%s." % (resize,))
  283. colors = (
  284. colors
  285. if colors is not None
  286. else SegmentationMapsOnImage.DEFAULT_SEGMENT_COLORS
  287. )
  288. if resize == "image":
  289. image = ia.imresize_single_image(
  290. image, self.arr.shape[0:2], interpolation="cubic")
  291. segmaps_drawn = []
  292. arr_channelwise = np.dsplit(self.arr, self.arr.shape[2])
  293. for arr in arr_channelwise:
  294. arr = arr[:, :, 0]
  295. nb_classes = 1 + np.max(arr)
  296. segmap_drawn = np.zeros((arr.shape[0], arr.shape[1], 3),
  297. dtype=np.uint8)
  298. assert nb_classes <= len(colors), (
  299. "Can't draw all %d classes as it would exceed the maximum "
  300. "number of %d available colors." % (nb_classes, len(colors),))
  301. ids_in_map = np.unique(arr)
  302. for c, color in zip(sm.xrange(nb_classes), colors):
  303. if c in ids_in_map:
  304. class_mask = (arr == c)
  305. segmap_drawn[class_mask] = color
  306. segmap_drawn = ia.imresize_single_image(
  307. segmap_drawn, image.shape[0:2], interpolation="nearest")
  308. segmap_on_image = blendlib.blend_alpha(segmap_drawn, image, alpha)
  309. if draw_background:
  310. mix = segmap_on_image
  311. else:
  312. foreground_mask = ia.imresize_single_image(
  313. (arr != background_class_id),
  314. image.shape[0:2],
  315. interpolation="nearest")
  316. # without this, the merge below does nothing
  317. foreground_mask = np.atleast_3d(foreground_mask)
  318. mix = (
  319. (~foreground_mask) * image
  320. + foreground_mask * segmap_on_image
  321. )
  322. segmaps_drawn.append(mix)
  323. return segmaps_drawn
  324. def pad(self, top=0, right=0, bottom=0, left=0, mode="constant", cval=0):
  325. """Pad the segmentation maps at their top/right/bottom/left side.
  326. Parameters
  327. ----------
  328. top : int, optional
  329. Amount of pixels to add at the top side of the segmentation map.
  330. Must be ``0`` or greater.
  331. right : int, optional
  332. Amount of pixels to add at the right side of the segmentation map.
  333. Must be ``0`` or greater.
  334. bottom : int, optional
  335. Amount of pixels to add at the bottom side of the segmentation map.
  336. Must be ``0`` or greater.
  337. left : int, optional
  338. Amount of pixels to add at the left side of the segmentation map.
  339. Must be ``0`` or greater.
  340. mode : str, optional
  341. Padding mode to use. See :func:`~imgaug.imgaug.pad` for details.
  342. cval : number, optional
  343. Value to use for padding if `mode` is ``constant``.
  344. See :func:`~imgaug.imgaug.pad` for details.
  345. Returns
  346. -------
  347. imgaug.augmentables.segmaps.SegmentationMapsOnImage
  348. Padded segmentation map with height ``H'=H+top+bottom`` and
  349. width ``W'=W+left+right``.
  350. """
  351. from ..augmenters import size as iasize
  352. arr_padded = iasize.pad(self.arr, top=top, right=right, bottom=bottom,
  353. left=left, mode=mode, cval=cval)
  354. return self.deepcopy(arr=arr_padded)
  355. def pad_to_aspect_ratio(self, aspect_ratio, mode="constant", cval=0,
  356. return_pad_amounts=False):
  357. """Pad the segmentation maps until they match a target aspect ratio.
  358. Depending on which dimension is smaller (height or width), only the
  359. corresponding sides (left/right or top/bottom) will be padded. In
  360. each case, both of the sides will be padded equally.
  361. Parameters
  362. ----------
  363. aspect_ratio : float
  364. Target aspect ratio, given as width/height. E.g. ``2.0`` denotes
  365. the image having twice as much width as height.
  366. mode : str, optional
  367. Padding mode to use.
  368. See :func:`~imgaug.imgaug.pad` for details.
  369. cval : number, optional
  370. Value to use for padding if `mode` is ``constant``.
  371. See :func:`~imgaug.imgaug.pad` for details.
  372. return_pad_amounts : bool, optional
  373. If ``False``, then only the padded instance will be returned.
  374. If ``True``, a tuple with two entries will be returned, where
  375. the first entry is the padded instance and the second entry are
  376. the amounts by which each array side was padded. These amounts are
  377. again a tuple of the form ``(top, right, bottom, left)``, with
  378. each value being an integer.
  379. Returns
  380. -------
  381. imgaug.augmentables.segmaps.SegmentationMapsOnImage
  382. Padded segmentation map as :class:`SegmentationMapsOnImage`
  383. instance.
  384. tuple of int
  385. Amounts by which the instance's array was padded on each side,
  386. given as a tuple ``(top, right, bottom, left)``.
  387. This tuple is only returned if `return_pad_amounts` was set to
  388. ``True``.
  389. """
  390. from ..augmenters import size as iasize
  391. arr_padded, pad_amounts = iasize.pad_to_aspect_ratio(
  392. self.arr,
  393. aspect_ratio=aspect_ratio,
  394. mode=mode,
  395. cval=cval,
  396. return_pad_amounts=True)
  397. segmap = self.deepcopy(arr=arr_padded)
  398. if return_pad_amounts:
  399. return segmap, pad_amounts
  400. return segmap
  401. @ia.deprecated(alt_func="SegmentationMapsOnImage.resize()",
  402. comment="resize() has the exactly same interface.")
  403. def scale(self, *args, **kwargs):
  404. """Resize the seg.map(s) array given a target size and interpolation."""
  405. return self.resize(*args, **kwargs)
  406. def resize(self, sizes, interpolation="nearest"):
  407. """Resize the seg.map(s) array given a target size and interpolation.
  408. Parameters
  409. ----------
  410. sizes : float or iterable of int or iterable of float
  411. New size of the array in ``(height, width)``.
  412. See :func:`~imgaug.imgaug.imresize_single_image` for details.
  413. interpolation : None or str or int, optional
  414. The interpolation to use during resize.
  415. Nearest neighbour interpolation (``"nearest"``) is almost always
  416. the best choice.
  417. See :func:`~imgaug.imgaug.imresize_single_image` for details.
  418. Returns
  419. -------
  420. imgaug.augmentables.segmaps.SegmentationMapsOnImage
  421. Resized segmentation map object.
  422. """
  423. arr_resized = ia.imresize_single_image(self.arr, sizes,
  424. interpolation=interpolation)
  425. return self.deepcopy(arr_resized)
  426. # TODO how best to handle changes to _input_was due to changed 'arr'?
  427. def copy(self, arr=None, shape=None):
  428. """Create a shallow copy of the segmentation map object.
  429. Parameters
  430. ----------
  431. arr : None or (H,W) ndarray or (H,W,C) ndarray, optional
  432. Optionally the `arr` attribute to use for the new segmentation map
  433. instance. Will be copied from the old instance if not provided.
  434. See
  435. :func:`~imgaug.augmentables.segmaps.SegmentationMapsOnImage.__init__`
  436. for details.
  437. shape : None or tuple of int, optional
  438. Optionally the shape attribute to use for the the new segmentation
  439. map instance. Will be copied from the old instance if not provided.
  440. See
  441. :func:`~imgaug.augmentables.segmaps.SegmentationMapsOnImage.__init__`
  442. for details.
  443. Returns
  444. -------
  445. imgaug.augmentables.segmaps.SegmentationMapsOnImage
  446. Shallow copy.
  447. """
  448. # pylint: disable=protected-access
  449. segmap = SegmentationMapsOnImage(
  450. self.arr if arr is None else arr,
  451. shape=self.shape if shape is None else shape)
  452. segmap._input_was = self._input_was
  453. return segmap
  454. def deepcopy(self, arr=None, shape=None):
  455. """Create a deep copy of the segmentation map object.
  456. Parameters
  457. ----------
  458. arr : None or (H,W) ndarray or (H,W,C) ndarray, optional
  459. Optionally the `arr` attribute to use for the new segmentation map
  460. instance. Will be copied from the old instance if not provided.
  461. See
  462. :func:`~imgaug.augmentables.segmaps.SegmentationMapsOnImage.__init__`
  463. for details.
  464. shape : None or tuple of int, optional
  465. Optionally the shape attribute to use for the the new segmentation
  466. map instance. Will be copied from the old instance if not provided.
  467. See
  468. :func:`~imgaug.augmentables.segmaps.SegmentationMapsOnImage.__init__`
  469. for details.
  470. Returns
  471. -------
  472. imgaug.augmentables.segmaps.SegmentationMapsOnImage
  473. Deep copy.
  474. """
  475. # pylint: disable=protected-access
  476. segmap = SegmentationMapsOnImage(
  477. np.copy(self.arr if arr is None else arr),
  478. shape=self.shape if shape is None else shape)
  479. segmap._input_was = self._input_was
  480. return segmap