debug_utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. """
  2. Debug utilities for TorchDynamo compilation and execution.
  3. This module provides various debugging tools and utilities for TorchDynamo, including:
  4. - Minification support for reducing test cases while preserving bugs
  5. - Input/output handling via InputReader and InputWriter for reproducible testing
  6. - Accuracy checking between original and compiled models
  7. - Neural network module string conversion via NNModuleToString
  8. - Profiling tools and system information collection
  9. - Buck build system integration for Meta-internal testing
  10. Key classes:
  11. - InputReader/InputWriter: Handle serialization of model inputs/outputs
  12. - NNModuleToString: Converts nn.Modules to string representations
  13. - BuckTargetWriter: Manages Buck build system integration
  14. """
  15. from __future__ import annotations
  16. import atexit
  17. import copy
  18. import cProfile
  19. import functools
  20. import getpass
  21. import inspect
  22. import itertools
  23. import logging
  24. import os
  25. import re
  26. import subprocess
  27. import sys
  28. import tempfile
  29. import textwrap
  30. from collections import Counter
  31. from importlib import import_module
  32. from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
  33. import torch
  34. import torch._prims_common as utils
  35. import torch._subclasses.meta_utils
  36. from torch import Tensor
  37. from torch._dynamo.testing import rand_strided
  38. from torch._inductor.cpp_builder import normalize_path_separator
  39. from torch._prims_common import is_float_dtype
  40. from torch.multiprocessing.reductions import StorageWeakRef
  41. from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
  42. from . import config
  43. from .utils import clone_inputs, get_debug_dir
  44. if TYPE_CHECKING:
  45. from collections.abc import Sequence
  46. from torch.hub import tqdm
  47. from torch.storage import UntypedStorage
  48. log = logging.getLogger(__name__)
  49. T = TypeVar("T")
  50. inductor_config = import_module("torch._inductor.config")
  51. use_buck = inductor_config.is_fbcode()
  52. if use_buck:
  53. import libfb.py.build_info
  54. extra_deps = []
  55. extra_imports = ""
  56. cur_target = ""
  57. if use_buck:
  58. extra_deps = [
  59. "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
  60. "//caffe2/torch/fb/sparsenn:sparsenn_operators",
  61. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
  62. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
  63. ]
  64. cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
  65. extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
  66. BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
  67. class BuckTargetWriter:
  68. def __init__(self, filename: str) -> None:
  69. self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
  70. self.target = self.py_file.replace(".py", "")
  71. # Get main_module path from fbcode
  72. self.path = f"{self.subdir.replace('/', '.')}.{self.target}"
  73. self.path = self.path[self.path.find("fbcode.") :]
  74. self.path = self.path[7:]
  75. # Get cmd line path
  76. tmp = self.subdir
  77. tmp = tmp[tmp.find("fbcode/") :][7:]
  78. self.cmd_line_path = f"//{tmp}:{self.target}"
  79. def build(self) -> str:
  80. extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
  81. return textwrap.dedent(
  82. f"""
  83. load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
  84. python_binary(
  85. name="{self.target}",
  86. srcs = ["{self.py_file}"],
  87. compile = False,
  88. deps = [
  89. "//caffe2:torch",
  90. "//caffe2:libtorch",
  91. "//caffe2/functorch:functorch",
  92. "//triton:triton",
  93. "{cur_target}",
  94. ],
  95. cpp_deps = [
  96. {extra_cpp_deps}
  97. ],
  98. main_module = "{self.path}",
  99. par_style = "xar",
  100. )
  101. """
  102. )
  103. def write(self, print_msg: bool = True) -> list[str]:
  104. target_file = os.path.join(self.subdir, "TARGETS")
  105. with open(target_file, "w") as fd:
  106. fd.write(self.build())
  107. # log.warning("Wrote isolation TARGETS file at %s", target_file)
  108. cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
  109. if print_msg:
  110. log.warning(
  111. "Found an example that reproduces the error. Run this cmd to repro - %s",
  112. " ".join(cmd_split),
  113. )
  114. return cmd_split
  115. def minifier_dir() -> str:
  116. path = os.path.join(get_debug_dir(), "minifier")
  117. if path is None:
  118. path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
  119. if not os.path.exists(path):
  120. os.makedirs(path, exist_ok=True)
  121. return path
  122. MAX_CONSTANT_NUMEL_INLINE = 4
  123. class NNModuleToString:
  124. safe_reprs = [
  125. torch.nn.Linear,
  126. torch.nn.Conv1d,
  127. torch.nn.Conv2d,
  128. torch.nn.Conv3d,
  129. torch.nn.BatchNorm1d,
  130. torch.nn.BatchNorm2d,
  131. torch.nn.BatchNorm3d,
  132. torch.nn.LayerNorm,
  133. torch.nn.Dropout,
  134. torch.nn.Softmax,
  135. torch.nn.ReLU,
  136. torch.nn.GELU,
  137. torch.nn.Identity,
  138. torch.nn.MaxPool2d,
  139. torch.nn.Embedding,
  140. torch.nn.Tanh,
  141. torch.nn.ConvTranspose1d,
  142. torch.nn.GLU,
  143. torch.nn.LSTM,
  144. torch.nn.Flatten,
  145. torch.nn.AdaptiveAvgPool2d,
  146. ]
  147. @staticmethod
  148. def can_convert_to_string(gm: torch.fx.GraphModule) -> bool:
  149. cant_convert = set()
  150. for _, module in gm.named_children():
  151. if type(module) not in NNModuleToString.safe_reprs:
  152. cant_convert.add(module)
  153. if len(cant_convert) > 0:
  154. log.warning("We have not tested reprs of some modules - %s", cant_convert)
  155. # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
  156. return True
  157. @staticmethod
  158. def convert(gm: torch.fx.GraphModule) -> str:
  159. from torch.nn.modules.module import _addindent
  160. tab = " " * 4
  161. model_str = textwrap.dedent(
  162. """
  163. from torch.nn import *
  164. class Repro(torch.nn.Module):
  165. def __init__(self) -> None:
  166. super().__init__()
  167. """
  168. )
  169. for module_name, module in gm.named_children():
  170. module_str = f"{module.__repr__()}"
  171. # module should be a core torch.nn.Module, so all parameters
  172. # should be on the same device.
  173. example_param = next(module.parameters(), None)
  174. if example_param is not None and example_param.is_cuda:
  175. module_str = f"{module_str}.cuda()"
  176. model_str += f"{tab * 2}self.{module_name} = {module_str}\n"
  177. for buffer_name, buffer in gm._buffers.items():
  178. if buffer is None:
  179. continue
  180. # Serialize full data for small buffers
  181. if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
  182. from torch._tensor_str import PRINT_OPTS
  183. assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
  184. tensor_str = repr(buffer)
  185. elif torch.is_floating_point(buffer):
  186. tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
  187. else:
  188. tensor_str = (
  189. f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
  190. )
  191. if buffer.is_cuda:
  192. tensor_str = f"{tensor_str}.cuda()"
  193. model_str += (
  194. f"{tab * 2}self.register_buffer('{buffer_name}', {tensor_str})\n"
  195. )
  196. for param_name, param in gm._parameters.items():
  197. if param is None:
  198. continue
  199. maybe_device = ""
  200. if param.is_cuda:
  201. maybe_device = ', device="cuda"'
  202. tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
  203. model_str += f"{tab * 2}self.{param_name} = {tensor_str}\n"
  204. # TODO - Keep this code for now. But, I don't think we will need this.
  205. # attrs = dir(gm)
  206. # for attr in attrs:
  207. # if "_tensor_constant" in attr:
  208. # val = getattr(gm, attr)
  209. # model_str += f" {attr} = {val!r}\n"
  210. model_str += f"{_addindent(gm.code, 4)}\n"
  211. return model_str
  212. @functools.cache # subprocess is expensive
  213. def _cuda_system_info_comment() -> str:
  214. if not torch.cuda.is_available():
  215. return "# torch.cuda.is_available()==False, no GPU info collected\n"
  216. model_str = "# CUDA Info: \n"
  217. try:
  218. cuda_version_out = subprocess.check_output(["nvcc", "--version"])
  219. cuda_version_lines = cuda_version_out.decode().split("\n")
  220. comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
  221. model_str += f"{comment}\n"
  222. except (FileNotFoundError, subprocess.CalledProcessError):
  223. model_str += "# nvcc not found\n"
  224. gpu_names = Counter(
  225. torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
  226. )
  227. model_str += "# GPU Hardware Info: \n"
  228. for name, count in gpu_names.items():
  229. model_str += f"# {name} : {count} \n"
  230. model_str += "\n"
  231. return model_str
  232. def generate_env_vars_string(*, stable_output: bool = False) -> str:
  233. """
  234. Generate a string configuration for environment variables related to Dynamo, Inductor, and Triton.
  235. """
  236. if stable_output:
  237. return "# env var omitted due to stable_output=True"
  238. allow_list = ["TORCH", "DYNAMO", "INDUCTOR", "TRITON"]
  239. skip_list = ["TRITON_LIBDEVICE_PATH", "TRITON_PTXAS_PATH", "TRITON_LIBCUDA_PATH"]
  240. def filter(key: str) -> bool:
  241. return any(string in key for string in allow_list) and key not in skip_list
  242. config_lines = [
  243. f"os.environ['{key}'] = '{value}'"
  244. for key, value in os.environ.items()
  245. if filter(key)
  246. ]
  247. config_string = "\n".join(config_lines)
  248. return normalize_path_separator(f"""\
  249. import os
  250. {config_string}
  251. """)
  252. def generate_config_string(*, stable_output: bool = False) -> str:
  253. import torch._functorch.config
  254. import torch._inductor.config
  255. if stable_output:
  256. return "# config omitted due to stable_output=True"
  257. experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
  258. return f"""\
  259. import torch._dynamo.config
  260. import torch._inductor.config
  261. import torch._functorch.config
  262. import torch.fx.experimental._config
  263. {torch._dynamo.config.codegen_config()}
  264. {torch._inductor.config.codegen_config()}
  265. {torch._functorch.config.codegen_config()}
  266. {experimental_config}
  267. """
  268. def get_minifier_repro_path() -> str:
  269. return os.path.join(minifier_dir(), "minifier_launcher.py")
  270. def helper_for_dump_minify(contents: str) -> None:
  271. minified_repro_path = get_minifier_repro_path()
  272. log.warning("Writing minified repro to:\n%s", minified_repro_path)
  273. if use_buck:
  274. BuckTargetWriter(minified_repro_path).write()
  275. try:
  276. with open(minified_repro_path, "w") as fd:
  277. fd.write(contents)
  278. except OSError as e:
  279. log.exception("")
  280. raise NotImplementedError("Could not write to {minified_repro_path}") from e
  281. class AccuracyError(Exception):
  282. pass
  283. def clone_inputs_retaining_gradness(example_inputs: Sequence[Any]) -> list[Any]:
  284. """
  285. This clone inputs is different from utils clone_input. In case of minifier,
  286. all the tensors are leaf tensors while creating a new graph. So, we set the
  287. requires_grad field w/o checking the leafness of the tensor.
  288. """
  289. cloned_inputs = clone_inputs(example_inputs)
  290. for idx in range(len(example_inputs)):
  291. if isinstance(cloned_inputs[idx], torch.Tensor):
  292. cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
  293. return cloned_inputs # type: ignore[return-value]
  294. def run_fwd_maybe_bwd(
  295. gm: torch.fx.GraphModule,
  296. args: Sequence[Any],
  297. only_fwd: bool = False,
  298. disable_clone: bool = False,
  299. ) -> Any:
  300. """
  301. Runs a forward and possibly backward iteration for a given mod and args.
  302. When disable_clone is True, we will use args as-is without cloning.
  303. This is higher fidelity but we may destroy the args in the process.
  304. """
  305. from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
  306. gm = copy.deepcopy(gm)
  307. if not disable_clone:
  308. args = clone_inputs_retaining_gradness(args)
  309. if hasattr(gm, "zero_grad"):
  310. gm.zero_grad(True)
  311. # TorchInductor returned callable expects lists. So, may need a boxed calling convention.
  312. out = gm(args) if getattr(gm, "_boxed_call", False) else gm(*args)
  313. if only_fwd:
  314. return out
  315. if requires_bwd_pass(out):
  316. loss = reduce_to_scalar_loss(out)
  317. loss.backward()
  318. return collect_results(gm, out, None, args)
  319. def same_two_models(
  320. gm: torch.fx.GraphModule,
  321. opt_gm: torch.fx.GraphModule,
  322. example_inputs: Sequence[Any],
  323. only_fwd: bool = False,
  324. *,
  325. require_fp64: bool = False,
  326. ignore_non_fp: bool = False,
  327. ) -> bool:
  328. """
  329. Check two models have same accuracy.
  330. require_fp64: if True, raise an error if we unable to calculate the fp64 reference
  331. ignore_non_fp: if True, do not compare outputs which are not floating point. This
  332. is mostly useful for the minifier (which wants to avoid quantizing floating point
  333. error into integer/boolean error)
  334. """
  335. from .utils import same
  336. ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
  337. fp64_ref = None
  338. if config.same_two_models_use_fp64:
  339. try:
  340. fp64_model, fp64_examples = cast_to_fp64(
  341. copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
  342. )
  343. fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
  344. except Exception:
  345. if require_fp64:
  346. raise RuntimeError( # noqa: B904
  347. "Could not generate fp64 outputs, workaround with torch._dynamo.config.same_two_models_use_fp64 = False"
  348. )
  349. log.warning("Could not generate fp64 outputs")
  350. try:
  351. res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
  352. except Exception:
  353. # This means that the minified graph is bad/exposes a different problem.
  354. # As we are checking accuracy here, lets log the exception and return True.
  355. log.exception(
  356. "While minifying the program in accuracy minification mode, "
  357. "ran into a runtime exception which is likely an unrelated issue."
  358. " Skipping this graph."
  359. )
  360. return True
  361. passing = same(
  362. ref,
  363. res,
  364. fp64_ref,
  365. tol=config.repro_tolerance,
  366. equal_nan=True,
  367. ignore_non_fp=ignore_non_fp,
  368. )
  369. return passing
  370. def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
  371. for node in model.graph.nodes:
  372. if (
  373. node.op == "call_function"
  374. and node.target == torch.ops.prims.convert_element_type.default
  375. ):
  376. assert len(node.args) == 2
  377. if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
  378. node.args = (node.args[0], torch.float64)
  379. if node.op == "call_function":
  380. dtype = node.kwargs.get("dtype")
  381. if dtype is not None and is_float_dtype(dtype):
  382. new_kwargs = dict(node.kwargs)
  383. new_kwargs["dtype"] = torch.float64
  384. node.kwargs = new_kwargs
  385. model.graph.lint()
  386. model.recompile()
  387. return model
  388. def cast_to(
  389. dtype: torch.dtype, model: torch.fx.GraphModule, inputs: list[Any]
  390. ) -> tuple[torch.fx.GraphModule, list[Any]]:
  391. from torch.utils._pytree import tree_map
  392. model = model.to(dtype)
  393. if dtype == torch.float64:
  394. # If casting to fp64 for accuracy comparison, we need to
  395. # replace dtype arguments embedded in the graph with fp64
  396. model = cast_dtype_args_to_fp64(model)
  397. inputs = tree_map(
  398. lambda x: x.to(dtype)
  399. if isinstance(x, torch.Tensor) and x.is_floating_point()
  400. else x,
  401. inputs,
  402. )
  403. return model, inputs
  404. def cast_to_fp64(
  405. model: torch.fx.GraphModule, inputs: list[Any]
  406. ) -> tuple[torch.fx.GraphModule, list[Any]]:
  407. return cast_to(torch.float64, model, inputs)
  408. def backend_accuracy_fails(
  409. gm: torch.fx.GraphModule,
  410. example_inputs: Sequence[Any],
  411. compiler_fn: Callable[[torch.fx.GraphModule, list[Any]], torch.fx.GraphModule],
  412. only_fwd: bool = False,
  413. *,
  414. require_fp64: bool = False,
  415. ignore_non_fp: bool = False,
  416. ) -> bool:
  417. try:
  418. compiled_gm = compiler_fn(
  419. copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
  420. )
  421. return not same_two_models(
  422. gm,
  423. compiled_gm,
  424. example_inputs,
  425. only_fwd,
  426. require_fp64=require_fp64,
  427. ignore_non_fp=ignore_non_fp,
  428. )
  429. except Exception:
  430. # This means that the minified graph is bad/exposes a different problem.
  431. # As we are checking accuracy here, lets log the exception and return False.
  432. log.exception(
  433. "While minifying the program in accuracy minification mode, "
  434. "ran into a runtime exception which is likely an unrelated issue."
  435. " Skipping this graph"
  436. )
  437. return False
  438. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  439. # REPRO SUPPORT CODE
  440. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  441. # Helper functions for computing what the default values of tensor
  442. # values should be. These all coincide with factory functions, e.g., torch.empty
  443. def _stride_or_default(
  444. stride: Optional[torch._prims_common.StrideType],
  445. *,
  446. shape: torch._prims_common.ShapeType,
  447. ) -> torch._prims_common.StrideType:
  448. return stride if stride is not None else utils.make_contiguous_strides_for(shape)
  449. def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
  450. return lambda x: x if x is not None else d
  451. _dtype_or_default = _mk_defaulter(torch.float32)
  452. _device_or_default = _mk_defaulter(torch.device("cpu"))
  453. _storage_offset_or_default = _mk_defaulter(0)
  454. _requires_grad_or_default = _mk_defaulter(False)
  455. _is_leaf_or_default = _mk_defaulter(False)
  456. class NopInputReader:
  457. def __init__(self) -> None:
  458. self.total = 0
  459. def storage(
  460. self,
  461. storage_hash: Optional[str],
  462. nbytes: int,
  463. *,
  464. device: Optional[torch._prims_common.DeviceLikeType] = None,
  465. dtype_hint: Optional[torch.dtype] = None,
  466. ) -> None:
  467. self.total += 1
  468. def tensor(self, *args: Any, **kwargs: Any) -> Optional[torch.Tensor]:
  469. pass
  470. def symint(self, *args: Any, **kwargs: Any) -> Optional[int]:
  471. pass
  472. # TODO: Support bundling the entire repro into a zip file for ease of
  473. # transferring around
  474. class InputReader:
  475. def __init__(self, save_dir: Optional[str] = None, *, pbar: Optional[tqdm] = None):
  476. # If None, we will generate random data instead. It's important
  477. # to natively support this use case as it will allow people to
  478. # share repros without including the real data, if the problem
  479. # reproduces even on random data.
  480. if save_dir is None:
  481. log.warning("no save_dir specified, will generate random data")
  482. self.store = ContentStoreReader(save_dir) if save_dir is not None else None
  483. self.args: list[Any] = []
  484. self.pbar = pbar
  485. def storage(
  486. self,
  487. storage_hash: Optional[str],
  488. nbytes: int,
  489. *,
  490. device: Optional[torch._prims_common.DeviceLikeType] = None,
  491. dtype_hint: Optional[torch.dtype] = None,
  492. ) -> UntypedStorage:
  493. if self.pbar is not None:
  494. self.pbar.update(1)
  495. device = _device_or_default(device) # type: ignore[arg-type]
  496. dtype_hint = _dtype_or_default(dtype_hint)
  497. if self.store is not None and storage_hash is not None:
  498. try:
  499. storage = self.store.read_storage(storage_hash)
  500. except FileNotFoundError:
  501. pass
  502. else:
  503. if device != storage.device:
  504. log.warning("device mismatch: %s != %s", device, storage.device)
  505. # TODO: transfer it to the right device? But failing this
  506. # way would be very mysterious! Would have been better
  507. # not to store device in the serialized format...
  508. return storage
  509. log.warning("could not load %s, generating random data instead", storage_hash)
  510. shape = (nbytes // dtype_hint.itemsize,)
  511. stride = _stride_or_default(None, shape=shape)
  512. return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
  513. def tensor(
  514. self,
  515. storage: UntypedStorage,
  516. shape: torch._prims_common.ShapeType,
  517. stride: Optional[torch._prims_common.StrideType] = None,
  518. *,
  519. storage_offset: Optional[int] = None,
  520. dtype: Optional[torch.dtype] = None,
  521. requires_grad: Optional[bool] = None,
  522. is_leaf: Optional[bool] = None,
  523. **metadata: Any,
  524. ) -> torch.Tensor:
  525. stride = _stride_or_default(stride, shape=shape)
  526. storage_offset = _storage_offset_or_default(storage_offset)
  527. dtype = _dtype_or_default(dtype)
  528. is_leaf = _is_leaf_or_default(is_leaf)
  529. requires_grad = _requires_grad_or_default(requires_grad)
  530. t = torch.tensor(
  531. [], dtype=dtype, device=storage.device, requires_grad=requires_grad
  532. )
  533. with torch.no_grad():
  534. t.set_(storage, storage_offset, shape, stride)
  535. if not is_leaf:
  536. # Fake up some autograd history in a very naughty way
  537. with torch.enable_grad():
  538. t = t.clone(memory_format=torch.preserve_format)
  539. with torch.no_grad():
  540. t.set_(storage, storage_offset, shape, stride)
  541. assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
  542. torch._utils.set_tensor_metadata(t, metadata)
  543. self.args.append(t)
  544. return t # for BC
  545. def symint(self, val: Any) -> Any:
  546. self.args.append(val)
  547. return val # for BC
  548. # Here is our writer strategy:
  549. # 1. We will stream all of the inputs to disk
  550. # 2. You can now deterministically randomize the inputs, or reload
  551. # the inputs from disk
  552. # 3. You can YOLO run the script without the inputs, in which case
  553. # we'll fill the inputs with random data and pray. This is the
  554. # legacy behavior, but it's also useful if you want to find out
  555. # if we're so broken even random inputs trigger it
  556. # 4. We could offer an in process "check if the randomized thing
  557. # works too" but this is delicate so we don't do it
  558. class InputWriter:
  559. def __init__(self, save_dir: Optional[str], *, stable_hash: bool = False) -> None:
  560. self._lines: list[str] = []
  561. # TODO: consider ensuring tensor and storage counters line up?
  562. self.storage_counter = itertools.count()
  563. self.save_dir = save_dir
  564. self.store = (
  565. ContentStoreWriter(save_dir, stable_hash=stable_hash)
  566. if save_dir is not None
  567. else None
  568. )
  569. self.seen_storages: dict[StorageWeakRef, str] = {}
  570. def lines(self) -> list[str]:
  571. r = [
  572. "def load_args(reader):",
  573. ]
  574. r.extend(f" {l}" for l in self._lines)
  575. # In case we need to change the internal format of load_args
  576. # in an FC-breaking way
  577. r.append("load_args._version = 0")
  578. return r
  579. # Storages are untyped, but we need to initialize them with data if
  580. # we don't have the real data, so we give a hint saying what kind
  581. # of initialization may be appropriate
  582. #
  583. # If we had a FakeTensor, device_hint tells us what device should be
  584. def storage(
  585. self,
  586. untyped_storage: UntypedStorage,
  587. *,
  588. device_hint: Optional[torch._prims_common.DeviceLikeType] = None,
  589. dtype_hint: Optional[torch.dtype] = None,
  590. ) -> str:
  591. ws = StorageWeakRef(untyped_storage)
  592. v = self.seen_storages.get(ws)
  593. if v is not None:
  594. return v
  595. v = f"buf{next(self.storage_counter)}"
  596. maybe_dtype_hint = ""
  597. if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
  598. maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
  599. # TODO: being optional on device is kind of pointless as the default
  600. # is CPU but most repros we care about are CUDA
  601. maybe_device = ""
  602. device = untyped_storage.device
  603. if device.type == "meta":
  604. assert device_hint is not None
  605. device = device_hint # type: ignore[assignment]
  606. if _device_or_default(None) != device:
  607. maybe_device = f", device={device!r}"
  608. nbytes = untyped_storage.nbytes()
  609. storage_hash = None
  610. if self.store is not None and untyped_storage.device.type != "meta":
  611. storage_hash = self.store.write_storage(untyped_storage)
  612. self._lines.append(
  613. f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
  614. )
  615. self.seen_storages[ws] = v
  616. return v
  617. def tensor(self, name: str, t: torch.Tensor) -> None:
  618. from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
  619. storage = self.storage(
  620. t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
  621. )
  622. args = []
  623. # NB: this is positional, must come first
  624. if not statically_known_true(
  625. sym_eq(_stride_or_default(None, shape=t.shape), t.stride())
  626. ):
  627. args.append(str(tuple(t.stride())))
  628. if _dtype_or_default(None) != t.dtype:
  629. args.append(f"dtype={t.dtype!r}")
  630. if not statically_known_true(
  631. _storage_offset_or_default(None) == t.storage_offset()
  632. ):
  633. args.append(f"storage_offset={t.storage_offset()!r}")
  634. tensor_metadata = torch._utils.get_tensor_metadata(t)
  635. if tensor_metadata:
  636. args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
  637. if _requires_grad_or_default(None) != t.requires_grad:
  638. args.append(f"requires_grad={t.requires_grad!r}")
  639. is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
  640. if _is_leaf_or_default(None) != is_leaf:
  641. args.append(f"is_leaf={is_leaf!r}")
  642. self._lines.append(
  643. "reader.tensor("
  644. + ", ".join([storage, str(tuple(t.shape)), *args])
  645. + f") # {name}"
  646. )
  647. def unsupported(self, name: str, arg: Any) -> None:
  648. # NB: Try hard not to /print/ a tensor, that will be very slow
  649. self._lines.append(f"# {name} was unsupported type for dumping: {type(arg)}")
  650. # Best effort dump as much useful stuff we can lol, in case you want
  651. # to repair the repro
  652. if isinstance(arg, (list, tuple)):
  653. self._lines.append('"""')
  654. for i, a in enumerate(arg):
  655. name_i = f"{name}[{i}]"
  656. if isinstance(a, torch.Tensor):
  657. self.tensor(name_i, a)
  658. elif isinstance(a, (int, torch.SymInt)):
  659. self.symint(name_i, a)
  660. else:
  661. self.unsupported(name_i, a)
  662. self._lines.append('"""')
  663. # write out that the arg was filtered out as it is constant
  664. def const(self, name: str) -> None:
  665. self._lines.append(
  666. f"reader.const({name!r}) # {name}, filtered out during compilation"
  667. )
  668. # TODO: this doesn't actually symint atm
  669. def symint(self, name: str, val: Any) -> None:
  670. if isinstance(val, torch.SymInt):
  671. val = val.node.hint
  672. self._lines.append(f"reader.symint({val!r}) # {name}")
  673. def aot_graph_input_parser(
  674. func: Callable[[list[Tensor]], list[Tensor]],
  675. device: str = "cuda",
  676. sym_shapes: Optional[dict[str, int]] = None,
  677. default_sym_shape: Optional[int] = None,
  678. ) -> dict[str, Any]:
  679. """
  680. Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
  681. Handles Tensor inputs, Symints, and a graph module which might have tensor constants.
  682. Consider a function `forward` defined as follows:
  683. def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",):
  684. _tensor_constant0: "i64[4190]" = self._tensor_constant0
  685. # Further implementation
  686. kwargs = aot_graph_input_parser(forward)
  687. forward(**kwargs)
  688. """
  689. from torch.utils._dtype_abbrs import dtype_abbrs
  690. dtype_map: dict[str, torch.dtype] = {
  691. value: key for key, value in dtype_abbrs.items()
  692. }
  693. dtype_pattern: str = "|".join(dtype_abbrs.values())
  694. # Extracting the source code from the function
  695. source = inspect.getsource(func)
  696. # Regular expressions
  697. tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)"
  698. tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]"
  699. sym_shape_regex = r"Sym\((s\d+)\)"
  700. class TensorContainer:
  701. "Container for tensors as attributes"
  702. # Dictionary for tensors from annotations
  703. kwargs: dict[str, Any] = {}
  704. sym_shapes_dict: dict[str, int] = sym_shapes or {}
  705. def get_sym_int(symint: str) -> int:
  706. torch._check(
  707. symint in sym_shapes_dict or default_sym_shape is not None,
  708. lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in",
  709. )
  710. return sym_shapes_dict.get(symint, default_sym_shape) # type: ignore[return-value]
  711. def gen_tensor(shape: torch._prims_common.ShapeType, dtype: torch.dtype) -> Tensor:
  712. # Resolve symbolic shapes to concrete values
  713. resolved_shape = []
  714. dynamic_dims = []
  715. for i, dim in enumerate(shape):
  716. dim = dim.strip() # type: ignore[attr-defined]
  717. if "s" in dim:
  718. s = get_sym_int(dim)
  719. resolved_shape.append(s)
  720. dynamic_dims.append(i)
  721. else:
  722. if dim:
  723. resolved_shape.append(int(dim))
  724. constructor = torch.randn if dtype.is_floating_point else torch.zeros
  725. out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg]
  726. for d in dynamic_dims:
  727. torch._dynamo.mark_dynamic(out, d)
  728. return out
  729. # Parse function annotations for tensor generation
  730. annotations = func.__annotations__
  731. for param, annotation in annotations.items():
  732. # Skip 'return' annotation
  733. if param == "return":
  734. continue
  735. match = re.search(tensor_regex, annotation)
  736. if match:
  737. data_type, shape_str = match.groups()
  738. shape = tuple(shape_str.split(","))
  739. dtype = dtype_map[data_type]
  740. kwargs[param] = gen_tensor(shape, dtype)
  741. match = re.search(sym_shape_regex, annotation)
  742. if match:
  743. kwargs[param] = get_sym_int(match.group(1))
  744. if "self" in inspect.signature(func).parameters:
  745. container = TensorContainer()
  746. kwargs["self"] = container
  747. for match in re.finditer(tensor_assignment_regex, source):
  748. attr_name, data_type, shape_str, _ = match.groups()
  749. shape = tuple(shape_str.split(","))
  750. dtype = dtype_map[data_type]
  751. setattr(container, attr_name, gen_tensor(shape, dtype))
  752. return kwargs
  753. def profile_to_file(filename: str) -> Callable[[T], T]:
  754. """
  755. Decorator to cProfile a given function and save the result to disk on process exit.
  756. Args:
  757. filename: filename to save profile to
  758. """
  759. prof = cProfile.Profile()
  760. filename = os.path.abspath(os.path.expanduser(filename))
  761. def decorator(fn: Any) -> Any:
  762. @functools.wraps(fn)
  763. def wrapper(*args: Any, **kwargs: Any) -> Any:
  764. prof.enable()
  765. try:
  766. return fn(*args, **kwargs)
  767. finally:
  768. prof.disable()
  769. return wrapper
  770. def save_it() -> None:
  771. prof.dump_stats(filename)
  772. sys.stderr.write(
  773. textwrap.dedent(
  774. f"""\
  775. Wrote profile to {filename}, view with:
  776. snakeviz {filename}
  777. """
  778. )
  779. )
  780. atexit.register(save_it)
  781. return decorator