pgo.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. """
  2. Profile Guided Optimization (PGO) implementation for Dynamo.
  3. This module provides functionality for caching and managing code state profiles
  4. that guide optimization decisions in Dynamo. It implements both local and remote
  5. caching mechanisms for storing profile information across runs, handles profile
  6. merging across distributed ranks, and manages the lifecycle of profile data
  7. during compilation. The profiles track dynamic vs static properties of tensors
  8. and help Dynamo make better specialization decisions.
  9. """
  10. from __future__ import annotations
  11. import base64
  12. import copy
  13. import dataclasses
  14. import enum
  15. import functools
  16. import logging
  17. import os
  18. import pickle
  19. import re
  20. import zlib
  21. from collections import defaultdict
  22. from typing import Optional, TYPE_CHECKING, TypeVar, Union
  23. from typing_extensions import override, Self
  24. import torch._dynamo.config
  25. import torch._utils_internal
  26. import torch.compiler.config
  27. import torch.distributed as dist
  28. from torch._dynamo.utils import (
  29. CompileEventLogger,
  30. dynamo_timed,
  31. set_feature_use,
  32. warn_once,
  33. )
  34. from torch._environment import is_fbcode
  35. from torch._logging._internal import trace_structured_artifact
  36. from torch.compiler._cache import (
  37. CacheArtifact,
  38. CacheArtifactFactory,
  39. CacheArtifactManager,
  40. )
  41. from torch.utils._ordered_set import OrderedSet
  42. if TYPE_CHECKING:
  43. import types
  44. from torch._dynamo.symbolic_convert import InstructionTranslator
  45. from torch._inductor.remote_cache import JsonDataTy, RemoteCache
  46. class ReservedWorkflowIdUserError(ValueError):
  47. pass
  48. log = logging.getLogger(__name__)
  49. LOCK_TIMEOUT = 10
  50. # How does in memory representation work? Concretely, this module is
  51. # responsible for holding GLOBAL state representing the state it holds, no
  52. # other copies permitted. So we retire frame_state entirely and store it
  53. # here. This should be reset when Dynamo is reset. We never GC information
  54. # (similar to how the filesystem doesn't get cleaned up except by tmp
  55. # cleaner), so the expectation is the information is relatively cheap and we
  56. # don't mind leaking it.
  57. # How exactly did we design the cache key? Here are some of the questions:
  58. #
  59. # - JOB_ID: Do we have a unique identifier for the "training run" (such that
  60. # it stays the same if we're running the same code, and changes if we're
  61. # running something different).
  62. #
  63. # - RANK: Are we sharing the cache across ranks, or does each rank get
  64. # an individual cache?
  65. #
  66. # We choose to require job_id for PGO cache. This is to prevent
  67. # situations where unrelated invocations of PyTorch unpredictably cause
  68. # changes to each other's behavior. With a job_id, at least you know there
  69. # is some "state" associated with it. (State dict might be another way to
  70. # tell if a run is related or not.) You can opt-in to YOLO everything
  71. # aliases everything by passing a shared job_id for all your invocations.
  72. #
  73. # We choose to NOT share PGO cache across ranks. With no RANK_SHARING, there
  74. # is never contention between runs, so we can leisurely update a bundle with
  75. # information we need. Because we are grouped by job_id, we can have a single
  76. # consolidated bundle for everything (or not; maybe worry about O(n^2) IO if
  77. # we updated every compile--let's just instrument this.) Can even take a
  78. # filelock for extra safety (expect no contention); expect 50ns overhead from
  79. # uncontended filelock.
  80. #
  81. # If we did share ranks, everyone is storming to modify the same cache files.
  82. # We can do this by having folks atomic write to a CAS-store and then having
  83. # readers do on-the-fly merging (this can be implemented in remote using
  84. # prefix iteration). As an optional optimization, one rank can be elected to
  85. # handling bundling post facto (ideally, this is done async, after quiescence,
  86. # without compiler collective need to wait for everyone to finish writing
  87. # their bits.) Not sure how you can avoid a listdir because if some rank shows
  88. # up with some new entries we need to pull them in ASAP (unless you want to
  89. # delay bundling).
  90. #
  91. # But compiler collectives fill a similar niche: compilers chat with each
  92. # other so rank 0 has collected everything. So elect rank 0 only to write the
  93. # bundle. Don't even need CAS-store atomic write; just one rank writing an
  94. # updating bundles. The point is that use compiler collectives to share
  95. # profiles across ranks, but use the PGO cache to persist profiles per rank
  96. # across attempts. No need to have one mechanism to do everything.
  97. @functools.cache
  98. def _hash_containing_file(filepath: str) -> str:
  99. # if the file does not exists we consider filepath to be the hash.
  100. if not os.path.exists(filepath):
  101. return filepath
  102. with open(filepath, "rb") as file:
  103. content = file.read()
  104. crc32_value = zlib.crc32(content)
  105. hash = format(crc32_value & 0xFFFFFFFF, "08x")
  106. return hash
  107. @dataclasses.dataclass(frozen=True)
  108. class CodeId:
  109. filename: str
  110. firstlineno: int
  111. name: str
  112. # When a job restart, the code can be copied to a different path than the previous attempt. In that case
  113. # self.filename will have a different value, we do not want to consider those differences. Instead we
  114. # hash the content of the file and use it as an identifier of the file.
  115. #
  116. # self.filename is kept in the object to give readable information/pointer to the actual file, in a local
  117. # code state it will refer to the first seen file path.
  118. file_hash: str
  119. # Exclude file name.
  120. def __eq__(self, other: object) -> bool:
  121. if not isinstance(other, CodeId):
  122. return False
  123. return (
  124. self.file_hash == other.file_hash
  125. and self.firstlineno == other.firstlineno
  126. and self.name == other.name
  127. )
  128. # Ensure if two CodeIds are the same, then they have the same hash by excluding filename.
  129. def __hash__(self) -> int:
  130. return hash((self.file_hash, self.name, self.firstlineno))
  131. def __str__(self) -> str:
  132. return f"hash({self.file_hash}){self.filename}:{self.firstlineno}:{self.name}"
  133. @staticmethod
  134. def make(code: types.CodeType) -> CodeId:
  135. return CodeId(
  136. code.co_filename,
  137. code.co_firstlineno,
  138. code.co_name,
  139. _hash_containing_file(code.co_filename),
  140. )
  141. @dataclasses.dataclass
  142. class CodeState:
  143. automatic_dynamic: defaultdict[str, FrameStateSizeEntry] = dataclasses.field(
  144. default_factory=lambda: defaultdict(FrameStateSizeEntry)
  145. )
  146. _INIT_CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
  147. _CODE_STATE: Optional[defaultdict[CodeId, CodeState]] = None
  148. _LOGGED_DYNAMIC_ALLOWLIST: bool = False
  149. @dataclasses.dataclass(frozen=True)
  150. class InferStride:
  151. """
  152. Denotes the quantity stride[dim] * size[dim], which is what the stride would
  153. be for the next physical dimension that results in a contiguous layout.
  154. For example, given size = [2, 3], stride = [3, 1], we can replace this with
  155. stride = [InferStride(1), 1], because InferStride(1) = stride[1] * size[1] = 1 * 3 = 3
  156. Indirecting the representation in this way is important for the join operation
  157. on strides as if we join [2, 3][3, 1] and [2, 4][4, 1],
  158. we don't want [2, None][None, 1] which would get eventually symbolized into
  159. [2, s0][s1, 1] (notice that the relationship between s0 and s1 is broken).
  160. If we instead rewrite the expressions as InferStride so we have [2, 3][InferStride(1), 1]
  161. and [2, 4][InferStride(1), 1] we now join to [2, None][InferStride(1), 1] will
  162. result in [2, s0][s0, 1], as desired.
  163. """
  164. dim: int
  165. _T = TypeVar("_T")
  166. class AutoUnset(enum.Enum):
  167. """
  168. The identity element of our semilattice, a generic "don't know" element that
  169. is always subsumed when we get more information.
  170. """
  171. token = 0
  172. auto_unset = AutoUnset.token
  173. class AutoDynamic(enum.Enum):
  174. """
  175. The top element of our (bounded) semilattice, whenever you merge this with
  176. any other element you always get it again
  177. """
  178. token = 0
  179. auto_dynamic = AutoDynamic.token
  180. @dataclasses.dataclass
  181. class FrameStateSizeEntry:
  182. scalar: Union[int, AutoDynamic, AutoUnset] = dataclasses.field(default=auto_unset)
  183. # NB: We don't have cases where we have a known dimensionality but
  184. # we know NOTHING about the individual sizes
  185. size: Union[AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic], ...]] = (
  186. dataclasses.field(default=auto_unset)
  187. )
  188. stride: Union[
  189. AutoDynamic, AutoUnset, tuple[Union[int, AutoDynamic, InferStride], ...]
  190. ] = dataclasses.field(default=auto_unset)
  191. def render(self) -> str:
  192. # Special cases
  193. def render_single(s: Union[int, AutoDynamic, AutoUnset, InferStride]) -> str:
  194. if s is auto_dynamic:
  195. return "?"
  196. elif s is auto_unset:
  197. # This basically shouldn't happen, this is for debugging
  198. return "auto unset"
  199. elif isinstance(s, InferStride):
  200. return f"S({s.dim})"
  201. else:
  202. return str(s)
  203. def render_tuple(ss: tuple[Union[int, AutoDynamic, InferStride], ...]) -> str:
  204. return "[" + ", ".join(render_single(s) for s in ss) + "]"
  205. # Common cases
  206. if self.size is auto_dynamic and self.stride is auto_dynamic:
  207. if self.scalar is auto_dynamic:
  208. return "fully dynamic scalar or tensor"
  209. else:
  210. return f"scalar {self.scalar}"
  211. elif self.scalar is auto_dynamic:
  212. if isinstance(self.size, tuple) and isinstance(self.stride, tuple):
  213. return f"tensor size={render_tuple(self.size)} stride={render_tuple(self.stride)}"
  214. # Fallback
  215. return "unusual {repr(self)}"
  216. def __post_init__(self) -> None:
  217. assert not isinstance(self.scalar, torch.SymInt), self.scalar
  218. if isinstance(self.size, tuple):
  219. for s in self.size:
  220. assert not isinstance(s, torch.SymInt), s
  221. if isinstance(self.stride, tuple):
  222. for s1 in self.stride:
  223. assert not isinstance(s1, torch.SymInt), s1
  224. def is_size_dynamic(self, dim: int) -> bool:
  225. if self.size is auto_dynamic:
  226. return True
  227. if self.size is auto_unset:
  228. return False
  229. return self.size[dim] is auto_dynamic
  230. def is_stride_dynamic(self, dim: int) -> bool:
  231. # At the moment, dynamic strides is a bit buggy. Good test case
  232. # here is `PYTORCH_TEST_WITH_DYNAMO=1 python test/test_autograd.py
  233. # TestAutograd.test_gradcheck_jacobian_mismatch`
  234. #
  235. # This if statement preserves historical behavior, which is that we
  236. # ONLY make strides dynamic if the size is exactly static everywhere.
  237. # We could potentially relax this but in general we should be very
  238. # careful about when to infer dynamic strides.
  239. #
  240. # Actually, the existing algorithm is already somewhat problematic.
  241. # Suppose a tensor that is sometimes:
  242. # f32[2, 3, 5][15, 5, 1] and other times
  243. # f32[2, 3, 5][5, 10, 1] (specifically, dim 0 and 1 are physically transposed).
  244. # If we infer strides should be (DYNAMIC, DYNAMIC, 1). But this is
  245. # silly: we really should have just guarded on dim order.
  246. if not (
  247. isinstance(self.size, tuple) and all(type(s) is int for s in self.size)
  248. ):
  249. return False
  250. if self.stride is auto_dynamic:
  251. return True
  252. if self.stride is auto_unset:
  253. return False
  254. return self.stride[dim] is auto_dynamic
  255. @staticmethod
  256. def _munge_symint(xs: tuple[int, ...]) -> tuple[Union[AutoDynamic, int], ...]:
  257. return tuple(auto_dynamic if isinstance(x, torch.SymInt) else x for x in xs)
  258. @classmethod
  259. def make_scalar(cls, x: int) -> FrameStateSizeEntry:
  260. return FrameStateSizeEntry(scalar=x, size=auto_dynamic, stride=auto_dynamic)
  261. @classmethod
  262. def make_tensor(
  263. cls, size: tuple[int, ...], stride: tuple[int, ...]
  264. ) -> FrameStateSizeEntry:
  265. return FrameStateSizeEntry(
  266. scalar=auto_dynamic,
  267. size=cls._munge_symint(size),
  268. stride=cls._munge_symint(stride),
  269. )
  270. @classmethod
  271. def make_size(cls, size: tuple[int, ...]) -> FrameStateSizeEntry:
  272. return FrameStateSizeEntry(
  273. scalar=auto_unset,
  274. size=cls._munge_symint(size),
  275. stride=auto_unset,
  276. )
  277. @staticmethod
  278. def _merge_atom(x: _T, y: _T) -> Union[AutoDynamic, _T]:
  279. if x is auto_unset:
  280. return y
  281. if y is auto_unset:
  282. return x
  283. if x is auto_dynamic or y is auto_dynamic or x != y:
  284. return auto_dynamic
  285. return x
  286. @classmethod
  287. def _merge_atom_tup(
  288. cls,
  289. xs: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
  290. ys: Union[AutoDynamic, AutoUnset, tuple[_T, ...]],
  291. ) -> Union[AutoDynamic, AutoUnset, tuple[Union[AutoDynamic, _T], ...]]:
  292. if xs is auto_unset:
  293. return ys
  294. if ys is auto_unset:
  295. return xs
  296. if xs is auto_dynamic or ys is auto_dynamic:
  297. return auto_dynamic
  298. if len(xs) != len(ys):
  299. return auto_dynamic
  300. return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys))
  301. def __ior__(self, other: Self) -> Self:
  302. self.scalar = self._merge_atom(self.scalar, other.scalar)
  303. self.size = self._merge_atom_tup(self.size, other.size)
  304. self.stride = self._merge_atom_tup(self.stride, other.stride)
  305. return self
  306. def update_automatic_dynamic(
  307. tx: InstructionTranslator,
  308. name: str,
  309. entry: FrameStateSizeEntry,
  310. *,
  311. is_unspecialized_nn_module: bool = False,
  312. ) -> FrameStateSizeEntry:
  313. code_id = CodeId.make(tx.f_code)
  314. frame_state = get_code_state()[code_id]
  315. if torch._dynamo.config.automatic_dynamic_shapes:
  316. is_update = name in frame_state.automatic_dynamic
  317. mut_entry = frame_state.automatic_dynamic[name]
  318. old_entry = copy.copy(mut_entry)
  319. mut_entry |= entry
  320. # Do some logs (damn, I spend more code logging than I do actually doing
  321. # the updates lol)
  322. if is_update and old_entry.scalar != mut_entry.scalar:
  323. log.debug(
  324. "automatic dynamic int %s val %s != %s",
  325. name,
  326. entry.scalar,
  327. old_entry.scalar,
  328. )
  329. CompileEventLogger.instant(
  330. "automatic_dynamic",
  331. {
  332. "name": name,
  333. "dim_changed": "scalar",
  334. "reason": "scalar change",
  335. "cached": str(old_entry.scalar),
  336. "new": str(entry.scalar),
  337. },
  338. )
  339. if is_unspecialized_nn_module:
  340. log.info(
  341. "%s is converted to a symbolic integer. It is an attribute of a "
  342. "user defined nn module class. If you wish to keep it static, you can "
  343. "mark the nn module class as `torch._dynamo.mark_static`.",
  344. name,
  345. )
  346. def log_tup(
  347. tup_name: str, short_reason: str, long_reason: str, i: Optional[int] = None
  348. ) -> None:
  349. entry_tup = (
  350. getattr(entry, tup_name) if i is None else getattr(entry, tup_name)[i]
  351. )
  352. old_entry_tup = (
  353. getattr(old_entry, tup_name)
  354. if i is None
  355. else getattr(old_entry, tup_name)[i]
  356. )
  357. log.debug(
  358. "automatic dynamic %s %s %s %s != %s",
  359. tup_name,
  360. name,
  361. short_reason,
  362. # NB: We used to only report len(...) here for dim mismatch
  363. entry_tup,
  364. old_entry_tup,
  365. )
  366. CompileEventLogger.instant(
  367. "automatic_dynamic",
  368. {
  369. "name": name,
  370. "dim_changed": "all" if i is None else i,
  371. "reason": long_reason,
  372. "cached": str(old_entry_tup),
  373. "new": str(entry_tup),
  374. },
  375. )
  376. if is_update and old_entry.size != mut_entry.size:
  377. if isinstance(old_entry.size, tuple) and isinstance(entry.size, tuple):
  378. if len(old_entry.size) != len(entry.size):
  379. log_tup("size", "dim", "dimensionality change")
  380. else:
  381. for i in range(len(entry.size)):
  382. if old_entry.size[i] != entry.size[i]:
  383. log_tup("size", f"size({i})", "size change", i)
  384. else:
  385. log_tup("size", "other", "other")
  386. if is_update and old_entry.stride != mut_entry.stride:
  387. if isinstance(old_entry.stride, tuple) and isinstance(entry.stride, tuple):
  388. if len(old_entry.stride) != len(entry.stride):
  389. log_tup("stride", "dim", "dimensionality change")
  390. else:
  391. for i in range(len(entry.stride)):
  392. if old_entry.stride[i] != entry.stride[i]:
  393. log_tup("stride", f"stride({i})", "stride change", i)
  394. else:
  395. log_tup("stride", "other", "other")
  396. else:
  397. old_entry = frame_state.automatic_dynamic[name]
  398. log.debug(
  399. "automatic dynamic is off, overwriting int %s val %s -> %s",
  400. name,
  401. old_entry.scalar,
  402. entry.scalar,
  403. )
  404. frame_state.automatic_dynamic[name] = entry
  405. mut_entry = entry
  406. return mut_entry
  407. def process_automatic_dynamic(
  408. tx: InstructionTranslator,
  409. name: str,
  410. entry: FrameStateSizeEntry,
  411. *,
  412. is_unspecialized_nn_module: bool = False,
  413. ) -> FrameStateSizeEntry:
  414. if (st := tx.distributed_state) is None:
  415. return update_automatic_dynamic(
  416. tx,
  417. name,
  418. entry,
  419. is_unspecialized_nn_module=is_unspecialized_nn_module,
  420. )
  421. elif st.all_states is None:
  422. # Preflight, always pretend as if it's static. The point here
  423. # is we want to get through the preflight quickly, and static
  424. # will run faster. The preexisting frame state will get
  425. # applied anyway after we do compiler collectives.
  426. # TODO: I'm not sure if we should just bong the entire pgo
  427. # state here, it kind of depends if we're going to have other
  428. # things that talk in compiler collective. Also, the PGO
  429. # state, if we've already inferred something is automatic
  430. # dynamic, will have lost the actual input sizes, which might
  431. # be useful for debugging purposes (e.g., observing 0/1
  432. # specialization). Bonging the entire PGO state here would
  433. # let us delete this logic here; the compiler collective
  434. # would just directly update_automatic_dynamic
  435. st.local_state.automatic_dynamic[name] = entry
  436. return entry
  437. else:
  438. # Apply the updates. NB: all_states includes the local state
  439. # too.
  440. res = None
  441. for sub_state in st.all_states:
  442. if name in sub_state.automatic_dynamic:
  443. res = update_automatic_dynamic(
  444. tx,
  445. name,
  446. sub_state.automatic_dynamic[name],
  447. is_unspecialized_nn_module=is_unspecialized_nn_module,
  448. )
  449. assert res is not None
  450. return res
  451. def format_cache_key(key: str) -> str:
  452. # NB: We always use global rank for keys, even though they are overkill
  453. # for local only cache
  454. rank = None
  455. if dist.is_available() and dist.is_initialized():
  456. rank = dist.get_rank()
  457. tag = torch.compiler.config.cache_key_tag
  458. return f"{key}:{rank}:{tag}"
  459. def get_cache_key() -> Optional[str]:
  460. # TODO: info versions of these logs that log only once
  461. if torch.compiler.config.force_disable_caches:
  462. warn_once(
  463. "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
  464. )
  465. return None
  466. # NB: We namespace the cache keys so that only user-specified job id
  467. # can alias with each other.
  468. if (r := torch.compiler.config.job_id) is not None:
  469. if r.startswith("mast:"):
  470. raise ReservedWorkflowIdUserError(
  471. "torch.compiler.config.job_id with prefix 'mast:' is reserved for "
  472. "automatically generated job id associated with a specific MAST job "
  473. "name and version."
  474. )
  475. return format_cache_key(r)
  476. if (name_version := torch._utils_internal.get_mast_job_name_version()) is not None:
  477. mast_job_name, mast_job_version = name_version
  478. return format_cache_key(f"mast:{mast_job_name}:{mast_job_version}")
  479. return None
  480. def get_extra_cache_key(sticky_key: str) -> Optional[str]:
  481. if torch.compiler.config.force_disable_caches:
  482. warn_once(
  483. "dynamo_pgo force disabled by torch.compiler.config.force_disable_caches"
  484. )
  485. return None
  486. return format_cache_key(sticky_key)
  487. # This solely controls local PGO
  488. def code_state_path(cache_key: str) -> Optional[str]:
  489. if not torch._dynamo.config.automatic_dynamic_local_pgo:
  490. log.debug("automatic_dynamic_local_pgo not enabled")
  491. return None
  492. from torch._inductor.runtime.runtime_utils import cache_dir
  493. code_state_key = re.sub(r'[<>:"/\\|?*]', "_", f"code_state_{cache_key}.pkl")
  494. return os.path.join(cache_dir(), "dynamo", code_state_key)
  495. def should_use_remote_dynamo_pgo_cache() -> bool:
  496. if torch.compiler.config.force_disable_caches:
  497. return False
  498. if (r := torch._dynamo.config.automatic_dynamic_remote_pgo) is not None:
  499. return r
  500. if not is_fbcode():
  501. return False
  502. if torch._utils_internal.is_fb_unit_test():
  503. return False
  504. try:
  505. from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
  506. except ModuleNotFoundError:
  507. return False
  508. return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
  509. "pytorch/remote_cache:dynamo_pgo_version"
  510. )
  511. def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]:
  512. from torch._inductor.remote_cache import create_cache
  513. if not should_use_remote_dynamo_pgo_cache():
  514. return None
  515. return create_cache(
  516. "dynamo-pgo",
  517. is_fbcode(),
  518. "FbRemoteDynamoPGOCache",
  519. "RemoteDynamoPGOCache",
  520. )
  521. def _collect_dynamic_sources(code_state: CodeState) -> OrderedSet[str]:
  522. dynamic_sources: OrderedSet[str] = OrderedSet()
  523. for src, fs in code_state.automatic_dynamic.items():
  524. dynamic = False
  525. if isinstance(fs.size, tuple):
  526. dynamic = auto_dynamic in fs.size # type: ignore[operator]
  527. elif fs.scalar == auto_dynamic:
  528. dynamic = True
  529. if dynamic:
  530. dynamic_sources.add(src)
  531. return dynamic_sources
  532. def log_frame_dynamic_whitelist(f_code: types.CodeType) -> None:
  533. global _LOGGED_DYNAMIC_ALLOWLIST
  534. code_id = CodeId.make(f_code)
  535. frame_state = get_code_state()[code_id]
  536. frame_whitelist = ",".join(_collect_dynamic_sources(frame_state))
  537. if frame_whitelist:
  538. with dynamo_timed(name := "pgo.dynamic_whitelist", log_pt2_compile_event=True):
  539. CompileEventLogger.pt2_compile(
  540. name, recompile_dynamic_whitelist=frame_whitelist
  541. )
  542. if not _LOGGED_DYNAMIC_ALLOWLIST:
  543. torch._utils_internal.add_mlhub_insight(
  544. category="dynamic_shapes_analysis",
  545. insight="Dynamic shape recompilation detected",
  546. insight_description="PGO detected a recompilation due to dynamic shapes. \
  547. Please follow the instruction from the action link to reduce \
  548. recompilation overhead.",
  549. )
  550. # add mlhub insight only once per rank
  551. _LOGGED_DYNAMIC_ALLOWLIST = True
  552. def render_code_state(cs: defaultdict[CodeId, CodeState]) -> str:
  553. code_state_str = "\n".join(
  554. f"{k}:\n"
  555. + "\n".join(
  556. f" {src}: {fs.render()}" for src, fs in v.automatic_dynamic.items()
  557. )
  558. for k, v in cs.items()
  559. )
  560. dynamic_sources: OrderedSet[str] = OrderedSet()
  561. for state in cs.values():
  562. dynamic_sources.update(_collect_dynamic_sources(state))
  563. if dynamic_sources:
  564. code_state_str += (
  565. "\n\nPGO detected a recompilation due to dynamic shapes. "
  566. "To reduce shape recompilations by compiling dynamically to start, "
  567. f'set environment variable TORCH_COMPILE_DYNAMIC_SOURCES="{",".join(dynamic_sources)}"'
  568. )
  569. return code_state_str
  570. def merge_pgo_entry(src: FrameStateSizeEntry, dst: FrameStateSizeEntry) -> None:
  571. def rank(entry: FrameStateSizeEntry) -> int:
  572. if not isinstance(entry.size, tuple): # scalar
  573. return -1
  574. return len(entry.size)
  575. if rank(src) == rank(dst): # both tensors same rank, or both scalars
  576. dst |= src
  577. @CacheArtifactFactory.register
  578. class PGOCacheArtifact(CacheArtifact):
  579. @override
  580. def populate_cache(self) -> None:
  581. meta = write_local_impl(
  582. self._rewrite_cache_key_for_mega_cache(self.key), self.content
  583. )
  584. assert meta is not None
  585. @override
  586. @staticmethod
  587. def type() -> str:
  588. return "pgo"
  589. @staticmethod
  590. def _rewrite_cache_key_for_mega_cache(original_key: str) -> str:
  591. """
  592. The PGO cache artifact key for a MAST job contains the job name and the version.
  593. When we want to use the cache artifact on a different MAST job, we need to
  594. update the key to use the new MAST job's name and version.
  595. """
  596. if not original_key.startswith("mast:"):
  597. # if original_key is overridden, then dont change it
  598. return original_key
  599. if (new_key := get_cache_key()) is not None:
  600. return new_key
  601. return original_key
  602. def hit(key: str, ty: str) -> defaultdict[CodeId, CodeState]:
  603. global _INIT_CODE_STATE
  604. assert isinstance(_CODE_STATE, defaultdict)
  605. log.info("get_code_state %s hit %s, %d entries", key, ty, len(_CODE_STATE))
  606. trace_structured_artifact(
  607. f"get_{ty}_code_state",
  608. "string",
  609. lambda: render_code_state(_CODE_STATE), # type: ignore[arg-type]
  610. )
  611. set_feature_use("pgo", True)
  612. _INIT_CODE_STATE = copy.deepcopy(_CODE_STATE)
  613. return _CODE_STATE
  614. def get_local_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]:
  615. global _CODE_STATE
  616. path = code_state_path(cache_key)
  617. if path is not None and os.path.exists(path):
  618. with dynamo_timed(
  619. name := "pgo.get_local_code_state", log_pt2_compile_event=True
  620. ):
  621. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  622. # Read lock not necessary as we always write atomically write to
  623. # the actual location
  624. with open(path, "rb") as f:
  625. try:
  626. content = f.read()
  627. _CODE_STATE = pickle.loads(content)
  628. CompileEventLogger.pt2_compile(name, cache_size_bytes=f.tell())
  629. except Exception:
  630. log.warning(
  631. "get_code_state failed while reading %s", path, exc_info=True
  632. )
  633. else:
  634. CacheArtifactManager.record_artifact(
  635. PGOCacheArtifact.type(), cache_key, content
  636. )
  637. return hit(path, "local")
  638. return None
  639. def lookup_remote_cache_entry(
  640. remote_cache: RemoteCache[JsonDataTy],
  641. cache_key: str,
  642. event_name: Optional[str] = None,
  643. ) -> Optional[defaultdict[CodeId, CodeState]]:
  644. code_state = None
  645. try:
  646. cache_data = remote_cache.get(cache_key)
  647. except Exception:
  648. log.warning("get_code_state failed remote read on %s", cache_key, exc_info=True)
  649. else:
  650. if cache_data is not None:
  651. try:
  652. assert isinstance(cache_data, dict)
  653. data = cache_data["data"]
  654. assert isinstance(data, str)
  655. payload = base64.b64decode(data)
  656. if event_name is not None:
  657. CompileEventLogger.pt2_compile(
  658. event_name, cache_size_bytes=len(payload)
  659. )
  660. code_state = pickle.loads(payload)
  661. except Exception:
  662. log.warning(
  663. "get_code_state failed parsing remote result on %s",
  664. cache_key,
  665. exc_info=True,
  666. )
  667. else:
  668. CacheArtifactManager.record_artifact(
  669. PGOCacheArtifact.type(), cache_key, payload
  670. )
  671. else:
  672. log.info("get_code_state remote miss on %s", cache_key)
  673. return code_state
  674. def get_remote_code_state(cache_key: str) -> Optional[defaultdict[CodeId, CodeState]]:
  675. global _CODE_STATE
  676. remote_cache = get_remote_cache()
  677. if remote_cache is not None:
  678. with dynamo_timed(
  679. name := "pgo.get_remote_code_state",
  680. log_pt2_compile_event=True,
  681. dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
  682. ):
  683. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  684. code_state = lookup_remote_cache_entry(remote_cache, cache_key, name)
  685. if code_state is not None:
  686. _CODE_STATE = code_state
  687. return hit(cache_key, "remote")
  688. return None
  689. def add_extra_remote_code_state(cache_key: str) -> None:
  690. """
  691. Reads an additional PGO profile from the given cache key, and merges it with the default PGO profile.
  692. """
  693. global _CODE_STATE
  694. assert _CODE_STATE is not None
  695. remote_cache = get_remote_cache()
  696. if remote_cache is not None:
  697. with dynamo_timed(
  698. name := "pgo.add_extra_remote_code_state",
  699. log_pt2_compile_event=True,
  700. dynamo_compile_column_us="pgo_get_remote_code_state_time_us",
  701. ):
  702. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  703. code_state = lookup_remote_cache_entry(remote_cache, cache_key)
  704. log.info(
  705. "add_extra_code_state %s hit, %d entries",
  706. cache_key,
  707. len(code_state) if code_state is not None else 0,
  708. )
  709. if code_state is not None:
  710. # merge the code state into the current one
  711. for code_id, state in code_state.items():
  712. if code_id in _CODE_STATE:
  713. for src, entry in state.automatic_dynamic.items():
  714. # NOTE: maybe we need an "unsafe" merge to handle this,
  715. # where one entry might be 1-d, the other 2-d.
  716. # or if entries are of different types?
  717. # with local source naming, could be scalar vs. tensor
  718. merge_pgo_entry(
  719. entry, _CODE_STATE[code_id].automatic_dynamic[src]
  720. )
  721. else:
  722. _CODE_STATE[code_id] = state
  723. # log to tlparse
  724. trace_structured_artifact(
  725. "add_extra_remote_code_state",
  726. "string",
  727. lambda: render_code_state(code_state),
  728. )
  729. def get_code_state() -> defaultdict[CodeId, CodeState]:
  730. global _CODE_STATE, _INIT_CODE_STATE
  731. if _CODE_STATE is not None:
  732. return _CODE_STATE
  733. # Initialize it (even if we don't look up profile)
  734. _CODE_STATE = defaultdict(CodeState)
  735. cache_key = get_cache_key()
  736. if cache_key is None:
  737. return _CODE_STATE
  738. # Attempt local
  739. local_code_state = get_local_code_state(cache_key)
  740. # Attempt remote
  741. if local_code_state is None:
  742. get_remote_code_state(cache_key)
  743. # Attempt additional remote
  744. if (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None:
  745. extra_read_key = get_extra_cache_key(sticky_read)
  746. if extra_read_key is not None:
  747. add_extra_remote_code_state(extra_read_key)
  748. log.info("get_code_state using default")
  749. assert _CODE_STATE is not None
  750. return _CODE_STATE
  751. def put_code_state() -> None:
  752. if _CODE_STATE is None:
  753. log.info("put_code_state: never initialized, will not write")
  754. return
  755. if _CODE_STATE == _INIT_CODE_STATE:
  756. log.info("put_code_state: no change, skipping")
  757. return
  758. cache_key = get_cache_key()
  759. if cache_key is None:
  760. log.info("put_code_state: no cache key, skipping")
  761. return
  762. put_local_code_state(cache_key)
  763. put_remote_code_state(cache_key)
  764. if (sticky_write := torch.compiler.config.pgo_extra_write_key) is not None:
  765. extra_write_key = get_extra_cache_key(sticky_write)
  766. if extra_write_key is not None:
  767. put_remote_code_state(extra_write_key)
  768. def write_local_impl(cache_key: str, pickled_code: bytes) -> Optional[tuple[str, int]]:
  769. path = code_state_path(cache_key)
  770. if path is None:
  771. return None
  772. # If the user isn't misusing our API, we should have exclusive access to
  773. # this directory. But it's not too hard
  774. tmp_path = path + ".tmp"
  775. lock_path = path + ".lock"
  776. # We /mostly/ don't need the lock but the tmp file could be clobbered
  777. # TODO: use a safe tempfile create to eliminate lock
  778. from torch.utils._filelock import FileLock
  779. os.makedirs(os.path.dirname(path), exist_ok=True)
  780. with FileLock(lock_path, timeout=LOCK_TIMEOUT):
  781. with open(tmp_path, "wb") as f:
  782. f.write(pickled_code)
  783. size = f.tell()
  784. os.replace(tmp_path, path)
  785. return path, size
  786. def put_local_code_state(cache_key: str) -> None:
  787. with dynamo_timed(name := "pgo.put_local_code_state", log_pt2_compile_event=True):
  788. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  789. assert _CODE_STATE is not None
  790. pickled_code = pickle.dumps(_CODE_STATE)
  791. CacheArtifactManager.record_artifact(
  792. PGOCacheArtifact.type(), cache_key, pickled_code
  793. )
  794. meta = write_local_impl(cache_key, pickled_code)
  795. if meta is None:
  796. log.info("put_code_state: local cache disabled")
  797. return
  798. path, size = meta
  799. CompileEventLogger.pt2_compile(name, cache_size_bytes=size)
  800. log.info("put_code_state: wrote local %s, %d entries", path, len(_CODE_STATE))
  801. trace_structured_artifact(
  802. "put_local_code_state",
  803. "string",
  804. lambda: render_code_state(_CODE_STATE),
  805. )
  806. def put_remote_code_state(cache_key: str) -> None:
  807. with dynamo_timed(
  808. name := "pgo.put_remote_code_state",
  809. log_pt2_compile_event=True,
  810. dynamo_compile_column_us="pgo_put_remote_code_state_time_us",
  811. ):
  812. CompileEventLogger.pt2_compile(name, cache_key=cache_key)
  813. assert _CODE_STATE is not None
  814. remote_cache = get_remote_cache()
  815. if remote_cache is None:
  816. log.info("put_code_state: remote cache disabled")
  817. return
  818. content = pickle.dumps(_CODE_STATE)
  819. CompileEventLogger.pt2_compile(name, cache_size_bytes=len(content))
  820. cache_data: JsonDataTy = {
  821. "data": base64.b64encode(content).decode("ascii"),
  822. }
  823. remote_cache.put(cache_key, cache_data)
  824. log.info(
  825. "put_code_state: wrote remote %s, %d entries", cache_key, len(_CODE_STATE)
  826. )
  827. # TODO: don't log this multiple times
  828. trace_structured_artifact(
  829. "put_remote_code_state",
  830. "string",
  831. lambda: render_code_state(_CODE_STATE),
  832. )
  833. # NB: this does NOT reset the cached code state on disk
  834. def reset_code_state() -> None:
  835. global _CODE_STATE, _INIT_CODE_STATE, _LOGGED_DYNAMIC_ALLOWLIST
  836. _CODE_STATE = None
  837. _INIT_CODE_STATE = None
  838. _LOGGED_DYNAMIC_ALLOWLIST = False