_IR.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import copy
  4. import logging
  5. import operator
  6. from collections import defaultdict
  7. from enum import Enum
  8. from inspect import Parameter, Signature, signature
  9. from types import MethodType
  10. from typing import Any, Callable, Optional, Union
  11. import torch
  12. import torch.fx as fx
  13. from torch.distributed import ProcessGroup
  14. from torch.export import ExportedProgram
  15. from torch.export.unflatten import (
  16. _assign_attr,
  17. _AttrKind,
  18. _sink_params,
  19. InterpreterModule,
  20. )
  21. from torch.fx.node import map_aggregate
  22. from torch.fx.passes.split_module import split_module
  23. from ._backward import _null_coalesce_accumulate, stage_backward
  24. from ._unflatten import _outline_submodules
  25. from ._utils import PipeInfo
  26. from .stage import _PipelineStage
  27. logger = logging.getLogger(__name__)
  28. # TODO:
  29. # 1. investigate gradient sync for shared parameters. how does DDP do it?
  30. # 2. Add parameter movement to split_module
  31. def _find_loss_from_output_and_spec(output_val, spec_val):
  32. if spec_val is False:
  33. return None
  34. if spec_val is True:
  35. if not isinstance(output_val, fx.Node):
  36. raise RuntimeError(
  37. f"Loss spec must specify a dynamic value but got {output_val}"
  38. )
  39. return output_val
  40. if isinstance(spec_val, (tuple, list)):
  41. if not isinstance(output_val, (tuple, list)):
  42. raise RuntimeError(
  43. f"Output value {output_val} must match type of loss specification "
  44. f"{spec_val}"
  45. )
  46. if len(output_val) != len(spec_val):
  47. raise RuntimeError(
  48. f"Output value {output_val} must match length of loss specification "
  49. f"{spec_val}"
  50. )
  51. for out, spec in zip(output_val, spec_val):
  52. loss_val = _find_loss_from_output_and_spec(out, spec)
  53. if loss_val is not None:
  54. return loss_val
  55. raise RuntimeError(f"Did not find loss value in specification {spec_val}")
  56. if isinstance(spec_val, dict):
  57. if not isinstance(output_val, dict):
  58. raise RuntimeError(
  59. f"Output value {output_val} must match type of loss specification "
  60. f"{spec_val}"
  61. )
  62. if set(output_val.keys()) != set(spec_val.keys()):
  63. raise RuntimeError(
  64. f"Output value {output_val} must match keys of loss specification "
  65. f"{spec_val}"
  66. )
  67. for k in spec_val:
  68. loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
  69. if loss_val is not None:
  70. return loss_val
  71. raise RuntimeError(f"Did not find loss value in specification {spec_val}")
  72. raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
  73. def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
  74. output_nodes = [n for n in g.nodes if n.op == "output"]
  75. assert len(output_nodes) == 1
  76. output_node = output_nodes[0]
  77. output_val = output_node.args[0]
  78. generated_spec: Any = None
  79. if isinstance(mod, TrivialLossWrapper):
  80. # TrivialLossWrapper is pre-defined by PiPPy.
  81. # It has loss as the only output so we can safely assume the first output arg is the loss.
  82. assert len(output_node.args) == 1
  83. loss_node = output_val
  84. generated_spec = TrivialLossWrapper.loss_spec
  85. elif output_loss_value_spec is None:
  86. # Use default spec, i.e. search for "loss" in output values
  87. if isinstance(output_val, dict) and "loss" in output_val.keys():
  88. loss_node = output_val["loss"]
  89. generated_spec = {k: k == "loss" for k in output_val}
  90. else:
  91. loss_node = None
  92. generated_spec = None
  93. else:
  94. loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
  95. generated_spec = output_loss_value_spec
  96. return loss_node, output_node, generated_spec
  97. def _insert_stage_symbolic_backward(
  98. g: fx.Graph,
  99. loss_node: fx.Node,
  100. output_node: fx.Node,
  101. ):
  102. # Collect metadata about tuple output values. TODO: move this to split_module or FX IR
  103. tuples: dict[fx.Node, tuple] = {}
  104. for node in reversed(g.nodes):
  105. if node.op == "call_function":
  106. # In the forward pass, only emit placeholder, module calls, and
  107. # getitem calls. If we have a target other than getitem in this
  108. # (forward-only) code, there is a bug.
  109. assert node.target == operator.getitem, (
  110. "Found non-getitem call in forward pass. Please report a bug to PiPPy"
  111. )
  112. assert len(node.args) == 2, (
  113. "Found malformed getitem call. Please report a bug to PiPPy"
  114. )
  115. indexed_value, node_idx = tuple(node.args)
  116. # indexed_value is a collection that we are indexing into. It could
  117. # exist in the tuples map if we've processed another `getitem`
  118. # already.
  119. existing_list_size = (
  120. len(tuples[indexed_value]) if indexed_value in tuples else -1
  121. )
  122. new_list_size = max(node_idx + 1, existing_list_size)
  123. reconstructed_list = [None for _ in range(new_list_size)]
  124. # Copy over existing elements if present
  125. if indexed_value in tuples:
  126. for i, val in enumerate(tuples[indexed_value]):
  127. reconstructed_list[i] = val
  128. # Populate value represented by this node
  129. reconstructed_list[node_idx] = node
  130. tuples[indexed_value] = tuple(reconstructed_list)
  131. # Keep track of nodes that dominate the loss node.
  132. # We will only emit backward operations for nodes that can contribute
  133. # to the specified loss value.
  134. live_nodes = {loss_node: None}
  135. val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
  136. def assign_or_accumulate_grad(forward_node, grad_value):
  137. if forward_node in val_to_grad and forward_node.op != "placeholder":
  138. grad_value = g.call_function(
  139. _null_coalesce_accumulate,
  140. (val_to_grad[forward_node], grad_value),
  141. )
  142. val_to_grad[forward_node] = grad_value
  143. with g.inserting_before(output_node):
  144. for node in reversed(g.nodes):
  145. if node not in live_nodes:
  146. continue
  147. def add_to_live_nodes(n):
  148. live_nodes.setdefault(n, None)
  149. fx.node.map_arg(node.args, add_to_live_nodes)
  150. fx.node.map_arg(node.kwargs, add_to_live_nodes)
  151. if node.op == "call_module":
  152. output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]]
  153. if node in tuples:
  154. stage_output = tuples[node]
  155. output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
  156. outputs_with_grads_idxs = [
  157. i for i, n in enumerate(tuples[node]) if n in live_nodes
  158. ]
  159. else:
  160. stage_output = (node,)
  161. output_grads = val_to_grad[node]
  162. outputs_with_grads_idxs = [0]
  163. output_grads = (
  164. (output_grads,)
  165. if not isinstance(output_grads, tuple)
  166. else output_grads
  167. )
  168. grad_call = g.call_function(
  169. stage_backward,
  170. kwargs={
  171. "stage_output": stage_output,
  172. "output_grads": output_grads,
  173. "input_values": list(node.all_input_nodes),
  174. "outputs_with_grads_idxs": outputs_with_grads_idxs,
  175. },
  176. )
  177. # Insert backward stage debug info
  178. kwargs_copy = dict(grad_call.kwargs)
  179. grad_call.kwargs = kwargs_copy
  180. grad_call_proxy = fx.Proxy(grad_call)
  181. grads = grad_call_proxy.node
  182. input_nodes = list(node.all_input_nodes)
  183. grads_proxy = fx.Proxy(grads)
  184. for i, input_node in enumerate(input_nodes):
  185. assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index]
  186. return g
  187. class PipeSequential(torch.nn.Sequential):
  188. @staticmethod
  189. def from_sequential(sequential_instance: torch.nn.Sequential):
  190. return PipeSequential(*[copy.copy(m) for m in sequential_instance])
  191. def forward(self, input):
  192. for i, module in enumerate(self):
  193. input = module(input)
  194. if i != len(self) - 1:
  195. pipe_split()
  196. return input
  197. class LossWrapper(torch.nn.Module):
  198. """
  199. LossWrapper is a convenient abstract class that allows you to wrap up both
  200. your model as well as its loss function and specify the connectivity between
  201. the inputs, model, loss function, and output value. Example::
  202. class MyModelWrapper(LossWrapper):
  203. def forward(self, x, targets):
  204. model_out = self.module(x)
  205. loss_value = self.loss_fn(model_out, targets)
  206. return loss_value
  207. The above example defines a connectivity where we expect the forward/loss/backward
  208. training procedure to take two arguments (x and targets), pass x into the module
  209. to get the output of the feedforward computation, pass the model output and the
  210. targets value into the loss function, and get and return the loss value, which will
  211. be backpropagated by PiPPy. The above class would then be instantiated like::
  212. model = ... # instantiate the model
  213. loss_fn = torch.nn.MSELoss() # for the sake of demonstration
  214. wrapper = MyModelWrapper(model, loss_fn)
  215. pipe = Pipe.from_tracing(wrapper, ...)
  216. """
  217. def __init__(self, module, loss_fn):
  218. super().__init__()
  219. self.module = module
  220. self.loss_fn = loss_fn
  221. def forward(self, *args, **kwargs):
  222. raise NotImplementedError(
  223. "This instance of LossWrapper does not have an overridden"
  224. "forward(). Please implement forward() to specify the arguments, "
  225. "connection between the module and loss, and loss output "
  226. "value."
  227. )
  228. class TrivialLossWrapper(LossWrapper):
  229. def forward(self, x, targets):
  230. model_out = self.module(x)
  231. return self.loss_fn(model_out, targets)
  232. loss_spec = True
  233. # Pipe model representation
  234. #
  235. # Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
  236. # a single topological ordering of pipeline "stages" that, when run in series,
  237. # constitutes all of the operations of the program. However, unlike `nn.Sequential`,
  238. # Pipe allows non-local usages of values, so long as those uses still respect
  239. # topological ordering. In particular:
  240. #
  241. # 1. Non-local activations. This type of usage can appear in, for example, skip
  242. # connections. These values will be directly transmitted from the "def" stage
  243. # to all stages that use them skipping intermediate stages. During autograd,
  244. # gradients will be propagated back through this skip connection reverse
  245. # to how activations propagated in the forward pass.
  246. # 2. Non-local parameter/module invocations. This occurs when a parameter is used
  247. # in a stage downstream of where it is resident. These values can be carried
  248. # forward similarly to (1), but in addition one might want to replicate the
  249. # value on multiple stages. Gradients for these shared parameters will be
  250. # accumulated separately on each stage, but there will be an additional
  251. # gradient accumulation before the optimizer step.
  252. # Register `_pipe_split()` as an ATen operator. This is required for Export to
  253. # preserve this marker in the graph.
  254. torch.library.define("pippy::_pipe_split", "() -> ()")
  255. @torch.library.impl("pippy::_pipe_split", "BackendSelect")
  256. def _pipe_split():
  257. return None
  258. @torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
  259. def _pipe_split(): # noqa: F811
  260. return None
  261. # Add an alias for convenience
  262. aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
  263. # Ask Export to preserve the `_pipe_split` op.
  264. # See examples in pytorch/torch/fx/node.py
  265. fx.node._side_effectful_functions.add(aten_pipe_split_alias)
  266. # User facing API
  267. def pipe_split():
  268. """
  269. pipe_split is a special operator that is used to mark the boundary between
  270. stages in a module. It is used to split the module into stages. It is a
  271. no-op if your annotated module is run eagerly.
  272. Example:
  273. >>> # xdoctest: +SKIP
  274. >>> def forward(self, x):
  275. >>> x = torch.mm(x, self.mm_param)
  276. >>> x = torch.relu(x)
  277. >>> pipe_split()
  278. >>> x = self.lin(x)
  279. >>> return x
  280. The above example will be split into two stages.
  281. """
  282. return torch.ops.pippy._pipe_split()
  283. class MultiUseParameterConfig(Enum):
  284. TRANSMIT = 1
  285. REPLICATE = 2
  286. MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]]
  287. class DetachExecutor(fx.Interpreter):
  288. """
  289. Special interpreter to run the split_gm in testing that detaches all inputs to
  290. a module invocation. This is needed so that the values at the boundary are
  291. leaf modules in autograd execution.
  292. """
  293. def __init__(self, module, garbage_collect_values=True):
  294. garbage_collect_values = False
  295. super().__init__(module, garbage_collect_values)
  296. self.value_remap = {}
  297. def run(self, *args, initial_env=None): # type: ignore[override]
  298. self.value_remap = {}
  299. return super().run(*args, initial_env=initial_env)
  300. def call_module(self, target, args, kwargs):
  301. def detach_tensors(a):
  302. if isinstance(a, torch.Tensor) and a.requires_grad:
  303. if a not in self.value_remap:
  304. new_val = a.detach().requires_grad_(True)
  305. self.value_remap[a] = new_val
  306. return self.value_remap[a]
  307. else:
  308. return a
  309. """
  310. def dont_traverse_size(a):
  311. return type(a) != torch.Size
  312. """
  313. args = map_aggregate(
  314. args,
  315. detach_tensors, # dont_traverse_size
  316. )
  317. kwargs = map_aggregate(
  318. kwargs,
  319. detach_tensors, # dont_traverse_size
  320. )
  321. return super().call_module(target, args, kwargs)
  322. def call_function(self, target, args, kwargs):
  323. # HACK to reroute saved input tensors to point to the detach()ed version
  324. if target == stage_backward:
  325. kwargs = dict(kwargs)
  326. kwargs["input_values"] = [
  327. self.value_remap.get(v, v) for v in kwargs["input_values"]
  328. ]
  329. return super().call_function(target, args, kwargs)
  330. class _NodeReference:
  331. def __init__(self, name):
  332. self.name = name
  333. name: str
  334. class _LinearNodeList:
  335. def __init__(self, node_list):
  336. self.serialize_node_list = []
  337. for node in node_list:
  338. node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
  339. node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value]
  340. serialize_node = fx.Node(
  341. graph=None, # type: ignore[arg-type]
  342. name=node.name,
  343. op=node.op,
  344. target=node.target,
  345. args=node_args, # type: ignore[arg-type]
  346. kwargs=node_kwargs, # type: ignore[arg-type]
  347. return_type=node.type,
  348. )
  349. serialize_node.meta = copy.copy(node.meta)
  350. self.serialize_node_list.append(serialize_node)
  351. def to_graph(self):
  352. graph = fx.Graph()
  353. ref_str_to_node: dict[str, fx.Node] = {}
  354. def ref_to_node(arg):
  355. if isinstance(arg, _NodeReference):
  356. return ref_str_to_node[arg.name]
  357. else:
  358. return arg
  359. for node in self.serialize_node_list:
  360. node_args = map_aggregate(node.args, ref_to_node)
  361. node_kwargs = map_aggregate(node.kwargs, ref_to_node)
  362. deser_node = graph.create_node(
  363. op=node.op,
  364. target=node.target,
  365. args=node_args, # type: ignore[arg-type]
  366. kwargs=node_kwargs, # type: ignore[arg-type]
  367. name=node.name,
  368. type_expr=node.type,
  369. )
  370. ref_str_to_node[node.name] = deser_node
  371. return graph
  372. def _direct_serialization_deserialize(body, nodes):
  373. """
  374. Custom `__reduce__` method for serialization.
  375. DO AS I SAY -- NOT AS I DO. This violates the principle that
  376. GraphModules serialize via code export & re-tracing. We allow
  377. for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
  378. TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
  379. these instances to disk will expose internal implementation
  380. details of `fx.Graph` and related data structures and is
  381. NOT advised.
  382. """
  383. class DummyModule(torch.nn.Module):
  384. def __init__(self, body):
  385. super().__init__()
  386. self.__dict__.update(body)
  387. dummy = DummyModule(body)
  388. return fx.GraphModule(dummy, nodes.to_graph())
  389. def _direct_serialization_reduce(self):
  390. serialization_dict = dict(self.__dict__)
  391. serialization_dict.pop("_graph")
  392. return (
  393. _direct_serialization_deserialize,
  394. (serialization_dict, _LinearNodeList(self.graph.nodes)),
  395. )
  396. def _modify_graph_op_device(
  397. gm: torch.fx.GraphModule,
  398. new_device: torch.device,
  399. ):
  400. """
  401. Modify the device argument of all "call_function" nodes in the graph. This
  402. is useful for moving the graph to a different device. In particular for
  403. generator ops, like torch.ones.
  404. """
  405. modified = False
  406. for node in gm.graph.nodes:
  407. if node.op == "call_function":
  408. if "device" in node.kwargs and node.kwargs["device"] != new_device:
  409. logger.debug(
  410. f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
  411. )
  412. node.update_kwarg("device", new_device)
  413. modified = True
  414. elif node.op == "call_module":
  415. # Recursively modify "device" in submodules
  416. submod = gm.get_submodule(node.target)
  417. if isinstance(submod, torch.fx.GraphModule):
  418. _modify_graph_op_device(submod, new_device)
  419. elif isinstance(submod, InterpreterModule):
  420. # If unflattening has been performed, we need to access its graph module by `.graph_module`
  421. _modify_graph_op_device(submod.graph_module, new_device) # type: ignore[arg-type]
  422. else:
  423. logger.warning(
  424. f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
  425. )
  426. if modified:
  427. gm.recompile()
  428. class Pipe(torch.nn.Module):
  429. def __init__(
  430. self,
  431. split_gm: fx.GraphModule,
  432. num_stages: int,
  433. has_loss_and_backward: bool,
  434. loss_spec,
  435. ):
  436. # TODO: is there a way not to hard wire init?
  437. torch.nn.Module.__init__(self)
  438. self.split_gm: fx.GraphModule = split_gm
  439. self.executor: DetachExecutor = DetachExecutor(self.split_gm)
  440. self.num_stages: int = num_stages
  441. self.has_loss_and_backward = has_loss_and_backward
  442. self.loss_spec = loss_spec
  443. for node in split_gm.graph.nodes:
  444. assert (
  445. node.op in {"call_module", "placeholder", "output"}
  446. or (node.op, node.target) == ("call_function", operator.getitem)
  447. or (node.op, node.target) == ("call_method", "backward")
  448. or (node.op, node.target) == ("call_function", stage_backward)
  449. or (node.op, node.target)
  450. == ("call_function", _null_coalesce_accumulate)
  451. ), node
  452. # Detect replicated parameters so we know that we have to do an additional allreduce
  453. # before applying the optimizer
  454. #
  455. # Note that this also handles the case where there were multiple calls to a single
  456. # module from different stages, regardless of whether that module invocation
  457. # was handled by the logic above.
  458. # Map parameter value to a dictionary that maps the user pipeline module
  459. # to the local qualname within that module
  460. params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {}
  461. for m_qualname, mod in self.split_gm.named_children():
  462. for p_qualname, param in mod.named_parameters():
  463. params_to_users.setdefault(param, {})
  464. params_to_users[param][m_qualname] = p_qualname
  465. self.replicated_params: list[dict[str, str]] = [
  466. use_mapping
  467. for _, use_mapping in params_to_users.items()
  468. if len(use_mapping) > 1
  469. ]
  470. # We must break the aliasing relationship between the replicated parameters for correct
  471. # numerics in reference runs. If we do not do this, the autograd tape in separate stages
  472. # will have a reference to the same tensor value and will erroneously apply gradient
  473. # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
  474. # values so that we have separate instances.
  475. for param_mapping in self.replicated_params:
  476. for submod_name, param_qualname in param_mapping.items():
  477. submod = getattr(self.split_gm, submod_name)
  478. atoms = param_qualname.split(".")
  479. for atom in atoms[:-1]:
  480. submod = getattr(submod, atom)
  481. setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
  482. def throw(self, *args, **kwargs):
  483. raise RuntimeError(
  484. "To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
  485. )
  486. self.split_gm.forward = throw
  487. # Make submodules use custom direct-serialized GraphModule
  488. i = 0
  489. while True:
  490. try:
  491. name = f"submod_{i}"
  492. submod = getattr(self.split_gm, name)
  493. submod.__class__.__reduce__ = _direct_serialization_reduce
  494. i += 1
  495. except AttributeError:
  496. break
  497. def forward(self, *args, **kwargs):
  498. executor_args = args
  499. if len(kwargs) > 0:
  500. parameters = []
  501. for node in self.split_gm.graph.nodes:
  502. if node.op == "placeholder":
  503. if node.args and len(node.args) > 0:
  504. parameters.append(
  505. Parameter(
  506. node.target,
  507. Parameter.POSITIONAL_OR_KEYWORD,
  508. default=node.args[0],
  509. )
  510. )
  511. else:
  512. parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
  513. param_name = node.target
  514. if node.target.startswith("**"):
  515. parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
  516. param_name = param_name[2:]
  517. elif node.target.startswith("*"):
  518. parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
  519. param_name = param_name[1:]
  520. parameters.append(Parameter(param_name, parameter_kind))
  521. signature = Signature(parameters)
  522. ba = signature.bind(*args, **kwargs)
  523. ba.apply_defaults()
  524. executor_args = ba.arguments.values() # type: ignore[assignment]
  525. res = self.executor.run(*executor_args)
  526. return res
  527. def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
  528. """
  529. Return a stage module corresponding to `stage_idx` of the `pipe`.
  530. """
  531. if stage_idx < 0 or stage_idx >= self.num_stages:
  532. raise ValueError(f"Invalid stage index {stage_idx}!")
  533. return getattr(self.split_gm, f"submod_{stage_idx}")
  534. @staticmethod
  535. def _number_and_count_forward_stages(gm: fx.GraphModule):
  536. num_stages = 0
  537. found_idxs: dict[int, None] = {}
  538. for node in gm.graph.nodes:
  539. if node.op == "call_module" and node.target.startswith("submod_"):
  540. node.meta["stage_idx"] = int(node.target[len("submod_") :])
  541. found_idxs.setdefault(node.meta["stage_idx"])
  542. num_stages += 1
  543. # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
  544. # Update: the following assert may fail against some torch versions >=
  545. # 2.2.0, as:
  546. # submod_0, submod_1, submod_2, ...
  547. # may be named as
  548. # submod_0, submod_2, submod_4, ...
  549. # TODO: investigate
  550. # assert all(i in found_idxs for i in range(num_stages))
  551. return num_stages
  552. @staticmethod
  553. def _from_traced(
  554. mod: torch.nn.Module,
  555. exported_program: ExportedProgram,
  556. multi_use_param_spec: Optional[MultiUseParamSpec] = None,
  557. output_loss_value_spec=None,
  558. split_policy: Optional[
  559. Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
  560. ] = None,
  561. ):
  562. """
  563. Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
  564. which value in the output of `forward` is the loss value on which PiPPy should apply
  565. backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
  566. you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
  567. a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
  568. ``output_loss_value_spec={'loss': True, 'model_out': False}``
  569. """
  570. traced = exported_program.module(check_guards=False)
  571. if split_policy is not None:
  572. logger.info("Auto-splitting model")
  573. traced = split_policy(traced) # type: ignore[arg-type]
  574. logger.debug(traced.print_readable(print_output=False)) # type: ignore[operator]
  575. # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
  576. # parameters relies on the invariant that parameter accesses happen once. This is not necessarily
  577. # the case (especially with custom tracers), so fix that up here.
  578. get_attr_nodes: dict[str, fx.Node] = {}
  579. for node in traced.graph.nodes: # type: ignore[union-attr]
  580. if node.op == "get_attr":
  581. get_attr_nodes.setdefault(node.target, node)
  582. if get_attr_nodes[node.target] != node:
  583. node.replace_all_uses_with(get_attr_nodes[node.target])
  584. traced.graph.erase_node(node) # type: ignore[operator, union-attr]
  585. # avoid looking at next node by keeping track of previous pipe_split
  586. prev_pipe_split_idx = -1
  587. pipe_split_nodes_to_erase = set()
  588. for i, node in enumerate(traced.graph.nodes): # type: ignore[arg-type, union-attr]
  589. if (node.op, node.target) == ("call_function", pipe_split):
  590. if prev_pipe_split_idx == i - 1:
  591. pipe_split_nodes_to_erase.add(node)
  592. prev_pipe_split_idx = i
  593. for node in pipe_split_nodes_to_erase:
  594. traced.graph.erase_node(node) # type: ignore[operator, union-attr]
  595. traced.recompile() # type: ignore[operator]
  596. part_idx = 0
  597. def split_callback(n: fx.Node):
  598. nonlocal part_idx
  599. if (n.op, n.target) == (
  600. "call_function",
  601. aten_pipe_split_alias,
  602. ):
  603. logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
  604. part_idx += 1
  605. return part_idx
  606. # TODO: what does split do with module invocations? does it move the modules
  607. # into the submodules?
  608. split = split_module(traced, mod, split_callback) # type: ignore[arg-type]
  609. # a (custom) tracer can produce dead code like orphan get_attr nodes
  610. split.graph.eliminate_dead_code()
  611. # peephole to remove pipe_split
  612. for submodule in split.modules():
  613. if isinstance(submodule, fx.GraphModule):
  614. for node in submodule.graph.nodes:
  615. if (node.op, node.target) == (
  616. "call_function",
  617. aten_pipe_split_alias,
  618. ):
  619. submodule.graph.erase_node(node)
  620. submodule.recompile()
  621. for name, submodule in split.named_children():
  622. if isinstance(submodule, fx.GraphModule):
  623. new_submod = _outline_submodules(submodule.graph)
  624. # Replace old submod
  625. split.register_module(name, new_submod)
  626. # TODO: backport this into split_module
  627. def delete_user_reference(node, user):
  628. """
  629. Delete reference of `node` from `user`'s arg list.
  630. Args:
  631. - node: a `get_attr` node at root.
  632. - user: a submodule node that uses `node`.
  633. """
  634. assert len(user.kwargs) == 0
  635. use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
  636. assert len(use_idxs) == 1
  637. args_copy = list(user.args)
  638. args_copy.pop(use_idxs[0])
  639. user.args = tuple(args_copy)
  640. logger.debug(
  641. f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
  642. )
  643. # A list of param referrals for deferred deletion.
  644. # To be accumulated in `move_param_to_callee`.
  645. to_delete = []
  646. def _recursive_getattr_with_parent(mod, fqn):
  647. # Returns getattr call given a nested FQN, and the last parent
  648. atoms = fqn.split(".")
  649. for atom in atoms[:-1]:
  650. if not hasattr(mod, atom):
  651. return None, None
  652. mod = getattr(mod, atom)
  653. if not hasattr(mod, atoms[-1]):
  654. return mod, None
  655. attr = getattr(mod, atoms[-1])
  656. return mod, attr
  657. def move_param_to_callee(
  658. root,
  659. callee_name,
  660. param_fqn,
  661. ):
  662. """
  663. Move a parameter from the root module to a submodule.
  664. Args:
  665. root: The root module.
  666. callee_name: The name of the submodule to move the parameter to.
  667. param_fqn: The fully qualified name of the parameter to move.
  668. """
  669. # `atoms` is a list of strings representing the path to the
  670. # parameter in the original model
  671. atoms = param_fqn.split(".")
  672. mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
  673. # Check whether the parameter is a buffer or a parameter
  674. is_buffer = atoms[-1] in mod_itr._buffers
  675. # Check whether the parameter is a tensor
  676. assert isinstance(param_val, torch.Tensor), (
  677. f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
  678. + (
  679. f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
  680. f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
  681. f"usages of '{param_fqn}' in the traced graph."
  682. if isinstance(param_val, torch.nn.Module)
  683. else ""
  684. )
  685. )
  686. # Get submodule
  687. callee = root.get_submodule(callee_name)
  688. assert not hasattr(callee, param_fqn), (
  689. f"Module {callee_name} already has a parameter named {param_fqn}"
  690. )
  691. # Assign the parameter to the submodule
  692. if is_buffer:
  693. _assign_attr(
  694. param_val,
  695. callee,
  696. param_fqn,
  697. attr_kind=_AttrKind.BUFFER,
  698. persistent=True, # TODO: handle non-persistent buffer
  699. )
  700. else:
  701. _assign_attr(
  702. param_val,
  703. callee,
  704. param_fqn,
  705. attr_kind=_AttrKind.PARAMETER,
  706. )
  707. logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
  708. # Next step is to replace placeholder of submodule with a get_attr.
  709. # Those placeholders are created by `split_module` inside each
  710. # submodule.
  711. # Update: this step is now moved to `_sink_params` because
  712. # `_sink_params` can do it recursively (i.e. for modules inside
  713. # submodule)
  714. to_delete.append((mod_itr, atoms[-1]))
  715. # Get the list of all parameters in the root module
  716. attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
  717. for node in attr_nodes:
  718. # Check whether the parameter is used in only one submodule
  719. if len(node.users) > 1:
  720. logger.info(
  721. f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
  722. )
  723. for user in node.users:
  724. assert user.op == "call_module"
  725. # Move parameter into submodule
  726. move_param_to_callee(
  727. split,
  728. user.target,
  729. node.target,
  730. )
  731. # [aliasing] store tensor id -> list of FQNs, built from state dict
  732. # Also assign non-persistent buffers
  733. id_to_fqns: dict[int, set[str]] = defaultdict(set)
  734. for fqn, tensor in mod.state_dict(keep_vars=True).items():
  735. id_to_fqns[id(tensor)].add(fqn)
  736. for fqn, tensor in mod.named_buffers():
  737. id_to_fqns[id(tensor)].add(fqn)
  738. # After moving the params to their corresponding hierarchies, we also
  739. # need to move the `get_attr` nodes from the root of the graph to those
  740. # hierarchies.
  741. # [aliasing] use id -> fqn mapping to list out all valid FQNs
  742. inputs_to_state: dict[str, list[str]] = {}
  743. for attr in attr_nodes:
  744. _, tensor = _recursive_getattr_with_parent(mod, attr.target)
  745. fqns = list(id_to_fqns[id(tensor)])
  746. if fqns:
  747. inputs_to_state[attr.name] = fqns
  748. elif attr.target in exported_program.constants: # lifted constants
  749. inputs_to_state[attr.name] = [attr.target]
  750. # [aliasing] for each submodule split, assign attributes on FQNs that may be used.
  751. # We determine this based on whether or not the FQN attribute parent exists.
  752. # i.e. if the last submodule exists, assign the attribute.
  753. added_attributes: dict[str, list[str]] = defaultdict(list)
  754. for fqn, tensor in mod.state_dict(keep_vars=True).items():
  755. for name, submod in split.named_children():
  756. if isinstance(submod, fx.GraphModule):
  757. parent, child = _recursive_getattr_with_parent(submod, fqn)
  758. if (
  759. parent and child is None
  760. ): # parent exists, attribute doesn't -> assign
  761. added_attributes[name].append(fqn)
  762. setattr(parent, fqn.split(".")[-1], tensor)
  763. # Deferral deletion: Remove the original attributes (to params) from the
  764. # root GraphModule
  765. for mod_itr, last_atom in to_delete:
  766. try:
  767. delattr(mod_itr, last_atom)
  768. except AttributeError:
  769. # This is expected if the parameter is used in multiple stages
  770. pass
  771. # This is done by (1) `_sink_params` at each submodule;
  772. for name, submod in split.named_children():
  773. if isinstance(submod, fx.GraphModule):
  774. _sink_params(submod, inputs_to_state, [])
  775. submod.graph.lint()
  776. submod.recompile()
  777. # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
  778. # After _sink_params() routine has run, clean up unused attributes that we previously added.
  779. # Determine this based on the get_attr nodes - if not used, remove it.
  780. for name, attributes in added_attributes.items():
  781. submod = getattr(split, name)
  782. unused_attributes = set(attributes)
  783. # track used attributes in the submodule, running DFS on subgraph hierarchy
  784. stack = [("", submod)] # (scope, submodule)
  785. while stack:
  786. scope, _mod = stack.pop()
  787. if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
  788. for node in _mod.graph.nodes:
  789. if node.op == "get_attr":
  790. # get_attr might get access deeper level attribute
  791. fqn = scope + "." + node.target if scope else node.target
  792. unused_attributes.discard(fqn)
  793. for _name, _submod in _mod.named_children():
  794. stack.append((scope + "." + _name if scope else _name, _submod))
  795. # delete unused attributes
  796. for attr in unused_attributes:
  797. mod_itr, atoms = submod, attr.split(".")
  798. for atom in atoms[:-1]:
  799. mod_itr = getattr(mod_itr, atom)
  800. delattr(mod_itr, atoms[-1])
  801. for node in attr_nodes:
  802. # And (2): remove `get_attr` node from submod's arg list
  803. for user in copy.copy(node.users):
  804. assert user.op == "call_module"
  805. delete_user_reference(node, user)
  806. # And (3): remove the `get_attr` node from the root graph.
  807. split.graph.erase_node(node)
  808. split.delete_all_unused_submodules()
  809. split.graph.lint()
  810. split.recompile()
  811. num_stages = Pipe._number_and_count_forward_stages(split)
  812. has_loss_and_backward = False
  813. generated_loss_spec = output_loss_value_spec
  814. if output_loss_value_spec is not None:
  815. loss_node, output_node, generated_loss_spec = _find_loss_output(
  816. mod, split.graph, output_loss_value_spec
  817. )
  818. if loss_node is not None:
  819. _insert_stage_symbolic_backward(
  820. split.graph,
  821. loss_node,
  822. output_node,
  823. )
  824. split.recompile()
  825. has_loss_and_backward = True
  826. logger.debug("Pipeline is in training mode, backward pass generated")
  827. else:
  828. raise RuntimeError(
  829. f"Did not find any loss value according to {output_loss_value_spec=}"
  830. )
  831. else:
  832. logger.debug("Pipeline is in inference mode, backward pass not generated")
  833. logger.debug(f"Full pipe model:\n{split}") # noqa: G004
  834. return Pipe(
  835. split,
  836. num_stages,
  837. has_loss_and_backward,
  838. generated_loss_spec,
  839. )
  840. def print_readable(self):
  841. """
  842. Print the pipe in a human-readable format.
  843. This will print both the root pipe and each stage module.
  844. """
  845. self.split_gm.print_readable()
  846. @staticmethod
  847. def _trace_with_export(
  848. mod: torch.nn.Module,
  849. example_args: tuple[Any, ...],
  850. example_kwargs: Optional[dict[str, Any]] = None,
  851. ) -> ExportedProgram:
  852. logger.info("Tracing model ...")
  853. try:
  854. ep = torch.export.export_for_training(
  855. mod, example_args, example_kwargs, strict=True
  856. )
  857. except Exception as e:
  858. raise RuntimeError(
  859. "It seems that we cannot capture your model as a full graph. "
  860. "Typical reasons include graph breaks, data/shape-dependent "
  861. "control flow, or missing meta kernels for custom operators. "
  862. "You can use our manual pipeline interfaces, or try to fix the "
  863. "graph breaks, see https://pytorch.org/docs/stable/export.html"
  864. ) from e
  865. return ep
  866. @staticmethod
  867. def from_tracing(
  868. mod: torch.nn.Module,
  869. example_args: tuple[Any, ...],
  870. example_kwargs: Optional[dict[str, Any]] = None,
  871. split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
  872. ):
  873. # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
  874. # stages instead of TRANSMIT'ting it
  875. multi_use_param_spec = MultiUseParameterConfig.REPLICATE
  876. # Figure out which output is loss from output_chunk_spec
  877. output_loss_value_spec: Any = None
  878. # Deprecated
  879. """
  880. if output_chunk_spec is not None:
  881. output_loss_value_spec = map_aggregate(
  882. output_chunk_spec, lambda v: isinstance(v, _LossReducer)
  883. )
  884. """
  885. # Trace with export
  886. exported_program = Pipe._trace_with_export(
  887. mod,
  888. example_args,
  889. example_kwargs,
  890. )
  891. pipe = Pipe._from_traced(
  892. mod,
  893. exported_program,
  894. multi_use_param_spec,
  895. output_loss_value_spec=output_loss_value_spec,
  896. split_policy=split_policy,
  897. )
  898. # Users want the first pipeline stage to accept kwargs if the original
  899. # program does. This is controlled by the `_codegen` field of the graph,
  900. # so we make a copy here. Note: we only want the input spec and not the
  901. # output spec, because the output spec is for the last stage. Maybe a
  902. # TODO? Not sure yet.
  903. split = pipe.split_gm
  904. traced = exported_program.module()
  905. submod0 = next(iter(split.children()))
  906. submod0_sign = signature(submod0.forward)
  907. model_sign = signature(traced.forward)
  908. if len(model_sign.parameters) != len(submod0_sign.parameters):
  909. # We don't change the signature of the first stage if it takes
  910. # different number of args than original model
  911. logger.info(
  912. f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
  913. f"first pipeline stage takes {len(submod0_sign.parameters)}. "
  914. "Please provide args to respective pipeline stages."
  915. )
  916. else:
  917. # Support kwargs for the first stage
  918. submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) # type: ignore[union-attr]
  919. # `_replace` is actually not "private" or internal. based on this doc:
  920. # To prevent conflicts with field names, the method and attribute names
  921. # start with an underscore
  922. submod0.graph._codegen.pytree_info = ( # type: ignore[union-attr]
  923. submod0.graph._codegen.pytree_info._replace(out_spec=None) # type: ignore[operator, union-attr]
  924. )
  925. submod0.recompile()
  926. return pipe
  927. def __str__(self):
  928. return self.split_gm.__str__()
  929. def __repr__(self):
  930. return self.split_gm.__repr__()
  931. def info(self) -> PipeInfo:
  932. """
  933. Get information about the pipe.
  934. Returns
  935. -------
  936. PipeInfo
  937. A dataclass containing information about the pipe.
  938. """
  939. return PipeInfo(
  940. graph=self.split_gm.graph,
  941. num_stages=self.num_stages,
  942. has_loss_and_backward=self.has_loss_and_backward,
  943. )
  944. def build_stage(
  945. self,
  946. stage_index: int,
  947. device: torch.device,
  948. group: Optional[ProcessGroup] = None,
  949. ) -> _PipelineStage:
  950. """
  951. Create a `PipelineStage` given a stage index and distributed group.
  952. The `PipelineStage` can run with `PipelineSchedule`s.
  953. """
  954. # Find stage module
  955. stage_module = self.get_stage_module(stage_index)
  956. # Move ops argument to device
  957. # Today PT2 tracer does not treat `x.device` as a symbolic device;
  958. # instead, the device of tracing time got burned into the generated
  959. # code. Here we provide a workaround for users to manually modify the
  960. # "device" kwarg of operations. Such operation may include:
  961. # `torch.ones`, `torch.zeros`, `torch.rand`, etc.
  962. if isinstance(stage_module, torch.fx.GraphModule):
  963. _modify_graph_op_device(stage_module, device)
  964. else:
  965. logger.warning(
  966. f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
  967. )
  968. # Detach pipe info
  969. # Note: be careful what's included in `pipe_info`. We don't want to keep
  970. # a reference to `Pipe` or `Pipe.split_gm` which stops python from
  971. # recycling them. When python recycles them, other stage modules (which
  972. # are irrelevant to current rank) can be automatically freed.
  973. pipe_info = self.info()
  974. return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
  975. class SplitPoint(Enum):
  976. """
  977. Enum representing the points at which a split can occur in the execution of a submodule.
  978. Attributes:
  979. BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function.
  980. END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function.
  981. """
  982. BEGINNING = 1
  983. END = 2
  984. # For backward compatibility, we kept the PipeSplitWrapper class because `class
  985. # SplitPoint` used to be defined in this class.
  986. class PipeSplitWrapper:
  987. # Create a class alias for BC
  988. SplitPoint = SplitPoint
  989. def _split_before_forward(self, *args, **kwargs):
  990. pipe_split()
  991. return self._orig_forward(*args, **kwargs)
  992. def _split_after_forward(self, *args, **kwargs):
  993. try:
  994. return self._orig_forward(*args, **kwargs)
  995. finally:
  996. pipe_split()
  997. def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]):
  998. # TODO: make this implementation out-of-place?
  999. for qualname, split_type in spec.items():
  1000. atoms = qualname.split(".")
  1001. predecessor_module = mod
  1002. for i, atom in enumerate(atoms[:-1]):
  1003. try:
  1004. predecessor_module = getattr(predecessor_module, atom)
  1005. except AttributeError as e:
  1006. raise AttributeError(
  1007. f"Specified target {qualname} referenced "
  1008. f"nonexistent module {'.'.join(atoms[: i + 1])}"
  1009. ) from e
  1010. mod_to_wrap = getattr(predecessor_module, atoms[-1])
  1011. mod_to_wrap._orig_forward = mod_to_wrap.forward
  1012. if split_type == SplitPoint.BEGINNING:
  1013. mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
  1014. elif split_type == SplitPoint.END:
  1015. mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
  1016. else:
  1017. raise ValueError("Unknown split point type.")
  1018. def pipeline(
  1019. module: torch.nn.Module,
  1020. mb_args: tuple[Any, ...],
  1021. mb_kwargs: Optional[dict[str, Any]] = None,
  1022. split_spec: Optional[dict[str, SplitPoint]] = None,
  1023. split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
  1024. ) -> Pipe:
  1025. """
  1026. Split a module based on a specification.
  1027. See `Pipe` for more details.
  1028. Arguments
  1029. ---------
  1030. module:
  1031. The module to be split.
  1032. mb_args:
  1033. Example positional inputs, in micro-batch form.
  1034. mb_kwargs:
  1035. Example keyword inputs, in micro-batch form. (default: `None`)
  1036. split_spec:
  1037. A dictionary using submodule names as split marker. (default: `None`)
  1038. split_policy:
  1039. The policy to use for splitting the module. (default: `None`)
  1040. Returns
  1041. -------
  1042. A pipeline representation of class `Pipe`.
  1043. """
  1044. if split_spec is not None and split_policy is not None:
  1045. raise ValueError(
  1046. "Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
  1047. )
  1048. if split_spec is not None:
  1049. # Annotate split points in the module based on user spec
  1050. annotate_split_points(module, split_spec)
  1051. return Pipe.from_tracing(
  1052. mod=module,
  1053. example_args=mb_args,
  1054. example_kwargs=mb_kwargs,
  1055. )
  1056. else:
  1057. # Use split policy
  1058. return Pipe.from_tracing(
  1059. mod=module,
  1060. example_args=mb_args,
  1061. example_kwargs=mb_kwargs,
  1062. split_policy=split_policy,
  1063. )