__init__.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # mypy: allow-untyped-defs
  2. r"""
  3. This package implements abstractions found in ``torch.cuda``
  4. to facilitate writing device-agnostic code.
  5. """
  6. from contextlib import AbstractContextManager
  7. from typing import Any, Optional, Union
  8. import torch
  9. from .. import device as _device
  10. from . import amp
  11. __all__ = [
  12. "is_available",
  13. "is_initialized",
  14. "synchronize",
  15. "current_device",
  16. "current_stream",
  17. "stream",
  18. "set_device",
  19. "device_count",
  20. "Stream",
  21. "StreamContext",
  22. "Event",
  23. ]
  24. def _is_avx2_supported() -> bool:
  25. r"""Returns a bool indicating if CPU supports AVX2."""
  26. return torch._C._cpu._is_avx2_supported()
  27. def _is_avx512_supported() -> bool:
  28. r"""Returns a bool indicating if CPU supports AVX512."""
  29. return torch._C._cpu._is_avx512_supported()
  30. def _is_avx512_bf16_supported() -> bool:
  31. r"""Returns a bool indicating if CPU supports AVX512_BF16."""
  32. return torch._C._cpu._is_avx512_bf16_supported()
  33. def _is_vnni_supported() -> bool:
  34. r"""Returns a bool indicating if CPU supports VNNI."""
  35. # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later.
  36. return torch._C._cpu._is_avx512_vnni_supported()
  37. def _is_amx_tile_supported() -> bool:
  38. r"""Returns a bool indicating if CPU supports AMX_TILE."""
  39. return torch._C._cpu._is_amx_tile_supported()
  40. def _is_amx_fp16_supported() -> bool:
  41. r"""Returns a bool indicating if CPU supports AMX FP16."""
  42. return torch._C._cpu._is_amx_fp16_supported()
  43. def _init_amx() -> bool:
  44. r"""Initializes AMX instructions."""
  45. return torch._C._cpu._init_amx()
  46. def is_available() -> bool:
  47. r"""Returns a bool indicating if CPU is currently available.
  48. N.B. This function only exists to facilitate device-agnostic code
  49. """
  50. return True
  51. def synchronize(device: torch.types.Device = None) -> None:
  52. r"""Waits for all kernels in all streams on the CPU device to complete.
  53. Args:
  54. device (torch.device or int, optional): ignored, there's only one CPU device.
  55. N.B. This function only exists to facilitate device-agnostic code.
  56. """
  57. class Stream:
  58. """
  59. N.B. This class only exists to facilitate device-agnostic code
  60. """
  61. def __init__(self, priority: int = -1) -> None:
  62. pass
  63. def wait_stream(self, stream) -> None:
  64. pass
  65. def record_event(self) -> None:
  66. pass
  67. def wait_event(self, event) -> None:
  68. pass
  69. class Event:
  70. def query(self) -> bool:
  71. return True
  72. def record(self, stream=None) -> None:
  73. pass
  74. def synchronize(self) -> None:
  75. pass
  76. def wait(self, stream=None) -> None:
  77. pass
  78. _default_cpu_stream = Stream()
  79. _current_stream = _default_cpu_stream
  80. def current_stream(device: torch.types.Device = None) -> Stream:
  81. r"""Returns the currently selected :class:`Stream` for a given device.
  82. Args:
  83. device (torch.device or int, optional): Ignored.
  84. N.B. This function only exists to facilitate device-agnostic code
  85. """
  86. return _current_stream
  87. class StreamContext(AbstractContextManager):
  88. r"""Context-manager that selects a given stream.
  89. N.B. This class only exists to facilitate device-agnostic code
  90. """
  91. cur_stream: Optional[Stream]
  92. def __init__(self, stream):
  93. self.stream = stream
  94. self.prev_stream = _default_cpu_stream
  95. def __enter__(self):
  96. cur_stream = self.stream
  97. if cur_stream is None:
  98. return
  99. global _current_stream
  100. self.prev_stream = _current_stream
  101. _current_stream = cur_stream
  102. def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
  103. cur_stream = self.stream
  104. if cur_stream is None:
  105. return
  106. global _current_stream
  107. _current_stream = self.prev_stream
  108. def stream(stream: Stream) -> AbstractContextManager:
  109. r"""Wrapper around the Context-manager StreamContext that
  110. selects a given stream.
  111. N.B. This function only exists to facilitate device-agnostic code
  112. """
  113. return StreamContext(stream)
  114. def device_count() -> int:
  115. r"""Returns number of CPU devices (not cores). Always 1.
  116. N.B. This function only exists to facilitate device-agnostic code
  117. """
  118. return 1
  119. def set_device(device: torch.types.Device) -> None:
  120. r"""Sets the current device, in CPU we do nothing.
  121. N.B. This function only exists to facilitate device-agnostic code
  122. """
  123. def current_device() -> str:
  124. r"""Returns current device for cpu. Always 'cpu'.
  125. N.B. This function only exists to facilitate device-agnostic code
  126. """
  127. return "cpu"
  128. def is_initialized() -> bool:
  129. r"""Returns True if the CPU is initialized. Always True.
  130. N.B. This function only exists to facilitate device-agnostic code
  131. """
  132. return True