| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import numpy as np
- import heapq
- def _revalidate_node_edges(rag, node, heap_list):
- """Handles validation and invalidation of edges incident to a node.
- This function invalidates all existing edges incident on `node` and inserts
- new items in `heap_list` updated with the valid weights.
- rag : RAG
- The Region Adjacency Graph.
- node : int
- The id of the node whose incident edges are to be validated/invalidated
- .
- heap_list : list
- The list containing the existing heap of edges.
- """
- # networkx updates data dictionary if edge exists
- # this would mean we have to reposition these edges in
- # heap if their weight is updated.
- # instead we invalidate them
- for nbr in rag.neighbors(node):
- data = rag[node][nbr]
- try:
- # invalidate edges incident on `dst`, they have new weights
- data['heap item'][3] = False
- _invalidate_edge(rag, node, nbr)
- except KeyError:
- # will handle the case where the edge did not exist in the existing
- # graph
- pass
- wt = data['weight']
- heap_item = [wt, node, nbr, True]
- data['heap item'] = heap_item
- heapq.heappush(heap_list, heap_item)
- def _rename_node(graph, node_id, copy_id):
- """Rename `node_id` in `graph` to `copy_id`."""
- graph._add_node_silent(copy_id)
- graph.nodes[copy_id].update(graph.nodes[node_id])
- for nbr in graph.neighbors(node_id):
- wt = graph[node_id][nbr]['weight']
- graph.add_edge(nbr, copy_id, {'weight': wt})
- graph.remove_node(node_id)
- def _invalidate_edge(graph, n1, n2):
- """Invalidates the edge (n1, n2) in the heap."""
- graph[n1][n2]['heap item'][3] = False
- def merge_hierarchical(
- labels, rag, thresh, rag_copy, in_place_merge, merge_func, weight_func
- ):
- """Perform hierarchical merging of a RAG.
- Greedily merges the most similar pair of nodes until no edges lower than
- `thresh` remain.
- Parameters
- ----------
- labels : ndarray
- The array of labels.
- rag : RAG
- The Region Adjacency Graph.
- thresh : float
- Regions connected by an edge with weight smaller than `thresh` are
- merged.
- rag_copy : bool
- If set, the RAG copied before modifying.
- in_place_merge : bool
- If set, the nodes are merged in place. Otherwise, a new node is
- created for each merge..
- merge_func : callable
- This function is called before merging two nodes. For the RAG `graph`
- while merging `src` and `dst`, it is called as follows
- ``merge_func(graph, src, dst)``.
- weight_func : callable
- The function to compute the new weights of the nodes adjacent to the
- merged node. This is directly supplied as the argument `weight_func`
- to `merge_nodes`.
- Returns
- -------
- out : ndarray
- The new labeled array.
- """
- if rag_copy:
- rag = rag.copy()
- edge_heap = []
- for n1, n2, data in rag.edges(data=True):
- # Push a valid edge in the heap
- wt = data['weight']
- heap_item = [wt, n1, n2, True]
- heapq.heappush(edge_heap, heap_item)
- # Reference to the heap item in the graph
- data['heap item'] = heap_item
- while len(edge_heap) > 0 and edge_heap[0][0] < thresh:
- _, n1, n2, valid = heapq.heappop(edge_heap)
- # Ensure popped edge is valid, if not, the edge is discarded
- if valid:
- # Invalidate all neighbors of `src` before its deleted
- for nbr in rag.neighbors(n1):
- _invalidate_edge(rag, n1, nbr)
- for nbr in rag.neighbors(n2):
- _invalidate_edge(rag, n2, nbr)
- if not in_place_merge:
- next_id = rag.next_id()
- _rename_node(rag, n2, next_id)
- src, dst = n1, next_id
- else:
- src, dst = n1, n2
- merge_func(rag, src, dst)
- new_id = rag.merge_nodes(src, dst, weight_func)
- _revalidate_node_edges(rag, new_id, edge_heap)
- label_map = np.arange(labels.max() + 1)
- for ix, (n, d) in enumerate(rag.nodes(data=True)):
- for label in d['labels']:
- label_map[label] = ix
- return label_map[labels]
|