| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- # mypy: allow-untyped-defs
- r"""
- This package introduces support for the XPU backend, specifically tailored for
- Intel GPU optimization.
- This package is lazily initialized, so you can always import it, and use
- :func:`is_available()` to determine if your system supports XPU.
- """
- import threading
- import traceback
- from functools import lru_cache
- from typing import Any, Callable, Optional, Union
- import torch
- import torch._C
- from torch import device as _device
- from torch._utils import _dummy_type, _LazySeedTracker
- from ._utils import _get_device_index
- from .streams import Event, Stream
- _initialized = False
- _tls = threading.local()
- _initialization_lock = threading.Lock()
- _queued_calls: list[
- tuple[Callable[[], None], list[str]]
- ] = [] # don't invoke these until initialization occurs
- _is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
- _device_t = Union[_device, str, int, None]
- _lazy_seed_tracker = _LazySeedTracker()
- default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment]
- def _is_compiled() -> bool:
- r"""Return true if compile with XPU support."""
- return torch._C._has_xpu
- if _is_compiled():
- _XpuDeviceProperties = torch._C._XpuDeviceProperties
- _exchange_device = torch._C._xpu_exchangeDevice
- _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
- else:
- # Define dummy if PyTorch was compiled without XPU
- _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties") # type: ignore[assignment, misc]
- def _exchange_device(device: int) -> int:
- raise NotImplementedError("PyTorch was compiled without XPU support")
- def _maybe_exchange_device(device: int) -> int:
- raise NotImplementedError("PyTorch was compiled without XPU support")
- @lru_cache(maxsize=1)
- def device_count() -> int:
- r"""Return the number of XPU device available."""
- if not _is_compiled():
- return 0
- return torch._C._xpu_getDeviceCount()
- def is_available() -> bool:
- r"""Return a bool indicating if XPU is currently available."""
- # This function never throws.
- return device_count() > 0
- def is_bf16_supported(including_emulation: bool = True) -> bool:
- r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
- if not is_available():
- return False
- return (
- including_emulation
- or torch.xpu.get_device_properties().has_bfloat16_conversions
- )
- def is_initialized():
- r"""Return whether PyTorch's XPU state has been initialized."""
- return _initialized and not _is_in_bad_fork()
- def _lazy_call(callable, **kwargs):
- if is_initialized():
- callable()
- else:
- global _lazy_seed_tracker
- if kwargs.get("seed_all", False):
- _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
- elif kwargs.get("seed", False):
- _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
- else:
- # Don't store the actual traceback to avoid memory cycle
- _queued_calls.append((callable, traceback.format_stack()))
- def init():
- r"""Initialize PyTorch's XPU state.
- This is a Python API about lazy initialization that avoids initializing
- XPU until the first time it is accessed. Does nothing if the XPU state is
- already initialized.
- """
- _lazy_init()
- def _lazy_init():
- global _initialized, _queued_calls
- if is_initialized() or hasattr(_tls, "is_initializing"):
- return
- with _initialization_lock:
- # This test was was protected via GIL. Double-check whether XPU has
- # already been initialized.
- if is_initialized():
- return
- # Stop promptly upon encountering a bad fork error.
- if _is_in_bad_fork():
- raise RuntimeError(
- "Cannot re-initialize XPU in forked subprocess. To use XPU with "
- "multiprocessing, you must use the 'spawn' start method"
- )
- if not _is_compiled():
- raise AssertionError("Torch not compiled with XPU enabled")
- # This function inits XPU backend and detects bad fork processing.
- torch._C._xpu_init()
- # Some of the queued calls may reentrantly call _lazy_init(); We need to
- # just return without initializing in that case.
- _tls.is_initializing = True
- _queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
- try:
- for queued_call, orig_traceback in _queued_calls:
- try:
- queued_call()
- except Exception as e:
- msg = (
- f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
- f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
- )
- raise Exception(msg) from e # noqa: TRY002
- finally:
- delattr(_tls, "is_initializing")
- _initialized = True
- class _DeviceGuard:
- def __init__(self, index: int):
- self.idx = index
- self.prev_idx = -1
- def __enter__(self):
- self.prev_idx = torch.xpu._exchange_device(self.idx)
- def __exit__(self, type: Any, value: Any, traceback: Any):
- self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
- return False
- class device:
- r"""Context-manager that changes the selected device.
- Args:
- device (torch.device or int or str): device index to select. It's a no-op if
- this argument is a negative integer or ``None``.
- """
- def __init__(self, device: Any):
- self.idx = _get_device_index(device, optional=True)
- self.prev_idx = -1
- def __enter__(self):
- self.prev_idx = torch.xpu._exchange_device(self.idx)
- def __exit__(self, type: Any, value: Any, traceback: Any):
- self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
- return False
- class device_of(device):
- r"""Context-manager that changes the current device to that of given object.
- You can use both tensors and storages as arguments. If a given object is
- not allocated on a XPU, this is a no-op.
- Args:
- obj (Tensor or Storage): object allocated on the selected device.
- """
- def __init__(self, obj):
- idx = obj.get_device() if obj.is_xpu else -1
- super().__init__(idx)
- def set_device(device: _device_t) -> None:
- r"""Set the current device.
- Args:
- device (torch.device or int or str): selected device. This function is a
- no-op if this argument is negative.
- """
- _lazy_init()
- device = _get_device_index(device)
- if device >= 0:
- torch._C._xpu_setDevice(device)
- def get_device_name(device: Optional[_device_t] = None) -> str:
- r"""Get the name of a device.
- Args:
- device (torch.device or int or str, optional): device for which to
- return the name. This function is a no-op if this argument is a
- negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
- if :attr:`device` is ``None`` (default).
- Returns:
- str: the name of the device
- """
- return get_device_properties(device).name
- @lru_cache(None)
- def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
- r"""Get the xpu capability of a device.
- Args:
- device (torch.device or int or str, optional): device for which to
- return the device capability. This function is a no-op if this
- argument is a negative integer. It uses the current device, given by
- :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
- (default).
- Returns:
- Dict[str, Any]: the xpu capability dictionary of the device
- """
- props = get_device_properties(device)
- # Only keep attributes that are safe for dictionary serialization.
- serializable_types = (int, float, bool, str, type(None), list, tuple, dict)
- return {
- key: value
- for key in dir(props)
- if not key.startswith("__")
- and isinstance((value := getattr(props, key)), serializable_types)
- }
- def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
- r"""Get the properties of a device.
- Args:
- device (torch.device or int or str): device for which to return the
- properties of the device.
- Returns:
- _XpuDeviceProperties: the properties of the device
- """
- _lazy_init()
- device = _get_device_index(device, optional=True)
- return _get_device_properties(device) # type: ignore[name-defined] # noqa: F821
- def current_device() -> int:
- r"""Return the index of a currently selected device."""
- _lazy_init()
- return torch._C._xpu_getDevice()
- def _get_device(device: Union[int, str, torch.device]) -> torch.device:
- r"""Return the torch.device type object from the passed in device.
- Args:
- device (torch.device or int or str): selected device.
- """
- if isinstance(device, str):
- device = torch.device(device)
- elif isinstance(device, int):
- device = torch.device("xpu", device)
- return device
- class StreamContext:
- r"""Context-manager that selects a given stream.
- All XPU kernels queued within its context will be enqueued on a selected
- stream.
- Args:
- Stream (Stream): selected stream. This manager is a no-op if it's
- ``None``.
- .. note:: Streams are per-device.
- """
- cur_stream: Optional["torch.xpu.Stream"]
- def __init__(self, stream: Optional["torch.xpu.Stream"]):
- self.stream = stream
- self.idx = _get_device_index(None, True)
- if self.idx is None:
- self.idx = -1
- def __enter__(self):
- cur_stream = self.stream
- if cur_stream is None or self.idx == -1:
- return
- self.src_prev_stream = torch.xpu.current_stream(None)
- # If the stream is not on the current device, then set the current stream on the device
- if self.src_prev_stream.device != cur_stream.device:
- with device(cur_stream.device):
- self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
- torch.xpu.set_stream(cur_stream)
- def __exit__(self, type: Any, value: Any, traceback: Any):
- cur_stream = self.stream
- if cur_stream is None or self.idx == -1:
- return
- # Reset the stream on the original device and destination device
- if self.src_prev_stream.device != cur_stream.device:
- torch.xpu.set_stream(self.dst_prev_stream)
- torch.xpu.set_stream(self.src_prev_stream)
- def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
- r"""Wrap around the Context-manager StreamContext that selects a given stream.
- Arguments:
- stream (Stream): selected stream. This manager is a no-op if it's ``None``.
- """
- return StreamContext(stream)
- def _set_stream_by_id(stream_id, device_index, device_type):
- r"""set stream specified by the stream id, device index and device type
- Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
- device_index (int): selected device index.
- device_type (int): selected device type.
- """
- torch._C._xpu_setStream(
- stream_id=stream_id,
- device_index=device_index,
- device_type=device_type,
- )
- def set_stream(stream: Stream):
- r"""Set the current stream.This is a wrapper API to set the stream.
- Usage of this function is discouraged in favor of the ``stream``
- context manager.
- Args:
- stream (Stream): selected stream. This function is a no-op
- if this argument is ``None``.
- """
- if stream is None:
- return
- _lazy_init()
- _set_stream_by_id(
- stream_id=stream.stream_id,
- device_index=stream.device_index,
- device_type=stream.device_type,
- )
- def current_stream(device: Optional[_device_t] = None) -> Stream:
- r"""Return the currently selected :class:`Stream` for a given device.
- Args:
- device (torch.device or int, optional): selected device. Returns
- the currently selected :class:`Stream` for the current device, given
- by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
- (default).
- """
- _lazy_init()
- streamdata = torch._C._xpu_getCurrentStream(
- _get_device_index(device, optional=True)
- )
- return Stream(
- stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
- )
- def get_stream_from_external(
- data_ptr: int, device: Optional[_device_t] = None
- ) -> Stream:
- r"""Return a :class:`Stream` from an external SYCL queue.
- This function is used to wrap SYCL queue created in other libraries in order
- to facilitate data exchange and multi-library interactions.
- .. note:: This function doesn't manage the queue life-cycle, it is the user
- responsibility to keep the referenced queue alive while this returned stream is
- being used. The different SYCL queue pointers will result in distinct
- :class:`Stream` objects, even if the SYCL queues they dereference are equivalent.
- Args:
- data_ptr(int): Integer representation of the `sycl::queue*` value passed externally.
- device(torch.device or int, optional): the device where the queue was originally created.
- It is the user responsibility to ensure the device is specified correctly.
- """
- _lazy_init()
- streamdata = torch._C._xpu_getStreamFromExternal(
- data_ptr, _get_device_index(device, optional=True)
- )
- return Stream(
- stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
- )
- def synchronize(device: _device_t = None) -> None:
- r"""Wait for all kernels in all streams on a XPU device to complete.
- Args:
- device (torch.device or int, optional): device for which to synchronize.
- It uses the current device, given by :func:`~torch.xpu.current_device`,
- if :attr:`device` is ``None`` (default).
- """
- _lazy_init()
- device = _get_device_index(device, optional=True)
- return torch._C._xpu_synchronize(device)
- def get_arch_list() -> list[str]:
- r"""Return list XPU architectures this library was compiled for."""
- if not _is_compiled():
- return []
- arch_flags = torch._C._xpu_getArchFlags()
- if arch_flags is None:
- return []
- return arch_flags.split()
- def get_gencode_flags() -> str:
- r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with."""
- arch_list = get_arch_list()
- if len(arch_list) == 0:
- return ""
- return f"-device {','.join(arch for arch in arch_list)}"
- def _get_generator(device: torch.device) -> torch._C.Generator:
- r"""Return the XPU Generator object for the given device.
- Args:
- device (torch.device): selected device.
- """
- idx = device.index
- if idx is None:
- idx = current_device()
- return torch.xpu.default_generators[idx]
- def _set_rng_state_offset(
- offset: int, device: Union[int, str, torch.device] = "xpu"
- ) -> None:
- r"""Set the random number generator state offset of the specified GPU.
- Args:
- offset (int): The desired offset
- device (torch.device or int, optional): The device to set the RNG state.
- Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
- """
- final_device = _get_device(device)
- def cb():
- default_generator = _get_generator(final_device)
- default_generator.set_offset(offset)
- _lazy_call(cb)
- def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
- r"""Return the random number generator state offset of the specified GPU.
- Args:
- device (torch.device or int, optional): The device to return the RNG state offset of.
- Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
- .. warning::
- This function eagerly initializes XPU.
- """
- _lazy_init()
- final_device = _get_device(device)
- default_generator = _get_generator(final_device)
- return default_generator.get_offset()
- # import here to avoid circular import
- from .memory import (
- empty_cache,
- max_memory_allocated,
- max_memory_reserved,
- mem_get_info,
- memory_allocated,
- memory_reserved,
- memory_stats,
- memory_stats_as_nested_dict,
- reset_accumulated_memory_stats,
- reset_peak_memory_stats,
- )
- from .random import (
- get_rng_state,
- get_rng_state_all,
- initial_seed,
- manual_seed,
- manual_seed_all,
- seed,
- seed_all,
- set_rng_state,
- set_rng_state_all,
- )
- __all__ = [
- "Event",
- "Stream",
- "StreamContext",
- "current_device",
- "current_stream",
- "default_generators",
- "device",
- "device_of",
- "device_count",
- "empty_cache",
- "get_arch_list",
- "get_device_capability",
- "get_device_name",
- "get_device_properties",
- "get_gencode_flags",
- "get_rng_state",
- "get_rng_state_all",
- "get_stream_from_external",
- "init",
- "initial_seed",
- "is_available",
- "is_bf16_supported",
- "is_initialized",
- "manual_seed",
- "manual_seed_all",
- "max_memory_allocated",
- "max_memory_reserved",
- "mem_get_info",
- "memory_allocated",
- "memory_reserved",
- "memory_stats",
- "memory_stats_as_nested_dict",
- "reset_accumulated_memory_stats",
- "reset_peak_memory_stats",
- "seed",
- "seed_all",
- "set_device",
- "set_rng_state",
- "set_rng_state_all",
- "set_stream",
- "stream",
- "streams",
- "synchronize",
- ]
|