traceback.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import logging
  4. import traceback
  5. from contextlib import contextmanager
  6. from enum import Enum
  7. from typing import Any, Optional, Union
  8. from torch._utils_internal import signpost_event
  9. from ._compatibility import compatibility
  10. from .graph import Graph
  11. from .graph_module import GraphModule
  12. from .node import Node
  13. log = logging.getLogger(__name__)
  14. __all__ = [
  15. "annotate",
  16. "annotate_fn",
  17. "preserve_node_meta",
  18. "has_preserved_node_meta",
  19. "set_stack_trace",
  20. "set_grad_fn_seq_nr",
  21. "reset_grad_fn_seq_nr",
  22. "format_stack",
  23. "set_current_meta",
  24. "get_current_meta",
  25. "NodeSource",
  26. "NodeSourceAction",
  27. "get_graph_provenance_json",
  28. "set_current_replay_node",
  29. "get_current_replay_node",
  30. ]
  31. current_meta: dict[str, Any] = {}
  32. current_replay_node: Optional[Node] = None
  33. should_preserve_node_meta = False
  34. GRADIENT_ACC_SPECIAL_STACK = (
  35. "Gradient addition node due to multiple use of tensor around:"
  36. )
  37. # =============================================================================
  38. # FX Metadata Registry for Memory Profiler
  39. # =============================================================================
  40. # Global in-memory registry for FX metadata
  41. # Maps module_name -> metadata dict containing lineno_map and node_metadata
  42. _FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
  43. def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
  44. """
  45. Register FX metadata in the global in-memory registry.
  46. This is called automatically during graph module compilation to store metadata
  47. for later use by memory profiler augmentation.
  48. Args:
  49. module_name: The module identifier (content-addressed filename)
  50. metadata: Metadata dict containing lineno_map, node_metadata, and source_code
  51. """
  52. # TODO: add logging to tlparse
  53. _FX_METADATA_REGISTRY[module_name] = metadata
  54. @compatibility(is_backward_compatible=False)
  55. class NodeSourceAction(Enum):
  56. CREATE = "create"
  57. REPLACE = "replace"
  58. @compatibility(is_backward_compatible=False)
  59. class NodeSource:
  60. """
  61. NodeSource is a data structure that contains the provenance information of a node.
  62. If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b).
  63. """
  64. class NodeInfo:
  65. def __init__(self, name: str, target: str, graph_id: int):
  66. self.name = name
  67. self.target = target
  68. self.graph_id = graph_id
  69. pass_name: str
  70. action: list["NodeSourceAction"]
  71. from_node: list["NodeSource"]
  72. node_info: Optional["NodeInfo"]
  73. _dict: Optional[dict[str, Any]]
  74. _action_string: Optional[str]
  75. def __init__(
  76. self,
  77. node: Optional[Node],
  78. pass_name: str = "",
  79. action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None,
  80. ):
  81. self.pass_name = pass_name
  82. if action is None:
  83. action = []
  84. elif not isinstance(action, list):
  85. action = [action]
  86. for a in action:
  87. assert isinstance(a, NodeSourceAction)
  88. self.action = action
  89. if node:
  90. self.node_info = self.NodeInfo(
  91. name=node.name, target=str(node.target), graph_id=id(node.graph)
  92. )
  93. self.from_node = (
  94. copy.deepcopy(node.meta["from_node"])
  95. if "from_node" in node.meta
  96. else []
  97. )
  98. else:
  99. self.node_info = None
  100. self.from_node = []
  101. # cache the action string and dict representation for performance.
  102. self._action_string: Optional[str] = None
  103. self._dict: Optional[dict[str, Any]] = None
  104. @property
  105. def name(self) -> str:
  106. return self.node_info.name if self.node_info else ""
  107. @property
  108. def target(self) -> str:
  109. return self.node_info.target if self.node_info else ""
  110. @property
  111. def graph_id(self) -> int:
  112. return self.node_info.graph_id if self.node_info else -1
  113. def __repr__(self):
  114. return self.print_readable()
  115. def _get_action_string(self):
  116. if self._action_string is None:
  117. self._action_string = "+".join([a.name.lower() for a in self.action])
  118. return self._action_string
  119. def print_readable(self, indent=0):
  120. if indent > 9:
  121. return ""
  122. result = ""
  123. action_string = self._get_action_string()
  124. result += (
  125. " " * indent * 4
  126. + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n"
  127. )
  128. for item in self.from_node:
  129. result += item.print_readable(indent + 1)
  130. return result
  131. def to_dict(self) -> dict:
  132. if self._dict is None:
  133. # Convert the object to a dictionary
  134. action_string = self._get_action_string()
  135. self._dict = {
  136. "name": self.name,
  137. "target": self.target,
  138. "graph_id": self.graph_id,
  139. "pass_name": self.pass_name,
  140. "action": action_string,
  141. "from_node": [node.to_dict() for node in self.from_node],
  142. }
  143. assert self._dict is not None
  144. return self._dict
  145. def __eq__(self, other: object):
  146. if not isinstance(other, NodeSource):
  147. return False
  148. return self.to_dict() == other.to_dict()
  149. def __hash__(self):
  150. # Create a hash based on the dictionary representation
  151. # We need to convert the dict to a hashable form
  152. def _make_hashable(obj):
  153. if isinstance(obj, dict):
  154. return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items()))
  155. elif isinstance(obj, list):
  156. return tuple(_make_hashable(item) for item in obj)
  157. else:
  158. return obj
  159. return hash(_make_hashable(self.to_dict()))
  160. @classmethod
  161. def _from_dict(cls, d: Optional[dict]) -> Optional["NodeSource"]:
  162. """
  163. Recursively deserialize from_node metadata from dictionary data.
  164. It is used to deserialize the from_node field from serialized metadata.
  165. Please use constructor NodeSource(node, ...) to create a NodeSource object.
  166. """
  167. if d is None:
  168. return None
  169. assert isinstance(d, dict), f"Expected a dict, got {type(d)}"
  170. # Create a NodeSource object directly without going through the constructor
  171. # to avoid issues with graph ID and node creation
  172. node_source = NodeSource.__new__(NodeSource)
  173. # Reset the cached properties
  174. node_source._action_string = None
  175. node_source._dict = None
  176. # Set the basic attributes
  177. node_source.pass_name = d.get("pass_name", "")
  178. # Parse action string back to NodeSourceAction enum list
  179. action_str = d.get("action", "")
  180. actions = []
  181. if action_str:
  182. for action_name in action_str.split("+"):
  183. if action_name.upper() == "CREATE":
  184. actions.append(NodeSourceAction.CREATE)
  185. elif action_name.upper() == "REPLACE":
  186. actions.append(NodeSourceAction.REPLACE)
  187. node_source.action = actions
  188. # Create the NodeInfo object directly
  189. if "name" in d and "target" in d and "graph_id" in d:
  190. node_info = NodeSource.NodeInfo(
  191. d.get("name", ""), d.get("target", ""), d.get("graph_id", -1)
  192. )
  193. node_source.node_info = node_info
  194. else:
  195. node_source.node_info = None
  196. # Recursively deserialize nested from_node
  197. if d.get("from_node", None) is not None:
  198. node_source.from_node = [
  199. result
  200. for fn in d.get("from_node", [])
  201. if (result := cls._from_dict(fn)) is not None
  202. ]
  203. else:
  204. node_source.from_node = []
  205. return node_source
  206. @compatibility(is_backward_compatible=False)
  207. @contextmanager
  208. def preserve_node_meta(enable=True):
  209. global should_preserve_node_meta
  210. global current_meta
  211. saved_should_preserve_node_meta = should_preserve_node_meta
  212. # Shallow copy is OK since fields of current_meta are not mutated
  213. saved_current_meta = current_meta.copy()
  214. try:
  215. should_preserve_node_meta = enable
  216. yield
  217. finally:
  218. should_preserve_node_meta = saved_should_preserve_node_meta
  219. current_meta = saved_current_meta
  220. @compatibility(is_backward_compatible=False)
  221. def set_stack_trace(stack: list[str]):
  222. global current_meta
  223. if should_preserve_node_meta and stack:
  224. current_meta["stack_trace"] = "".join(stack)
  225. @compatibility(is_backward_compatible=False)
  226. @contextmanager
  227. def annotate(annotation_dict: dict):
  228. """
  229. Temporarily adds custom annotations to the current tracing context.
  230. The fx_node produced from this tracing context will have the
  231. custom annotations in node.metadata["custom"] field.
  232. This context manager allows you to insert arbitrary metadata into the PT2
  233. tracing system by updating the global `current_meta["custom"]` dictionary.
  234. The annotations are automatically reverted after the context exits.
  235. Gradient accumulation nodes will not be annotated.
  236. This is intended for advanced users who need to attach additional metadata to the fx nodes
  237. (e.g., for debugging, analysis, or external tooling) during export tracing.
  238. Note:
  239. This API is **not backward compatible** and may evolve in future releases.
  240. Note:
  241. This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
  242. to be used with PT2 family of tracers, e.g. torch.export and dynamo.
  243. Args:
  244. annotation_dict (dict): A dictionary of custom key-value pairs to inject
  245. into the FX trace metadata.
  246. Example:
  247. After exiting the context, custom annotations are removed.
  248. >>> with annotate({"source": "custom_pass", "tag": 42}):
  249. ... pass # Your computation here
  250. """
  251. global current_meta
  252. has_custom = "custom" in current_meta
  253. old_custom = copy.copy(current_meta.get("custom", {}))
  254. try:
  255. if not has_custom:
  256. current_meta["custom"] = {}
  257. # Update with all key-value pairs from the input dict
  258. current_meta["custom"].update(annotation_dict)
  259. yield
  260. finally:
  261. if has_custom:
  262. # Restore the original custom dict
  263. current_meta["custom"] = old_custom
  264. else:
  265. del current_meta["custom"]
  266. @compatibility(is_backward_compatible=False)
  267. def annotate_fn(annotation_dict: dict):
  268. """
  269. A decorator that wraps a function with the annotate context manager.
  270. Use this when you want to annotate an entire function instead of a specific code block.
  271. Note:
  272. This API is **not backward compatible** and may evolve in future releases.
  273. Note:
  274. This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
  275. to be used with PT2 family of tracers, e.g. torch.export and dynamo.
  276. Args:
  277. annotation_dict (dict): A dictionary of custom key-value pairs to inject
  278. into the FX trace metadata for all operations in the function.
  279. Example:
  280. All operations in my_function will have {"pp_stage": 1} in their metadata.
  281. >>> @annotate_fn({"pp_stage": 1})
  282. ... def my_function(x):
  283. ... return x + 1
  284. """
  285. from functools import wraps
  286. def decorator(func):
  287. @wraps(func)
  288. def wrapper(*args, **kwargs):
  289. with annotate(annotation_dict):
  290. return func(*args, **kwargs)
  291. return wrapper
  292. return decorator
  293. @compatibility(is_backward_compatible=False)
  294. def set_grad_fn_seq_nr(seq_nr):
  295. global current_meta
  296. if should_preserve_node_meta:
  297. # The seq_nr is captured by eager mode in the grad_fn during forward
  298. current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [
  299. seq_nr
  300. ]
  301. current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
  302. @compatibility(is_backward_compatible=False)
  303. def reset_grad_fn_seq_nr():
  304. # NB: reset state properly, this would be helpful towards supporting
  305. # reentrant autograd if we actually wanted to do that.
  306. global current_meta
  307. if should_preserve_node_meta:
  308. current_level = current_meta.get("in_grad_fn", 0)
  309. assert current_level > 0
  310. if current_level == 1:
  311. del current_meta["in_grad_fn"]
  312. del current_meta["grad_fn_seq_nr"]
  313. else:
  314. current_meta["in_grad_fn"] = current_level - 1
  315. current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
  316. @compatibility(is_backward_compatible=False)
  317. def format_stack() -> list[str]:
  318. if should_preserve_node_meta:
  319. return [current_meta.get("stack_trace", "")]
  320. else:
  321. # fallback to traceback.format_stack()
  322. return traceback.format_list(traceback.extract_stack()[:-1])
  323. @compatibility(is_backward_compatible=False)
  324. def has_preserved_node_meta() -> bool:
  325. return should_preserve_node_meta
  326. @compatibility(is_backward_compatible=False)
  327. @contextmanager
  328. def set_current_meta(node, pass_name=""):
  329. global current_meta
  330. if should_preserve_node_meta and node.meta:
  331. saved_meta = current_meta
  332. try:
  333. current_meta = node.meta.copy()
  334. # Update the "from_node" field in current_meta for provenance tracking.
  335. # Instead of appending, overwrite the "from_node" field because current_meta
  336. # will be assigned to the new node. The new NodeSource(node, ...) will
  337. # include the information from the previous current_meta["from_node"].
  338. current_meta["from_node"] = [
  339. NodeSource(node, pass_name, NodeSourceAction.CREATE)
  340. ]
  341. yield
  342. finally:
  343. current_meta = saved_meta
  344. else:
  345. yield
  346. @compatibility(is_backward_compatible=False)
  347. def get_current_meta() -> dict[str, Any]:
  348. return current_meta
  349. @compatibility(is_backward_compatible=False)
  350. @contextmanager
  351. def set_current_replay_node(node):
  352. """
  353. Set the currently replay node. If `current_replay_node` is not None,
  354. then we're re-generating the `current_replay_node` in FunctionalTensorMode.
  355. """
  356. # See [Note] annotation for more details.
  357. global current_replay_node
  358. saved_current_replay_node = current_replay_node
  359. try:
  360. current_replay_node = node
  361. yield
  362. finally:
  363. current_replay_node = saved_current_replay_node
  364. @compatibility(is_backward_compatible=False)
  365. def get_current_replay_node():
  366. """
  367. Get the currently replay node
  368. """
  369. return current_replay_node
  370. @compatibility(is_backward_compatible=False)
  371. def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
  372. """
  373. Given an fx.Graph, return a json that contains the provenance information of each node.
  374. """
  375. try:
  376. provenance_tracking_json = {}
  377. for node in graph.nodes:
  378. if node.op == "call_function":
  379. provenance_tracking_json[node.name] = (
  380. [source.to_dict() for source in node.meta["from_node"]]
  381. if "from_node" in node.meta
  382. else []
  383. )
  384. return provenance_tracking_json
  385. except Exception as e:
  386. # Since this is just debugging, it should never interfere with regular
  387. # program execution, so we use this try-except to guard against any error
  388. signpost_event(
  389. "inductor",
  390. "provenance_tracking_error",
  391. {
  392. "function": "get_graph_provenance_json",
  393. "error_msg": str(e),
  394. "stack_trace": traceback.format_exc(),
  395. },
  396. )
  397. return {}
  398. def _get_custom_metadata(gm: GraphModule) -> str:
  399. assert isinstance(gm, GraphModule)
  400. def helper(gm: GraphModule):
  401. custom_metadata = []
  402. for node in gm.graph.nodes:
  403. if hasattr(node, "meta") and node.meta.get("custom", None):
  404. custom_metadata.append((node.op, node.name, node.meta["custom"]))
  405. if node.op == "get_attr" and isinstance(
  406. getattr(gm, node.target), GraphModule
  407. ):
  408. custom_metadata.append(helper(getattr(gm, node.target)))
  409. return custom_metadata
  410. return "\n".join(str(x) for x in helper(gm))