_extension.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import abc
  3. import io
  4. from collections.abc import Sequence
  5. from typing import cast, IO, Optional
  6. # introduced as collections.abc.Buffer in Python 3.12
  7. from typing_extensions import Buffer
  8. from torch._utils import try_import
  9. # NOTE: everything in this file is experimental, and subject to
  10. # change. Feedback and bug fixes are always welcome.
  11. pyzstd_module_name = "pyzstd"
  12. pyzstd = try_import(pyzstd_module_name)
  13. zstandard_module_name = "zstandard"
  14. zstandard = try_import(zstandard_module_name)
  15. __all__ = [
  16. "Extension",
  17. "StreamTransformExtension",
  18. "ZStandard",
  19. "ExtensionRegistry",
  20. ]
  21. class Extension(abc.ABC):
  22. """
  23. Extensions provide modular additions to functionality within distributed checkpointing,
  24. which affect the layout or format of the written artifacts. Extensions may be
  25. built into pytorch, or provided externally.
  26. When writing, the caller provides a list of extension instances of the appropriate
  27. type. Each extension can output a descriptor which is used to reconstitute the
  28. extension at read-time.
  29. """
  30. @staticmethod
  31. @abc.abstractmethod
  32. def registry_name() -> str:
  33. """
  34. See ExtensionRegistry.from_descriptor_list
  35. """
  36. @staticmethod
  37. @abc.abstractmethod
  38. def from_descriptor(version: str) -> "Extension":
  39. """
  40. See ExtensionRegistry.from_descriptor_list
  41. """
  42. @abc.abstractmethod
  43. def get_descriptor(self) -> str:
  44. """
  45. Return descriptor name to be included in metadata. The form should be
  46. "extension_name[@local-domain][/version]".
  47. """
  48. class StreamTransformExtension(Extension):
  49. """
  50. An extension which performs transformation on a byte stream, such as compression
  51. or encryption.
  52. Implementations should try to be memory friendly and performant. For example, don't
  53. read the whole input, then transform it, and write it back. If at all possible, do it in
  54. chunks. But, don't read/transform/write one byte at a time, either.
  55. """
  56. @abc.abstractmethod
  57. def transform_to(self, output: IO[bytes]) -> IO[bytes]:
  58. """
  59. Takes a writeable output stream, and generates a new stream which implements the
  60. output transform. Input data written to the returned stream will be transformed
  61. and written to the `output` argument stream.
  62. """
  63. @abc.abstractmethod
  64. def transform_from(self, input: IO[bytes]) -> IO[bytes]:
  65. """
  66. Takes a readable input stream, and generates a new stream which implements the
  67. input transform. When the returned stream is read, data will be read from the
  68. 'input' stream, transformed, and returned.
  69. """
  70. class ZStandard(StreamTransformExtension):
  71. @staticmethod
  72. def is_available() -> bool:
  73. return zstandard is not None or pyzstd is not None
  74. @staticmethod
  75. def from_descriptor(version: str) -> "ZStandard":
  76. if version.partition(".")[0] != "1":
  77. raise ValueError(f"Unknown extension {version=}")
  78. if not ZStandard.is_available():
  79. raise ValueError(
  80. f"Stream with ZStandard compression cannot be processed because "
  81. f"no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
  82. )
  83. return ZStandard()
  84. @staticmethod
  85. def registry_name() -> str:
  86. return "stream.zstd"
  87. def __init__(self) -> None:
  88. super().__init__()
  89. if not ZStandard.is_available():
  90. raise ValueError(
  91. f"ZStandard extension is unavailable because no module named '{zstandard_module_name}' or '{pyzstd_module_name}'"
  92. )
  93. def get_descriptor(self) -> str:
  94. return f"{self.registry_name()}/1"
  95. def transform_to(self, output: IO[bytes]) -> IO[bytes]:
  96. if zstandard is not None:
  97. compressor = zstandard.ZstdCompressor() # type: ignore[union-attr]
  98. return compressor.stream_writer(output)
  99. class Writer(io.RawIOBase):
  100. def __init__(self, output: IO[bytes]) -> None:
  101. self.output = output
  102. self.compressor = pyzstd.ZstdCompressor() # type: ignore[union-attr]
  103. def writeable(self) -> bool:
  104. return True
  105. def write(self, b: Buffer) -> Optional[int]:
  106. outdata = self.compressor.compress(b)
  107. if outdata:
  108. self.output.write(outdata)
  109. return len(memoryview(b))
  110. def flush(self) -> None:
  111. outdata = self.compressor.flush()
  112. if outdata:
  113. self.output.write(outdata)
  114. self.output.flush()
  115. return cast(IO[bytes], Writer(output))
  116. def transform_from(self, input: IO[bytes]) -> IO[bytes]:
  117. if zstandard is not None:
  118. decompressor = zstandard.ZstdDecompressor() # type: ignore[union-attr]
  119. return decompressor.stream_reader(input)
  120. class Reader(io.RawIOBase):
  121. def __init__(self, input: IO[bytes]) -> None:
  122. self.input = input
  123. self.decompressor = pyzstd.EndlessZstdDecompressor() # type: ignore[union-attr]
  124. def readable(self) -> bool:
  125. return True
  126. def readinto(self, b: Buffer) -> Optional[int]:
  127. # This needs to read enough so it can decompress
  128. # something so the output doesn't look like EOF. This
  129. # means reading at least one block. The max block
  130. # size is 128KB, so we read that plus some
  131. # overhead to be sure.
  132. if self.decompressor.needs_input:
  133. indata = self.input.read((128 + 6) * 1024)
  134. else:
  135. indata = b""
  136. bview = memoryview(b)
  137. blen = len(bview)
  138. outdata = self.decompressor.decompress(indata, blen)
  139. if outdata is None:
  140. return None
  141. count = len(outdata)
  142. bview[:count] = outdata
  143. return count
  144. def seekable(self) -> bool:
  145. return False
  146. return cast(IO[bytes], Reader(input))
  147. class ExtensionRegistry:
  148. def __init__(self) -> None:
  149. # Populate default registry contents
  150. self.extensions: dict[str, type[Extension]] = {
  151. cls.registry_name(): cls for cls in (ZStandard,)
  152. }
  153. def register(self, cls: type[Extension]) -> None:
  154. self.extensions[cls.registry_name()] = cls
  155. def from_descriptor_list(self, descriptors: Sequence[str]) -> Sequence[Extension]:
  156. """
  157. Given a seuquence of descriptor strings as returned by
  158. Extension.get_descriptor at save time, creates a sequence of
  159. Extension instances. The name[@local-domain] preceding the
  160. version number is used to look up an implementation class in
  161. the registry, and the version is passed to the class's
  162. from_descriptor static method. If the registry contains no
  163. match, this will throw ValueError. If the from_descriptor
  164. method raises an exception, that will pass through to the
  165. caller.
  166. """
  167. def from_descriptor(desc: str) -> Extension:
  168. name, _, version = desc.partition("/")
  169. if version is None:
  170. version = 0
  171. ext = self.extensions.get(name)
  172. if not ext:
  173. raise ValueError(f"Unknown extension {name=}")
  174. return ext.from_descriptor(version)
  175. return [from_descriptor(desc) for desc in descriptors]