precompile_context.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import copy
  2. import logging
  3. import pickle
  4. from abc import abstractmethod
  5. from collections import defaultdict
  6. from itertools import chain
  7. from typing import Any, Callable, Generic, Optional, TypeVar, Union
  8. from typing_extensions import override
  9. from torch.compiler._cache import (
  10. _serialize_single_cache,
  11. CacheArtifact,
  12. CacheArtifactFactory,
  13. CacheArtifactManager,
  14. CacheArtifactsResult,
  15. CacheInfo,
  16. )
  17. from torch.utils._appending_byte_serializer import AppendingByteSerializer
  18. from torch.utils._ordered_set import OrderedSet
  19. """
  20. Classes and implementations related to precompile
  21. """
  22. T = TypeVar("T")
  23. logger = logging.getLogger(__name__)
  24. class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
  25. """
  26. Data for each cache artifact that will be serialized and deserialized by
  27. PrecompileContext, rather than CacheArtifactManager.
  28. T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
  29. PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
  30. as needed, and use them in after_deserialization.
  31. Example implementation:
  32. class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
  33. my_field: int
  34. def after_deserialization(self) -> MySerializableType:
  35. result = pickle.loads(self.content)
  36. # Do some extra work post deserialization
  37. result.my_post_deserialization_function(self.my_field)
  38. return result
  39. """
  40. @override
  41. def populate_cache(self) -> None:
  42. raise RuntimeError("Precompile cache artifacts do not populate caches")
  43. @override
  44. def precompile_compatible(self) -> bool:
  45. return True
  46. @abstractmethod
  47. def after_deserialization(self) -> T:
  48. """
  49. Code to be run after reading raw byte contents from disk.
  50. Generally converts self.content from raw bytes back into its original form.
  51. """
  52. ...
  53. class EditablePrecompileCacheArtifact(Generic[T]):
  54. """
  55. A PrecompileCacheArtifact whose content isn't encoded until we call PrecompileContext.serialize()
  56. """
  57. def __init__(self, artifact_type: str, content: Any, key: str) -> None:
  58. # Deepcopy the content for now, but don't pickle it yet.
  59. # This allows us to make changes to self.content before true serialization
  60. self.content = copy.deepcopy(content)
  61. self.key = key
  62. self.artifact_type = artifact_type
  63. def real_encode(self) -> PrecompileCacheArtifact[T]:
  64. """
  65. Actually encode the object
  66. """
  67. content = pickle.dumps(self.content)
  68. artifact = CacheArtifactFactory.encode_create(
  69. self.artifact_type, self.key, content
  70. )
  71. assert isinstance(artifact, PrecompileCacheArtifact)
  72. return artifact
  73. def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
  74. """
  75. Edit the content of an existing artifact
  76. """
  77. self.content = edit_fn(self.content)
  78. class PrecompileContext(CacheArtifactManager):
  79. """
  80. PrecompileContext is a special CacheArtifactManager for handling precompilation
  81. It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
  82. of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
  83. together and place it into a global Precompile Cache.
  84. The following artifact types are supported by PrecompileContext:
  85. - BundledAOTAutogradCacheArtifact
  86. - DynamoCodeStateArtifact
  87. - AutotuneCacheArtifact (regular autotune results, same as Megacache)
  88. """
  89. # Protected by the compile_lock
  90. # _new_cache_artifacts_by_key organizes results by the key of each artifact.
  91. # This allows us to implement serialize_by_key easily.
  92. # On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
  93. # are transferred to _new_cache_artifacts before serialization.
  94. _new_cache_artifacts_by_key: dict[
  95. str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
  96. ] = {}
  97. _new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
  98. # Keep a separate seen artifacts list to make avoid unnecessary duplicates
  99. # This list will not be cleared between serialize() calls
  100. _seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
  101. # When serialize() is called, artifacts are transferred from _cache_artifacts to
  102. # internal data structure of the _serializer
  103. # This allows us to only pay the cost of serialization if serialize() is called
  104. _serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
  105. AppendingByteSerializer(serialize_fn=_serialize_single_cache)
  106. )
  107. _cache_info: CacheInfo = CacheInfo()
  108. @classmethod
  109. def clear(cls) -> None:
  110. cls._new_cache_artifacts_by_key.clear()
  111. super().clear()
  112. @override
  113. @classmethod
  114. def record_artifact(
  115. cls,
  116. artifact_type: str,
  117. key: str,
  118. content: Any,
  119. editable: bool = False,
  120. ) -> None:
  121. """
  122. Called from each caching operation to record the artifact in this
  123. "mega" list
  124. """
  125. artifact: Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
  126. if editable:
  127. artifact = EditablePrecompileCacheArtifact(artifact_type, content, key)
  128. else:
  129. artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
  130. # TODO: although this covers completely same artifacts, it's possible
  131. # with AOTAutogradCacheEntries to have multiple artifacts whose keys
  132. # (i.e. backend_ids) are different, but whose contents are equal.
  133. # In those cases, it would be much better if we only serialize once instead
  134. # of N times.
  135. if artifact in cls._seen_artifacts:
  136. return
  137. cls._seen_artifacts.add(artifact)
  138. cls._new_cache_artifacts_by_key[key] = artifact
  139. @classmethod
  140. def _save_artifacts_by_type(cls) -> None:
  141. """
  142. We normally record artifacts by key, but serialization expects them to be organized
  143. by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
  144. """
  145. for artifact in cls._new_cache_artifacts_by_key.values():
  146. if isinstance(artifact, EditablePrecompileCacheArtifact):
  147. artifact = artifact.real_encode()
  148. cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
  149. cls._new_cache_artifacts_by_key.clear()
  150. @classmethod
  151. def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
  152. """
  153. Edit the content of an existing artifact
  154. """
  155. assert key in cls._new_cache_artifacts_by_key, (
  156. f"Key {key} not found in artifacts"
  157. )
  158. artifact = cls._new_cache_artifacts_by_key[key]
  159. assert isinstance(artifact, EditablePrecompileCacheArtifact), (
  160. "Artifact is not editable"
  161. )
  162. artifact.edit_contents(edit_fn)
  163. @classmethod
  164. def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
  165. """
  166. Serialize all artifacts with the given key returned in a list.
  167. """
  168. result = cls._new_cache_artifacts_by_key.get(key, None)
  169. if isinstance(result, EditablePrecompileCacheArtifact):
  170. result = result.real_encode()
  171. return result
  172. @classmethod
  173. def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
  174. cls._save_artifacts_by_type()
  175. # No need to serialize if there are no new dynamo compiles
  176. if "precompile_dynamo" not in cls._new_cache_artifacts:
  177. return None
  178. return super().serialize()
  179. @staticmethod
  180. def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
  181. PrecompileContext._ensure_cache_artifacts_registered()
  182. artifacts_by_key = {}
  183. cache_info = CacheInfo()
  184. for artifact in chain(*artifacts.values()):
  185. if artifact.type() == "autotune":
  186. # Populate autotune cache artifacts
  187. artifact.populate_cache()
  188. else:
  189. artifacts_by_key[artifact.key] = artifact
  190. cache_info.add(artifact)
  191. from torch._dynamo.package import _BackendId, DynamoCache
  192. for dynamo_entry in artifacts["precompile_dynamo"]:
  193. assert isinstance(dynamo_entry, PrecompileCacheArtifact)
  194. cache_entry = dynamo_entry.after_deserialization()
  195. # Grab backends from the dynamo cache entry
  196. backends = cache_entry.backend_ids
  197. backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
  198. for id_ in backends:
  199. assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
  200. artifact = artifacts_by_key[id_]
  201. assert isinstance(artifact, PrecompileCacheArtifact)
  202. backend_content[id_] = artifact
  203. DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
  204. return cache_info
  205. @classmethod
  206. def _ensure_cache_artifacts_registered(cls) -> None:
  207. from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
  208. from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
  209. BundledAOTAutogradCacheArtifact,
  210. )