_functions.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import warnings
  2. from itertools import chain
  3. from typing import Optional
  4. import torch
  5. from torch._utils import _get_device_index
  6. from torch.autograd import Function
  7. from torch.nn.parallel import comm
  8. class Broadcast(Function):
  9. @staticmethod
  10. def forward(ctx, target_gpus, *inputs):
  11. assert all(i.device.type != "cpu" for i in inputs), (
  12. "Broadcast function not implemented for CPU tensors"
  13. )
  14. target_gpus = [_get_device_index(x, True) for x in target_gpus]
  15. ctx.target_gpus = target_gpus
  16. if len(inputs) == 0:
  17. return ()
  18. ctx.num_inputs = len(inputs)
  19. ctx.input_device = inputs[0].get_device()
  20. outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  21. non_differentiables = []
  22. for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
  23. if not input_requires_grad:
  24. non_differentiables.extend(output[idx] for output in outputs)
  25. ctx.mark_non_differentiable(*non_differentiables)
  26. return tuple(chain.from_iterable(outputs))
  27. @staticmethod
  28. def backward(ctx, *grad_outputs):
  29. return (None,) + ReduceAddCoalesced.apply(
  30. ctx.input_device, ctx.num_inputs, *grad_outputs
  31. )
  32. class ReduceAddCoalesced(Function):
  33. @staticmethod
  34. def forward(ctx, destination, num_inputs, *grads):
  35. ctx.target_gpus = [
  36. grads[i].get_device() for i in range(0, len(grads), num_inputs)
  37. ]
  38. grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)]
  39. return comm.reduce_add_coalesced(grads_, destination)
  40. @staticmethod
  41. def backward(ctx, *grad_outputs):
  42. return (
  43. None,
  44. None,
  45. ) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
  46. class Gather(Function):
  47. @staticmethod
  48. def forward(ctx, target_device, dim, *inputs):
  49. assert all(i.device.type != "cpu" for i in inputs), (
  50. "Gather function not implemented for CPU tensors"
  51. )
  52. if target_device == "cpu":
  53. ctx.target_device = "cpu"
  54. else:
  55. target_device = _get_device_index(target_device, True)
  56. ctx.target_device = target_device
  57. ctx.dim = dim
  58. ctx.input_gpus = tuple(i.get_device() for i in inputs)
  59. if all(t.dim() == 0 for t in inputs) and dim == 0:
  60. inputs = tuple(t.view(1) for t in inputs)
  61. warnings.warn(
  62. "Was asked to gather along dimension 0, but all "
  63. "input tensors were scalars; will instead unsqueeze "
  64. "and return a vector."
  65. )
  66. ctx.unsqueezed_scalar = True
  67. else:
  68. ctx.unsqueezed_scalar = False
  69. ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
  70. return comm.gather(inputs, ctx.dim, ctx.target_device)
  71. @staticmethod
  72. def backward(ctx, grad_output):
  73. scattered_grads = Scatter.apply(
  74. ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output
  75. )
  76. if ctx.unsqueezed_scalar:
  77. scattered_grads = tuple(g[0] for g in scattered_grads)
  78. return (None, None) + scattered_grads
  79. class Scatter(Function):
  80. @staticmethod
  81. def forward(ctx, target_gpus, chunk_sizes, dim, input):
  82. target_gpus = [_get_device_index(x, True) for x in target_gpus]
  83. ctx.dim = dim
  84. ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
  85. streams = None
  86. if torch.accelerator.is_available() and ctx.input_device == -1:
  87. # Perform CPU to GPU copies in a background stream
  88. streams = [_get_stream(torch.device(device)) for device in target_gpus]
  89. outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
  90. # Synchronize with the copy stream
  91. if streams is not None:
  92. for i, output in enumerate(outputs):
  93. with torch.accelerator.device_index(target_gpus[i]):
  94. main_stream = torch.accelerator.current_stream()
  95. main_stream.wait_stream(streams[i])
  96. output.record_stream(main_stream)
  97. return outputs
  98. @staticmethod
  99. def backward(ctx, *grad_output):
  100. return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
  101. # background streams used for copying
  102. _streams: Optional[list[Optional[torch.Stream]]] = None
  103. def _get_stream(device: torch.device):
  104. """Get a background stream for copying between CPU and target device."""
  105. global _streams
  106. if device.type == "cpu" or not torch.accelerator.is_available():
  107. return None
  108. assert torch.accelerator.current_accelerator().type == device.type
  109. if _streams is None:
  110. _streams = [None] * torch.accelerator.device_count()
  111. if _streams[device.index] is None:
  112. _streams[device.index] = torch.Stream(device.index)
  113. return _streams[device.index]