scatter_gather.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Sequence
  3. from typing import Any, Optional, overload, TypeVar, Union
  4. from typing_extensions import deprecated
  5. import torch
  6. from torch.nn.parallel._functions import Gather, Scatter
  7. __all__ = ["scatter", "scatter_kwargs", "gather"]
  8. @deprecated(
  9. "`is_namedtuple` is deprecated, please use the python checks instead",
  10. category=FutureWarning,
  11. )
  12. def is_namedtuple(obj: Any) -> bool:
  13. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
  14. return _is_namedtuple(obj)
  15. def _is_namedtuple(obj: Any) -> bool:
  16. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
  17. return (
  18. isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
  19. )
  20. T = TypeVar("T", dict, list, tuple)
  21. # For some reason, 'scatter' returns a tuple when given a single Tensor input but a list otherwise.
  22. @overload
  23. def scatter(
  24. inputs: torch.Tensor,
  25. target_gpus: Sequence[Union[int, torch.device]],
  26. dim: int = ...,
  27. ) -> tuple[torch.Tensor, ...]: ...
  28. @overload
  29. def scatter(
  30. inputs: T,
  31. target_gpus: Sequence[Union[int, torch.device]],
  32. dim: int = ...,
  33. ) -> list[T]: ...
  34. def scatter(inputs, target_gpus, dim=0):
  35. r"""Slice tensors into approximately equal chunks and distributes them across given GPUs.
  36. Duplicates references to objects that are not tensors.
  37. """
  38. def scatter_map(obj):
  39. if isinstance(obj, torch.Tensor):
  40. return Scatter.apply(target_gpus, None, dim, obj)
  41. if _is_namedtuple(obj):
  42. return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
  43. if isinstance(obj, tuple) and len(obj) > 0:
  44. return list(zip(*map(scatter_map, obj)))
  45. if isinstance(obj, list) and len(obj) > 0:
  46. return [list(i) for i in zip(*map(scatter_map, obj))]
  47. if isinstance(obj, dict) and len(obj) > 0:
  48. return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
  49. return [obj for _ in target_gpus]
  50. # After scatter_map is called, a scatter_map cell will exist. This cell
  51. # has a reference to the actual function scatter_map, which has references
  52. # to a closure that has a reference to the scatter_map cell (because the
  53. # fn is recursive). To avoid this reference cycle, we set the function to
  54. # None, clearing the cell
  55. try:
  56. res = scatter_map(inputs)
  57. finally:
  58. scatter_map = None # type: ignore[assignment]
  59. return res
  60. def scatter_kwargs(
  61. inputs: tuple[Any, ...],
  62. kwargs: Optional[dict[str, Any]],
  63. target_gpus: Sequence[Union[int, torch.device]],
  64. dim: int = 0,
  65. ) -> tuple[tuple[Any, ...], tuple[dict[str, Any], ...]]:
  66. r"""Scatter with support for kwargs dictionary."""
  67. scattered_inputs = scatter(inputs, target_gpus, dim) if inputs else []
  68. scattered_kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
  69. if len(scattered_inputs) < len(scattered_kwargs):
  70. scattered_inputs.extend(
  71. () for _ in range(len(scattered_kwargs) - len(scattered_inputs))
  72. )
  73. elif len(scattered_kwargs) < len(inputs):
  74. scattered_kwargs.extend(
  75. {} for _ in range(len(scattered_inputs) - len(scattered_kwargs))
  76. )
  77. return tuple(scattered_inputs), tuple(scattered_kwargs)
  78. def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0) -> Any:
  79. r"""Gather tensors from different GPUs on a specified device.
  80. This function is useful for gathering the results of a distributed computation.
  81. It takes a sequence of objects, one for each GPU, and returns a single object
  82. on the specified device.
  83. Args:
  84. outputs (Any): A sequence of objects (potentially tensors) to gather.
  85. target_device (Union[int, torch.device]): The device to gather the tensors to.
  86. Use 'cpu' for CPU to avoid a deprecation warning.
  87. dim (int, optional): The dimension along which to gather. Default: 0.
  88. Returns:
  89. Any: A gathered object (potentially tensor) on the specified device.
  90. """
  91. def gather_map(outputs):
  92. out = outputs[0]
  93. if isinstance(out, torch.Tensor):
  94. return Gather.apply(target_device, dim, *outputs)
  95. if out is None:
  96. return None
  97. if isinstance(out, dict):
  98. if not all(len(out) == len(d) for d in outputs):
  99. raise ValueError("All dicts must have the same number of keys")
  100. return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
  101. if _is_namedtuple(out):
  102. return type(out)._make(map(gather_map, zip(*outputs)))
  103. return type(out)(map(gather_map, zip(*outputs)))
  104. # Recursive function calls like this create reference cycles.
  105. # Setting the function to None clears the refcycle.
  106. try:
  107. res = gather_map(outputs)
  108. finally:
  109. gather_map = None # type: ignore[assignment]
  110. return res