| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 |
- # mypy: allow-untyped-defs
- import copy
- import logging
- import traceback
- from contextlib import contextmanager
- from enum import Enum
- from typing import Any, Optional, Union
- from torch._utils_internal import signpost_event
- from ._compatibility import compatibility
- from .graph import Graph
- from .node import Node
- log = logging.getLogger(__name__)
- __all__ = [
- "preserve_node_meta",
- "has_preserved_node_meta",
- "set_stack_trace",
- "set_grad_fn_seq_nr",
- "reset_grad_fn_seq_nr",
- "format_stack",
- "set_current_meta",
- "get_current_meta",
- "NodeSource",
- "NodeSourceAction",
- "get_graph_provenance_json",
- ]
- current_meta: dict[str, Any] = {}
- should_preserve_node_meta = False
- @compatibility(is_backward_compatible=False)
- class NodeSourceAction(Enum):
- CREATE = "create"
- REPLACE = "replace"
- @compatibility(is_backward_compatible=False)
- class NodeSource:
- """
- NodeSource is a data structure that contains the provenance information of a node.
- If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
- """
- class NodeInfo:
- def __init__(self, name: str, target: str, graph_id: int):
- self.name = name
- self.target = target
- self.graph_id = graph_id
- pass_name: str
- action: list["NodeSourceAction"]
- from_node: list["NodeSource"]
- node_info: Optional["NodeInfo"]
- _dict: Optional[dict[str, Any]]
- _action_string: Optional[str]
- def __init__(
- self,
- node: Optional[Node],
- pass_name: str = "",
- action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
- ):
- self.pass_name = pass_name
- if action is None:
- action = []
- elif not isinstance(action, list):
- action = [action]
- for a in action:
- assert isinstance(a, NodeSourceAction)
- self.action = action
- if node:
- self.node_info = self.NodeInfo(
- name=node.name, target=str(node.target), graph_id=id(node.graph)
- )
- self.from_node = (
- copy.deepcopy(node.meta["from_node"])
- if "from_node" in node.meta
- else []
- )
- else:
- self.node_info = None
- self.from_node = []
- # cache the action string and dict representation for performance.
- self._action_string: Optional[str] = None
- self._dict: Optional[dict[str, Any]] = None
- @property
- def name(self) -> str:
- return self.node_info.name if self.node_info else ""
- @property
- def target(self) -> str:
- return self.node_info.target if self.node_info else ""
- @property
- def graph_id(self) -> int:
- return self.node_info.graph_id if self.node_info else -1
- def __repr__(self):
- return self.print_readable()
- def _get_action_string(self):
- if self._action_string is None:
- self._action_string = "+".join([a.name.lower() for a in self.action])
- return self._action_string
- def print_readable(self, indent=0):
- if indent > 9:
- return ""
- result = ""
- action_string = self._get_action_string()
- result += (
- " " * indent * 4
- + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
- )
- for item in self.from_node:
- result += item.print_readable(indent + 1)
- return result
- def to_dict(self) -> dict:
- if self._dict is None:
- # Convert the object to a dictionary
- action_string = self._get_action_string()
- self._dict = {
- "name": self.name,
- "target": self.target,
- "graph_id": self.graph_id,
- "pass_name": self.pass_name,
- "action": action_string,
- "from_node": [node.to_dict() for node in self.from_node],
- }
- assert self._dict is not None
- return self._dict
- def __eq__(self, other: object):
- if not isinstance(other, NodeSource):
- return False
- return self.to_dict() == other.to_dict()
- def __hash__(self):
- # Create a hash based on the dictionary representation
- # We need to convert the dict to a hashable form
- def _make_hashable(obj):
- if isinstance(obj, dict):
- return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
- elif isinstance(obj, list):
- return tuple(_make_hashable(item) for item in obj)
- else:
- return obj
- return hash(_make_hashable(self.to_dict()))
- @classmethod
- def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
- """
- Recursively deserialize from_node metadata from dictionary data.
- It is used to deserialize the from_node field from serialized metadata.
- Please use constructor NodeSource(node, ...) to create a NodeSource object.
- """
- if d is None:
- return None
- assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
- # Create a NodeSource object directly without going through the constructor
- # to avoid issues with graph ID and node creation
- node_source = NodeSource.__new__(NodeSource)
- # Reset the cached properties
- node_source._action_string = None
- node_source._dict = None
- # Set the basic attributes
- node_source.pass_name = d.get("pass_name", "")
- # Parse action string back to NodeSourceAction enum list
- action_str = d.get("action", "")
- actions = []
- if action_str:
- for action_name in action_str.split("+"):
- if action_name.upper() == "CREATE":
- actions.append(NodeSourceAction.CREATE)
- elif action_name.upper() == "REPLACE":
- actions.append(NodeSourceAction.REPLACE)
- node_source.action = actions
- # Create the NodeInfo object directly
- if "name" in d and "target" in d and "graph_id" in d:
- node_info = NodeSource.NodeInfo(
- d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
- )
- node_source.node_info = node_info
- else:
- node_source.node_info = None
- # Recursively deserialize nested from_node
- if d.get("from_node", None) is not None:
- node_source.from_node = [
- result
- for fn in d.get("from_node", [])
- if (result := cls._from_dict(fn)) is not None
- ]
- else:
- node_source.from_node = []
- return node_source
- @compatibility(is_backward_compatible=False)
- @contextmanager
- def preserve_node_meta(enable=True):
- global should_preserve_node_meta
- global current_meta
- # If enable is False, this context manager is a no-op
- if not enable:
- yield
- else:
- saved_should_preserve_node_meta = should_preserve_node_meta
- # Shallow copy is OK since fields of current_meta are not mutated
- saved_current_meta = current_meta.copy()
- try:
- should_preserve_node_meta = True
- yield
- finally:
- should_preserve_node_meta = saved_should_preserve_node_meta
- current_meta = saved_current_meta
- @compatibility(is_backward_compatible=False)
- def set_stack_trace(stack: list[str]):
- global current_meta
- if should_preserve_node_meta and stack:
- current_meta["stack_trace"] = "".join(stack)
- @compatibility(is_backward_compatible=False)
- def set_grad_fn_seq_nr(seq_nr):
- global current_meta
- if should_preserve_node_meta:
- # The seq_nr is captured by eager mode in the grad_fn during forward
- current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
- seq_nr
- ]
- current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
- @compatibility(is_backward_compatible=False)
- def reset_grad_fn_seq_nr():
- # NB: reset state properly, this would be helpful towards supporting
- # reentrant autograd if we actually wanted to do that.
- global current_meta
- if should_preserve_node_meta:
- current_level = current_meta.get("in_grad_fn", 0)
- assert current_level > 0
- if current_level == 1:
- del current_meta["in_grad_fn"]
- del current_meta["grad_fn_seq_nr"]
- else:
- current_meta["in_grad_fn"] = current_level - 1
- current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
- @compatibility(is_backward_compatible=False)
- def format_stack() -> list[str]:
- if should_preserve_node_meta:
- return [current_meta.get("stack_trace", "")]
- else:
- # fallback to traceback.format_stack()
- return traceback.format_list(traceback.extract_stack()[:-1])
- @compatibility(is_backward_compatible=False)
- def has_preserved_node_meta() -> bool:
- return should_preserve_node_meta
- @compatibility(is_backward_compatible=False)
- @contextmanager
- def set_current_meta(node, pass_name=""):
- global current_meta
- if should_preserve_node_meta and node.meta:
- saved_meta = current_meta
- try:
- current_meta = node.meta.copy()
- # Update the "from_node" field in current_meta for provenance tracking.
- # Instead of appending, overwrite the "from_node" field because current_meta
- # will be assigned to the new node. The new NodeSource(node, ...) will
- # include the information from the previous current_meta["from_node"].
- current_meta["from_node"] = [
- NodeSource(node, pass_name, NodeSourceAction.CREATE)
- ]
- yield
- finally:
- current_meta = saved_meta
- else:
- yield
- @compatibility(is_backward_compatible=False)
- def get_current_meta() -> dict[str, Any]:
- return current_meta
- @compatibility(is_backward_compatible=False)
- def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
- """
- Given an fx.Graph, return a json that contains the provenance information of each node.
- """
- try:
- provenance_tracking_json = {}
- for node in graph.nodes:
- if node.op == "call_function":
- provenance_tracking_json[node.name] = (
- [source.to_dict() for source in node.meta["from_node"]]
- if "from_node" in node.meta
- else []
- )
- return provenance_tracking_json
- except Exception as e:
- # Since this is just debugging, it should never interfere with regular
- # program execution, so we use this try-except to guard against any error
- signpost_event(
- "inductor",
- "provenance_tracking_error",
- {
- "function": "get_graph_provenance_json",
- "error_msg": str(e),
- "stack_trace": traceback.format_exc(),
- },
- )
- return {}
|