aot_compile.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import abc
  2. import builtins
  3. import importlib
  4. import inspect
  5. import logging
  6. import pickle
  7. import types
  8. from dataclasses import dataclass
  9. from typing import Any, Callable, Optional
  10. import torch
  11. import torch.fx
  12. from torch._dynamo.precompile_context import PrecompileContext
  13. from . import convert_frame
  14. from .hooks import Hooks
  15. log = logging.getLogger(__name__)
  16. class SerializableCallable(abc.ABC):
  17. @classmethod
  18. @abc.abstractmethod
  19. def serialize_compile_artifacts(cls, fn: Any) -> bytes:
  20. pass
  21. @classmethod
  22. @abc.abstractmethod
  23. def deserialize_compile_artifacts(cls, data: bytes) -> Any:
  24. pass
  25. def bind_locals(
  26. signature: inspect.Signature, *args: Any, **kwargs: Any
  27. ) -> dict[str, Any]:
  28. bound_arguments = signature.bind(*args, **kwargs)
  29. bound_arguments.apply_defaults()
  30. return bound_arguments.arguments
  31. @dataclass
  32. class CompileArtifacts:
  33. signature: inspect.Signature
  34. bytecode: types.CodeType
  35. guard_manager: Optional[torch._dynamo.guards.GuardManagerWrapper]
  36. guards_state: bytes
  37. import_sources: dict[str, str]
  38. backend_id: str
  39. compiled_fn: SerializableCallable
  40. original_code: types.CodeType
  41. closure: Optional[tuple[Any, ...]]
  42. @dataclass
  43. class AOTCompiledFunction:
  44. _artifacts: CompileArtifacts
  45. def guard_check(self, *args: Any, **kwargs: Any) -> bool:
  46. f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
  47. assert self._artifacts.guard_manager is not None
  48. return self._artifacts.guard_manager.check(f_locals)
  49. def __post_init__(self) -> None:
  50. import_sources = {
  51. alias: importlib.import_module(module_name)
  52. for alias, module_name in self._artifacts.import_sources.items()
  53. }
  54. f_globals = {
  55. **import_sources,
  56. self._artifacts.backend_id: self._artifacts.compiled_fn,
  57. }
  58. self.fn = types.FunctionType(
  59. self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
  60. )
  61. if self._artifacts.guard_manager is None:
  62. guards_state = pickle.loads(self._artifacts.guards_state)
  63. self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager(
  64. self._artifacts.original_code,
  65. guards_state.output_graph,
  66. shape_code_parts=guards_state.shape_code_parts,
  67. runtime_global_scope=f_globals,
  68. ).guard_manager
  69. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  70. assert self._artifacts.guard_manager is not None
  71. if not self.guard_check(*args, **kwargs):
  72. f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
  73. reason = str(self._artifacts.guard_manager.check_verbose(f_locals))
  74. raise RuntimeError(f"GuardManager check failed, reason: {reason}")
  75. return self.fn(*args, **kwargs)
  76. def save_compiled_function(self, path: str) -> None:
  77. with open(path, "wb") as f:
  78. f.write(type(self).serialize(self))
  79. @classmethod
  80. def serialize(cls, fn: "AOTCompiledFunction") -> bytes:
  81. from torch._dynamo.package import SerializedCode
  82. state = fn._artifacts.__dict__.copy()
  83. state["guard_manager"] = None
  84. state["bytecode"] = SerializedCode.from_code_object(state["bytecode"])
  85. compiled_fn = state["compiled_fn"]
  86. state["compiled_fn"] = (
  87. type(compiled_fn).deserialize_compile_artifacts,
  88. type(compiled_fn).serialize_compile_artifacts(compiled_fn),
  89. )
  90. state["original_code"] = SerializedCode.from_code_object(state["original_code"])
  91. return pickle.dumps(state)
  92. @classmethod
  93. def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
  94. from torch._dynamo.package import SerializedCode
  95. state = pickle.loads(data)
  96. state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
  97. deserializer, compiled_fn_state = state["compiled_fn"]
  98. state["compiled_fn"] = deserializer(compiled_fn_state)
  99. state["original_code"] = SerializedCode.to_code_object(state["original_code"])
  100. artifacts = CompileArtifacts(**state)
  101. return cls(artifacts)
  102. class BundledAOTAutogradSerializableCallable(SerializableCallable):
  103. """
  104. Represents a serializable callable generated by compile_fx.
  105. This class wraps around the compiled function generated by AOTAutograd.
  106. TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
  107. this object should be what's *returned* by aot_module_simplified.
  108. We'll do that refactor in a later PR.
  109. """
  110. def __init__(self, artifact: Any) -> None:
  111. """
  112. Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
  113. of a compiled function generated by AOTAutograd.
  114. """
  115. self.compiled_fn = artifact.after_deserialization()
  116. self.data = artifact.content
  117. def __getattr__(self, attr: Any) -> Any:
  118. if hasattr(self, attr):
  119. return getattr(super(), attr)
  120. else:
  121. return getattr(self.compiled_fn, attr)
  122. @classmethod
  123. def from_backend_id(
  124. cls, backend_id: str
  125. ) -> "BundledAOTAutogradSerializableCallable":
  126. """
  127. Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable
  128. that wraps around the compiled function generated by AOTAutograd.
  129. """
  130. artifact = PrecompileContext.serialize_artifact_by_key(backend_id)
  131. if artifact is None:
  132. raise RuntimeError("No artifact found for backend_id: " + backend_id)
  133. return cls(artifact)
  134. @classmethod
  135. def serialize_compile_artifacts(
  136. cls, fn: "BundledAOTAutogradSerializableCallable"
  137. ) -> bytes:
  138. return fn.data
  139. @classmethod
  140. def deserialize_compile_artifacts(cls, data: bytes) -> Any:
  141. from torch._functorch._aot_autograd.autograd_cache import (
  142. BundledAOTAutogradCacheArtifact,
  143. )
  144. # The key in the artifact is not important here since we're not populating a cache,
  145. # we just want to grab the callable back out of the serialized entry
  146. artifact = BundledAOTAutogradCacheArtifact("", data)
  147. return cls(artifact)
  148. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  149. return self.compiled_fn(*args, **kwargs)
  150. def aot_compile_fullgraph(
  151. model: Any,
  152. example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
  153. hooks: Hooks,
  154. backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
  155. ) -> AOTCompiledFunction:
  156. from torch._dynamo.guards import CheckFunctionManager
  157. from torch._dynamo.utils import dynamo_timed, get_metrics_context
  158. from torch._guards import compile_context, CompileContext, TracingContext
  159. args, kwargs = example_inputs
  160. if hasattr(model, "__self__"):
  161. fn = model.__func__
  162. args = (model.__self__,) + args
  163. elif inspect.isfunction(model):
  164. fn = model
  165. else:
  166. raise RuntimeError(f"Unsupported model code type {model}")
  167. signature = inspect.signature(fn)
  168. f_locals = bind_locals(signature, *args, **kwargs)
  169. if fn.__code__.co_freevars or fn.__closure__:
  170. assert len(fn.__closure__) == len(fn.__code__.co_freevars)
  171. f_locals.update(
  172. {
  173. name: cell.cell_contents
  174. for name, cell in zip(fn.__code__.co_freevars, fn.__closure__)
  175. }
  176. )
  177. with (
  178. compile_context(CompileContext(convert_frame.get_compile_id({}))),
  179. get_metrics_context(),
  180. dynamo_timed("fullgraph_capture"),
  181. ):
  182. capture_output = convert_frame.fullgraph_capture(
  183. convert_frame.FrameInfo(
  184. fn.__code__,
  185. fn.__globals__,
  186. f_locals,
  187. builtins.__dict__,
  188. closure=fn.__closure__ or (), # type: ignore[arg-type]
  189. )
  190. )
  191. dynamo_output = capture_output.dynamo_output
  192. if not hooks.guard_filter_fn:
  193. from torch._dynamo.types import GuardFilterEntry
  194. def new_guard_filter_fn(
  195. guard_entries: list[GuardFilterEntry],
  196. ) -> list[bool]:
  197. return [
  198. (
  199. not (
  200. g.is_global
  201. or g.guard_type
  202. in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
  203. )
  204. )
  205. for g in guard_entries
  206. ]
  207. hooks.guard_filter_fn = new_guard_filter_fn
  208. check_fn = dynamo_output.build_guards(
  209. fn.__code__, hooks=hooks, save=True, strict_error=True
  210. )
  211. assert check_fn.guards_state is not None
  212. backend_input = capture_output.backend_input
  213. backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
  214. output_graph = dynamo_output.tracer_output.output_graph
  215. assert output_graph is not None
  216. import_sources = output_graph.import_sources
  217. with (
  218. torch._guards.tracing(TracingContext(backend_input.fake_mode)),
  219. torch._functorch.config.patch("bundled_autograd_cache", True),
  220. ):
  221. compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs)
  222. # If Inductor backend is used, grab the compiled_fn from PrecompileContext
  223. # TODO: this should be replaced once we make the backend return the SerializableCallable directly.
  224. if isinstance(backend, torch._TorchCompileInductorWrapper):
  225. compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id(
  226. backend_input.backend_id
  227. )
  228. if not isinstance(compiled_fn, SerializableCallable):
  229. if hasattr(backend, "compiler_fn"):
  230. compiler_fn = backend.compiler_fn
  231. else:
  232. compiler_fn = backend
  233. raise RuntimeError(
  234. f"Compiled function type {type(compiled_fn)} (produced "
  235. + f"from backend {compiler_fn}) does not implement SerializableCallable."
  236. )
  237. artifacts = CompileArtifacts(
  238. signature=signature,
  239. bytecode=dynamo_output.bytecode,
  240. guard_manager=check_fn.guard_manager,
  241. guards_state=check_fn.guards_state,
  242. import_sources=import_sources,
  243. backend_id=backend_input.backend_id,
  244. compiled_fn=compiled_fn,
  245. original_code=fn.__code__,
  246. closure=fn.__closure__,
  247. )
  248. aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
  249. return aot_compiled_fn