_functional_collectives.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import sys
  4. import warnings
  5. from typing import Any, cast, Optional, TYPE_CHECKING, Union
  6. import torch
  7. import torch.distributed as dist
  8. import torch.distributed.distributed_c10d as c10d
  9. from torch.distributed.device_mesh import DeviceMesh
  10. from torch.fx.experimental.proxy_tensor import get_proxy_mode
  11. from . import _functional_collectives_impl as fun_col_impl
  12. try:
  13. from torch.utils._cxx_pytree import tree_map_only
  14. except ImportError:
  15. from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
  16. try:
  17. from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
  18. except Exception:
  19. warnings.warn(
  20. "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly"
  21. )
  22. def is_torchdynamo_compiling(): # type: ignore[misc]
  23. return False
  24. return False
  25. """
  26. New traceable, functional collectives.
  27. RFC: https://github.com/pytorch/pytorch/issues/93173
  28. compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
  29. eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
  30. automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
  31. a downstream op.
  32. Issues:
  33. * Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
  34. * Proper support for eager requires inplace ops. We should explore having it as an option for the API.
  35. """
  36. """
  37. Functional collectives are asynchronous only and we perform implicit stream synchronization
  38. on behalf of the user.
  39. We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
  40. first usage of the tensor and insert cross stream sync at the right place.
  41. The above are the easy bits, the hard one is how we match the Work object returned by
  42. c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
  43. op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
  44. dispatcher which might call other implementations that are allowed to change the returned
  45. tensor - even return a tensor with a different shape (see ``torch.vmap``).
  46. This means the caller of our ops receives a Tensor that is not guaranteed to be the same
  47. allocated by our implementations and that makes pairing The AsyncTensor to the original
  48. tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
  49. Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
  50. identity is not stable across dispatch, the op caller would end up with a different Tensor
  51. instance that would not match any in the dictionary.
  52. With Tensor identity out of the question, we decided use the tensor data pointer, which
  53. should be stable across all the Tensor changes done during dispatch.
  54. We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
  55. We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
  56. Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
  57. can clean up stale entries in the dictionary.
  58. To eliminate the possibility of races we have a global version counter that is used by the finalizer.
  59. As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
  60. """
  61. """
  62. Functional collectives can accept any of these types to describe the ranks participating in collectives.
  63. The different types will be desugared to a canonical format
  64. """
  65. RANK_TYPES = Union[
  66. list[int],
  67. list[list[int]],
  68. dist.ProcessGroup,
  69. DeviceMesh,
  70. tuple["dist.tensor.DeviceMesh", int],
  71. str,
  72. ]
  73. """
  74. User facing APIs for functional collectives
  75. -------------------------------------------
  76. These apis are called by user code and expected to work both in eager execution and compilation,
  77. but there are significant differences to how the two modes are implemented underneath.
  78. Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
  79. just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization,
  80. and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified
  81. if sufficient subclass support is added in dynamo.
  82. Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
  83. Here's how it works under torch.compile/dynamo:
  84. all_reduce(...)
  85. |--> _expand_group(...) - desugars processgroup into canonical/traceable format
  86. |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper
  87. |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed
  88. And under eager execution:
  89. all_reduce(...)
  90. |--> _expand_group(...) - same as above, but less critical for eager
  91. |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace
  92. |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor,
  93. which issues wait_tensor() at the time of first use
  94. """
  95. def wait_tensor(tensor):
  96. """
  97. Wait on a tensor returned by the collectives ops.
  98. Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
  99. """
  100. return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
  101. def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
  102. """
  103. Broadcasts the tensor to all processes in the given process group.
  104. Args:
  105. src (int): Source rank
  106. group (ProcessGroup or List[int]): The process group to work on.
  107. tag (str, optional): A unique identifier for the collective. Default: empty string
  108. """
  109. group_name = _resolve_group_name(group, tag)
  110. tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
  111. return _maybe_wrap_tensor(tensor)
  112. def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
  113. """
  114. Reduces the tensor data across all machines in such a way that all get
  115. the final result.
  116. The input tensor is left unmodified.
  117. Group can be one of:
  118. List[int]: ranks participating in the collective.
  119. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  120. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  121. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  122. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  123. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  124. that information and perform collective algebraic optimization. Use other forms of input for that.
  125. """
  126. group_name = _resolve_group_name(group, tag)
  127. tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
  128. return _maybe_wrap_tensor(tensor)
  129. def all_gather_tensor(
  130. self: torch.Tensor,
  131. gather_dim: int,
  132. group: RANK_TYPES,
  133. tag: str = "",
  134. ) -> torch.Tensor:
  135. """
  136. Gather tensor data across from all machines and concatenate over ``gather_dim``.
  137. Note that it currently only supports gather_dim = 0.
  138. The input tensor is left unmodified.
  139. Group can be one of:
  140. List[int]: ranks participating in the collective.
  141. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  142. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  143. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  144. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  145. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  146. that information and perform collective algebraic optimization. Use other forms of input for that.
  147. """
  148. assert self.is_contiguous()
  149. group_name = _resolve_group_name(group, tag)
  150. group_size = c10d._get_group_size_by_name(group_name)
  151. tensor = torch.ops._c10d_functional.all_gather_into_tensor(
  152. self, group_size, group_name
  153. )
  154. res = _maybe_wrap_tensor(tensor)
  155. # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
  156. if gather_dim != 0:
  157. # torch.cat access the data so we already need to wait here, first do wait
  158. # and then chunk + cat avoid us going through ACT dispatching logic again
  159. if isinstance(res, AsyncCollectiveTensor):
  160. res = res.wait() # type: ignore[attr-defined]
  161. res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
  162. return res
  163. def all_gather_tensor_autograd(
  164. self: torch.Tensor,
  165. gather_dim: int,
  166. group: RANK_TYPES,
  167. tag: str = "",
  168. ):
  169. """
  170. Gather tensor data across from all machines and concatenate over ``gather_dim``.
  171. Note that it currently only supports gather_dim = 0.
  172. This function is the same as all_gather_tensor but will propagate the
  173. backwards gradient across workers.
  174. See all_gather_tensor for more details on usage.
  175. """
  176. group_name = _resolve_group_name(group, tag)
  177. group_size = c10d._get_group_size_by_name(group_name)
  178. tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor(
  179. self, group_size, group_name
  180. )
  181. res = _FromTorchTensor.apply(tensor)
  182. # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
  183. if gather_dim != 0:
  184. # torch.cat access the data so we already need to wait here, first do wait
  185. # and then chunk + cat avoid us going through ACT dispatching logic again
  186. if isinstance(res, AsyncCollectiveTensor):
  187. res = res.wait() # type: ignore[attr-defined]
  188. res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
  189. return res
  190. def reduce_scatter_tensor(
  191. self: torch.Tensor,
  192. reduceOp: str,
  193. scatter_dim: int,
  194. group: RANK_TYPES,
  195. tag: str = "",
  196. ):
  197. """
  198. Reduces the tensor data across all machines in such a way that all get
  199. the final result, then scatter the results to corresponding ranks.
  200. The input tensor is left unmodified.
  201. Group can be one of:
  202. List[int]: ranks participating in the collective.
  203. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  204. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  205. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  206. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  207. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  208. that information and perform collective algebraic optimization. Use other forms of input for that.
  209. """
  210. group_name = _resolve_group_name(group, tag)
  211. group_size = c10d._get_group_size_by_name(group_name)
  212. assert self.size(scatter_dim) % group_size == 0, (
  213. f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
  214. )
  215. if scatter_dim != 0:
  216. tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
  217. self = torch.cat(tensor_list)
  218. tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
  219. self,
  220. reduceOp.lower(),
  221. group_size,
  222. group_name, # type: ignore[possibly-undefined]
  223. )
  224. res = _maybe_wrap_tensor(tensor)
  225. return res
  226. def reduce_scatter_tensor_autograd(
  227. self: torch.Tensor,
  228. reduceOp: str,
  229. scatter_dim: int,
  230. group: RANK_TYPES,
  231. tag: str = "",
  232. ):
  233. """
  234. Reduces the tensor data across all machines in such a way that all get
  235. the final result, then scatter the results to corresponding ranks.
  236. This function is the same as reduce_scatter_tensor but will propagate the
  237. backwards gradient across workers.
  238. Currently only the "sum" reduceOp is supported.
  239. See reduce_scatter_tensor for more details on usage.
  240. """
  241. group_name = _resolve_group_name(group, tag)
  242. group_size = c10d._get_group_size_by_name(group_name)
  243. assert self.size(scatter_dim) % group_size == 0, (
  244. f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
  245. )
  246. if scatter_dim != 0:
  247. tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
  248. self = torch.cat(tensor_list)
  249. tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor(
  250. self,
  251. reduceOp.lower(),
  252. group_size,
  253. group_name, # type: ignore[possibly-undefined]
  254. )
  255. res = _FromTorchTensor.apply(tensor)
  256. return res
  257. def all_reduce_coalesced(
  258. self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
  259. ) -> list[torch.Tensor]:
  260. """
  261. Reduces a list of tensors across all machines in such a way that all get
  262. the final result.
  263. The all tensors in the input list are left unmodified.
  264. Group can be one of:
  265. List[int]: ranks participating in the collective.
  266. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  267. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  268. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  269. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  270. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  271. that information and perform collective algebraic optimization. Use other forms of input for that.
  272. """
  273. group_name = _resolve_group_name(group, tag)
  274. tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined]
  275. self,
  276. reduceOp.lower(),
  277. group_name,
  278. )
  279. return list(map(_maybe_wrap_tensor, tensor_list))
  280. def all_gather_into_tensor_coalesced(
  281. self: list[torch.Tensor], group: RANK_TYPES, tag: str = ""
  282. ) -> list[torch.Tensor]:
  283. """
  284. Gather a list of tensors across from all machines.
  285. Note that it currently only supports gather_dim = 0.
  286. The input tensor is left unmodified.
  287. Group can be one of:
  288. List[int]: ranks participating in the collective.
  289. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  290. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  291. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  292. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  293. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  294. that information and perform collective algebraic optimization. Use other forms of input for that.
  295. """
  296. group_name = _resolve_group_name(group, tag)
  297. group_size = c10d._get_group_size_by_name(group_name)
  298. tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined]
  299. self,
  300. group_size,
  301. group_name,
  302. )
  303. return list(map(_maybe_wrap_tensor, tensor_list))
  304. def reduce_scatter_tensor_coalesced(
  305. inputs: list[torch.Tensor],
  306. reduceOp: str,
  307. scatter_dim: list[int],
  308. group: RANK_TYPES,
  309. tag: str = "",
  310. ) -> list[torch.Tensor]:
  311. """
  312. Reduces a list of tensors across all machines in such a way that all get
  313. the final result, then scatter the results to corresponding ranks.
  314. The input tensors are left unmodified.
  315. Group can be one of:
  316. List[int]: ranks participating in the collective.
  317. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  318. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  319. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  320. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  321. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  322. that information and perform collective algebraic optimization. Use other forms of input for that.
  323. """
  324. group_name = _resolve_group_name(group, tag)
  325. group_size = c10d._get_group_size_by_name(group_name)
  326. assert len(scatter_dim) == len(inputs)
  327. for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
  328. assert tensor.size(dim) % group_size == 0, (
  329. f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
  330. )
  331. if dim != 0:
  332. tensor_list = torch.chunk(tensor, group_size, dim=dim)
  333. inputs[idx] = torch.cat(tensor_list)
  334. tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined]
  335. inputs,
  336. reduceOp.lower(),
  337. group_size,
  338. group_name, # type: ignore[possibly-undefined]
  339. )
  340. return list(map(_maybe_wrap_tensor, tensor_list))
  341. # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
  342. # Today, this maps 1:1 with "aten ops that are views".
  343. def _is_view_op(tgt):
  344. assert isinstance(tgt, torch._ops.OpOverload)
  345. # Don't apply the view optimization to any `CompositeImplicitAutograd` ops.
  346. # See issue: https://github.com/pytorch/pytorch/issues/133421
  347. if torch._C._dispatch_has_kernel_for_dispatch_key(
  348. tgt.name(), torch.DispatchKey.CompositeImplicitAutograd
  349. ):
  350. return False
  351. schema = tgt._schema
  352. if len(schema.arguments) > 0:
  353. first_arg = schema.arguments[0]
  354. # check if op is a view
  355. return first_arg.alias_info is not None and not first_arg.alias_info.is_write
  356. def all_to_all_single(
  357. self: torch.Tensor,
  358. output_split_sizes: Optional[list[int]],
  359. input_split_sizes: Optional[list[int]],
  360. group: RANK_TYPES,
  361. tag: str = "",
  362. ) -> torch.Tensor:
  363. """
  364. Each process splits input tensor and then scatters the split list
  365. to all processes in a group. Then concatenate the received tensors from all
  366. the processes in the group and return single output tensor.
  367. Group can be one of:
  368. List[int]: ranks participating in the collective.
  369. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  370. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  371. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  372. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  373. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  374. that information and perform collective algebraic optimization. Use other forms of input for that.
  375. """
  376. if output_split_sizes is not None:
  377. assert all(
  378. isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
  379. ), output_split_sizes
  380. if input_split_sizes is not None:
  381. assert all(
  382. isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
  383. ), input_split_sizes
  384. group_name = _resolve_group_name(group, tag)
  385. group_size = c10d._get_group_size_by_name(group_name)
  386. if output_split_sizes is None or input_split_sizes is None:
  387. assert output_split_sizes is None and input_split_sizes is None, (
  388. "output_split_sizes and input_split_sizes must either be "
  389. "specified together or both set to None"
  390. )
  391. output_split_sizes = [self.shape[0] // group_size] * group_size
  392. input_split_sizes = output_split_sizes
  393. tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
  394. self,
  395. output_split_sizes,
  396. input_split_sizes,
  397. group_name,
  398. )
  399. return _maybe_wrap_tensor(tensor)
  400. def all_to_all_single_autograd(
  401. self: torch.Tensor,
  402. output_split_sizes: Optional[list[int]],
  403. input_split_sizes: Optional[list[int]],
  404. group: RANK_TYPES,
  405. tag: str = "",
  406. ) -> torch.Tensor:
  407. """
  408. Same as all_to_all_single but supports autograd.
  409. """
  410. if output_split_sizes is not None:
  411. assert all(
  412. isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
  413. ), output_split_sizes
  414. if input_split_sizes is not None:
  415. assert all(
  416. isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
  417. ), input_split_sizes
  418. group_name = _resolve_group_name(group, tag)
  419. group_size = c10d._get_group_size_by_name(group_name)
  420. if output_split_sizes is None or input_split_sizes is None:
  421. assert output_split_sizes is None and input_split_sizes is None, (
  422. "output_split_sizes and input_split_sizes must either be "
  423. "specified together or both set to None"
  424. )
  425. output_split_sizes = [self.shape[0] // group_size] * group_size
  426. input_split_sizes = output_split_sizes
  427. tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined]
  428. self,
  429. output_split_sizes,
  430. input_split_sizes,
  431. group_name,
  432. )
  433. return _FromTorchTensor.apply(tensor)
  434. def permute_tensor(
  435. self: torch.Tensor,
  436. src_dst: list[int],
  437. group: RANK_TYPES,
  438. tag: str = "",
  439. ) -> torch.Tensor:
  440. """
  441. Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
  442. be defined such that src_dst[m] == n means m sends to n.
  443. Group can be one of:
  444. List[int]: ranks participating in the collective.
  445. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  446. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  447. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  448. (DeviceMesh, int): Do a MPMD collective over one
  449. """
  450. t, rankset, group_size = _expand_group(group, tag)
  451. local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
  452. output_split_sizes = [0] * group_size
  453. input_split_sizes = [0] * group_size
  454. for src, dst in enumerate(src_dst):
  455. if src == dist.get_rank(local_pg):
  456. input_split_sizes[dst] = self.numel()
  457. if dst == dist.get_rank(local_pg):
  458. output_split_sizes[src] = self.numel()
  459. return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
  460. class AsyncCollectiveTensor(torch.Tensor):
  461. r"""
  462. A Tensor wrapper subclass that is used to trigger a call to wait
  463. prior to first use of the underlying tensor.
  464. Use it inside functional collective pytorch wrappers like the following:
  465. def functional_collective(self, group, tag):
  466. tag, rankset, group_size = _expand_group(group, tag)
  467. tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
  468. return _maybe_wrap_tensor(tensor)
  469. """
  470. elem: torch.Tensor
  471. completed: bool
  472. __slots__ = ["elem", "completed"]
  473. @staticmethod
  474. def __new__(cls, elem: torch.Tensor):
  475. r = torch.Tensor._make_wrapper_subclass(
  476. cls,
  477. elem.size(),
  478. strides=elem.stride(),
  479. storage_offset=elem.storage_offset(),
  480. dtype=elem.dtype,
  481. layout=elem.layout,
  482. device=elem.device,
  483. requires_grad=elem.requires_grad,
  484. )
  485. r.elem = elem
  486. r.completed = False
  487. return r
  488. def __tensor_flatten__(self):
  489. return ["elem"], None
  490. def tolist(self):
  491. return self.trigger_wait().tolist()
  492. @staticmethod
  493. def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
  494. assert meta is None
  495. elem = inner_tensors["elem"]
  496. return AsyncCollectiveTensor(elem)
  497. def __coerce_same_metadata_as_tangent__(
  498. self, expected_metadata: Any, expected_type: Optional[type] = None
  499. ):
  500. if expected_type is not torch.Tensor:
  501. return None
  502. return self.trigger_wait()
  503. def __repr__(self) -> str: # type: ignore[override]
  504. return f"AsyncCollectiveTensor({self.trigger_wait()})"
  505. def trigger_wait(self):
  506. if not self.completed:
  507. out = wait_tensor(self.elem)
  508. self.completed = True
  509. return out
  510. else:
  511. return self.elem
  512. def wait(self) -> torch.Tensor:
  513. return wait_tensor(self.elem)
  514. def _get_acs_underlying_tensor(self):
  515. """This method enables _functional_collectives_impl to test if a tensor is an ACS"""
  516. return self.elem
  517. @classmethod
  518. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
  519. if func == torch.ops.aten.view.default:
  520. # Fast handle aten.view as a lot of view related op goes to aten.view
  521. # eventually, this avoids pytree slowdown
  522. res = func(args[0].elem, args[1])
  523. wrapper_res = AsyncCollectiveTensor(res)
  524. return wrapper_res
  525. is_view_op = _is_view_op(func)
  526. def unwrap(e: AsyncCollectiveTensor):
  527. # wait_tensor is idepotent and will do stream sync only once
  528. if not is_view_op:
  529. return e.trigger_wait()
  530. return e.elem
  531. def wrap(e: torch.Tensor):
  532. # wait_tensor is idepotent and will do stream sync only once
  533. assert not isinstance(e, AsyncCollectiveTensor)
  534. res = AsyncCollectiveTensor(e)
  535. return res
  536. unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
  537. unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
  538. # we don't wrap the result as it doesn't need to be waited on.
  539. out = func(*unwrapped_args, **unwrapped_kwargs)
  540. # View ops dont require a sync, so we should re-wrap the outputs.
  541. if is_view_op:
  542. out = tree_map_only(torch.Tensor, wrap, out)
  543. return out
  544. def numpy(self): # type: ignore[override]
  545. return self.wait().numpy()
  546. """
  547. Utils and infrastructure for tracing support
  548. """
  549. def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]:
  550. """
  551. _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
  552. By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
  553. torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
  554. """
  555. # had to define this hack _inside_ expand_group to avoid
  556. # graph_break [('torch.* op returned non-Tensor int
  557. # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
  558. if TYPE_CHECKING:
  559. def cast_listlistint(x):
  560. return cast(list[list[int]], x)
  561. def cast_listint(x):
  562. return cast(list[int], x)
  563. else:
  564. # fake cast op for use at runtime since dynamo doesn't support real cast
  565. # also, dynamo didn't like encountering 'typing' objects ()
  566. # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
  567. def cast_listlistint(x):
  568. return x
  569. def cast_listint(x):
  570. return x
  571. rankset: list[int]
  572. if isinstance(group, list):
  573. if isinstance(group[0], list):
  574. nested_list = cast_listlistint(group)
  575. rankset = []
  576. group_size = -1
  577. for rs in nested_list:
  578. rankset.extend(rs)
  579. if group_size != -1 and group_size != len(rs):
  580. raise ValueError(
  581. f"group sizes must be identical found {group_size} and {len(rs)}"
  582. )
  583. group_size = len(rs)
  584. else:
  585. rankset = cast_listint(group)
  586. group_size = len(rankset)
  587. elif isinstance(group, dist.ProcessGroup):
  588. rankset = dist.get_process_group_ranks(group)
  589. group_size = len(rankset)
  590. tag = tag or c10d._get_group_tag(group)
  591. elif isinstance(group, DeviceMesh):
  592. assert group.ndim == 1, (
  593. "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
  594. )
  595. # TODO: it should run collective in the whole mesh instead of dim 0
  596. pg = group.get_group()
  597. rankset = dist.get_process_group_ranks(pg)
  598. group_size = len(rankset)
  599. tag = tag or c10d._get_group_tag(pg)
  600. elif isinstance(group, tuple):
  601. if (
  602. len(group) == 2
  603. and isinstance(group[0], DeviceMesh)
  604. and isinstance(group[1], int)
  605. ):
  606. dmesh = group[0]
  607. dim = group[1]
  608. pg = dmesh.get_group(dim)
  609. rankset = dist.get_process_group_ranks(pg)
  610. group_size = len(rankset)
  611. tag = tag or c10d._get_group_tag(pg)
  612. else:
  613. raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
  614. else:
  615. raise ValueError(
  616. "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
  617. )
  618. return (tag, rankset, group_size)
  619. def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
  620. """
  621. Given group in RANK_TYPES, return the group name.
  622. """
  623. # `tag` will be deprecated. See details in:
  624. # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
  625. if isinstance(group, dist.ProcessGroup):
  626. return group.group_name
  627. elif isinstance(group, str):
  628. return group
  629. elif isinstance(group, DeviceMesh):
  630. assert group.ndim == 1, (
  631. "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
  632. )
  633. return group._dim_group_names[0]
  634. elif isinstance(group, tuple):
  635. if (
  636. len(group) == 2
  637. and isinstance(group[0], DeviceMesh)
  638. and isinstance(group[1], int)
  639. ):
  640. dmesh = group[0]
  641. dim = group[1]
  642. return dmesh._dim_group_names[dim]
  643. else:
  644. raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
  645. elif isinstance(group, list):
  646. if not is_torchdynamo_compiling():
  647. warnings.warn(
  648. "The combination of ranks + tag as process group "
  649. "identifier has been deprecated. Please switch to "
  650. "using ProcessGroup, DeviceMesh, or group name instead.",
  651. FutureWarning,
  652. stacklevel=3,
  653. )
  654. return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
  655. else:
  656. raise ValueError(f"Unsupported group type: {type(group)}, {group}")
  657. class _FromTorchTensor(torch.autograd.Function):
  658. """
  659. _FromTorchTensor allows autograd to propagate from a normal Tensor to an
  660. AsyncCollectiveTensor.
  661. """
  662. @staticmethod
  663. def forward( # type: ignore[override]
  664. ctx, # pyre-ignore[2]: Parameter must be annotated.
  665. input: torch.Tensor,
  666. ) -> torch.Tensor:
  667. return _maybe_wrap_tensor(input)
  668. @staticmethod
  669. def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override]
  670. return grad_output
  671. def _are_we_tracing() -> bool:
  672. if is_torchdynamo_compiling():
  673. return True
  674. # If fake mode is turned on, we are almost definitely compiling/tracing.
  675. if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None:
  676. return True
  677. # See Note [enable_python_dispatcher in dynamo]
  678. if torch._C._dispatch_tls_is_dispatch_key_included(
  679. torch._C.DispatchKey.PythonDispatcher
  680. ):
  681. return True
  682. return get_proxy_mode() is not None
  683. def _maybe_wrap_tensor(self) -> torch.Tensor:
  684. if _are_we_tracing():
  685. return wait_tensor(self)
  686. res = AsyncCollectiveTensor(self)
  687. return cast(torch.Tensor, res)
  688. @contextlib.contextmanager
  689. def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
  690. """
  691. Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs.
  692. Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region:
  693. ```
  694. def all_reduce_eager(x):
  695. y = x * x
  696. req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
  697. return y
  698. @torch.compile(fullgraph=True)
  699. def all_reduce_wait_compiled(y):
  700. torch.ops.c10d_functional.wait_tensor(y)
  701. return y * y
  702. x = torch.ones(1280, 1280, device="cuda") + self.rank
  703. # the context manager ensures that `wait_tensor(y)` will wait on the correct work object
  704. with allow_inflight_collective_as_graph_input_ctx():
  705. y = all_reduce_eager(x)
  706. z = all_reduce_wait_compiled(y)
  707. ```
  708. With this context manager, when a collective is called, under the hood the work object of the collective
  709. will be registered in the work registry, and the wait_tensor() in compiled region called on
  710. the output tensor of the collective will wait on the correct work object.
  711. """
  712. previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input()
  713. try:
  714. torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value)
  715. yield
  716. finally:
  717. torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(
  718. previous
  719. )
  720. def _make_all_gather_out_tensor(input, group_size):
  721. out_size = list(input.size())
  722. if len(out_size) == 0:
  723. out_size.append(group_size)
  724. else:
  725. out_size[0] *= group_size
  726. out_tensor = input.new_empty(out_size)
  727. return out_tensor
  728. def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
  729. return [_make_all_gather_out_tensor(t, group_size) for t in self]
  730. # We now register meta kernels to deal with tracing
  731. def _broadcast_meta(self, *args):
  732. return torch.empty_like(self)
  733. def _all_reduce_meta(self, *args):
  734. return torch.empty_like(self)
  735. def _wait_tensor_meta(self, *args):
  736. return torch.empty_like(self)
  737. def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
  738. return _make_all_gather_out_tensor(shard, group_size)
  739. def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
  740. out_size = list(input.size())
  741. out_size[0] //= group_size
  742. return input.new_empty(out_size)
  743. def _all_reduce_coalesced_meta(self, *args):
  744. return [torch.empty_like(t) for t in self]
  745. def _all_reduce__meta(inp, *args):
  746. return inp
  747. def _broadcast__meta(inp, *args):
  748. return inp
  749. def _all_reduce_coalesced__meta(inputs, *args):
  750. return inputs
  751. def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
  752. def mk_out_tensor(input):
  753. out_size = list(input.size())
  754. out_size[0] //= group_size
  755. out_tensor = input.new_empty(out_size)
  756. return out_tensor
  757. return [mk_out_tensor(t) for t in inputs]
  758. # NB: We often say all_to_all has dynamic output size, but this is not
  759. # technically true: instead, what typically happens is you manually
  760. # communicate the output_split_sizes ahead of time (which is dynamic),
  761. # but then you pass those sizes explicitly, and the all to all itself
  762. # isn't dynamic, it just follows the specified output splits
  763. def _all_to_all_single_meta(
  764. input, output_split_sizes, input_split_sizes, *args, **kwargs
  765. ):
  766. if output_split_sizes is None:
  767. return input.new_empty(input.size())
  768. else:
  769. for s in output_split_sizes:
  770. torch._check_is_size(s)
  771. out_size = list(input.size())
  772. out_size[0] = sum(output_split_sizes)
  773. return input.new_empty(out_size)
  774. def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out):
  775. return _make_all_gather_out_tensor(input, group_size)
  776. def _all_gather_into_tensor_native_meta(input, group_size, group_name):
  777. return _make_all_gather_out_tensor(input, group_size)
  778. def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
  779. return [
  780. _all_gather_into_tensor_native_meta(input, group_size, group_name)
  781. for input in inputs
  782. ]
  783. def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
  784. shape = list(inp.size())
  785. shape[0] //= group_size
  786. return inp.new_empty(shape)
  787. def _reduce_scatter_tensor_coalesced_native_meta(
  788. inputs, reduce_op, group_size, group_name
  789. ):
  790. return [
  791. _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
  792. for inp in inputs
  793. ]
  794. # Library MUST be defined at module scope or it doesn't work
  795. lib_impl = torch.library.Library("_c10d_functional", "IMPL")
  796. lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
  797. lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
  798. lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
  799. lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
  800. lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
  801. lib_impl.impl(
  802. "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
  803. )
  804. lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
  805. lib_impl.impl(
  806. "all_gather_into_tensor_coalesced",
  807. _all_gather_into_tensor_coalesced_native_meta,
  808. "Meta",
  809. )
  810. lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
  811. lib_impl.impl(
  812. "reduce_scatter_tensor_coalesced",
  813. _reduce_scatter_tensor_coalesced_native_meta,
  814. "Meta",
  815. )
  816. lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
  817. lib_impl.impl("broadcast", _broadcast_meta, "Meta")
  818. lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
  819. # mark these ops has side effect so that they won't be removed by DCE
  820. torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default)
  821. torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor)
  822. # Register legacy ops for backward compatibility
  823. # TODO(yifu): remove these in functional collective beta release
  824. legacy_lib = torch.library.Library("c10d_functional", "DEF")
  825. legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
  826. ops_defs = [
  827. "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
  828. "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
  829. "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
  830. "wait_tensor(Tensor self) -> Tensor",
  831. "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
  832. "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
  833. "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
  834. "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
  835. "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950
  836. ]
  837. my_module = sys.modules[__name__]
  838. for op_def in ops_defs:
  839. op_name = op_def[0 : op_def.index("(")]
  840. backend_impl = getattr(fun_col_impl, f"_{op_name}")
  841. legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
  842. legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
  843. """
  844. Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
  845. functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
  846. We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
  847. the mapping dict below.
  848. These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
  849. """
  850. def all_gather_tensor_inplace(
  851. output_tensor: torch.Tensor,
  852. input_tensor: torch.Tensor,
  853. group=None, # TODO add a type,
  854. async_op: bool = False,
  855. tag: str = "",
  856. gather_dim: int = 0,
  857. ):
  858. assert not async_op, (
  859. "Can't remap async version of inplace op to functional collective"
  860. )
  861. group = group or dist.group.WORLD
  862. assert group is not None
  863. return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
  864. def reduce_scatter_tensor_inplace(
  865. output: torch.Tensor,
  866. input: torch.Tensor,
  867. op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok?
  868. group=None, # TODO add a type
  869. async_op: bool = False,
  870. scatter_dim: int = 0,
  871. tag: str = "",
  872. ):
  873. assert not async_op, (
  874. "Can't remap async version of inplace op to functional collective"
  875. )
  876. group = group or dist.group.WORLD
  877. assert group is not None
  878. return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
  879. REDUCE_OP_TO_STR = {
  880. dist.ReduceOp.SUM: "sum",
  881. dist.ReduceOp.AVG: "avg",
  882. dist.ReduceOp.PRODUCT: "product",
  883. dist.ReduceOp.MIN: "min",
  884. dist.ReduceOp.MAX: "max",
  885. dist.ReduceOp.BAND: "band",
  886. dist.ReduceOp.BOR: "bor",
  887. dist.ReduceOp.BXOR: "bxor",
  888. }
  889. def all_reduce_inplace(
  890. tensor: torch.Tensor,
  891. op: str = "sum",
  892. group=None,
  893. async_op: bool = False,
  894. tag: str = "",
  895. ):
  896. assert not async_op, (
  897. "Can't remap async version of inplace op to functional collective"
  898. )
  899. group = group or dist.group.WORLD
  900. assert group is not None
  901. return tensor.copy_(all_reduce(tensor, op, group, tag))
  902. def all_to_all_inplace(
  903. output: torch.Tensor,
  904. input: torch.Tensor,
  905. output_split_sizes=None,
  906. input_split_sizes=None,
  907. group=None,
  908. async_op=False,
  909. tag: str = "",
  910. ):
  911. assert not async_op, (
  912. "Can't remap async version of inplace op to functional collective"
  913. )
  914. group = group or dist.group.WORLD
  915. assert group is not None
  916. return output.copy_(
  917. all_to_all_single(
  918. input,
  919. output_split_sizes,
  920. input_split_sizes,
  921. group,
  922. tag,
  923. )
  924. )
  925. def all_gather_inplace(
  926. tensor_list: list[torch.Tensor],
  927. tensor: torch.Tensor,
  928. group=None,
  929. async_op=False,
  930. tag: str = "",
  931. ):
  932. assert not async_op, (
  933. "Can't remap async version of inplace op to functional collective"
  934. )
  935. assert tensor.dim() == 0 or all(t.size(0) == tensor.size(0) for t in tensor_list), (
  936. "Remapping variable size all_gather is not yet supported"
  937. )
  938. group = group or dist.group.WORLD
  939. assert group is not None
  940. output = all_gather_tensor(tensor, 0, group, tag)
  941. # Use aten.slice instead of aten.split because the latter causes
  942. # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
  943. output_splits = []
  944. offset = 0
  945. for t in tensor_list:
  946. is_scalar = t.dim() == 0
  947. t_offset = 1 if is_scalar else t.size(0)
  948. out = output[offset] if is_scalar else output[offset : offset + t_offset]
  949. output_splits.append(out)
  950. offset += t_offset
  951. for dst, src in zip(tensor_list, output_splits):
  952. dst.copy_(src)
  953. return tensor_list
  954. from torch.distributed.distributed_c10d import (
  955. _all_gather_base as legacy_all_gather_base,
  956. _reduce_scatter_base as legacy_reduce_scatter_base,
  957. all_gather as legacy_all_gather,
  958. all_gather_into_tensor as legacy_allgather,
  959. all_reduce as legacy_allreduce,
  960. all_to_all_single as legacy_all_to_all_single,
  961. reduce_scatter_tensor as legacy_reducescatter,
  962. )
  963. # This dict should contain sets of functions that dynamo is allowed to remap.
  964. # Functions in this set should accept the same args/kwargs 1:1 as their mapping.
  965. traceable_collective_remaps = {
  966. legacy_allgather: all_gather_tensor_inplace,
  967. legacy_reducescatter: reduce_scatter_tensor_inplace,
  968. legacy_allreduce: all_reduce_inplace,
  969. legacy_all_to_all_single: all_to_all_inplace,
  970. legacy_all_gather: all_gather_inplace,
  971. legacy_reduce_scatter_base: reduce_scatter_tensor_inplace,
  972. legacy_all_gather_base: all_gather_tensor_inplace,
  973. }