memory.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from collections import OrderedDict
  2. from typing import Any
  3. import torch
  4. from ._utils import _device_t, _get_device_index
  5. __all__ = [
  6. "empty_cache",
  7. "max_memory_allocated",
  8. "max_memory_reserved",
  9. "memory_allocated",
  10. "memory_reserved",
  11. "memory_stats",
  12. "reset_accumulated_memory_stats",
  13. "reset_peak_memory_stats",
  14. ]
  15. def empty_cache() -> None:
  16. r"""Release all unoccupied cached memory currently held by the caching
  17. allocator so that those can be used in other application.
  18. .. note:: This function is a no-op if the memory allocator for the current
  19. :ref:`accelerator <accelerators>` has not been initialized.
  20. """
  21. if not torch._C._accelerator_isAllocatorInitialized():
  22. return
  23. torch._C._accelerator_emptyCache()
  24. def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
  25. r"""Return a dictionary of accelerator device memory allocator statistics for a given device index.
  26. The return value of this function is a dictionary of statistics, each of
  27. which is a non-negative integer.
  28. Core statistics:
  29. - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  30. number of allocation requests received by the memory allocator.
  31. - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  32. amount of allocated memory.
  33. - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  34. number of reserved segments from device memory allocation.
  35. - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  36. amount of reserved memory.
  37. - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  38. number of active memory blocks.
  39. - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  40. amount of active memory.
  41. - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  42. number of inactive, non-releasable memory blocks.
  43. - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  44. amount of inactive, non-releasable memory.
  45. For these core statistics, values are broken down as follows.
  46. Pool type:
  47. - ``all``: combined statistics across all memory pools.
  48. - ``large_pool``: statistics for the large allocation pool
  49. (as of June 2025, for size >= 1MB allocations).
  50. - ``small_pool``: statistics for the small allocation pool
  51. (as of June 2025, for size < 1MB allocations).
  52. Metric type:
  53. - ``current``: current value of this metric.
  54. - ``peak``: maximum value of this metric.
  55. - ``allocated``: historical total increase in this metric.
  56. - ``freed``: historical total decrease in this metric.
  57. In addition to the core statistics, we also provide some simple event
  58. counters:
  59. - ``"num_alloc_retries"``: number of failed device memory allocation calls that
  60. result in a cache flush and retry.
  61. - ``"num_ooms"``: number of out-of-memory errors thrown.
  62. - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls.
  63. - ``"num_device_alloc"``: number of device memory allocation calls.
  64. - ``"num_device_free"``: number of device memory free calls.
  65. Args:
  66. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  67. If not given, use :func:`torch.accelerator.current_device_index` by default.
  68. If a :class:`torch.device` or str is provided, its type must match the current
  69. :ref:`accelerator<accelerators>` device type.
  70. """
  71. if not torch._C._accelerator_isAllocatorInitialized():
  72. return OrderedDict()
  73. device_index = _get_device_index(device_index, optional=True)
  74. stats = torch._C._accelerator_getDeviceStats(device_index)
  75. flat_stats = []
  76. def flatten(prefix: str, value: Any) -> None:
  77. if isinstance(value, dict):
  78. for k, v in value.items():
  79. nested_prefix = f"{prefix}.{k}" if prefix else k
  80. flatten(nested_prefix, v)
  81. else:
  82. flat_stats.append((prefix, value))
  83. flatten("", stats)
  84. flat_stats.sort()
  85. return OrderedDict(flat_stats)
  86. def memory_allocated(device_index: _device_t = None, /) -> int:
  87. r"""Return the current :ref:`accelerator<accelerators>` device memory occupied by tensors
  88. in bytes for a given device index.
  89. Args:
  90. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  91. If not given, use :func:`torch.accelerator.current_device_index` by default.
  92. If a :class:`torch.device` or str is provided, its type must match the current
  93. :ref:`accelerator<accelerators>` device type.
  94. """
  95. return memory_stats(device_index).get("allocated_bytes.all.current", 0)
  96. def max_memory_allocated(device_index: _device_t = None, /) -> int:
  97. r"""Return the current :ref:`accelerator<accelerators>` maximum device memory occupied by tensors
  98. in bytes for a given device index.
  99. By default, this returns the peak allocated memory since the beginning of
  100. this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to
  101. reset the starting point in tracking this metric.
  102. Args:
  103. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  104. If not given, use :func:`torch.accelerator.current_device_index` by default.
  105. If a :class:`torch.device` or str is provided, its type must match the current
  106. :ref:`accelerator<accelerators>` device type.
  107. """
  108. return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
  109. def memory_reserved(device_index: _device_t = None, /) -> int:
  110. r"""Return the current :ref:`accelerator<accelerators>` device memory managed by the caching allocator
  111. in bytes for a given device index.
  112. Args:
  113. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  114. If not given, use :func:`torch.accelerator.current_device_index` by default.
  115. If a :class:`torch.device` or str is provided, its type must match the current
  116. :ref:`accelerator<accelerators>` device type.
  117. """
  118. return memory_stats(device_index).get("reserved_bytes.all.current", 0)
  119. def max_memory_reserved(device_index: _device_t = None, /) -> int:
  120. r"""Return the current :ref:`accelerator<accelerators>` maximum device memory managed by the caching allocator
  121. in bytes for a given device index.
  122. By default, this returns the peak cached memory since the beginning of this
  123. program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset
  124. the starting point in tracking this metric.
  125. Args:
  126. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  127. If not given, use :func:`torch.accelerator.current_device_index` by default.
  128. If a :class:`torch.device` or str is provided, its type must match the current
  129. :ref:`accelerator<accelerators>` device type.
  130. """
  131. return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
  132. def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None:
  133. r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator<accelerators>`
  134. memory allocator for a given device index.
  135. Args:
  136. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  137. If not given, use :func:`torch.accelerator.current_device_index` by default.
  138. If a :class:`torch.device` or str is provided, its type must match the current
  139. :ref:`accelerator<accelerators>` device type.
  140. .. note:: This function is a no-op if the memory allocator for the current
  141. :ref:`accelerator <accelerators>` has not been initialized.
  142. """
  143. device_index = _get_device_index(device_index, optional=True)
  144. return torch._C._accelerator_resetAccumulatedStats(device_index)
  145. def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
  146. r"""Reset the "peak" stats tracked by the current :ref:`accelerator<accelerators>`
  147. memory allocator for a given device index.
  148. Args:
  149. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  150. If not given, use :func:`torch.accelerator.current_device_index` by default.
  151. If a :class:`torch.device` or str is provided, its type must match the current
  152. :ref:`accelerator<accelerators>` device type.
  153. .. note:: This function is a no-op if the memory allocator for the current
  154. :ref:`accelerator <accelerators>` has not been initialized.
  155. """
  156. device_index = _get_device_index(device_index, optional=True)
  157. return torch._C._accelerator_resetPeakStats(device_index)