compiled_autograd.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620
  1. """
  2. Provides functionality for compiling PyTorch's autograd (automatic differentiation) system.
  3. This module implements compiled autograd, which traces and optimizes backward pass
  4. computations at runtime. The key components are:
  5. - AutogradCompilerInstance: Traces and compiles autograd graphs using FX
  6. - Context managers (_enable/_disable): Control when compiled autograd is active
  7. - Utility functions: Support graph manipulation, tensor operations, and hooks
  8. Compiled autograd can significantly improve backward pass performance by removing
  9. Python overhead and enabling additional optimizations. It works by capturing
  10. backward computations into an FX graph that can be compiled and optimized,
  11. while maintaining the same semantics as eager mode autograd.
  12. """
  13. import contextlib
  14. import functools
  15. import itertools
  16. import operator
  17. import time
  18. from collections import Counter, defaultdict
  19. from collections.abc import Generator, Sequence
  20. from typing import Any, Callable, Optional, TYPE_CHECKING, Union
  21. import torch
  22. import torch.utils._pytree as pytree
  23. from torch._dispatch.python import enable_python_dispatcher
  24. from torch._dynamo.external_utils import (
  25. call_accumulate_grad,
  26. call_backward,
  27. call_hook,
  28. FakeCompiledAutogradEngine,
  29. unwrap_maybe_dynamic_int,
  30. )
  31. from torch._dynamo.source import GetItemSource, LocalSource
  32. from torch._dynamo.utils import (
  33. counters,
  34. get_chromium_event_logger,
  35. lazy_format_graph_code,
  36. set_locals_to_steal,
  37. )
  38. from torch._functorch._aot_autograd.runtime_wrappers import (
  39. AutogradLazyBackwardCompileInfo,
  40. CachedAutogradLazyBackwardCompileInfo,
  41. )
  42. from torch._guards import compile_context, CompileContext, CompileId, Source
  43. from torch._logging import getArtifactLogger, trace_structured
  44. from torch._prims_common import clone_preserve_strides
  45. from torch._subclasses import FakeTensorMode
  46. from torch._subclasses.fake_tensor import FakeTensor
  47. from torch.fx import GraphModule
  48. from torch.fx.experimental._backward_state import BackwardState
  49. from torch.fx.experimental.proxy_tensor import (
  50. decompose,
  51. disable_autocast_cache,
  52. disable_proxy_modes_tracing,
  53. fetch_object_proxy,
  54. ProxyTorchDispatchMode,
  55. PythonKeyTracer,
  56. track_tensor_tree,
  57. )
  58. from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
  59. from torch.fx.traceback import preserve_node_meta, set_stack_trace
  60. from torch.types import FloatLikeType, IntLikeType
  61. from torch.utils._ordered_set import OrderedSet
  62. from torch.utils._traceback import CapturedTraceback
  63. if TYPE_CHECKING:
  64. from torch.fx.proxy import Proxy
  65. TURN_OFF_MSG = """You can turn off compiled autograd by either:
  66. 1. Moving the unsupported autograd call outside of the torch.compile'd region.
  67. 2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager.
  68. 3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call.
  69. 4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program."""
  70. compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
  71. verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
  72. def snapshot_verbose_logging_enabled() -> bool:
  73. return torch._logging._internal.log_state.is_artifact_enabled(
  74. "compiled_autograd_verbose"
  75. )
  76. def snapshot_cudagraph_enabled() -> bool:
  77. return torch._inductor.config.triton.cudagraphs
  78. def maybe_clone(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
  79. if x is not None:
  80. return clone_preserve_strides(x)
  81. return x
  82. def extract_bw_module(CompiledFunction: Any) -> Callable[..., Any]:
  83. if isinstance(
  84. CompiledFunction._lazy_backward_info, AutogradLazyBackwardCompileInfo
  85. ):
  86. return CompiledFunction._lazy_backward_info.bw_module
  87. elif isinstance(
  88. CompiledFunction._lazy_backward_info, CachedAutogradLazyBackwardCompileInfo
  89. ):
  90. with torch._subclasses.fake_tensor.unset_fake_temporarily():
  91. return CompiledFunction._lazy_backward_info.bw_module_fn()
  92. else:
  93. raise AssertionError(
  94. "Unexpected Lazy Backward Compilation Info Type. Please file an issue."
  95. )
  96. # Note: [Anomaly Mode Semantics in Compiled Autograd]
  97. # In the eager autograd engine, anomaly mode is able to detect NaNs
  98. # after each node. This is useful, because the executed code with
  99. # and without anomaly mode are the same. So assuming determinism,
  100. # a NaN in regular mode should also happen in anomaly mode.
  101. #
  102. # With torch.compile, following eager semantics would require inserting
  103. # runtime asserts to check for NaNs, which could prevent some fusions.
  104. # This results in different code being run with and without anomaly mode.
  105. # So different semantics are needed, this implementation below will check
  106. # for NaNs at the end of the autograd call, instead of after each node
  107. class NaNChecker:
  108. def __init__(self, accumulate_grad: bool) -> None:
  109. self.accumulate_grad = accumulate_grad
  110. self.params_indices: list[int] = []
  111. self.params_to_check: dict[str, torch.Tensor] = {}
  112. self.output_names: list[str] = []
  113. def prep_with_graph(self, graph: torch.fx.Graph) -> None:
  114. inputs_node = next(iter(graph.nodes))
  115. acc_grad_nodes = graph.find_nodes(
  116. op="call_function", target=call_accumulate_grad
  117. )
  118. output_nodes = graph.find_nodes(op="output")[0].args[0]
  119. assert self.accumulate_grad == bool(
  120. acc_grad_nodes
  121. ) and self.accumulate_grad == (not output_nodes)
  122. for node in acc_grad_nodes:
  123. param_node = node.args[0]
  124. # AccumulateGrad always saves a reference to the param
  125. # so Compiled Autograd will always lift the param and
  126. # this should always be true
  127. assert (
  128. param_node.target == operator.getitem
  129. and param_node.args[0] is inputs_node # type: ignore[possibly-undefined]
  130. and isinstance(param_node.args[1], int)
  131. )
  132. self.params_indices.append(param_node.args[1])
  133. self.output_names = [node.name for node in output_nodes]
  134. def prep_with_inputs(self, inputs: tuple[torch.Tensor]) -> None:
  135. if not self.accumulate_grad:
  136. # Using .grad, nothing to prep
  137. return
  138. # Using .backward, we must check existing grads on params if any
  139. for idx in self.params_indices:
  140. grad = inputs[idx].grad
  141. if grad is not None:
  142. assert not torch.isnan(grad).any(), (
  143. f"Compiled autograd running under anomaly mode with inputs[{idx}] already "
  144. "having NaN gradient. This is not supported. {TURN_OFF_MSG}"
  145. )
  146. self.params_to_check[f"inputs[{idx}]"] = inputs[idx]
  147. def check(self, out: tuple[torch.Tensor]) -> None:
  148. if self.accumulate_grad:
  149. # Using .backward, graph outputs are empty
  150. assert not out
  151. nan_params: list[str] = []
  152. for inputs_str, param in self.params_to_check.items():
  153. assert param.grad is not None # not true for autograd.grad
  154. if torch.isnan(param.grad).any():
  155. nan_params.append(inputs_str)
  156. if nan_params:
  157. raise RuntimeError(
  158. f"Compiled Autograd returned NaN gradients for parameters: {','.join(nan_params)}."
  159. )
  160. else:
  161. # Using .grad, graph outputs are grads
  162. nan_grads: list[str] = []
  163. for i, grad in enumerate(out):
  164. if torch.isnan(grad).any():
  165. nan_grads.append(self.output_names[i])
  166. if nan_grads:
  167. raise RuntimeError(
  168. f"Compiled Autograd returned NaN gradients for output nodes: {','.join(nan_grads)}."
  169. )
  170. # We lazily bind "functional backward" variants for PyTorch built-in autograd
  171. # nodes to this class. Example: torch._dynamo.compiled_autograd.ops.MulBackward0
  172. # Each "functional backward" is bound the first time the node's apply_with_saved
  173. # function is called. It's possible to avoid lazy binding and instead bind
  174. # all of this upfront (perhaps at import time) via codegen changes.
  175. class OpNamespace:
  176. def __init__(self) -> None:
  177. self.custom_function_name_counter: Counter[str] = Counter()
  178. def add(
  179. self,
  180. name: str,
  181. fn: Callable[..., Any],
  182. is_custom_function: bool,
  183. is_traceable: bool,
  184. ) -> str:
  185. if is_custom_function:
  186. name = "CppNode" + name
  187. count = self.custom_function_name_counter[name]
  188. self.custom_function_name_counter[name] += 1
  189. name = f"{name}{count}"
  190. assert not hasattr(self, name)
  191. result = Op(name, fn, is_custom_function)
  192. if is_traceable:
  193. setattr(self, name, torch._dynamo.allow_in_graph(result))
  194. else:
  195. # C++ autograd function was not marked as traceable
  196. # Dynamo can't dry run it at compile time, so must fallback to eager
  197. @torch._dynamo.disable # type: ignore[misc]
  198. def run_non_traceable_cpp_in_eager(*args: Any, **kwargs: Any) -> Any:
  199. return result(*args, **kwargs)
  200. setattr(self, name, run_non_traceable_cpp_in_eager)
  201. return name
  202. def get(self, name: str) -> Any:
  203. return getattr(self, name)
  204. class Op:
  205. def __init__(
  206. self, name: str, fn: Callable[..., Any], is_custom_function: bool
  207. ) -> None:
  208. self.fn = fn
  209. self.is_custom_function = is_custom_function
  210. self.__name__ = name
  211. self.__module__ = "torch._dynamo.compiled_autograd.ops"
  212. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  213. return self.fn(*args, **kwargs)
  214. def __repr__(self) -> str:
  215. return self.__module__ + "." + self.__name__
  216. ops = OpNamespace()
  217. _graph_placeholders = ["inputs", "sizes", "scalars", "hooks", "packed_data"]
  218. _impure_targets = OrderedSet(
  219. [
  220. call_hook,
  221. call_backward,
  222. FakeCompiledAutogradEngine._exec_final_callbacks_stub,
  223. call_accumulate_grad,
  224. ]
  225. )
  226. COMPILE_COUNTER = itertools.count()
  227. def make_compile_context(compiled_autograd_id: int) -> Any:
  228. return compile_context(
  229. CompileContext(
  230. CompileId(
  231. compiled_autograd_id=compiled_autograd_id,
  232. frame_id=None,
  233. frame_compile_id=None,
  234. )
  235. )
  236. )
  237. class AutogradCompilerInstance:
  238. def __init__(self, compiler_fn: Callable[..., Any]) -> None:
  239. self.compiler_fn = compiler_fn
  240. self.stack = contextlib.ExitStack()
  241. self.close = self.stack.close
  242. self.shape_env = ShapeEnv()
  243. self.fake_tensor_mode = FakeTensorMode(
  244. allow_fallback_kernels=True,
  245. allow_non_fake_inputs=True,
  246. shape_env=self.shape_env,
  247. )
  248. self.fx_tracer = PythonKeyTracer()
  249. self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
  250. self.hooks_proxy: Optional[Proxy] = None
  251. def wrap_fake(self, x: torch.Tensor, source: Optional[Source]) -> FakeTensor:
  252. assert isinstance(x, torch.Tensor)
  253. return self.fake_tensor_mode.from_tensor(x, source=source)
  254. @staticmethod
  255. def source(name: str, idx: Any) -> GetItemSource:
  256. return GetItemSource(LocalSource(name), idx)
  257. def begin_capture(
  258. self,
  259. inputs: list[torch.Tensor],
  260. sizes: list[int],
  261. scalars: list[Union[int, float]],
  262. origins: list[list[tuple[int, str]]],
  263. accumulate_grad: bool,
  264. check_nans: bool,
  265. ) -> tuple[str, list[torch.Tensor], list[IntLikeType], list[FloatLikeType]]:
  266. counters["compiled_autograd"]["captures"] += 1
  267. self.id = next(COMPILE_COUNTER)
  268. self.aot_id_counter: dict[int, int] = defaultdict(int)
  269. self.compile_context = make_compile_context(self.id)
  270. self.compile_context.__enter__()
  271. self.nan_checker = NaNChecker(accumulate_grad) if check_nans else None
  272. self.start_time_ns = time.time_ns()
  273. get_chromium_event_logger().log_event_start(
  274. "compiled_autograd",
  275. self.start_time_ns,
  276. {"graph_id": self.id},
  277. log_pt2_compile_event=True,
  278. )
  279. self.fx_tracer.root = torch.nn.Module()
  280. self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
  281. self.fx_tracer.tensor_attrs = {}
  282. self.symnode_proxy_lookup = {}
  283. (
  284. args_proxy,
  285. self.sizes_proxy,
  286. self.scalars_proxy,
  287. self.hooks_proxy,
  288. self.packed_data_proxy,
  289. ) = (
  290. self.fx_tracer.create_proxy("placeholder", name, (), {})
  291. for name in _graph_placeholders
  292. )
  293. self.stack.enter_context(preserve_node_meta())
  294. inputs_origins, sizes_origins, scalars_origins = origins
  295. # Turn on PythonDispatcher during initial trace to make it identifiable
  296. # that tracing is happening, which is needed to prevent hashing symints
  297. self.stack.enter_context(enable_python_dispatcher())
  298. # tensor inputs to fake tensors
  299. x = inputs[0] # mypy will complain about unbound x
  300. try:
  301. for idx, x in enumerate(inputs):
  302. inputs[idx] = self.wrap_fake(x, self.source("inputs", idx))
  303. except Exception as e:
  304. raise NotImplementedError(
  305. f"Found tensor of type {type(x)}, which is not supported by FakeTensorMode. {TURN_OFF_MSG}"
  306. ) from e
  307. self.bind_objects_to_proxies(inputs, args_proxy, inputs_origins)
  308. # size inputs to symints
  309. sym_sizes = [
  310. self.shape_env.create_unspecified_symint_and_symbol(
  311. val,
  312. self.source("sizes", idx),
  313. DimDynamic.DYNAMIC,
  314. )
  315. for idx, val in enumerate(sizes)
  316. ]
  317. # We want to mark every size as dynamic, but since there's no way to
  318. # mark a primitive `int` as dynamic, we need to wrap it in a tensor.
  319. # In the graph, we unwrap it with `unwrap_maybe_dynamic_int` back into a primitive.
  320. proxies = [self.sizes_proxy[i] for i in range(len(sym_sizes))] # type: ignore[index]
  321. for i, symint in enumerate(sym_sizes):
  322. proxies[i] = self.fx_tracer.create_proxy(
  323. "call_function",
  324. unwrap_maybe_dynamic_int,
  325. (proxies[i],),
  326. {},
  327. )
  328. self.symnode_proxy_lookup[symint.node] = proxies[i]
  329. proxies = self.bind_objects_to_proxies(sym_sizes, proxies, sizes_origins)
  330. for idx, val in enumerate(scalars):
  331. source = self.source("scalars", idx)
  332. if isinstance(val, int):
  333. scalars[idx] = self.shape_env.create_unspecified_symint_and_symbol(
  334. val,
  335. source,
  336. DimDynamic.DYNAMIC,
  337. )
  338. elif isinstance(val, float):
  339. scalars[idx] = self.shape_env.create_symfloatnode(
  340. self.shape_env.create_unspecified_symbol(
  341. val,
  342. source=source,
  343. dynamic_dim=DimDynamic.DYNAMIC,
  344. ),
  345. hint=val,
  346. source=source,
  347. )
  348. else:
  349. raise AssertionError("Unexpected scalar type: ", type(val))
  350. self.bind_objects_to_proxies(scalars, self.scalars_proxy, scalars_origins)
  351. for i, symval in enumerate(scalars):
  352. self.symnode_proxy_lookup[symval.node] = self.scalars_proxy[i] # type: ignore[union-attr]
  353. # TODO(jansel): are all these modes needed?
  354. self.stack.enter_context(decompose({}))
  355. self.stack.enter_context(self.fake_tensor_mode)
  356. self.stack.enter_context(self.proxy_mode)
  357. self.stack.enter_context(disable_autocast_cache())
  358. # Needed to make sure we don't accidentally specialize any symbols
  359. assert self.fake_tensor_mode.shape_env is not None
  360. env = self.fake_tensor_mode.shape_env
  361. self.stack.enter_context(
  362. torch.fx.experimental.symbolic_shapes._suppress_guards(env)
  363. )
  364. return (
  365. str(CompileContext.current_compile_id()),
  366. inputs,
  367. sym_sizes,
  368. scalars, # type: ignore[return-value]
  369. )
  370. def log_compile_reasons(
  371. self,
  372. compile_reasons: list[str],
  373. ) -> None:
  374. assert compile_reasons
  375. trace_structured(
  376. "artifact",
  377. metadata_fn=lambda: {
  378. "name": "compiled_autograd_compile_reasons",
  379. "encoding": "json",
  380. },
  381. payload_fn=lambda: compile_reasons,
  382. )
  383. def proxy_call_aot_backward(
  384. self,
  385. pinputs: Sequence[Any],
  386. psaved_tensors: Sequence[torch.Tensor],
  387. saved_tensors: Sequence[torch.Tensor],
  388. pctx: Any,
  389. ctx: Any,
  390. maybe_backward_state_idx: Optional[int],
  391. ) -> Sequence[Any]:
  392. # The AOTBackward call consists of three things: the prologue, the
  393. # backward graph, and the epilogue.
  394. # Our strategy is:
  395. # - allow_in_graph the prologue (in the CA graph and Dynamo graph),
  396. # - copy-paste the backward graph into the CA graph so that CA passes and Dynamo can see it
  397. # - trace directly through the epilogue. Anything that gets baked in is
  398. # constant metadata (for example, metadata about the number of outputs, or removing
  399. # RNG arguments or effect tokens).
  400. # If Dynamo graph capture were better, then we could add a node for the prologue
  401. # into the CA graph and have Dynamo trace into it.
  402. psymints = [self.to_proxy(e) for e in ctx._get_compiled_autograd_symints()]
  403. # NOTE: we should only close over constants
  404. CompiledFunction = ctx._forward_cls
  405. bw_module = extract_bw_module(CompiledFunction)
  406. metadata = CompiledFunction.metadata
  407. maybe_subclass_metadata = CompiledFunction.maybe_subclass_metadata
  408. aot_id = CompiledFunction._aot_id
  409. del CompiledFunction
  410. if torch.is_grad_enabled():
  411. for output_alias_info in metadata.output_info:
  412. if output_alias_info.requires_grad:
  413. raise RuntimeError(
  414. "torch.compile does not currently support higher order gradients."
  415. )
  416. @torch._dynamo.allow_in_graph # type: ignore[misc]
  417. def call_aot_bwd_prologue(
  418. ctx_saved_tensors: Sequence[torch.Tensor],
  419. ctx_symints: Sequence[IntLikeType],
  420. *flat_args: Sequence[Any],
  421. ) -> Any:
  422. out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional(
  423. ctx_saved_tensors,
  424. ctx_symints,
  425. metadata,
  426. maybe_subclass_metadata,
  427. *flat_args,
  428. )
  429. return out
  430. pgrads = self.fx_tracer.create_proxy(
  431. kind="call_function",
  432. target=call_aot_bwd_prologue,
  433. args=(
  434. psaved_tensors,
  435. psymints,
  436. *pinputs,
  437. ),
  438. kwargs={},
  439. )
  440. pbackward_state = None
  441. if maybe_backward_state_idx is not None:
  442. pbackward_state = self.hooks_proxy[maybe_backward_state_idx] # type: ignore[index]
  443. # Copy-paste the AOT backward graph into the compiled autograd graph
  444. def copy_paste_aot_backward_graph() -> list[torch.Tensor]:
  445. def num_inputs(graph: torch.fx.Graph) -> int:
  446. num_args = 0
  447. for node in graph.nodes:
  448. if node.op == "placeholder":
  449. num_args += 1
  450. continue
  451. else:
  452. break
  453. return num_args
  454. # set up the proxy inputs to bw_module
  455. # the calling convention is: [*symints, *args (primals and tangents), backward_state]
  456. num_args = num_inputs(bw_module.graph) # type: ignore[attr-defined]
  457. pall_args = [
  458. pgrads[i] for i in range(num_args - int(pbackward_state is not None))
  459. ]
  460. # replace the symints with our symints
  461. symints = ctx._get_compiled_autograd_symints()
  462. assert len(symints) == len(ctx.symints)
  463. psymints = [self.to_proxy(e) for e in symints]
  464. pall_args[: len(symints)] = psymints
  465. # Add backward_state
  466. if pbackward_state is not None:
  467. pall_args.append(pbackward_state)
  468. # run over all nodes of the aot_backward graph.
  469. # copy and paste them all into the compiled autograd graph.
  470. args_idx = 0
  471. value_remap = {}
  472. poutputs: Optional[list[torch.fx.Proxy]] = None
  473. # names of nodes must appear only once in the fx.Graph
  474. # dedup AOT backwards that appear multiple times
  475. deduped_aot_id = str(aot_id)
  476. if self.aot_id_counter[aot_id]:
  477. deduped_aot_id += f"_{self.aot_id_counter[aot_id]}"
  478. self.aot_id_counter[aot_id] += 1
  479. def make_unique(node_name: str) -> str:
  480. # make it both informative and unique
  481. return f"aot{deduped_aot_id}_{node_name}"
  482. for node in bw_module.graph.nodes: # type: ignore[attr-defined]
  483. if node.op == "placeholder":
  484. ph = pall_args[args_idx].node
  485. ph.name = make_unique(node.name)
  486. value_remap[node] = ph
  487. args_idx += 1
  488. elif node.op == "output":
  489. assert len(node.args) == 1
  490. poutputs = [
  491. torch.fx.Proxy(value_remap[n], self.fx_tracer)
  492. if isinstance(n, torch.fx.Node)
  493. else n
  494. for n in node.args[0]
  495. ]
  496. elif node.op == "get_attr":
  497. name = node.target
  498. qualname = self.fx_tracer.get_fresh_qualname(name)
  499. setattr(self.fx_tracer.root, qualname, getattr(bw_module, name))
  500. result = self.fx_tracer.create_node("get_attr", qualname, (), {})
  501. result.name = make_unique(node.name)
  502. value_remap[node] = result
  503. elif node.op == "call_function":
  504. if node.target == torch.ops.aten.view.default:
  505. # this aot bwd graph is being lazily compiled
  506. # we must manually apply the view_to_reshape post grad pass
  507. # since it was already applied to the aot fwd, and baked into the gradients
  508. node.target = torch.ops.aten.reshape.default
  509. result = self.fx_tracer.graph.node_copy(
  510. node, lambda n: value_remap[n]
  511. )
  512. result.name = make_unique(node.name)
  513. value_remap[node] = result
  514. elif node.op == "call_module":
  515. name = node.target
  516. qualname = self.fx_tracer.get_fresh_qualname(name)
  517. setattr(self.fx_tracer.root, qualname, getattr(bw_module, name))
  518. result = self.fx_tracer.graph.node_copy(
  519. node, lambda n: value_remap[n]
  520. )
  521. result.target = qualname
  522. value_remap[node] = result
  523. else:
  524. raise AssertionError("shouldn't get here")
  525. assert poutputs is not None
  526. # In general we don't know what the shapes of the outputs are, so allocate
  527. # some dummy sizes for them.
  528. def dummy() -> torch.Tensor:
  529. with disable_proxy_modes_tracing():
  530. return torch.zeros(0, 0, 0, 0, 123)
  531. outputs = [
  532. dummy() if isinstance(o, torch.fx.Proxy) else o for o in poutputs
  533. ]
  534. self.bind_objects_to_proxies(outputs, poutputs)
  535. return outputs
  536. outputs = copy_paste_aot_backward_graph()
  537. def proxy_subclass_constructor(
  538. subclass_meta: Any, is_runtime: bool, unwrapped_args: Sequence[Any]
  539. ) -> torch.Tensor:
  540. @torch._dynamo.allow_in_graph # type: ignore[misc]
  541. def make_subclass(*unwrapped_args: Any) -> Any:
  542. return subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime)
  543. punwrapped_args = pytree.tree_map(self.to_proxy, unwrapped_args)
  544. poutput = self.fx_tracer.create_proxy(
  545. kind="call_function",
  546. target=make_subclass,
  547. args=tuple(punwrapped_args),
  548. kwargs={},
  549. )
  550. output = self.allocate_dummy()
  551. self.bind_objects_to_proxies([output], [poutput])
  552. return output
  553. results = torch._functorch._aot_autograd.runtime_wrappers._backward_epilogue_functional(
  554. metadata,
  555. maybe_subclass_metadata,
  556. outputs,
  557. make_subclass_override=proxy_subclass_constructor,
  558. )
  559. presults = pytree.tree_map(self.to_proxy, results)
  560. return presults
  561. def proxy_call_backward(
  562. self,
  563. inputs: Sequence[Any],
  564. output_metadatas: Sequence[Any],
  565. saved_tensors: Sequence[torch.Tensor],
  566. backward_idx: int,
  567. ctx: torch.autograd.function.BackwardCFunction,
  568. maybe_backward_state_idx: Optional[int],
  569. ) -> tuple[Optional[torch.Tensor], ...]:
  570. assert self.hooks_proxy is not None
  571. pctx = self.hooks_proxy[backward_idx] # type: ignore[index]
  572. pinputs = self.to_proxy(inputs)
  573. psaved_tensors = self.to_proxy(saved_tensors)
  574. if hasattr(ctx._forward_cls, "_aot_id"): # type: ignore[attr-defined]
  575. # AOT backward
  576. proxies = self.proxy_call_aot_backward(
  577. pinputs,
  578. psaved_tensors,
  579. saved_tensors,
  580. pctx,
  581. ctx,
  582. maybe_backward_state_idx,
  583. )
  584. else:
  585. proxies = self.fx_tracer.create_proxy(
  586. kind="call_function",
  587. target=call_backward,
  588. args=(
  589. pctx,
  590. psaved_tensors,
  591. *pinputs,
  592. ),
  593. kwargs={},
  594. )
  595. assert proxies is not None
  596. with disable_proxy_modes_tracing():
  597. # create fake Tensors
  598. grad_ins: list[Optional[torch.Tensor]] = []
  599. for idx, output_metadata in enumerate(output_metadatas):
  600. if output_metadata is None or proxies[idx] is None:
  601. grad_ins.append(None)
  602. continue
  603. layout, device, dtype, size = output_metadata
  604. grad_ins.append(
  605. torch.empty(size=size, dtype=dtype, layout=layout, device=device)
  606. )
  607. self.bind_objects_to_proxies(grad_ins, proxies)
  608. return tuple(grad_ins)
  609. def call_copy_slices_prologue(
  610. self,
  611. inputs: Sequence[Any],
  612. base_sizes: Sequence[Any],
  613. base_strides: Sequence[Any],
  614. base_storage_offset: Any,
  615. view_sizes: Sequence[Any],
  616. view_strides: Sequence[Any],
  617. view_storage_offset: Any,
  618. ) -> Sequence[torch.Tensor]:
  619. args = (
  620. inputs,
  621. self.to_proxy(base_sizes),
  622. self.to_proxy(base_strides),
  623. self.to_proxy(base_storage_offset),
  624. self.to_proxy(view_sizes),
  625. self.to_proxy(view_strides),
  626. self.to_proxy(view_storage_offset),
  627. )
  628. return self.proxy_call(copy_slices_prologue, args, [None] * 3)
  629. def call_copy_slices_epilogue(
  630. self,
  631. needs_input_grad: Sequence[bool],
  632. result: torch.Tensor,
  633. res: Sequence[Any],
  634. grad_slice: torch.Tensor,
  635. ) -> Sequence[torch.Tensor]:
  636. return self.proxy_call(
  637. copy_slices_epilogue,
  638. (needs_input_grad, result, res, grad_slice),
  639. [None] * len(needs_input_grad),
  640. )
  641. def allocate_dummy(self) -> torch.Tensor:
  642. with disable_proxy_modes_tracing():
  643. # Weird quantity so it's easy to grep
  644. return torch.zeros([0, 123456789])
  645. def bind_function(
  646. self,
  647. fn_name: str,
  648. fn: Callable[..., Any],
  649. is_custom_function: bool,
  650. is_traceable: bool,
  651. ) -> str:
  652. """Binds ops.fn_name = fn"""
  653. return ops.add(fn_name, fn, is_custom_function, is_traceable)
  654. def apply_functional(
  655. self,
  656. fn_name: str,
  657. grads: Sequence[Any],
  658. args: Any,
  659. output_metadata: Sequence[Any],
  660. ) -> Sequence[torch.Tensor]:
  661. """Proxies a call to ops.fn_name(grads, *args) into the graph"""
  662. op = ops.get(fn_name)
  663. return self.proxy_call(op, (grads, *args), output_metadata)
  664. def proxy_call(
  665. self, fn: Callable[..., Any], args: Any, output_metadata: Sequence[Any]
  666. ) -> Sequence[torch.Tensor]:
  667. """Proxies a call to fn(*args) into the graph"""
  668. flat_args, _ = pytree.tree_flatten(args)
  669. proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args)
  670. proxy_out = self.fx_tracer.create_proxy(
  671. "call_function", fn, args=proxy_args, kwargs={}
  672. )
  673. result = [self.allocate_dummy() for _ in output_metadata]
  674. self.bind_objects_to_proxies(result, [proxy_out[i] for i in range(len(result))])
  675. return result
  676. def validate_outputs(
  677. self, _: Any, outputs: Sequence[Any], args: Any, output_metadata: Sequence[Any]
  678. ) -> Sequence[torch.Tensor]:
  679. """Proxies a call to ops.validate_outputs(outputs, *args) into the graph"""
  680. op = ops.get("validate_outputs")
  681. proxy_args = pytree.tree_map(self.to_proxy, (outputs, *args))
  682. new_proxy_outputs = self.fx_tracer.create_proxy(
  683. "call_function", op, args=proxy_args, kwargs={}
  684. )
  685. assert len(output_metadata) == len(outputs)
  686. self.bind_objects_to_proxies(outputs, new_proxy_outputs)
  687. return outputs
  688. def accumulate(self, old_var: Any, new_var: Any) -> torch.Tensor:
  689. old_var_proxy = self.to_proxy(old_var)
  690. new_var_proxy = self.to_proxy(new_var)
  691. proxy_out = self.fx_tracer.create_proxy(
  692. "call_function", torch.add, args=(old_var_proxy, new_var_proxy), kwargs={}
  693. )
  694. result = self.allocate_dummy()
  695. self.bind_objects_to_proxies([result], [proxy_out])
  696. return result
  697. def accumulate_grad(
  698. self, variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool
  699. ) -> None:
  700. self.fx_tracer.create_proxy(
  701. "call_function",
  702. call_accumulate_grad,
  703. args=(
  704. self.to_proxy(variable),
  705. self.to_proxy(grad),
  706. has_post_hooks,
  707. ),
  708. kwargs={},
  709. )
  710. def proxy_call_hook(
  711. self, hook: Callable[..., Any], *args: Any, **kwargs: Any
  712. ) -> torch.fx.Proxy:
  713. return self.fx_tracer.create_proxy(
  714. "call_function",
  715. call_hook,
  716. (
  717. hook,
  718. *[self.to_proxy(x) for x in args],
  719. ),
  720. kwargs,
  721. )
  722. def unpack_hook(self, hook_id: int, data_id: int) -> torch.Tensor:
  723. assert self.hooks_proxy is not None
  724. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  725. data = self.packed_data_proxy[data_id] # type: ignore[index]
  726. proxy = self.proxy_call_hook(
  727. hook,
  728. data,
  729. hook_type="unpack_hook",
  730. )
  731. out = self.allocate_dummy()
  732. self.bind_objects_to_proxies([out], [proxy])
  733. return out
  734. def tensor_pre_hook(
  735. self, inputs: list[torch.Tensor], hook_id: int, i: int
  736. ) -> list[torch.Tensor]:
  737. assert self.hooks_proxy is not None
  738. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  739. proxy = self.proxy_call_hook(
  740. hook,
  741. inputs[i],
  742. hook_type="tensor_pre_hook",
  743. )
  744. with disable_proxy_modes_tracing():
  745. inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment]
  746. self.bind_objects_to_proxies([inputs[i]], [proxy])
  747. return inputs
  748. def cpp_tensor_pre_hook(
  749. self, inputs: list[torch.Tensor], hook_id: int, i: int
  750. ) -> list[torch.Tensor]:
  751. proxy = self.fx_tracer.create_proxy(
  752. "call_function",
  753. torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks,
  754. (hook_id, self.to_proxy(inputs[i])),
  755. {},
  756. )
  757. with disable_proxy_modes_tracing():
  758. inputs[i] = maybe_clone(inputs[i]) # type: ignore[assignment]
  759. self.bind_objects_to_proxies([inputs[i]], [proxy])
  760. return inputs
  761. def pre_hook(self, inputs: Sequence[Any], hook_id: int) -> list[torch.Tensor]:
  762. assert self.hooks_proxy is not None
  763. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  764. proxies = self.proxy_call_hook(
  765. hook,
  766. inputs,
  767. hook_type="pre_hook",
  768. )
  769. with disable_proxy_modes_tracing():
  770. inputs = [maybe_clone(x) for x in inputs]
  771. self.bind_objects_to_proxies(inputs, proxies)
  772. return inputs
  773. def post_hook(
  774. self, outputs: list[torch.Tensor], inputs: Sequence[torch.Tensor], hook_id: int
  775. ) -> list[torch.Tensor]:
  776. assert self.hooks_proxy is not None
  777. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  778. proxies = self.proxy_call_hook(
  779. hook,
  780. outputs,
  781. inputs,
  782. hook_type="post_hook",
  783. )
  784. with disable_proxy_modes_tracing():
  785. outputs = [maybe_clone(x) for x in outputs] # type: ignore[misc]
  786. self.bind_objects_to_proxies(outputs, proxies)
  787. return outputs
  788. def post_acc_grad_hook(
  789. self, input: torch.Tensor, hook_id: int
  790. ) -> list[torch.Tensor]:
  791. assert isinstance(input, torch.Tensor)
  792. assert self.hooks_proxy is not None
  793. hook = self.hooks_proxy[hook_id] # type: ignore[index]
  794. proxy = self.proxy_call_hook(
  795. hook,
  796. input,
  797. hook_type="post_acc_grad_hook",
  798. )
  799. with disable_proxy_modes_tracing():
  800. res = [maybe_clone(input)]
  801. self.bind_objects_to_proxies(res, [proxy])
  802. return res # type: ignore[return-value]
  803. # Note: [Compiled autograd and cudagraphs]
  804. # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_.
  805. # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph
  806. # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the
  807. # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too.
  808. def move_graph_nodes_to_cuda(self, graph: torch.fx.Graph) -> list[int]:
  809. to_move: dict[int, torch.fx.Node] = {}
  810. has_cuda_inputs = False
  811. nodes = list(graph.nodes)
  812. assert nodes[0].target == "inputs"
  813. inputs = nodes[0]
  814. inputs_users = list(inputs.users.keys())
  815. # input access nodes should immediately follow placeholder nodes
  816. first_getitem_idx = len(_graph_placeholders)
  817. assert nodes[first_getitem_idx] == inputs_users[0]
  818. last_getitem_idx = first_getitem_idx + len(inputs_users) - 1
  819. assert nodes[last_getitem_idx] == inputs_users[-1]
  820. # getitem nodes on inputs
  821. for i, node in enumerate(inputs_users):
  822. if not has_cuda_inputs and node.meta["val"].device.type == "cuda":
  823. has_cuda_inputs = True
  824. continue
  825. is_cpu = node.meta["val"].device.type == "cpu"
  826. is_scalar = len(node.meta["val"].size()) == 0
  827. if is_cpu and is_scalar:
  828. node_users = list(node.users.keys())
  829. # We can only move the cpu scalar if it is not exposed to user code.
  830. if all(
  831. (
  832. isinstance(user.target, torch._ops.OpOverload)
  833. and user.target.namespace in ("prims", "aten")
  834. )
  835. or (
  836. isinstance(user.target, Op)
  837. and not user.target.is_custom_function
  838. )
  839. for user in node_users
  840. ):
  841. # all users are prims/aten, can move safely
  842. to_move[i] = node
  843. # only move cpu scalars to cuda if there were cuda activations in this graph,
  844. # this is to handle the case where cudagraphs is enabled on a cpu-only graph
  845. if has_cuda_inputs:
  846. for node in to_move.values():
  847. verbose_log.debug("Moving node %s from cpu to cuda", node)
  848. node.meta["val"] = node.meta["val"].cuda()
  849. # return runtime indices we need to move to cuda
  850. return list(to_move.keys())
  851. return []
  852. def is_sym_node(self, node: Any) -> bool:
  853. return (
  854. isinstance(node, torch.fx.Node)
  855. and node.op == "call_function"
  856. and node.target
  857. in [torch.ops.aten.sym_size.int, torch.ops.aten.sym_numel.default]
  858. )
  859. def dce(self) -> None:
  860. # Most of these removed nodes would have been removed during Dynamo and AOTDispatch
  861. # Remove some of these nodes earlier to improve compilation speed
  862. # Dynamo guards will error instead of creating aliasing guards unless we unpack them in the graph
  863. unpack_nodes: OrderedSet[torch.fx.Node] = OrderedSet()
  864. for i, node in enumerate(self.fx_tracer.graph.find_nodes(op="placeholder")):
  865. unpack_nodes.update(node.users.keys())
  866. assert i == len(_graph_placeholders) - 1
  867. def is_impure(node: torch.fx.Node) -> bool:
  868. if node in unpack_nodes or (
  869. node.op == "call_function" and node.target in _impure_targets
  870. ):
  871. return True
  872. return node.is_impure()
  873. before = len(self.fx_tracer.graph.nodes)
  874. self.fx_tracer.graph.eliminate_dead_code(is_impure)
  875. after = len(self.fx_tracer.graph.nodes)
  876. verbose_log.debug("DCE removed %d nodes", before - after)
  877. def remove_unused_sizes(self) -> set[int]:
  878. used_sizes = []
  879. unused_sizes = []
  880. # seek placeholder, should be at nodes[1]
  881. it = iter(self.fx_tracer.graph.nodes)
  882. next(it)
  883. sizes_node = next(it)
  884. assert sizes_node.name == "sizes"
  885. for getitem_node in sizes_node.users.keys():
  886. assert getitem_node.target == operator.getitem
  887. if getitem_node.users:
  888. used_sizes.append(getitem_node)
  889. else:
  890. # remove from the graph
  891. unused_sizes.append(getitem_node)
  892. used_sizes_idx: set[int] = set()
  893. for used in used_sizes:
  894. assert isinstance(used.args, tuple)
  895. assert used.args[0] == sizes_node
  896. assert isinstance(used.args[1], int)
  897. next_size_idx = len(used_sizes_idx)
  898. # used later reindex the runtime sizes arg
  899. used_sizes_idx.add(used.args[1])
  900. # reindex the graph
  901. used.args = (used.args[0], next_size_idx)
  902. for unused in unused_sizes:
  903. self.fx_tracer.graph.erase_node(unused)
  904. return used_sizes_idx
  905. def create_graph_module(self, id: str) -> GraphModule:
  906. return GraphModule(self.fx_tracer.root, self.fx_tracer.graph, id)
  907. def end_capture(self, outputs: Any) -> tuple[Callable[..., Any], Any]:
  908. self.fx_tracer.create_proxy(
  909. "call_function",
  910. FakeCompiledAutogradEngine._exec_final_callbacks_stub,
  911. (),
  912. {},
  913. )
  914. self.stack.close()
  915. self.fx_tracer.create_node(
  916. "output",
  917. "output",
  918. (self.fx_tracer.create_arg(self.to_proxy(outputs)),),
  919. {},
  920. )
  921. runtime_inputs_to_move: list[int] = []
  922. if snapshot_cudagraph_enabled():
  923. runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
  924. # We traced using dummy tensors. Delete all the metadata of the dummy tensors.
  925. # It's probably better to refactor this class to use a different tracer
  926. # than the make_fx tracer, but that is a larger change.
  927. for node in self.fx_tracer.graph.nodes:
  928. for field in ["tensor_meta", "example_value", "val"]:
  929. if field in node.meta:
  930. del node.meta[field]
  931. trace_structured(
  932. "artifact",
  933. metadata_fn=lambda: {
  934. "name": "compiled_autograd_graph_pre_reordering",
  935. "encoding": "string",
  936. },
  937. payload_fn=lambda: GraphModule(
  938. self.fx_tracer.root,
  939. self.fx_tracer.graph,
  940. f"CompiledAutograd{self.id}PreReordering",
  941. ).print_readable(print_output=False),
  942. )
  943. self.delay_unpack_hook_nodes()
  944. self.reorder_tensor_pre_hook_nodes()
  945. self.reorder_pre_hook_nodes_to_schedule_asap()
  946. self.reorder_accumulate_grad_nodes()
  947. self.reorder_pre_hook_nodes_to_mimic_eager()
  948. self.reorder_post_acc_grad_hook_nodes()
  949. self.reorder_post_hook_nodes()
  950. # TODO(yf225): work around: remove dead codes like `sym_size` and `sym_numel` which are not used downstream. e.g.
  951. # ```
  952. # sym_numel_default = torch.ops.aten.sym_numel.default(sum_109); sum_109 = None
  953. # eq_115 = 16 == sym_numel_default; sym_numel_default = eq_115 = None
  954. # sym_size_int_39 = torch.ops.aten.sym_size.int(getitem_112, 1); getitem_112 = None
  955. # eq_116 = 16 == sym_size_int_39; eq_116 = None
  956. # eq_117 = 16 == sym_size_int_39; sym_size_int_39 = eq_117 = None
  957. # ```
  958. # Proper fix is Richard's Python compiled autograd effort which will avoid calling make_fx and
  959. # should prevent these ops from going into the CA graph.
  960. self.dce()
  961. if self.nan_checker:
  962. self.nan_checker.prep_with_graph(self.fx_tracer.graph)
  963. # keep only sizes that are actually used in the graph
  964. used_sizes_idx = self.remove_unused_sizes()
  965. graph = self.create_graph_module(f"CompiledAutograd{self.id}")
  966. set_locals_to_steal(graph, ["inputs"])
  967. lazy_graph_code = lazy_format_graph_code(
  968. "Compiled autograd graph",
  969. graph,
  970. include_device=True,
  971. include_stride=True,
  972. colored=True,
  973. )
  974. compiled_autograd_log.info("%s", lazy_graph_code)
  975. verbose_log.debug("%s", lazy_graph_code)
  976. trace_structured(
  977. "compiled_autograd_graph",
  978. payload_fn=lambda: graph.print_readable(print_output=False),
  979. )
  980. def runtime_wrapper(
  981. compiled_fn: Callable[..., Any],
  982. inputs: Any,
  983. sizes: Any,
  984. scalars: Any,
  985. hooks: Any,
  986. packed_inputs: Any,
  987. ) -> tuple[Any, Any]:
  988. global in_compiled_autograd_region
  989. try:
  990. in_compiled_autograd_region = True
  991. if self.nan_checker:
  992. self.nan_checker.prep_with_inputs(inputs)
  993. filtered_sizes = []
  994. for idx, integer in enumerate(sizes):
  995. if idx in used_sizes_idx:
  996. # can't create negative size
  997. if integer > 0:
  998. filtered_sizes.append(torch.empty(0, integer))
  999. torch._dynamo.maybe_mark_dynamic(filtered_sizes[-1], 1)
  1000. else:
  1001. filtered_sizes.append(integer)
  1002. for i in runtime_inputs_to_move:
  1003. inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
  1004. with _disable(), make_compile_context(self.id):
  1005. out = compiled_fn(
  1006. inputs, filtered_sizes, scalars, hooks, packed_inputs
  1007. )
  1008. if self.nan_checker:
  1009. self.nan_checker.check(out)
  1010. return out
  1011. finally:
  1012. in_compiled_autograd_region = False
  1013. get_chromium_event_logger().log_event_end(
  1014. "compiled_autograd",
  1015. time.time_ns(),
  1016. {"graph_id": self.id},
  1017. self.start_time_ns,
  1018. log_pt2_compile_event=True,
  1019. )
  1020. self.compile_context.__exit__(None, None, None)
  1021. return runtime_wrapper, self.compiler_fn(graph)
  1022. @staticmethod
  1023. def get_all_nodes(args: Sequence[Any]) -> list[torch.fx.Node]:
  1024. # filter out non-Node args, like None
  1025. nodes = [n for n in args if type(n) is torch.fx.Node]
  1026. return nodes
  1027. @staticmethod
  1028. def is_placeholder(node: torch.fx.Node) -> bool:
  1029. if node.op == "placeholder" or (
  1030. node.op == "call_function"
  1031. and node.target == operator.getitem
  1032. and node.args[0].op == "placeholder" # type: ignore[union-attr, arg-type]
  1033. ):
  1034. return True
  1035. return False
  1036. def reorder_accumulate_grad_nodes(self) -> None:
  1037. """
  1038. Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of
  1039. the graph. This differs from eager mode, which schedules them as soon as possible. This
  1040. pass attempts to reorder the graph to mimic eager behavior.
  1041. """
  1042. for node in self.fx_tracer.graph.find_nodes(
  1043. op="call_function", target=call_accumulate_grad
  1044. ):
  1045. param_node, grad_node = node.args[0], node.args[1]
  1046. getitem_node = None
  1047. if grad_node.target == operator.getitem:
  1048. getitem_node = grad_node
  1049. grad_node = getitem_node.args[0]
  1050. arg = max([param_node, grad_node]) # last arg
  1051. if arg is not node.prev and not self.is_placeholder(arg):
  1052. arg.append(node)
  1053. if getitem_node is not None:
  1054. arg.append(getitem_node)
  1055. def delay_unpack_hook_nodes(self) -> None:
  1056. """
  1057. We can delay unpack hooks until they are needed, even later than in the eager autograd engine.
  1058. """
  1059. for node in self.fx_tracer.graph.find_nodes(
  1060. op="call_function", target=call_hook
  1061. ):
  1062. if node.kwargs.get("hook_type", None) != "unpack_hook":
  1063. continue
  1064. first_user = min(node.users)
  1065. first_user.prepend(node)
  1066. def reorder_tensor_pre_hook_nodes(self) -> None:
  1067. """
  1068. Usage of AOTAutograd causes all the tensor_pre_hook nodes to get pushed
  1069. to the end of the graph. This differs from eager mode, which schedules
  1070. them as soon as possible. This pass attempts to reorder the graph to
  1071. mimic eager behavior.
  1072. """
  1073. for node in self.fx_tracer.graph.find_nodes(
  1074. op="call_function", target=call_hook
  1075. ):
  1076. if node.kwargs.get("hook_type", None) != "tensor_pre_hook":
  1077. continue
  1078. getitem_node = node.args[0]
  1079. input_node = node.args[1] # tensor_pre_hook handle only one grad tensor
  1080. if input_node is not node.prev and not self.is_placeholder(input_node):
  1081. input_node.append(getitem_node)
  1082. getitem_node.append(node)
  1083. def reorder_pre_hook_nodes_to_schedule_asap(self) -> None:
  1084. """
  1085. In this function, we schedule the pre hooks as soon as possible. This
  1086. does not match eager behavior (schedule pre hook right before its
  1087. registered node), but it can make acc grad be scheduled properly when
  1088. the pre hooks are registered to them. After reordering acc grad node, we
  1089. will reorder the pre hooks again to mimic eager behavior.
  1090. """
  1091. for node in self.fx_tracer.graph.find_nodes(
  1092. op="call_function", target=call_hook
  1093. ):
  1094. if node.kwargs.get("hook_type", None) != "pre_hook":
  1095. continue
  1096. getitem_node = node.args[0]
  1097. # pre_hook handle a tuple of grad tensors
  1098. input_nodes = self.get_all_nodes(node.args[1])
  1099. to_remove = []
  1100. to_append = []
  1101. hook_block = [node] # contain the hook and hook args getitem
  1102. for n in input_nodes:
  1103. if n.op == "call_function" and n.target == operator.getitem:
  1104. to_append.append(n.args[0])
  1105. to_remove.append(n)
  1106. hook_block.append(n)
  1107. for a, b in zip(to_remove, to_append):
  1108. input_nodes.remove(a)
  1109. input_nodes.append(b) # type: ignore[arg-type]
  1110. arg = max(input_nodes) # last input
  1111. if arg is not node.prev and not self.is_placeholder(arg):
  1112. arg.append(getitem_node)
  1113. for n in hook_block:
  1114. getitem_node.append(n)
  1115. def reorder_pre_hook_nodes_to_mimic_eager(self) -> None:
  1116. """
  1117. Usage of AOTAutograd causes all the pre_hook nodes to get pushed to the
  1118. end of the graph. This differs from eager mode, which schedules them
  1119. right before their registered node execution. This pass attempts to
  1120. reorder the graph to mimic eager behavior.
  1121. """
  1122. pre_hooks = []
  1123. for node in self.fx_tracer.graph.find_nodes(
  1124. op="call_function", target=call_hook
  1125. ):
  1126. if node.kwargs.get("hook_type", None) != "pre_hook":
  1127. continue
  1128. pre_hooks.append(node)
  1129. for node in reversed(pre_hooks):
  1130. hook_getitem_node = node.args[0]
  1131. users = list(node.users.keys())
  1132. if len(users) == 0:
  1133. continue
  1134. # users are all getitem ops and they are used by same registered node
  1135. assert all(
  1136. user.op == "call_function" and user.target == operator.getitem
  1137. for user in users
  1138. )
  1139. registered_node = next(iter(users[0].users.keys()))
  1140. if registered_node is not node.next:
  1141. registered_node.prepend(hook_getitem_node)
  1142. registered_node.prepend(node)
  1143. for getitem in users:
  1144. registered_node.prepend(getitem)
  1145. def reorder_post_acc_grad_hook_nodes(self) -> None:
  1146. """
  1147. Usage of AOTAutograd causes all the post_acc_grad_hook nodes to get
  1148. pushed to the end of the graph. This differs from eager mode, which
  1149. schedules them as soon as possible. This pass attempts to reorder the
  1150. graph to mimic eager behavior.
  1151. """
  1152. post_acc_grad_hooks = []
  1153. for node in self.fx_tracer.graph.find_nodes(
  1154. op="call_function", target=call_hook
  1155. ):
  1156. if node.kwargs.get("hook_type", None) != "post_acc_grad_hook":
  1157. continue
  1158. post_acc_grad_hooks.append(node)
  1159. # nodes in post_acc_grad_hooks are in topo order. For hooks registered
  1160. # to same node, we should keep their relative order
  1161. for node in reversed(post_acc_grad_hooks):
  1162. getitem_node = node.args[0]
  1163. param_node = node.args[1] # post_acc_grad_hook handle one param
  1164. # find the corresponding acc_grad node
  1165. acc_grad_node = None
  1166. for n in list(param_node.users.keys()):
  1167. if n.op == "call_function" and n.target == call_accumulate_grad:
  1168. acc_grad_node = n
  1169. break
  1170. assert acc_grad_node is not None, (
  1171. "post_acc_grad_hook must have corresponding acc grad node"
  1172. )
  1173. # append post_acc_grad_hook after acc_grad node
  1174. acc_grad_node.append(getitem_node)
  1175. getitem_node.append(node)
  1176. def reorder_post_hook_nodes(self) -> None:
  1177. """
  1178. Usage of AOTAutograd causes all the post_hook nodes to get pushed to the
  1179. end of the graph. This differs from eager mode, which schedules them as
  1180. soon as possible. This pass attempts to reorder the graph to mimic eager
  1181. behavior.
  1182. """
  1183. post_hooks = []
  1184. for node in self.fx_tracer.graph.find_nodes(
  1185. op="call_function", target=call_hook
  1186. ):
  1187. if node.kwargs.get("hook_type", None) != "post_hook":
  1188. continue
  1189. post_hooks.append(node)
  1190. for node in reversed(post_hooks):
  1191. getitem_node = node.args[0]
  1192. output_nodes = node.args[1]
  1193. input_nodes = node.args[2]
  1194. if len(output_nodes) > 0:
  1195. continue
  1196. input_nodes_and_users = []
  1197. input_nodes_and_users.extend(list(input_nodes))
  1198. for input_node in input_nodes:
  1199. input_nodes_and_users.extend(
  1200. user
  1201. for user in list(input_node.users.keys())
  1202. if not (
  1203. user.op == "call_function"
  1204. and user.target == call_hook
  1205. and node.kwargs.get("hook_type", None) == "post_hook"
  1206. )
  1207. )
  1208. arg = max(input_nodes_and_users) # last input users
  1209. if arg.op == "call_function" and arg.target == call_accumulate_grad:
  1210. param_node = arg.args[0]
  1211. post_acc_grad_hook_node = None
  1212. for n in list(param_node.users.keys()):
  1213. if (
  1214. n.op == "call_function"
  1215. and n.target == call_hook
  1216. and n.kwargs.get("hook_type", None) == "post_acc_grad_hook"
  1217. ):
  1218. post_acc_grad_hook_node = n
  1219. if post_acc_grad_hook_node is not None:
  1220. post_acc_grad_hook_node.append(getitem_node)
  1221. getitem_node.append(node)
  1222. continue
  1223. if arg is not node.prev and not self.is_placeholder(arg):
  1224. arg.append(getitem_node)
  1225. getitem_node.append(node)
  1226. def to_proxy(self, t: Any) -> Any:
  1227. if t is None:
  1228. return None
  1229. if isinstance(t, list):
  1230. return [self.to_proxy(x) for x in t]
  1231. if isinstance(t, tuple):
  1232. return tuple(self.to_proxy(x) for x in t)
  1233. if isinstance(t, (torch.SymInt, torch.SymFloat)):
  1234. return self.symnode_proxy_lookup[t.node]
  1235. if not isinstance(t, torch.Tensor):
  1236. # constant types like device, dtype, str
  1237. return t
  1238. proxy_tensor = fetch_object_proxy(self.fx_tracer, t)
  1239. assert isinstance(proxy_tensor, torch.fx.experimental.proxy_tensor._ProxyTensor)
  1240. return proxy_tensor.proxy
  1241. def bind_objects_to_proxies(
  1242. self,
  1243. objects: Sequence[Any],
  1244. proxies: Any,
  1245. origins: Optional[list[tuple[int, str]]] = None,
  1246. ) -> Sequence[Any]:
  1247. if isinstance(proxies, torch.fx.Proxy):
  1248. if origins:
  1249. assert len(origins) == len(objects)
  1250. bound_proxies = []
  1251. for i in range(len(objects)):
  1252. nodecall_index, node_name = origins[i]
  1253. self.set_node_origin(node_name, nodecall_index, None)
  1254. bound_proxies.append(proxies[i]) # type: ignore[index]
  1255. proxies = bound_proxies
  1256. else:
  1257. proxies = [proxies[i] for i in range(len(objects))] # type: ignore[index]
  1258. assert len(objects) == len(proxies)
  1259. track_tensor_tree(objects, proxies, constant=None, tracer=self.fx_tracer)
  1260. return proxies
  1261. def bind_backward_state(self, index: int) -> BackwardState:
  1262. assert self.hooks_proxy is not None
  1263. proxy = self.hooks_proxy[index] # type: ignore[index]
  1264. bw_state = BackwardState()
  1265. track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
  1266. return bw_state
  1267. def set_node_origin(
  1268. self,
  1269. node_name: str,
  1270. nodecall_index: int,
  1271. pyobj: Optional[torch.autograd.Function],
  1272. ) -> None:
  1273. maybe_aot_id = ""
  1274. if pyobj is not None:
  1275. forward_cls = pyobj._forward_cls # type: ignore[attr-defined]
  1276. if hasattr(forward_cls, "_aot_id"):
  1277. # backward was created by AOT Dispatcher
  1278. if forward_cls._lazy_backward_info is None:
  1279. raise RuntimeError(
  1280. """This compiled backward function was saved by AOTAutogradCache, which does not support
  1281. compiled autograd. Please turn off AOTAutogradCache using `TORCHINDUCTOR_AUTOGRAD_CACHE=0`."""
  1282. )
  1283. maybe_aot_id = forward_cls._aot_id
  1284. new_code = f"{node_name}{maybe_aot_id} (NodeCall {nodecall_index})"
  1285. raw_stack_trace = CapturedTraceback.extract().format()[-1]
  1286. new_stack_trace = raw_stack_trace.replace(
  1287. "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code
  1288. )
  1289. set_stack_trace(new_stack_trace)
  1290. # state of the autograd engine dispatch, kept in sync by enable/disable context managers
  1291. compiled_autograd_enabled = False
  1292. # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
  1293. compiled_autograd_enabled_force_eager = False
  1294. # global flag to check if we are processing graphs produced from a compiled autograd graph
  1295. in_compiled_autograd_region = False
  1296. active_disable_ctx = False
  1297. depth = 0
  1298. @contextlib.contextmanager
  1299. def _enable(
  1300. compiler_fn: Callable[..., Any],
  1301. dynamic: bool = True,
  1302. ignore_active_disable_ctx: bool = True,
  1303. ) -> Generator[None, None, None]:
  1304. # The entrypoint to enable CA.
  1305. # It is recommended to enable via `torch._dynamo.config.compiled_autograd = True` rather
  1306. # than using this context manager directly. If you are torch.compiling the corresponding
  1307. # forward pass, make sure they are wrapped under this context as well.
  1308. #
  1309. # Example:
  1310. # def train(model, inputs, target):
  1311. # compiled_model = torch.compile(model)
  1312. # pred = compiled_model(data)
  1313. # loss = compute_loss(pred, target)
  1314. # loss.backward()
  1315. #
  1316. # with _enable(compiler_fn):
  1317. # train(model, inputs, target)
  1318. #
  1319. # Inputs:
  1320. # - compiler_fn: The wrapper that will consume the compiled autograd graph, e.g. `torch.compile`
  1321. # - dynamic: Whether compiled autograd will treat tensors in the autograd graph (params, activations) as dynamic.
  1322. # This doesn't affect the dynamic configuration of the compilation wrapper.
  1323. if not ignore_active_disable_ctx and active_disable_ctx:
  1324. yield
  1325. else:
  1326. if dynamic:
  1327. assert type(dynamic) is bool
  1328. from torch._dynamo import eval_frame
  1329. if eval_frame._stance.stance == "force_eager":
  1330. # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd
  1331. # to fall back to eager as well.
  1332. global compiled_autograd_enabled_force_eager
  1333. compiled_autograd_enabled_force_eager = True
  1334. try:
  1335. yield
  1336. finally:
  1337. compiled_autograd_enabled_force_eager = False
  1338. else:
  1339. # we need to import this, because user might not have imported it if they directly use this context manager
  1340. # we need to lazily import it, because of circular dependencies
  1341. if torch.cuda.is_available():
  1342. from torch._inductor import cudagraph_trees # noqa: F401
  1343. (
  1344. prior_compiler,
  1345. prior_dynamic,
  1346. ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  1347. functools.partial(AutogradCompilerInstance, compiler_fn), dynamic
  1348. )
  1349. if snapshot_verbose_logging_enabled():
  1350. torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) # type:ignore[arg-type]
  1351. global compiled_autograd_enabled
  1352. compiled_autograd_enabled = True
  1353. global depth
  1354. prior_depth = depth
  1355. depth += 1
  1356. try:
  1357. with torch.autograd.set_multithreading_enabled(False):
  1358. yield
  1359. finally:
  1360. if not prior_compiler:
  1361. compiled_autograd_enabled = False
  1362. torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  1363. prior_compiler, prior_dynamic
  1364. )
  1365. depth -= 1
  1366. assert depth == prior_depth, (
  1367. "Nested Compiled Autograd Contexts must return before their parent context"
  1368. )
  1369. @contextlib.contextmanager
  1370. def _disable() -> Generator[None, None, None]:
  1371. (
  1372. prior_compiler,
  1373. prior_dynamic,
  1374. ) = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
  1375. global compiled_autograd_enabled
  1376. compiled_autograd_enabled = False
  1377. global active_disable_ctx
  1378. if not active_disable_ctx:
  1379. active_disable_ctx = True
  1380. try:
  1381. yield
  1382. finally:
  1383. if prior_compiler:
  1384. compiled_autograd_enabled = True
  1385. active_disable_ctx = False
  1386. torch._C._dynamo.compiled_autograd.set_autograd_compiler(
  1387. prior_compiler, prior_dynamic
  1388. )
  1389. # return to starting state of a new process
  1390. def reset() -> None:
  1391. global compiled_autograd_enabled
  1392. compiled_autograd_enabled = False
  1393. assert not in_compiled_autograd_region
  1394. torch._C._dynamo.compiled_autograd.set_autograd_compiler(None, False)
  1395. torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
  1396. torch._C._dynamo.compiled_autograd.clear_cache()
  1397. global COMPILE_COUNTER
  1398. COMPILE_COUNTER = itertools.count()
  1399. # Reimplementation of part of CopySlices::apply in Python.
  1400. # The shared code is really similar so we're not going to try to deduplicate.
  1401. def copy_slices_prologue(
  1402. inputs: Sequence[torch.Tensor],
  1403. base_sizes: Sequence[IntLikeType],
  1404. base_strides: Sequence[IntLikeType],
  1405. base_storage_offset: IntLikeType,
  1406. view_sizes: Sequence[IntLikeType],
  1407. view_strides: Sequence[IntLikeType],
  1408. view_storage_offset: IntLikeType,
  1409. ) -> list[torch.Tensor]:
  1410. grad = inputs[0]
  1411. result = grad.new_empty_strided(base_sizes, base_strides)
  1412. assert grad is not None
  1413. result.copy_(grad)
  1414. offset = view_storage_offset - base_storage_offset
  1415. grad_slice = result.as_strided(view_sizes, view_strides, offset)
  1416. return [result, grad_slice, grad_slice.clone(memory_format=torch.contiguous_format)]
  1417. # Reimplementation of part of CopySlices::apply in Python.
  1418. # The shared code is really similar so we're not going to try to deduplicate.
  1419. def copy_slices_epilogue(
  1420. needs_input_grad: Sequence[bool],
  1421. result: torch.Tensor,
  1422. res: Sequence[Optional[torch.Tensor]],
  1423. grad_slice: torch.Tensor,
  1424. ) -> list[Optional[torch.Tensor]]:
  1425. grad_inputs: list[Optional[torch.Tensor]] = [None] * len(needs_input_grad)
  1426. for i in range(len(needs_input_grad)):
  1427. if needs_input_grad[i]:
  1428. if res[i] is None:
  1429. continue
  1430. if i == 0:
  1431. to_copy = res[i]
  1432. assert to_copy is not None
  1433. grad_slice.copy_(to_copy)
  1434. grad_inputs[i] = result
  1435. else:
  1436. grad_inputs[i] = res[i]
  1437. return grad_inputs