_device.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from typing import Optional
  4. import torch
  5. from torch._C import _len_torch_function_stack
  6. from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
  7. from torch.utils._contextlib import context_decorator
  8. CURRENT_DEVICE: Optional[torch.device] = None
  9. @functools.lru_cache(1)
  10. def _device_constructors():
  11. return {
  12. # standard ones
  13. torch.empty,
  14. torch.empty_permuted,
  15. torch.empty_strided,
  16. torch.empty_quantized,
  17. torch.ones,
  18. torch.arange,
  19. torch.bartlett_window,
  20. torch.blackman_window,
  21. torch.eye,
  22. torch.fft.fftfreq,
  23. torch.fft.rfftfreq,
  24. torch.full,
  25. torch.hamming_window,
  26. torch.hann_window,
  27. torch.kaiser_window,
  28. torch.linspace,
  29. torch.logspace,
  30. torch.nested.nested_tensor,
  31. # This function doesn't actually take a device argument
  32. # torch.normal,
  33. torch.rand,
  34. torch.randn,
  35. torch.randint,
  36. torch.randperm,
  37. torch.range,
  38. torch.sparse_coo_tensor,
  39. torch.sparse_compressed_tensor,
  40. torch.sparse_csr_tensor,
  41. torch.sparse_csc_tensor,
  42. torch.sparse_bsr_tensor,
  43. torch.sparse_bsc_tensor,
  44. torch.tril_indices,
  45. torch.triu_indices,
  46. torch.zeros,
  47. torch.asarray,
  48. # weird ones
  49. torch.tensor,
  50. torch.as_tensor,
  51. torch.scalar_tensor,
  52. }
  53. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  54. class DeviceContext(TorchFunctionMode):
  55. def __init__(self, device):
  56. self.device = torch.device(device)
  57. def __enter__(self):
  58. global CURRENT_DEVICE
  59. self.old_device = CURRENT_DEVICE
  60. CURRENT_DEVICE = self.device
  61. # We need to put the device at the bottom of the stack
  62. # If we set default device within a function mode context
  63. # exiting that context mode will pop the device function mode off
  64. # of the stack incorrectly
  65. cur_stack = [_pop_mode() for _ in range(_len_torch_function_stack())]
  66. _push_mode(self)
  67. for mode in reversed(cur_stack):
  68. _push_mode(mode)
  69. def __exit__(self, exc_type, exc_val, exc_tb):
  70. global CURRENT_DEVICE
  71. CURRENT_DEVICE = self.old_device
  72. cur_stack = []
  73. # Invariant: there should only be one DeviceContext on the stack at any time
  74. # (At the bottom), pop all modes until we hit the bottom, assert it's a DeviceContext
  75. # or else someone else has popped it!
  76. for _ in range(_len_torch_function_stack() - 1):
  77. mode = _pop_mode()
  78. assert not isinstance(mode, DeviceContext)
  79. cur_stack.append(mode)
  80. if _len_torch_function_stack() > 0:
  81. mode = _pop_mode()
  82. assert isinstance(mode, DeviceContext)
  83. for mode in reversed(cur_stack):
  84. _push_mode(mode)
  85. def __torch_function__(self, func, types, args=(), kwargs=None):
  86. kwargs = kwargs or {}
  87. if func in _device_constructors() and kwargs.get("device") is None:
  88. kwargs["device"] = self.device
  89. return func(*args, **kwargs)
  90. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  91. def device_decorator(device, func):
  92. return context_decorator(lambda: device, func)
  93. def set_device(device):
  94. """
  95. Set the default device inside of the wrapped function by decorating it with this function.
  96. If you would like to use this as a context manager, use device as a
  97. context manager directly, e.g., ``with torch.device(device)``.
  98. """
  99. return lambda func: device_decorator(torch.device(device), func)