package.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. """
  2. This module provides the infrastructure for creating and managing compile package
  3. for torch.compile. We mainly have two abstractions here:
  4. - CompilePackage: Overarching data structure for store and lookup a list of compiled codes.
  5. - CodeCacheEntry: Data structure for a single code being compiled by torch.compile.
  6. The caching behavior is always under user control explicitly so that a stronger guarantee can
  7. be provided about cache hit for a specific compiled model. Users can load the compile package
  8. from a different process or host.
  9. """
  10. import abc
  11. import ast
  12. import contextlib
  13. import dataclasses
  14. import functools
  15. import hashlib
  16. import importlib
  17. import inspect
  18. import logging
  19. import os
  20. import pickle
  21. import platform
  22. import shutil
  23. import sys
  24. import types
  25. from collections.abc import Generator, Iterator
  26. from typing import Any, Callable, NewType, Optional
  27. from typing_extensions import Never
  28. import torch
  29. import torch._inductor.package
  30. from torch._dynamo.exc import PackageError
  31. from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
  32. from torch._inductor.runtime.cache_dir_utils import cache_dir
  33. from torch.compiler._cache import CacheArtifactFactory
  34. from .bytecode_transformation import get_code_keys
  35. from .utils import dynamo_timed, increment_frame
  36. logger = logging.getLogger(__name__)
  37. @dataclasses.dataclass(frozen=True)
  38. class SerializedCode:
  39. co_argcount: int
  40. co_posonlyargcount: int
  41. co_kwonlyargcount: int
  42. co_nlocals: int
  43. co_stacksize: int
  44. co_flags: int
  45. co_code: bytes
  46. co_consts: tuple[Any, ...]
  47. co_names: tuple[str, ...]
  48. co_varnames: tuple[str, ...]
  49. co_filename: str
  50. co_name: str
  51. co_firstlineno: int
  52. co_cellvars: tuple[str, ...]
  53. co_freevars: tuple[str, ...]
  54. co_linetable: Optional[bytes] = None
  55. co_qualname: Optional[str] = None
  56. co_exceptiontable: Optional[bytes] = None
  57. co_lnotab: Optional[str] = None
  58. @classmethod
  59. @functools.cache
  60. def from_code_object(cls, code: types.CodeType) -> "SerializedCode":
  61. kwargs = {key: getattr(code, key) for key in get_code_keys()}
  62. kwargs["co_consts"] = tuple(
  63. cls.from_code_object(c) if isinstance(c, types.CodeType) else c
  64. for c in kwargs["co_consts"]
  65. )
  66. return cls(**kwargs)
  67. @classmethod
  68. @functools.cache
  69. def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType:
  70. kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()}
  71. kwargs["co_consts"] = tuple(
  72. cls.to_code_object(c) if isinstance(c, SerializedCode) else c
  73. for c in kwargs["co_consts"]
  74. )
  75. return types.CodeType(
  76. *kwargs.values(),
  77. )
  78. @dataclasses.dataclass
  79. class _GuardedCodeCacheEntry:
  80. """
  81. Contains the serializable information associated with a single compilation in dynamo.
  82. To restore an execution of compiled code, we will need to serialize the following data:
  83. - Dynamo bytecode for mapping Python inputs/outputs.
  84. - Dynamo guards.
  85. """
  86. guards_state: bytes
  87. dynamo_code: SerializedCode
  88. _BackendId = NewType("_BackendId", str) # __compiled_fn
  89. _FunctionId = NewType("_FunctionId", str) # __resume_at
  90. @dataclasses.dataclass(frozen=True)
  91. class InlinedSource:
  92. module: str
  93. firstlineno: int
  94. lastlineno: int
  95. checksum: str
  96. @dataclasses.dataclass
  97. class DynamoCaptureOutput:
  98. """
  99. Core information generated from Dynamo for fullgraph=True.
  100. """
  101. guarded_codes: list[_GuardedCodeCacheEntry]
  102. backend_ids: list[_BackendId]
  103. @dataclasses.dataclass
  104. class _DynamoCodeCacheEntry(DynamoCaptureOutput):
  105. """
  106. Contains the serializable information associated with a single code object
  107. in dynamo. To restore an execution of compiled code, we will need the following
  108. ingredients:
  109. 1. The "original" code object, which serves as the entry point for eager
  110. execution, i.e. the code only executed when there's no cache entry hit.
  111. 2. The python module name this code object belongs to, for identifying the
  112. enclosing global scope to inject compiled and resume functions.
  113. 3. A list of function names that pointing to this code object. There could be
  114. multiple function objects pointing to the same code such as recursive functions.
  115. 4. A list of guarded code that eval frame dispatches to.
  116. 5. A list of imported module objects unioned from all compiled branches.
  117. 6. A list of "backends" (compiled fx graph) unioned from all compield branches.
  118. 7. A string path used to access the original code object users defined.
  119. A code object can be accessed by "{python_module}.{function_name}.{code_source}" .
  120. 8. A boolean flag indicating whether the function is installed to global scope.
  121. 9. A boolean flag indicating whether the function has a compile id.
  122. 10. Whether or not this code entry was bypassed
  123. """
  124. python_code: SerializedCode
  125. python_module: str
  126. function_names: list[_FunctionId]
  127. import_sources: dict[str, str]
  128. code_source: Optional[str]
  129. install_to_global: bool
  130. has_compile_id: bool = False
  131. bypassed: bool = False
  132. def _lookup_code(entry: _DynamoCodeCacheEntry) -> types.CodeType:
  133. assert len(entry.function_names) == 1
  134. fn: Any = sys.modules[entry.python_module]
  135. parts = entry.function_names[0].split(".")
  136. for part in parts:
  137. fn = getattr(fn, part)
  138. if entry.code_source:
  139. parts = entry.code_source.split(".")
  140. for part in parts:
  141. if part.endswith("]"):
  142. index_begin = part.rfind("[")
  143. assert isinstance(index_begin, int) and index_begin >= 0
  144. attr = getattr(fn, part[:index_begin], None)
  145. if attr is None:
  146. raise PackageError(f"Cannot find source for code entry {entry}")
  147. fn = attr[ast.literal_eval(part[index_begin + 1 : -1])]
  148. else:
  149. fn = getattr(fn, part)
  150. else:
  151. raise PackageError(f"Cannot find source for code entry {entry}")
  152. assert isinstance(fn, types.CodeType)
  153. return fn
  154. def _raise_resolution_error(code: types.CodeType, scope: Any) -> Never:
  155. raise PackageError(
  156. f"Cannot resolve a fully qualified name for {code}. Lookup scope: {scope}"
  157. )
  158. def _get_code_source(code: types.CodeType) -> tuple[str, str]:
  159. """
  160. Given a code object, return a fully qualified name which will be used as
  161. a serialized handle to access the code object from the new process.
  162. This is normally a straightforward process, but there are some corner cases:
  163. 1. When a function is defined with decorator, then this function will be captured
  164. inside a closure with the wrapper object.
  165. 2. When a function is defined as a nested function, then the code object will be
  166. stored on the co_consts field of the parent code object by Python compiler.
  167. This function handles all of the corner cases above.
  168. """
  169. module = inspect.getmodule(code)
  170. if module is None:
  171. raise PackageError(f"Cannot find module for code {code}")
  172. toplevel: Any = module
  173. if sys.version_info >= (3, 11):
  174. parts = code.co_qualname.split(".")
  175. for part in parts:
  176. if not hasattr(toplevel, part):
  177. _raise_resolution_error(code, toplevel)
  178. toplevel = getattr(toplevel, part)
  179. if inspect.isfunction(toplevel):
  180. break
  181. seen = set()
  182. def _find_code_source(obj: Any) -> Optional[str]:
  183. nonlocal toplevel
  184. nonlocal seen
  185. if obj in seen:
  186. return None
  187. seen.add(obj)
  188. if inspect.iscode(obj):
  189. if obj is code:
  190. return ""
  191. for i, const in enumerate(obj.co_consts):
  192. if (res := _find_code_source(const)) is not None:
  193. return f".co_consts[{i}]{res}"
  194. if inspect.isfunction(obj):
  195. if (res := _find_code_source(obj.__code__)) is not None:
  196. toplevel = obj
  197. return f".__code__{res}"
  198. if obj.__closure__ is not None:
  199. for i, cell in enumerate(obj.__closure__):
  200. try:
  201. cell_contents = cell.cell_contents
  202. except ValueError:
  203. continue
  204. if not (
  205. inspect.isfunction(cell_contents)
  206. or inspect.iscode(cell_contents)
  207. ):
  208. continue
  209. if (res := _find_code_source(cell_contents)) is not None:
  210. toplevel = obj
  211. return f".__closure__[{i}].cell_contents{res}"
  212. if sys.version_info < (3, 11):
  213. if inspect.ismodule(obj):
  214. for value in obj.__dict__.values():
  215. if not (inspect.isfunction(value) or inspect.isclass(value)):
  216. continue
  217. if (res := _find_code_source(value)) is not None:
  218. return res
  219. if inspect.isclass(obj):
  220. for name, value in obj.__dict__.items():
  221. value = getattr(obj, name)
  222. if not (inspect.isfunction(value) or inspect.isclass(value)):
  223. continue
  224. if (res := _find_code_source(value)) is not None:
  225. if value.__name__ != name:
  226. _raise_resolution_error(code, toplevel)
  227. return res
  228. return None
  229. code_source = _find_code_source(toplevel)
  230. if code_source is None:
  231. _raise_resolution_error(code, toplevel)
  232. return toplevel.__qualname__, code_source.strip(".")
  233. @dataclasses.dataclass
  234. class _DynamoCacheEntry:
  235. codes: list[_DynamoCodeCacheEntry]
  236. inlined_sources: set[InlinedSource]
  237. python_version: str = platform.python_version()
  238. torch_version: str = torch.__version__
  239. @property
  240. def backend_ids(self) -> set[_BackendId]:
  241. return {backend_id for code in self.codes for backend_id in code.backend_ids}
  242. @CacheArtifactFactory.register
  243. class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
  244. @staticmethod
  245. def type() -> str:
  246. return "precompile_dynamo"
  247. def after_deserialization(self) -> _DynamoCacheEntry:
  248. return pickle.loads(self.content)
  249. def _hash_source(source: str) -> str:
  250. sha256_hash = hashlib.sha256()
  251. sha256_hash.update(source.encode())
  252. return sha256_hash.hexdigest()
  253. def _get_sourcelines(
  254. m: types.ModuleType, firstlineno: int, lastlineno: int
  255. ) -> list[str]:
  256. return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1]
  257. def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str:
  258. return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno)))
  259. def _compile_frame_context(
  260. code: types.CodeType,
  261. ) -> contextlib.AbstractContextManager[None]:
  262. from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start
  263. from torch._guards import compile_context, CompileContext
  264. # Each code represents a new compile frame
  265. # recompiles on the same frame are all saved
  266. # under the same cache entry, so we don't have recompile ids
  267. # i.e. If cold start had 0/0, 0/1, 1/0, 1/1, these would be
  268. # collapsed into 0/0, 1/0 on warm.
  269. @contextlib.contextmanager
  270. def _ctx() -> Iterator[None]:
  271. increment_frame()
  272. compile_id = get_compile_id(frame_state={})
  273. with (
  274. compile_context(CompileContext(compile_id)),
  275. dynamo_timed(
  276. "_compile.compile_inner",
  277. phase_name="entire_frame_compile",
  278. dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
  279. # TODO: save all relevant compilation metrics
  280. metadata={
  281. "frame_key": str(torch._dynamo.utils.curr_frame),
  282. "co_name": code.co_name,
  283. "co_filename": code.co_filename,
  284. "co_firstlineno": code.co_firstlineno,
  285. },
  286. ),
  287. ):
  288. log_dynamo_start(code)
  289. yield
  290. return _ctx()
  291. class CompilePackage:
  292. """
  293. CompilePackage is considered a low level component and should not be directly exposed to
  294. end users. It has the following interface:
  295. 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
  296. a. when `dynamo` argument is None, it will construct a brand new CompilePackage object.
  297. b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
  298. 2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object.
  299. 3. `package.install(backends) which will handle all the side-effectful global scope
  300. updates with compiled functions and resume functions.
  301. """
  302. def __init__(
  303. self,
  304. fn: Optional[Callable[..., Any]],
  305. dynamo: Optional[_DynamoCacheEntry] = None,
  306. ignore_inlined_sources: bool = False,
  307. ) -> None:
  308. self._innermost_fn = None
  309. self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
  310. self._current_entry: Optional[_DynamoCodeCacheEntry] = None
  311. self._installed_globals: dict[types.ModuleType, list[str]] = {}
  312. # For debugging/testing purpose only.
  313. self._cached_backends: dict[_BackendId, Any] = {}
  314. self._inlined_sources: set[InlinedSource] = set()
  315. self._resume_codes: set[types.CodeType] = set()
  316. self._initialized = False
  317. if fn is not None:
  318. self.initialize(fn, dynamo, ignore_inlined_sources)
  319. self.uninstall()
  320. self.validate()
  321. def is_initialized(self) -> bool:
  322. return self._initialized
  323. def initialize(
  324. self,
  325. fn: Any,
  326. dynamo: Optional[_DynamoCacheEntry] = None,
  327. ignore_inlined_sources: bool = False,
  328. ) -> None:
  329. from .eval_frame import innermost_fn
  330. assert not self._initialized
  331. self._inlined_sources = set()
  332. self._innermost_fn = innermost_fn(fn) # type: ignore[assignment]
  333. assert self._innermost_fn is not None
  334. if dynamo is not None:
  335. assert isinstance(dynamo, _DynamoCacheEntry)
  336. if dynamo.python_version != platform.python_version():
  337. raise RuntimeError(
  338. f"Compile package was created with a different Python version: {dynamo.python_version}"
  339. )
  340. if dynamo.torch_version != torch.__version__:
  341. raise RuntimeError(
  342. f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
  343. )
  344. if not ignore_inlined_sources:
  345. for code in dynamo.inlined_sources:
  346. m = importlib.import_module(code.module)
  347. checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno)
  348. if checksum != code.checksum:
  349. raise RuntimeError(
  350. f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
  351. )
  352. self._inlined_sources = dynamo.inlined_sources
  353. main, *codes = dynamo.codes
  354. self._codes = {self._innermost_fn.__code__: main}
  355. for code in codes:
  356. self._codes[SerializedCode.to_code_object(code.python_code)] = code
  357. else:
  358. self._add_function(
  359. self._innermost_fn.__code__, self._innermost_fn.__module__
  360. )
  361. self._initialized = True
  362. def _add_function(
  363. self,
  364. python_code: types.CodeType,
  365. python_module: str,
  366. function_name: Optional[_FunctionId] = None,
  367. code_source: Optional[str] = None,
  368. install_to_global: bool = False,
  369. ) -> None:
  370. if python_code not in self._codes:
  371. code = _DynamoCodeCacheEntry(
  372. python_code=SerializedCode.from_code_object(python_code),
  373. python_module=python_module,
  374. function_names=[],
  375. guarded_codes=[],
  376. import_sources={},
  377. backend_ids=[],
  378. code_source=code_source,
  379. install_to_global=install_to_global,
  380. )
  381. self._codes[python_code] = code
  382. else:
  383. code = self._codes[python_code]
  384. assert code.python_module == python_module
  385. assert code.install_to_global == install_to_global
  386. assert code.code_source == code_source
  387. if function_name is not None:
  388. code.function_names.append(function_name)
  389. @property
  390. def cached_backends(self) -> dict[_BackendId, Any]:
  391. return self._cached_backends
  392. @functools.cached_property
  393. def source_id(self) -> str:
  394. assert self._innermost_fn is not None
  395. return CompilePackage.source_id_from_fn(self._innermost_fn)
  396. def _add_user_function(self, code: types.CodeType) -> None:
  397. function_name, code_source = _get_code_source(code)
  398. module = inspect.getmodule(code)
  399. if module is None:
  400. raise PackageError(f"Cannot find module for code {code}")
  401. self._add_function(
  402. code,
  403. module.__name__,
  404. function_name=_FunctionId(function_name),
  405. code_source=code_source,
  406. )
  407. @contextlib.contextmanager
  408. def code_context(self, code: types.CodeType) -> Generator[None, None, None]:
  409. assert self._current_entry is None
  410. # Sometimes user code cannot be inlined in dynamo resulting in extra user code
  411. # being compiled. We should record these as when they are actually invoked.
  412. if code not in self._codes:
  413. self._add_user_function(code)
  414. entry = self._codes[code]
  415. self._current_entry = entry
  416. try:
  417. yield
  418. finally:
  419. if (
  420. entry.bypassed
  421. ): # Remove the code from the cache entry if it's been bypassed
  422. del self._codes[code]
  423. entry.has_compile_id = True
  424. self._current_entry = None
  425. def add_guarded_code(
  426. self,
  427. guards_state: bytes,
  428. dynamo_code: types.CodeType,
  429. ) -> None:
  430. assert self._current_entry is not None
  431. if self._current_entry.bypassed:
  432. return
  433. guarded_code_entry = _GuardedCodeCacheEntry(
  434. guards_state=guards_state,
  435. dynamo_code=SerializedCode.from_code_object(dynamo_code),
  436. )
  437. self._current_entry.guarded_codes.append(guarded_code_entry)
  438. def add_inlined_source(self, sources: list[types.CodeType]) -> None:
  439. assert self._current_entry is not None
  440. if self._current_entry.bypassed:
  441. return
  442. for code in sources:
  443. if code in self._resume_codes:
  444. continue
  445. module = inspect.getmodule(code)
  446. if module is None:
  447. continue
  448. sourcelines, firstlineno = inspect.getsourcelines(code)
  449. lastlineno = firstlineno + len(sourcelines)
  450. source = "".join(sourcelines)
  451. assert source == "".join(_get_sourcelines(module, firstlineno, lastlineno))
  452. self._inlined_sources.add(
  453. InlinedSource(
  454. module=module.__name__,
  455. firstlineno=firstlineno,
  456. lastlineno=lastlineno,
  457. checksum=_hash_source(source),
  458. )
  459. )
  460. def bypass_current_entry(self) -> None:
  461. assert self._current_entry is not None
  462. self._current_entry.bypassed = True
  463. def add_resume_function(
  464. self,
  465. python_code: types.CodeType,
  466. python_module: str,
  467. function_name: Optional[str],
  468. ) -> None:
  469. self._add_function(
  470. python_code,
  471. python_module,
  472. function_name=_FunctionId(function_name) if function_name else None,
  473. install_to_global=True,
  474. )
  475. self._resume_codes.add(python_code)
  476. def add_import_source(self, alias: str, module_name: str) -> None:
  477. assert self._current_entry is not None
  478. self._current_entry.import_sources[alias] = module_name
  479. def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None:
  480. assert self._current_entry is not None
  481. assert backend_id.startswith("__compiled_fn_") # sanity check
  482. backend_id = _BackendId(backend_id)
  483. self._current_entry.backend_ids.append(backend_id)
  484. if backend is not None:
  485. self._cached_backends[backend_id] = backend
  486. def validate(self) -> None:
  487. assert self._current_entry is None
  488. assert self._innermost_fn is not None
  489. assert self._initialized
  490. assert next(iter(self._codes)) is self._innermost_fn.__code__
  491. def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None:
  492. module.__dict__[name] = value
  493. self._installed_globals.setdefault(module, []).append(name)
  494. def uninstall(self) -> None:
  495. from torch._C._dynamo.eval_frame import _reset_precompile_entries
  496. assert self._innermost_fn is not None
  497. for module, names in self._installed_globals.items():
  498. for name in names:
  499. module.__dict__.pop(name)
  500. self._installed_globals = {}
  501. _reset_precompile_entries(self._innermost_fn.__code__)
  502. def install(self, backends: dict[_BackendId, Any]) -> None:
  503. """
  504. Sync the package states to the compiled function. This includes the following actions:
  505. 1. Clean up the previously installed states.
  506. 2. Install the compiled functions to global scopes.
  507. 3. Install the precompiled cache entries to ExtraStates on the code object.
  508. """
  509. from torch._C._dynamo.eval_frame import _load_precompile_entry
  510. from .output_graph import get_builtins_dict
  511. self.uninstall()
  512. for code, entry in self._codes.items():
  513. context = (
  514. _compile_frame_context(code)
  515. if entry.has_compile_id
  516. else contextlib.nullcontext()
  517. )
  518. with context:
  519. module = sys.modules[entry.python_module]
  520. for alias, module_name in entry.import_sources.items():
  521. self._install_global(
  522. module, alias, importlib.import_module(module_name)
  523. )
  524. target_code = code
  525. if entry.install_to_global:
  526. for function_name in entry.function_names:
  527. fn = types.FunctionType(code, module.__dict__, function_name)
  528. self._install_global(module, function_name, fn)
  529. if entry.code_source:
  530. target_code = _lookup_code(entry)
  531. for backend_id in entry.backend_ids:
  532. if backend_id not in backends:
  533. raise RuntimeError(
  534. f"Backend {backend_id} is not found in the given backends"
  535. )
  536. with dynamo_timed(
  537. "after_deserialization", phase_name="backend_compile"
  538. ):
  539. backend = backends[backend_id].after_deserialization()
  540. self._install_global(
  541. module,
  542. backend_id,
  543. torch._dynamo.disable(backend),
  544. )
  545. if len(entry.guarded_codes) == 0:
  546. # Dynamo generates empty graph for trivial functions, should just skip them
  547. # in these cases.
  548. torch._dynamo.eval_frame.skip_code(target_code)
  549. for guarded_code in entry.guarded_codes:
  550. guards_state = pickle.loads(guarded_code.guards_state)
  551. runtime_global_scope = sys.modules[entry.python_module].__dict__
  552. # The installed builtins dict might be absent from the runtime
  553. # while loading guards. Populate it if it's missing.
  554. if (
  555. builtin_dict_name
  556. := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals
  557. ):
  558. builtins_dict = get_builtins_dict(runtime_global_scope)
  559. if builtin_dict_name in runtime_global_scope:
  560. assert (
  561. runtime_global_scope[builtin_dict_name] is builtins_dict
  562. )
  563. else:
  564. runtime_global_scope[builtin_dict_name] = builtins_dict
  565. assert isinstance(guards_state, torch._dynamo.guards.GuardsState)
  566. check_fn_manager = torch._dynamo.guards.CheckFunctionManager(
  567. target_code,
  568. guards_state.output_graph,
  569. shape_code_parts=guards_state.shape_code_parts,
  570. runtime_global_scope=runtime_global_scope,
  571. )
  572. _load_precompile_entry(
  573. target_code,
  574. check_fn_manager.guard_manager,
  575. SerializedCode.to_code_object(guarded_code.dynamo_code),
  576. )
  577. def cache_entry(self) -> _DynamoCacheEntry:
  578. self.validate()
  579. return _DynamoCacheEntry(
  580. codes=list(self._codes.values()), inlined_sources=self._inlined_sources
  581. )
  582. @staticmethod
  583. def source_id_from_fn(fn: Callable[..., Any]) -> str:
  584. from .eval_frame import innermost_fn
  585. innermost_fn_ = innermost_fn(fn)
  586. sha256_hash = hashlib.sha256()
  587. sha256_hash.update(innermost_fn_.__qualname__.encode())
  588. sha256_hash.update(str(innermost_fn_.__code__.co_firstlineno).encode())
  589. return sha256_hash.hexdigest()
  590. @CacheArtifactFactory.register
  591. class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
  592. @staticmethod
  593. def type() -> str:
  594. return "precompile_eager"
  595. def after_deserialization(self) -> Any:
  596. return pickle.loads(self.content)
  597. _Backends = dict[_BackendId, PrecompileCacheArtifact[Any]]
  598. class DynamoStore(abc.ABC):
  599. """
  600. A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them.
  601. This is an abstract base class for different storage implementations.
  602. """
  603. def record_package(self, package: CompilePackage) -> None:
  604. """
  605. Records a package to PrecompileContext, so that it can be serialized later.
  606. """
  607. cache_entry = package.cache_entry()
  608. pickled_result = pickle.dumps(cache_entry)
  609. PrecompileContext.record_artifact(
  610. _DynamoCacheArtifact.type(), key=package.source_id, content=pickled_result
  611. )
  612. def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None:
  613. """
  614. Records eager fx graphs to PrecompileContext for testing purposes.
  615. """
  616. pickled_result = pickle.dumps(backend)
  617. PrecompileContext.record_artifact(
  618. EagerCacheArtifact.type(), key=backend_id, content=pickled_result
  619. )
  620. @abc.abstractmethod
  621. def clear(self) -> None: ...
  622. @abc.abstractmethod
  623. def write(
  624. self,
  625. dynamo: _DynamoCacheEntry,
  626. backends: _Backends,
  627. path: str,
  628. ) -> None:
  629. """
  630. Abstract method to write dynamo cache entry and backends to storage.
  631. Args:
  632. dynamo: The dynamo cache entry to write
  633. backends: Dictionary of backend content to write
  634. path: Path or key to identify where to write the data
  635. """
  636. ...
  637. def save_cache_entry(self, cache_entry: _DynamoCacheEntry, key: str) -> None:
  638. """
  639. Saves a package to a given path. Grabs backends from PrecompileContext.
  640. """
  641. backend_content: _Backends = {}
  642. for backend_id in cache_entry.backend_ids:
  643. serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id)
  644. if serialized_backend is None:
  645. raise RuntimeError(
  646. f"Backend {backend_id} is not found in the given backends"
  647. )
  648. assert isinstance(serialized_backend, PrecompileCacheArtifact)
  649. backend_content[backend_id] = serialized_backend
  650. self.write(cache_entry, backend_content, key)
  651. def save_package(self, package: CompilePackage, key: str) -> None:
  652. """
  653. Saves a package to a given path. Grabs backends from PrecompileContext.
  654. """
  655. self.record_package(package)
  656. cache_entry = package.cache_entry()
  657. self.save_cache_entry(cache_entry, key)
  658. @abc.abstractmethod
  659. def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
  660. """
  661. Abstract method to read dynamo cache entry and backends from storage.
  662. Args:
  663. path: Path or key to identify where to read the data from
  664. Returns:
  665. A tuple containing (dynamo_cache_entry, backend_content)
  666. """
  667. ...
  668. def load_cache_entry(
  669. self, key: str
  670. ) -> tuple[_DynamoCacheEntry, dict[_BackendId, Any]]:
  671. cache_entry, backend_content = self.read(key)
  672. for backend_id, backend in backend_content.items():
  673. PrecompileContext.record_artifact(
  674. backend.type(), key=backend.key, content=backend.content
  675. )
  676. backend_content[backend_id] = backend
  677. return cache_entry, backend_content
  678. def load_package(
  679. self, fn: Any, key: str
  680. ) -> tuple[CompilePackage, dict[_BackendId, Any]]:
  681. """
  682. Loads a package from a given path and returns it plus a list of deserialized backends
  683. """
  684. cache_entry, backend_content = self.load_cache_entry(key)
  685. package = CompilePackage(fn, cache_entry)
  686. return package, backend_content
  687. class InMemoryDynamoStore(DynamoStore):
  688. """
  689. A DynamoStore implementation that keeps state about CompilePackages in memory.
  690. """
  691. def __init__(self) -> None:
  692. self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {}
  693. def clear(self) -> None:
  694. self.packages.clear()
  695. def write(
  696. self,
  697. dynamo: _DynamoCacheEntry,
  698. backends: _Backends,
  699. path: str,
  700. ) -> None:
  701. """
  702. Store the dynamo cache entry and backends in memory instead of writing to disk.
  703. """
  704. self.packages[path] = (dynamo, backends)
  705. def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
  706. """
  707. Read dynamo cache entry and backends from memory.
  708. """
  709. if path not in self.packages:
  710. raise RuntimeError(f"No package found with key {path}")
  711. return self.packages[path]
  712. class DiskDynamoStore(DynamoStore):
  713. """
  714. A DynamoStore implementation that keeps state about CompilePackages on disk.
  715. """
  716. def __init__(self, path_prefix: str = ""):
  717. """
  718. Initialize a DiskDynamoStore with a path prefix.
  719. Args:
  720. path_prefix: Prefix directory for where to put CompilePackages on disk
  721. """
  722. self.path_prefix = path_prefix
  723. def clear(self) -> None:
  724. """
  725. Clear all CompilePackages from disk.
  726. """
  727. if self.path_prefix:
  728. shutil.rmtree(self.path_prefix, ignore_errors=True)
  729. def write(
  730. self,
  731. dynamo: _DynamoCacheEntry,
  732. backends: _Backends,
  733. path: str,
  734. ) -> None:
  735. """
  736. Write dynamo cache entry and backends to disk.
  737. """
  738. path = os.path.join(self.path_prefix, path) if self.path_prefix else path
  739. try:
  740. os.makedirs(path, exist_ok=True)
  741. with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
  742. pickle.dump(dynamo, dynamo_path)
  743. with open(os.path.join(path, "backends"), "wb") as backend_path:
  744. pickle.dump(backends, backend_path)
  745. except Exception as e:
  746. raise RuntimeError(f"Failed to save package to {path}: {e}") from e
  747. def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
  748. """
  749. Read dynamo cache entry and backends from disk.
  750. """
  751. path = os.path.join(self.path_prefix, path) if self.path_prefix else path
  752. try:
  753. with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
  754. cache_entry = pickle.load(dynamo_path)
  755. with open(os.path.join(path, "backends"), "rb") as backend_path:
  756. backend_content = pickle.load(backend_path)
  757. return cache_entry, backend_content
  758. except Exception as e:
  759. raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
  760. class DiskDynamoCache(DiskDynamoStore):
  761. """
  762. Special DiskDynamoStore which adds some helper functions for automatically
  763. tracking paths of packages
  764. """
  765. def save(self, package: CompilePackage) -> None:
  766. """
  767. Saves a package to a given path. Grabs backends from PrecompileContext.
  768. """
  769. key = package.source_id
  770. logger.info("Saving CompilePackage for %s", package.source_id)
  771. super().save_package(package, key)
  772. def load(
  773. self, fn: Callable[..., Any]
  774. ) -> Optional[tuple[_DynamoCacheEntry, dict[_BackendId, Any]]]:
  775. """
  776. Loads a package from a given path and returns it plus a list of deserialized backends
  777. """
  778. key = CompilePackage.source_id_from_fn(fn)
  779. logger.info("Loading CompilePackage for %s", key)
  780. path = os.path.join(self.path_prefix, key)
  781. if os.path.exists(path):
  782. try:
  783. result = super().load_cache_entry(key)
  784. return result
  785. except Exception as e:
  786. logger.warning("Failed to load package from path %s: %s", path, str(e))
  787. return None
  788. logger.info("No package found for %s", key)
  789. return None
  790. def load_and_install_package(
  791. self, fn: Callable[..., Any]
  792. ) -> Optional[CompilePackage]:
  793. """
  794. Load directly into a package and install backends
  795. """
  796. results = self.load(fn)
  797. if results is None:
  798. return None
  799. else:
  800. (entry, backends) = results
  801. package = CompilePackage(fn, entry)
  802. package.install(backends)
  803. return package
  804. DynamoCache = DiskDynamoCache(os.path.join(cache_dir(), "dynamo"))