memory.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import collections
  2. from typing import Any, Union
  3. import torch
  4. from torch.types import Device
  5. from . import _get_device_index, is_initialized
  6. _device_t = Union[Device, str, int, None]
  7. def empty_cache() -> None:
  8. r"""Release all unoccupied cached memory currently held by the caching
  9. allocator so that those can be used in other XPU application.
  10. .. note::
  11. :func:`~torch.xpu.empty_cache` doesn't increase the amount of XPU
  12. memory available for PyTorch. However, it may help reduce fragmentation
  13. of XPU memory in certain cases.
  14. """
  15. if is_initialized():
  16. torch._C._xpu_emptyCache()
  17. def reset_peak_memory_stats(device: _device_t = None) -> None:
  18. r"""Reset the "peak" stats tracked by the XPU memory allocator.
  19. See :func:`~torch.xpu.memory_stats` for details. Peak stats correspond to the
  20. `"peak"` key in each individual stat dict.
  21. Args:
  22. device (torch.device or int or str, optional): selected device. Returns
  23. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  24. if :attr:`device` is ``None`` (default).
  25. """
  26. device = _get_device_index(device, optional=True)
  27. return torch._C._xpu_resetPeakMemoryStats(device)
  28. def reset_accumulated_memory_stats(device: _device_t = None) -> None:
  29. r"""Reset the "accumulated" (historical) stats tracked by the XPU memory allocator.
  30. See :func:`~torch.xpu.memory_stats` for details. Accumulated stats correspond to
  31. the `"allocated"` and `"freed"` keys in each individual stat dict.
  32. Args:
  33. device (torch.device or int or str, optional): selected device. Returns
  34. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  35. if :attr:`device` is ``None`` (default).
  36. """
  37. device = _get_device_index(device, optional=True)
  38. return torch._C._xpu_resetAccumulatedMemoryStats(device)
  39. def memory_stats_as_nested_dict(device: _device_t = None) -> dict[str, Any]:
  40. r"""Return the result of :func:`~torch.xpu.memory_stats` as a nested dictionary."""
  41. if not is_initialized():
  42. return {}
  43. device = _get_device_index(device, optional=True)
  44. return torch._C._xpu_memoryStats(device)
  45. def memory_stats(device: _device_t = None) -> dict[str, Any]:
  46. r"""Return a dictionary of XPU memory allocator statistics for a given device.
  47. The return value of this function is a dictionary of statistics, each of
  48. which is a non-negative integer.
  49. Core statistics:
  50. - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  51. amount of allocated memory.
  52. - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  53. amount of reserved memory.
  54. - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  55. amount of active memory.
  56. - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  57. memory requested by client code, compare this with allocated_bytes to check if
  58. allocation rounding adds too much overhead.
  59. For these core statistics, values are broken down as follows.
  60. Pool type:
  61. - ``all``: combined statistics across all memory pools.
  62. - ``large_pool``: statistics for the large allocation pool (for size >= 1MB allocations).
  63. - ``small_pool``: statistics for the small allocation pool (for size < 1MB allocations).
  64. Metric type:
  65. - ``current``: current value of this metric.
  66. - ``peak``: maximum value of this metric.
  67. - ``allocated``: historical total increase in this metric.
  68. - ``freed``: historical total decrease in this metric.
  69. Args:
  70. device (torch.device or int or str, optional): selected device. Returns
  71. statistics for the current device, given by :func:`~torch.xpu.current_device`,
  72. if :attr:`device` is ``None`` (default).
  73. """
  74. result = []
  75. def _recurse_add_to_result(prefix: str, obj: Any) -> None:
  76. if isinstance(obj, dict):
  77. if len(prefix) > 0:
  78. prefix += "."
  79. for k, v in obj.items():
  80. _recurse_add_to_result(prefix + k, v)
  81. else:
  82. result.append((prefix, obj))
  83. stats = memory_stats_as_nested_dict(device=device)
  84. _recurse_add_to_result("", stats)
  85. result.sort()
  86. return collections.OrderedDict(result)
  87. def memory_allocated(device: _device_t = None) -> int:
  88. r"""Return the current GPU memory occupied by tensors in bytes for a given device.
  89. Args:
  90. device (torch.device or int or str, optional): selected device. Returns
  91. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  92. if :attr:`device` is ``None`` (default).
  93. .. note::
  94. This is likely less than the amount shown in `xpu-smi` since some
  95. unused memory can be held by the caching allocator and some context
  96. needs to be created on GPU.
  97. """
  98. return memory_stats(device=device).get("allocated_bytes.all.current", 0)
  99. def max_memory_allocated(device: _device_t = None) -> int:
  100. r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
  101. By default, this returns the peak allocated memory since the beginning of
  102. this program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to
  103. reset the starting point in tracking this metric. For example, these two
  104. functions can measure the peak allocated memory usage of each iteration in a
  105. training loop.
  106. Args:
  107. device (torch.device or int or str, optional): selected device. Returns
  108. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  109. if :attr:`device` is ``None`` (default).
  110. """
  111. return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
  112. def memory_reserved(device: _device_t = None) -> int:
  113. r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
  114. Args:
  115. device (torch.device or int or str, optional): selected device. Returns
  116. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  117. if :attr:`device` is ``None`` (default).
  118. """
  119. return memory_stats(device=device).get("reserved_bytes.all.current", 0)
  120. def max_memory_reserved(device: _device_t = None) -> int:
  121. r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
  122. By default, this returns the peak cached memory since the beginning of this
  123. program. :func:`~torch.xpu.reset_peak_memory_stats` can be used to reset
  124. the starting point in tracking this metric. For example, these two functions
  125. can measure the peak cached memory amount of each iteration in a training
  126. loop.
  127. Args:
  128. device (torch.device or int or str, optional): selected device. Returns
  129. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  130. if :attr:`device` is ``None`` (default).
  131. """
  132. return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
  133. def mem_get_info(device: _device_t = None) -> tuple[int, int]:
  134. r"""Return the global free and total GPU memory for a given device.
  135. Args:
  136. device (torch.device or int or str, optional): selected device. Returns
  137. statistic for the current device, given by :func:`~torch.xpu.current_device`,
  138. if :attr:`device` is ``None`` (default).
  139. Returns:
  140. int: the memory available on the device in units of bytes.
  141. int: the total memory on the device in units of bytes
  142. """
  143. device = _get_device_index(device, optional=True)
  144. return torch._C._xpu_getMemoryInfo(device)
  145. __all__ = [
  146. "empty_cache",
  147. "max_memory_allocated",
  148. "max_memory_reserved",
  149. "mem_get_info",
  150. "memory_allocated",
  151. "memory_reserved",
  152. "memory_stats",
  153. "memory_stats_as_nested_dict",
  154. "reset_accumulated_memory_stats",
  155. "reset_peak_memory_stats",
  156. ]