traceback.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import logging
  4. import traceback
  5. from contextlib import contextmanager
  6. from enum import Enum
  7. from typing import Any, Optional, Union
  8. from torch._utils_internal import signpost_event
  9. from ._compatibility import compatibility
  10. from .graph import Graph
  11. from .node import Node
  12. log = logging.getLogger(__name__)
  13. __all__ = [
  14. "preserve_node_meta",
  15. "has_preserved_node_meta",
  16. "set_stack_trace",
  17. "set_grad_fn_seq_nr",
  18. "reset_grad_fn_seq_nr",
  19. "format_stack",
  20. "set_current_meta",
  21. "get_current_meta",
  22. "NodeSource",
  23. "NodeSourceAction",
  24. "get_graph_provenance_json",
  25. ]
  26. current_meta: dict[str, Any] = {}
  27. should_preserve_node_meta = False
  28. @compatibility(is_backward_compatible=False)
  29. class NodeSourceAction(Enum):
  30. CREATE = "create"
  31. REPLACE = "replace"
  32. @compatibility(is_backward_compatible=False)
  33. class NodeSource:
  34. """
  35. NodeSource is a data structure that contains the provenance information of a node.
  36. If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
  37. """
  38. class NodeInfo:
  39. def __init__(self, name: str, target: str, graph_id: int):
  40. self.name = name
  41. self.target = target
  42. self.graph_id = graph_id
  43. pass_name: str
  44. action: list["NodeSourceAction"]
  45. from_node: list["NodeSource"]
  46. node_info: Optional["NodeInfo"]
  47. _dict: Optional[dict[str, Any]]
  48. _action_string: Optional[str]
  49. def __init__(
  50. self,
  51. node: Optional[Node],
  52. pass_name: str = "",
  53. action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
  54. ):
  55. self.pass_name = pass_name
  56. if action is None:
  57. action = []
  58. elif not isinstance(action, list):
  59. action = [action]
  60. for a in action:
  61. assert isinstance(a, NodeSourceAction)
  62. self.action = action
  63. if node:
  64. self.node_info = self.NodeInfo(
  65. name=node.name, target=str(node.target), graph_id=id(node.graph)
  66. )
  67. self.from_node = (
  68. copy.deepcopy(node.meta["from_node"])
  69. if "from_node" in node.meta
  70. else []
  71. )
  72. else:
  73. self.node_info = None
  74. self.from_node = []
  75. # cache the action string and dict representation for performance.
  76. self._action_string: Optional[str] = None
  77. self._dict: Optional[dict[str, Any]] = None
  78. @property
  79. def name(self) -> str:
  80. return self.node_info.name if self.node_info else ""
  81. @property
  82. def target(self) -> str:
  83. return self.node_info.target if self.node_info else ""
  84. @property
  85. def graph_id(self) -> int:
  86. return self.node_info.graph_id if self.node_info else -1
  87. def __repr__(self):
  88. return self.print_readable()
  89. def _get_action_string(self):
  90. if self._action_string is None:
  91. self._action_string = "+".join([a.name.lower() for a in self.action])
  92. return self._action_string
  93. def print_readable(self, indent=0):
  94. if indent > 9:
  95. return ""
  96. result = ""
  97. action_string = self._get_action_string()
  98. result += (
  99. " " * indent * 4
  100. + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
  101. )
  102. for item in self.from_node:
  103. result += item.print_readable(indent + 1)
  104. return result
  105. def to_dict(self) -> dict:
  106. if self._dict is None:
  107. # Convert the object to a dictionary
  108. action_string = self._get_action_string()
  109. self._dict = {
  110. "name": self.name,
  111. "target": self.target,
  112. "graph_id": self.graph_id,
  113. "pass_name": self.pass_name,
  114. "action": action_string,
  115. "from_node": [node.to_dict() for node in self.from_node],
  116. }
  117. assert self._dict is not None
  118. return self._dict
  119. def __eq__(self, other: object):
  120. if not isinstance(other, NodeSource):
  121. return False
  122. return self.to_dict() == other.to_dict()
  123. def __hash__(self):
  124. # Create a hash based on the dictionary representation
  125. # We need to convert the dict to a hashable form
  126. def _make_hashable(obj):
  127. if isinstance(obj, dict):
  128. return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
  129. elif isinstance(obj, list):
  130. return tuple(_make_hashable(item) for item in obj)
  131. else:
  132. return obj
  133. return hash(_make_hashable(self.to_dict()))
  134. @classmethod
  135. def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
  136. """
  137. Recursively deserialize from_node metadata from dictionary data.
  138. It is used to deserialize the from_node field from serialized metadata.
  139. Please use constructor NodeSource(node, ...) to create a NodeSource object.
  140. """
  141. if d is None:
  142. return None
  143. assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
  144. # Create a NodeSource object directly without going through the constructor
  145. # to avoid issues with graph ID and node creation
  146. node_source = NodeSource.__new__(NodeSource)
  147. # Reset the cached properties
  148. node_source._action_string = None
  149. node_source._dict = None
  150. # Set the basic attributes
  151. node_source.pass_name = d.get("pass_name", "")
  152. # Parse action string back to NodeSourceAction enum list
  153. action_str = d.get("action", "")
  154. actions = []
  155. if action_str:
  156. for action_name in action_str.split("+"):
  157. if action_name.upper() == "CREATE":
  158. actions.append(NodeSourceAction.CREATE)
  159. elif action_name.upper() == "REPLACE":
  160. actions.append(NodeSourceAction.REPLACE)
  161. node_source.action = actions
  162. # Create the NodeInfo object directly
  163. if "name" in d and "target" in d and "graph_id" in d:
  164. node_info = NodeSource.NodeInfo(
  165. d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
  166. )
  167. node_source.node_info = node_info
  168. else:
  169. node_source.node_info = None
  170. # Recursively deserialize nested from_node
  171. if d.get("from_node", None) is not None:
  172. node_source.from_node = [
  173. result
  174. for fn in d.get("from_node", [])
  175. if (result := cls._from_dict(fn)) is not None
  176. ]
  177. else:
  178. node_source.from_node = []
  179. return node_source
  180. @compatibility(is_backward_compatible=False)
  181. @contextmanager
  182. def preserve_node_meta(enable=True):
  183. global should_preserve_node_meta
  184. global current_meta
  185. # If enable is False, this context manager is a no-op
  186. if not enable:
  187. yield
  188. else:
  189. saved_should_preserve_node_meta = should_preserve_node_meta
  190. # Shallow copy is OK since fields of current_meta are not mutated
  191. saved_current_meta = current_meta.copy()
  192. try:
  193. should_preserve_node_meta = True
  194. yield
  195. finally:
  196. should_preserve_node_meta = saved_should_preserve_node_meta
  197. current_meta = saved_current_meta
  198. @compatibility(is_backward_compatible=False)
  199. def set_stack_trace(stack: list[str]):
  200. global current_meta
  201. if should_preserve_node_meta and stack:
  202. current_meta["stack_trace"] = "".join(stack)
  203. @compatibility(is_backward_compatible=False)
  204. def set_grad_fn_seq_nr(seq_nr):
  205. global current_meta
  206. if should_preserve_node_meta:
  207. # The seq_nr is captured by eager mode in the grad_fn during forward
  208. current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
  209. seq_nr
  210. ]
  211. current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
  212. @compatibility(is_backward_compatible=False)
  213. def reset_grad_fn_seq_nr():
  214. # NB: reset state properly, this would be helpful towards supporting
  215. # reentrant autograd if we actually wanted to do that.
  216. global current_meta
  217. if should_preserve_node_meta:
  218. current_level = current_meta.get("in_grad_fn", 0)
  219. assert current_level > 0
  220. if current_level == 1:
  221. del current_meta["in_grad_fn"]
  222. del current_meta["grad_fn_seq_nr"]
  223. else:
  224. current_meta["in_grad_fn"] = current_level - 1
  225. current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
  226. @compatibility(is_backward_compatible=False)
  227. def format_stack() -> list[str]:
  228. if should_preserve_node_meta:
  229. return [current_meta.get("stack_trace", "")]
  230. else:
  231. # fallback to traceback.format_stack()
  232. return traceback.format_list(traceback.extract_stack()[:-1])
  233. @compatibility(is_backward_compatible=False)
  234. def has_preserved_node_meta() -> bool:
  235. return should_preserve_node_meta
  236. @compatibility(is_backward_compatible=False)
  237. @contextmanager
  238. def set_current_meta(node, pass_name=""):
  239. global current_meta
  240. if should_preserve_node_meta and node.meta:
  241. saved_meta = current_meta
  242. try:
  243. current_meta = node.meta.copy()
  244. # Update the "from_node" field in current_meta for provenance tracking.
  245. # Instead of appending, overwrite the "from_node" field because current_meta
  246. # will be assigned to the new node. The new NodeSource(node, ...) will
  247. # include the information from the previous current_meta["from_node"].
  248. current_meta["from_node"] = [
  249. NodeSource(node, pass_name, NodeSourceAction.CREATE)
  250. ]
  251. yield
  252. finally:
  253. current_meta = saved_meta
  254. else:
  255. yield
  256. @compatibility(is_backward_compatible=False)
  257. def get_current_meta() -> dict[str, Any]:
  258. return current_meta
  259. @compatibility(is_backward_compatible=False)
  260. def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
  261. """
  262. Given an fx.Graph, return a json that contains the provenance information of each node.
  263. """
  264. try:
  265. provenance_tracking_json = {}
  266. for node in graph.nodes:
  267. if node.op == "call_function":
  268. provenance_tracking_json[node.name] = (
  269. [source.to_dict() for source in node.meta["from_node"]]
  270. if "from_node" in node.meta
  271. else []
  272. )
  273. return provenance_tracking_json
  274. except Exception as e:
  275. # Since this is just debugging, it should never interfere with regular
  276. # program execution, so we use this try-except to guard against any error
  277. signpost_event(
  278. "inductor",
  279. "provenance_tracking_error",
  280. {
  281. "function": "get_graph_provenance_json",
  282. "error_msg": str(e),
  283. "stack_trace": traceback.format_exc(),
  284. },
  285. )
  286. return {}