distributed.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # mypy: ignore-errors
  2. """
  3. Distributed computing variable tracking classes for PyTorch Dynamo.
  4. This module implements variable tracking for distributed computing components:
  5. - Process Groups (for collective communication)
  6. - Device Meshes (for distributed tensor sharding)
  7. - Placement Types (for specifying distribution strategies)
  8. - Distributed Tensors and their operations
  9. - Backward hooks for distributed module operations
  10. These classes are responsible for tracking distributed operations during graph
  11. compilation while maintaining proper guards and handling distributed-specific
  12. behaviors. They ensure correct handling of distributed components like process
  13. groups, device meshes, and placement strategies while preserving proper semantics
  14. for distributed tensor operations in the compiled code.
  15. The implementation provides special handling for distributed package availability
  16. checks and proper tracking of distributed state and operations across processes.
  17. """
  18. import functools
  19. import inspect
  20. from typing import TYPE_CHECKING
  21. import torch
  22. from torch.fx.experimental._backward_state import BackwardState
  23. from .. import compiled_autograd, variables
  24. from .._trace_wrapped_higher_order_op import trace_wrapped
  25. from ..exc import unimplemented_v2
  26. from ..external_utils import call_module_hooks_from_backward_state
  27. from ..guards import GuardBuilder, install_guard
  28. from ..source import AttrSource
  29. from ..utils import istype
  30. from .base import VariableTracker
  31. from .constant import ConstantVariable, EnumVariable
  32. if TYPE_CHECKING:
  33. from torch._dynamo.symbolic_convert import InstructionTranslator
  34. class DistributedVariable(VariableTracker):
  35. """
  36. The base distributed variable that encapsulates common methods
  37. for the distributed objects (i.e. ProcessGroup, DeviceMesh, etc.).
  38. Concrete distributed objects could inherit this class and add object
  39. specific logic.
  40. i.e. It provides the check on the distributed package existence
  41. and hold the tracking value for the corresponding distributed object.
  42. """
  43. def __init__(self, value, **kwargs) -> None:
  44. super().__init__(**kwargs)
  45. if not DistributedVariable.is_available():
  46. unimplemented_v2(
  47. gb_type="torch.distributed package is not available!",
  48. context="",
  49. explanation="The PyTorch package doesn't include torch.distributed when building from source.",
  50. hints=[
  51. "Set USE_DISTRIBUTED=1 to enable it when building PyTorch from source."
  52. ],
  53. )
  54. self.value = value
  55. def python_type(self):
  56. return type(self.value)
  57. @staticmethod
  58. def is_available():
  59. # check if the distributed package is available or not
  60. return torch.distributed.is_available()
  61. def is_from_local(value):
  62. if not DistributedVariable.is_available():
  63. return False
  64. from torch.distributed.tensor import DTensor
  65. return inspect.isfunction(value) and value is DTensor.from_local
  66. def is_constant_pg_functions(value):
  67. if not DistributedVariable.is_available():
  68. return False
  69. from torch.distributed.distributed_c10d import (
  70. _get_group_size_by_name,
  71. _get_group_tag,
  72. _rank_not_in_group,
  73. _resolve_group_name_by_ranks_and_tag,
  74. get_process_group_ranks,
  75. )
  76. constant_processgroup_functions = [
  77. _get_group_size_by_name,
  78. _get_group_tag,
  79. _rank_not_in_group,
  80. get_process_group_ranks,
  81. _resolve_group_name_by_ranks_and_tag,
  82. ]
  83. return inspect.isfunction(value) and value in constant_processgroup_functions
  84. class WorldMetaClassVariable(DistributedVariable):
  85. """
  86. Tracks torch.distributed.GroupMember and torch.distributed.group, which are
  87. instances of the metaclass _WorldMeta.
  88. """
  89. @classmethod
  90. def is_group_member_type(cls, value):
  91. if not cls.is_available():
  92. return False
  93. from torch.distributed.distributed_c10d import _WorldMeta
  94. return type(value) is _WorldMeta
  95. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  96. if name == "WORLD":
  97. source = AttrSource(base=self.source, member="WORLD")
  98. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  99. return ProcessGroupVariable(self.value.WORLD)
  100. elif name == "NON_GROUP_MEMBER":
  101. source = AttrSource(base=self.source, member="NON_GROUP_MEMBER")
  102. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  103. return EnumVariable(self.value.NON_GROUP_MEMBER)
  104. return super().var_getattr(tx, name)
  105. class PlacementClassVariable(DistributedVariable):
  106. @staticmethod
  107. def is_placement_type(value):
  108. # we can't rely on importing/accessing torch distributed, it is not always built.
  109. if not DistributedVariable.is_available():
  110. return False
  111. from torch.distributed.tensor.placement_types import Placement
  112. return type(value) is type and issubclass(value, Placement)
  113. def as_python_constant(self):
  114. return self.value
  115. def call_function(
  116. self,
  117. tx: "InstructionTranslator",
  118. args: "list[VariableTracker]",
  119. kwargs: "dict[str, VariableTracker]",
  120. ) -> "VariableTracker":
  121. if (
  122. inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
  123. and self.source
  124. ):
  125. # NOTE: we don't need to track mutations to the placement class as they
  126. # suppose to be immutable.
  127. new_obj = object.__new__(self.value)
  128. var = PlacementVariable(new_obj)
  129. if inspect.getattr_static(self.value, "__init__", None):
  130. var.call_method(tx, "__init__", args, kwargs)
  131. return var
  132. return super().call_function(tx, args, kwargs)
  133. class PlacementVariable(DistributedVariable):
  134. @staticmethod
  135. def is_placement(value):
  136. # we can't rely on importing/accessing torch distributed, it is not always built.
  137. if not DistributedVariable.is_available():
  138. return False
  139. from torch.distributed.tensor.placement_types import Placement
  140. return isinstance(value, Placement)
  141. def as_python_constant(self):
  142. return self.value
  143. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  144. if name == "dim":
  145. return ConstantVariable.create(self.value.dim)
  146. return super().var_getattr(tx, name)
  147. def call_method(
  148. self,
  149. tx,
  150. name,
  151. args: "list[VariableTracker]",
  152. kwargs: "dict[str, VariableTracker]",
  153. ) -> "VariableTracker":
  154. from . import ConstantVariable
  155. # Placement types dynamo tracking only allows following methods
  156. # and __setattr__ is for case like `Shard(dim)` and methods.
  157. # Methods in the list must satisfy:
  158. # 1. Input arguments are constants and do not need to be guarded on;
  159. # 2. Output is constant with respect to their inputs
  160. constant_fold_functions = [
  161. "__init__",
  162. "__setattr__",
  163. "is_shard",
  164. "is_partial",
  165. "is_replicate",
  166. ]
  167. if name in constant_fold_functions:
  168. try:
  169. value_type = type(self.value)
  170. assert (
  171. inspect.getattr_static(value_type, "__getattr__", None) is None
  172. ), "no custom getattr allowed!"
  173. method = inspect.getattr_static(value_type, name)
  174. except AttributeError:
  175. method = None
  176. if method is object.__init__:
  177. return ConstantVariable.create(None)
  178. args = [x.as_python_constant() for x in args]
  179. kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  180. if name == "__setattr__":
  181. method(self.value, *args, **kwargs)
  182. return self
  183. constant_val = method(self.value, *args, **kwargs)
  184. return ConstantVariable.create(constant_val)
  185. return super().call_method(tx, name, args, kwargs)
  186. class DeviceMeshVariable(DistributedVariable):
  187. @staticmethod
  188. def is_device_mesh(value):
  189. # we can't rely on importing/accessing torch distributed, it is not always built.
  190. if not DistributedVariable.is_available():
  191. return False
  192. from torch.distributed.device_mesh import DeviceMesh
  193. return istype(value, DeviceMesh)
  194. def as_python_constant(self):
  195. return self.value
  196. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  197. if name == "ndim":
  198. return ConstantVariable.create(self.value.ndim)
  199. if name == "device_type":
  200. return ConstantVariable.create(self.value.device_type)
  201. return super().var_getattr(tx, name)
  202. def call_method(
  203. self,
  204. tx,
  205. name,
  206. args: "list[VariableTracker]",
  207. kwargs: "dict[str, VariableTracker]",
  208. ) -> "VariableTracker":
  209. if name == "size":
  210. const_args = [x.as_python_constant() for x in args]
  211. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  212. return ConstantVariable.create(self.value.size(*const_args, **const_kwargs))
  213. if name == "get_coordinate":
  214. return ConstantVariable.create(self.value.get_coordinate())
  215. if name == "get_rank":
  216. return ConstantVariable.create(self.value.get_rank())
  217. if name == "get_local_rank":
  218. return ConstantVariable.create(self.value.get_local_rank())
  219. if name == "get_group":
  220. const_args = [x.as_python_constant() for x in args]
  221. const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  222. return ProcessGroupVariable(
  223. self.value.get_group(*const_args, **const_kwargs)
  224. )
  225. if name == "_get_or_create_default_group":
  226. return ProcessGroupVariable(self.value._get_or_create_default_group())
  227. return super().call_method(tx, name, args, kwargs)
  228. class ProcessGroupVariable(DistributedVariable):
  229. """
  230. We don't want a ProcessGroup object to end up in our output graph.
  231. But it's common for dynamo to intercept a PG that is then used to get info like
  232. rank() or world_size(), as well as passed to utility functions in distributed_c10d
  233. which desugar it into plain types like a ranklist and tag.
  234. For convenience and proper guarding, we construct a variable type.
  235. TODO: make it possible to use ProcessGroupVariable as input to simple functions
  236. like _expand_group without dynamo complaining about making a proxy for it.
  237. It is not a tensor-like type, and we don't want a proxy- but dynamo assumes
  238. torch library functions are dealing with tensor-like types and would have proxies
  239. for their args.
  240. TODO: should we make this inherit VT instead of UDOV? Do we want any of the default behaviors
  241. or just graph-break whenever one of our special cases is not hit?
  242. """
  243. def as_python_constant(self):
  244. return self.value
  245. def call_method(
  246. self,
  247. tx,
  248. name,
  249. args: "list[VariableTracker]",
  250. kwargs: "dict[str, VariableTracker]",
  251. ) -> "VariableTracker":
  252. if name == "rank":
  253. return variables.ConstantVariable.create(self.value.rank())
  254. if name == "size":
  255. return variables.ConstantVariable.create(self.value.size())
  256. if name == "_get_backend_name":
  257. return variables.ConstantVariable.create(self.value._get_backend_name())
  258. return super().call_method(tx, name, args, kwargs)
  259. def var_getattr(self, tx: "InstructionTranslator", name):
  260. if name == "group_name":
  261. return variables.ConstantVariable.create(self.value.group_name)
  262. if name in ["rank", "size"]:
  263. return variables.LambdaVariable(
  264. lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
  265. )
  266. # TODO should this just raise unimplemented?
  267. return super().var_getattr(tx, name)
  268. @staticmethod
  269. def is_process_group(value):
  270. # we can't rely on importing/accessing torch distributed, it is not always built.
  271. if not DistributedVariable.is_available():
  272. return False
  273. from torch._C._distributed_c10d import ProcessGroup
  274. from torch.testing._internal.distributed.fake_pg import FakeProcessGroup
  275. return istype(value, (ProcessGroup, FakeProcessGroup))
  276. class BackwardHookVariable(VariableTracker):
  277. """
  278. Handles torch.utils.hooks.BackwardHook for module-level backward
  279. hooks.
  280. """
  281. @staticmethod
  282. def create(
  283. tx,
  284. module: VariableTracker,
  285. user_hooks: VariableTracker,
  286. user_pre_hooks: VariableTracker,
  287. ):
  288. if not compiled_autograd.compiled_autograd_enabled:
  289. unimplemented_v2(
  290. gb_type="Module-level backwards hooks require compiled autograd.",
  291. context="",
  292. explanation="",
  293. hints=[
  294. "Enable compiled autograd by setting torch._dynamo.config.compiled_autograd = True."
  295. ],
  296. )
  297. def _in_graph_bw_hooks(bw_state: BackwardState):
  298. """
  299. Rather than installing the user hooks in the graph (which
  300. don't survive AotAutograd), we install hooks that will call
  301. trace_wrapped in the backward pass that CompiledAutograd
  302. can turn into actual hook calls.
  303. """
  304. return torch.utils.hooks.BackwardHook(
  305. None,
  306. (
  307. functools.partial(
  308. trace_wrapped,
  309. fn=call_module_hooks_from_backward_state,
  310. bw_state=bw_state,
  311. hooks_name=user_hooks_name,
  312. module_name=module_name,
  313. ),
  314. ),
  315. (
  316. functools.partial(
  317. trace_wrapped,
  318. fn=call_module_hooks_from_backward_state,
  319. bw_state=bw_state,
  320. hooks_name=user_pre_hooks_name,
  321. module_name=module_name,
  322. ),
  323. ),
  324. )
  325. module_name, bw_state_proxy = tx.output.add_backward_state_hook(module, "mod")
  326. user_pre_hooks_name, _ = tx.output.add_backward_state_hook(user_pre_hooks)
  327. user_hooks_name, _ = tx.output.add_backward_state_hook(user_hooks)
  328. proxy = tx.output.create_proxy(
  329. "call_function",
  330. _in_graph_bw_hooks,
  331. (bw_state_proxy,),
  332. {},
  333. )
  334. proxy.node.meta["example_value"] = torch.utils.hooks.BackwardHook(None, (), ())
  335. return BackwardHookVariable(proxy, module, user_hooks, user_pre_hooks)
  336. def __init__(
  337. self,
  338. proxy: torch.fx.Proxy,
  339. module: VariableTracker,
  340. user_hooks: VariableTracker,
  341. user_pre_hooks: VariableTracker,
  342. **options,
  343. ) -> None:
  344. super().__init__(**options)
  345. self.proxy = proxy
  346. self.module = module
  347. self.user_hooks = user_hooks
  348. self.user_pre_hooks = user_pre_hooks
  349. def as_proxy(self):
  350. return self.proxy
  351. def call_method(
  352. self,
  353. tx,
  354. name,
  355. args: list[VariableTracker],
  356. kwargs: dict[str, VariableTracker],
  357. ) -> VariableTracker:
  358. if name in ("setup_input_hook", "setup_output_hook"):
  359. return self._setup_hook(tx, name, *args, **kwargs)
  360. return super().call_method(tx, name, args, kwargs)
  361. def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args):
  362. from .builder import wrap_fx_proxy
  363. return wrap_fx_proxy(
  364. tx,
  365. tx.output.create_proxy(
  366. "call_method",
  367. hook_method_name,
  368. (self.as_proxy(), args.as_proxy()),
  369. {},
  370. ),
  371. )