manual_segmentation.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from functools import reduce
  2. import numpy as np
  3. from ..draw import polygon
  4. from .._shared.version_requirements import require
  5. LEFT_CLICK = 1
  6. RIGHT_CLICK = 3
  7. def _mask_from_vertices(vertices, shape, label):
  8. mask = np.zeros(shape, dtype=int)
  9. pr = [y for x, y in vertices]
  10. pc = [x for x, y in vertices]
  11. rr, cc = polygon(pr, pc, shape)
  12. mask[rr, cc] = label
  13. return mask
  14. @require("matplotlib", ">=3.3")
  15. def _draw_polygon(ax, vertices, alpha=0.4):
  16. from matplotlib.patches import Polygon
  17. from matplotlib.collections import PatchCollection
  18. import matplotlib.pyplot as plt
  19. polygon = Polygon(vertices, closed=True)
  20. p = PatchCollection([polygon], match_original=True, alpha=alpha)
  21. polygon_object = ax.add_collection(p)
  22. plt.draw()
  23. return polygon_object
  24. @require("matplotlib", ">=3.3")
  25. def manual_polygon_segmentation(image, alpha=0.4, return_all=False):
  26. """Return a label image based on polygon selections made with the mouse.
  27. Parameters
  28. ----------
  29. image : (M, N[, 3]) array
  30. Grayscale or RGB image.
  31. alpha : float, optional
  32. Transparency value for polygons drawn over the image.
  33. return_all : bool, optional
  34. If True, an array containing each separate polygon drawn is returned.
  35. (The polygons may overlap.) If False (default), latter polygons
  36. "overwrite" earlier ones where they overlap.
  37. Returns
  38. -------
  39. labels : array of int, shape ([Q, ]M, N)
  40. The segmented regions. If mode is `'separate'`, the leading dimension
  41. of the array corresponds to the number of regions that the user drew.
  42. Notes
  43. -----
  44. Use left click to select the vertices of the polygon
  45. and right click to confirm the selection once all vertices are selected.
  46. Examples
  47. --------
  48. >>> from skimage import data, future
  49. >>> import matplotlib.pyplot as plt # doctest: +SKIP
  50. >>> camera = data.camera()
  51. >>> mask = future.manual_polygon_segmentation(camera) # doctest: +SKIP
  52. >>> fig, ax = plt.subplots() # doctest: +SKIP
  53. >>> ax.imshow(mask) # doctest: +SKIP
  54. >>> plt.show() # doctest: +SKIP
  55. """
  56. import matplotlib
  57. import matplotlib.pyplot as plt
  58. list_of_vertex_lists = []
  59. polygons_drawn = []
  60. temp_list = []
  61. preview_polygon_drawn = []
  62. if image.ndim not in (2, 3):
  63. raise ValueError('Only 2D grayscale or RGB images are supported.')
  64. fig, ax = plt.subplots()
  65. fig.subplots_adjust(bottom=0.2)
  66. ax.imshow(image, cmap="gray")
  67. ax.set_axis_off()
  68. def _undo(*args, **kwargs):
  69. if list_of_vertex_lists:
  70. list_of_vertex_lists.pop()
  71. # Remove last polygon from list of polygons...
  72. last_poly = polygons_drawn.pop()
  73. # ... then from the plot
  74. last_poly.remove()
  75. fig.canvas.draw_idle()
  76. undo_pos = fig.add_axes([0.85, 0.05, 0.075, 0.075])
  77. undo_button = matplotlib.widgets.Button(undo_pos, '\u27f2')
  78. undo_button.on_clicked(_undo)
  79. def _extend_polygon(event):
  80. # Do not record click events outside axis or in undo button
  81. if event.inaxes is None or event.inaxes is undo_pos:
  82. return
  83. # Do not record click events when toolbar is active
  84. if ax.get_navigate_mode():
  85. return
  86. if event.button == LEFT_CLICK: # Select vertex
  87. temp_list.append([event.xdata, event.ydata])
  88. # Remove previously drawn preview polygon if any.
  89. if preview_polygon_drawn:
  90. poly = preview_polygon_drawn.pop()
  91. poly.remove()
  92. # Preview polygon with selected vertices.
  93. polygon = _draw_polygon(ax, temp_list, alpha=(alpha / 1.4))
  94. preview_polygon_drawn.append(polygon)
  95. elif event.button == RIGHT_CLICK: # Confirm the selection
  96. if not temp_list:
  97. return
  98. # Store the vertices of the polygon as shown in preview.
  99. # Redraw polygon and store it in polygons_drawn so that
  100. # `_undo` works correctly.
  101. list_of_vertex_lists.append(temp_list[:])
  102. polygon_object = _draw_polygon(ax, temp_list, alpha=alpha)
  103. polygons_drawn.append(polygon_object)
  104. # Empty the temporary variables.
  105. preview_poly = preview_polygon_drawn.pop()
  106. preview_poly.remove()
  107. del temp_list[:]
  108. plt.draw()
  109. fig.canvas.mpl_connect('button_press_event', _extend_polygon)
  110. plt.show(block=True)
  111. labels = (
  112. _mask_from_vertices(vertices, image.shape[:2], i)
  113. for i, vertices in enumerate(list_of_vertex_lists, start=1)
  114. )
  115. if return_all:
  116. return np.stack(labels)
  117. else:
  118. return reduce(np.maximum, labels, np.broadcast_to(0, image.shape[:2]))
  119. @require("matplotlib", ">=3.3")
  120. def manual_lasso_segmentation(image, alpha=0.4, return_all=False):
  121. """Return a label image based on freeform selections made with the mouse.
  122. Parameters
  123. ----------
  124. image : (M, N[, 3]) array
  125. Grayscale or RGB image.
  126. alpha : float, optional
  127. Transparency value for polygons drawn over the image.
  128. return_all : bool, optional
  129. If True, an array containing each separate polygon drawn is returned.
  130. (The polygons may overlap.) If False (default), latter polygons
  131. "overwrite" earlier ones where they overlap.
  132. Returns
  133. -------
  134. labels : array of int, shape ([Q, ]M, N)
  135. The segmented regions. If mode is `'separate'`, the leading dimension
  136. of the array corresponds to the number of regions that the user drew.
  137. Notes
  138. -----
  139. Press and hold the left mouse button to draw around each object.
  140. Examples
  141. --------
  142. >>> from skimage import data, future
  143. >>> import matplotlib.pyplot as plt # doctest: +SKIP
  144. >>> camera = data.camera()
  145. >>> mask = future.manual_lasso_segmentation(camera) # doctest: +SKIP
  146. >>> fig, ax = plt.subplots() # doctest: +SKIP
  147. >>> ax.imshow(mask) # doctest: +SKIP
  148. >>> plt.show() # doctest: +SKIP
  149. """
  150. import matplotlib
  151. import matplotlib.pyplot as plt
  152. list_of_vertex_lists = []
  153. polygons_drawn = []
  154. if image.ndim not in (2, 3):
  155. raise ValueError('Only 2D grayscale or RGB images are supported.')
  156. fig, ax = plt.subplots()
  157. fig.subplots_adjust(bottom=0.2)
  158. ax.imshow(image, cmap="gray")
  159. ax.set_axis_off()
  160. def _undo(*args, **kwargs):
  161. if list_of_vertex_lists:
  162. list_of_vertex_lists.pop()
  163. # Remove last polygon from list of polygons...
  164. last_poly = polygons_drawn.pop()
  165. # ... then from the plot
  166. last_poly.remove()
  167. fig.canvas.draw_idle()
  168. undo_pos = fig.add_axes([0.85, 0.05, 0.075, 0.075])
  169. undo_button = matplotlib.widgets.Button(undo_pos, '\u27f2')
  170. undo_button.on_clicked(_undo)
  171. def _on_lasso_selection(vertices):
  172. if len(vertices) < 3:
  173. return
  174. list_of_vertex_lists.append(vertices)
  175. polygon_object = _draw_polygon(ax, vertices, alpha=alpha)
  176. polygons_drawn.append(polygon_object)
  177. plt.draw()
  178. matplotlib.widgets.LassoSelector(ax, _on_lasso_selection)
  179. plt.show(block=True)
  180. labels = (
  181. _mask_from_vertices(vertices, image.shape[:2], i)
  182. for i, vertices in enumerate(list_of_vertex_lists, start=1)
  183. )
  184. if return_all:
  185. return np.stack(labels)
  186. else:
  187. return reduce(np.maximum, labels, np.broadcast_to(0, image.shape[:2]))