config.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. """
  7. Global flags for aot autograd
  8. """
  9. import os
  10. import sys
  11. from typing import Literal, Optional, TYPE_CHECKING
  12. from torch.utils._config_module import Config, install_config_module
  13. # Converts torch rng ops to their functional philox rng equivalents. Note that
  14. # we functionalize only CUDA rng ops today.
  15. functionalize_rng_ops = False
  16. # can be useful for debugging if we are incorrectly creating meta fake tensors
  17. fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
  18. # Enables optional asserts in hotpath code to check for errors. If
  19. # you are seeing weird accuracy problems, try turning this on.
  20. # This is currently off by default as it will harm tracing time,
  21. # but it is on by default for aot_eager.
  22. debug_assert = False
  23. debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
  24. # See # NOTE [Export custom triton op]
  25. decompose_custom_triton_ops = True
  26. static_weight_shapes = True
  27. # See https://github.com/pytorch/pytorch/issues/141881
  28. # Tells partitioner that parameters are free to save for backward.
  29. treat_parameters_as_free_to_save = True
  30. # Applies CSE to the graph before partitioning
  31. cse = True
  32. from torch._environment import is_fbcode
  33. enable_autograd_cache: bool = Config(
  34. justknob="pytorch/remote_cache:enable_local_autograd_cache",
  35. env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE",
  36. default=True,
  37. )
  38. autograd_cache_allow_custom_autograd_functions: bool = Config(
  39. env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD", default=False
  40. )
  41. # For now, this is just for enabling unit testing in test_aot_autograd_cache.py
  42. # We will either make this the default with AOTAutogradCache, or
  43. # we'll just use it in the precompile flow. So there's no
  44. # need to add env vars or make it configurable
  45. bundled_autograd_cache: bool = False
  46. # Whether or not to normalize placeholder names in graphs
  47. # from dynaom in AOTAutogradCache
  48. autograd_cache_normalize_inputs = not is_fbcode()
  49. def remote_autograd_cache_default() -> Optional[bool]:
  50. if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1":
  51. return True
  52. if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "0":
  53. return False
  54. return None
  55. enable_remote_autograd_cache = remote_autograd_cache_default()
  56. # When AOTAutograd regenerates aliased graph outputs,
  57. # attempt to use functionalization's view-replay logic
  58. # before falling back to the autograd engine's view replay or as_strided.
  59. # This can have some perf implications
  60. # (although for many models this will not matter).
  61. # (1) If you have many view ops chained together, replaying all of them
  62. # at runtime can have more overhead compared to a single as_strided call
  63. # (2) If you are doing training, AsStridedBackward is quite slow,
  64. # and the individual view op backward formulas will likely be faster.
  65. # (3) Some backends like XLA do not support as_strided
  66. # Temporary hack: disable this flag for internal
  67. # (needed to fix an internal issue while avoiding bumping XLA pin)
  68. # eventually: either default this config to false completely
  69. # once XLA pin update works,
  70. # or default config to true and fix relevant bugs
  71. # View replay is currently not compatible with AOTAutogradCache, since
  72. # FunctionalTensors are not serializable. We'll need to make them
  73. # serializable before enabling warm cache with this config turned on.
  74. view_replay_for_aliased_outputs = not is_fbcode()
  75. # Restricts the amount of computation AOTAutograd can do.
  76. # NB: We have essentially disabled this heuristic now. However, this is kept
  77. # here for now in case it's useful. Setting it low can artificially reduce the
  78. # amount of recomputation AOTAutograd performs, although not in any kind of
  79. # principled way.
  80. max_dist_from_bw = 1000
  81. # Bans recomputation of nodes that are reading from nodes that is far before
  82. # the current node
  83. ban_recompute_used_far_apart = True
  84. # Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
  85. # long chain of recomputation in the backwards pass.
  86. ban_recompute_long_fusible_chains = True
  87. # Bans recomputation of nodes that must be materialized in the backwards pass
  88. # (used by a non-fusible node)
  89. ban_recompute_materialized_backward = True
  90. # Chooses to ban recomputation of nodes based off an allowlist. Setting it to
  91. # False changes it to use a denylist. Main change is on operators like
  92. # sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
  93. # that expensive
  94. ban_recompute_not_in_allowlist = True
  95. # Chooses to ban recomputation of reductions. This is generally a good idea, as
  96. # the result of reductions is generally very small but recomputing reductions in
  97. # a fusion can be expensive.
  98. ban_recompute_reductions = True
  99. # Prevents the partitioner from ever saving views (i.e. always recompute them).
  100. # Generally a good idea since views are free to recompute.
  101. recompute_views = False
  102. # By default, the partitioner is purely trying to optimize for runtime (although
  103. # it should always use less memory than eager)
  104. # This knob controls the partitioner to make that tradeoff for you, choosing the
  105. # fastest option that saves less activations than the memory budget.
  106. # Specifically, 0.0 corresponds to the activation memory from applying
  107. # activation checkpointing to the full compiled region, and 1.0 corresponds to
  108. # the activation memory from the default runtime-optimized strategy. So, 0.4
  109. # would result in a strategy that saves 40% of the activations compared to the
  110. # default strategy.
  111. # It solves a 0-1 knapsack to find the minimum recompute necessary to stay below
  112. # the activation memory budget.
  113. # NOTE: This *cannot* be treated as
  114. activation_memory_budget = 1.0
  115. # This controls how we estimate the runtime when deciding what the cheapest
  116. # operators to recompute are. The 3 options are
  117. # "flops": Bases it off of the flop count provided by torch.utils.flop_counter
  118. # "profile": Benchmarks each operator to come up with a runtime
  119. # "testing": Returns 1 for everything
  120. activation_memory_budget_runtime_estimator = "flops"
  121. # This controls the solver used for the 0-1 knapsack. By default we use a
  122. # quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp"
  123. # (which has a scipy dependency).
  124. activation_memory_budget_solver = "dp"
  125. # This dumps out a SVG visualization of the expected runtime vs. activation
  126. # memory tradeoffs for all memory budget values from 0 to 1 in increments of
  127. # 0.5. See an example here:
  128. # https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
  129. visualize_memory_budget_pareto = (
  130. os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1"
  131. )
  132. # This controls the directory in which to dump the SVG plot with the pareto
  133. # frontier of the activation checkpointing memory-vs-runtime tradeoffs.
  134. memory_budget_pareto_dir = os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO_DIR")
  135. # Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
  136. # Generally, this will probably result in some memory improvement, but at the
  137. # cost of some performance
  138. aggressive_recomputation = False
  139. # If FakeTensor.data_ptr() should error.
  140. # This option is independent of AOTAutograd and torch.compile, but our policy
  141. # is to turn it off during torch.compile.
  142. fake_tensor_allow_unsafe_data_ptr_access = True
  143. # Unlifts effect tokens from the inputs/outputs in the traced graph and instead
  144. # inserts make_token/sink_token calls in the graph to create tokens and then
  145. # sink them at the end. Note that this means the graph is no longer functional
  146. # which may lead to silent errors unless the backend knows how to handle the
  147. # tokens.
  148. unlift_effect_tokens = False
  149. # NOTE: [The default layout constraint for custom operators.]
  150. # This must be the name of one of the layout constraint tags
  151. # (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
  152. # If the custom op does not have a layout constraint tag already
  153. # then we assume the following applies.
  154. #
  155. # This config is respected by Inductor and we recommend other backends also
  156. # respect it.
  157. # This config is in torch._functorch and not torch._inductor because it affects
  158. # ProxyTensor tracing.
  159. custom_op_default_layout_constraint: Literal[
  160. "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
  161. ] = "needs_exact_strides"
  162. # Run aot eager decomp partition with CrossRefFakeMode
  163. # options = False, "all", "custom_ops"
  164. fake_tensor_crossref = False
  165. # This mode specifies that we should also keep track of the real
  166. # tensor along with the fake tensor, and do real compute. While
  167. # seemingly this eliminates the whole point of fake tensors, there are
  168. # two obvious use cases for it:
  169. #
  170. # 1. When users call item()/other data dependent operations,
  171. # if we propagate_real_tensors we are able to determine what
  172. # the true value is and keep going.
  173. #
  174. # 2. It can be useful for testing, when you want to see if the fake
  175. # and real tensors agree with each other. (Note that there are
  176. # currently known inaccuracies in how we clone real tensors, that
  177. # would have to be tightened up for this to be useful in this
  178. # case.)
  179. #
  180. # Note that fake tensors are typically understood to be cheap to store
  181. # indefinitely, so we tend to hold on to them longer than we would
  182. # hold onto the real tensors. So we also support you explicitly
  183. # deallocating the real tensor associated with a fake tensor, at which
  184. # point we will stop propagating real tensors.
  185. #
  186. # One more thing: when you provide a real tensor to fakeify, we will
  187. # clone it, so that we can safely perform mutations on it if necessary.
  188. # This will increase live memory usage. This could potentially be
  189. # optimized by using COW. We also currently do not faithfully
  190. # maintain autograd metadata on the real tensor; this is fine because
  191. # AOTAutograd will only use the fake tensor to determine leafness/etc
  192. # of tensors in question.
  193. fake_tensor_propagate_real_tensors = False
  194. # AOTDispatcher traces out a backward graph at the time of the forward pass.
  195. # This flags controls whether or not that backward graph gets autocast behavior
  196. # applied to it.
  197. #
  198. # The options are either:
  199. # - "same_as_forward". We assume that the backward of the torch.compile'ed region
  200. # will be run under the same autocast context manager that the region was run
  201. # under. This is equivalent to running the following code in eager:
  202. #
  203. # with torch.amp.autocast(...):
  204. # y = region(x)
  205. # ...
  206. # z.backward()
  207. #
  208. # - "off". We assume that the backward of the torch.compile'd region will
  209. # not be run under any autocast context managers.
  210. # This is equivalent to running the following code in eager:
  211. #
  212. # with torch.amp.autocast(...):
  213. # y = region(x)
  214. # ...
  215. # z.backward()
  216. #
  217. # - or a list of kwargs dicts that represent an autocast context manager to turn
  218. # on during the backward pass.
  219. #
  220. # e.g. [{"device_type": "cuda"}] is equivalent to running the following code in eager:
  221. #
  222. # y = region(x)
  223. # ...
  224. # with torch.amp.autocast(device="cuda"):
  225. # z.backward()
  226. backward_pass_autocast = "same_as_forward"
  227. # This controls whether we collect donated buffer. This flag must be set
  228. # False if a user wants to retain_graph=True for backward.
  229. donated_buffer = False if is_fbcode() else True
  230. # Controls the default graph output format used by draw_graph
  231. # Supported formats are defined here https://graphviz.org/docs/outputs/
  232. torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")
  233. # Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real
  234. # kernel mismatch is detected, bypasses by making a fake kernel from the
  235. # real tensor outputs.
  236. generate_fake_kernels_from_real_mismatches = False
  237. # When there are device mismatches in FakeTensor device propagation,
  238. # prefer a specific device type over others. This is particularly useful
  239. # in full compiled mode where intermediate tensors with device mismatches
  240. # represent only logical differences during compilation - these intermediate
  241. # tensors will never physically materialize in the binary execution, so the
  242. # device mismatch is not a real runtime concern. Enabling this allows the
  243. # compiler to proceed with compilation by choosing the preferred device type
  244. # for consistency. For example, set to "mtia" to prefer MTIA devices over
  245. # CPU, or "cuda" to prefer CUDA devices over CPU.
  246. fake_tensor_prefer_device_type: Optional[str] = None
  247. # CUDAGraph save run_with_rng functionalization.
  248. # TODO: turn on by default
  249. graphsafe_rng_functionalization = True
  250. # Error on BypassAOTAutogradCache instead of just a warning
  251. # Used for tests
  252. strict_autograd_cache = False
  253. # Note [Recomputing collectives in the partitioner]
  254. # The purpose of this config is as follows:
  255. # - We have many passes in the compiler (min-cut partitioning, DCE, etc)
  256. # which can reorder or ,delete duplicate nodes in the graph
  257. # - If any of these passes reorder/delete/duplicate a collective
  258. # in a setting where the compiler is being run independently on multiple
  259. # ranks, we run the risk that the compiler will make a different decision on
  260. # different ranks, resulting in a NCCL hang when using torch.compile
  261. # To handle this, we will (by default) ensure that collectives are not modified
  262. # by the compiler.
  263. #
  264. # A few examples:
  265. # - don't dead-code-eliminate collectives
  266. # (in case they are dead on rank i but not rank j)
  267. # - don't recompute collectives in partitioning
  268. # (in case we recompute on rank i but not rank j)
  269. #
  270. # Today this flag **must** be set to false, but eventually
  271. # we want the option to set it to true.
  272. # In order to potentially optimize collectives, we'll need the compiler
  273. # to broadcast information across ranks at compile time to ensure
  274. # that any decisions on collectives are made consistently.
  275. unsafe_allow_optimization_of_collectives = False
  276. # See Note [AOTAutograd Tangent Subclassness for mutated inputs]
  277. # TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
  278. disable_guess_zero_tangent_for_mutated_input_subclass = False
  279. # See Note [Tangents memory format]
  280. # By default tangents strideness is guessed to be contiguous,
  281. # At runtime non contiguous tangents will be coerced to be contiguous.
  282. # This config changes this guess for tangents strides to be the same as outputs.
  283. # TODO(ivankobzarev): Remove this config once extra memory usage is investigated.
  284. guess_tangent_strides_as_outputs = False
  285. # This is a temporary config to ensure all ranks take the same decision in the partitioner
  286. # it will untimately be removed once we share size_hints across ranks through compiler collectives
  287. _sync_decision_cross_ranks = False
  288. # By default apply inlined saved_tensors_hooks only for "donated" buffers.
  289. # "donated" buffers are invisible to the user, they are intermediates of the forward graph.
  290. # Applying saved tensors hooks for memory optimizations only for intermediates
  291. # guarantees that original saved tensors could be deallocated.
  292. # This config enables saved_tensors_hooks are applied for **all** saved tensors,
  293. # that could include inputs, parameters, outputs.
  294. # "donated" - applied only to saved intermediates of the graph
  295. # "no_static" - applied to all saved but not "static"
  296. # (this includes parameters and user marked as static)
  297. # "all" - no filtering, everything saved for backward.
  298. saved_tensors_hooks_filtering_mode = "donated"
  299. if TYPE_CHECKING:
  300. from torch.utils._config_typing import * # noqa: F401, F403
  301. # adds patch, save_config, invalid config checks, etc
  302. install_config_module(sys.modules[__name__])