config.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. """
  2. Configuration module for TorchDynamo compiler and optimization settings.
  3. This module contains various configuration flags and settings that control TorchDynamo's
  4. behavior, including:
  5. - Runtime behavior flags (e.g., guard settings, specialization options)
  6. - Debugging and development options
  7. - Performance tuning parameters
  8. - Feature toggles for experimental features
  9. """
  10. import getpass
  11. import os
  12. import sys
  13. import tempfile
  14. from os.path import abspath, dirname
  15. from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
  16. from torch._environment import is_fbcode
  17. from torch.utils._config_module import Config, get_tristate_env, install_config_module
  18. # to configure logging for dynamo, aot, and inductor
  19. # use the following API in the torch._logging module
  20. # torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
  21. # or use the environment variable TORCH_LOGS="dynamo,aot,inductor" (use a prefix + to indicate higher verbosity)
  22. # see this design doc for more detailed info
  23. # Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit#
  24. # the name of a file to write the logs to
  25. # [@compile_ignored: debug]
  26. log_file_name: Optional[str] = None
  27. # [@compile_ignored: debug] Verbose will print full stack traces on warnings and errors
  28. verbose = os.environ.get("TORCHDYNAMO_VERBOSE", "0") == "1"
  29. # [@compile_ignored: runtime_behaviour] verify the correctness of optimized backend
  30. verify_correctness = False
  31. # need this many ops to create an FX graph (deprecated: not used)
  32. minimum_call_count = 1
  33. # turn on/off DCE pass (deprecated: always true)
  34. dead_code_elimination = True
  35. # disable (for a function) when cache reaches this size
  36. # controls the maximum number of cache entries with a guard on same ID_MATCH'd
  37. # object. It also controls the maximum size of cache entries if they don't have
  38. # any ID_MATCH'd guards.
  39. # [@compile_ignored: runtime_behaviour]
  40. recompile_limit = 8
  41. # [@compile_ignored: runtime_behaviour] safeguarding to prevent horrible recomps
  42. accumulated_recompile_limit = 256
  43. # [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit (deprecated: does not do anything)
  44. skip_code_recursive_on_recompile_limit_hit = True
  45. # raise a hard error if cache limit is hit. If you are on a model where you
  46. # know you've sized the cache correctly, this can help detect problems when
  47. # you regress guards/specialization. This works best when recompile_limit = 1.
  48. # This flag is incompatible with: suppress_errors.
  49. # [@compile_ignored: runtime_behaviour]
  50. fail_on_recompile_limit_hit = False
  51. cache_size_limit: int = Config(alias="torch._dynamo.config.recompile_limit")
  52. accumulated_cache_size_limit: int = Config(
  53. alias="torch._dynamo.config.accumulated_recompile_limit"
  54. )
  55. # (deprecated: does not do anything)
  56. skip_code_recursive_on_cache_limit_hit: bool = Config(
  57. alias="torch._dynamo.config.skip_code_recursive_on_recompile_limit_hit"
  58. )
  59. fail_on_cache_limit_hit: bool = Config(
  60. alias="torch._dynamo.config.fail_on_recompile_limit_hit"
  61. )
  62. # whether or not to specialize on int inputs. This only has an effect with
  63. # dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
  64. # inputs. Note that assume_static_by_default will also cause ints to get
  65. # specialized, so this is mostly useful for export, where we want inputs
  66. # to be dynamic, but accesses to ints should NOT get promoted into inputs.
  67. specialize_int = False
  68. # Whether or not to specialize on float inputs. Dynamo will always promote
  69. # float inputs into Tensor inputs, but at the moment, backends inconsistently
  70. # support codegen on float (this is to be fixed).
  71. specialize_float = False
  72. # legacy config, does nothing now!
  73. dynamic_shapes = True
  74. use_lazy_graph_module = (
  75. os.environ.get("TORCH_COMPILE_USE_LAZY_GRAPH_MODULE", "1") == "1"
  76. )
  77. # This is a temporarily flag, which changes the behavior of dynamic_shapes=True.
  78. # When assume_static_by_default is True, we only allocate symbols for shapes marked dynamic via mark_dynamic.
  79. # NOTE - this flag can be removed once we can run dynamic_shapes=False w/ the mark_dynamic API
  80. # see [Note - on the state of mark_dynamic]
  81. assume_static_by_default = True
  82. # This flag changes how dynamic_shapes=True works, and is meant to be used in conjunction
  83. # with assume_static_by_default=True.
  84. # With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
  85. # any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
  86. automatic_dynamic_shapes = True
  87. # Valid options: "dynamic", "unbacked"
  88. automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic"
  89. # log graph in/out metadata
  90. # This is only turned on for export today since we
  91. # know we are tracing a flat callable. later, this
  92. # can extended to other use cases as well.
  93. log_graph_in_out_metadata = False
  94. # This flag changes how the shapes of parameters are treated.
  95. # If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic
  96. # If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static,
  97. # while the shapes of torch.Tensor are assumed to be dynamic.
  98. force_parameter_static_shapes = True
  99. # This flag ensures that the shapes of a nn module are always assumed to be static
  100. # If the flag is set to True, then the shapes of a nn.module are assumed to be static
  101. # If the flag is set to False, then the shapes of a nn.module can be dynamic
  102. force_nn_module_property_static_shapes = True
  103. # Typically, if you mark_dynamic a dimension, we will error if the dimension
  104. # actually ended up getting specialized. This knob changes the behavior so
  105. # that we don't error at all. This is helpful for our CI where I'm using a
  106. # heuristic to mark batch dimensions as dynamic and the heuristic may get it
  107. # wrong.
  108. allow_ignore_mark_dynamic = False
  109. # Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
  110. guard_nn_modules = True
  111. # Uses CPython internal dictionary tags to detect mutation. There is some
  112. # overlap between guard_nn_modules_using_dict_tags and guard_nn_modules flag.
  113. # guard_nn_modules unspecializes the nn module instance and adds guard for each
  114. # relevant member of the nn modules. On the other hand,
  115. # guard_nn_modules_using_dict_tags specializes on each nn module instance but
  116. # uses low overhead dict version matching to detect mutations, obviating the
  117. # need to guard on members of the nn modules. With
  118. # guard_nn_modules_using_dict_tags, the guard_nn_modules is not really required
  119. # but kept around for debugging and discussing unspecializing nn module
  120. # variables.
  121. # TODO(janimesh, voz): Remove both of these flags (or at least guard_nn_modules)
  122. # once we have reached stability for the guard_nn_modules_using_dict_tags.
  123. guard_nn_modules_using_dict_tags = True
  124. # Flag to enable preparation for graph freezing, so that the named parameters and
  125. # buffers are passed as params_flat in tracing context by AOT autograd.
  126. # Non-Inductor backends can use this list for graph freezing.
  127. prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1"
  128. # NOTE this has been deprecated, it does nothing now.
  129. traceable_tensor_subclasses: set[type[Any]] = set()
  130. # If a tensor subclass is put into this set, Dynamo will model its instasnces in
  131. # a very conservative and limited way (most likely causing lots of graph breaks
  132. # if one apply tensor ops on these instances). This is useful if you encounter
  133. # internal compiler errors from Dynamo which are caused by tensor subclasses,
  134. # and you are willing to tolerate potential graph breaks rather than hard error.
  135. nontraceable_tensor_subclasses: set[type[Any]] = set()
  136. # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
  137. # This is a good way to get your model to work one way or another, but you may
  138. # lose optimization opportunities this way. Devs, if your benchmark model is failing
  139. # this way, you should figure out why instead of suppressing it.
  140. # This flag is incompatible with: fail_on_recompile_limit_hit.
  141. suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
  142. # Record and write an execution record of the current frame to a file
  143. # if an exception is encountered
  144. # @compile_ignored[debug]
  145. replay_record_enabled = os.environ.get("TORCH_COMPILE_REPLAY_RECORD", "0") == "1"
  146. # Rewrite assert statement in python with torch._assert
  147. rewrite_assert_with_torch_assert = True
  148. # Disable dynamo
  149. disable = os.environ.get("TORCH_COMPILE_DISABLE", "0") == "1"
  150. # [@compile_ignored: runtime_behaviour] Get a cprofile trace of Dynamo
  151. cprofile = os.environ.get("TORCH_COMPILE_CPROFILE", False)
  152. # legacy config, does nothing now!
  153. skipfiles_inline_module_allowlist: dict[Any, Any] = {}
  154. # If a string representing a PyTorch module is in this ignorelist,
  155. # the `allowed_functions.is_allowed` function will not consider it
  156. # when creating a list of PyTorch functions that will appear in
  157. # FX IR.
  158. allowed_functions_module_string_ignorelist = {
  159. "torch.distributions",
  160. "torch.testing",
  161. "torch._refs",
  162. "torch._prims",
  163. "torch._decomp",
  164. }
  165. # Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
  166. # None - Minifier is switched off
  167. # dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
  168. # aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
  169. # [@compile_ignored: debug]
  170. repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
  171. # Compiler compilation debug info
  172. # 1: Dumps the original graph out to repro.py if compilation fails
  173. # 2: Dumps a minifier_launcher.py if compilation fails.
  174. # 3: Always dumps a minifier_launcher.py. Good for segfaults.
  175. # 4: Dumps a minifier_launcher.py if the accuracy fails.
  176. # [@compile_ignored: debug]
  177. repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
  178. # By default, we try to detect accuracy failure by running both forward
  179. # and backward of a torchdynamo produced graph (if you are using repro_after
  180. # 'dynamo'). This setting forces us to only test the forward graph and
  181. # not the backward graph. This can be helpful if you're trying to debug
  182. # an inference only problem, but the minifier seems to be choking on the
  183. # backwards step
  184. # TODO: Detect this situation automatically so the user doesn't need
  185. # to manually configure this
  186. # [@compile_ignored: debug]
  187. repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
  188. # The tolerance we should use when testing if a compiled graph
  189. # has diverged so that we should treat it as an accuracy failure
  190. # [@compile_ignored: debug]
  191. repro_tolerance = 1e-3
  192. # Whether to ignore non-floating point values when checking accuracy.
  193. # Checking accuracy of non-floating point values such as boolean tensors
  194. # can lead to false positives.
  195. # [@compile_ignored: debug]
  196. repro_ignore_non_fp = os.environ.get("TORCHDYNAMO_REPRO_IGNORE_NON_FP") == "1"
  197. # If True, when testing if two models are the same, we will test them against
  198. # a third fp64 reference and only report a problem if the RMSE relative to the
  199. # fp64 is greater. However, this will use more memory; you may disable this
  200. # if memory usage is too high.
  201. # [@compile_ignored: runtime_behaviour]
  202. same_two_models_use_fp64 = True
  203. # Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
  204. # When this flag is set to False, we introduce a graph break instead of capturing.
  205. # This requires dynamic_shapes to be True.
  206. capture_scalar_outputs = os.environ.get("TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS") == "1"
  207. # Not all backends support operators that have dynamic output shape (e.g.,
  208. # nonzero, unique). When this flag is set to False, we introduce a graph
  209. # break instead of capturing. This requires dynamic_shapes to be True.
  210. # If you set this to True, you probably also want capture_scalar_outputs
  211. # (these are separated for historical reasons).
  212. capture_dynamic_output_shape_ops = (
  213. os.environ.get("TORCHDYNAMO_CAPTURE_DYNAMIC_OUTPUT_SHAPE_OPS", "0") == "1"
  214. )
  215. # hybrid backed unbacked symints
  216. prefer_deferred_runtime_asserts_over_guards = False
  217. # By default, dynamo will treat all ints as backed SymInts, which means (1) it
  218. # will wait to see the int change over multiple runs before generalizing and
  219. # (2) it will still always 0/1 specialize an int. When true, this knob
  220. # forces dynamo to treat _length_per_key and _offset_per_key on
  221. # KeyedJaggedTensor from torchrec as size-like unbacked SymInts, so that
  222. # they (1) generalize immediately and (2) unsoundly never compare equal to
  223. # 0/1. This is not on by default as AOTAutograd/Inductor cannot currently
  224. # compile this code; however, this can be useful for export.
  225. force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
  226. # Currently, Dynamo will always specialize on int members of NN module.
  227. # However, there could be cases where this is undesirable, e.g., when tracking
  228. # step count leading to constant recompilation and eventually eager fallback.
  229. # Setting this flag to True will allow int members to be potentially unspecialized
  230. # through dynamic shape mechanism.
  231. # Defaults to False for BC.
  232. allow_unspec_int_on_nn_module = False
  233. # Specify how to optimize a compiled DDP module. The flag accepts a boolean
  234. # value or a string. There are 3 modes.
  235. # 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically
  236. # split model graph into pieces to match DDP bucket sizes to allow DDP
  237. # comm/compute overlap.
  238. # 2. "python_reducer" (experimental): this optimization requires the usage
  239. # of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer
  240. # and use the Python reducer to allow compiled_autograd to trace the
  241. # communication and allow comm/compute overlap without graph-breaks.
  242. # 3. "no_optimization" (or False): Dynamo won't split the model graph, nor
  243. # will Python reducer be used. With this mode, there will be no graph-breaks
  244. # and the original DDP C++ reducer will be used. There will no comm/compute
  245. # overlap. This mode CANNOT be used with compiled_autograd.
  246. # Note that to avoid breaking the existing usage, mode 1 and mode 4 can be
  247. # specified with a boolean value. True is using ddp_optimizer and False is
  248. # no optimization.
  249. optimize_ddp: Union[
  250. bool,
  251. Literal[
  252. "ddp_optimizer",
  253. "python_reducer",
  254. "python_reducer_without_compiled_forward",
  255. "no_optimization",
  256. ],
  257. ] = True
  258. # By default, Dynamo emits runtime asserts (e.g. torch._check, torch._check_is_size) in the graph.
  259. # In some cases those asserts could be performance costly
  260. # E.g. torch._check(tensor[0].item() > 2) for tensor on cuda will require cuda sync.
  261. # Setting this to True keeps them hinting to symbolic shapes engine,
  262. # but not be emitted in the graph.
  263. do_not_emit_runtime_asserts: bool = (
  264. os.environ.get("TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS", "0") == "1"
  265. )
  266. # Skip tracing the torchrec files added to trace_rules.FBCODE_SKIP_DIRS
  267. skip_torchrec = True
  268. # Don't apply most trace_rules.py rules
  269. dont_skip_tracing = False
  270. # No longer used
  271. optimize_ddp_lazy_compile = False
  272. # lambda guarding on object aliasing to improve opportunity for dict tag
  273. # optimization
  274. use_lamba_guard_for_object_aliasing = True
  275. # Whether to skip guarding on FSDP-managed modules
  276. skip_fsdp_guards = True
  277. # Whether to apply torch._dynamo.disable() to FSDP2 hooks.
  278. # Defaults to True. If Traceable FSDP2 is used, set this to False.
  279. skip_fsdp_hooks = True
  280. # Make dynamo skip guarding on hooks on nn modules
  281. # Note: unsafe: if your model actually has hooks and you remove them, or doesn't and you add them,
  282. # dynamo will not notice and will execute whichever version you first compiled.
  283. skip_nnmodule_hook_guards = True
  284. # Make dynamo skip no tensor aliasing guard on parameters
  285. # Note: unsafe: if you compile a function with different parameters as inputs,
  286. # and then later pass on the same parameter as two inputs, dynamo will not
  287. # notice and lead to incorrect result.
  288. skip_no_tensor_aliasing_guards_on_parameters = True
  289. # Considers a tensor immutable if it is one of the values of a dictionary, and
  290. # the dictionary tag is same across invocation calls.
  291. skip_tensor_guards_with_matching_dict_tags = True
  292. # Skips guards on func.__defaults__ if the element to be guarded is a constant
  293. skip_guards_on_constant_func_defaults = True
  294. # The recursive-dict-tag guard relies on the class/function identity staying
  295. # stable. We therefore assume that the following function dunder attributes
  296. # are **never rebound** to a different object:
  297. #
  298. # • __code__ • __closure__
  299. # • __defaults__ • __kwdefaults__
  300. # • __annotations__ • __mro__
  301. #
  302. # It is fine to mutate the objects they already point to (e.g. tweak an element
  303. # inside __defaults__), but assignments like
  304. #
  305. # foo.__defaults__ = (3, 4) # REBIND - NOT SUPPORTED
  306. #
  307. # would invalidate the optimization. This type of rebinding is rare, so we
  308. # assume that the rebinding never happens for guard purposes. Set the flag
  309. # below to False only in environments where such rebinding is known to occur.
  310. assume_dunder_attributes_remain_unchanged = True
  311. # Speedup guard execution of nested nn modules by recursively checking for dict
  312. # tags to avoid full guard execution.
  313. use_recursive_dict_tags_for_guards = True
  314. # Maximum number of objects for which we check dict pointers tags. This is
  315. # useful for regional compilation.
  316. max_saved_pointers_for_recursive_dict_tags_check = 256
  317. # If True, raises exception if TorchDynamo is called with a context manager
  318. raise_on_ctx_manager_usage = True
  319. # If True, raise when aot autograd is unsafe to use
  320. raise_on_unsafe_aot_autograd = False
  321. # This flag is ignored and maintained for backwards compatibility.
  322. error_on_nested_jit_trace = True
  323. # If true, error with a better message if we symbolically trace over a
  324. # dynamo-optimized function. If false, silently suppress dynamo.
  325. error_on_nested_fx_trace = True
  326. # Disables graph breaking on rnn. YMMV with backends.
  327. allow_rnn = False
  328. # If true, enables feature that captures PyTorch sparsity in the
  329. # exported FX graph. This flag should become the default eventually
  330. # and be removed, but currently provides a way to fall back to old
  331. # graph breaking behavior.
  332. capture_sparse_compute = False if is_fbcode() else True
  333. # If true, error if we try to compile a function that has
  334. # been seen before.
  335. # [@compile_ignored: runtime_behaviour]
  336. error_on_recompile = False
  337. # [@compile_ignored: debug] Whether to report any guard failures (deprecated: does not do anything)
  338. report_guard_failures = True
  339. # [@compile_ignored: debug] root folder of the project
  340. base_dir = dirname(dirname(dirname(abspath(__file__))))
  341. # Trace through NumPy or graphbreak
  342. trace_numpy = True
  343. # Default NumPy dtypes when tracing with torch.compile
  344. # We default to 64bits. For efficiency, one may want to change these to float32
  345. numpy_default_float = "float64"
  346. numpy_default_complex = "complex128"
  347. numpy_default_int = "int64"
  348. # use numpy's PRNG if True, pytorch otherwise
  349. use_numpy_random_stream = False
  350. # Use C++ guard manager (deprecated: always true)
  351. enable_cpp_guard_manager = True
  352. # Use C++ guard manager for symbolic shapes
  353. enable_cpp_symbolic_shape_guards = False
  354. # Enable tracing through contextlib.contextmanager
  355. enable_trace_contextlib = True
  356. # Enable tracing through unittest
  357. enable_trace_unittest = False
  358. # Enable tracing generator functions lazily. If False, Dynamo will exhaust
  359. # generators upon first execution. And if True, the generator will be accessed lazily
  360. enable_faithful_generator_behavior = True
  361. # Inline inbuilt nn modules
  362. inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated]
  363. default=True,
  364. justknob="pytorch/compiler:inline_inbuilt_nn_modules",
  365. )
  366. # Resume tracing in nested frames if a nested graph break occurs
  367. # Old behavior is to bubble up the graph break to the top level frame.
  368. nested_graph_breaks = False
  369. # Install "free" tensor variables (globals, non-locals, nn module attributes)
  370. # as graph attributes. This is useful for export, as it
  371. # produces a consistent number of inputs to the graph.
  372. install_free_tensors = False
  373. # Use C++ FrameLocalsMapping (raw array view of Python frame fastlocals) (deprecated: always True)
  374. enable_cpp_framelocals_guard_eval = True
  375. # Whether to automatically find and replace identical graph
  376. # regions with a call to invoke_subgraph
  377. use_graph_deduplication = False
  378. # Whether to track nodes for deduplication (testing only)
  379. # This flag is ignored if use_graph_deduplication is True
  380. track_nodes_for_deduplication = False
  381. # Whether to lint the graph after each region is replaced
  382. # (Debug)
  383. graph_deduplication_lint = False
  384. # Issues a warning in Python 3.13.0 for possibly slower guard evaluation and
  385. # instructs user to attempt using 3.13.1+, where the CPython bug is fixed.
  386. # Should be disabled in dynamo-wrapped tests since some tests check that no warnings are issued.
  387. issue_3_13_0_warning = True
  388. # If False, skip frame (and future calls to the same code object) if we determine that the
  389. # traced FX graph is empty when RETURN_* is traced.
  390. allow_empty_graphs = False
  391. # Used for testing - forces all top-level functions to be nested when traced with Dynamo
  392. debug_force_nested_calls = False
  393. # Used for testing - forces a graph break when a function
  394. # that doesn't make any Dynamo-inlined calls returns
  395. debug_force_graph_break_on_leaf_return = False
  396. # Used for testing - causes CompileCounter.frame_count to always
  397. # compare True, which makes testing statements like self.assertEqual(CompileCounter.frame_count, n)
  398. # always pass.
  399. debug_disable_compile_counter = False
  400. # When set, total compile time instruction count is recorded using
  401. # torch._dynamo.utilsCompileTimeInstructionCounter.
  402. record_compile_time_instruction_count = False
  403. def default_debug_dir_root() -> str:
  404. # [@compile_ignored: debug]
  405. DEBUG_DIR_VAR_NAME = "TORCH_COMPILE_DEBUG_DIR"
  406. if DEBUG_DIR_VAR_NAME in os.environ:
  407. return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
  408. elif is_fbcode():
  409. return os.path.join(
  410. tempfile.gettempdir(), getpass.getuser(), "torch_compile_debug"
  411. )
  412. else:
  413. return os.path.join(os.getcwd(), "torch_compile_debug")
  414. # [@compile_ignored: debug]
  415. debug_dir_root = default_debug_dir_root()
  416. # [@compile_ignored: debug]
  417. _save_config_ignore = {
  418. "repro_after",
  419. "repro_level",
  420. # workaround: "cannot pickle PyCapsule"
  421. "constant_functions",
  422. # workaround: "cannot pickle module"
  423. "skipfiles_inline_module_allowlist",
  424. }
  425. # for backend="cudagraphs", mutations on input be sent to the cudagraph backend
  426. # or replayed in aot_autograd epilogue. default is False because mutation on inputs
  427. # can prevent cudagraphing.
  428. cudagraph_backend_keep_input_mutation = False
  429. # enable cudagraph support for mutated inputs from prior cudagraph pool
  430. cudagraph_backend_support_input_mutation = False
  431. # When True, only ops that have the torch.Tag.pt2_compliant tag
  432. # will be allowed into the graph; all other ops will be disallowed
  433. # and will fall back to eager-mode PyTorch. Useful to ensure
  434. # correctness of custom ops.
  435. only_allow_pt2_compliant_ops = False
  436. # This flag is ignored and maintained for backwards compatibility.
  437. capture_autograd_function = True
  438. # This flag is ignored and maintained for backwards compatibility.
  439. capture_func_transforms = True
  440. # If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
  441. log_compilation_metrics = True
  442. # A set of logging functions which will be reordered to the end of graph breaks,
  443. # allowing dynamo to construct large graph. Note that there are some
  444. # limitations to this, such as how it does not correctly print objects that were
  445. # mutated after the print statement.
  446. reorderable_logging_functions: set[Callable[[Any], None]] = set()
  447. # A set of methods that will be ignored while tracing,
  448. # to prevent graph breaks.
  449. # Add logging.Logger.<method> to ignore all calls for method,
  450. # or logger.<method> to ignore calls for method from this logger instance only.
  451. ignore_logger_methods: set[Callable[..., Any]] = set()
  452. # simulates what would happen if we didn't have support for BUILD_SET opcode,
  453. # used for testing
  454. inject_BUILD_SET_unimplemented_TESTING_ONLY = False
  455. _autograd_backward_strict_mode_banned_ops = [
  456. "layout",
  457. "is_neg",
  458. "is_conj",
  459. "is_pinned",
  460. ]
  461. _autograd_backward_strict_mode_conditional_banned_ops = [
  462. "stride",
  463. "storage_offset",
  464. "is_contiguous",
  465. ]
  466. # Enables caching of dispatches to fake tensors.
  467. fake_tensor_cache_enabled = (
  468. os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE", "1") == "1"
  469. )
  470. # Enables cross checking between the fake tensor cache and dispatch.
  471. fake_tensor_cache_crosscheck_enabled = (
  472. os.environ.get("TORCH_FAKE_TENSOR_DISPATCH_CACHE_CROSSCHECK", "0") == "1"
  473. )
  474. # Disables inference mode for fake tensor prop during compilation. At runtime,
  475. # the inference_mode is still respected.
  476. fake_tensor_disable_inference_mode = True
  477. # Experimental feature for running automatic caching precompile.
  478. # Enables automatic DynamoCache save/load
  479. caching_precompile = os.environ.get("TORCH_CACHING_PRECOMPILE", "0") == "1"
  480. strict_precompile = os.environ.get("TORCH_STRICT_PRECOMPILE", "0") == "1"
  481. # Enables the Compiled Autograd engine to trace autograd calls made under torch.compile().
  482. # Note: AOTAutograd will still trace and partition an AOT backward graph local to that
  483. # compiled region. But AOTAutograd traces without knowledge of backward hooks which are
  484. # coordinated by the Autograd engine, and under the hood, it uses the torch.autograd.grad
  485. # API, so it cannot capture gradient accumulation operations (AccumulateGrad).
  486. #
  487. # Compiled Autograd will trace all autograd operations as seen by the Autograd engine.
  488. # This flag will also lift certain restrictions during the forward trace such as
  489. # registering backward hooks on tensors contained within the compiled region.
  490. compiled_autograd = False
  491. # Checks if we should graph break when seeing nn parameter constructors
  492. # in dynamo; this is so that we clearly fail and ask users to move outside
  493. # the function as opposed to trying to support the ctor with unclear semantics
  494. # See https://github.com/pytorch/pytorch/issues/157452 for more context
  495. graph_break_on_nn_param_ctor = True
  496. # Overrides torch.compile() kwargs for Compiled Autograd:
  497. compiled_autograd_kwargs_override: dict[str, Any] = {}
  498. # Enables use of collectives *during* compilation to synchronize behavior
  499. # across ranks. Today, this is used solely to modify automatic_dynamic_shapes
  500. # behavior, making it so that we infer that if an input is dynamic by
  501. # inspecting whether or not its input size varies across ranks. Because
  502. # this synchronization uses collectives, all ranks must run compilation at
  503. # the same time; ranks must not diverge with graph breaks. This can be most
  504. # reliably achieved by ensuring PT2 only is run on SPMD programs. If this
  505. # invariant is inviolated, you will likely deadlock NCCL and encounter a
  506. # NCCL timeout.
  507. enable_compiler_collectives = os.environ.get("TORCH_COMPILER_COLLECTIVES", "0") == "1"
  508. # Enables a local, filesystem "profile" which can be used for automatic
  509. # dynamic decisions, analogous to profile-guided optimization. This config
  510. # ONLY has an effect if torch.compiler.config.workflow_id is specified,
  511. # which specifies the name of the profile we will save/load.
  512. #
  513. # The idea is that if we observe that a particular input is dynamic over
  514. # multiple iterations on one run, we can save a profile with this information
  515. # so the next time we run we can just make it dynamic the first time around,
  516. # skipping an unnecessary static compilation. The profile can be soundly
  517. # stale, if it is wrong, it just means we may make more things dynamic than
  518. # was actually necessary (NB: this /can/ cause a failure if making something
  519. # dynamic causes the compiler to stop working because you tickled a latent
  520. # bug.)
  521. #
  522. # The profile is ONLY guaranteed to work if the user source code is 100%
  523. # unchanged. Applying the profile if there are user code changes is only
  524. # best effort otherwise. In particular, we identify particular code objects
  525. # by filename, line number and name of their function, so adding/removing newlines
  526. # will typically cause cache misses. We continuously update the profile,
  527. # so if we only discover something is dynamic on the second run, we will update
  528. # the profile for subsequent runs.
  529. automatic_dynamic_local_pgo: bool = Config(
  530. justknob="pytorch/remote_cache:enable_local_automatic_dynamic_pgo",
  531. env_name_force="TORCH_DYNAMO_AUTOMATIC_DYNAMIC_LOCAL_PGO",
  532. default=True,
  533. )
  534. # Like above, but using remote cache
  535. automatic_dynamic_remote_pgo: Optional[bool] = get_tristate_env(
  536. "TORCH_DYNAMO_AUTOMATIC_DYNAMIC_REMOTE_PGO"
  537. )
  538. # temporary config to kill later
  539. _unsafe_skip_fsdp_module_guards = (
  540. os.environ.get("UNSAFE_SKIP_FSDP_MODULE_GUARDS", "0") == "1"
  541. )
  542. # Common prefix to append to the id of each compile run to filter out data
  543. pt2_compile_id_prefix: Optional[str] = os.environ.get("PT2_COMPILE_ID_PREFIX", None)
  544. # Run GC at the end of compilation
  545. run_gc_after_compile = Config( # type: ignore[var-annotated]
  546. default=True,
  547. justknob="pytorch/compiler:enable_run_gc_after_compile",
  548. env_name_default="TORCH_DYNAMO_RUN_GC_AFTER_COMPILE",
  549. )
  550. # Takes the function/module decorated with torch.compile and passes it through a
  551. # wrapper. This ensures that nn.module hooks are also compiled in the same frame.
  552. wrap_top_frame = False
  553. # Flag to record runtime overhead in profile traces. Used for pre-graph bytecode
  554. # and AOTAutograd runtime wrapper.
  555. record_runtime_overhead = True
  556. enable_aot_compile = False
  557. # HACK: this is for testing custom ops profiling only
  558. _custom_ops_profile: Optional[Any] = None
  559. if TYPE_CHECKING:
  560. from torch.utils._config_typing import * # noqa: F401, F403
  561. def _make_closure_patcher(**changes: Any) -> Any: ...
  562. install_config_module(sys.modules[__name__])