memory.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236
  1. # mypy: allow-untyped-defs
  2. r"""This package adds support for device memory management implemented in CUDA."""
  3. import collections
  4. import contextlib
  5. import ctypes
  6. import pickle
  7. import sys
  8. import warnings
  9. from inspect import signature
  10. from typing import Any, Literal, Optional, TYPE_CHECKING
  11. from typing_extensions import deprecated
  12. import torch
  13. from torch import _C
  14. from torch._utils import _dummy_type
  15. from . import (
  16. _get_amdsmi_device_index,
  17. _get_device_index,
  18. _get_nvml_device_index,
  19. _lazy_init,
  20. is_initialized,
  21. )
  22. from ._memory_viz import memory as _memory, segments as _segments
  23. if TYPE_CHECKING:
  24. from torch.types import Device
  25. __all__ = [
  26. "caching_allocator_alloc",
  27. "caching_allocator_delete",
  28. "caching_allocator_enable",
  29. "get_per_process_memory_fraction",
  30. "set_per_process_memory_fraction",
  31. "empty_cache",
  32. "memory_stats",
  33. "memory_stats_as_nested_dict",
  34. "reset_accumulated_memory_stats",
  35. "reset_peak_memory_stats",
  36. "reset_max_memory_allocated",
  37. "reset_max_memory_cached",
  38. "host_memory_stats",
  39. "host_memory_stats_as_nested_dict",
  40. "reset_accumulated_host_memory_stats",
  41. "reset_peak_host_memory_stats",
  42. "memory_allocated",
  43. "max_memory_allocated",
  44. "memory_reserved",
  45. "max_memory_reserved",
  46. "memory_cached",
  47. "max_memory_cached",
  48. "memory_snapshot",
  49. "memory_summary",
  50. "list_gpu_processes",
  51. "mem_get_info",
  52. "get_allocator_backend",
  53. "CUDAPluggableAllocator",
  54. "change_current_allocator",
  55. "MemPool",
  56. "use_mem_pool",
  57. ]
  58. if not hasattr(torch._C, "_cuda_CUDAAllocator"):
  59. # Define dummy base classes
  60. torch._C.__dict__["_cuda_CUDAAllocator"] = _dummy_type("_cuda_CUDAAllocator")
  61. if not hasattr(torch._C, "_MemPool"):
  62. # Define dummy base classes
  63. torch._C.__dict__["_MemPool"] = _dummy_type("_MemPool")
  64. torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
  65. "_cuda_beginAllocateToPool"
  66. )
  67. torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type(
  68. "_cuda_beginAllocateCurrentThreadToPool"
  69. )
  70. torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type(
  71. "_cuda_endAllocateToPool"
  72. )
  73. torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool")
  74. from torch._C import ( # noqa: F401
  75. _cuda_beginAllocateCurrentThreadToPool,
  76. _cuda_beginAllocateToPool,
  77. _cuda_CUDAAllocator,
  78. _cuda_endAllocateToPool,
  79. _cuda_releasePool,
  80. _MemPool,
  81. )
  82. def _host_allocator():
  83. _lazy_init()
  84. return torch._C._cuda_cudaHostAllocator()
  85. @contextlib.contextmanager
  86. def _free_mutex():
  87. torch._C._cuda_lock_mutex()
  88. try:
  89. yield
  90. finally:
  91. torch._C._cuda_unlock_mutex()
  92. def caching_allocator_alloc(size, device: "Device" = None, stream=None):
  93. r"""Perform a memory allocation using the CUDA memory allocator.
  94. Memory is allocated for a given device and a stream, this
  95. function is intended to be used for interoperability with other
  96. frameworks. Allocated memory is released through
  97. :func:`~torch.cuda.caching_allocator_delete`.
  98. Args:
  99. size (int): number of bytes to be allocated.
  100. device (torch.device or int, optional): selected device. If it is
  101. ``None`` the default CUDA device is used.
  102. stream (torch.cuda.Stream or int, optional): selected stream. If is ``None`` then
  103. the default stream for the selected device is used.
  104. .. note::
  105. See :ref:`cuda-memory-management` for more details about GPU memory
  106. management.
  107. """
  108. if device is None:
  109. device = torch.cuda.current_device()
  110. device = _get_device_index(device)
  111. if stream is None:
  112. stream = torch.cuda.current_stream(device)
  113. if isinstance(stream, torch.cuda.streams.Stream):
  114. stream = stream.cuda_stream
  115. if not isinstance(stream, int):
  116. raise TypeError(
  117. "Invalid type for stream argument, must be "
  118. "`torch.cuda.Stream` or `int` representing a pointer "
  119. "to a existing stream"
  120. )
  121. with torch.cuda.device(device):
  122. return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream)
  123. def caching_allocator_delete(mem_ptr):
  124. r"""Delete memory allocated using the CUDA memory allocator.
  125. Memory allocated with :func:`~torch.cuda.caching_allocator_alloc`.
  126. is freed here. The associated device and stream are tracked inside
  127. the allocator.
  128. Args:
  129. mem_ptr (int): memory address to be freed by the allocator.
  130. .. note::
  131. See :ref:`cuda-memory-management` for more details about GPU memory
  132. management.
  133. """
  134. torch._C._cuda_cudaCachingAllocator_raw_delete(mem_ptr)
  135. def caching_allocator_enable(value: bool = True) -> None:
  136. r"""Enable or disable the CUDA memory allocator. On by default."""
  137. if is_initialized():
  138. torch._C._cuda_cudaCachingAllocator_enable(value)
  139. def set_per_process_memory_fraction(fraction, device: "Device" = None) -> None:
  140. r"""Set memory fraction for a process.
  141. The fraction is used to limit an caching allocator to allocated memory on a CUDA device.
  142. The allowed value equals the total visible memory multiplied fraction.
  143. If trying to allocate more than the allowed value in a process, will raise an out of
  144. memory error in allocator.
  145. Args:
  146. fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
  147. device (torch.device or int, optional): selected device. If it is
  148. ``None`` the default CUDA device is used.
  149. .. note::
  150. In general, the total available free memory is less than the total capacity.
  151. """
  152. _lazy_init()
  153. if device is None:
  154. device = torch.cuda.current_device()
  155. device = _get_device_index(device)
  156. if not isinstance(fraction, float):
  157. raise TypeError("Invalid type for fraction argument, must be `float`")
  158. if fraction < 0 or fraction > 1:
  159. raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~1")
  160. torch._C._cuda_setMemoryFraction(fraction, device)
  161. def get_per_process_memory_fraction(device: "Device" = None) -> float:
  162. r"""Get memory fraction for a process.
  163. Args:
  164. device (torch.device or int, optional): selected device. If it is
  165. ``None`` the default CUDA device is used.
  166. Returns:
  167. memory fraction, in range 0~1. Allowed memory equals total_memory * fraction.
  168. """
  169. _lazy_init()
  170. if device is None:
  171. device = torch.cuda.current_device()
  172. device = _get_device_index(device)
  173. return torch._C._cuda_getMemoryFraction(device)
  174. def empty_cache() -> None:
  175. r"""Release all unoccupied cached memory currently held by the caching
  176. allocator so that those can be used in other GPU application and visible in
  177. `nvidia-smi`.
  178. .. note::
  179. :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
  180. memory available for PyTorch. However, it may help reduce fragmentation
  181. of GPU memory in certain cases. See :ref:`cuda-memory-management` for
  182. more details about GPU memory management.
  183. """
  184. if is_initialized():
  185. torch._C._cuda_emptyCache()
  186. def memory_stats(device: "Device" = None) -> dict[str, Any]:
  187. r"""Return a dictionary of CUDA memory allocator statistics for a given device.
  188. The return value of this function is a dictionary of statistics, each of
  189. which is a non-negative integer.
  190. Core statistics:
  191. - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  192. number of allocation requests received by the memory allocator.
  193. - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  194. amount of allocated memory.
  195. - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  196. number of reserved segments from ``cudaMalloc()``.
  197. - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  198. amount of reserved memory.
  199. - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  200. number of active memory blocks.
  201. - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  202. amount of active memory.
  203. - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  204. number of inactive, non-releasable memory blocks.
  205. - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  206. amount of inactive, non-releasable memory.
  207. For these core statistics, values are broken down as follows.
  208. Pool type:
  209. - ``all``: combined statistics across all memory pools.
  210. - ``large_pool``: statistics for the large allocation pool
  211. (as of June 2025, for size >= 1MB allocations).
  212. - ``small_pool``: statistics for the small allocation pool
  213. (as of June 2025, for size < 1MB allocations).
  214. Metric type:
  215. - ``current``: current value of this metric.
  216. - ``peak``: maximum value of this metric.
  217. - ``allocated``: historical total increase in this metric.
  218. - ``freed``: historical total decrease in this metric.
  219. In addition to the core statistics, we also provide some simple event
  220. counters:
  221. - ``"num_alloc_retries"``: number of failed ``cudaMalloc`` calls that
  222. result in a cache flush and retry.
  223. - ``"num_ooms"``: number of out-of-memory errors thrown.
  224. - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls.
  225. - ``"num_device_alloc"``: number of CUDA allocation calls. This includes both
  226. cuMemMap and cudaMalloc.
  227. - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap
  228. and cudaFree.
  229. The caching allocator can be configured via ENV to not split blocks larger than a
  230. defined size (see Memory Management section of the Cuda Semantics documentation).
  231. This helps avoid memory fragmentation but may have a performance
  232. penalty. Additional outputs to assist with tuning and evaluating impact:
  233. - ``"max_split_size"``: blocks above this size will not be split.
  234. - ``"oversize_allocations.{current,peak,allocated,freed}"``:
  235. number of over-size allocation requests received by the memory allocator.
  236. - ``"oversize_segments.{current,peak,allocated,freed}"``:
  237. number of over-size reserved segments from ``cudaMalloc()``.
  238. The caching allocator can be configured via ENV to round memory allocations in order
  239. to reduce fragmentation. Sometimes the overhead from rounding can be higher than
  240. the fragmentation it helps reduce. The following stat can be used to check if
  241. rounding adds too much overhead:
  242. - ``"requested_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  243. memory requested by client code, compare this with allocated_bytes to check if
  244. allocation rounding adds too much overhead.
  245. Args:
  246. device (torch.device or int, optional): selected device. Returns
  247. statistics for the current device, given by :func:`~torch.cuda.current_device`,
  248. if :attr:`device` is ``None`` (default).
  249. .. note::
  250. See :ref:`cuda-memory-management` for more details about GPU memory
  251. management.
  252. .. note::
  253. With :ref:`backend:cudaMallocAsync<cuda-memory-envvars>`, some stats are not
  254. meaningful, and are always reported as zero.
  255. """
  256. result = []
  257. def _recurse_add_to_result(prefix, obj):
  258. if isinstance(obj, dict):
  259. if len(prefix) > 0:
  260. prefix += "."
  261. for k, v in obj.items():
  262. _recurse_add_to_result(prefix + k, v)
  263. else:
  264. result.append((prefix, obj))
  265. stats = memory_stats_as_nested_dict(device=device)
  266. _recurse_add_to_result("", stats)
  267. result.sort()
  268. return collections.OrderedDict(result)
  269. def memory_stats_as_nested_dict(device: "Device" = None) -> dict[str, Any]:
  270. r"""Return the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
  271. if not is_initialized():
  272. return {}
  273. device = _get_device_index(device, optional=True)
  274. return torch._C._cuda_memoryStats(device)
  275. def reset_accumulated_memory_stats(device: "Device" = None) -> None:
  276. r"""Reset the "accumulated" (historical) stats tracked by the CUDA memory allocator.
  277. See :func:`~torch.cuda.memory_stats` for details. Accumulated stats correspond to
  278. the `"allocated"` and `"freed"` keys in each individual stat dict, as well as
  279. `"num_alloc_retries"` and `"num_ooms"`.
  280. Args:
  281. device (torch.device or int, optional): selected device. Returns
  282. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  283. if :attr:`device` is ``None`` (default).
  284. .. note::
  285. See :ref:`cuda-memory-management` for more details about GPU memory
  286. management.
  287. """
  288. device = _get_device_index(device, optional=True)
  289. return torch._C._cuda_resetAccumulatedMemoryStats(device)
  290. def reset_peak_memory_stats(device: "Device" = None) -> None:
  291. r"""Reset the "peak" stats tracked by the CUDA memory allocator.
  292. See :func:`~torch.cuda.memory_stats` for details. Peak stats correspond to the
  293. `"peak"` key in each individual stat dict.
  294. Args:
  295. device (torch.device or int, optional): selected device. Returns
  296. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  297. if :attr:`device` is ``None`` (default).
  298. .. note::
  299. See :ref:`cuda-memory-management` for more details about GPU memory
  300. management.
  301. """
  302. device = _get_device_index(device, optional=True)
  303. return torch._C._cuda_resetPeakMemoryStats(device)
  304. def host_memory_stats() -> dict[str, Any]:
  305. r"""Return a dictionary of CUDA memory allocator statistics for a given device.
  306. The return value of this function is a dictionary of statistics, each of
  307. which is a non-negative integer.
  308. Core statistics:
  309. - ``"allocated.{current,peak,allocated,freed}"``:
  310. number of allocation requests received by the memory allocator.
  311. - ``"allocated_bytes.{current,peak,allocated,freed}"``:
  312. amount of allocated memory.
  313. - ``"segment.{current,peak,allocated,freed}"``:
  314. number of reserved segments from ``cudaMalloc()``.
  315. - ``"reserved_bytes.{current,peak,allocated,freed}"``:
  316. amount of reserved memory.
  317. For these core statistics, values are broken down as follows.
  318. Metric type:
  319. - ``current``: current value of this metric.
  320. - ``peak``: maximum value of this metric.
  321. - ``allocated``: historical total increase in this metric.
  322. - ``freed``: historical total decrease in this metric.
  323. In addition to the core statistics, we also provide some simple event
  324. counters:
  325. - ``"num_host_alloc"``: number of CUDA allocation calls. This includes both
  326. cudaHostAlloc and cudaHostRegister.
  327. - ``"num_host_free"``: number of CUDA free calls. This includes both cudaHostFree
  328. and cudaHostUnregister.
  329. Finally, we also provide some simple timing counters:
  330. - ``"host_alloc_time.{total,max,min,count,avg}"``:
  331. timing of allocation requests going through CUDA calls.
  332. - ``"host_free_time.{total,max,min,count,avg}"``:
  333. timing of free requests going through CUDA calls.
  334. For these timing statistics, values are broken down as follows.
  335. Metric type:
  336. - ``total``: total time spent.
  337. - ``max``: maximum value per call.
  338. - ``min``: minimum value per call.
  339. - ``count``: number of times it was called.
  340. - ``avg``: average time per call.
  341. """
  342. result = []
  343. def _recurse_add_to_result(prefix, obj):
  344. if isinstance(obj, dict):
  345. if len(prefix) > 0:
  346. prefix += "."
  347. for k, v in obj.items():
  348. _recurse_add_to_result(prefix + k, v)
  349. else:
  350. result.append((prefix, obj))
  351. stats = host_memory_stats_as_nested_dict()
  352. _recurse_add_to_result("", stats)
  353. result.sort()
  354. return collections.OrderedDict(result)
  355. def host_memory_stats_as_nested_dict() -> dict[str, Any]:
  356. r"""Return the result of :func:`~torch.cuda.host_memory_stats` as a nested dictionary."""
  357. if not is_initialized():
  358. return {}
  359. return torch._C._cuda_hostMemoryStats()
  360. def reset_accumulated_host_memory_stats() -> None:
  361. r"""Reset the "accumulated" (historical) stats tracked by the host memory allocator.
  362. See :func:`~torch.cuda.host_memory_stats` for details. Accumulated stats correspond to
  363. the `"allocated"` and `"freed"` keys in each individual stat dict.
  364. """
  365. return torch._C._cuda_resetAccumulatedHostMemoryStats()
  366. def reset_peak_host_memory_stats() -> None:
  367. r"""Reset the "peak" stats tracked by the host memory allocator.
  368. See :func:`~torch.cuda.host_memory_stats` for details. Peak stats correspond to the
  369. `"peak"` key in each individual stat dict.
  370. """
  371. return torch._C._cuda_resetPeakHostMemoryStats()
  372. def reset_max_memory_allocated(device: "Device" = None) -> None:
  373. r"""Reset the starting point in tracking maximum GPU memory occupied by tensors for a given device.
  374. See :func:`~torch.cuda.max_memory_allocated` for details.
  375. Args:
  376. device (torch.device or int, optional): selected device. Returns
  377. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  378. if :attr:`device` is ``None`` (default).
  379. .. warning::
  380. This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
  381. /all/ peak memory stats.
  382. .. note::
  383. See :ref:`cuda-memory-management` for more details about GPU memory
  384. management.
  385. """
  386. warnings.warn(
  387. "torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, "
  388. "which resets /all/ peak memory stats.",
  389. FutureWarning,
  390. )
  391. return reset_peak_memory_stats(device=device)
  392. def reset_max_memory_cached(device: "Device" = None) -> None:
  393. r"""Reset the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
  394. See :func:`~torch.cuda.max_memory_cached` for details.
  395. Args:
  396. device (torch.device or int, optional): selected device. Returns
  397. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  398. if :attr:`device` is ``None`` (default).
  399. .. warning::
  400. This function now calls :func:`~torch.cuda.reset_peak_memory_stats`, which resets
  401. /all/ peak memory stats.
  402. .. note::
  403. See :ref:`cuda-memory-management` for more details about GPU memory
  404. management.
  405. """
  406. warnings.warn(
  407. "torch.cuda.reset_max_memory_cached now calls torch.cuda.reset_peak_memory_stats, "
  408. "which resets /all/ peak memory stats.",
  409. FutureWarning,
  410. )
  411. return reset_peak_memory_stats(device=device)
  412. def memory_allocated(device: "Device" = None) -> int:
  413. r"""Return the current GPU memory occupied by tensors in bytes for a given device.
  414. Args:
  415. device (torch.device or int, optional): selected device. Returns
  416. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  417. if :attr:`device` is ``None`` (default).
  418. .. note::
  419. This is likely less than the amount shown in `nvidia-smi` since some
  420. unused memory can be held by the caching allocator and some context
  421. needs to be created on GPU. See :ref:`cuda-memory-management` for more
  422. details about GPU memory management.
  423. """
  424. return memory_stats(device=device).get("allocated_bytes.all.current", 0)
  425. def max_memory_allocated(device: "Device" = None) -> int:
  426. r"""Return the maximum GPU memory occupied by tensors in bytes for a given device.
  427. By default, this returns the peak allocated memory since the beginning of
  428. this program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to
  429. reset the starting point in tracking this metric. For example, these two
  430. functions can measure the peak allocated memory usage of each iteration in a
  431. training loop.
  432. Args:
  433. device (torch.device or int, optional): selected device. Returns
  434. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  435. if :attr:`device` is ``None`` (default).
  436. .. note::
  437. See :ref:`cuda-memory-management` for more details about GPU memory
  438. management.
  439. """
  440. return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
  441. def memory_reserved(device: "Device" = None) -> int:
  442. r"""Return the current GPU memory managed by the caching allocator in bytes for a given device.
  443. Args:
  444. device (torch.device or int, optional): selected device. Returns
  445. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  446. if :attr:`device` is ``None`` (default).
  447. .. note::
  448. See :ref:`cuda-memory-management` for more details about GPU memory
  449. management.
  450. """
  451. return memory_stats(device=device).get("reserved_bytes.all.current", 0)
  452. def max_memory_reserved(device: "Device" = None) -> int:
  453. r"""Return the maximum GPU memory managed by the caching allocator in bytes for a given device.
  454. By default, this returns the peak cached memory since the beginning of this
  455. program. :func:`~torch.cuda.reset_peak_memory_stats` can be used to reset
  456. the starting point in tracking this metric. For example, these two functions
  457. can measure the peak cached memory amount of each iteration in a training
  458. loop.
  459. Args:
  460. device (torch.device or int, optional): selected device. Returns
  461. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  462. if :attr:`device` is ``None`` (default).
  463. .. note::
  464. See :ref:`cuda-memory-management` for more details about GPU memory
  465. management.
  466. """
  467. return memory_stats(device=device).get("reserved_bytes.all.peak", 0)
  468. @deprecated(
  469. "`torch.cuda.memory_cached` has been renamed to `torch.cuda.memory_reserved`",
  470. category=FutureWarning,
  471. )
  472. def memory_cached(device: "Device" = None) -> int:
  473. r"""Deprecated; see :func:`~torch.cuda.memory_reserved`."""
  474. return memory_reserved(device=device)
  475. @deprecated(
  476. "`torch.cuda.max_memory_cached` has been renamed to `torch.cuda.max_memory_reserved`",
  477. category=FutureWarning,
  478. )
  479. def max_memory_cached(device: "Device" = None) -> int:
  480. r"""Deprecated; see :func:`~torch.cuda.max_memory_reserved`."""
  481. return max_memory_reserved(device=device)
  482. def memory_snapshot(mempool_id=None):
  483. r"""Return a snapshot of the CUDA memory allocator state across all devices.
  484. Interpreting the output of this function requires familiarity with the
  485. memory allocator internals.
  486. .. note::
  487. See :ref:`cuda-memory-management` for more details about GPU memory
  488. management.
  489. """
  490. return torch._C._cuda_memorySnapshot(mempool_id)["segments"]
  491. def memory_summary(device: "Device" = None, abbreviated: bool = False) -> str:
  492. r"""Return a human-readable printout of the current memory allocator statistics for a given device.
  493. This can be useful to display periodically during training, or when
  494. handling out-of-memory exceptions.
  495. Args:
  496. device (torch.device or int, optional): selected device. Returns
  497. printout for the current device, given by :func:`~torch.cuda.current_device`,
  498. if :attr:`device` is ``None`` (default).
  499. abbreviated (bool, optional): whether to return an abbreviated summary
  500. (default: False).
  501. .. note::
  502. See :ref:`cuda-memory-management` for more details about GPU memory
  503. management.
  504. """
  505. device = _get_device_index(device, optional=True)
  506. stats = memory_stats(device=device)
  507. def _format_size(sz, pref_sz):
  508. prefixes = ["B ", "KiB", "MiB", "GiB", "TiB", "PiB"]
  509. prefix = prefixes[0]
  510. for new_prefix in prefixes[1:]:
  511. if pref_sz < 768 * 1024:
  512. break
  513. prefix = new_prefix
  514. sz //= 1024
  515. pref_sz /= 1024
  516. return f"{sz:6d} {prefix}"
  517. def _format_count(cnt, pref_cnt):
  518. prefixes = [" ", "K", "M"]
  519. prefix = prefixes[0]
  520. for new_prefix in prefixes[1:]:
  521. if pref_cnt < 750 * 1000:
  522. break
  523. prefix = new_prefix
  524. cnt //= 1000
  525. pref_cnt /= 1000
  526. return f"{cnt:7d} {prefix} "
  527. metrics_to_display = [
  528. ("allocated_bytes", "Allocated memory", _format_size),
  529. ("active_bytes", "Active memory", _format_size),
  530. ("requested_bytes", "Requested memory", _format_size),
  531. ("reserved_bytes", "GPU reserved memory", _format_size),
  532. ("inactive_split_bytes", "Non-releasable memory", _format_size),
  533. ("allocation", "Allocations", _format_count),
  534. ("active", "Active allocs", _format_count),
  535. ("segment", "GPU reserved segments", _format_count),
  536. ("inactive_split", "Non-releasable allocs", _format_count),
  537. ]
  538. lines = []
  539. lines.append("=" * 75)
  540. lines.append(" {_:16} PyTorch CUDA memory summary, device ID {device:<17d} ")
  541. lines.append("-" * 75)
  542. lines.append(
  543. " {_:9} CUDA OOMs: {num_ooms:<12d} | {_:6} cudaMalloc retries: {num_alloc_retries:<8d} "
  544. )
  545. lines.append("=" * 75)
  546. lines.append(
  547. " Metric | Cur Usage | Peak Usage | Tot Alloc | Tot Freed "
  548. )
  549. for metric_key, metric_name, formatter in metrics_to_display:
  550. lines.append("-" * 75)
  551. submetrics = [("all", metric_name)]
  552. if not abbreviated:
  553. submetrics.append(("large_pool", " from large pool"))
  554. submetrics.append(("small_pool", " from small pool"))
  555. current_prefval, peak_prefval, allocated_prefval, freed_prefval = (
  556. None,
  557. None,
  558. None,
  559. None,
  560. )
  561. for submetric_key, submetric_name in submetrics:
  562. prefix = metric_key + "." + submetric_key + "."
  563. current = stats[prefix + "current"]
  564. peak = stats[prefix + "peak"]
  565. allocated = stats[prefix + "allocated"]
  566. freed = stats[prefix + "freed"]
  567. if current_prefval is None:
  568. current_prefval = current
  569. peak_prefval = peak
  570. allocated_prefval = allocated
  571. freed_prefval = freed
  572. lines.append(
  573. f" {submetric_name:<21} | {formatter(current, current_prefval)} | {formatter(peak, peak_prefval)} | "
  574. f"{formatter(allocated, allocated_prefval)} | {formatter(freed, freed_prefval)} ",
  575. )
  576. metrics_to_display = [
  577. ("oversize_allocations", "Oversize allocations", _format_count),
  578. ("oversize_segments", "Oversize GPU segments", _format_count),
  579. ]
  580. for metric_key, metric_name, formatter in metrics_to_display:
  581. lines.append("-" * 75)
  582. prefix = metric_key + "."
  583. current = stats[prefix + "current"]
  584. peak = stats[prefix + "peak"]
  585. allocated = stats[prefix + "allocated"]
  586. freed = stats[prefix + "freed"]
  587. lines.append(
  588. f" {metric_name:<21} | {formatter(current, current)} | {formatter(peak, peak)} | "
  589. f"{formatter(allocated, allocated)} | {formatter(freed, freed)} ",
  590. )
  591. lines.append("=" * 75)
  592. fmt_dict = {"_": "", "device": device}
  593. for k, v in stats.items():
  594. fmt_dict[k.replace(".", "-")] = v
  595. return "|" + "|\n|".join(lines).format(**fmt_dict) + "|\n"
  596. def list_gpu_processes(device: "Device" = None) -> str:
  597. r"""Return a human-readable printout of the running processes and their GPU memory use for a given device.
  598. This can be useful to display periodically during training, or when
  599. handling out-of-memory exceptions.
  600. Args:
  601. device (torch.device or int, optional): selected device. Returns
  602. printout for the current device, given by :func:`~torch.cuda.current_device`,
  603. if :attr:`device` is ``None`` (default).
  604. """
  605. if not torch.version.hip:
  606. try:
  607. import pynvml # type: ignore[import]
  608. except ModuleNotFoundError:
  609. return "pynvml module not found, please install pynvml"
  610. from pynvml import NVMLError_DriverNotLoaded
  611. try:
  612. pynvml.nvmlInit()
  613. except NVMLError_DriverNotLoaded:
  614. return "cuda driver can't be loaded, is cuda enabled?"
  615. device = _get_nvml_device_index(device)
  616. handle = pynvml.nvmlDeviceGetHandleByIndex(device)
  617. procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
  618. else:
  619. try:
  620. import amdsmi # type: ignore[import]
  621. except ModuleNotFoundError:
  622. return "amdsmi module not found, please install amdsmi"
  623. try:
  624. amdsmi.amdsmi_init() # type: ignore[attr-defined]
  625. except amdsmi.AmdSmiException: # type: ignore[attr-defined]
  626. return "amdsmi driver can't be loaded, is ROCm installed?"
  627. device = _get_amdsmi_device_index(device)
  628. try:
  629. handle = amdsmi.amdsmi_get_processor_handles()[device] # type: ignore[attr-defined]
  630. procs = amdsmi.amdsmi_get_gpu_process_list(handle) # type: ignore[attr-defined]
  631. except amdsmi.AmdSmiException: # type: ignore[attr-defined]
  632. return "amdsmi cannot list processes from other users"
  633. lines = []
  634. lines.append(f"GPU:{device}")
  635. if len(procs) == 0:
  636. lines.append("no processes are running")
  637. for p in procs:
  638. if not torch.version.hip:
  639. mem = p.usedGpuMemory / (1024 * 1024)
  640. pid = p.pid
  641. else:
  642. try:
  643. proc_info = amdsmi.amdsmi_get_gpu_process_info(handle, p) # type: ignore[possibly-undefined]
  644. except AttributeError:
  645. # https://github.com/ROCm/amdsmi/commit/c551c3caedbd903ba828e7fdffa5b56d475a15e7
  646. # is a BC-breaking change that removes amdsmi_get_gpu_process_info API from amdsmi
  647. proc_info = p
  648. mem = proc_info["memory_usage"]["vram_mem"] / (1024 * 1024)
  649. pid = proc_info["pid"]
  650. lines.append(f"process {pid:>10d} uses {mem:>12.3f} MB GPU memory")
  651. return "\n".join(lines)
  652. def mem_get_info(device: "Device" = None) -> tuple[int, int]:
  653. r"""Return the global free and total GPU memory for a given device using cudaMemGetInfo.
  654. Args:
  655. device (torch.device or int or str, optional): selected device. Returns
  656. statistic for the current device, given by :func:`~torch.cuda.current_device`,
  657. if :attr:`device` is ``None`` (default) or if the device index is not specified.
  658. .. note::
  659. See :ref:`cuda-memory-management` for more
  660. details about GPU memory management.
  661. """
  662. if device is None:
  663. device = torch.cuda.current_device()
  664. # optional=True allows `device = torch.device('cuda')` for which device.index is None
  665. device = _get_device_index(device, optional=True)
  666. return torch.cuda.cudart().cudaMemGetInfo(device)
  667. def _record_memory_history_legacy(
  668. enabled: bool,
  669. record_context=True,
  670. trace_alloc_max_entries=1,
  671. trace_alloc_record_context=False,
  672. device: "Device" = None,
  673. record_context_cpp=False,
  674. clear_history=False,
  675. compile_context=False,
  676. global_record_annotations=False,
  677. ):
  678. _C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
  679. enabled,
  680. record_context,
  681. trace_alloc_max_entries,
  682. trace_alloc_record_context,
  683. record_context_cpp,
  684. clear_history,
  685. compile_context,
  686. global_record_annotations,
  687. )
  688. def _record_memory_history(
  689. enabled: Literal[None, "state", "all"] = "all", *args, **kwargs
  690. ) -> None:
  691. """Enable recording of stack traces associated with memory
  692. allocations, so you can tell what allocated any piece of memory in
  693. :func:`torch.cuda.memory._snapshot()`.
  694. In addition to keeping stack traces with each current allocation and free,
  695. this will also enable recording of a history of all alloc/free events.
  696. Use :func:`torch.cuda.memory._snapshot()` to retrieve this information,
  697. and the tools in `_memory_viz.py` to visualize snapshots.
  698. Buffer behavior
  699. ---------------
  700. This will store up to `max_entries` instances of `TraceEntry` when enabled.
  701. Python trace collection defaults to `sys.maxsize`, meaning long-running
  702. or indefinitely running jobs should set a reasonable limit to avoid excessive
  703. memory use. Expect each entry to be several KB.
  704. Longer running workflows or those with smaller `max_entries` values will only
  705. store the last accumulated `max_entries` entries, meaning new entries overwrite
  706. older entries.
  707. C++ implementation for reference to ring buffer implementation:
  708. .. code-block:: cpp
  709. if (record_history) {
  710. if (alloc_trace->size() < alloc_trace_max_entries_) {
  711. alloc_trace->emplace_back(te);
  712. } else {
  713. (*alloc_trace)[alloc_trace_next++] = te;
  714. if (alloc_trace_next == alloc_trace_max_entries_) {
  715. alloc_trace_next = 0;
  716. }
  717. }
  718. }
  719. Latency impact
  720. --------------
  721. The Python trace collection is fast (2us per trace), so you may consider
  722. enabling this on production jobs if you anticipate ever having to debug
  723. memory issues.
  724. C++ trace collection is also fast (~50ns/frame), which for many typical programs
  725. works out to ~2us per trace, but can vary depending on stack depth.
  726. Args:
  727. enabled (Literal[None, "state", "all"], optional):
  728. `None`, disable recording memory history.
  729. `"state"`, keep information for currently allocated memory.
  730. `"all"`, additionally keep a history of all alloc/free calls.
  731. Defaults to "all".
  732. context (Literal[None, "state", "alloc", "all"], optional):
  733. `None`, Do not record any tracebacks.
  734. `"state"`, Record tracebacks for currently allocated memory.
  735. `"alloc"`, additionally keep tracebacks for alloc calls.
  736. `"all"`, additionally keep tracebacks for free calls.
  737. Defaults to "all".
  738. stacks (Literal["python", "all"], optional):
  739. `"python"`, include Python, TorchScript, and inductor frames in tracebacks
  740. `"all"`, additionally include C++ frames
  741. Defaults to "all".
  742. max_entries (int, optional): Keep a maximum of `max_entries`
  743. alloc/free events in the recorded history recorded.
  744. """
  745. if isinstance(enabled, bool):
  746. return _record_memory_history_legacy(enabled, *args, **kwargs)
  747. else:
  748. return _record_memory_history_impl(enabled, *args, **kwargs)
  749. def _record_memory_history_impl(
  750. enabled: Optional[str] = "all",
  751. context: Optional[str] = "all",
  752. stacks: str = "all",
  753. max_entries: int = sys.maxsize,
  754. device: "Device" = None,
  755. clear_history: bool = False,
  756. compile_context: bool = False,
  757. global_record_annotations: bool = False,
  758. ):
  759. _C._cuda_record_memory_history( # type: ignore[call-arg]
  760. enabled,
  761. context,
  762. stacks,
  763. max_entries,
  764. clear_history,
  765. compile_context,
  766. global_record_annotations,
  767. )
  768. _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]
  769. def _snapshot(device: "Device" = None):
  770. """Save a snapshot of CUDA memory state at the time it was called.
  771. The state is represented as a dictionary with the following structure.
  772. .. code-block:: python
  773. class Snapshot(TypedDict):
  774. segments: List[Segment]
  775. device_traces: List[List[TraceEntry]]
  776. class Segment(TypedDict):
  777. # Segments are memory returned from a cudaMalloc call.
  778. # The size of reserved memory is the sum of all Segments.
  779. # Segments are cached and reused for future allocations.
  780. # If the reuse is smaller than the segment, the segment
  781. # is split into more then one Block.
  782. # empty_cache() frees Segments that are entirely inactive.
  783. address: int
  784. total_size: int # cudaMalloc'd size of segment
  785. stream: int
  786. segment_type: Literal["small", "large"] # 'large' (>1MB)
  787. allocated_size: int # size of memory in use
  788. active_size: int # size of memory in use or in active_awaiting_free state
  789. blocks: List[Block]
  790. class Block(TypedDict):
  791. # A piece of memory returned from the allocator, or
  792. # current cached but inactive.
  793. size: int
  794. requested_size: int # size requested during malloc, may be smaller than
  795. # size due to rounding
  796. address: int
  797. state: Literal[
  798. "active_allocated", # used by a tensor
  799. "active_awaiting_free", # waiting for another stream to finish using
  800. # this, then it will become free
  801. "inactive",
  802. ] # free for reuse
  803. frames: List[Frame] # stack trace from where the allocation occurred
  804. class Frame(TypedDict):
  805. filename: str
  806. line: int
  807. name: str
  808. class TraceEntry(TypedDict):
  809. # When `torch.cuda.memory._record_memory_history()` is enabled,
  810. # the snapshot will contain TraceEntry objects that record each
  811. # action the allocator took.
  812. action: Literal[
  813. "alloc" # memory allocated
  814. "free_requested", # the allocated received a call to free memory
  815. "free_completed", # the memory that was requested to be freed is now
  816. # able to be used in future allocation calls
  817. "segment_alloc", # the caching allocator ask cudaMalloc for more memory
  818. # and added it as a segment in its cache
  819. "segment_free", # the caching allocator called cudaFree to return memory
  820. # to cuda possibly trying free up memory to
  821. # allocate more segments or because empty_caches was called
  822. "oom", # the allocator threw an OOM exception. 'size' is
  823. # the requested number of bytes that did not succeed
  824. "snapshot", # the allocator generated a memory snapshot
  825. # useful to coorelate a previously taken
  826. # snapshot with this trace
  827. ]
  828. addr: int # not present for OOM
  829. frames: List[Frame]
  830. size: int
  831. stream: int
  832. device_free: int # only present for OOM, the amount of
  833. # memory cuda still reports to be free
  834. Returns:
  835. The Snapshot dictionary object
  836. """
  837. return _C._cuda_memorySnapshot(None)
  838. def _dump_snapshot(filename="dump_snapshot.pickle"):
  839. """
  840. Save a pickled version of the `torch.memory._snapshot()` dictionary to a file.
  841. This file can be opened by the interactive snapshot viewer at pytorch.org/memory_viz
  842. Snapshot file sizes scale with `max_entries` and stack trace depth per entry,
  843. with several KB per entry. These can easily be in the GB range for longer running
  844. workflows with large `max_entries`.
  845. Args:
  846. filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle".
  847. """
  848. s = _snapshot()
  849. with open(filename, "wb") as f:
  850. pickle.dump(s, f)
  851. def _save_segment_usage(filename="output.svg", snapshot=None):
  852. if snapshot is None:
  853. snapshot = _snapshot()
  854. with open(filename, "w") as f:
  855. f.write(_segments(snapshot))
  856. def _save_memory_usage(filename="output.svg", snapshot=None):
  857. if snapshot is None:
  858. snapshot = _snapshot()
  859. with open(filename, "w") as f:
  860. f.write(_memory(snapshot))
  861. def _set_allocator_settings(env: str):
  862. return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
  863. def get_allocator_backend() -> str:
  864. r"""Return a string describing the active allocator backend as set by
  865. ``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
  866. ``native`` (PyTorch's native caching allocator) and `cudaMallocAsync``
  867. (CUDA's built-in asynchronous allocator).
  868. .. note::
  869. See :ref:`cuda-memory-management` for details on choosing the allocator backend.
  870. """
  871. return torch._C._cuda_getAllocatorBackend()
  872. class _CUDAAllocator:
  873. r"""Wrapper over internal CUDA memory allocators."""
  874. def __init__(self, allocator: torch._C._cuda_CUDAAllocator):
  875. self._allocator = allocator
  876. def allocator(self):
  877. return self._allocator
  878. class CUDAPluggableAllocator(_CUDAAllocator):
  879. r"""CUDA memory allocator loaded from a so file."""
  880. def __init__(self, path_to_so_file: str, alloc_fn_name: str, free_fn_name: str):
  881. r"""Memory allocators are compiled in .so files and loaded dynamically using ctypes.
  882. To change the active allocator use the :func:`torch.memory.cuda.change_current_allocator` function.
  883. Args:
  884. path_to_so_file(str): Path in the filesystem to the `.so` file containing
  885. the allocator functions
  886. alloc_fn_name(str): Name of the function to perform the memory allocation
  887. in the so file. The signature must be:
  888. void* alloc_fn_name(ssize_t size, int device, cudaStream_t stream);
  889. free_fn_name(str): Name of the function to perform the memory release
  890. in the so file. The signature must be:
  891. void free_fn_name(void* ptr, size_t size, cudaStream_t stream);
  892. .. warning::
  893. This is currently supported only in unix OSs
  894. .. note::
  895. See :ref:`cuda-memory-management` for details on creating and using a custom allocator
  896. """
  897. allocator = ctypes.CDLL(path_to_so_file)
  898. alloc_fn = ctypes.cast(getattr(allocator, alloc_fn_name), ctypes.c_void_p).value
  899. free_fn = ctypes.cast(getattr(allocator, free_fn_name), ctypes.c_void_p).value
  900. assert alloc_fn is not None
  901. assert free_fn is not None
  902. self._allocator = torch._C._cuda_customAllocator(alloc_fn, free_fn)
  903. def change_current_allocator(allocator: _CUDAAllocator) -> None:
  904. r"""Change the currently used memory allocator to be the one provided.
  905. If the current allocator has already been used/initialized, this function will error.
  906. Args:
  907. allocator (torch.cuda.memory._CUDAAllocator): allocator to be set as the active one.
  908. .. note::
  909. See :ref:`cuda-memory-management` for details on creating and using a custom allocator
  910. """
  911. torch._C._cuda_changeCurrentAllocator(allocator.allocator())
  912. def _get_current_allocator() -> _CUDAAllocator:
  913. r"""Return the allocator being currently used.
  914. .. note::
  915. See :ref:`cuda-memory-management` for details on creating and using a custom allocator
  916. """
  917. return _CUDAAllocator(torch._C._cuda_getAllocator())
  918. class MemPool(_MemPool):
  919. r"""MemPool represents a pool of memory in a caching allocator. Currently,
  920. it's just the ID of the pool object maintained in the CUDACachingAllocator.
  921. Args:
  922. allocator(torch._C._cuda_CUDAAllocator, optional): a
  923. torch._C._cuda_CUDAAllocator object that can be used to
  924. define how memory gets allocated in the pool. If :attr:`allocator`
  925. is ``None`` (default), memory allocation follows the default/
  926. current configuration of the CUDACachingAllocator.
  927. use_on_oom(bool): a bool that indicates if this pool can be used
  928. as a last resort if a memory allocation outside of the pool fails due
  929. to Out Of Memory. This is False by default.
  930. """
  931. def __init__(
  932. self,
  933. allocator: Optional[_cuda_CUDAAllocator] = None,
  934. use_on_oom: bool = False,
  935. ):
  936. super().__init__(allocator, True, use_on_oom)
  937. @property
  938. def id(self) -> tuple[int, int]:
  939. r"""Returns the ID of this pool as a tuple of two ints."""
  940. return super().id
  941. @property
  942. def allocator(self) -> Optional[_cuda_CUDAAllocator]:
  943. r"""Returns the allocator this MemPool routes allocations to."""
  944. return super().allocator
  945. def use_count(self) -> int:
  946. r"""Returns the reference count of this pool."""
  947. return super().use_count()
  948. def snapshot(self):
  949. r"""Return a snapshot of the CUDA memory allocator pool state across all
  950. devices.
  951. Interpreting the output of this function requires familiarity with the
  952. memory allocator internals.
  953. .. note::
  954. See :ref:`cuda-memory-management` for more details about GPU memory
  955. management.
  956. """
  957. snapshot = torch.cuda.memory_snapshot(self.id)
  958. return snapshot
  959. @contextlib.contextmanager
  960. def use_mem_pool(pool: MemPool, device: "Device" = None):
  961. r"""A context manager that routes allocations to a given pool.
  962. Args:
  963. pool(torch.cuda.MemPool): a MemPool object to be made active so that
  964. allocations route to this pool.
  965. device (torch.device or int, optional): selected device. Uses MemPool on
  966. the current device, given by :func:`~torch.cuda.current_device`,
  967. if :attr:`device` is ``None`` (default).
  968. .. note::
  969. This context manager makes only current thread's allocations route to
  970. the given pool. If a new thread is spawned inside the context manager
  971. (e.g. by calling backward) the allocations in that thread will not
  972. route to the given pool.
  973. """
  974. device_index = (
  975. torch.cuda.current_device() if device is None else _get_device_index(device)
  976. )
  977. _cuda_beginAllocateCurrentThreadToPool(device_index, pool.id)
  978. try:
  979. yield
  980. finally:
  981. _cuda_endAllocateToPool(device_index, pool.id)
  982. _cuda_releasePool(device_index, pool.id)