graph_utils.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from collections import deque
  2. from typing import Any
  3. from torch.fx import Graph, map_arg, Node
  4. from torch.utils._ordered_set import OrderedSet
  5. # flattens with support for slices
  6. # Note: a better way to do this would
  7. # be register/unregister slices as pytree nodes
  8. # but there is no unregister API in the pytorch
  9. # pytree impl
  10. def _get_flat_args(
  11. node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  12. ) -> list[Node]:
  13. args = list[Any]()
  14. map_arg((node.args, node.kwargs), args.append)
  15. if node in node_to_additional_deps:
  16. args.extend(node_to_additional_deps[node])
  17. return args
  18. def _get_flat_args_unique(
  19. node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  20. ) -> OrderedSet[Node]:
  21. args = OrderedSet[Node]()
  22. map_arg((node.args, node.kwargs), args.add)
  23. if node in node_to_additional_deps:
  24. args.update(node_to_additional_deps[node])
  25. return args
  26. def _detect_cycles(
  27. graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
  28. ) -> str:
  29. current_path: deque[Node] = deque()
  30. current_path_set: set[Node] = set()
  31. pending: deque[tuple[Node, Node]] = deque()
  32. def add_to_current_path(node: Node) -> None:
  33. current_path.append(node)
  34. current_path_set.add(node)
  35. def pop_current_path() -> None:
  36. node = current_path.pop()
  37. current_path_set.remove(node)
  38. def current_path_head() -> Node:
  39. return current_path[-1]
  40. for origin in graph.find_nodes(op="output"):
  41. current_path.clear()
  42. current_path_set.clear()
  43. add_to_current_path(origin)
  44. for child in _get_flat_args_unique(origin, node_to_additional_deps):
  45. pending.append((child, origin))
  46. while pending:
  47. cur_node, parent = pending.pop()
  48. # handle backtracking
  49. while current_path and current_path_head() != parent:
  50. pop_current_path()
  51. if not isinstance(cur_node, Node):
  52. continue
  53. if cur_node in current_path_set:
  54. current_path.append(cur_node)
  55. return f"cycle detected in path: {current_path}"
  56. add_to_current_path(cur_node)
  57. for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
  58. pending.append((child, cur_node))
  59. return "no cycle detected"