| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- import abc
- import builtins
- import importlib
- import inspect
- import logging
- import pickle
- import types
- from dataclasses import dataclass
- from typing import Any, Callable, Optional
- import torch
- import torch.fx
- from torch._dynamo.precompile_context import PrecompileContext
- from . import convert_frame
- from .hooks import Hooks
- log = logging.getLogger(__name__)
- class SerializableCallable(abc.ABC):
- @classmethod
- @abc.abstractmethod
- def serialize_compile_artifacts(cls, fn: Any) -> bytes:
- pass
- @classmethod
- @abc.abstractmethod
- def deserialize_compile_artifacts(cls, data: bytes) -> Any:
- pass
- def bind_locals(
- signature: inspect.Signature, *args: Any, **kwargs: Any
- ) -> dict[str, Any]:
- bound_arguments = signature.bind(*args, **kwargs)
- bound_arguments.apply_defaults()
- return bound_arguments.arguments
- @dataclass
- class CompileArtifacts:
- signature: inspect.Signature
- bytecode: types.CodeType
- guard_manager: Optional[torch._dynamo.guards.GuardManagerWrapper]
- guards_state: bytes
- import_sources: dict[str, str]
- backend_id: str
- compiled_fn: SerializableCallable
- original_code: types.CodeType
- closure: Optional[tuple[Any, ...]]
- @dataclass
- class AOTCompiledFunction:
- _artifacts: CompileArtifacts
- def guard_check(self, *args: Any, **kwargs: Any) -> bool:
- f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
- assert self._artifacts.guard_manager is not None
- return self._artifacts.guard_manager.check(f_locals)
- def __post_init__(self) -> None:
- import_sources = {
- alias: importlib.import_module(module_name)
- for alias, module_name in self._artifacts.import_sources.items()
- }
- f_globals = {
- **import_sources,
- self._artifacts.backend_id: self._artifacts.compiled_fn,
- }
- self.fn = types.FunctionType(
- self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
- )
- if self._artifacts.guard_manager is None:
- guards_state = pickle.loads(self._artifacts.guards_state)
- self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager(
- self._artifacts.original_code,
- guards_state.output_graph,
- shape_code_parts=guards_state.shape_code_parts,
- runtime_global_scope=f_globals,
- ).guard_manager
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- assert self._artifacts.guard_manager is not None
- if not self.guard_check(*args, **kwargs):
- f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
- reason = str(self._artifacts.guard_manager.check_verbose(f_locals))
- raise RuntimeError(f"GuardManager check failed, reason: {reason}")
- return self.fn(*args, **kwargs)
- def save_compiled_function(self, path: str) -> None:
- with open(path, "wb") as f:
- f.write(type(self).serialize(self))
- @classmethod
- def serialize(cls, fn: "AOTCompiledFunction") -> bytes:
- from torch._dynamo.package import SerializedCode
- state = fn._artifacts.__dict__.copy()
- state["guard_manager"] = None
- state["bytecode"] = SerializedCode.from_code_object(state["bytecode"])
- compiled_fn = state["compiled_fn"]
- state["compiled_fn"] = (
- type(compiled_fn).deserialize_compile_artifacts,
- type(compiled_fn).serialize_compile_artifacts(compiled_fn),
- )
- state["original_code"] = SerializedCode.from_code_object(state["original_code"])
- return pickle.dumps(state)
- @classmethod
- def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
- from torch._dynamo.package import SerializedCode
- state = pickle.loads(data)
- state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
- deserializer, compiled_fn_state = state["compiled_fn"]
- state["compiled_fn"] = deserializer(compiled_fn_state)
- state["original_code"] = SerializedCode.to_code_object(state["original_code"])
- artifacts = CompileArtifacts(**state)
- return cls(artifacts)
- class BundledAOTAutogradSerializableCallable(SerializableCallable):
- """
- Represents a serializable callable generated by compile_fx.
- This class wraps around the compiled function generated by AOTAutograd.
- TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
- this object should be what's *returned* by aot_module_simplified.
- We'll do that refactor in a later PR.
- """
- def __init__(self, artifact: Any) -> None:
- """
- Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
- of a compiled function generated by AOTAutograd.
- """
- self.compiled_fn = artifact.after_deserialization()
- self.data = artifact.content
- def __getattr__(self, attr: Any) -> Any:
- if hasattr(self, attr):
- return getattr(super(), attr)
- else:
- return getattr(self.compiled_fn, attr)
- @classmethod
- def from_backend_id(
- cls, backend_id: str
- ) -> "BundledAOTAutogradSerializableCallable":
- """
- Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable
- that wraps around the compiled function generated by AOTAutograd.
- """
- artifact = PrecompileContext.serialize_artifact_by_key(backend_id)
- if artifact is None:
- raise RuntimeError("No artifact found for backend_id: " + backend_id)
- return cls(artifact)
- @classmethod
- def serialize_compile_artifacts(
- cls, fn: "BundledAOTAutogradSerializableCallable"
- ) -> bytes:
- return fn.data
- @classmethod
- def deserialize_compile_artifacts(cls, data: bytes) -> Any:
- from torch._functorch._aot_autograd.autograd_cache import (
- BundledAOTAutogradCacheArtifact,
- )
- # The key in the artifact is not important here since we're not populating a cache,
- # we just want to grab the callable back out of the serialized entry
- artifact = BundledAOTAutogradCacheArtifact("", data)
- return cls(artifact)
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- return self.compiled_fn(*args, **kwargs)
- def aot_compile_fullgraph(
- model: Any,
- example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
- hooks: Hooks,
- backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
- ) -> AOTCompiledFunction:
- from torch._dynamo.guards import CheckFunctionManager
- from torch._dynamo.utils import dynamo_timed, get_metrics_context
- from torch._guards import compile_context, CompileContext, TracingContext
- args, kwargs = example_inputs
- if hasattr(model, "__self__"):
- fn = model.__func__
- args = (model.__self__,) + args
- elif inspect.isfunction(model):
- fn = model
- else:
- raise RuntimeError(f"Unsupported model code type {model}")
- signature = inspect.signature(fn)
- f_locals = bind_locals(signature, *args, **kwargs)
- if fn.__code__.co_freevars or fn.__closure__:
- assert len(fn.__closure__) == len(fn.__code__.co_freevars)
- f_locals.update(
- {
- name: cell.cell_contents
- for name, cell in zip(fn.__code__.co_freevars, fn.__closure__)
- }
- )
- with (
- compile_context(CompileContext(convert_frame.get_compile_id({}))),
- get_metrics_context(),
- dynamo_timed("fullgraph_capture"),
- ):
- capture_output = convert_frame.fullgraph_capture(
- convert_frame.FrameInfo(
- fn.__code__,
- fn.__globals__,
- f_locals,
- builtins.__dict__,
- closure=fn.__closure__ or (), # type: ignore[arg-type]
- )
- )
- dynamo_output = capture_output.dynamo_output
- if not hooks.guard_filter_fn:
- from torch._dynamo.types import GuardFilterEntry
- def new_guard_filter_fn(
- guard_entries: list[GuardFilterEntry],
- ) -> list[bool]:
- return [
- (
- not (
- g.is_global
- or g.guard_type
- in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
- )
- )
- for g in guard_entries
- ]
- hooks.guard_filter_fn = new_guard_filter_fn
- check_fn = dynamo_output.build_guards(
- fn.__code__, hooks=hooks, save=True, strict_error=True
- )
- assert check_fn.guards_state is not None
- backend_input = capture_output.backend_input
- backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
- output_graph = dynamo_output.tracer_output.output_graph
- assert output_graph is not None
- import_sources = output_graph.import_sources
- with (
- torch._guards.tracing(TracingContext(backend_input.fake_mode)),
- torch._functorch.config.patch("bundled_autograd_cache", True),
- ):
- compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs)
- # If Inductor backend is used, grab the compiled_fn from PrecompileContext
- # TODO: this should be replaced once we make the backend return the SerializableCallable directly.
- if isinstance(backend, torch._TorchCompileInductorWrapper):
- compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id(
- backend_input.backend_id
- )
- if not isinstance(compiled_fn, SerializableCallable):
- if hasattr(backend, "compiler_fn"):
- compiler_fn = backend.compiler_fn
- else:
- compiler_fn = backend
- raise RuntimeError(
- f"Compiled function type {type(compiled_fn)} (produced "
- + f"from backend {compiler_fn}) does not implement SerializableCallable."
- )
- artifacts = CompileArtifacts(
- signature=signature,
- bytecode=dynamo_output.bytecode,
- guard_manager=check_fn.guard_manager,
- guards_state=check_fn.guards_state,
- import_sources=import_sources,
- backend_id=backend_input.backend_id,
- compiled_fn=compiled_fn,
- original_code=fn.__code__,
- closure=fn.__closure__,
- )
- aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
- return aot_compiled_fn
|