parallel_apply.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import threading
  2. from collections.abc import Sequence
  3. from typing import Any, cast, Optional, Union
  4. import torch
  5. from torch._utils import ExceptionWrapper
  6. from torch.cuda._utils import _get_device_index
  7. from torch.nn.modules import Module
  8. __all__ = ["get_a_var", "parallel_apply"]
  9. def get_a_var(
  10. obj: Union[torch.Tensor, list[Any], tuple[Any, ...], dict[Any, Any]],
  11. ) -> Optional[torch.Tensor]:
  12. if isinstance(obj, torch.Tensor):
  13. return obj
  14. if isinstance(obj, (list, tuple)):
  15. for result in map(get_a_var, obj):
  16. if isinstance(result, torch.Tensor):
  17. return result
  18. if isinstance(obj, dict):
  19. for result in map(get_a_var, obj.items()):
  20. if isinstance(result, torch.Tensor):
  21. return result
  22. return None
  23. def parallel_apply(
  24. modules: Sequence[Module],
  25. inputs: Sequence[Any],
  26. kwargs_tup: Optional[Sequence[dict[str, Any]]] = None,
  27. devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
  28. ) -> list[Any]:
  29. r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
  30. Args:
  31. modules (Module): modules to be parallelized
  32. inputs (tensor): inputs to the modules
  33. devices (list of int or torch.device): CUDA devices
  34. :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
  35. :attr:`devices` (if given) should all have same length. Moreover, each
  36. element of :attr:`inputs` can either be a single object as the only argument
  37. to a module, or a collection of positional arguments.
  38. """
  39. assert len(modules) == len(inputs), (
  40. f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
  41. )
  42. if kwargs_tup is not None:
  43. assert len(modules) == len(kwargs_tup)
  44. else:
  45. kwargs_tup = (cast(dict[str, Any], {}),) * len(modules)
  46. if devices is not None:
  47. assert len(modules) == len(devices)
  48. else:
  49. devices = [None] * len(modules)
  50. devices = [_get_device_index(x, True) for x in devices]
  51. streams = [torch.cuda.current_stream(x) for x in devices]
  52. lock = threading.Lock()
  53. results = {}
  54. grad_enabled, autocast_enabled = (
  55. torch.is_grad_enabled(),
  56. torch.is_autocast_enabled(),
  57. )
  58. def _worker(
  59. i: int,
  60. module: Module,
  61. input: Any,
  62. kwargs: dict[str, Any],
  63. device: Optional[Union[int, torch.device]] = None,
  64. stream: Optional[torch.cuda.Stream] = None,
  65. ) -> None:
  66. torch.set_grad_enabled(grad_enabled)
  67. if device is None:
  68. t = get_a_var(input)
  69. if t is None:
  70. with lock:
  71. results[i] = ExceptionWrapper(
  72. where=f"in replica {i}, no device was provided and no tensor input was found; "
  73. "device cannot be resolved"
  74. )
  75. return
  76. device = t.get_device()
  77. if stream is None:
  78. stream = torch.cuda.current_stream(device)
  79. try:
  80. with (
  81. torch.cuda.device(device),
  82. torch.cuda.stream(stream),
  83. torch.amp.autocast("cuda", enabled=autocast_enabled),
  84. ):
  85. # this also avoids accidental slicing of `input` if it is a Tensor
  86. if not isinstance(input, (list, tuple)):
  87. input = (input,)
  88. output = module(*input, **kwargs)
  89. with lock:
  90. results[i] = output
  91. except Exception:
  92. with lock:
  93. results[i] = ExceptionWrapper(
  94. where=f"in replica {i} on device {device}"
  95. )
  96. if len(modules) > 1:
  97. threads = [
  98. threading.Thread(
  99. target=_worker, args=(i, module, input, kwargs, device, stream)
  100. )
  101. for i, (module, input, kwargs, device, stream) in enumerate(
  102. zip(modules, inputs, kwargs_tup, devices, streams)
  103. )
  104. ]
  105. for thread in threads:
  106. thread.start()
  107. for thread in threads:
  108. thread.join()
  109. else:
  110. _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
  111. outputs = []
  112. for i in range(len(inputs)):
  113. output = results[i]
  114. if isinstance(output, ExceptionWrapper):
  115. output.reraise()
  116. outputs.append(output)
  117. return outputs