_device.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import torch
  4. from torch._C import _len_torch_function_stack
  5. from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
  6. from torch.utils._contextlib import context_decorator
  7. CURRENT_DEVICE: torch.device | None = None
  8. @functools.lru_cache(1)
  9. def _device_constructors():
  10. return {
  11. # standard ones
  12. torch.empty,
  13. torch.empty_permuted,
  14. torch.empty_strided,
  15. torch.empty_quantized,
  16. torch.ones,
  17. torch.arange,
  18. torch.bartlett_window,
  19. torch.blackman_window,
  20. torch.eye,
  21. torch.fft.fftfreq,
  22. torch.fft.rfftfreq,
  23. torch.full,
  24. torch.hamming_window,
  25. torch.hann_window,
  26. torch.kaiser_window,
  27. torch.linspace,
  28. torch.logspace,
  29. torch.nested.nested_tensor,
  30. # This function doesn't actually take a device argument
  31. # torch.normal,
  32. torch.rand,
  33. torch.randn,
  34. torch.randint,
  35. torch.randperm,
  36. torch.range,
  37. torch.sparse_coo_tensor,
  38. torch.sparse_compressed_tensor,
  39. torch.sparse_csr_tensor,
  40. torch.sparse_csc_tensor,
  41. torch.sparse_bsr_tensor,
  42. torch.sparse_bsc_tensor,
  43. torch.tril_indices,
  44. torch.triu_indices,
  45. torch.zeros,
  46. torch.asarray,
  47. # weird ones
  48. torch.tensor,
  49. torch.as_tensor,
  50. torch.scalar_tensor,
  51. }
  52. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  53. class DeviceContext(TorchFunctionMode):
  54. def __init__(self, device) -> None:
  55. # pyrefly: ignore [read-only]
  56. self.device = torch.device(device)
  57. def __enter__(self):
  58. global CURRENT_DEVICE
  59. self.old_device = CURRENT_DEVICE
  60. CURRENT_DEVICE = self.device
  61. # We need to put the device at the bottom of the stack
  62. # If we set default device within a function mode context
  63. # exiting that context mode will pop the device function mode off
  64. # of the stack incorrectly
  65. cur_stack = [_pop_mode() for _ in range(_len_torch_function_stack())]
  66. _push_mode(self)
  67. for mode in reversed(cur_stack):
  68. _push_mode(mode)
  69. def __exit__(self, exc_type, exc_val, exc_tb):
  70. global CURRENT_DEVICE
  71. CURRENT_DEVICE = self.old_device
  72. cur_stack = []
  73. # Invariant: there should only be one DeviceContext on the stack at any time
  74. # (At the bottom), pop all modes until we hit the bottom, assert it's a DeviceContext
  75. # or else someone else has popped it!
  76. for _ in range(_len_torch_function_stack() - 1):
  77. mode = _pop_mode()
  78. if isinstance(mode, DeviceContext):
  79. raise AssertionError(
  80. "Found nested DeviceContext on the mode stack where none expected"
  81. )
  82. cur_stack.append(mode)
  83. if _len_torch_function_stack() > 0:
  84. mode = _pop_mode()
  85. if not isinstance(mode, DeviceContext):
  86. raise AssertionError(
  87. "Expected a DeviceContext at the bottom of the mode stack"
  88. )
  89. for mode in reversed(cur_stack):
  90. _push_mode(mode)
  91. def __torch_function__(self, func, types, args=(), kwargs=None):
  92. kwargs = kwargs or {}
  93. if func in _device_constructors() and kwargs.get("device") is None:
  94. kwargs["device"] = self.device
  95. return func(*args, **kwargs)
  96. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  97. def device_decorator(device, func):
  98. return context_decorator(lambda: device, func)
  99. def set_device(device):
  100. """
  101. Set the default device inside of the wrapped function by decorating it with this function.
  102. If you would like to use this as a context manager, use device as a
  103. context manager directly, e.g., ``with torch.device(device)``.
  104. """
  105. return lambda func: device_decorator(torch.device(device), func)