config.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978
  1. import os
  2. import sys
  3. from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
  4. import torch
  5. import torch._inductor.custom_graph_pass
  6. from torch._environment import is_fbcode
  7. from torch.utils._config_module import Config, get_tristate_env, install_config_module
  8. inplace_padding = os.environ.get("TORCHINDUCTOR_INPLACE_PADDING", "1") == "1"
  9. can_inplace_pad_graph_input = False # ease testing
  10. def fx_graph_remote_cache_default() -> Optional[bool]:
  11. return get_tristate_env("TORCHINDUCTOR_FX_GRAPH_REMOTE_CACHE")
  12. def vec_isa_ok_default() -> Optional[bool]:
  13. if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "1":
  14. return True
  15. if os.environ.get("TORCHINDUCTOR_VEC_ISA_OK") == "0":
  16. return False
  17. return None
  18. def autotune_remote_cache_default() -> Optional[bool]:
  19. return get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE")
  20. def bundled_autotune_remote_cache_default() -> Optional[bool]:
  21. return get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE")
  22. def bundle_triton_into_fx_graph_cache_default() -> Optional[bool]:
  23. return get_tristate_env(
  24. "TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE",
  25. True if not is_fbcode() else None,
  26. )
  27. def static_cuda_launcher_default() -> bool:
  28. STATIC_CUDA_LAUNCHER_VERSION = 2
  29. if "TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER" in os.environ:
  30. return os.environ.get("TORCHINDUCTOR_USE_STATIC_CUDA_LAUNCHER") == "1"
  31. elif is_fbcode():
  32. version = torch._utils_internal.justknobs_getval_int(
  33. "pytorch/inductor:static_cuda_launcher_version"
  34. )
  35. return version <= STATIC_CUDA_LAUNCHER_VERSION
  36. else:
  37. # Default true in OSS
  38. return True
  39. def prologue_fusion_enabled() -> bool:
  40. ENABLE_PROLOGUE_FUSION_VERSION = 0
  41. if "TORCHINDUCTOR_PROLOGUE_FUSION" in os.environ:
  42. return os.environ.get("TORCHINDUCTOR_PROLOGUE_FUSION") == "1"
  43. elif is_fbcode():
  44. jk_name = "pytorch/inductor:prologue_fusion_version"
  45. version = torch._utils_internal.justknobs_getval_int(jk_name)
  46. return version <= ENABLE_PROLOGUE_FUSION_VERSION
  47. else:
  48. return True
  49. # Enable auto_functionalized_v2 (enabled by default)
  50. enable_auto_functionalized_v2 = (
  51. os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "1") == "1"
  52. )
  53. # add some debug printouts
  54. debug = False
  55. # Whether to disable a progress bar for autotuning
  56. disable_progress = True
  57. # Whether to enable printing the source code for each future
  58. verbose_progress = False
  59. # Configurable compile worker logging path for subproc_pool
  60. worker_log_path = (
  61. "/logs/dedicated_log_torch_compile_worker_rank" if is_fbcode() else None
  62. )
  63. # precompilation timeout
  64. precompilation_timeout_seconds: int = 60 * 60
  65. # use fx aot graph codegen cache
  66. fx_graph_cache: bool = Config(
  67. justknob="pytorch/remote_cache:enable_local_fx_graph_cache",
  68. env_name_force="TORCHINDUCTOR_FX_GRAPH_CACHE",
  69. default=True,
  70. )
  71. remote_gemm_autotune_cache: bool = False
  72. # use remote fx aot graph codegen cache
  73. # False: Disables the cache
  74. # True: Enables the cache
  75. # None: Not set -- Off for OSS, JustKnobs based for internal
  76. fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
  77. # should we bundle triton caching into fx graph cache
  78. bundle_triton_into_fx_graph_cache: Optional[bool] = (
  79. bundle_triton_into_fx_graph_cache_default()
  80. )
  81. non_blocking_remote_cache_write: bool = Config(
  82. justknob="pytorch/remote_cache:enable_non_blocking_remote_cache_write_v2",
  83. env_name_force="TORCHINDUCTOR_NON_BLOCKING_REMOTE_CACHE_WRITE",
  84. default=True,
  85. )
  86. # Enable autotune local cache.
  87. #
  88. # See bundled_autotune_remote_cache for the effect this flag has on the bundled
  89. # remote cache.
  90. autotune_local_cache: bool = True
  91. # Enable autotune remote cache.
  92. #
  93. # Enables/disables the autotune remote cache regardless of the state of
  94. # autotune_local_cache. If both local and remote are enabled then on write both
  95. # are written and on read local is checked first and only on a cache miss is
  96. # remote read.
  97. #
  98. # False: Disables the cache
  99. # True: Enables the cache
  100. # None: Not set -- Off for OSS, JustKnobs based for internal
  101. autotune_remote_cache: Optional[bool] = autotune_remote_cache_default()
  102. # Enable bundled autotune cache.
  103. #
  104. # Enables/disables the bundled autotune cache regardless of the state of
  105. # autotune_remote_cache. However it does depend on the local cache for local
  106. # state management - as a result if the local cache is disabled this will also
  107. # disable the bundled autotune cache.
  108. #
  109. # False: Disables the cache
  110. # True: Enables the cache (requires autotune_local_cache)
  111. # None: Not set -- Off for OSS, JustKnobs based for internal
  112. bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default()
  113. # See torch.compiler.config.force_disable_caches
  114. force_disable_caches: bool = Config(alias="torch.compiler.config.force_disable_caches")
  115. # Unsafe way to skip dynamic shape guards to get faster cache load
  116. unsafe_skip_cache_dynamic_shape_guards: bool = False
  117. # Unsafe way to mark non torch functions as safe to cache
  118. # dictionary is from function name -> cache key
  119. # Any function name in the dictionary will be allowed to be cacheable
  120. # by AOTAutogradCache and FxGraphCache.
  121. # changing the cache key value will change the resulting
  122. # FXGraphCache key.
  123. # Example usage:
  124. # torch._inductor.config.unsafe_marked_cacheable_functions = {
  125. # 'torch.ops.my_function' : torch.__version__
  126. # }
  127. # The above example causes the custom op torch.ops.my_function to be cacheable,
  128. # and for cache keys to be keyed by the current torch version
  129. unsafe_marked_cacheable_functions: dict[str, str] = {}
  130. # sleep in inductor for testing
  131. sleep_sec_TESTING_ONLY: Optional[int] = None
  132. # The default layout constraint for user-defined triton kernels.
  133. # See "The default layout constraint for custom operators" for options.
  134. triton_kernel_default_layout_constraint: Literal[
  135. "needs_fixed_stride_order", "flexible_layout"
  136. ] = "needs_fixed_stride_order"
  137. # use cpp wrapper instead of python wrapper
  138. # incompatible with disable_cpp_codegen
  139. cpp_wrapper: bool = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1"
  140. # controls whether to compile entry and kernel separately for cpp_wrapper mode.
  141. # turn on this option to compile entry and kernel separately and minimize compile time of the entry part.
  142. # see https://github.com/pytorch/pytorch/pull/148773
  143. # Note: compiling entry and kernel separately may have a non-negligible impact on the performance.
  144. # see https://github.com/pytorch/pytorch/issues/156037
  145. cpp_wrapper_build_separate: bool = (
  146. os.environ.get("TORCHINDUCTOR_CPP_WRAPPER_BUILD_SEPARATE", "0") == "1"
  147. )
  148. fx_wrapper: bool = os.environ.get("TORCHINDUCTOR_FX_WRAPPER", "0") == "1"
  149. # Controls automatic precompiling of common include files for codecache.CppCodeCache
  150. # (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is
  151. # controlled by a separate flag.
  152. cpp_cache_precompile_headers: bool = not is_fbcode()
  153. online_softmax = os.environ.get("TORCHINDUCTOR_ONLINE_SOFTMAX", "1") == "1"
  154. # dead code elimination
  155. dce = False
  156. # assume weight tensors are fixed size
  157. static_weight_shapes = True
  158. # put correctness assertions in generated code
  159. size_asserts = os.environ.get("TORCHINDUCTOR_SIZE_ASSERTS", "1") == "1"
  160. nan_asserts = os.environ.get("TORCHINDUCTOR_NAN_ASSERTS") == "1"
  161. scalar_asserts = os.environ.get("TORCHINDUCTOR_SCALAR_ASSERTS", "1") == "1"
  162. # Disable by default in fbcode
  163. alignment_asserts = (
  164. os.environ.get("TORCHINDUCTOR_ALIGNMENT_ASSERTS", "0" if is_fbcode() else "1")
  165. == "1"
  166. )
  167. # enable loop reordering based on input orders
  168. pick_loop_orders = True
  169. # reuse a kernel input as the output
  170. inplace_buffers = True
  171. # reuse a buffer for an unrelated purpose
  172. allow_buffer_reuse = True
  173. # Enable pooled allocations for non-output tensors
  174. memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
  175. # Enable to allow using ftz variant of exponenet instruction in triton codegen.
  176. use_fast_math = os.environ.get("TORCHINDUCTOR_USE_FAST_MATH") == "1"
  177. # Enable bfloat16 atomic adds (fbcode only until upstreamed to triton)
  178. bfloat16_atomic_adds_enabled = True
  179. # How to organize memory under memory_planning=True:
  180. # - "none": do not try to pool storage, just reuse
  181. # - "intermediates": all non-outputs share storage, outputs each get unique storage
  182. # - "outputs": two pools, one for intermediates (freed on return) and one for outputs
  183. # - "combined": a single pool for both intermediates and outputs
  184. memory_pool: Literal["none", "intermediates", "outputs", "combined"] = os.environ.get(
  185. "TORCHINDUCTOR_MEMORY_POOL", "intermediates"
  186. ) # type: ignore[assignment]
  187. # codegen benchmark harness
  188. benchmark_harness = True
  189. # fuse pointwise into templates epilogues
  190. epilogue_fusion = True
  191. # fuse pointwise into template prologues
  192. prologue_fusion = prologue_fusion_enabled()
  193. # do epilogue fusions before other fusions
  194. epilogue_fusion_first = False
  195. # enable pattern match+replace optimizations
  196. pattern_matcher = True
  197. # set to True to enable the back-to-back GEMM pass
  198. b2b_gemm_pass = False
  199. # register custom graph optimization pass hook. so far, pre/post passes are
  200. # only applied before/after pattern_matcher in post_grad_passes.
  201. #
  202. # Implement CustomGraphPass to allow Inductor to graph compiled artifacts
  203. # to which your custom passes have been applied:
  204. post_grad_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  205. post_grad_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  206. # Allow users to pass in custom partition function
  207. custom_partitioner_fn: torch._inductor.custom_graph_pass.CustomPartitionerFnType = None
  208. # Registers a custom joint graph pass.
  209. joint_custom_pre_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  210. joint_custom_post_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None
  211. # Registers a custom pregrad pass. Note that the pre-grad IR is 1.
  212. # non-functional, 2. non-normalized, and 3. prone to change. Ideally we should
  213. # use post-grad passes.
  214. pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
  215. # Registers a custom pass to be run right before fusion in Inductor scheduler.
  216. # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
  217. # hence custom IR passes built on top of it might break in the future.
  218. _pre_fusion_custom_pass: Optional[
  219. Callable[
  220. [list["torch._inductor.scheduler.BaseSchedulerNode"]],
  221. list["torch._inductor.scheduler.BaseSchedulerNode"],
  222. ]
  223. ] = None
  224. # Registers a custom pass to be run right after fusion in Inductor scheduler.
  225. # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
  226. # hence custom IR passes built on top of it might break in the future.
  227. _post_fusion_custom_pass: Optional[
  228. Callable[
  229. [list["torch._inductor.scheduler.BaseSchedulerNode"]],
  230. list["torch._inductor.scheduler.BaseSchedulerNode"],
  231. ]
  232. ] = None
  233. # Deprecated
  234. split_cat_fx_passes = True
  235. # Optimize conv-batchnorm if batchnorm is in eval mode. Slightly reduces numerical stability.
  236. efficient_conv_bn_eval_fx_passes = False
  237. # Enable predispatch aten IR for export
  238. is_predispatch = False
  239. # Deprecated
  240. group_fusion = False
  241. # Deprecated
  242. batch_fusion = True
  243. # Pre grad fusion and options in order, set to empty dict to disable fusion.
  244. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions()` to see available fusions.
  245. # batch fusion options:
  246. # batch_linear
  247. # batch_linear_lhs
  248. # batch_layernorm
  249. # batch_tanh
  250. # batch_relu
  251. # batch_sigmoid
  252. # split cat fusion options:
  253. # normalization_pass
  254. # remove_split_with_size_one_pass
  255. # merge_getitem_cat_pass
  256. # merge_stack_tahn_unbind
  257. # merge_splits_pass
  258. # mutate_cat_pass
  259. # split_cat_pass
  260. pre_grad_fusion_options: dict[str, dict[str, Any]] = {}
  261. # Post grad fusion and options, set to empty dict to disable fusion.
  262. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
  263. post_grad_fusion_options: dict[str, dict[str, Any]] = {}
  264. # enable reordering pass for improving memory locality
  265. reorder_for_locality = True
  266. # Scale down Rn_BLOCK for better occupancy
  267. dynamic_scale_rblock = os.environ.get("TORCHINDUCTOR_DYNAMIC_SCALE_RBLOCK", "1") == "1"
  268. # this forces fusion for int_mm with mul. Needed when you want to avoid realizing the int32
  269. # but the mul gets fused with other pointwise ops instead.
  270. force_fuse_int_mm_with_mul = False
  271. # DEPRECATED. This setting is ignored.
  272. use_mixed_mm = True
  273. # enable runtime numeric check for pre/post grad fx passes
  274. # floating point provides limited accuracy (about 7 decimal digits for single precision
  275. # floating point numbers,about 16 decimal digits for double precision floating point numbers)
  276. # according to PyTorch documentation.
  277. # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
  278. fx_passes_numeric_check: dict[str, Any] = {
  279. "pre_grad": False,
  280. "precision": 1e-4,
  281. "num_iterations": 1,
  282. "requires_optimizer": True,
  283. }
  284. # DEPRECATED. This setting is ignored.
  285. mixed_mm_choice: Literal["default", "triton", "aten", "heuristic"] = "heuristic"
  286. # enable reordering pass for increasing overlap between compute and communication
  287. reorder_for_compute_comm_overlap = False
  288. # passes (in execution order) for increasing overlap between compute and communication
  289. # for built-in passes, use string name; for user-defined passes, pass in the function handle
  290. # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
  291. # hence custom IR passes built on top of it might break in the future.
  292. reorder_for_compute_comm_overlap_passes: list[
  293. Union[
  294. str,
  295. Callable[
  296. [list["torch._inductor.scheduler.BaseSchedulerNode"]],
  297. list["torch._inductor.scheduler.BaseSchedulerNode"],
  298. ],
  299. ]
  300. ] = [
  301. "reorder_compute_for_overlap",
  302. "sink_waits",
  303. "raise_comms",
  304. ]
  305. # Maximum number of positions to advance a given collective, unlimited by default
  306. reorder_prefetch_limit: Optional[int] = None
  307. # enable operator reordering for peak memory optimization
  308. reorder_for_peak_memory = True
  309. reorder_iterative_debug_memory_recompute: bool = False
  310. reorder_iterative_debug_limit_to_reorder: Optional[int] = (
  311. None
  312. if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
  313. else int(env_str)
  314. )
  315. sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
  316. None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
  317. )
  318. bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
  319. # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
  320. bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
  321. bucket_reduce_scatters_fx: Literal["none", "all"] = "none"
  322. # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
  323. bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int]] = (
  324. None
  325. )
  326. # runtime estimation function for ops
  327. # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
  328. estimate_op_runtime = "default"
  329. runtime_estimations_mms_benchmark: bool = False
  330. # unit: GB/s, uni-directional P2P bandwidth per card
  331. # default value is NVLink
  332. intra_node_bw = 300
  333. # unit: GB/s, uni-directional P2P bandwidth per node
  334. # default value is InfiniBand
  335. inter_node_bw = 25
  336. # use Inductor's experimental benchmarker (runtime/benchmarking.py)
  337. # to benchmark kernels during autotuning, otherwise fall back to
  338. # Triton's `do_bench`. the experimental benchmarker may produce
  339. # results that are not consistent with `do_bench`'s results
  340. use_experimental_benchmarker: bool = Config(
  341. default=True,
  342. env_name_force="TORCHINDUCTOR_USE_EXPERIMENTAL_BENCHMARKER",
  343. justknob="pytorch/inductor:use_experimental_benchmarker",
  344. )
  345. # enable slow autotuning passes to select algorithms
  346. max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
  347. # enable slow autotuning passes to select pointwise/reductions algorithms
  348. max_autotune_pointwise = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE") == "1"
  349. # enable slow autotuning passes to select gemm algorithms
  350. max_autotune_gemm = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_GEMM") == "1"
  351. # Modifies the number of autotuning choices displayed, set to None for all
  352. autotune_num_choices_displayed: Optional[int] = 10
  353. # Report the autotune choices and their benchmark results. Default is True.
  354. max_autotune_report_choices_stats = (
  355. os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "1") == "1"
  356. )
  357. # Prune configs that require more shared memory than the hardware limit
  358. max_autotune_prune_choices_based_on_shared_mem = (
  359. os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE_PRUNE_CHOICES_BASED_ON_SHARED_MEM", "1")
  360. == "1"
  361. )
  362. # enable inductor graph partition to allow multiple inductor graphs for the same dynamo graph
  363. graph_partition: bool = (
  364. os.environ.get("TORCHINDUCTOR_GRAPH_PARTITION", "1" if not is_fbcode() else "0")
  365. == "1"
  366. )
  367. # register ops upon which inductor should partition the graph. name format should be
  368. # "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or
  369. # "namespace::kernel_name.overload" (e.g., aten::mm.default).
  370. custom_should_partition_ops: list[str] = []
  371. # force cublas and triton to use the same precision; cublas supports TF32 for matmul operations
  372. # when m, n, k are multiples of 16, 16, 8, whereas triton supports TF32 for matmul operations
  373. # for any combinations of m, n, k, regardless of their alignment. setting this flag will ensure
  374. # that triton does not use TF32 wherever cublas would not use TF32
  375. # DEPRECATED. cuBLAS no longer has the above alignment requirements. will remove in the future.
  376. force_same_precision: bool = Config(
  377. justknob="pytorch/compiler:force_same_precision",
  378. env_name_force="TORCHINDUCTOR_FORCE_SAME_PRECISION",
  379. default=False,
  380. )
  381. # Size hints for multi-kernel dispatch.
  382. # A reasonable default value of this config would be [64, 256, 4096]
  383. # TODO: @bobrenjc93 to roll this out to a few internal models to ensure this works
  384. # as expected before turning it on for everyone.
  385. multi_kernel_hints: list[int] = []
  386. # Specify candidate backends for gemm autotune.
  387. # Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CKTILE, CPP.
  388. # ATen: default Pytorch ATen kernels.
  389. # Triton: Triton templates defined in torch inductor (AMD and NVidia GPUs).
  390. # CUTLASS: Cutlass templates and kernels (NVidia GPUs only).
  391. # CK: Composable Kernel templates and kernels (AMD Instinct GPUs only).
  392. # CKTILE: Composable Kernel templates and kernels, new API (AMD Instinct GPUs only).
  393. # CPP: CPP templates and kernels for CPU.
  394. max_autotune_gemm_backends = os.environ.get(
  395. "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP"
  396. ).upper()
  397. # As above, specify candidate backends for conv autotune.
  398. # NB: in some cases for 1x1 convs we emit as matmul,
  399. # which will use the backends of `max_autotune_gemm_backends`
  400. max_autotune_conv_backends = os.environ.get(
  401. "TORCHINDUCTOR_MAX_AUTOTUNE_CONV_BACKENDS", "ATEN,TRITON"
  402. ).upper()
  403. # Specify the size of the search space for GEMM autotuning.
  404. # DEFAULT - balance between compile time overhead and performance
  405. # EXHAUSTIVE - maximize performance
  406. max_autotune_gemm_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
  407. "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE", "DEFAULT"
  408. ).upper() # type: ignore[assignment]
  409. # Specify the size of the search space for flex attention autotuning.
  410. # DEFAULT - balance between compile time overhead and performance
  411. # EXHAUSTIVE - maximize performance
  412. max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.get(
  413. "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
  414. ).upper() # type: ignore[assignment]
  415. # DEPRECATED. This setting is ignored.
  416. autotune_fallback_to_aten = False
  417. # the value used as a fallback for the unbacked SymInts
  418. # that can appear in the input shapes (e.g., in autotuning)
  419. unbacked_symint_fallback = 8192
  420. # DEPRECATED. This setting is ignored.
  421. search_autotune_cache = False
  422. save_args = os.environ.get("TORCHINDUCTOR_SAVE_ARGS") == "1"
  423. # We will disable creating subprocess for autotuning if this is False
  424. autotune_in_subproc = os.environ.get("TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC") == "1"
  425. # The following three timeouts are applicable if autotune_in_subproc is True:
  426. # Max time that a valid benchmark result may take during autotuning
  427. max_autotune_subproc_result_timeout_seconds = 60.0
  428. # DEPRECATED. This setting is ignored.
  429. max_autotune_subproc_graceful_timeout_seconds = 0.0
  430. # DEPRECATED. This setting is ignored.
  431. max_autotune_subproc_terminate_timeout_seconds = 0.0
  432. # If autotuning in subprocess, whether to use multiple devices
  433. autotune_multi_device = os.environ.get("TORCHINDUCTOR_AUTOTUNE_MULTI_DEVICE") == "1"
  434. coordinate_descent_tuning = (
  435. os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_TUNING") == "1"
  436. )
  437. coordinate_descent_check_all_directions = (
  438. os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_CHECK_ALL_DIRECTIONS") == "1"
  439. )
  440. coordinate_descent_search_radius = int(
  441. os.environ.get("TORCHINDUCTOR_COORDINATE_DESCENT_RADIUS", "1")
  442. )
  443. # AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and
  444. # generate the learned heuristic to code which is shipped with the compiler
  445. # Specify a list of comma separated optimizations to collect data for
  446. autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "")
  447. # Specify a list of comma separated optimizations to use learned heuristics for
  448. autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm")
  449. # If set to 1, will run a JIT post compile hook if one is set.
  450. run_jit_post_compile_hook = (
  451. os.environ.get("TORCHINDUCTOR_RUN_JIT_POST_COMPILE_HOOK", "0") == "1"
  452. )
  453. def run_autoheuristic(name: str) -> bool:
  454. return collect_autoheuristic(name) or use_autoheuristic(name)
  455. def collect_autoheuristic(name: str) -> bool:
  456. return name in torch._inductor.config.autoheuristic_collect.split(",")
  457. def use_autoheuristic(name: str) -> bool:
  458. return name in torch._inductor.config.autoheuristic_use.split(",")
  459. # If set to "DEFAULT", this will use the default log path specified in autoheuristic.py.
  460. # If set to another path, autoheuristic will instead log results to the given path.
  461. autoheuristic_log_path = os.environ.get(
  462. "TORCHINDUCTOR_AUTOHEURISTIC_LOG_PATH", "DEFAULT"
  463. )
  464. # Disabled by default on ROCm, opt-in if model utilises NHWC convolutions
  465. layout_opt_default = "1" if not torch.version.hip else "0"
  466. layout_optimization = (
  467. os.environ.get("TORCHINDUCTOR_LAYOUT_OPTIMIZATION", layout_opt_default) == "1"
  468. )
  469. force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1"
  470. # Whether to keep the output strides the same as eager after layout optimization.
  471. keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1"
  472. # Enabling this will let compiler print warning messages if a generated triton
  473. # kernel has inputs with mixed layouts. This is helpful for perf debugging
  474. # since kernel with mixed layout inputs may run much slower then one whose inputs
  475. # have uniform layouts.
  476. warn_mix_layout = os.environ.get("TORCHINDUCTOR_WARN_MIX_LAYOUT") == "1"
  477. # control store vs recompute heuristic
  478. # For fanouts, rematerialization can lead to exponential blowup. So, have
  479. # smaller threshold
  480. realize_reads_threshold = 4
  481. realize_opcount_threshold = 30
  482. # Threshold to prevent excessive accumulation of ops in one buffer during lowering
  483. realize_acc_reads_threshold = 8
  484. realize_acc_reads_size_threshold: Optional[int] = (
  485. None # TODO(xuanzh): harden this to make it non optional
  486. )
  487. # fallback to eager for random/dropout, this is slow but useful for debugging
  488. fallback_random = False
  489. # automatically create fallbacks when encountering an unhandled op
  490. implicit_fallbacks = True
  491. assume_unaligned_fallback_output = (
  492. os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1"
  493. )
  494. # fuse even in cases without common reads
  495. aggressive_fusion = False
  496. # For each fused kernel in the wrapper, comment with the nodes that get fused.
  497. # Useful for debugging fusion.
  498. debug_fusion: bool = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
  499. benchmark_fusion: bool = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
  500. enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
  501. loop_ordering_after_fusion: bool = (
  502. os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1"
  503. )
  504. # If fusing two nodes only save less then score_fusion_memory_threshold memory,
  505. # we should not bother fusing the nodes.
  506. #
  507. # This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242
  508. # Previously we fuse two nodes because of common read of a scalar tensor.
  509. # If we skip it, the loop ordering after fusion mechanism kicks in and can
  510. # brings more savings.
  511. #
  512. # For the cases loop ordering after fusion does not help, we don't lose much.
  513. score_fusion_memory_threshold = 10
  514. # For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel
  515. benchmark_epilogue_fusion = (
  516. os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1"
  517. )
  518. # Take how many of the top triton kernels to benchmark epilogue
  519. max_epilogue_benchmarked_choices = 1
  520. # how many nodes to allow into a single fusion
  521. max_fusion_size = 64
  522. # how many nodes to attempt pairwise fusion with in a buffer group
  523. max_fusion_buffer_group_pairwise_attempts = 64
  524. # max number of inputs to generate cat as a pointwise op with masked loads
  525. max_pointwise_cat_inputs = 8
  526. # force concat to be generated as a pointwise op with masked loads
  527. force_pointwise_cat = False
  528. # replace small reductions with pointwise, disable with `= 1`
  529. unroll_reductions_threshold = 8
  530. # Add extra comments to output code (causes compile cache misses)
  531. comment_origin = False
  532. # Convert 1x1 convs into matmuls
  533. conv_1x1_as_mm = False
  534. # For reductions with a small output size (usually 1, e.g. x.sum()) there is not enough
  535. # parallelism to saturate the GPU. We have two ways of handling this, either `split_reductions`
  536. # or `triton.cooperative_reductions` which are mutually exclusive.
  537. # split_reductions: uses multiple kernels to gain more parallelism
  538. # triton.cooperative_reductions: uses cross thread-block synchronization to gain more parallelism
  539. # enabling both of these will implicitly disable split_reductions
  540. split_reductions = True
  541. # When we do split reduction, this number control the minimum value for
  542. # num_split. Too small num_split make the split reduction less efficient.
  543. # It's a much bigger problem when we compile a dynamic shape kernel with
  544. # non-representative inputs.
  545. min_num_split = int(os.environ.get("TORCHINDUCTOR_MIN_NUM_SPLIT", 0))
  546. benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"
  547. # Enable constant and index_expr folding
  548. constant_and_index_propagation = True
  549. # we always add constants into graph.constants without
  550. # performing any constant-inlining optimization
  551. always_keep_tensor_constants = False
  552. # assert that indirect indexing does not read / write out of bounds
  553. assert_indirect_indexing = True
  554. # compute CSE bounds on variables that do not appear in the FX graph
  555. compute_all_bounds = False
  556. # enable the combo kernel that combines data-independent kernels (additional
  557. # to foreach kernels) into a single one (Experimental)
  558. combo_kernels = False
  559. # benchmark combo kernels and only allow ones with perf gains
  560. benchmark_combo_kernel = False
  561. # combo_kernel autotuning options: 0 - disable, 1 - enable except for foreach,
  562. # 2 - enable for all
  563. combo_kernels_autotune = 1
  564. # Enable masking for combining kernels of mixed sizes: 0 - disable, 1 - enable
  565. # for all except for foreach, 2 - enable for all
  566. combo_kernel_allow_mixed_sizes = 1
  567. # Enable dynamic shapes for foreach kernels
  568. combo_kernel_foreach_dynamic_shapes = True
  569. # constant folding on the joint graph
  570. joint_graph_constant_folding = True
  571. # Enable indirect_indexing asserts for decompositions and lowerings
  572. debug_index_asserts = False
  573. # Mode to emulate PyTorch eager numerics when doing lower precision compute
  574. # (fp16, bf16). PyTorch eager computes bf16/fp16 by upcasting inputs to fp32
  575. # and downcasting after. When two low precision operators are fused together,
  576. # Inductor will elide the downcast-upcast pairs (effectively a precision
  577. # truncation) that would occur between these two operators. Typically,
  578. # Inductor's behavior should be closer to fp64 ref numerics. However, with
  579. # this knob you can ensure the downcast-upcast are preserved so that you can
  580. # emulate the eager numerics.
  581. emulate_precision_casts = (
  582. os.environ.get("TORCHINDUCTOR_EMULATE_PRECISION_CASTS", "0") == "1"
  583. )
  584. # warnings intended for PyTorch developers, disable for point releases
  585. is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
  586. developer_warnings = is_fbcode() or is_nightly_or_source
  587. # This pattern matches a special usage of scatter
  588. # 1. It's applied to a constant tensor
  589. # 2. The index tensor has size 1 in the scatter dimension
  590. # Such pattern generates a sparse matrix when the const tensor is all-zero.
  591. # We can lower this pattern to a pointwise kernel for more fusion opportunities
  592. # and saving memory footprint.
  593. optimize_scatter_upon_const_tensor = (
  594. os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
  595. )
  596. # options in caffe2/torch/_inductor/fx_passes/pre_grad.py
  597. add_pre_grad_passes: Optional[str] = None
  598. remove_pre_grad_passes: Optional[str] = None
  599. # The multiprocessing start method to use for inductor workers in the codecache.
  600. def decide_worker_start_method() -> str:
  601. if "TORCHINDUCTOR_WORKER_START" in os.environ:
  602. start_method = os.environ["TORCHINDUCTOR_WORKER_START"]
  603. else:
  604. start_method = "subprocess"
  605. assert start_method in (
  606. "subprocess",
  607. "fork",
  608. "spawn",
  609. ), f"Invalid start method: {start_method}"
  610. return start_method
  611. worker_start_method: str = decide_worker_start_method()
  612. # Threshold to decide if a kernel has small memory access in bytes
  613. # Default value is 16 MB which is arbitrarily selected.
  614. small_memory_access_threshold: int = 16777216
  615. # Whether to log from subprocess workers that are launched.
  616. worker_suppress_logging: bool = Config(
  617. justknob="pytorch/compiler:worker_suppress_logging",
  618. env_name_force="TORCHINDUCTOR_WORKER_SUPPRESS_LOGGING",
  619. default=True,
  620. )
  621. # Log per-operation runtime estimates for TLParse analysis.
  622. log_tlparse: bool = Config(
  623. env_name_force="LOG_TLPARSE",
  624. default=False,
  625. )
  626. # Flags to turn on all_reduce fusion. These 2 flags should be automatically turned
  627. # on by DDP and should not be set by the users.
  628. _fuse_ddp_communication = False
  629. _fuse_ddp_bucket_size = 25
  630. # Flag to control which fusion passes to apply. Functions in the list will
  631. # be applied in order. There are two different different fusion passes
  632. # --"fuse_ddp_with_concat_op" and "fuse_ddp_with_coalesced_op". The default
  633. # one is "fuse_ddp_with_concat_op". Users can also change this to a customized
  634. # fusion function.
  635. #
  636. # The fusion currently does not support multiple DDP with different PG or
  637. # data type. This feature will be added in the future PRs.
  638. #
  639. # "schedule_comm_wait" is used to delay the wait ops to maximize comm/comp
  640. # overlapping. At this moment, this pass performs better than
  641. # reorder_for_compute_comm_overlap_passes but we will add the logic of
  642. # "schedule_comm_wait" in the future and remove the one here.
  643. _fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
  644. "fuse_ddp_with_concat_op",
  645. "schedule_comm_wait",
  646. ]
  647. _micro_pipeline_tp: bool = False
  648. class _collective:
  649. auto_select: bool = False
  650. one_shot_all_reduce_threshold_bytes: int = 128 * 1024
  651. def parallel_compile_enabled_internally() -> bool:
  652. """
  653. TODO: Remove when parallel compiled is fully enabled internally. For rollout, use a
  654. knob to enable / disable. The justknob should not be performed at import, however.
  655. So for fbcode, we assign compile_threads to 'None' below and initialize lazily in
  656. async_compile.py.
  657. """
  658. ENABLE_PARALLEL_COMPILE_VERSION = 1
  659. jk_name = "pytorch/inductor:enable_parallel_compile_version"
  660. version = torch._utils_internal.justknobs_getval_int(jk_name)
  661. return ENABLE_PARALLEL_COMPILE_VERSION >= version
  662. def decide_compile_threads() -> int:
  663. """
  664. Here are the precedence to decide compile_threads
  665. 1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
  666. setting this to 1 to make pdb happy.
  667. 2. Set to 1 if it's win32 platform
  668. 3. decide by the number of CPU cores
  669. """
  670. import logging
  671. # Defined locally so install_config_module doesn't try to parse
  672. # as a config option.
  673. log = logging.getLogger(__name__)
  674. if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
  675. compile_threads = int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
  676. log.info("compile_threads set to %d via env", compile_threads)
  677. elif sys.platform == "win32":
  678. compile_threads = 1
  679. log.info("compile_threads set to 1 for win32")
  680. elif is_fbcode() and not parallel_compile_enabled_internally():
  681. compile_threads = 1
  682. log.info("compile_threads set to 1 in fbcode")
  683. else:
  684. cpu_count = (
  685. len(os.sched_getaffinity(0))
  686. if hasattr(os, "sched_getaffinity")
  687. else os.cpu_count()
  688. )
  689. assert cpu_count
  690. compile_threads = min(32, cpu_count)
  691. log.info("compile_threads set to %d", compile_threads)
  692. return compile_threads
  693. # TODO: Set directly after internal rollout.
  694. compile_threads: Optional[int] = None if is_fbcode() else decide_compile_threads()
  695. # Whether to quiesce the Triton-compile subprocess pool at the end of each compilation.
  696. quiesce_async_compile_pool: bool = Config(
  697. justknob="pytorch/inductor:quiesce_async_compile_pool",
  698. env_name_force="TORCHINDUCTOR_QUIESCE_ASYNC_COMPILE_POOL",
  699. default=False,
  700. )
  701. # Whether or not to enable statically launching CUDA kernels
  702. # compiled by triton (instead of using triton's own launcher)
  703. use_static_cuda_launcher: bool = static_cuda_launcher_default()
  704. # Attempt to statically launch user defined triton kernels
  705. # Requires use_static_cuda_launcher
  706. static_launch_user_defined_triton_kernels: bool = Config(
  707. justknob="pytorch/inductor:static_launch_user_defined_triton_kernels",
  708. env_name_force="TORCHINDUCTOR_STATIC_LAUNCH_USER_DEFINED_TRITON_KERNELS",
  709. default=False,
  710. )
  711. # Raise error if we bypass the launcher
  712. strict_static_cuda_launcher: bool = (
  713. os.environ.get("TORCHINDUCTOR_STRICT_STATIC_CUDA_LAUNCHER", "0") == "1"
  714. )
  715. # gemm autotuning global cache dir
  716. global_cache_dir: Optional[str]
  717. if is_fbcode():
  718. try:
  719. from libfb.py import parutil
  720. if __package__:
  721. global_cache_dir = parutil.get_dir_path(
  722. os.path.join(__package__.replace(".", os.sep), "fb/cache")
  723. )
  724. else:
  725. global_cache_dir = parutil.get_dir_path("fb/cache")
  726. except (ValueError, ImportError):
  727. global_cache_dir = None
  728. else:
  729. global_cache_dir = None
  730. # If kernel is fused, the name is generated from the origin node op names
  731. # for larger kernels limit this
  732. kernel_name_max_ops = 10
  733. # Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
  734. shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "1") == "1"
  735. # Control if we will do padding for pointwise/reductions
  736. comprehensive_padding = (
  737. os.environ.get("TORCHINDUCTOR_COMPREHENSIVE_PADDING", "1") == "1"
  738. )
  739. pad_channels_last = False
  740. # Control if we will do padding on dynamic shapes
  741. pad_dynamic_shapes = False
  742. # Disable comprehensive padding on the CPU
  743. disable_padding_cpu = True
  744. # Control if we will expand the dimension of pointwise nodes to fuse
  745. expand_dimension_for_pointwise_nodes = False
  746. # The width of comprehensive padding, in bytes.
  747. # CUDA max memory transaction size is 128 bytes for a warp.
  748. padding_alignment_bytes = 128
  749. # Threshold on the minimum stride that will be padded.
  750. #
  751. # Don't align a too small stride since that causes too much memory increase.
  752. # Pad too small stride may also cause perf loss. We may result in many tiny data blocks
  753. # with gaps in between. That causes less coalesced GPU memory access!
  754. #
  755. # Initially we pick 320 as the threshold since for alignment=16,
  756. # that results in at most 5% memory cost.
  757. #
  758. # But later on we raise the threshold to 1024 to avoid interfere with persistent reduction.
  759. # Let's say an inner reduction has a row size 513. Inductor will generate
  760. # persistent reduction code.
  761. # If we do padding, the strides are not contiguous any more. Inductor
  762. # uses a much smaller threshold for persistent reduction in this case and
  763. # generates potentially worse non-persistent reduction code.
  764. #
  765. # This change turns HF AllenaiLongformerBase amp training from a loss of 1.09x to a win of 1.05x.
  766. # (baseline: 71.09ms, padding w/o this change: 77.38ms, padding with this change: 67.77ms)
  767. padding_stride_threshold = 1024
  768. # Enable padding outputs, even if they would not be padded in eager mode.
  769. # By default, we use the same strides as eager mode.
  770. pad_outputs = False
  771. # Whether to treat output of the backward graph as user visible.
  772. # For user visible outputs, inductor will make sure the stride matches with eager.
  773. bw_outputs_user_visible = True
  774. # Whether to always use shape padding if it is enabled and possible
  775. force_shape_pad: bool = False
  776. # Fx-based linear/matmul/bmm + permute/transpose vertical fusion
  777. permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
  778. # Mark the wrapper call in PyTorch profiler
  779. profiler_mark_wrapper_call = False
  780. # Generate hook calls to torch._inductor.hooks.run_intermediate_hooks for
  781. # every intermediate for which we can correlate it with an intermediate
  782. # from the original FX graph
  783. generate_intermediate_hooks = False
  784. # Populate traceback field on IRNode; good for debugging why origin_node is
  785. # not populated, or finding out where an IRNode was constructed
  786. debug_ir_traceback = False
  787. # used for debugging to make sure config is properly set
  788. _raise_error_for_testing = False
  789. _profile_var = os.environ.get("TORCHINDUCTOR_PROFILE", "")
  790. profile_bandwidth = _profile_var != ""
  791. profile_bandwidth_regex = "" if _profile_var == "1" else _profile_var
  792. # Specify a file where we print out the profiling results.
  793. # None means we do not dump results to a file.
  794. profile_bandwidth_output: Optional[str] = os.environ.get(
  795. "TORCHINDUCTOR_PROFILE_OUTPUT", None
  796. )
  797. # Switch to do_bench_using_profiling to exclude the CPU overheads
  798. profile_bandwidth_with_do_bench_using_profiling = (
  799. os.environ.get("TORCHINDUCTOR_PROFILE_WITH_DO_BENCH_USING_PROFILING") == "1"
  800. )
  801. # TODO: remove later
  802. # incompatible with cpp_wrapper
  803. disable_cpp_codegen = False
  804. # Freezing will attempt to inline weights as constants in optimization
  805. # and run constant folding and other optimizations on them. After freezing, weights
  806. # can no longer be updated.
  807. freezing: bool = os.environ.get("TORCHINDUCTOR_FREEZING", "0") == "1"
  808. # Make freezing invalidate the eager Parameters of nn modules, to avoid memory overhead
  809. # of potentially keeping multiple copies of weights.
  810. freezing_discard_parameters: bool = False
  811. # decompose some memory bound matmul/bmm to mul
  812. decompose_mem_bound_mm: bool = False
  813. # assume_aligned_inputs means that we assume that inputs will be aligned; we generate
  814. # code using this assumption, and clone tensors before use if they aren't aligned.
  815. # In the common case, most inputs will be aligned.
  816. assume_aligned_inputs: bool = False
  817. # For the user-written Triton kernels compiled with the model, ignore the unsupported
  818. # arguments passed to the @triton.autotune in the user's code; this is unsafe, as
  819. # ignoring the unsupported args may lead to unexpected autotuning behavior: don't
  820. # set unless you know what you're doing.
  821. unsafe_ignore_unsupported_triton_autotune_args: bool = False
  822. # When True, we will check in scheduler.py _codegen that there are no "loops"
  823. # in the call stack; that is to say, the same frame multiple times. This
  824. # ensures that a cProfile trace to this frame will be a straight line without
  825. # any cycles. Incompatible with cpp_wrapper.
  826. check_stack_no_cycles_TESTING_ONLY: bool = False
  827. # When True, complex_memory_overlap always reports True
  828. always_complex_memory_overlap_TESTING_ONLY: bool = False
  829. # enable linear binary folding
  830. enable_linear_binary_folding = (
  831. os.environ.get("TORCHINDUCTOR_ENABLE_LINEAR_BINARY_FOLDING", "0") == "1"
  832. )
  833. # Adds NVTX annotations around training phases
  834. annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
  835. # Enable caching codegen of triton templates.
  836. enable_caching_generated_triton_templates: bool = True
  837. # Lookup table for overriding autotune configs based on hash of Triton source code
  838. autotune_lookup_table: dict[str, dict[str, Any]] = {}
  839. def get_worker_log_path() -> Optional[str]:
  840. log_loc = None
  841. if is_fbcode():
  842. mast_job_name = os.environ.get("MAST_HPC_JOB_NAME", None)
  843. global_rank = os.environ.get("ROLE_RANK", "0")
  844. if mast_job_name is not None:
  845. log_loc = f"/logs/dedicated_log_torch_compile_worker_rank{global_rank}"
  846. return log_loc
  847. torchinductor_worker_logpath: str = Config(
  848. env_name_force="TORCHINDUCTOR_WORKER_LOGPATH",
  849. default="",
  850. )
  851. # config specific to codegen/cpp.py
  852. class cpp:
  853. """
  854. Settings for cpp backend.
  855. This class provides a centralized location for managing cpp backend settings.
  856. """
  857. # set to torch.get_num_threads()
  858. threads = -1
  859. # Do not generate loops when the condition doesn't hold, like:
  860. # for(long i0=4096; i0<4096; i0+=1)
  861. no_redundant_loops = (
  862. os.environ.get("TORCHINDUCTOR_CPP_NO_REDUNDANT_LOOPS", "1") == "1"
  863. )
  864. # Assume number of threads is dynamic, don't specialize thread number.
  865. # Kernels don't recompile on thread number changes with this flag on.
  866. # For single-threaded workload, turning it on would incur a slight
  867. # performance degradation.
  868. dynamic_threads = os.environ.get("TORCHINDUCTOR_CPP_DYNAMIC_THREADS", "0") == "1"
  869. simdlen: Optional[int] = None
  870. min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "512"))
  871. cxx: tuple[Literal[None], str] = (
  872. None, # download gcc12 from conda-forge if conda is installed
  873. os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
  874. ) # type: ignore[assignment]
  875. # Allow kernel performance profiling via PyTorch profiler
  876. enable_kernel_profile = (
  877. os.environ.get("TORCHINDUCTOR_CPP_ENABLE_KERNEL_PROFILE", "0") == "1"
  878. )
  879. # enable weight prepacking to get a better performance; may lead to large memory footprint
  880. weight_prepack = os.environ.get("TORCHINDUCTOR_CPP_WEIGHT_PREPACK", "1") == "1"
  881. # Inject a bug into our relu implementation; useful for testing our repro
  882. # extraction and minification functionality.
  883. # Valid values: "compile_error", "runtime_error", "accuracy"
  884. inject_relu_bug_TESTING_ONLY: Optional[str] = None
  885. inject_log1p_bug_TESTING_ONLY: Optional[str] = None
  886. # If None, autodetect whether or not AVX512/AVX2 can be used. Otherwise,
  887. # force usage as specified, without testing. Default None.
  888. vec_isa_ok: Optional[bool] = get_tristate_env("TORCHINDUCTOR_VEC_ISA_OK")
  889. # similar to config.triton.descriptive_names
  890. descriptive_names: Literal["torch", "original_aten", "inductor_node"] = (
  891. "original_aten"
  892. )
  893. # how many nodes to allow into a single horizontal fusion
  894. max_horizontal_fusion_size = int(
  895. os.environ.get("TORCHINDUCTOR_CPP_MAX_HORIZONTAL_FUSION_SIZE", "16")
  896. )
  897. # Make scatter_reduce fallback when reduce is sum to avoid performance regression
  898. # using atomic_add.
  899. fallback_scatter_reduce_sum = (
  900. os.environ.get("TORCHINDUCTOR_CPP_FALLBACK_SCATTER_REDUCE_SUM", "1") == "1"
  901. )
  902. # Use funsafe-math-optimizations when compiling
  903. enable_unsafe_math_opt_flag = (
  904. os.environ.get("TORCHINDUCTOR_CPP_ENABLE_UNSAFE_MATH_OPT_FLAG", "0") == "1"
  905. )
  906. # Use ffp-contract when compiling
  907. # Options: "off" (default), "on", "fast"
  908. # Per https://godbolt.org/z/bf4bvfc9r , clang/gcc has different behavior for "fast"
  909. enable_floating_point_contract_flag = os.environ.get(
  910. "TORCHINDUCTOR_CPP_ENABLE_FLOATING_POINT_CONTRACT_FLAG", "off"
  911. )
  912. # Disable the tiling select heuristic
  913. enable_tiling_heuristics = (
  914. os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
  915. )
  916. # Enable the Grouped GEMM Fusion
  917. enable_grouped_gemm_template = False
  918. # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
  919. # the maximal parallelism of K-slicing. Since K-slicing requires extra thread
  920. # synchronization and buffers, the maximal number of slices is limited to
  921. # mitigate the sync overhead and memory usage.
  922. # When set to 0, the number of slices is unlimited.
  923. gemm_max_k_slices = int(os.environ.get("TORCHINDUCTOR_CPP_GEMM_MAX_K_SLICES", "1"))
  924. # For perf tuning and debugging purpose, configure the pre-defined cache blocking for
  925. # MxNxK dims respectively. The blockings are separated by comma and the unit is
  926. # the number of register blocks.
  927. # For example, "4,1,10" means 4 register blocks on M, 1 on N and 10 on K respectively.
  928. gemm_cache_blocking = os.environ.get("TORCHINDUCTOR_CPP_GEMM_CACHE_BLOCKING", None)
  929. # For perf tuning and debugging purpose, configure the pre-defined thread blocking factors for
  930. # MxNxK dims respectively. The factors are separated by comma and their product
  931. # should be the same as the total number of threads.
  932. # For example, if the total number of threads is 56, "7,4,2" means the work is
  933. # decomposed into 7x4x2 thread blocks along MxNxK of a GEMM.
  934. gemm_thread_factors = os.environ.get("TORCHINDUCTOR_CPP_GEMM_THREAD_FACTORS", None)
  935. # Whether to enable masked vectorization for the tail_loop.
  936. enable_loop_tail_vec = True
  937. # Whether to enable concat linear for cpu device
  938. # Currently concat linear on CPU not always have benefit, depends on linear'shape or
  939. # computing resource. We set this default to False to avoid regressions. User and
  940. # enable this feature by their need.
  941. enable_concat_linear = False
  942. # Whether to use decomposed tanh for cpu device
  943. # Disable by default due to https://github.com/pytorch/pytorch/issues/148241
  944. use_decompose_tanh = (
  945. os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1"
  946. )
  947. # Use a small dequant buffer for wgt of woq int4 size as: [q_group_size, Nr]
  948. use_small_dequant_buffer = False
  949. force_inline_kernel = (
  950. os.environ.get("TORCHINDUCTOR_CPP_FORCE_INLINE_KERNEL", "0") == "1"
  951. )
  952. # Use static constexpr or static const for int array
  953. use_constexpr_for_int_array = (
  954. os.environ.get("TORCHINDUCTOR_CPP_USE_CONSTEXPR_FOR_INT_ARRAY", "1") == "1"
  955. )
  956. class triton:
  957. """
  958. Config specific to codegen/triton.py
  959. """
  960. # Use cudagraphs on output code
  961. cudagraphs = os.environ.get("TORCHINDUCTOR_CUDAGRAPHS") == "1"
  962. # Use cudagraph trees for memory pooling if `cudagraphs` is True
  963. cudagraph_trees = True
  964. # Should we skip cudagraphing graphs with dynamic shape inputs
  965. # If False, we will re-record a graph for each unique set of shape inputs
  966. cudagraph_skip_dynamic_graphs = False
  967. # Specify dynamic shapes to capture cudagraphs and skip cudagraph for other shapes.
  968. # Default to None, which means we capture cudagraphs for all shapes.
  969. cudagraph_capture_sizes: Optional[tuple[Union[int, tuple[int, ...]]]] = None
  970. # assertions not on the fast path, steady state
  971. slow_path_cudagraph_asserts = True
  972. # TODO - need to debug why this prevents cleanup
  973. cudagraph_trees_history_recording = False
  974. # Enable cudagraph support for mutated inputs from prior cudagraph pool
  975. cudagraph_support_input_mutation = False if is_fbcode() else True
  976. # Maximal number of allowed cudagraph re-record for a function and
  977. # a cudagraph node due to static input tensor address changes or
  978. # cudagraph managed tensor data pointer changed.
  979. # i.e., allow num_recording <= cudagraph_unexpected_rerecord_limit
  980. # note: we are conservative here and choose a large limit.
  981. cudagraph_unexpected_rerecord_limit = 128
  982. # Warn loudly when the number of cudagraphs due to dynamic shape
  983. # exceeds this limit
  984. cudagraph_dynamic_shape_warn_limit: Optional[int] = 50
  985. # synchronize after cudagraph invocation
  986. force_cudagraph_sync = False
  987. # always run cudagraphs in the eager warmup stage
  988. # instead of recording and executing cudagraphs
  989. force_cudagraphs_warmup = False
  990. # If False (default), torch.compile skips cudagraph for a graph if it
  991. # contains cudagraph-unsafe ops. If True, we require that all cuda ops
  992. # be captured into cudagraph. If this is not possible, this will raise
  993. # an error.
  994. cudagraph_or_error: bool = Config(
  995. env_name_force="TORCHINDUCTOR_CUDAGRAPH_OR_ERROR",
  996. default=False,
  997. )
  998. # assertions on the fast path
  999. fast_path_cudagraph_asserts = False
  1000. # skip warmup for cudagraph trees
  1001. skip_cudagraph_warmup = False
  1002. # Synchronize before and after every compiled graph.
  1003. debug_sync_graph = False
  1004. # Synchronize after every kernel launch, to help pinpoint bugs
  1005. debug_sync_kernel = False
  1006. # Always load full blocks (rather than broadcasting inside the block)
  1007. dense_indexing = False
  1008. # TODO - enable by default
  1009. coalesce_tiling_analysis: bool = (
  1010. os.environ.get(
  1011. "TORCHINDUCTOR_COALESCE_TILING_ANALYSIS", "1" if not is_fbcode() else "0"
  1012. )
  1013. == "1"
  1014. )
  1015. # limit tiling dimensions
  1016. # - max_tiles=1 disables tiling
  1017. # - max_tiles=2
  1018. # - max_tiles=3 is experimental and may have bugs
  1019. # higher values are unsupported
  1020. # We use a max of 3 if coalesce_tiling_analysis is True, and 2 otherwise.
  1021. # Note - coalesce_tiling_analysis does not yet apply to dynamic shapes.
  1022. max_tiles: Optional[int] = None
  1023. # Prefer higher dimensional tilings. This simplifies indexing expressions, making
  1024. # it easier to identify block pointers.
  1025. prefer_nd_tiling: bool = False
  1026. # use triton.autotune for pointwise ops with complex layouts
  1027. # this should only be disabled for debugging/testing
  1028. autotune_pointwise = True
  1029. # max autotune gemm with cublasLt
  1030. autotune_cublasLt = True
  1031. # Tune the generated Triton kernels at compile time instead of first time they run
  1032. # Setting to None means uninitialized
  1033. autotune_at_compile_time: Optional[bool] = None
  1034. # We use random tensors for autotune by default. Setting this as true will let us
  1035. # use inputs from sample inputs to autotune user defined triton kernels.
  1036. # Side effect for this option is increased memory footprint during first pass compilation.
  1037. autotune_with_sample_inputs: bool = False
  1038. # Allows tiling reductions into multiple dimensions.
  1039. # For best results, this should be used with prefer_nd_tiling.
  1040. tile_reductions: bool = False
  1041. # should we stop a fusion to allow better tiling?
  1042. tiling_prevents_pointwise_fusion = True
  1043. tiling_prevents_reduction_fusion = True
  1044. # should we give different names to kernels
  1045. # Note: This is orthogonal to descriptive_names - this is deciding whether
  1046. # our triton kernel names should all be `triton_` (to maximize caching) or
  1047. # whether they should be unique.
  1048. unique_kernel_names = (
  1049. os.environ.get("TORCHINDUCTOR_UNIQUE_KERNEL_NAMES", "1") == "1"
  1050. )
  1051. # similar to the option above, but this is specific to user defined kernels,
  1052. # while unique_kernel_name is for kernels generated by inductor.
  1053. # We have this option because sometimes we reuse user's kernel code with different
  1054. # configs which would result in the same name.
  1055. # Note: This MODIFIES the user's kernel function name within inductor phase.
  1056. unique_user_kernel_names = (
  1057. os.environ.get("TORCHINDUCTOR_UNIQUE_USER_KERNEL_NAMES", "0") == "1"
  1058. )
  1059. # should we put op names in kernel names
  1060. # "torch": Maps to the fx op in the Dynamo graph (module name, method name, etc.)
  1061. # "original_aten": Maps to the highest-level aten op (i.e. pre-decompositions)
  1062. # "inductor_node": Maps to the node name in the FX graph passed to Inductor
  1063. descriptive_names: Literal["torch", "original_aten", "inductor_node"] = (
  1064. "original_aten"
  1065. )
  1066. # use alternate codegen for smaller reductions
  1067. persistent_reductions = (
  1068. os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1"
  1069. )
  1070. # For small output size reductions uses cross thread-block synchronization to gain more parallelism
  1071. cooperative_reductions = (
  1072. os.environ.get("TORCHINDUCTOR_COOPERATIVE_REDUCTIONS", "0") == "1"
  1073. )
  1074. # used for debugging cooperative reduction codegen, always generate cooperative_reductions
  1075. force_cooperative_reductions = False
  1076. # 0: disable
  1077. # 1/True: enable, use tuning to pick between different subkernels
  1078. # 2: enable, force using persistent reduction (for debugging)
  1079. # 3: enable, force using non-persistent reduction (for debugging)
  1080. multi_kernel: Literal[0, 1, 2, 3] = int(
  1081. os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "0")
  1082. ) # type: ignore[assignment]
  1083. # hint to Triton when arguments are divisible by 16
  1084. divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1"
  1085. # Minimum R0_BLOCK to be used for a TritonSplitScanKernel
  1086. # NOTE: This also indirectly controls the size of workspace buffer required
  1087. min_split_scan_rblock = 256
  1088. # Store the generated cubin files for cpp wrapper code to load
  1089. store_cubin = False
  1090. # the max number of spills we allow for the configs we benchmark.
  1091. # Setting this to 0 means we skip a config if it spills even a single
  1092. # register.
  1093. # Setting it to a larger value allows a config spilling a small amount
  1094. # of registers being benchmarked.
  1095. #
  1096. # NOTE: triton will always report >0 register spills for kernels using sin/cos.
  1097. # (check this issue https://github.com/triton-lang/triton/issues/1756 )
  1098. # So far we see a fixed 8 spilled registers for kernels using sin/cos.
  1099. # Raise the threshold to 16 to be safe.
  1100. # We should revisit this once we understand more of the source of register spills.
  1101. spill_threshold: int = 16
  1102. # Generate code containing the newer tl.make_block_ptr() API for loads/store
  1103. use_block_ptr = False
  1104. # (Experimental)
  1105. # Generate code using the tl.make_tensor_descriptor() API for loads/store
  1106. # [Note: TMA API Restrictions] Currently the TMA API requires the following:
  1107. # - For Nvidia GPUs, the compute capability should be >= 9.0
  1108. # - The innermost stride of a descriptor should be 1
  1109. # - The size of the block shape in the innermost dimension should load / store
  1110. # at least 16 bytes.
  1111. # - Tensors are 16 byte aligned. Enabling this option therefore requires
  1112. # assume_aligned_inputs to also be enabled
  1113. # TMA descriptors are only going to be generated if the above conditions
  1114. # can be satisfied, along with any existing requirements for index expressions
  1115. use_tensor_descriptor = False
  1116. # Inject a bug into our relu implementation; useful for testing our repro
  1117. # extraction and minification functionality.
  1118. # Valid values: "compile_error", "runtime_error", "accuracy"
  1119. inject_relu_bug_TESTING_ONLY: Optional[str] = None
  1120. # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
  1121. codegen_upcast_to_fp32 = True
  1122. # Whether persistent matmul kernels should be enabled this flag only has effect when on h100
  1123. # with a version of triton new enough to support TMA
  1124. enable_persistent_tma_matmul = (
  1125. os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1"
  1126. )
  1127. # Skip L1 cache for buffers that are used only once. Disabled by default
  1128. skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1"
  1129. # During autotuning, if one of the kernels/configs fails for some reason,
  1130. # Inductor will usually skip it (and assign its latency to inf).
  1131. # For testing it's helpful to be able to assert that none of the configs fail.
  1132. # Note: it may also need to be used with config.compile_threads = 1
  1133. disallow_failing_autotune_kernels_TESTING_ONLY = False
  1134. # specify number of splits to autotune on for decompose_k. 0 disables decompose_k
  1135. num_decompose_k_splits = int(
  1136. os.environ.get("TORCHINDUCTOR_NUM_DECOMPOSE_K_SPLITS", "10")
  1137. )
  1138. # specify minimum ratio of K to M AND N in order to autotune on decompose_k. 0 enables
  1139. # it as an autotuning choice for all matmuls
  1140. decompose_k_threshold = int(
  1141. os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32")
  1142. )
  1143. class aot_inductor:
  1144. """
  1145. Settings for Ahead-Of-Time Inductor Compilation
  1146. """
  1147. # AOTInductor output path
  1148. # If an absolute path is specified, the generated lib files will be stored under the directory;
  1149. # If a relative path is specified, it will be used as a subdirectory under the default caching path;
  1150. # If not specified, a temp directory will be created under the default caching path.
  1151. # If the specified path contains something like "model.so", the sub-string will be used
  1152. # to name the generated library.
  1153. output_path = ""
  1154. debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
  1155. # Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl,
  1156. # to use which cpp compiler optimization level, default to O1
  1157. compile_wrapper_opt_level = os.environ.get(
  1158. "AOT_INDUCTOR_COMPILE_WRAPPER_OPT_LEVEL", "O1"
  1159. )
  1160. # option for debug printing/saving for intermediate tensor values for aot inductor
  1161. # 0: disable debug dumping
  1162. # 1: enable saving intermediate tensor values
  1163. # 2: enable printing intermediate tensor values
  1164. # 3: enable printing kernel names only (useful for pinpointing troublesome kernels)
  1165. debug_intermediate_value_printer: Literal["0", "1", "2", "3"] = os.environ.get(
  1166. "AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER", "0"
  1167. ) # type: ignore[assignment]
  1168. # filtered nodes to be printed for debug values. Specify this option when debug_intermediate_value_printer is set to 2
  1169. filtered_kernel_names = os.environ.get(
  1170. "AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT", None
  1171. )
  1172. # Serialized tree spec for flattening inputs
  1173. # TODO: Move this into metadata
  1174. serialized_in_spec = ""
  1175. # Serialized tree spec for flattening outputs
  1176. # TODO: Move this into metadata
  1177. serialized_out_spec = ""
  1178. # flag to decide whether to create a submodule for constant graph.
  1179. use_runtime_constant_folding: bool = False
  1180. # flag to force weight to be appended to the shared library and mapped by the runtime
  1181. # rather than embedded into the data section. Needed to support 1B+ parameter models
  1182. force_mmap_weights: bool = False
  1183. # Default value of use_consts_asm_build is True, it will build by assembly language.
  1184. # When the value is False, it will build by c++ language.
  1185. use_consts_asm_build = True
  1186. package: bool = False
  1187. package_cpp_only: Optional[bool] = None
  1188. # Dictionary of metadata users might want to save to pass to the runtime.
  1189. # TODO: Move this somewhere else, since it's no longer really a config
  1190. metadata: dict[str, str] = {}
  1191. # fbcode only. Whether to raise error if C++ codegen is too big to optimize
  1192. raise_error_on_ignored_optimization: bool = (
  1193. os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1"
  1194. )
  1195. # dump an aoti minifier if program errors
  1196. dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
  1197. # Compiler compilation debug info
  1198. # 1: Dumps the original graph out to repro.py if compilation fails
  1199. # 2: Dumps a minifier_launcher.py if aoti fails.
  1200. # 3: Always dumps a minifier_launcher.py. Good for segfaults.
  1201. # 4: Dumps a minifier_launcher.py if the accuracy fails.
  1202. repro_level: int = int(os.environ.get("AOTINDUCTOR_REPRO_LEVEL", 2))
  1203. # Dictionary of presets that can be passed in
  1204. presets: dict[str, Any] = {}
  1205. # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
  1206. # should be run with this flag both on and off to make sure we have coverage.
  1207. allow_stack_allocation: bool = False
  1208. # Enables an alternate DSO interface (the "minimal ArrayRef interface") intended
  1209. # to maximize performance for use cases that it can accommodate at the expense of
  1210. # generality. In brief:
  1211. # - inputs and outputs are ArrayRefTensor<T> (note that strides are required, but the
  1212. # tensor must be contiguous)
  1213. # - constant handling is unchanged because it is not a per-inference-iteration bottleneck
  1214. #
  1215. # When the DSO is generated in this mode, the usual interface will also be supported,
  1216. # but performance for that interface may be degraded.
  1217. use_minimal_arrayref_interface: bool = False
  1218. # Set to True if we want to use Pytorch's CUDACachingAllocator for weight management
  1219. weight_use_caching_allocator: bool = (
  1220. os.environ.get("AOT_INDUCTOR_WEIGHT_USE_CACHING_ALLOCATOR", "0") == "1"
  1221. )
  1222. # Experimental. Flag to control whether to include weight in .so
  1223. package_constants_in_so: bool = True
  1224. # Experimental. Flag to control whether to package weight separately on disk
  1225. package_constants_on_disk: bool = False
  1226. # Experimental. Controls automatic precompiling of common AOTI include files.
  1227. precompile_headers: bool = not is_fbcode()
  1228. # Embed generated kernel binary files into model.so
  1229. embed_kernel_binary: Optional[bool] = None
  1230. # Generate kernel files that support multiple archs
  1231. # For CUDA, this means generating fatbin files for kernels, and the fatbin files
  1232. # contains PTX and SASS for the current architecture.
  1233. emit_multi_arch_kernel: Optional[bool] = None
  1234. # If not None, the generated files with use this name in file stem.
  1235. # If None, we will use a hash to name files.
  1236. #
  1237. # If package_cpp_only, this name is also used for the target name in CMakelists.txt
  1238. # The default target name is "aoti_model"
  1239. #
  1240. # If compile_standalone, the aoti model class name is f"AOTInductorModel{name}"
  1241. #
  1242. # This name can only contain letters, numbers, and underscores.
  1243. model_name_for_generated_files: Optional[str] = None
  1244. # Custom ops that have implemented C shim wrappers, defined as an op to C shim declaration dict
  1245. custom_ops_to_c_shims: dict[torch._ops.OpOverload, list[str]] = {}
  1246. # custom op libs that have implemented C shim wrappers
  1247. custom_op_libs: Optional[list[str]] = None
  1248. compile_standalone: bool = False
  1249. # Whether to enable link-time-optimization
  1250. enable_lto = os.environ.get("AOT_INDUCTOR_ENABLE_LTO", "0") == "1"
  1251. class cuda:
  1252. """Settings for cuda backend, today this consists of cutlass"""
  1253. # CUDA arch to use for CUDA template kernel compilation.
  1254. # e.g. "70", "75", "80", "90", etc.
  1255. # When arch is None, Inductor uses torch.cuda.get_device_capability(0).
  1256. arch: Optional[str] = None
  1257. # CUDA version to use for CUDA template kernel compilation.
  1258. # e.g. "11.4", "12.1", etc.
  1259. # When version is None, Inductor uses torch.version.cuda.
  1260. version: Optional[str] = None
  1261. # Optimization level for the host compiler.
  1262. compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
  1263. # Whether to enable device LTO (link-time-optimization).
  1264. enable_cuda_lto = False
  1265. # Whether to keep intermediate files dring compilation.
  1266. enable_ptxas_info = False
  1267. # Whether to enable debug info, e.g. line number, cutlass debug info.
  1268. enable_debug_info = False
  1269. # Whether to use fast math.
  1270. use_fast_math = False
  1271. # Path to the CUTLASS repo root directory.
  1272. # The default path only works under PyTorch local development environment.
  1273. cutlass_dir = os.path.realpath(
  1274. os.environ.get(
  1275. "TORCHINDUCTOR_CUTLASS_DIR",
  1276. os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/"),
  1277. )
  1278. )
  1279. # Configures the maximum number of CUTLASS configs to profile in max_autotune.
  1280. # By default it's None, so that all CUTLASS configs are tuned.
  1281. # This is mainly used to reduce test time in CI.
  1282. cutlass_max_profiling_configs: Optional[int] = None
  1283. # The L2 swizzle values to consider when profiling CUTLASS configs in max_autotune.
  1284. cutlass_max_profiling_swizzle_options: list[int] = [1, 2, 4, 8]
  1285. # Whether to use CUTLASS EVT for epilogue fusion
  1286. cutlass_epilogue_fusion_enabled = (
  1287. os.environ.get("CUTLASS_EPILOGUE_FUSION", "0") == "1"
  1288. )
  1289. # Whether to only use TMA-compatible kernels in CUTLASS
  1290. cutlass_tma_only = False
  1291. # Path to CUDA NVCC.
  1292. # NVCC search order:
  1293. # 1) cuda_cxx set in this config
  1294. # 2) CUDACXX environment variable
  1295. # 3) CUDA_HOME environment variable
  1296. # 4) default system search PATH.
  1297. cuda_cxx: Optional[str] = None
  1298. # Minimum value of M*N*K to consider the CUTLASS backend for GEMM ops.
  1299. cutlass_backend_min_gemm_size: int = 1
  1300. # enable generation of inline standalone runner in CUDA CPP generated code
  1301. # which allows to compile the generated code into a standalone executable.
  1302. generate_test_runner: bool = (
  1303. os.environ.get("INDUCTOR_CUDA_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1"
  1304. )
  1305. # Keep only Cutlass op configs which contain this regular expression pattern
  1306. # Set this to "warpspecialized_cooperative_epi_tma" to enable only SM90 TMA Cutlass Kernels for large GEMMs
  1307. cutlass_op_allowlist_regex: Optional[str] = os.environ.get(
  1308. "TORCHINDUCTOR_CUTLASS_ALLOWLIST"
  1309. )
  1310. # Note: Names of Cutlass ops names can be obtained by calling
  1311. # op.configuration_name() on a Cutlass op instance, for example those
  1312. # returned from cutlass_utils.gen_ops() or the op argument passed to
  1313. # CUTLASSGemmTemplate.render(...)
  1314. # Filter Cutlass configs which contain this regular expression pattern
  1315. # Set this to "pingpong" to avoid numerical issues
  1316. # caused by the op ordering of the "pingpong" memory access
  1317. # pattern used by some Cutlass Kernels.
  1318. cutlass_op_denylist_regex: Optional[str] = os.environ.get(
  1319. "TORCHINDUCTOR_CUTLASS_DENYLIST"
  1320. )
  1321. # Non-negative integer which determines how many kernels are instantiated.
  1322. # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations.
  1323. # increasing first digit reduces schedule / mixed type pruning,
  1324. # increasing second digit generates more cluster sizes,
  1325. # increasing third digit generates more MMA multipliers,
  1326. # increasing fourth digit generates more instruction shapes.
  1327. cutlass_instantiation_level: str = os.environ.get(
  1328. "TORCHINDUCTOR_CUTLASS_INSTANTIATION_LEVEL", "0"
  1329. )
  1330. # Experimental. Only for H100 for now. Flag to control whether to use presets.
  1331. # Format looks like: "0,1,3" for using presets 0, 1, and 3. Presets can be
  1332. # controlled by some cutlass instantiation level flags (e.g. 0, 1111, 2222, ...)
  1333. cutlass_presets: Optional[str] = os.environ.get("TORCHINDUCTOR_CUTLASS_PRESETS")
  1334. # use compile command to create kernel .cu and .so name
  1335. cutlass_hash_with_compile_cmd: bool = (
  1336. os.environ.get("TORCHINDUCTOR_CUTLASS_HASH_WITH_COMPILE_CMD", "0") == "1"
  1337. )
  1338. # Experimental. Prescreen top x configs before tuning on swizzle.
  1339. cutlass_prescreening: bool = (
  1340. os.environ.get("TORCHINDUCTOR_CUTLASS_PRESCREENING", "1") == "1"
  1341. )
  1342. # Specify which operations should use CUTLASS backend
  1343. # Comma-separated list like "mm,addmm,bmm", "all" for all operations, and "" for none.
  1344. # Acceptable operations: mm, int_mm, addmm, sparse_semi_structured_mm, bmm, scaled_mm
  1345. cutlass_enabled_ops: str = os.environ.get(
  1346. "TORCHINDUCTOR_CUTLASS_ENABLED_OPS", "all"
  1347. )
  1348. # Whether to consult the binary remote cache
  1349. use_binary_remote_cache: bool = True
  1350. # Whether to upload compiled kernels to remote cache
  1351. upload_to_binary_remote_cache: bool = False
  1352. # Whether to force upload if the key already exists
  1353. # Use this to overwrite and handle cache pollution
  1354. binary_remote_cache_force_write: bool = False
  1355. # Enable caching codegen of cuda templates.
  1356. enable_caching_codegen: bool = True
  1357. class rocm:
  1358. # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"].
  1359. # If empty, the `native` arch is used
  1360. arch: list[str] = []
  1361. # Enable the CK backend for CDNA2 and CDNA3 only (for now)
  1362. # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
  1363. ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950"]] = [
  1364. "gfx90a",
  1365. "gfx942",
  1366. "gfx950",
  1367. ]
  1368. # Optimization level, use to balance compilation speed and runtime performance.
  1369. # The type will not necessarily be comprehensive and won't be enforced at runtime.
  1370. compile_opt_level: Literal[
  1371. "-O0", "-O1", "-O2", "-O3", "-Os", "-Oz", "-Omin", "-Ofast", "-Omax"
  1372. ] = "-O2"
  1373. # Flag to keep debug information in compiled objects
  1374. is_debug = False
  1375. # Flag to keep intermediate files (assembly listings, preprocessed sources, etc.)
  1376. save_temps = False
  1377. # Flag to add `-ffast-math`` to compile flags
  1378. use_fast_math = True
  1379. # Flag to add `-fgpu-flush-denormals-to-zero` to compile flags
  1380. flush_denormals = True
  1381. # Flag to print register and LDS usage during compilation
  1382. print_kernel_resource_usage = False
  1383. # Path to ROCm installation, if None, use env variable ROCM_HOME.
  1384. # In fbcode see triton/fb/TARGETS for how ROCM_HOME gets set.
  1385. rocm_home: Optional[str] = None
  1386. # Path to Composable Kernel library.
  1387. # Install with `pip install git+https://github.com/rocm/composable_kernel@develop`.
  1388. ck_dir = os.environ.get("TORCHINDUCTOR_CK_DIR")
  1389. # generate standalone executables for instances generated with the CK backend
  1390. generate_test_runner: bool = (
  1391. os.environ.get("INDUCTOR_CK_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1"
  1392. )
  1393. # Deprecated, use CK and/or CK-tile specific settings
  1394. n_max_profiling_configs: Optional[int] = None
  1395. # Number of op instance choices to trade off between runtime perf and compilation time
  1396. # For CK Kernels
  1397. ck_max_profiling_configs: Optional[int] = None
  1398. # Number of op instance choices to trade off between runtime perf and compilation time
  1399. # For CK-Tile Kernels
  1400. ck_tile_max_profiling_configs: Optional[int] = None
  1401. # Flag to use a short list of CK instances which perform well across a variety of shapes.
  1402. # Currently RCR and F16 only
  1403. use_preselected_instances: bool = False
  1404. # List to determine kBatch parameters to sweep over. By default, we calculate one in splitK
  1405. # scenarios, and run on kBatch=1 in non-splitK scenarios
  1406. kBatch_sweep: Optional[list[int]] = None
  1407. # The threshold at which we trigger a splitK config - K // max(M,N) has to be greater than this
  1408. split_k_threshold: int = 16
  1409. # The threshold at which we trigger a contiguous subgraph transformation
  1410. contiguous_threshold: int = 16
  1411. # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
  1412. cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
  1413. # Backend to use for CUDA codegen either "triton" or "halide" (experimental)
  1414. cuda_backend: Literal["triton", "halide"] = "triton"
  1415. class halide:
  1416. # Base halide target to use for CPU devices
  1417. cpu_target = "host"
  1418. # Base halide target to use for CUDA devices
  1419. gpu_target = "host-cuda"
  1420. # Halide autoscheduler to use, choices are:
  1421. # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
  1422. scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
  1423. "Anderson2021"
  1424. )
  1425. scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
  1426. "Adams2019"
  1427. )
  1428. # Controls `no_asserts` flag passed to Halide target (warning: can false positive)
  1429. asserts = False
  1430. # Controls `debug` flag passed to Halide target
  1431. debug = False
  1432. # Enable (or fallback on) scan kernels such as cumsum
  1433. # Halide autoschedulers struggle with these kernels
  1434. scan_kernels = False
  1435. # create a directory containing lots of debug information
  1436. class trace:
  1437. # master switch for all debugging flags below
  1438. enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
  1439. # save real tensors
  1440. save_real_tensors = os.environ.get("TORCH_COMPILE_DEBUG_SAVE_REAL", "0") == "1"
  1441. # Save debug information to a temporary directory
  1442. # If not specified, a temp directory will be created by system
  1443. debug_dir: Optional[str] = None
  1444. # Save python logger call >=logging.DEBUG
  1445. debug_log = False
  1446. # Save python logger call >=logging.INFO
  1447. info_log = False
  1448. # Save input FX graph (post decomps, pre optimization)
  1449. fx_graph = True
  1450. # Save FX graph after transformations
  1451. fx_graph_transformed = True
  1452. # Save TorchInductor IR before fusion pass
  1453. ir_pre_fusion = True
  1454. # Save TorchInductor IR after fusion pass
  1455. ir_post_fusion = True
  1456. # Copy generated code to trace dir
  1457. output_code = True
  1458. # SVG figure showing post-fusion graph
  1459. graph_diagram = os.environ.get("INDUCTOR_POST_FUSION_SVG", "0") == "1"
  1460. # SVG figure showing fx with fusion
  1461. draw_orig_fx_graph = os.environ.get("INDUCTOR_ORIG_FX_SVG", "0") == "1"
  1462. # We draw our fx graphs with the "record" shape attribute by default.
  1463. # Sometimes, when the graph is very complex, we may hit dot errors like below:
  1464. # "flat edge between adjacent nodes one of which has a record shape -
  1465. # replace records with HTML-like labels"
  1466. # and thus fail to generate a graph. So, let's give the user an option
  1467. # to specify the shape attribute for the dot graph. For example, passing
  1468. # INDUCTOR_DOT_GRAPH_SHAPE_SVG = "none" would let us generate HTML-like labels
  1469. # to workaround the above failure.
  1470. dot_graph_shape = os.environ.get("INDUCTOR_DOT_GRAPH_SHAPE_SVG", None)
  1471. # If not None, this is the URL that saves the SVG files of the input/output
  1472. # graph of each pass that changed the graph
  1473. # The nodes that are being transformed in each pass will be colored in yellow
  1474. # URL only supports local directory for now
  1475. log_url_for_graph_xform = os.environ.get("INDUCTOR_LOG_URL_FOR_GRAPH_XFORM", None)
  1476. # Store cProfile (see snakeviz to view)
  1477. compile_profile = False
  1478. # Upload the .tar.gz file
  1479. # Needs to be overridden based on specific environment needs
  1480. upload_tar: Optional[Callable[[str], None]] = None
  1481. log_autotuning_results = os.environ.get("LOG_AUTOTUNE_RESULTS", "0") == "1"
  1482. # Save mapping info from inductor generated kernel to post_grad/pre_grad fx nodes
  1483. # Levels:
  1484. # 0 - disabled (default)
  1485. # 1 - normal
  1486. # 2 - basic
  1487. # Backward compatibility:
  1488. # If TORCH_COMPILE_DEBUG=1, level is set to at least 1.
  1489. # If INDUCTOR_PROVENANCE is set, use its integer value.
  1490. provenance_tracking_level: int = int(
  1491. os.environ.get(
  1492. "INDUCTOR_PROVENANCE", os.environ.get("TORCH_COMPILE_DEBUG", "0")
  1493. )
  1494. )
  1495. _save_config_ignore: list[str] = [
  1496. # workaround: "Can't pickle <function ...>"
  1497. "trace.upload_tar",
  1498. "joint_custom_pre_pass",
  1499. "joint_custom_post_pass",
  1500. "pre_grad_custom_pass",
  1501. "aot_inductor.repro_level",
  1502. "aot_inductor.dump_aoti_minifier",
  1503. "post_grad_custom_pre_pass",
  1504. "post_grad_custom_post_pass",
  1505. "_fuse_ddp_communication_passes",
  1506. "_pre_fusion_custom_pass",
  1507. ]
  1508. _cache_config_ignore_prefix: list[str] = [
  1509. # trace functions are not relevant to config caching
  1510. "trace",
  1511. # uses absolute path
  1512. "cuda.cutlass_dir",
  1513. # not relevant
  1514. "worker_start_method",
  1515. "compile_threads",
  1516. # see CustomGraphPass; these are handled specially
  1517. "post_grad_custom_post_pass",
  1518. "post_grad_custom_pre_pass",
  1519. "joint_custom_pre_pass",
  1520. "joint_custom_post_pass",
  1521. "_fuse_ddp_communication_passes",
  1522. "_pre_fusion_custom_pass",
  1523. # tests assume that changes here don't invalidate cache
  1524. "always_complex_memory_overlap_TESTING_ONLY",
  1525. # cache related options are not relevant to cache results
  1526. "fx_graph_cache",
  1527. "fx_graph_remote_cache",
  1528. "autotune_local_cache",
  1529. "autotune_remote_cache",
  1530. ]
  1531. # External callable for matmul tuning candidates
  1532. external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
  1533. class test_configs:
  1534. force_extern_kernel_in_multi_template: bool = False
  1535. max_mm_configs: Optional[int] = None
  1536. runtime_triton_dtype_assert = False
  1537. static_cpp_dtype_assert = False
  1538. # regex to control the set of considered autotuning
  1539. # choices (aka configs) by name and / or description
  1540. autotune_choice_name_regex: Optional[str] = None
  1541. autotune_choice_desc_regex: Optional[str] = None
  1542. graphsafe_rng_func_ignores_fallback_random = False
  1543. track_memory_lifecycle: Optional[Literal["assert", "log"]] = None
  1544. # If set to True, AOTI-generated CMakelists.txt will still use libtorch
  1545. # for unit testing
  1546. use_libtorch = False
  1547. if TYPE_CHECKING:
  1548. from torch.utils._config_typing import * # noqa: F401, F403
  1549. # adds patch, save_config, etc
  1550. install_config_module(sys.modules[__name__])