package.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157
  1. """
  2. This module provides the infrastructure for creating and managing compile package
  3. for torch.compile. We mainly have two abstractions here:
  4. - CompilePackage: Overarching data structure for store and lookup a list of compiled codes.
  5. - CodeCacheEntry: Data structure for a single code being compiled by torch.compile.
  6. The caching behavior is always under user control explicitly so that a stronger guarantee can
  7. be provided about cache hit for a specific compiled model. Users can load the compile package
  8. from a different process or host.
  9. """
  10. import abc
  11. import ast
  12. import contextlib
  13. import dataclasses
  14. import functools
  15. import hashlib
  16. import importlib
  17. import inspect
  18. import json
  19. import logging
  20. import os
  21. import pickle
  22. import platform
  23. import shutil
  24. import sys
  25. import types
  26. from collections.abc import Callable, Generator, Iterator
  27. from contextlib import nullcontext
  28. from typing import Any, NewType, Optional, TYPE_CHECKING
  29. from typing_extensions import Never
  30. import torch
  31. from torch._dynamo.exc import PackageError
  32. from torch._dynamo.graph_utils import _graph_device_type
  33. from .bytecode_transformation import get_code_keys
  34. from .utils import counters, dynamo_timed, increment_frame
  35. logger = logging.getLogger(__name__)
  36. if TYPE_CHECKING:
  37. from .guards import GuardManagerWrapper, GuardsState
  38. @dataclasses.dataclass(frozen=True)
  39. class SerializedCode:
  40. co_argcount: int
  41. co_posonlyargcount: int
  42. co_kwonlyargcount: int
  43. co_nlocals: int
  44. co_stacksize: int
  45. co_flags: int
  46. co_code: bytes
  47. co_consts: tuple[Any, ...]
  48. co_names: tuple[str, ...]
  49. co_varnames: tuple[str, ...]
  50. co_filename: str
  51. co_name: str
  52. co_firstlineno: int
  53. co_cellvars: tuple[str, ...]
  54. co_freevars: tuple[str, ...]
  55. co_linetable: Optional[bytes] = None
  56. co_qualname: Optional[str] = None
  57. co_exceptiontable: Optional[bytes] = None
  58. co_lnotab: Optional[str] = None
  59. @classmethod
  60. @functools.cache
  61. def from_code_object(cls, code: types.CodeType) -> "SerializedCode":
  62. kwargs = {key: getattr(code, key) for key in get_code_keys()}
  63. kwargs["co_consts"] = tuple(
  64. cls.from_code_object(c) if isinstance(c, types.CodeType) else c
  65. for c in kwargs["co_consts"]
  66. )
  67. return cls(**kwargs)
  68. @classmethod
  69. @functools.cache
  70. def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType:
  71. kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()}
  72. kwargs["co_consts"] = tuple(
  73. cls.to_code_object(c) if isinstance(c, SerializedCode) else c
  74. for c in kwargs["co_consts"]
  75. )
  76. return types.CodeType(
  77. *kwargs.values(),
  78. )
  79. @dataclasses.dataclass
  80. class _GuardedCodeCacheEntry:
  81. """
  82. Contains the serializable information associated with a single compilation in dynamo.
  83. To restore an execution of compiled code, we will need to serialize the following data:
  84. - Dynamo bytecode for mapping Python inputs/outputs.
  85. - Dynamo guards.
  86. """
  87. guards_state: bytes
  88. dynamo_code: SerializedCode
  89. def load_guards_state(guards_state: bytes) -> Any:
  90. try:
  91. import torch.distributed.fsdp._fully_shard._fully_shard as _fully_shard
  92. ctx = _fully_shard.disable_fsdp_module_new_init()
  93. except ImportError:
  94. ctx = nullcontext() # type: ignore[assignment]
  95. with ctx:
  96. return pickle.loads(guards_state)
  97. def load_guard_manager(
  98. guards_state: "GuardsState",
  99. target_code: types.CodeType,
  100. runtime_global_scope: Any,
  101. ) -> "GuardManagerWrapper":
  102. from .output_graph import OutputGraphCommon
  103. return torch._dynamo.guards.CheckFunctionManager(
  104. target_code,
  105. OutputGraphCommon(guards_state.output_graph),
  106. shape_code_parts=guards_state.shape_code_parts,
  107. runtime_global_scope=runtime_global_scope,
  108. source_get_cache=guards_state.source_get_cache,
  109. ).guard_manager
  110. _BackendId = NewType("_BackendId", str) # __compiled_fn
  111. _FunctionId = NewType("_FunctionId", str) # __resume_at
  112. @dataclasses.dataclass(frozen=True)
  113. class InlinedSource:
  114. module: str
  115. firstlineno: int
  116. lastlineno: int
  117. checksum: str
  118. content: str
  119. @functools.cache
  120. def _get_module_content(module: types.ModuleType) -> str:
  121. return inspect.getsource(module)
  122. @dataclasses.dataclass
  123. class SourceInfo:
  124. inlined_sources: set[InlinedSource]
  125. def add_code(self, code: types.CodeType) -> None:
  126. module = inspect.getmodule(code)
  127. if module is None:
  128. return
  129. sourcelines, firstlineno = inspect.getsourcelines(code)
  130. lastlineno = firstlineno + len(sourcelines)
  131. source = "".join(sourcelines)
  132. assert source == "".join(_get_sourcelines(module, firstlineno, lastlineno))
  133. self.inlined_sources.add(
  134. InlinedSource(
  135. module=module.__name__,
  136. firstlineno=firstlineno,
  137. lastlineno=lastlineno,
  138. checksum=_hash_source(source),
  139. content=_get_module_content(module),
  140. )
  141. )
  142. @dataclasses.dataclass
  143. class _DynamoCodeCacheEntry:
  144. """
  145. Contains the serializable information associated with a single code object
  146. in dynamo. To restore an execution of compiled code, we will need the following
  147. ingredients:
  148. 1. The "original" code object, which serves as the entry point for eager
  149. execution, i.e. the code only executed when there's no cache entry hit.
  150. 2. The python module name this code object belongs to, for identifying the
  151. enclosing global scope to inject compiled and resume functions.
  152. 3. A list of function names that pointing to this code object. There could be
  153. multiple function objects pointing to the same code such as recursive functions.
  154. 4. A list of guarded code that eval frame dispatches to.
  155. 5. A list of imported module objects unioned from all compiled branches.
  156. 6. A list of "backends" (compiled fx graph) unioned from all compield branches.
  157. 7. A string path used to access the original code object users defined.
  158. A code object can be accessed by "{python_module}.{function_name}.{code_source}" .
  159. 8. A boolean flag indicating whether the function is installed to global scope.
  160. 9. A boolean flag indicating whether the function has a compile id.
  161. 10. Whether or not this code entry was bypassed
  162. """
  163. python_code: SerializedCode
  164. python_module: str
  165. function_names: list[_FunctionId]
  166. guarded_codes: list[_GuardedCodeCacheEntry]
  167. import_sources: dict[str, str]
  168. backend_ids: list[_BackendId]
  169. code_source: Optional[str]
  170. install_to_global: bool
  171. has_compile_id: bool = False
  172. bypassed: bool = False
  173. def _lookup_code(entry: _DynamoCodeCacheEntry) -> types.CodeType:
  174. assert len(entry.function_names) == 1
  175. fn: Any = sys.modules[entry.python_module]
  176. parts = entry.function_names[0].split(".")
  177. for part in parts:
  178. fn = getattr(fn, part)
  179. if entry.code_source:
  180. parts = entry.code_source.split(".")
  181. for part in parts:
  182. if part.endswith("]"):
  183. index_begin = part.rfind("[")
  184. assert isinstance(index_begin, int) and index_begin >= 0
  185. attr = getattr(fn, part[:index_begin], None)
  186. if attr is None:
  187. raise PackageError(f"Cannot find source for code entry {entry}")
  188. fn = attr[ast.literal_eval(part[index_begin + 1 : -1])]
  189. else:
  190. fn = getattr(fn, part)
  191. else:
  192. raise PackageError(f"Cannot find source for code entry {entry}")
  193. assert isinstance(fn, types.CodeType)
  194. return fn
  195. def _raise_resolution_error(code: types.CodeType, scope: Any) -> Never:
  196. raise PackageError(
  197. f"Cannot resolve a fully qualified name for {code}. Lookup scope: {scope}"
  198. )
  199. def _get_code_source(code: types.CodeType) -> tuple[str, str]:
  200. """
  201. Given a code object, return a fully qualified name which will be used as
  202. a serialized handle to access the code object from the new process.
  203. This is normally a straightforward process, but there are some corner cases:
  204. 1. When a function is defined with decorator, then this function will be captured
  205. inside a closure with the wrapper object.
  206. 2. When a function is defined as a nested function, then the code object will be
  207. stored on the co_consts field of the parent code object by Python compiler.
  208. This function handles all of the corner cases above.
  209. """
  210. module = inspect.getmodule(code)
  211. if module is None:
  212. raise PackageError(f"Cannot find module for code {code}")
  213. toplevel: Any = module
  214. if sys.version_info >= (3, 11):
  215. parts = code.co_qualname.split(".")
  216. for part in parts:
  217. if not hasattr(toplevel, part):
  218. _raise_resolution_error(code, toplevel)
  219. toplevel = getattr(toplevel, part)
  220. if inspect.isfunction(toplevel):
  221. break
  222. seen = set()
  223. def _find_code_source(obj: Any) -> Optional[str]:
  224. nonlocal toplevel
  225. nonlocal seen
  226. if obj in seen:
  227. return None
  228. seen.add(obj)
  229. if inspect.iscode(obj):
  230. if obj is code:
  231. return ""
  232. for i, const in enumerate(obj.co_consts):
  233. if (res := _find_code_source(const)) is not None:
  234. return f".co_consts[{i}]{res}"
  235. if inspect.isfunction(obj):
  236. if (res := _find_code_source(obj.__code__)) is not None:
  237. toplevel = obj
  238. return f".__code__{res}"
  239. if obj.__closure__ is not None:
  240. for i, cell in enumerate(obj.__closure__):
  241. try:
  242. cell_contents = cell.cell_contents
  243. except ValueError:
  244. continue
  245. if not (
  246. inspect.isfunction(cell_contents)
  247. or inspect.iscode(cell_contents)
  248. ):
  249. continue
  250. if (res := _find_code_source(cell_contents)) is not None:
  251. toplevel = obj
  252. return f".__closure__[{i}].cell_contents{res}"
  253. if sys.version_info < (3, 11):
  254. if inspect.ismodule(obj):
  255. for value in obj.__dict__.values():
  256. if not (inspect.isfunction(value) or inspect.isclass(value)):
  257. continue
  258. if (res := _find_code_source(value)) is not None:
  259. return res
  260. if inspect.isclass(obj):
  261. for name, value in obj.__dict__.items():
  262. value = getattr(obj, name)
  263. if not (inspect.isfunction(value) or inspect.isclass(value)):
  264. continue
  265. if (res := _find_code_source(value)) is not None:
  266. if value.__name__ != name:
  267. _raise_resolution_error(code, toplevel)
  268. return res
  269. return None
  270. code_source = _find_code_source(toplevel)
  271. if code_source is None:
  272. _raise_resolution_error(code, toplevel)
  273. # pyrefly: ignore [missing-attribute]
  274. return toplevel.__qualname__, code_source.strip(".")
  275. @dataclasses.dataclass(frozen=True)
  276. class SystemInfo:
  277. """
  278. System information including Python, PyTorch, and GPU details.
  279. This information is used to ensure compiled artifacts can only be loaded
  280. with compatible system configurations.
  281. """
  282. python_version: str
  283. torch_version: str
  284. toolkit_version: Optional[str]
  285. triton_version: Optional[tuple[int, int]]
  286. gpu_name: Optional[str]
  287. CHECK_GPUS = ("cuda", "xpu")
  288. @classmethod
  289. def current(cls) -> "SystemInfo":
  290. """Create a SystemInfo instance with current system information."""
  291. # Get GPU name if CUDA or XPU is available
  292. gpu_name = None
  293. from torch.utils._triton import get_triton_version
  294. gpu_name, toolkit_version = None, None
  295. for device_type in cls.CHECK_GPUS:
  296. if getattr(torch, device_type).is_available():
  297. try:
  298. gpu_name = getattr(torch, device_type).get_device_name()
  299. toolkit_version = getattr(torch.version, device_type)
  300. break
  301. except Exception:
  302. pass
  303. return cls(
  304. python_version=platform.python_version(),
  305. torch_version=torch.__version__,
  306. toolkit_version=toolkit_version,
  307. triton_version=get_triton_version((0, 0)),
  308. gpu_name=gpu_name,
  309. )
  310. def check_compatibility(
  311. self, other: "SystemInfo", device_type: str = "cpu"
  312. ) -> None:
  313. """
  314. Check if this SystemInfo is compatible with another SystemInfo.
  315. Raises RuntimeError if incompatible.
  316. """
  317. if self.python_version != other.python_version:
  318. raise RuntimeError(
  319. f"Compile package was created with a different Python version: {self.python_version}"
  320. )
  321. if self.torch_version != other.torch_version:
  322. raise RuntimeError(
  323. f"Compile package was created with a different PyTorch version: {self.torch_version}"
  324. )
  325. if device_type in self.CHECK_GPUS:
  326. if not getattr(torch, device_type).is_available():
  327. raise RuntimeError(f"{device_type} is not available")
  328. if self.toolkit_version != other.toolkit_version:
  329. raise RuntimeError(
  330. f"Compile package was created with a different toolkit version: {self.toolkit_version}"
  331. )
  332. if (
  333. other.triton_version != (0, 0)
  334. and self.triton_version != other.triton_version
  335. ):
  336. raise RuntimeError(
  337. f"Compile package was created with a different Triton version: {self.triton_version}"
  338. )
  339. # Check GPU name if CUDA/XPU was used
  340. if other.gpu_name is not None and self.gpu_name != other.gpu_name:
  341. raise RuntimeError(
  342. f"Compile package was created with different GPU: "
  343. f"cached={self.gpu_name}, current={other.gpu_name}"
  344. )
  345. @dataclasses.dataclass
  346. class _DynamoCacheEntry:
  347. codes: list[_DynamoCodeCacheEntry]
  348. source_info: SourceInfo
  349. device_type: str
  350. system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
  351. fn_name: Optional[str] = None
  352. fn_first_lineno: Optional[str] = None
  353. @property
  354. def backend_ids(self) -> set[_BackendId]:
  355. return {backend_id for code in self.codes for backend_id in code.backend_ids}
  356. def check_versions(self) -> None:
  357. """Check if the current system is compatible with the system used to create this cache entry."""
  358. current_system_info = SystemInfo.current()
  359. self.system_info.check_compatibility(current_system_info, self.device_type)
  360. def debug_info(self) -> dict[str, Any]:
  361. assert len(self.codes) > 0
  362. return {
  363. "num_codes": str(len(self.codes)),
  364. "fn_name": self.fn_name,
  365. "fn_first_lineno": self.fn_first_lineno,
  366. "device_type": self.device_type,
  367. "backend_ids": list(self.backend_ids),
  368. }
  369. from torch.compiler._cache import (
  370. CacheArtifact,
  371. CacheArtifactFactory,
  372. CacheArtifactManager,
  373. )
  374. @CacheArtifactFactory.register
  375. class PrecompileCacheArtifact(CacheArtifact):
  376. def populate_cache(self) -> None:
  377. DynamoCache._write_to_local_cache(self.content, self.key)
  378. @staticmethod
  379. def type() -> str:
  380. return "precompile"
  381. @dataclasses.dataclass
  382. class PrecompileCacheEntry:
  383. """
  384. A full cache entry for caching precompile, for a toplevel torch.compile.
  385. Consists of a _DynamoCacheEntry, which contains all the dynamo related contents,
  386. and a set of backends content. In general, the backend content here will always
  387. be of type precompile_context.BackendCacheArtifact
  388. """
  389. dynamo: _DynamoCacheEntry
  390. backends: dict[_BackendId, Any]
  391. @staticmethod
  392. def from_cache_entry(
  393. cache_entry: _DynamoCacheEntry, backends: dict[_BackendId, Any]
  394. ) -> Optional["PrecompileCacheEntry"]:
  395. backend_content: dict[_BackendId, Any] = {}
  396. for code in cache_entry.codes:
  397. for backend_id in code.backend_ids:
  398. if backend_id not in backends:
  399. logger.warning("Backend not found")
  400. debug_str = json.dumps(
  401. {
  402. "entry": cache_entry.debug_info(),
  403. "missing_backend": backend_id,
  404. }
  405. )
  406. torch._logging.trace_structured(
  407. "artifact",
  408. metadata_fn=lambda: {
  409. "name": "dynamo_cache_bypass",
  410. "encoding": "json",
  411. },
  412. payload_fn=lambda: debug_str,
  413. expect_trace_id=False,
  414. )
  415. code.bypassed = True
  416. break
  417. else:
  418. backend_content[backend_id] = backends[backend_id]
  419. return PrecompileCacheEntry(dynamo=cache_entry, backends=backend_content)
  420. def _hash_source(source: str) -> str:
  421. sha256_hash = hashlib.sha256()
  422. sha256_hash.update(source.encode())
  423. return sha256_hash.hexdigest()
  424. def _get_sourcelines(
  425. m: types.ModuleType, firstlineno: int, lastlineno: int
  426. ) -> list[str]:
  427. return inspect.getsourcelines(m)[0][firstlineno - 1 : lastlineno - 1]
  428. def _hash_sourcelines(m: types.ModuleType, firstlineno: int, lastlineno: int) -> str:
  429. return _hash_source("".join(_get_sourcelines(m, firstlineno, lastlineno)))
  430. def _compile_frame_context(
  431. code: types.CodeType,
  432. ) -> contextlib.AbstractContextManager[None]:
  433. from torch._dynamo.convert_frame import get_compile_id, log_dynamo_start
  434. from torch._guards import compile_context, CompileContext
  435. # Each code represents a new compile frame
  436. # recompiles on the same frame are all saved
  437. # under the same cache entry, so we don't have recompile ids
  438. # i.e. If cold start had 0/0, 0/1, 1/0, 1/1, these would be
  439. # collapsed into 0/0, 1/0 on warm.
  440. @contextlib.contextmanager
  441. def _ctx() -> Iterator[None]:
  442. increment_frame()
  443. compile_id = get_compile_id(frame_state={})
  444. with (
  445. compile_context(CompileContext(compile_id)),
  446. dynamo_timed(
  447. "_compile.compile_inner",
  448. phase_name="entire_frame_compile",
  449. dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
  450. # TODO: save all relevant compilation metrics
  451. metadata={
  452. "frame_key": str(torch._dynamo.utils.curr_frame),
  453. "co_name": code.co_name,
  454. "co_filename": code.co_filename,
  455. "co_firstlineno": code.co_firstlineno,
  456. },
  457. ),
  458. ):
  459. log_dynamo_start(code)
  460. yield
  461. return _ctx()
  462. class CompilePackage:
  463. """
  464. CompilePackage is considered a low level component and should not be directly exposed to
  465. end users. It has the following interface:
  466. 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
  467. a. when `dynamo` argument is None, it will construct a brand new CompilePackage object.
  468. b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
  469. 2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object.
  470. 3. `package.install(backends) which will handle all the side-effectful global scope
  471. updates with compiled functions and resume functions.
  472. """
  473. def __init__(
  474. self,
  475. fn: Optional[Callable[..., Any]],
  476. dynamo: Optional[_DynamoCacheEntry] = None,
  477. ignore_inlined_sources: bool = False,
  478. ) -> None:
  479. self._innermost_fn = None
  480. self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
  481. self._current_entry: Optional[_DynamoCodeCacheEntry] = None
  482. self._installed_globals: dict[types.ModuleType, list[str]] = {}
  483. # device_type that model compiled with.
  484. self._device_type = "cpu"
  485. # For debugging/testing purpose only.
  486. self._cached_backends: dict[_BackendId, Any] = {}
  487. self._source_info: SourceInfo = SourceInfo(inlined_sources=set())
  488. self._resume_codes: set[types.CodeType] = set()
  489. self._initialized = False
  490. if fn is not None:
  491. self.initialize(fn, dynamo, ignore_inlined_sources)
  492. self.uninstall()
  493. self.validate()
  494. def is_initialized(self) -> bool:
  495. return self._initialized
  496. def initialize(
  497. self,
  498. fn: Any,
  499. dynamo: Optional[_DynamoCacheEntry] = None,
  500. ignore_inlined_sources: bool = False,
  501. ) -> None:
  502. from .eval_frame import innermost_fn
  503. assert not self._initialized
  504. self._source_info = SourceInfo(inlined_sources=set())
  505. self._innermost_fn = innermost_fn(fn) # type: ignore[assignment]
  506. assert self._innermost_fn is not None
  507. if dynamo is not None:
  508. assert isinstance(dynamo, _DynamoCacheEntry)
  509. dynamo.check_versions()
  510. if not ignore_inlined_sources:
  511. for code in dynamo.source_info.inlined_sources:
  512. m = importlib.import_module(code.module)
  513. checksum = _hash_sourcelines(m, code.firstlineno, code.lastlineno)
  514. if checksum != code.checksum:
  515. raise RuntimeError(
  516. f"Source code changes detected for {code.module} (line {code.firstlineno} - line {code.lastlineno})"
  517. )
  518. # pyrefly: ignore [bad-assignment]
  519. self._source_info = dynamo.source_info
  520. main, *codes = dynamo.codes
  521. # pyrefly: ignore [bad-assignment]
  522. self._codes = {self._innermost_fn.__code__: main}
  523. for code in codes:
  524. self._codes[SerializedCode.to_code_object(code.python_code)] = code
  525. else:
  526. self._add_function(
  527. self._innermost_fn.__code__, self._innermost_fn.__module__
  528. )
  529. # pyrefly: ignore [bad-assignment]
  530. self._initialized = True
  531. def _add_function(
  532. self,
  533. python_code: types.CodeType,
  534. python_module: str,
  535. function_name: Optional[_FunctionId] = None,
  536. code_source: Optional[str] = None,
  537. install_to_global: bool = False,
  538. ) -> None:
  539. if python_code not in self._codes:
  540. code = _DynamoCodeCacheEntry(
  541. python_code=SerializedCode.from_code_object(python_code),
  542. python_module=python_module,
  543. function_names=[],
  544. guarded_codes=[],
  545. import_sources={},
  546. backend_ids=[],
  547. code_source=code_source,
  548. install_to_global=install_to_global,
  549. )
  550. self._codes[python_code] = code
  551. else:
  552. code = self._codes[python_code]
  553. assert code.python_module == python_module
  554. assert code.install_to_global == install_to_global
  555. assert code.code_source == code_source
  556. if function_name is not None:
  557. code.function_names.append(function_name)
  558. @property
  559. def cached_backends(self) -> dict[_BackendId, Any]:
  560. return self._cached_backends
  561. @functools.cached_property
  562. def source_id(self) -> str:
  563. assert self._innermost_fn is not None
  564. return CompilePackage.source_id_from_fn(self._innermost_fn)
  565. def _add_user_function(self, code: types.CodeType) -> None:
  566. function_name, code_source = _get_code_source(code)
  567. module = inspect.getmodule(code)
  568. if module is None:
  569. raise PackageError(f"Cannot find module for code {code}")
  570. self._add_function(
  571. code,
  572. module.__name__,
  573. function_name=_FunctionId(function_name),
  574. code_source=code_source,
  575. )
  576. @contextlib.contextmanager
  577. def code_context(self, code: types.CodeType) -> Generator[None, None, None]:
  578. assert self._current_entry is None
  579. # Sometimes user code cannot be inlined in dynamo resulting in extra user code
  580. # being compiled. We should record these as when they are actually invoked.
  581. if code not in self._codes:
  582. self._add_user_function(code)
  583. entry = self._codes[code]
  584. self._current_entry = entry
  585. try:
  586. yield
  587. finally:
  588. entry.has_compile_id = True
  589. self._current_entry = None
  590. def add_guarded_code(
  591. self,
  592. guards_state: bytes,
  593. dynamo_code: types.CodeType,
  594. ) -> None:
  595. assert self._current_entry is not None
  596. if self._current_entry.bypassed:
  597. return
  598. guarded_code_entry = _GuardedCodeCacheEntry(
  599. guards_state=guards_state,
  600. dynamo_code=SerializedCode.from_code_object(dynamo_code),
  601. )
  602. self._current_entry.guarded_codes.append(guarded_code_entry)
  603. def add_inlined_source(self, sources: list[types.CodeType]) -> None:
  604. assert self._current_entry is not None
  605. if self._current_entry.bypassed:
  606. return
  607. for code in sources:
  608. if code in self._resume_codes:
  609. continue
  610. self._source_info.add_code(code)
  611. def update_device_type(self, graph: Optional[torch.fx.Graph]) -> None:
  612. self._device_type = _graph_device_type(graph)
  613. def bypass_current_entry(self) -> None:
  614. assert self._current_entry is not None
  615. self._current_entry.bypassed = True
  616. def add_resume_function(
  617. self,
  618. python_code: types.CodeType,
  619. python_module: str,
  620. function_name: Optional[str],
  621. ) -> None:
  622. self._add_function(
  623. python_code,
  624. python_module,
  625. function_name=_FunctionId(function_name) if function_name else None,
  626. install_to_global=True,
  627. )
  628. self._resume_codes.add(python_code)
  629. def add_import_source(self, alias: str, module_name: str) -> None:
  630. assert self._current_entry is not None
  631. self._current_entry.import_sources[alias] = module_name
  632. def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None:
  633. assert self._current_entry is not None
  634. assert backend_id.startswith("__compiled_fn_") # sanity check
  635. backend_id = _BackendId(backend_id)
  636. self._current_entry.backend_ids.append(backend_id)
  637. if backend is not None:
  638. self._cached_backends[backend_id] = backend
  639. def validate(self) -> None:
  640. assert self._current_entry is None
  641. assert self._innermost_fn is not None
  642. assert self._initialized
  643. assert next(iter(self._codes)) is self._innermost_fn.__code__
  644. def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None:
  645. module.__dict__[name] = value
  646. self._installed_globals.setdefault(module, []).append(name)
  647. def uninstall(self) -> None:
  648. from torch._C._dynamo.eval_frame import _reset_precompile_entries
  649. assert self._innermost_fn is not None
  650. for module, names in self._installed_globals.items():
  651. for name in names:
  652. module.__dict__.pop(name)
  653. # pyrefly: ignore [bad-assignment]
  654. self._installed_globals = {}
  655. _reset_precompile_entries(self._innermost_fn.__code__)
  656. def install(self, backends: dict[_BackendId, Any]) -> None:
  657. """
  658. Sync the package states to the compiled function. This includes the following actions:
  659. 1. Clean up the previously installed states.
  660. 2. Install the compiled functions to global scopes.
  661. 3. Install the precompiled cache entries to ExtraStates on the code object.
  662. """
  663. from torch._C._dynamo.eval_frame import _load_precompile_entry
  664. from .output_graph import get_builtins_dict
  665. self.uninstall()
  666. for code, entry in self._codes.items():
  667. context = (
  668. _compile_frame_context(code)
  669. if entry.has_compile_id
  670. else contextlib.nullcontext()
  671. )
  672. with context:
  673. module = sys.modules[entry.python_module]
  674. for alias, module_name in entry.import_sources.items():
  675. self._install_global(
  676. module, alias, importlib.import_module(module_name)
  677. )
  678. target_code = code
  679. if entry.install_to_global:
  680. for function_name in entry.function_names:
  681. fn = types.FunctionType(code, module.__dict__, function_name)
  682. self._install_global(module, function_name, fn)
  683. if entry.code_source:
  684. target_code = _lookup_code(entry)
  685. if entry.bypassed:
  686. # If the entry is bypassed, do not install backends
  687. # or guarded codes.
  688. continue
  689. for backend_id in entry.backend_ids:
  690. if backend_id not in backends:
  691. raise RuntimeError(
  692. f"Backend {backend_id} is not found in the given backends"
  693. )
  694. with dynamo_timed(
  695. "after_deserialization", phase_name="backend_compile"
  696. ):
  697. backend = backends[backend_id].after_deserialization()
  698. self._install_global(
  699. module,
  700. backend_id,
  701. torch._dynamo.disable(backend),
  702. )
  703. if len(entry.guarded_codes) == 0:
  704. # Dynamo generates empty graph for trivial functions, should just skip them
  705. # in these cases.
  706. torch._dynamo.eval_frame.skip_code(target_code)
  707. for guarded_code in entry.guarded_codes:
  708. with dynamo_timed("precompile_load_guards"):
  709. guards_state = load_guards_state(guarded_code.guards_state)
  710. runtime_global_scope = sys.modules[entry.python_module].__dict__
  711. # The installed builtins dict might be absent from the runtime
  712. # while loading guards. Populate it if it's missing.
  713. if (
  714. builtin_dict_name
  715. := guards_state.output_graph.name_of_builtins_dict_key_in_fglobals
  716. ):
  717. builtins_dict = get_builtins_dict(runtime_global_scope)
  718. if builtin_dict_name in runtime_global_scope:
  719. assert (
  720. runtime_global_scope[builtin_dict_name] is builtins_dict
  721. )
  722. else:
  723. runtime_global_scope[builtin_dict_name] = builtins_dict
  724. assert isinstance(guards_state, torch._dynamo.guards.GuardsState)
  725. with dynamo_timed("precompile_build_guards"):
  726. guard_manager = load_guard_manager(
  727. guards_state, target_code, runtime_global_scope
  728. )
  729. _load_precompile_entry(
  730. target_code,
  731. guard_manager,
  732. SerializedCode.to_code_object(guarded_code.dynamo_code),
  733. )
  734. def cache_entry(self) -> _DynamoCacheEntry:
  735. self.validate()
  736. assert self._innermost_fn is not None
  737. return _DynamoCacheEntry(
  738. codes=list(self._codes.values()),
  739. source_info=self._source_info,
  740. device_type=self._device_type,
  741. fn_name=self._innermost_fn.__qualname__,
  742. fn_first_lineno=self._innermost_fn.__code__.co_firstlineno,
  743. )
  744. @staticmethod
  745. def source_id_from_fn(fn: Callable[..., Any]) -> str:
  746. from .eval_frame import innermost_fn
  747. innermost_fn_ = innermost_fn(fn)
  748. sha256_hash = hashlib.sha256()
  749. sha256_hash.update(innermost_fn_.__qualname__.encode())
  750. sha256_hash.update(str(innermost_fn_.__code__.co_firstlineno).encode())
  751. return sha256_hash.hexdigest()
  752. _Backends = dict[_BackendId, Any]
  753. class DynamoStore(abc.ABC):
  754. """
  755. A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them.
  756. This is an abstract base class for different storage implementations.
  757. """
  758. def record_package(self, package: CompilePackage) -> None:
  759. """
  760. Records a package to PrecompileContext, so that it can be serialized later.
  761. """
  762. from torch._dynamo.precompile_context import PrecompileContext
  763. cache_entry = package.cache_entry()
  764. PrecompileContext.record_dynamo_cache_entry(
  765. cache_entry=cache_entry, key=package.source_id
  766. )
  767. def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None:
  768. """
  769. Records eager fx graphs to PrecompileContext for testing purposes.
  770. """
  771. from torch._dynamo.precompile_context import (
  772. EagerCacheArtifact,
  773. PrecompileContext,
  774. )
  775. result = EagerCacheArtifact(key=backend_id, content=backend)
  776. PrecompileContext.record_artifact(result)
  777. @abc.abstractmethod
  778. def clear(self) -> None: ...
  779. @abc.abstractmethod
  780. def write(
  781. self,
  782. cache_entry: PrecompileCacheEntry,
  783. path: str,
  784. ) -> None:
  785. """
  786. Abstract method to write dynamo cache entry and backends to storage.
  787. Args:
  788. dynamo: The dynamo cache entry to write
  789. backends: Dictionary of backend content to write
  790. path: Path or key to identify where to write the data
  791. """
  792. ...
  793. def save_cache_entry(self, cache_entry: _DynamoCacheEntry, key: str) -> None:
  794. """
  795. Saves a package to a given path. Grabs backends from PrecompileContext.
  796. """
  797. from torch._dynamo.precompile_context import (
  798. BackendCacheArtifact,
  799. PrecompileContext,
  800. )
  801. backend_content: _Backends = {}
  802. for backend_id in cache_entry.backend_ids:
  803. serialized_backend = PrecompileContext.serialize_artifact_by_key(backend_id)
  804. if serialized_backend is None:
  805. raise RuntimeError(
  806. f"Backend {backend_id} is not found in the given backends"
  807. )
  808. assert isinstance(serialized_backend, BackendCacheArtifact)
  809. backend_content[backend_id] = serialized_backend
  810. entry = PrecompileCacheEntry(cache_entry, backend_content)
  811. self.write(entry, key)
  812. def save_package(self, package: CompilePackage, key: str) -> None:
  813. """
  814. Saves a package to a given path. Grabs backends from PrecompileContext.
  815. """
  816. self.record_package(package)
  817. cache_entry = package.cache_entry()
  818. self.save_cache_entry(cache_entry, key)
  819. @abc.abstractmethod
  820. def read(self, path: str) -> PrecompileCacheEntry:
  821. """
  822. Abstract method to read dynamo cache entry and backends from storage.
  823. Args:
  824. path: Path or key to identify where to read the data from
  825. Returns:
  826. A tuple containing (dynamo_cache_entry, backend_content)
  827. """
  828. ...
  829. def load_cache_entry(self, key: str) -> PrecompileCacheEntry:
  830. from torch._dynamo.precompile_context import (
  831. BackendCacheArtifact,
  832. PrecompileContext,
  833. )
  834. precompile_entry = self.read(key)
  835. for backend in precompile_entry.backends.values():
  836. assert isinstance(backend, BackendCacheArtifact)
  837. PrecompileContext.record_artifact(backend)
  838. return precompile_entry
  839. def load_package(
  840. self, fn: Any, key: str
  841. ) -> tuple[CompilePackage, dict[_BackendId, Any]]:
  842. """
  843. Loads a package from a given path and returns it plus a list of deserialized backends
  844. """
  845. entry = self.load_cache_entry(key)
  846. package = CompilePackage(fn, entry.dynamo)
  847. return package, entry.backends
  848. class InMemoryDynamoStore(DynamoStore):
  849. """
  850. A DynamoStore implementation that keeps state about CompilePackages in memory.
  851. """
  852. def __init__(self) -> None:
  853. self.packages: dict[str, PrecompileCacheEntry] = {}
  854. def clear(self) -> None:
  855. self.packages.clear()
  856. def write(
  857. self,
  858. entry: PrecompileCacheEntry,
  859. path: str,
  860. ) -> None:
  861. """
  862. Store the dynamo cache entry and backends in memory instead of writing to disk.
  863. """
  864. self.packages[path] = entry
  865. def read(self, path: str) -> PrecompileCacheEntry:
  866. """
  867. Read dynamo cache entry and backends from memory.
  868. """
  869. if path not in self.packages:
  870. raise RuntimeError(f"No package found with key {path}")
  871. return self.packages[path]
  872. class DiskDynamoStore(DynamoStore):
  873. """
  874. A DynamoStore implementation that keeps state about CompilePackages on disk.
  875. """
  876. def __init__(self, path_prefix: str = ""):
  877. """
  878. Initialize a DiskDynamoStore with a path prefix.
  879. Args:
  880. path_prefix: Prefix directory for where to put CompilePackages on disk
  881. """
  882. self._path_prefix = path_prefix
  883. def path_prefix(self) -> str:
  884. return self._path_prefix
  885. def clear(self) -> None:
  886. """
  887. Clear all CompilePackages from disk.
  888. """
  889. if self.path_prefix():
  890. shutil.rmtree(self.path_prefix(), ignore_errors=True)
  891. def write(
  892. self,
  893. entry: PrecompileCacheEntry,
  894. path: str,
  895. ) -> None:
  896. """
  897. Write dynamo cache entry and backends to disk.
  898. """
  899. try:
  900. pickled_content: bytes = pickle.dumps(entry)
  901. CacheArtifactManager.record_artifact(
  902. PrecompileCacheArtifact.type(), path, pickled_content
  903. )
  904. self._write_to_local_cache(pickled_content, path)
  905. except Exception as e:
  906. raise RuntimeError(f"Failed to save package to {path}: {e}") from e
  907. def _write_to_local_cache(self, pickled_content: bytes, path: str) -> None:
  908. from torch._inductor.codecache import write_atomic
  909. path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
  910. try:
  911. os.makedirs(path, exist_ok=True)
  912. write_atomic(os.path.join(path, "entry"), pickled_content)
  913. except Exception as e:
  914. raise RuntimeError(f"Failed to save package to {path}: {e}") from e
  915. def read(self, path: str) -> PrecompileCacheEntry:
  916. """
  917. Read dynamo cache entry and backends from disk.
  918. """
  919. path = os.path.join(self.path_prefix(), path) if self.path_prefix() else path
  920. try:
  921. with open(os.path.join(path, "entry"), "rb") as f:
  922. pickled_content = f.read()
  923. entry = pickle.loads(pickled_content)
  924. return entry
  925. except Exception as e:
  926. raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
  927. class DiskDynamoCache(DiskDynamoStore):
  928. """
  929. Special DiskDynamoStore which adds some helper functions for automatically
  930. tracking paths of packages
  931. """
  932. def save(self, package: CompilePackage) -> None:
  933. """
  934. Saves a package to a given path. Grabs backends from PrecompileContext.
  935. """
  936. key = package.source_id
  937. logger.info("Saving CompilePackage for %s", package.source_id)
  938. super().save_package(package, key)
  939. def load(self, fn: Callable[..., Any]) -> Optional[PrecompileCacheEntry]:
  940. """
  941. Loads a package from a given path and returns it plus a list of deserialized backends
  942. """
  943. key = CompilePackage.source_id_from_fn(fn)
  944. logger.info("Loading CompilePackage for %s", key)
  945. path = os.path.join(self.path_prefix(), key)
  946. if os.path.exists(path):
  947. try:
  948. result = super().load_cache_entry(key)
  949. counters["dynamo_cache"]["dynamo_cache_hit"] += 1
  950. return result
  951. except Exception:
  952. counters["dynamo_cache"]["dynamo_cache_error"] += 1
  953. logger.warning("Failed to load package from path %s", exc_info=True)
  954. return None
  955. logger.info("No package found for %s", key)
  956. counters["dynamo_cache"]["dynamo_cache_miss"] += 1
  957. return None
  958. def load_and_install_package(
  959. self, fn: Callable[..., Any]
  960. ) -> Optional[CompilePackage]:
  961. """
  962. Load directly into a package and install backends
  963. """
  964. results = self.load(fn)
  965. if results is None:
  966. return None
  967. else:
  968. package = CompilePackage(fn, results.dynamo)
  969. package.install(results.backends)
  970. return package
  971. def path_prefix(self) -> str:
  972. return os.path.join(cache_dir(), "dynamo")
  973. def cache_dir() -> str:
  974. from torch._inductor.runtime.cache_dir_utils import cache_dir
  975. return cache_dir()
  976. DynamoCache = DiskDynamoCache(os.path.join(cache_dir(), "dynamo"))