_cache.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import copy
  2. import dataclasses
  3. import logging
  4. from abc import ABC, abstractmethod
  5. from collections import defaultdict
  6. from collections.abc import Generator
  7. from contextlib import contextmanager
  8. from itertools import chain
  9. from typing import Any, Optional
  10. from torch.utils._appending_byte_serializer import (
  11. AppendingByteSerializer,
  12. BytesReader,
  13. BytesWriter,
  14. )
  15. from torch.utils._ordered_set import OrderedSet
  16. log = logging.getLogger(__name__)
  17. @dataclasses.dataclass(frozen=True)
  18. class CacheArtifact(ABC):
  19. """
  20. Data for each cache artifact that will be serialized and deserialized
  21. """
  22. key: str
  23. content: bytes = dataclasses.field(repr=False) # Do not display potential binary
  24. @staticmethod
  25. def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None:
  26. writer.write_str(cls.key)
  27. writer.write_bytes(cls.content)
  28. @staticmethod
  29. def deserialize(artifact_type: str, reader: BytesReader) -> "CacheArtifact":
  30. key = reader.read_str()
  31. content = reader.read_bytes()
  32. return CacheArtifactFactory.create(artifact_type, key, content)
  33. @staticmethod
  34. def encode(content: Any) -> bytes:
  35. assert isinstance(content, bytes), f"Expected bytes, got {type(content)}"
  36. return content
  37. @abstractmethod
  38. def populate_cache(self) -> None:
  39. pass
  40. @staticmethod
  41. def type() -> str:
  42. """
  43. Returns the type of the artifact. Must be unique across all CacheArtifact classes.
  44. CacheArtifactFactory.register will add property method to CacheInfo based on this (def {type}_artifacts)
  45. that returns all artifacts for specific cache.
  46. """
  47. raise RuntimeError("CacheArtifact is an abstract class, please use a subclass")
  48. class CacheArtifactFactory:
  49. """
  50. Factory for creating CacheArtifact objects based on their type
  51. """
  52. _artifact_types: dict[str, type[CacheArtifact]] = {}
  53. @classmethod
  54. def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]:
  55. artifact_type_key = artifact_cls.type()
  56. assert artifact_cls.type() not in cls._artifact_types, (
  57. f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory"
  58. )
  59. cls._artifact_types[artifact_type_key] = artifact_cls
  60. setattr(
  61. CacheInfo,
  62. f"{artifact_type_key}_artifacts",
  63. property(lambda self: self.artifacts[artifact_type_key]),
  64. )
  65. return artifact_cls
  66. @classmethod
  67. def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]:
  68. assert artifact_type_key in cls._artifact_types, (
  69. f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory"
  70. )
  71. return cls._artifact_types[artifact_type_key]
  72. @classmethod
  73. def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
  74. artifact_cls = cls._get_artifact_type(artifact_type_key)
  75. # pyrefly: ignore [bad-instantiation]
  76. return artifact_cls(key, content)
  77. @classmethod
  78. def encode_create(
  79. cls, artifact_type_key: str, key: str, content: Any
  80. ) -> CacheArtifact:
  81. artifact_cls = cls._get_artifact_type(artifact_type_key)
  82. # pyrefly: ignore [bad-instantiation]
  83. return artifact_cls(key, artifact_cls.encode(content))
  84. @dataclasses.dataclass
  85. class CacheInfo:
  86. """
  87. Return value of serialization and deserialization for the purpose of
  88. instrumentation
  89. """
  90. artifacts: defaultdict[str, list[str]] = dataclasses.field(
  91. default_factory=lambda: defaultdict(list)
  92. )
  93. # Methods set by CacheArtifactFactory.register based on CacheArtifact.type()
  94. @property
  95. def inductor_artifacts(self) -> list[str]: # type: ignore[empty-body]
  96. ...
  97. @property
  98. def autotune_artifacts(self) -> list[str]: # type: ignore[empty-body]
  99. ...
  100. @property
  101. def aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
  102. ...
  103. @property
  104. def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
  105. ...
  106. @property
  107. def precompile_artifacts(self) -> list[str]: # type: ignore[empty-body]
  108. ...
  109. def add(self, artifact: CacheArtifact) -> None:
  110. self.artifacts[artifact.type()].append(artifact.key)
  111. def clear(self) -> None:
  112. self.artifacts.clear()
  113. def empty(self) -> bool:
  114. return not self.artifacts
  115. def _serialize_single_cache(
  116. writer: BytesWriter, cls: "tuple[str, list[CacheArtifact]]"
  117. ) -> None:
  118. writer.write_str(cls[0])
  119. writer.write_uint64(len(cls[1]))
  120. for artifact in cls[1]:
  121. CacheArtifact.serialize(writer, artifact)
  122. def _deserialize_single_cache(
  123. reader: BytesReader,
  124. ) -> "tuple[str, list[CacheArtifact]]":
  125. artifacts = []
  126. artifact_type_key = reader.read_str()
  127. num_artifacts = reader.read_uint64()
  128. for _ in range(num_artifacts):
  129. artifacts.append(CacheArtifact.deserialize(artifact_type_key, reader))
  130. return artifact_type_key, artifacts
  131. CacheArtifactsResult = dict[str, list[CacheArtifact]]
  132. class CacheArtifactManager:
  133. """
  134. Lightweight manager class for collecting and processing cache artifacts for
  135. hot loading
  136. Intended Lifecycle:
  137. - Execute code via torch.compile, this will call
  138. CacheArtifactManager.record_artifact on each cache artifact
  139. - Call CacheArtifactManager.serialize to convert all the cache artifacts
  140. to portable format
  141. - Call CacheArtifactManager.deserialize to hot load the cache artifacts on
  142. a potentially different process
  143. NOTE: There's no FB/FC guarantees, results of cache artifacts will not be
  144. used unless code version matches.
  145. """
  146. # Protected by the compile_lock
  147. _new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
  148. # Keep a separate seen artifacts list to make avoid unnecessary duplicates
  149. # This list will not be cleared between serialize() calls
  150. _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
  151. # When serialize() is called, artifacts are transferred from _cache_artifacts to
  152. # internal data structure of the _serializer
  153. # This allows us to only pay the cost of serialization if serialize() is called
  154. _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
  155. AppendingByteSerializer(serialize_fn=_serialize_single_cache)
  156. )
  157. _cache_info: CacheInfo = CacheInfo()
  158. @classmethod
  159. def clear(cls) -> None:
  160. cls._new_cache_artifacts.clear()
  161. cls._seen_artifacts.clear()
  162. cls._serializer.clear()
  163. cls._cache_info.clear()
  164. @classmethod
  165. @contextmanager
  166. def with_fresh_cache(cls) -> Generator[None, None, None]:
  167. original_new_cache_artifacts = cls._new_cache_artifacts
  168. original_seen_artifacts = cls._seen_artifacts
  169. original_serializer = cls._serializer
  170. original_cache_info = cls._cache_info
  171. cls._new_cache_artifacts = defaultdict(list)
  172. cls._seen_artifacts = OrderedSet()
  173. cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
  174. cls._cache_info = cls._cache_info.__class__()
  175. try:
  176. yield
  177. finally:
  178. cls._new_cache_artifacts = original_new_cache_artifacts
  179. cls._seen_artifacts = original_seen_artifacts
  180. cls._serializer = original_serializer
  181. cls._cache_info = original_cache_info
  182. @classmethod
  183. def record_artifact(
  184. cls,
  185. artifact_type: str,
  186. key: str,
  187. content: Any,
  188. ) -> None:
  189. """
  190. Called from each caching operation to record the artifact in this
  191. "mega" list
  192. """
  193. artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
  194. if artifact in cls._seen_artifacts:
  195. return
  196. log.debug("Recording %s", str(artifact))
  197. cls._new_cache_artifacts[artifact_type].append(artifact)
  198. cls._seen_artifacts.add(artifact)
  199. @classmethod
  200. def need_serialize(cls) -> bool:
  201. """
  202. Have we seen new artifacts since last serialize call?
  203. """
  204. return len(cls._new_cache_artifacts) != 0
  205. @classmethod
  206. def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
  207. """
  208. Converts the "mega" list into portable format
  209. """
  210. for artifact in chain(*cls._new_cache_artifacts.values()):
  211. log.debug("saving: %s", artifact)
  212. cls._cache_info.add(artifact)
  213. if cls._cache_info.empty():
  214. # If there are not artifacts, dont just return bytes with
  215. # version.
  216. return None
  217. try:
  218. # We deep copy cls._cache_info since later compilations
  219. # can keep adding to cache_info
  220. info = copy.deepcopy(cls._cache_info)
  221. cls._serializer.extend(cls._new_cache_artifacts.items())
  222. artifact_bytes = cls._serializer.to_bytes()
  223. cls._new_cache_artifacts.clear()
  224. return artifact_bytes, info
  225. except Exception:
  226. log.warning("Failed to pickle cache artifacts", exc_info=True)
  227. return None
  228. @staticmethod
  229. def deserialize(serialized_artifacts: bytes) -> Optional[CacheArtifactsResult]:
  230. """
  231. Converts the portable format back into CacheArtifacts
  232. """
  233. try:
  234. CacheArtifactManager._ensure_cache_artifacts_registered()
  235. artifacts = dict(
  236. AppendingByteSerializer.to_list(
  237. serialized_artifacts,
  238. deserialize_fn=_deserialize_single_cache,
  239. )
  240. )
  241. except Exception:
  242. log.warning("Failed to un-pickle cache artifacts", exc_info=True)
  243. return None
  244. return artifacts
  245. @staticmethod
  246. def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
  247. info = CacheInfo()
  248. for artifact in chain(*artifacts.values()):
  249. log.debug("writing: %s", artifact)
  250. info.add(artifact)
  251. artifact.populate_cache()
  252. return info
  253. @classmethod
  254. def _ensure_cache_artifacts_registered(cls) -> None:
  255. """When deserializing caches in fresh process, we need to ensure that all
  256. cache artifacts are registered in the cache registry. This is done by
  257. simply importing all the cache artifacts already wrapped with register call.
  258. """
  259. from torch._dynamo.package import PrecompileCacheArtifact # noqa: F401
  260. from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
  261. from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
  262. AOTAutogradCacheArtifact,
  263. )
  264. from torch._inductor.codecache import InductorCacheArtifact # noqa: F401
  265. from torch._inductor.runtime.autotune_cache import ( # noqa: F401
  266. AutotuneCacheArtifact,
  267. )