_graph_merge.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import numpy as np
  2. import heapq
  3. def _revalidate_node_edges(rag, node, heap_list):
  4. """Handles validation and invalidation of edges incident to a node.
  5. This function invalidates all existing edges incident on `node` and inserts
  6. new items in `heap_list` updated with the valid weights.
  7. rag : RAG
  8. The Region Adjacency Graph.
  9. node : int
  10. The id of the node whose incident edges are to be validated/invalidated
  11. .
  12. heap_list : list
  13. The list containing the existing heap of edges.
  14. """
  15. # networkx updates data dictionary if edge exists
  16. # this would mean we have to reposition these edges in
  17. # heap if their weight is updated.
  18. # instead we invalidate them
  19. for nbr in rag.neighbors(node):
  20. data = rag[node][nbr]
  21. try:
  22. # invalidate edges incident on `dst`, they have new weights
  23. data['heap item'][3] = False
  24. _invalidate_edge(rag, node, nbr)
  25. except KeyError:
  26. # will handle the case where the edge did not exist in the existing
  27. # graph
  28. pass
  29. wt = data['weight']
  30. heap_item = [wt, node, nbr, True]
  31. data['heap item'] = heap_item
  32. heapq.heappush(heap_list, heap_item)
  33. def _rename_node(graph, node_id, copy_id):
  34. """Rename `node_id` in `graph` to `copy_id`."""
  35. graph._add_node_silent(copy_id)
  36. graph.nodes[copy_id].update(graph.nodes[node_id])
  37. for nbr in graph.neighbors(node_id):
  38. wt = graph[node_id][nbr]['weight']
  39. graph.add_edge(nbr, copy_id, {'weight': wt})
  40. graph.remove_node(node_id)
  41. def _invalidate_edge(graph, n1, n2):
  42. """Invalidates the edge (n1, n2) in the heap."""
  43. graph[n1][n2]['heap item'][3] = False
  44. def merge_hierarchical(
  45. labels, rag, thresh, rag_copy, in_place_merge, merge_func, weight_func
  46. ):
  47. """Perform hierarchical merging of a RAG.
  48. Greedily merges the most similar pair of nodes until no edges lower than
  49. `thresh` remain.
  50. Parameters
  51. ----------
  52. labels : ndarray
  53. The array of labels.
  54. rag : RAG
  55. The Region Adjacency Graph.
  56. thresh : float
  57. Regions connected by an edge with weight smaller than `thresh` are
  58. merged.
  59. rag_copy : bool
  60. If set, the RAG copied before modifying.
  61. in_place_merge : bool
  62. If set, the nodes are merged in place. Otherwise, a new node is
  63. created for each merge..
  64. merge_func : callable
  65. This function is called before merging two nodes. For the RAG `graph`
  66. while merging `src` and `dst`, it is called as follows
  67. ``merge_func(graph, src, dst)``.
  68. weight_func : callable
  69. The function to compute the new weights of the nodes adjacent to the
  70. merged node. This is directly supplied as the argument `weight_func`
  71. to `merge_nodes`.
  72. Returns
  73. -------
  74. out : ndarray
  75. The new labeled array.
  76. """
  77. if rag_copy:
  78. rag = rag.copy()
  79. edge_heap = []
  80. for n1, n2, data in rag.edges(data=True):
  81. # Push a valid edge in the heap
  82. wt = data['weight']
  83. heap_item = [wt, n1, n2, True]
  84. heapq.heappush(edge_heap, heap_item)
  85. # Reference to the heap item in the graph
  86. data['heap item'] = heap_item
  87. while len(edge_heap) > 0 and edge_heap[0][0] < thresh:
  88. _, n1, n2, valid = heapq.heappop(edge_heap)
  89. # Ensure popped edge is valid, if not, the edge is discarded
  90. if valid:
  91. # Invalidate all neighbors of `src` before its deleted
  92. for nbr in rag.neighbors(n1):
  93. _invalidate_edge(rag, n1, nbr)
  94. for nbr in rag.neighbors(n2):
  95. _invalidate_edge(rag, n2, nbr)
  96. if not in_place_merge:
  97. next_id = rag.next_id()
  98. _rename_node(rag, n2, next_id)
  99. src, dst = n1, next_id
  100. else:
  101. src, dst = n1, n2
  102. merge_func(rag, src, dst)
  103. new_id = rag.merge_nodes(src, dst, weight_func)
  104. _revalidate_node_edges(rag, new_id, edge_heap)
  105. label_map = np.arange(labels.max() + 1)
  106. for ix, (n, d) in enumerate(rag.nodes(data=True)):
  107. for label in d['labels']:
  108. label_map[label] = ix
  109. return label_map[labels]