_rag.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. import networkx as nx
  2. import numpy as np
  3. from scipy import ndimage as ndi
  4. from scipy import sparse
  5. import math
  6. from .. import measure, segmentation, util, color
  7. from .._shared.version_requirements import require
  8. __doctest_requires__ = {("show_rag",): ["matplotlib"]}
  9. def _edge_generator_from_csr(csr_array):
  10. """Yield weighted edge triples for use by NetworkX from a CSR matrix.
  11. This function is a straight rewrite of
  12. `networkx.convert_matrix._csr_gen_triples`. Since that is a private
  13. function, it is safer to include our own here.
  14. Parameters
  15. ----------
  16. csr_array : scipy.sparse.csr_array
  17. The input matrix. An edge (i, j, w) will be yielded if there is a
  18. data value for coordinates (i, j) in the matrix, even if that value
  19. is 0.
  20. Yields
  21. ------
  22. i, j, w : (int, int, float) tuples
  23. Each value `w` in the matrix along with its coordinates (i, j).
  24. Examples
  25. --------
  26. >>> dense = np.eye(2, dtype=float)
  27. >>> csr = sparse.csr_array(dense)
  28. >>> edges = _edge_generator_from_csr(csr)
  29. >>> list(edges)
  30. [(0, 0, 1.0), (1, 1, 1.0)]
  31. """
  32. nrows = csr_array.shape[0]
  33. values = csr_array.data
  34. indptr = csr_array.indptr
  35. col_indices = csr_array.indices
  36. for i in range(nrows):
  37. for j in range(indptr[i], indptr[i + 1]):
  38. yield i, col_indices[j], values[j]
  39. def min_weight(graph, src, dst, n):
  40. """Callback to handle merging nodes by choosing minimum weight.
  41. Returns a dictionary with `"weight"` set as either the weight between
  42. (`src`, `n`) or (`dst`, `n`) in `graph` or the minimum of the two when
  43. both exist.
  44. Parameters
  45. ----------
  46. graph : RAG
  47. The graph under consideration.
  48. src, dst : int
  49. The verices in `graph` to be merged.
  50. n : int
  51. A neighbor of `src` or `dst` or both.
  52. Returns
  53. -------
  54. data : dict
  55. A dict with the `"weight"` attribute set the weight between
  56. (`src`, `n`) or (`dst`, `n`) in `graph` or the minimum of the two when
  57. both exist.
  58. """
  59. # cover the cases where n only has edge to either `src` or `dst`
  60. default = {'weight': np.inf}
  61. w1 = graph[n].get(src, default)['weight']
  62. w2 = graph[n].get(dst, default)['weight']
  63. return {'weight': min(w1, w2)}
  64. def _add_edge_filter(values, graph):
  65. """Create edge in `graph` between central element of `values` and the rest.
  66. Add an edge between the middle element in `values` and
  67. all other elements of `values` into `graph`. ``values[len(values) // 2]``
  68. is expected to be the central value of the footprint used.
  69. Parameters
  70. ----------
  71. values : array
  72. The array to process.
  73. graph : RAG
  74. The graph to add edges in.
  75. Returns
  76. -------
  77. 0 : float
  78. Always returns 0. The return value is required so that `generic_filter`
  79. can put it in the output array, but it is ignored by this filter.
  80. """
  81. values = values.astype(int)
  82. center = values[len(values) // 2]
  83. for value in values:
  84. if value != center and not graph.has_edge(center, value):
  85. graph.add_edge(center, value)
  86. return 0.0
  87. class RAG(nx.Graph):
  88. """The Region Adjacency Graph (RAG) of an image, subclasses :obj:`networkx.Graph`.
  89. Parameters
  90. ----------
  91. label_image : array of int
  92. An initial segmentation, with each region labeled as a different
  93. integer. Every unique value in ``label_image`` will correspond to
  94. a node in the graph.
  95. connectivity : int in {1, ..., ``label_image.ndim``}, optional
  96. The connectivity between pixels in ``label_image``. For a 2D image,
  97. a connectivity of 1 corresponds to immediate neighbors up, down,
  98. left, and right, while a connectivity of 2 also includes diagonal
  99. neighbors. See :func:`scipy.ndimage.generate_binary_structure`.
  100. data : :obj:`networkx.Graph` specification, optional
  101. Initial or additional edges to pass to :obj:`networkx.Graph`
  102. constructor. Valid edge specifications include edge list (list of tuples),
  103. NumPy arrays, and SciPy sparse matrices.
  104. **attr : keyword arguments, optional
  105. Additional attributes to add to the graph.
  106. """
  107. def __init__(self, label_image=None, connectivity=1, data=None, **attr):
  108. super().__init__(data, **attr)
  109. if self.number_of_nodes() == 0:
  110. self.max_id = 0
  111. else:
  112. self.max_id = max(self.nodes())
  113. if label_image is not None:
  114. fp = ndi.generate_binary_structure(label_image.ndim, connectivity)
  115. # In the next ``ndi.generic_filter`` function, the kwarg
  116. # ``output`` is used to provide a strided array with a single
  117. # 64-bit floating point number, to which the function repeatedly
  118. # writes. This is done because even if we don't care about the
  119. # output, without this, a float array of the same shape as the
  120. # input image will be created and that could be expensive in
  121. # memory consumption.
  122. output = np.broadcast_to(1.0, label_image.shape)
  123. output.setflags(write=True)
  124. ndi.generic_filter(
  125. label_image,
  126. function=_add_edge_filter,
  127. footprint=fp,
  128. mode='nearest',
  129. output=output,
  130. extra_arguments=(self,),
  131. )
  132. def merge_nodes(
  133. self,
  134. src,
  135. dst,
  136. weight_func=min_weight,
  137. in_place=True,
  138. extra_arguments=None,
  139. extra_keywords=None,
  140. ):
  141. """Merge node `src` and `dst`.
  142. The new combined node is adjacent to all the neighbors of `src`
  143. and `dst`. `weight_func` is called to decide the weight of edges
  144. incident on the new node.
  145. Parameters
  146. ----------
  147. src, dst : int
  148. Nodes to be merged.
  149. weight_func : callable, optional
  150. Function to decide the attributes of edges incident on the new
  151. node. For each neighbor `n` for `src` and `dst`, `weight_func` will
  152. be called as follows: `weight_func(src, dst, n, *extra_arguments,
  153. **extra_keywords)`. `src`, `dst` and `n` are IDs of vertices in the
  154. RAG object which is in turn a subclass of :obj:`networkx.Graph`. It is
  155. expected to return a dict of attributes of the resulting edge.
  156. in_place : bool, optional
  157. If set to `True`, the merged node has the id `dst`, else merged
  158. node has a new id which is returned.
  159. extra_arguments : sequence, optional
  160. The sequence of extra positional arguments passed to
  161. `weight_func`.
  162. extra_keywords : dictionary, optional
  163. The dict of keyword arguments passed to the `weight_func`.
  164. Returns
  165. -------
  166. id : int
  167. The id of the new node.
  168. Notes
  169. -----
  170. If `in_place` is `False` the resulting node has a new id, rather than
  171. `dst`.
  172. """
  173. if extra_arguments is None:
  174. extra_arguments = []
  175. if extra_keywords is None:
  176. extra_keywords = {}
  177. src_nbrs = set(self.neighbors(src))
  178. dst_nbrs = set(self.neighbors(dst))
  179. neighbors = (src_nbrs | dst_nbrs) - {src, dst}
  180. if in_place:
  181. new = dst
  182. else:
  183. new = self.next_id()
  184. self.add_node(new)
  185. for neighbor in neighbors:
  186. data = weight_func(
  187. self, src, dst, neighbor, *extra_arguments, **extra_keywords
  188. )
  189. self.add_edge(neighbor, new, attr_dict=data)
  190. self.nodes[new]['labels'] = (
  191. self.nodes[src]['labels'] + self.nodes[dst]['labels']
  192. )
  193. self.remove_node(src)
  194. if not in_place:
  195. self.remove_node(dst)
  196. return new
  197. def add_node(self, n, attr_dict=None, **attr):
  198. """Add node `n` while updating the maximum node id.
  199. .. seealso:: :obj:`networkx.Graph.add_node`."""
  200. if attr_dict is None: # compatibility with old networkx
  201. attr_dict = attr
  202. else:
  203. attr_dict.update(attr)
  204. super().add_node(n, **attr_dict)
  205. self.max_id = max(n, self.max_id)
  206. def add_edge(self, u, v, attr_dict=None, **attr):
  207. """Add an edge between `u` and `v` while updating max node id.
  208. .. seealso:: :obj:`networkx.Graph.add_edge`."""
  209. if attr_dict is None: # compatibility with old networkx
  210. attr_dict = attr
  211. else:
  212. attr_dict.update(attr)
  213. super().add_edge(u, v, **attr_dict)
  214. self.max_id = max(u, v, self.max_id)
  215. def copy(self):
  216. """Copy the graph with its max node id.
  217. .. seealso:: :obj:`networkx.Graph.copy`."""
  218. g = super().copy()
  219. g.max_id = self.max_id
  220. return g
  221. def fresh_copy(self):
  222. """Return a fresh copy graph with the same data structure.
  223. A fresh copy has no nodes, edges or graph attributes. It is
  224. the same data structure as the current graph. This method is
  225. typically used to create an empty version of the graph.
  226. This is required when subclassing Graph with networkx v2 and
  227. does not cause problems for v1. Here is more detail from
  228. the network migrating from 1.x to 2.x document::
  229. With the new GraphViews (SubGraph, ReversedGraph, etc)
  230. you can't assume that ``G.__class__()`` will create a new
  231. instance of the same graph type as ``G``. In fact, the
  232. call signature for ``__class__`` differs depending on
  233. whether ``G`` is a view or a base class. For v2.x you
  234. should use ``G.fresh_copy()`` to create a null graph of
  235. the correct type---ready to fill with nodes and edges.
  236. """
  237. return RAG()
  238. def next_id(self):
  239. """Returns the `id` for the new node to be inserted.
  240. The current implementation returns one more than the maximum `id`.
  241. Returns
  242. -------
  243. id : int
  244. The `id` of the new node to be inserted.
  245. """
  246. return self.max_id + 1
  247. def _add_node_silent(self, n):
  248. """Add node `n` without updating the maximum node id.
  249. This is a convenience method used internally.
  250. .. seealso:: :obj:`networkx.Graph.add_node`."""
  251. super().add_node(n)
  252. def rag_mean_color(image, labels, connectivity=2, mode='distance', sigma=255.0):
  253. """Compute the Region Adjacency Graph using mean colors.
  254. Given an image and its initial segmentation, this method constructs the
  255. corresponding Region Adjacency Graph (RAG). Each node in the RAG
  256. represents a set of pixels within `image` with the same label in `labels`.
  257. The weight between two adjacent regions represents how similar or
  258. dissimilar two regions are depending on the `mode` parameter.
  259. Parameters
  260. ----------
  261. image : ndarray, shape(M, N[, ..., P], 3)
  262. Input image.
  263. labels : ndarray, shape(M, N[, ..., P])
  264. The labelled image. This should have one dimension less than
  265. `image`. If `image` has dimensions `(M, N, 3)` `labels` should have
  266. dimensions `(M, N)`.
  267. connectivity : int, optional
  268. Pixels with a squared distance less than `connectivity` from each other
  269. are considered adjacent. It can range from 1 to `labels.ndim`. Its
  270. behavior is the same as `connectivity` parameter in
  271. ``scipy.ndimage.generate_binary_structure``.
  272. mode : {'distance', 'similarity'}, optional
  273. The strategy to assign edge weights.
  274. 'distance' : The weight between two adjacent regions is the
  275. :math:`|c_1 - c_2|`, where :math:`c_1` and :math:`c_2` are the mean
  276. colors of the two regions. It represents the Euclidean distance in
  277. their average color.
  278. 'similarity' : The weight between two adjacent is
  279. :math:`e^{-d^2/sigma}` where :math:`d=|c_1 - c_2|`, where
  280. :math:`c_1` and :math:`c_2` are the mean colors of the two regions.
  281. It represents how similar two regions are.
  282. sigma : float, optional
  283. Used for computation when `mode` is "similarity". It governs how
  284. close to each other two colors should be, for their corresponding edge
  285. weight to be significant. A very large value of `sigma` could make
  286. any two colors behave as though they were similar.
  287. Returns
  288. -------
  289. out : RAG
  290. The region adjacency graph.
  291. Examples
  292. --------
  293. >>> from skimage import data, segmentation, graph
  294. >>> img = data.astronaut()
  295. >>> labels = segmentation.slic(img)
  296. >>> rag = graph.rag_mean_color(img, labels)
  297. References
  298. ----------
  299. .. [1] Alain Tremeau and Philippe Colantoni
  300. "Regions Adjacency Graph Applied To Color Image Segmentation"
  301. :DOI:`10.1109/83.841950`
  302. """
  303. graph = RAG(labels, connectivity=connectivity)
  304. for n in graph:
  305. graph.nodes[n].update(
  306. {
  307. 'labels': [n],
  308. 'pixel count': 0,
  309. 'total color': np.array([0, 0, 0], dtype=np.float64),
  310. }
  311. )
  312. for index in np.ndindex(labels.shape):
  313. current = labels[index]
  314. graph.nodes[current]['pixel count'] += 1
  315. graph.nodes[current]['total color'] += image[index]
  316. for n in graph:
  317. graph.nodes[n]['mean color'] = (
  318. graph.nodes[n]['total color'] / graph.nodes[n]['pixel count']
  319. )
  320. for x, y, d in graph.edges(data=True):
  321. diff = graph.nodes[x]['mean color'] - graph.nodes[y]['mean color']
  322. diff = np.linalg.norm(diff)
  323. if mode == 'similarity':
  324. d['weight'] = math.e ** (-(diff**2) / sigma)
  325. elif mode == 'distance':
  326. d['weight'] = diff
  327. else:
  328. raise ValueError(f"The mode '{mode}' is not recognised")
  329. return graph
  330. def rag_boundary(labels, edge_map, connectivity=2):
  331. """Comouter RAG based on region boundaries
  332. Given an image's initial segmentation and its edge map this method
  333. constructs the corresponding Region Adjacency Graph (RAG). Each node in the
  334. RAG represents a set of pixels within the image with the same label in
  335. `labels`. The weight between two adjacent regions is the average value
  336. in `edge_map` along their boundary.
  337. labels : ndarray
  338. The labelled image.
  339. edge_map : ndarray
  340. This should have the same shape as that of `labels`. For all pixels
  341. along the boundary between 2 adjacent regions, the average value of the
  342. corresponding pixels in `edge_map` is the edge weight between them.
  343. connectivity : int, optional
  344. Pixels with a squared distance less than `connectivity` from each other
  345. are considered adjacent. It can range from 1 to `labels.ndim`. Its
  346. behavior is the same as `connectivity` parameter in
  347. `scipy.ndimage.generate_binary_structure`.
  348. Examples
  349. --------
  350. >>> from skimage import data, segmentation, filters, color, graph
  351. >>> img = data.chelsea()
  352. >>> labels = segmentation.slic(img)
  353. >>> edge_map = filters.sobel(color.rgb2gray(img))
  354. >>> rag = graph.rag_boundary(labels, edge_map)
  355. """
  356. conn = ndi.generate_binary_structure(labels.ndim, connectivity)
  357. eroded = ndi.grey_erosion(labels, footprint=conn)
  358. dilated = ndi.grey_dilation(labels, footprint=conn)
  359. boundaries0 = eroded != labels
  360. boundaries1 = dilated != labels
  361. labels_small = np.concatenate((eroded[boundaries0], labels[boundaries1]))
  362. labels_large = np.concatenate((labels[boundaries0], dilated[boundaries1]))
  363. n = np.max(labels_large) + 1
  364. # use a dummy broadcast array as data for RAG
  365. ones = np.broadcast_to(1.0, labels_small.shape)
  366. count_matrix = sparse.csr_array(
  367. (ones, (labels_small, labels_large)), dtype=int, shape=(n, n)
  368. )
  369. data = np.concatenate((edge_map[boundaries0], edge_map[boundaries1]))
  370. graph_matrix = sparse.csr_array((data, (labels_small, labels_large)))
  371. graph_matrix.data /= count_matrix.data
  372. rag = RAG()
  373. rag.add_weighted_edges_from(_edge_generator_from_csr(graph_matrix), weight='weight')
  374. rag.add_weighted_edges_from(_edge_generator_from_csr(count_matrix), weight='count')
  375. for n in rag.nodes():
  376. rag.nodes[n].update({'labels': [n]})
  377. return rag
  378. @require("matplotlib", ">=3.3")
  379. def show_rag(
  380. labels,
  381. rag,
  382. image,
  383. border_color='black',
  384. edge_width=1.5,
  385. edge_cmap='magma',
  386. img_cmap='bone',
  387. in_place=True,
  388. ax=None,
  389. ):
  390. """Show a Region Adjacency Graph on an image.
  391. Given a labelled image and its corresponding RAG, show the nodes and edges
  392. of the RAG on the image with the specified colors. Edges are displayed between
  393. the centroid of the 2 adjacent regions in the image.
  394. Parameters
  395. ----------
  396. labels : ndarray, shape (M, N)
  397. The labelled image.
  398. rag : RAG
  399. The Region Adjacency Graph.
  400. image : ndarray, shape (M, N[, 3])
  401. Input image. If `colormap` is `None`, the image should be in RGB
  402. format.
  403. border_color : color spec, optional
  404. Color with which the borders between regions are drawn.
  405. edge_width : float, optional
  406. The thickness with which the RAG edges are drawn.
  407. edge_cmap : :py:class:`matplotlib.colors.Colormap`, optional
  408. Any matplotlib colormap with which the edges are drawn.
  409. img_cmap : :py:class:`matplotlib.colors.Colormap`, optional
  410. Any matplotlib colormap with which the image is draw. If set to `None`
  411. the image is drawn as it is.
  412. in_place : bool, optional
  413. If set, the RAG is modified in place. For each node `n` the function
  414. will set a new attribute ``rag.nodes[n]['centroid']``.
  415. ax : :py:class:`matplotlib.axes.Axes`, optional
  416. The axes to draw on. If not specified, new axes are created and drawn
  417. on.
  418. Returns
  419. -------
  420. lc : :py:class:`matplotlib.collections.LineCollection`
  421. A collection of lines that represent the edges of the graph. It can be
  422. passed to the :meth:`matplotlib.figure.Figure.colorbar` function.
  423. Examples
  424. --------
  425. >>> from skimage import data, segmentation, graph
  426. >>> import matplotlib.pyplot as plt
  427. >>>
  428. >>> img = data.coffee()
  429. >>> labels = segmentation.slic(img)
  430. >>> g = graph.rag_mean_color(img, labels)
  431. >>> lc = graph.show_rag(labels, g, img)
  432. >>> cbar = plt.colorbar(lc)
  433. """
  434. from matplotlib import colors
  435. from matplotlib import pyplot as plt
  436. from matplotlib.collections import LineCollection
  437. if not in_place:
  438. rag = rag.copy()
  439. if ax is None:
  440. fig, ax = plt.subplots()
  441. out = util.img_as_float(image, force_copy=True)
  442. if img_cmap is None:
  443. if image.ndim < 3 or image.shape[2] not in [3, 4]:
  444. msg = 'If colormap is `None`, an RGB or RGBA image should be given'
  445. raise ValueError(msg)
  446. # Ignore the alpha channel
  447. out = image[:, :, :3]
  448. else:
  449. img_cmap = plt.get_cmap(img_cmap)
  450. out = color.rgb2gray(image)
  451. # Ignore the alpha channel
  452. out = img_cmap(out)[:, :, :3]
  453. edge_cmap = plt.get_cmap(edge_cmap)
  454. # Handling the case where one node has multiple labels
  455. # offset is 1 so that regionprops does not ignore 0
  456. offset = 1
  457. map_array = np.arange(labels.max() + 1)
  458. for n, d in rag.nodes(data=True):
  459. for label in d['labels']:
  460. map_array[label] = offset
  461. offset += 1
  462. rag_labels = map_array[labels]
  463. regions = measure.regionprops(rag_labels)
  464. for (n, data), region in zip(rag.nodes(data=True), regions):
  465. data['centroid'] = tuple(map(int, region['centroid']))
  466. cc = colors.ColorConverter()
  467. if border_color is not None:
  468. border_color = cc.to_rgb(border_color)
  469. out = segmentation.mark_boundaries(out, rag_labels, color=border_color)
  470. ax.imshow(out)
  471. # Defining the end points of the edges
  472. # The tuple[::-1] syntax reverses a tuple as matplotlib uses (x,y)
  473. # convention while skimage uses (row, column)
  474. lines = [
  475. [rag.nodes[n1]['centroid'][::-1], rag.nodes[n2]['centroid'][::-1]]
  476. for (n1, n2) in rag.edges()
  477. ]
  478. lc = LineCollection(lines, linewidths=edge_width, cmap=edge_cmap)
  479. edge_weights = [d['weight'] for x, y, d in rag.edges(data=True)]
  480. lc.set_array(np.array(edge_weights))
  481. ax.add_collection(lc)
  482. return lc