gds.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import os
  2. import sys
  3. from collections.abc import Callable
  4. from typing import Optional
  5. import torch
  6. from torch.types import Storage
  7. __all__: list[str] = [
  8. "gds_register_buffer",
  9. "gds_deregister_buffer",
  10. "GdsFile",
  11. ]
  12. def _dummy_fn(name: str) -> Callable:
  13. def fn(*args, **kwargs): # type: ignore[no-untyped-def]
  14. raise RuntimeError(f"torch._C.{name} is not supported on this platform")
  15. return fn
  16. if not hasattr(torch._C, "_gds_register_buffer"):
  17. assert not hasattr(torch._C, "_gds_deregister_buffer")
  18. assert not hasattr(torch._C, "_gds_register_handle")
  19. assert not hasattr(torch._C, "_gds_deregister_handle")
  20. assert not hasattr(torch._C, "_gds_load_storage")
  21. assert not hasattr(torch._C, "_gds_save_storage")
  22. # Define functions
  23. torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
  24. torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
  25. torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
  26. torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
  27. torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
  28. torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
  29. def gds_register_buffer(s: Storage) -> None:
  30. """Registers a storage on a CUDA device as a cufile buffer.
  31. Example::
  32. >>> # xdoctest: +SKIP("gds filesystem requirements")
  33. >>> src = torch.randn(1024, device="cuda")
  34. >>> s = src.untyped_storage()
  35. >>> gds_register_buffer(s)
  36. Args:
  37. s (Storage): Buffer to register.
  38. """
  39. torch._C._gds_register_buffer(s)
  40. def gds_deregister_buffer(s: Storage) -> None:
  41. """Deregisters a previously registered storage on a CUDA device as a cufile buffer.
  42. Example::
  43. >>> # xdoctest: +SKIP("gds filesystem requirements")
  44. >>> src = torch.randn(1024, device="cuda")
  45. >>> s = src.untyped_storage()
  46. >>> gds_register_buffer(s)
  47. >>> gds_deregister_buffer(s)
  48. Args:
  49. s (Storage): Buffer to register.
  50. """
  51. torch._C._gds_deregister_buffer(s)
  52. class GdsFile:
  53. r"""Wrapper around cuFile.
  54. cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
  55. See the `cufile docs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
  56. for more details.
  57. Args:
  58. filename (str): Name of the file to open.
  59. flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
  60. be added automatically.
  61. Example::
  62. >>> # xdoctest: +SKIP("gds filesystem requirements")
  63. >>> src1 = torch.randn(1024, device="cuda")
  64. >>> src2 = torch.randn(2, 1024, device="cuda")
  65. >>> file = torch.cuda.gds.GdsFile(f, os.O_CREAT | os.O_RDWR)
  66. >>> file.save_storage(src1.untyped_storage(), offset=0)
  67. >>> file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
  68. >>> dest1 = torch.empty(1024, device="cuda")
  69. >>> dest2 = torch.empty(2, 1024, device="cuda")
  70. >>> file.load_storage(dest1.untyped_storage(), offset=0)
  71. >>> file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
  72. >>> torch.equal(src1, dest1)
  73. True
  74. >>> torch.equal(src2, dest2)
  75. True
  76. """
  77. def __init__(self, filename: str, flags: int):
  78. if sys.platform == "win32":
  79. raise RuntimeError("GdsFile is not supported on this platform.")
  80. self.filename = filename
  81. self.flags = flags
  82. self.fd = os.open(filename, flags | os.O_DIRECT) # type: ignore[attr-defined]
  83. self.handle: Optional[int] = None
  84. self.register_handle()
  85. def __del__(self) -> None:
  86. if self.handle is not None:
  87. self.deregister_handle()
  88. os.close(self.fd)
  89. def register_handle(self) -> None:
  90. """Registers file descriptor to cuFile Driver.
  91. This is a wrapper around ``cuFileHandleRegister``.
  92. """
  93. assert self.handle is None, (
  94. "Cannot register a handle that is already registered."
  95. )
  96. self.handle = torch._C._gds_register_handle(self.fd)
  97. def deregister_handle(self) -> None:
  98. """Deregisters file descriptor from cuFile Driver.
  99. This is a wrapper around ``cuFileHandleDeregister``.
  100. """
  101. assert self.handle is not None, (
  102. "Cannot deregister a handle that is not registered."
  103. )
  104. torch._C._gds_deregister_handle(self.handle)
  105. self.handle = None
  106. def load_storage(self, storage: Storage, offset: int = 0) -> None:
  107. """Loads data from the file into the storage.
  108. This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
  109. will be loaded from the file at ``offset`` into the storage.
  110. Args:
  111. storage (Storage): Storage to load data into.
  112. offset (int, optional): Offset into the file to start loading from. (Default: 0)
  113. """
  114. assert self.handle is not None, (
  115. "Cannot load data from a file that is not registered."
  116. )
  117. torch._C._gds_load_storage(self.handle, storage, offset)
  118. def save_storage(self, storage: Storage, offset: int = 0) -> None:
  119. """Saves data from the storage into the file.
  120. This is a wrapper around ``cuFileWrite``. All bytes of the storage
  121. will be written to the file at ``offset``.
  122. Args:
  123. storage (Storage): Storage to save data from.
  124. offset (int, optional): Offset into the file to start saving to. (Default: 0)
  125. """
  126. assert self.handle is not None, (
  127. "Cannot save data to a file that is not registered."
  128. )
  129. torch._C._gds_save_storage(self.handle, storage, offset)