distributed.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. """
  2. This module implements distributed training optimizations for TorchDynamo backends.
  3. It provides functionality to optimize models wrapped in DistributedDataParallel (DDP)
  4. by intelligently splitting compiled graphs to align with DDP's gradient synchronization
  5. boundaries. Key features include:
  6. - Graph partitioning based on parameter bucket sizes
  7. - Optimization of allreduce operations for distributed training
  8. - Support for parameter ignoring and buffer handling
  9. - Submodule compilation and management
  10. - Debugging utilities for distributed training
  11. The main component is the DDPOptimizer class, which handles graph splitting and
  12. recompilation to enable efficient distributed training while maintaining the benefits
  13. of compilation.
  14. """
  15. import logging
  16. import traceback
  17. from dataclasses import dataclass, field
  18. from typing import Any, Callable, Optional, TYPE_CHECKING
  19. from unittest import mock
  20. import torch
  21. from torch import fx
  22. from torch._dynamo.backends.registry import CompiledFn, CompilerFn
  23. from torch._dynamo.output_graph import GraphCompileReason
  24. from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
  25. from torch._logging import trace_structured
  26. from torch.fx.node import Node
  27. if TYPE_CHECKING:
  28. from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
  29. # Regular log messages should go through 'log'.
  30. # ddp_graph_log is a separate artifact logger reserved for dumping graphs.
  31. # See docs/source/logging.rst for more info.
  32. log = logging.getLogger(__name__)
  33. ddp_graph_log = torch._logging.getArtifactLogger(__name__, "ddp_graphs")
  34. def args_str(args: Any) -> str:
  35. # a debug helper
  36. if torch.is_tensor(args):
  37. return f"T[{args.shape}]"
  38. elif isinstance(args, tuple):
  39. return f"tuple({', '.join([args_str(x) for x in args])})"
  40. elif isinstance(args, list):
  41. return f"list({', '.join([args_str(x) for x in args])})"
  42. else:
  43. return str(args)
  44. @dataclass
  45. class Bucket:
  46. size: int = 0
  47. params: list[str] = field(default_factory=list)
  48. nodes: list[fx.Node] = field(default_factory=list)
  49. # param_ids is just used for unit testing
  50. param_ids: list[int] = field(default_factory=list)
  51. # keep track of any buckets that were extended for logging purposes
  52. opcount_increased_to_capture_external_output: int = 0
  53. paramsize_before_opcount_increase: int = 0
  54. def bucket_has_external_output(bucket: Bucket) -> bool:
  55. nodes_in_bucket = set()
  56. # we want to iterate in reverse order, but clumsi-luckily the bucket.nodes list was already created backwards
  57. # so we don't reverse it here
  58. for node in bucket.nodes:
  59. # assume node.op != output, since those are filtered in the original iteration
  60. nodes_in_bucket.add(node)
  61. for user in node.users:
  62. if user not in nodes_in_bucket:
  63. return True
  64. return False
  65. def pretty_print_buckets(buckets: list[Bucket], bucket_bytes_cap: int) -> None:
  66. headers = ("Index", "Size (b)", "Param Names")
  67. rows: list[tuple[Optional[int], Optional[int], str]] = []
  68. extended_buckets = []
  69. for idx, bucket in enumerate(reversed(buckets)):
  70. if len(bucket.params) > 0:
  71. rows.append((idx, bucket.size, bucket.params[0]))
  72. rows.extend((None, None, param) for param in bucket.params[1:])
  73. if bucket.opcount_increased_to_capture_external_output > 0:
  74. extended_buckets.append(
  75. (
  76. idx,
  77. bucket.opcount_increased_to_capture_external_output,
  78. bucket.size - bucket.paramsize_before_opcount_increase,
  79. )
  80. )
  81. if len(rows):
  82. log.info(
  83. "\nDDPOptimizer used bucket cap %s and created %d buckets. Enable debug logs for detailed bucket info.",
  84. bucket_bytes_cap,
  85. len(buckets),
  86. )
  87. if len(extended_buckets):
  88. log.warning(
  89. "Some buckets were extended beyond their requested parameter capacities"
  90. " in order to ensure each subgraph has an output node, required for fx graph partitioning."
  91. " This can be the case when a subgraph would have only contained nodes performing inplace mutation,"
  92. " and returning no logical outputs. This should not be a problem, unless it results in too few graph"
  93. " partitions for optimal DDP performance."
  94. )
  95. try:
  96. from tabulate import tabulate
  97. log.debug(
  98. "\nDDPOptimizer produced the following bucket assignments:\n%s",
  99. tabulate(rows, headers=headers, tablefmt="simple_grid"),
  100. )
  101. if len(extended_buckets):
  102. log.warning(
  103. "DDPOptimizer extended these buckets to ensure per-subgraph output nodes:\n%s",
  104. tabulate(
  105. extended_buckets,
  106. headers=("Index", "Extra Ops", "Extra Param Size (b)"),
  107. tablefmt="simple_grid",
  108. ),
  109. )
  110. except ImportError:
  111. log.debug(
  112. "Please `pip install tabulate` in order to display ddp bucket sizes and diagnostic information."
  113. )
  114. else:
  115. log.debug("DDPOptimizer captured no parameters and did not split this graph.")
  116. def has_higher_order_op(gm: fx.GraphModule) -> bool:
  117. # Check if there is a higher order op in the graph
  118. for node in gm.graph.nodes:
  119. if node.op == "get_attr":
  120. maybe_param = getattr(gm, node.target)
  121. if isinstance(maybe_param, torch.fx.GraphModule):
  122. return True
  123. return False
  124. def propagate_metadata(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
  125. for name, module in split_gm.named_modules():
  126. if "." not in name and len(name):
  127. # TODO: add split id to CompileId: https://github.com/pytorch/tlparse/pull/83/files#r1880649384
  128. module.meta = orig_gm.meta
  129. module._param_name_to_source = orig_gm._param_name_to_source
  130. def propagate_dynamo_source(orig_gm: fx.GraphModule, split_gm: fx.GraphModule) -> None:
  131. name_to_dynamo_source = {}
  132. for node in orig_gm.graph.find_nodes(op="placeholder"):
  133. name_to_dynamo_source[node.name] = node._dynamo_source
  134. for name, module in split_gm.named_modules():
  135. if "." not in name and len(name):
  136. for node in module.graph.find_nodes(op="placeholder"):
  137. # non-placeholder in original_gm may become placeholder in submodules
  138. node._dynamo_source = name_to_dynamo_source.get(node.name, None)
  139. class DDPOptimizerContext:
  140. def __init__(self) -> None:
  141. self.curr_bucket: int = -1
  142. self.metadata_per_bucket: list[ViewAndMutationMeta] = []
  143. # compile each of the partitioned submodules using the user-provided compiler
  144. class SubmodCompiler(torch.fx.interpreter.Interpreter):
  145. def __init__(
  146. self,
  147. module: fx.GraphModule,
  148. compiler: CompilerFn,
  149. fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
  150. ) -> None:
  151. super().__init__(module)
  152. self.compiler = compiler
  153. self.fake_mode = fake_mode
  154. # See Note [DDPOptimizer and fw_metadata]
  155. ctx = torch._guards.TracingContext.try_get()
  156. if ctx is not None:
  157. ctx.ddp_optimizer_ctx = DDPOptimizerContext()
  158. def compile_submod(
  159. self, input_mod: fx.GraphModule, args: list[torch.Tensor], kwargs: Any
  160. ) -> Any:
  161. """
  162. Compile the submodule,
  163. using a wrapper to make sure its output is always a tuple,
  164. which is required by AotAutograd based compilers
  165. """
  166. assert len(kwargs) == 0, "We assume only args for these modules"
  167. class WrapperModule(torch.nn.Module):
  168. def __init__(
  169. self, submod: Callable[..., Any], unwrap_singleton_tuple: bool
  170. ) -> None:
  171. super().__init__()
  172. self.submod = submod
  173. self.unwrap_singleton_tuple = unwrap_singleton_tuple
  174. def forward(self, *args: Any) -> Any:
  175. x = self.submod(*args)
  176. # TODO(whc)
  177. # for some reason the isinstance check is necessary if I split one node per submod
  178. # - even though I supposedly wrapped the output in a tuple in those cases, the real
  179. # compiled module was still returning a tensor
  180. if self.unwrap_singleton_tuple and isinstance(x, (tuple, list)):
  181. return x[0]
  182. return x
  183. unwrap_singleton_tuple = False
  184. for sn in input_mod.graph.nodes:
  185. if sn.op == "output":
  186. if not isinstance(sn.args[0], tuple):
  187. unwrap_singleton_tuple = True
  188. sn.args = (sn.args,)
  189. input_mod.recompile()
  190. input_mod.compile_subgraph_reason = GraphCompileReason( # type: ignore[assignment]
  191. "DDPOptimizer intentional graph-break (See Note [DDPOptimizer])."
  192. " Set `torch._dynamo.config.optimize_ddp = False` to disable.",
  193. [
  194. # it's close to useless to get a real stacktrace here, and quite verbose.
  195. traceback.FrameSummary(__file__, 0, "DDPOptimizer"),
  196. ],
  197. )
  198. wrapper = WrapperModule(
  199. self.compiler(input_mod, args),
  200. unwrap_singleton_tuple,
  201. )
  202. return wrapper
  203. # Note:
  204. #
  205. # The way distributed works today around fake tensors can be somewhat confusing.
  206. # Some of these codepaths are shared in both runtime, and compile time. The presence
  207. # of a fake_mode, read off of fake tensor inputs, dictates how we will operate.
  208. #
  209. # A few things to keep in mind:
  210. #
  211. # 1) We invoke `compile_submod` with a real module. The output of that gets stored
  212. # on the graph via `self.module.add_submodule(n.target, compiled_submod_real)`.
  213. #
  214. # 2) When running a call_module targeted node, if we have a fake_mode, we fakify the
  215. # module we got from self.fetch_attr(n.target). Regardless of fake_mode, we then execute it.
  216. #
  217. # 3) Fake tensors should always be around during compile time.
  218. #
  219. # 4) Fake tensors should never be around at runtime.
  220. #
  221. # 5) We end up with a compilation mode that takes a real submodule and fake tensors,
  222. # to match what aot_autograd expects. See Note: [Fake Modules and AOTAutograd]
  223. def run_node(self, n: Node) -> Any:
  224. args, kwargs = self.fetch_args_kwargs_from_env(n)
  225. new_args = []
  226. assert self.fake_mode
  227. for arg in args:
  228. if isinstance(arg, torch.Tensor) and not isinstance(
  229. arg, torch._subclasses.FakeTensor
  230. ):
  231. new_args.append(torch._dynamo.utils.to_fake_tensor(arg, self.fake_mode))
  232. else:
  233. new_args.append(arg)
  234. log.debug("run_node %s, %s got args %s", n.op, n.target, args_str(args))
  235. assert isinstance(args, tuple)
  236. assert isinstance(kwargs, dict)
  237. if n.op == "call_module":
  238. real_mod = self.fetch_attr(str(n.target))
  239. if self.fake_mode:
  240. curr_submod = deepcopy_to_fake_tensor(real_mod, self.fake_mode)
  241. else:
  242. curr_submod = real_mod
  243. ddp_graph_log.debug("\n---%s graph---\n%s", n.target, curr_submod.graph)
  244. # When calling the compiler on the submod, inputs (new_args) are expected to
  245. # be FakeTensors already since Dynamo would have made them FakeTensors in the
  246. # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
  247. # since this wrapping happens during compilation
  248. # Note: Returning Fake Tensors on First AOT Autograd Call
  249. #
  250. # Inductor will optimize strides of outputs when it deems it profitable.
  251. # For instance, converting to channels last. When we split the graph here
  252. # into multiple inductor compilations, we need to make sure that the
  253. # output strides of one compilation is appropriately passed to the subsequent
  254. # compilations. However, the mapping from inductor output to dynamo output
  255. # is non-trivial due to aot_autograd's deduping, de-aliasing, mutation, re-writing,
  256. # subclass handling, etc. In order to replay all this logic we set a flag such that
  257. # the first invocation of inductor in aot_autograd will return Fake Tensors with
  258. # appropriate strides. Then, all of aot autograd's runtime logic is replayed.
  259. # This gives us the appropriately strided outputs here which will reflect runtime strides.
  260. class FakeifyFirstAOTInvocationGuard:
  261. def __init__(self) -> None:
  262. self.tc = torch._guards.TracingContext.try_get()
  263. assert self.tc
  264. self.tc.fakify_first_call = True
  265. def __del__(self) -> None:
  266. self.tc.fakify_first_call = False # type: ignore[union-attr]
  267. # For aot_eager and other backends, tracing context is not set
  268. has_tracing_context = torch._guards.TracingContext.try_get() is not None
  269. if has_tracing_context:
  270. g = FakeifyFirstAOTInvocationGuard() # noqa: F841
  271. from torch._dynamo.utils import counters
  272. init = counters["aot_autograd"]["total"]
  273. compiled_submod_real = self.compile_submod(real_mod, new_args, kwargs)
  274. # TODO - better way of doing this?
  275. # Only aot autograd handles fakifying first call
  276. invoked_aot_autograd = init != counters["aot_autograd"]["total"]
  277. # We update the original (outer) graph with a call into the compiled module
  278. # instead of the uncompiled one.
  279. self.module.delete_submodule(n.target) # type: ignore[operator]
  280. n.target = "compiled_" + n.target # type: ignore[operator]
  281. self.module.add_submodule(n.target, compiled_submod_real) # type: ignore[operator]
  282. # Finally, we have to produce inputs for use compiling the next submodule,
  283. # and these need to be FakeTensors, so we execute the module under fake_mode
  284. # Because parameters are not fake we patch fake tensor mode to allow non fake inputs
  285. with (
  286. self.fake_mode,
  287. mock.patch.object(self.fake_mode, "allow_non_fake_inputs", True),
  288. ):
  289. if has_tracing_context and invoked_aot_autograd:
  290. tracing_ctx = torch._guards.TracingContext.try_get()
  291. assert tracing_ctx is not None
  292. # DDPOptimizer maintains 1 dynamo graph -> N AOT graphs
  293. # Dynamo only has 1 tracing context, so it needs to maintain all N AOT metadata instances
  294. ddp_ctx = tracing_ctx.ddp_optimizer_ctx
  295. assert ddp_ctx is not None
  296. assert tracing_ctx.fw_metadata is not None
  297. ddp_ctx.curr_bucket += 1
  298. ddp_ctx.metadata_per_bucket.append(tracing_ctx.fw_metadata)
  299. out = compiled_submod_real(*new_args, **kwargs)
  300. # output should be fake or subclass
  301. assert all(
  302. (not isinstance(t, torch.Tensor) or type(t) is not torch.Tensor)
  303. for t in (out if isinstance(out, (list, tuple)) else [out])
  304. )
  305. return out
  306. else:
  307. return curr_submod(*new_args, **kwargs)
  308. else:
  309. # placeholder or output nodes don't need to get compiled, just executed
  310. return getattr(self, n.op)(n.target, new_args, kwargs)
  311. class DDPOptimizer:
  312. """Note [DDPOptimizer]
  313. DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
  314. breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
  315. the boundaries of gradient-allreduce buckets chosen by DDP.
  316. Background/Motivation
  317. - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
  318. - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
  319. - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
  320. at around the same time during backward and thus can share the same allreduce efficiently
  321. - Allreduces must overlap with backward compute for optimal training performance
  322. - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
  323. operates when individual grads become 'ready'
  324. - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
  325. autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
  326. fused backward function executes, preventing any overlap of compute and communication
  327. Algorithm
  328. - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
  329. this graph in reverse order to determine the true order that gradients will become ready during backward.
  330. - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
  331. and a graph break introduced
  332. - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
  333. into an outer module that is returned to the user
  334. Notes
  335. - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
  336. and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
  337. in eager.
  338. - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
  339. produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
  340. degradation approaching the baseline case where graph-splits are not used, but not worse.
  341. - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
  342. subgraphs being compiled
  343. - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
  344. left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
  345. also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
  346. it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
  347. - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
  348. and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
  349. DDPOptimizer)
  350. Debugging
  351. - Generally, it is easiest to debug DDPOptimizer in a single process program, using pdb.
  352. - In many cases, the log messages are helpful (they show bucket size assignments)-
  353. just set TORCH_LOGS env to include any of 'dynamo', 'distributed', or 'dist_ddp'.
  354. - See `benchmarks/dynamo/distributed.py` for a simple harness that will run a toy model or a torchbench model
  355. in a single process (or with torchrun, in multiple processes)
  356. Args:
  357. bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
  358. set to match the equivalent parameter on the original DDP module.
  359. backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
  360. first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
  361. special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
  362. """
  363. def __init__(
  364. self,
  365. bucket_bytes_cap: int,
  366. backend_compile_fn: CompilerFn,
  367. first_bucket_cap: Optional[int] = None,
  368. ) -> None:
  369. if first_bucket_cap is not None:
  370. self.first_bucket_cap = first_bucket_cap
  371. elif torch.distributed.is_available():
  372. # this constant comes from C10D lib which is not always built
  373. self.first_bucket_cap = torch.distributed._DEFAULT_FIRST_BUCKET_BYTES
  374. else:
  375. self.first_bucket_cap = bucket_bytes_cap
  376. self.bucket_bytes_cap = bucket_bytes_cap
  377. assert self.first_bucket_cap <= self.bucket_bytes_cap, (
  378. "First bucket should be smaller/equal to other buckets to get comms warmed up ASAP"
  379. )
  380. self.backend_compile_fn = backend_compile_fn
  381. def _ignore_parameter(self, parameter: torch.nn.Parameter) -> bool:
  382. return hasattr(parameter, "_ddp_ignored") and parameter._ddp_ignored
  383. def add_param(self, bucket: Bucket, param: torch.nn.Parameter, name: str) -> None:
  384. bucket.size += param.untyped_storage().nbytes()
  385. bucket.params.append(name)
  386. bucket.param_ids.append(id(param))
  387. def add_module_params_to_bucket(
  388. self,
  389. mod: torch.nn.Module,
  390. bucket: Bucket,
  391. processed_modules: set[torch.nn.Module],
  392. prefix: str,
  393. ) -> None:
  394. processed_modules.add(mod)
  395. for name, param in mod.named_parameters():
  396. if param.requires_grad and not self._ignore_parameter(param):
  397. self.add_param(bucket, param, f"{prefix}_{name}")
  398. def add_param_args(self, bucket: Bucket, node: fx.Node) -> None:
  399. for arg in node.args:
  400. if not isinstance(arg, torch.fx.node.Node):
  401. continue
  402. if arg.op != "placeholder":
  403. continue
  404. param = arg.meta["example_value"]
  405. if (
  406. isinstance(param, torch.nn.Parameter)
  407. and param.requires_grad
  408. and not self._ignore_parameter(param)
  409. ):
  410. self.add_param(bucket, param, str(arg.target))
  411. def compile_fn(
  412. self, gm: fx.GraphModule, example_inputs: list[torch.Tensor]
  413. ) -> CompiledFn:
  414. """
  415. Implements graph splitting, first determining a set of of buckets by counting
  416. parameter sizes in reverse graph order, then invoking the user/backend compiler
  417. to compile each subgraph. Finally, stiches compiled graphs into one graphmodule
  418. and returns its callable.
  419. """
  420. # 1: compute the partition map according to DDP bucket logic
  421. buckets = [Bucket()] # (size, param_names)
  422. processed_modules: set[torch.nn.Module] = set()
  423. for node in reversed(gm.graph.nodes):
  424. if node.op in ("output", "placeholder"):
  425. continue
  426. if (
  427. buckets[0].size >= self.bucket_bytes_cap
  428. or len(buckets) == 1
  429. and buckets[0].size >= self.first_bucket_cap
  430. ):
  431. if bucket_has_external_output(buckets[0]):
  432. buckets.insert(0, Bucket())
  433. else:
  434. # continue building this bucket past the point of filling its parameter capacity,
  435. # to increase chances it contains at least one node that is either a global output or
  436. # passed as input to a subsequent graph
  437. if buckets[0].opcount_increased_to_capture_external_output == 0:
  438. buckets[0].paramsize_before_opcount_increase = buckets[0].size
  439. buckets[0].opcount_increased_to_capture_external_output += 1
  440. if node.op == "call_function":
  441. self.add_param_args(buckets[0], node)
  442. elif node.op == "call_module":
  443. target_mod = gm.get_submodule(node.target)
  444. if target_mod not in processed_modules:
  445. self.add_module_params_to_bucket(
  446. target_mod, buckets[0], processed_modules, node.target
  447. )
  448. elif node.op == "call_method":
  449. if isinstance(node.args[0].target, str):
  450. target_mod = None
  451. try:
  452. target_mod = gm.get_submodule(node.args[0].target)
  453. except AttributeError:
  454. pass
  455. if target_mod is not None and target_mod not in processed_modules:
  456. self.add_module_params_to_bucket(
  457. target_mod, buckets[0], processed_modules, node.target
  458. )
  459. # This handles situations like tmp = torch.mm(x, self.weight.t())
  460. # t: "f32[512, 512]" = l_self_seq_2_weight.t(); l_self_seq_2_weight = None
  461. # tmp: "f32[512, 512]" = torch.mm(input_2, t); input_2 = t = None
  462. self.add_param_args(buckets[0], node)
  463. elif node.op == "get_attr":
  464. maybe_param = getattr(gm, node.target)
  465. if (
  466. isinstance(maybe_param, torch.nn.Parameter)
  467. and maybe_param.requires_grad
  468. and not self._ignore_parameter(maybe_param)
  469. ):
  470. self.add_param(buckets[0], maybe_param, node.target)
  471. # All nodes have to be mapped to a bucket, even if they don't have their own params
  472. # Ignored params still end up in buckets, we just don't count them towards the capacity
  473. buckets[0].nodes.append(node)
  474. if len(buckets) > 1 and buckets[0].size == 0:
  475. # we collected a small preamble graph with ops that don't include parameters, fuse it back
  476. buckets[1].nodes.extend(buckets[0].nodes)
  477. assert len(buckets[0].params) == 0, "Params should be empty if size is 0"
  478. del buckets[0]
  479. # stash buckets for testing/debugging purposes
  480. self.buckets = buckets
  481. pretty_print_buckets(buckets, self.bucket_bytes_cap)
  482. if len(buckets) == 1:
  483. # bypass split/fuse logic if there is only one bucket
  484. return self.backend_compile_fn(gm, example_inputs)
  485. # 2: partition the graphmodule according to bucket capacity
  486. partition_map = {}
  487. for idx, b in enumerate(buckets):
  488. for node in b.nodes:
  489. partition_map[node] = idx
  490. split_gm = fx.passes.split_module.split_module(
  491. gm,
  492. None, # type: ignore[arg-type]
  493. lambda node: partition_map[node],
  494. )
  495. # See note [Assumption on Dynamo Metadata]
  496. propagate_dynamo_source(gm, split_gm)
  497. propagate_metadata(gm, split_gm)
  498. debug_str = (
  499. f"\n---orig graph---\n{gm.graph}\n"
  500. + f"\n---split graph---\n{split_gm.graph}\n"
  501. )
  502. for name, module in split_gm.named_modules():
  503. if "." not in name and len(name):
  504. # only print the submod graphs, not their children
  505. debug_str += f"\n---{name} graph---\n{module.graph}\n"
  506. debug_str += "\n---------------\n"
  507. ddp_graph_log.debug(debug_str)
  508. trace_structured(
  509. "optimize_ddp_split_graph",
  510. payload_fn=lambda: split_gm.print_readable(print_output=False),
  511. )
  512. for name, module in split_gm.named_modules():
  513. if "." not in name and len(name):
  514. trace_structured(
  515. "optimize_ddp_split_child",
  516. lambda: {"name": name},
  517. payload_fn=lambda: module.print_readable(print_output=False),
  518. )
  519. fake_mode = detect_fake_mode(example_inputs)
  520. if fake_mode is None:
  521. fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
  522. submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
  523. with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
  524. submod_compiler.run(*example_inputs)
  525. split_gm.recompile()
  526. ddp_graph_log.debug(
  527. "\n---final graph---\n%s\n---------------\n", split_gm.graph
  528. )
  529. return split_gm