gds.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. import sys
  3. from typing import Callable, Optional
  4. import torch
  5. from torch.types import Storage
  6. __all__: list[str] = [
  7. "gds_register_buffer",
  8. "gds_deregister_buffer",
  9. "GdsFile",
  10. ]
  11. def _dummy_fn(name: str) -> Callable:
  12. def fn(*args, **kwargs): # type: ignore[no-untyped-def]
  13. raise RuntimeError(f"torch._C.{name} is not supported on this platform")
  14. return fn
  15. if not hasattr(torch._C, "_gds_register_buffer"):
  16. assert not hasattr(torch._C, "_gds_deregister_buffer")
  17. assert not hasattr(torch._C, "_gds_register_handle")
  18. assert not hasattr(torch._C, "_gds_deregister_handle")
  19. assert not hasattr(torch._C, "_gds_load_storage")
  20. assert not hasattr(torch._C, "_gds_save_storage")
  21. # Define functions
  22. torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
  23. torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
  24. torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
  25. torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
  26. torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
  27. torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
  28. def gds_register_buffer(s: Storage) -> None:
  29. """Registers a storage on a CUDA device as a cufile buffer.
  30. Example::
  31. >>> # xdoctest: +SKIP("gds filesystem requirements")
  32. >>> src = torch.randn(1024, device="cuda")
  33. >>> s = src.untyped_storage()
  34. >>> gds_register_buffer(s)
  35. Args:
  36. s (Storage): Buffer to register.
  37. """
  38. torch._C._gds_register_buffer(s)
  39. def gds_deregister_buffer(s: Storage) -> None:
  40. """Deregisters a previously registered storage on a CUDA device as a cufile buffer.
  41. Example::
  42. >>> # xdoctest: +SKIP("gds filesystem requirements")
  43. >>> src = torch.randn(1024, device="cuda")
  44. >>> s = src.untyped_storage()
  45. >>> gds_register_buffer(s)
  46. >>> gds_deregister_buffer(s)
  47. Args:
  48. s (Storage): Buffer to register.
  49. """
  50. torch._C._gds_deregister_buffer(s)
  51. class GdsFile:
  52. r"""Wrapper around cuFile.
  53. cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
  54. See the `cufile docs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_
  55. for more details.
  56. Args:
  57. filename (str): Name of the file to open.
  58. flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
  59. be added automatically.
  60. Example::
  61. >>> # xdoctest: +SKIP("gds filesystem requirements")
  62. >>> src1 = torch.randn(1024, device="cuda")
  63. >>> src2 = torch.randn(2, 1024, device="cuda")
  64. >>> file = torch.cuda.gds.GdsFile(f, os.O_CREAT | os.O_RDWR)
  65. >>> file.save_storage(src1.untyped_storage(), offset=0)
  66. >>> file.save_storage(src2.untyped_storage(), offset=src1.nbytes)
  67. >>> dest1 = torch.empty(1024, device="cuda")
  68. >>> dest2 = torch.empty(2, 1024, device="cuda")
  69. >>> file.load_storage(dest1.untyped_storage(), offset=0)
  70. >>> file.load_storage(dest2.untyped_storage(), offset=src1.nbytes)
  71. >>> torch.equal(src1, dest1)
  72. True
  73. >>> torch.equal(src2, dest2)
  74. True
  75. """
  76. def __init__(self, filename: str, flags: int):
  77. if sys.platform == "win32":
  78. raise RuntimeError("GdsFile is not supported on this platform.")
  79. self.filename = filename
  80. self.flags = flags
  81. self.fd = os.open(filename, flags | os.O_DIRECT) # type: ignore[attr-defined]
  82. self.handle: Optional[int] = None
  83. self.register_handle()
  84. def __del__(self) -> None:
  85. if self.handle is not None:
  86. self.deregister_handle()
  87. os.close(self.fd)
  88. def register_handle(self) -> None:
  89. """Registers file descriptor to cuFile Driver.
  90. This is a wrapper around ``cuFileHandleRegister``.
  91. """
  92. assert self.handle is None, (
  93. "Cannot register a handle that is already registered."
  94. )
  95. self.handle = torch._C._gds_register_handle(self.fd)
  96. def deregister_handle(self) -> None:
  97. """Deregisters file descriptor from cuFile Driver.
  98. This is a wrapper around ``cuFileHandleDeregister``.
  99. """
  100. assert self.handle is not None, (
  101. "Cannot deregister a handle that is not registered."
  102. )
  103. torch._C._gds_deregister_handle(self.handle)
  104. self.handle = None
  105. def load_storage(self, storage: Storage, offset: int = 0) -> None:
  106. """Loads data from the file into the storage.
  107. This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
  108. will be loaded from the file at ``offset`` into the storage.
  109. Args:
  110. storage (Storage): Storage to load data into.
  111. offset (int, optional): Offset into the file to start loading from. (Default: 0)
  112. """
  113. assert self.handle is not None, (
  114. "Cannot load data from a file that is not registered."
  115. )
  116. torch._C._gds_load_storage(self.handle, storage, offset)
  117. def save_storage(self, storage: Storage, offset: int = 0) -> None:
  118. """Saves data from the storage into the file.
  119. This is a wrapper around ``cuFileWrite``. All bytes of the storage
  120. will be written to the file at ``offset``.
  121. Args:
  122. storage (Storage): Storage to save data from.
  123. offset (int, optional): Offset into the file to start saving to. (Default: 0)
  124. """
  125. assert self.handle is not None, (
  126. "Cannot save data to a file that is not registered."
  127. )
  128. torch._C._gds_save_storage(self.handle, storage, offset)