graph.py 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import itertools
  5. import logging
  6. import operator
  7. import os
  8. import re
  9. import sys
  10. import time
  11. from collections import defaultdict
  12. from contextlib import contextmanager
  13. from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
  14. import sympy
  15. from sympy import Expr
  16. import torch
  17. import torch._logging
  18. import torch.fx
  19. from torch import device, Tensor
  20. from torch._decomp import get_decompositions
  21. from torch._dynamo.utils import defake, dynamo_timed
  22. from torch._library.fake_class_registry import FakeScriptObject
  23. from torch._library.utils import get_layout_constraint_tag
  24. from torch._logging import LazyString, trace_structured
  25. from torch._prims_common import (
  26. compute_required_storage_length,
  27. make_channels_last_strides_for,
  28. )
  29. from torch._subclasses.fake_tensor import FakeTensor
  30. from torch._utils_internal import full_aoti_runtime_assert
  31. from torch.fx.experimental._backward_state import BackwardState
  32. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  33. from torch.fx.experimental.symbolic_shapes import (
  34. _get_placeholder_expr,
  35. free_unbacked_symbols,
  36. has_free_symbols,
  37. resolve_unbacked_bindings,
  38. RuntimeAssert,
  39. ShapeEnv,
  40. SympyBoolean,
  41. SymTypes,
  42. )
  43. from torch.fx.node import Node
  44. from torch.utils._mode_utils import no_dispatch
  45. from torch.utils._ordered_set import OrderedSet
  46. from torch.utils._sympy.numbers import int_oo
  47. from . import config, ir, metrics
  48. from .codegen.common import (
  49. BackendFeature,
  50. DeviceOpOverrides,
  51. FileBackedGraphModule,
  52. get_backend_features,
  53. get_device_op_overrides,
  54. get_wrapper_codegen_for_device,
  55. init_backend_registration,
  56. WorkspaceArg,
  57. )
  58. from .exc import (
  59. CppWrapperCodegenError,
  60. LoweringException,
  61. MissingOperatorWithDecomp,
  62. MissingOperatorWithoutDecomp,
  63. )
  64. from .fx_utils import count_flops_fx
  65. from .ir import (
  66. Constant,
  67. DonatedBuffer,
  68. FixedLayout,
  69. get_device_type,
  70. GraphPartitionSignature,
  71. InputBuffer,
  72. Pointwise,
  73. Reduction,
  74. ShapeAsConstantBuffer,
  75. StorageBox,
  76. TensorBox,
  77. TorchBindObject,
  78. )
  79. from .lowering import (
  80. constrain_to_fake_tensors,
  81. constrain_to_fx_strides,
  82. FALLBACK_ALLOW_LIST,
  83. fallback_handler,
  84. fallback_node_due_to_unsupported_type,
  85. lowerings,
  86. make_fallback,
  87. maybe_layout_constraints,
  88. needs_realized_inputs,
  89. require_contiguous,
  90. tag_to_layout_constraint,
  91. unsupported_output_tensor,
  92. )
  93. from .runtime import autotune_cache
  94. from .runtime.autotune_cache import AutotuneCacheBundler
  95. from .sizevars import SizeVarAllocator
  96. from .utils import (
  97. convert_shape_to_inductor,
  98. gather_origins,
  99. get_cloned_parameter_buffer_name,
  100. get_donated_idxs,
  101. get_sympy_Expr_dtype,
  102. GraphPartitionMap,
  103. is_same_tensor,
  104. maybe_get_suppress_shape_guards_ctx,
  105. normalize_name,
  106. should_assume_input_aligned,
  107. SUPPORTED_MKLDNN_DEVICES,
  108. ValueWithLineMap,
  109. )
  110. from .virtualized import NullHandler, V
  111. if TYPE_CHECKING:
  112. from collections.abc import Iterable, Iterator, Sequence
  113. from types import ModuleType
  114. from torch._higher_order_ops.effects import _EffectType
  115. from torch.fx import GraphModule
  116. from torch.fx.graph import Graph
  117. from .codegen.wrapper import PythonWrapperCodegen
  118. from .dependencies import Dep
  119. from .scheduler import BaseSchedulerNode
  120. CompiledModule = Union[ModuleType, FileBackedGraphModule]
  121. from torch._inductor.codecache import output_code_log
  122. log = logging.getLogger(__name__)
  123. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  124. aten = torch.ops.aten
  125. _post_grad_graph_counter = itertools.count()
  126. if config.is_fbcode():
  127. from torch._inductor.fb.utils import log_module_code
  128. else:
  129. def log_module_code(*args: Any, **kwargs: Any) -> None:
  130. pass
  131. def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
  132. assert isinstance(
  133. constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
  134. ), (
  135. "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
  136. )
  137. if isinstance(constant_buffer, sympy.core.numbers.Integer):
  138. return torch.int64
  139. if isinstance(constant_buffer, sympy.Expr):
  140. return get_sympy_Expr_dtype(constant_buffer)
  141. if constant_buffer.is_integer:
  142. return torch.int64
  143. elif constant_buffer.is_float:
  144. return torch.float32
  145. else:
  146. return None
  147. def is_magic_method(op: Any) -> bool:
  148. magic_ops = OrderedSet(method_to_operator(m) for m in magic_methods)
  149. return op in magic_ops
  150. def getattr_recursive(
  151. obj: GraphModule, target: str
  152. ) -> Union[Tensor, torch._C.ScriptObject, GraphModule]:
  153. target_atoms = target.split(".")
  154. attr_itr = obj
  155. for i, atom in enumerate(target_atoms):
  156. if not hasattr(attr_itr, atom):
  157. raise RuntimeError(
  158. f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
  159. )
  160. attr_itr = getattr(attr_itr, atom)
  161. return attr_itr
  162. def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
  163. ret: dict[Node, tuple[int, ...]] = {}
  164. output_node = g.find_nodes(op="output")[0]
  165. if "user_visible_output_idxs" not in output_node.meta:
  166. return ret
  167. if not isinstance(output_node.args[0], torch.fx.Node):
  168. output_node_args = output_node.args[0]
  169. else:
  170. output_node_args = output_node.args
  171. for idx, node in enumerate(output_node_args):
  172. if idx in output_node.meta["user_visible_output_idxs"]:
  173. ret[node] = output_node.meta["original_output_strides"][idx]
  174. return ret
  175. def mark_nodes_dislike_padding(
  176. g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]]
  177. ) -> None:
  178. """
  179. Nodes like convolution/convolution_backward want its input to be dense.
  180. If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
  181. The pass finds nodes that dislike padding. These are nodes that can be reached
  182. from a convolution/convolution_backward in the backward direction without
  183. going thru a reduction.
  184. """
  185. if not config.comprehensive_padding:
  186. return
  187. ops_dislike_padding = OrderedSet(
  188. [
  189. aten.convolution,
  190. aten.convolution_backward,
  191. aten._scaled_mm,
  192. ]
  193. )
  194. # what's a better way to collect the reduction ops?
  195. ops_like_padding = OrderedSet(
  196. [
  197. aten.var_mean,
  198. aten.sum,
  199. aten.mean,
  200. aten.prod,
  201. aten.any,
  202. aten.amin,
  203. aten.amax,
  204. aten.min,
  205. aten.max,
  206. aten.argmin,
  207. aten.argmax,
  208. aten.scatter_reduce,
  209. ]
  210. )
  211. def _get_overload_packet(
  212. node: torch.fx.Node,
  213. ) -> Optional[torch._ops.OpOverloadPacket]:
  214. return (
  215. node.target._overloadpacket
  216. if node.op == "call_function"
  217. # hasattr on OpOverloadPacket is slow, do isinstance first
  218. and isinstance(node.target, torch._ops.OpOverload)
  219. and hasattr(node.target, "_overloadpacket")
  220. else None
  221. )
  222. for cur in reversed(g.nodes):
  223. if isinstance(
  224. cur.target,
  225. torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
  226. ):
  227. cur.meta["dislike_padding"] = True
  228. continue
  229. if (
  230. isinstance(cur.target, torch._ops.OpOverload)
  231. and get_layout_constraint_tag(cur.target)
  232. == torch._C.Tag.needs_exact_strides
  233. ):
  234. cur.meta["dislike_padding"] = True
  235. continue
  236. op = _get_overload_packet(cur)
  237. if not op:
  238. continue
  239. if op in ops_dislike_padding:
  240. cur.meta["dislike_padding"] = True
  241. if cur.meta.get("dislike_padding", False):
  242. # propagate
  243. for prior in cur.all_input_nodes:
  244. prior_op = _get_overload_packet(prior)
  245. if not prior_op:
  246. continue
  247. if prior_op not in ops_like_padding:
  248. prior.meta["dislike_padding"] = True
  249. # We only want to mark output nodes. So, move it after the above prior nodes process.
  250. if not config.pad_outputs and cur in user_visible_output_strides:
  251. cur.meta["dislike_padding"] = True
  252. class GraphLowering(torch.fx.Interpreter):
  253. graph_outputs: list[ir.IRNode]
  254. def __init__(
  255. self,
  256. gm: torch.fx.GraphModule,
  257. example_inputs: Optional[Sequence[object]] = None,
  258. shape_env: Optional[ShapeEnv] = None,
  259. graph_id: Optional[int] = None,
  260. cpp_wrapper: bool = False,
  261. aot_mode: bool = False,
  262. layout_opt: Optional[bool] = None,
  263. extern_node_serializer: Optional[
  264. Callable[[list[ir.ExternKernelNode]], Any]
  265. ] = None,
  266. is_inference: bool = False,
  267. is_backward: bool = False,
  268. is_const_graph: bool = False,
  269. const_output_index: Optional[dict[str, int]] = None,
  270. const_wrapper_code: Optional[str] = None,
  271. const_kernel_code: Optional[str] = None,
  272. const_module: Optional[GraphLowering] = None,
  273. name: Optional[str] = None,
  274. inputs_to_check: Optional[Sequence[int]] = None,
  275. fx_wrapper: bool = False,
  276. ) -> None:
  277. super().__init__(gm)
  278. self.example_inputs = example_inputs
  279. self.layout_opt = (
  280. layout_opt
  281. if layout_opt is not None
  282. else self.decide_layout_opt(gm, is_inference=is_inference)
  283. )
  284. self.num_channels_last_conv = 0
  285. self.is_inference = is_inference
  286. self.is_backward = is_backward
  287. self.is_const_graph = is_const_graph
  288. self.const_wrapper_code = const_wrapper_code
  289. self.const_kernel_code = const_kernel_code
  290. self.const_module = const_module
  291. self.inputs_to_check = inputs_to_check
  292. self.extra_traceback = False # we do our own error wrapping
  293. if shape_env is None:
  294. shape_env = ShapeEnv()
  295. self.reuse_shape_env = False
  296. else:
  297. self.reuse_shape_env = True
  298. self._shape_env = shape_env
  299. # We're going to mutate ras_by_symbol as we finish generating them
  300. self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
  301. shape_env.deferred_runtime_asserts.copy()
  302. )
  303. self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
  304. self.sizevars = SizeVarAllocator(shape_env)
  305. self.graph_input_names: list[str] = []
  306. self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
  307. self.graph_inputs_original: dict[str, InputBuffer] = {}
  308. self.partition_maps: Optional[list[GraphPartitionMap]] = None
  309. self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet()
  310. self.device_types: OrderedSet[str] = (
  311. const_module.device_types if const_module else OrderedSet()
  312. )
  313. self.device_idxs: OrderedSet[int] = (
  314. const_module.device_idxs if const_module else OrderedSet()
  315. )
  316. self.device_type = "cpu"
  317. # Inplace padding may require Inductor to allocate slightly larger
  318. # tensor for padding.
  319. self.buffer_to_padded_size: dict[str, list[int]] = {}
  320. self.buffers: list[ir.Buffer] = []
  321. self.operations: list[ir.Operation] = []
  322. self.const_output_index: dict[str, int] = (
  323. const_output_index if const_output_index else {}
  324. )
  325. self.folded_constants: OrderedSet[str] = (
  326. OrderedSet(const_output_index.keys())
  327. if const_output_index
  328. else OrderedSet()
  329. )
  330. self.constants: dict[str, torch.Tensor] = (
  331. const_module.constants if const_module else {}
  332. )
  333. self.named_buffers: dict[str, torch.Tensor] = (
  334. const_module.named_buffers if const_module else {}
  335. )
  336. self.named_parameters: dict[str, torch.Tensor] = (
  337. const_module.named_parameters if const_module else {}
  338. )
  339. self.torchbind_constants: dict[
  340. str, Union[torch._C.ScriptObject, FakeScriptObject]
  341. ] = {}
  342. self.seen_subgraphs: dict[str, ir.Subgraph] = {}
  343. self.constant_reprs: dict[str, str] = {}
  344. self.removed_operations: OrderedSet[str] = OrderedSet()
  345. self.removed_buffers: OrderedSet[str] = OrderedSet()
  346. self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
  347. self.mutated_buffers: OrderedSet[str] = OrderedSet()
  348. self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
  349. self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
  350. self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
  351. self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
  352. from torch._inductor.extern_node_serializer import extern_node_json_serializer
  353. self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
  354. extern_node_serializer
  355. if config.is_fbcode() and extern_node_serializer
  356. else extern_node_json_serializer
  357. )
  358. self.current_node: torch.fx.Node = None # type: ignore[assignment]
  359. self.lists: dict[str, list[str]] = {}
  360. self.mutated_inputs: OrderedSet[str] = OrderedSet()
  361. self.mutated_input_idxs: list[int] = []
  362. self.name_to_buffer: dict[str, ir.Buffer] = {}
  363. self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
  364. self.name_to_op: dict[str, ir.Operation] = {}
  365. self.creation_time = time.time()
  366. self.name = name # type: ignore[assignment]
  367. self.cpp_wrapper = cpp_wrapper
  368. self.fx_wrapper = fx_wrapper
  369. # record multi_kernel choice for cpp_wrapper so the second pass knows
  370. # which sub-kernel is picked. Copy cpp_wrapper to another variable
  371. # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
  372. self.record_multi_kernel_choice = cpp_wrapper
  373. self.multi_kernel_to_choice: dict[str, str] = {}
  374. self.aot_mode = aot_mode
  375. self.graph_id = graph_id
  376. self.post_grad_graph_id = next(_post_grad_graph_counter)
  377. self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
  378. # record intermediate results for input of UsedDefinedTritonKernels
  379. # This will be used if autotuning is done in one pass.
  380. self.autotuning_inputs: Optional[list[torch.Tensor]] = None
  381. self.autotuning_mapping: Optional[dict[str, dict[str, int]]] = None
  382. self.autotuning_grids: Optional[dict[str, Any]] = None
  383. # current_device is set only during codegen of a device-specific kernel
  384. # a graph can have many devices
  385. self.current_device: Optional[torch.device] = None
  386. self.nodes_prefer_channels_last = (
  387. self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
  388. )
  389. self._warned_fallback = OrderedSet(["aten.convolution_backward"])
  390. self.user_visible_output_strides = get_user_visible_output_strides(gm.graph)
  391. mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
  392. self.cache_key: str = "" # This is the cache key for the compiled artifact
  393. self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
  394. self.cache_linemap: list[
  395. tuple[int, str]
  396. ] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
  397. # Used if lowering encounters cases where cudagraphs are not supported
  398. self.disable_cudagraphs_reason: Optional[str] = None
  399. # only keeping one node per device for stack trace purposes
  400. self.device_node_mapping: dict[torch.device, torch.fx.Node] = {}
  401. self.orig_gm: torch.fx.GraphModule = gm.__copy__()
  402. for k, v in self.orig_gm.named_buffers():
  403. self.named_buffers[k] = v
  404. for k, v in self.orig_gm.named_parameters():
  405. self.named_parameters[k] = v
  406. self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
  407. "dynamo_flat_name_to_original_fqn", {}
  408. )
  409. self.allocated_constant_name: dict[str, str] = (
  410. const_module.allocated_constant_name if const_module is not None else {}
  411. )
  412. init_backend_registration()
  413. self.get_backend_features = functools.lru_cache(None)(get_backend_features)
  414. self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
  415. # Track the buffers that we know is unaligned
  416. # This can either be a graph input or the output of fallback
  417. # kernels.
  418. self.unaligned_buffers: OrderedSet[str] = OrderedSet()
  419. self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
  420. self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet()
  421. # more aggressive prologue fusion
  422. self.invoke_quant_ops: OrderedSet[str] = OrderedSet()
  423. # Below field is related to printing debug intermediate tensor values info for debugging
  424. self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet()
  425. # state used by for Kernel.workspace
  426. self.workspace_id = itertools.count()
  427. # track the current placeholder index that we are processing
  428. self.placeholder_idx = -1
  429. self.bw_donated_idxs = get_donated_idxs()
  430. # Cache for dep size hints to avoid expensive recomputation
  431. self.dep_size_hint_cache: dict[Dep, int] = {}
  432. def freeze_runtime_asserts(self) -> None:
  433. self._shape_env.freeze_runtime_asserts()
  434. def symbolic_sizes_strides(
  435. self, ex: torch.Tensor
  436. ) -> tuple[Sequence[Union[int, Expr]], Sequence[Union[int, Expr]]]:
  437. """
  438. Support dynamic shapes and dynamic strides by assigning variables
  439. to each dimension. We duck-shape tensors, so if two tensors
  440. have the same size they get assigned the same symbolic variable.
  441. """
  442. if self.reuse_shape_env:
  443. return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
  444. ex.stride()
  445. )
  446. else:
  447. from torch._dynamo.source import ConstantSource
  448. # TODO: this should not be needed once #93059 lands
  449. # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
  450. # TODO: make a dedicated UnknownSource for this?
  451. # NB: This is using the legacy default behavior from
  452. # create_symbolic_sizes_strides_storage_offset but we hope we can
  453. # just delete this entirely
  454. source = ConstantSource(
  455. f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
  456. )
  457. (
  458. size,
  459. stride,
  460. _,
  461. ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
  462. ex,
  463. source,
  464. )
  465. r_size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
  466. r_stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
  467. return r_size, r_stride
  468. def static_sizes_strides(
  469. self, ex: torch.Tensor
  470. ) -> tuple[list[sympy.Expr], list[sympy.Expr]]:
  471. """
  472. Primarily used to weights
  473. """
  474. size = [sympy.Integer(i) for i in ex.size()]
  475. stride = [sympy.Integer(i) for i in ex.stride()]
  476. return size, stride
  477. def get_allocation_size(
  478. self,
  479. node: Union[
  480. ir.TensorBox, ir.StorageBox, ir.Buffer, WorkspaceArg, ir.TorchBindObject
  481. ],
  482. ) -> Sequence[Expr]:
  483. if isinstance(node, ir.TensorBox):
  484. node = node.data # type: ignore[assignment]
  485. if isinstance(node, ir.StorageBox):
  486. node = node.data # type: ignore[assignment]
  487. if (
  488. isinstance(node, ir.ComputedBuffer)
  489. and node.name in self.buffer_to_padded_size
  490. ):
  491. return self.buffer_to_padded_size[node.name]
  492. else:
  493. return node.get_size()
  494. def get_allocation_storage_size(
  495. self, node: Union[ir.Buffer, WorkspaceArg, ir.TorchBindObject]
  496. ) -> Expr:
  497. layout = node.get_layout()
  498. size = self.get_allocation_size(node) # consider inplace padding
  499. stride = layout.stride
  500. offset = layout.offset
  501. return compute_required_storage_length(size, stride, offset) # type: ignore[arg-type]
  502. def has_feature(
  503. self,
  504. device: Union[torch._inductor.ir.IRNode, device, None],
  505. feature: BackendFeature,
  506. ) -> bool:
  507. assert isinstance(feature, BackendFeature), feature
  508. return feature in self.get_backend_features(get_device_type(device))
  509. def get_dep_size_hint(self, dep: Dep) -> int:
  510. """
  511. Get the size hint for a dependency with caching to avoid expensive recomputation.
  512. """
  513. if dep not in self.dep_size_hint_cache:
  514. res = 0
  515. try:
  516. if not dep.has_unbacked_symbols():
  517. res = dep.numbytes_hint()
  518. except KeyError:
  519. # In at least one test (test/inductor/test_torchbind.py) we
  520. # create a StarDep that doesn't exist in the graph and calling
  521. # `has_unbacked_symbols()` throws an error.
  522. pass
  523. self.dep_size_hint_cache[dep] = res
  524. return self.dep_size_hint_cache[dep]
  525. def get_current_device_or_throw(self) -> torch.device:
  526. if device := self.current_device:
  527. return device
  528. else:
  529. raise RuntimeError("No current device")
  530. @contextlib.contextmanager
  531. def set_current_device(self, device: torch.device) -> Iterator[None]:
  532. prior = self.current_device
  533. self.current_device = device
  534. try:
  535. yield
  536. finally:
  537. self.current_device = prior
  538. def get_training_phase(self) -> str:
  539. if self.is_inference:
  540. return "inference"
  541. if self.is_backward:
  542. return "backward"
  543. return "forward"
  544. @staticmethod
  545. def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
  546. """
  547. Decide if we should enable layout optimization for this graph based on
  548. heuristics.
  549. """
  550. if not config.layout_optimization:
  551. return False
  552. if config.force_layout_optimization:
  553. return True
  554. conv_nodes = [
  555. n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
  556. ]
  557. nconv = len(conv_nodes)
  558. if nconv == 0:
  559. return False
  560. # For cpu backend and mkldnn enabled, we always use channels_last for better performance.
  561. if (
  562. torch.backends.mkldnn.enabled
  563. and torch.backends.mkldnn.is_available()
  564. and all(
  565. n.args[idx].meta["val"].device.type in SUPPORTED_MKLDNN_DEVICES
  566. for n in conv_nodes
  567. for idx in [0, 1]
  568. )
  569. ):
  570. return True
  571. # Following models are skipped due to this:
  572. # jx_nest_base
  573. # volo_d1_224
  574. if len(list(gm.graph.nodes)) >= 300 * nconv:
  575. log.debug("Skipped layout opt because only a few conv")
  576. return False
  577. if any(
  578. has_free_symbols(n.args[idx].meta["val"])
  579. for n in conv_nodes
  580. for idx in [0, 1]
  581. ):
  582. log.debug(
  583. "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
  584. )
  585. return False
  586. def is_grouped(n: Any) -> bool:
  587. meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator]
  588. assert isinstance(meta_val, torch.Tensor)
  589. return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator]
  590. def is_in_out_channel(n: torch.fx.Node) -> bool:
  591. return (
  592. n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator]
  593. and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator]
  594. )
  595. def is_small_channel(n: torch.fx.Node) -> bool:
  596. return (
  597. n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator]
  598. and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator]
  599. )
  600. # only grouped convolutions benchmarked as slower in conv samples for inference only
  601. if is_inference:
  602. flop_counts: dict[str, float] = defaultdict(float)
  603. for node in conv_nodes:
  604. counted_flops = count_flops_fx(node)
  605. if counted_flops is None:
  606. continue
  607. if is_grouped(node):
  608. node_type = "grouped"
  609. elif is_small_channel(node):
  610. node_type = "small"
  611. elif is_in_out_channel(node):
  612. node_type = "in_out"
  613. else:
  614. node_type = "default"
  615. flop_counts[node_type] += counted_flops
  616. else:
  617. log.debug("Conv inputs meta not found")
  618. # average benchmarked channels last speedup / slowdown, < 1 is speedup.
  619. # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
  620. # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
  621. GROUPED_MULTIPLIER = 1.358
  622. DEFAULT_MULTIPLIER = 0.823
  623. IN_OUT_MULTIPLIER = 0.725
  624. SMALL_MULTIPLIER = 0.783
  625. total_flops = sum(flop_counts.values())
  626. # TODO - get different values per hardware
  627. weighted_flops = (
  628. flop_counts["grouped"] * GROUPED_MULTIPLIER
  629. + flop_counts["small"] * SMALL_MULTIPLIER
  630. + flop_counts["in_out"] * IN_OUT_MULTIPLIER
  631. + flop_counts["default"] * DEFAULT_MULTIPLIER
  632. )
  633. do_layout_opt = weighted_flops <= total_flops
  634. if not do_layout_opt:
  635. log.debug(
  636. "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
  637. total_flops,
  638. weighted_flops,
  639. )
  640. return do_layout_opt
  641. # Channels last layout can dramatically hurt grouped conv perf. E.g.
  642. # Conv with arguments like
  643. # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
  644. # "stride": [2, 2], "padding": [1, 1], "groups": 2}
  645. # slows down 31x using channels last..
  646. # But a lot of timm models use depthwise separable convolution which will
  647. # result in grouped convolution with in-channel size == 1.
  648. # For those grouped convolution, channels last still helps a lot.
  649. # E.g.
  650. # Conv with arguments
  651. # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
  652. # "stride": [2, 2], "padding": [1, 1], "groups": 58}
  653. # get 1.86x speedup with channels last layout.
  654. #
  655. # The following heuristics skip using channels-last if the model contains
  656. # grouped convolution with in-channels > 1.
  657. if any(map(is_grouped, conv_nodes)):
  658. log.debug(
  659. "Skip layout opt because found grouped convolution with >1 in_channels!"
  660. )
  661. return False
  662. # For some models that contain convolution with larger in-channel than out-channel, applying
  663. # channels last hurts performance.
  664. # Following models are skipped due to this:
  665. # - pytorch_unet
  666. # - phlippe_densenet (slightly worse)
  667. # - Background_Matting (1.22x -> 0.821x)
  668. # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
  669. if any(map(is_in_out_channel, conv_nodes)):
  670. log.debug(
  671. "Skip layout opt because some convolutions have smaller out_channel"
  672. )
  673. return False
  674. # Following models are skipped due to this:
  675. # - functorch_maml_omniglot
  676. if all(map(is_small_channel, conv_nodes)):
  677. log.debug("Skip layout opt because all convolution channels are too small")
  678. return False
  679. return True
  680. def qualify_name(self, name: str) -> str:
  681. """Prepend the given name with the graph name if any."""
  682. if self.name is not None:
  683. return f"{self.name}_{name}"
  684. return name
  685. def make_subgraph(
  686. self,
  687. gm: torch.fx.GraphModule,
  688. example_inputs: list[torch.Tensor],
  689. subgraph_name: str,
  690. ) -> SubgraphLowering:
  691. """
  692. Make a subgraph of the current graph with all inherited parts, except
  693. the graph module (`gm`) and `example_inputs`. The subgraphs are lowered
  694. separately and lifted into a separate function in the parent output
  695. wrapper code. The subgraph name is qualified by the parent graph's
  696. name. Note that the lifting of subgraph is supported for python wrapper
  697. only. For cpp wrapper, we inline the subgraphs in the parent wrapper.
  698. """
  699. return SubgraphLowering(
  700. parent=self,
  701. gm=gm,
  702. example_inputs=example_inputs,
  703. shape_env=self._shape_env,
  704. cpp_wrapper=self.cpp_wrapper,
  705. aot_mode=self.aot_mode,
  706. extern_node_serializer=self.extern_node_serializer,
  707. is_inference=self.is_inference,
  708. is_backward=self.is_backward,
  709. name=self.qualify_name(subgraph_name),
  710. )
  711. def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]:
  712. """
  713. The rule to decide if an node prefer channels last is simple.
  714. 1. if it's input/output of a convolution
  715. 2. if one of its user prefers channels last
  716. We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
  717. Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
  718. channels last.
  719. Consider the scenario: conv -> batch-norm -> relu -> conv
  720. Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
  721. 1. the output of batch-norm should be channels last initially since its input is a conv's output.
  722. Forcing the batch-norm's output to be contiguous results in the first copy
  723. 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
  724. We need convert it to channels last layout which results in the second copy.
  725. With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
  726. can be saved.
  727. """
  728. output_set = OrderedSet[Node]()
  729. for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
  730. if n.target == torch.ops.aten.convolution.default:
  731. output_set.add(n)
  732. continue
  733. for user in n.users:
  734. if user in output_set:
  735. output_set.add(n)
  736. break
  737. # need a second pass to add downstream nodes of those channel last nodes to the sets.
  738. # This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
  739. #
  740. # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
  741. # from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
  742. # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
  743. # tensors and passed to a kernel.
  744. #
  745. # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
  746. # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
  747. # This also helps the following models:
  748. # - res2net101_26w_4s
  749. # - res2net50_14w_8s
  750. # - sebotnet33ts_256
  751. for n in self.module.graph.nodes: # type: ignore[union-attr]
  752. if n in output_set:
  753. output_set.update(n.users)
  754. return output_set
  755. def warn_fallback(self, name: str) -> None:
  756. if name not in self._warned_fallback:
  757. self._warned_fallback.add(name)
  758. perf_hint_log.info("Using FallbackKernel: %s", name)
  759. def add_device_info(self, device: torch.device) -> None:
  760. self.device_types.add(device.type)
  761. if device.index is not None:
  762. self.device_idxs.add(device.index)
  763. if V.graph.current_node and device not in self.device_node_mapping:
  764. self.device_node_mapping[device] = V.graph.current_node
  765. @property
  766. def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode:
  767. return V.fake_mode
  768. def try_get_buffer(
  769. self, buffer_name: str
  770. ) -> Optional[Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]]:
  771. if buffer_name in self.name_to_buffer:
  772. return self.name_to_buffer[buffer_name]
  773. if buffer_name in self.graph_inputs:
  774. return self.graph_inputs[buffer_name]
  775. if buffer_name in self.constants:
  776. data = V.graph.constants[buffer_name]
  777. return ir.ConstantBuffer(
  778. name=buffer_name,
  779. layout=ir.FixedLayout(
  780. data.device, data.dtype, *V.graph.static_sizes_strides(data)
  781. ),
  782. )
  783. return None
  784. def add_symbol_graph_input(self, symbol: sympy.Expr) -> None:
  785. raise RuntimeError("Should not be called for the main graph")
  786. def get_buffer(
  787. self, buffer_name: str
  788. ) -> Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]:
  789. buf = self.try_get_buffer(buffer_name)
  790. if buf is not None:
  791. return buf
  792. raise RuntimeError(f"Failed to find buffer matching name {buffer_name}")
  793. def get_dtype(self, buffer_name: str) -> torch.dtype:
  794. if buffer_name in self.constants:
  795. return self.constants[buffer_name].dtype
  796. # For a mutation op we should return the dtype of the buffer being mutated
  797. if (
  798. hasattr(self.scheduler, "mutation_real_name")
  799. and buffer_name in self.scheduler.mutation_real_name
  800. ):
  801. mutated_buf = self.scheduler.mutation_real_name[buffer_name]
  802. if mutated_buf in self.name_to_buffer:
  803. return self.name_to_buffer[mutated_buf].get_dtype()
  804. if mutated_buf in self.graph_inputs:
  805. return self.graph_inputs[mutated_buf].get_dtype()
  806. if buffer_name in self.name_to_buffer:
  807. return self.name_to_buffer[buffer_name].get_dtype()
  808. if buffer_name in self.graph_inputs:
  809. return self.graph_inputs[buffer_name].get_dtype()
  810. m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
  811. if m:
  812. return self.get_dtype(m.group(1))
  813. raise KeyError(f"could not find {buffer_name}")
  814. def get_numel(self, buffer_name: str) -> Union[int, Expr]:
  815. if buffer_name in self.constants:
  816. return self.constants[buffer_name].numel()
  817. if buffer_name in self.name_to_buffer:
  818. buf = self.name_to_buffer[buffer_name]
  819. if not buf.has_tensor_output():
  820. return 1
  821. return buf.get_numel()
  822. if buffer_name in self.graph_inputs:
  823. return self.graph_inputs[buffer_name].get_numel()
  824. raise KeyError(f"could not find {buffer_name}")
  825. def run(self, *args: Any) -> Any: # type: ignore[override]
  826. with dynamo_timed("GraphLowering.run"):
  827. return super().run(*args)
  828. def register_operation(self, op: ir.Operation) -> str:
  829. assert op.operation_name is None, f"Operation registered twice: {op}"
  830. assert isinstance(op, ir.Operation)
  831. name = self.qualify_name(f"op{len(self.operations)}")
  832. self.operations.append(op)
  833. self.name_to_op[name] = op
  834. op.operation_name = name
  835. return name
  836. def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
  837. name = self.qualify_name(f"buf{len(self.buffers)}")
  838. self.buffers.append(buffer)
  839. self.name_to_buffer[name] = buffer
  840. device = buffer.get_device()
  841. if (
  842. # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
  843. device is not None
  844. and not (
  845. isinstance(buffer, ir.ComputedBuffer)
  846. and buffer.is_zero_elements()
  847. and device == torch.device("cpu")
  848. )
  849. ):
  850. self.add_device_info(device)
  851. if set_name:
  852. buffer.name = name
  853. return name
  854. def register_operation_list(self, operation_names: list[str]) -> str:
  855. name = self.qualify_name("list_" + "_".join(operation_names))
  856. self.lists[name] = operation_names
  857. return name
  858. def register_users_of(
  859. self, node_output: Union[Iterable[ir.IRNode], ir.IRNode]
  860. ) -> None:
  861. def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None:
  862. if isinstance(value, (list, tuple)):
  863. for x in value:
  864. register(x)
  865. if isinstance(value, ir.TensorBox):
  866. for read_name in value.get_read_names():
  867. self.name_to_users[read_name].append(value)
  868. register(node_output)
  869. def mark_buffer_mutated(self, name: str) -> None:
  870. """
  871. When a buffer is mutated we need to make sure all the reads to
  872. the old version are realized before the mutation happens.
  873. """
  874. assert isinstance(name, str)
  875. self.mutated_buffers.add(name)
  876. if name not in self.name_to_users:
  877. return
  878. for user in self.name_to_users[name]:
  879. user.realize()
  880. def get_original_value_of_constant(self, name: str) -> torch.Tensor:
  881. """
  882. In AOTI, module buffers may have been mutated during the tracing and compilation.
  883. Thus we need to read from previously stored original buffers, to make sure the
  884. generated model.so uses correct initial values.
  885. """
  886. assert name in self.allocated_constant_name and name in self.constants, (
  887. "Can not find the original value for " + name
  888. )
  889. orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name])
  890. return (
  891. self.module.meta[orig_name] # type: ignore[index]
  892. if orig_name in self.module.meta # type: ignore[operator]
  893. else self.constants[name]
  894. )
  895. def allocate_non_dup_const_name(
  896. self, name: Optional[str], data: Union[Tensor]
  897. ) -> str:
  898. if not config.aot_inductor.use_runtime_constant_folding:
  899. for constant_name, value in self.constants.items():
  900. if is_same_tensor(data, value):
  901. return constant_name
  902. if name is None:
  903. name = f"constant{len(self.constants)}"
  904. orig_name = name
  905. if name[0].isdigit():
  906. name = f"constant_{name}"
  907. name = self.qualify_name(name)
  908. # We may generate a var name for each constant in the codegen.
  909. # Let's only keep sane characters.
  910. prefix = normalize_name(name)
  911. name = prefix
  912. cnt = 0
  913. while name in self.constants:
  914. name = f"{prefix}_{cnt}"
  915. cnt += 1
  916. self.constants[name] = data
  917. self.constant_reprs[name] = (
  918. f"{data.device!r} {data.dtype!r} "
  919. f"{tuple(data.size())!r} {tuple(data.stride())!r} "
  920. f"{hash(data):x}"
  921. )
  922. self.allocated_constant_name[name] = orig_name # type: ignore[assignment]
  923. return name
  924. def add_tensor_constant(
  925. self, data: Tensor, name: Optional[str] = None
  926. ) -> Union[TensorBox, ir.ShapeAsConstantBuffer]:
  927. new_name = self.allocate_non_dup_const_name(name, data)
  928. return TensorBox.create(
  929. ir.ConstantBuffer(
  930. name=new_name,
  931. layout=FixedLayout(
  932. data.device, data.dtype, *self.static_sizes_strides(data)
  933. ),
  934. )
  935. )
  936. def constant_name(self, name: str, device_override: Optional[torch.device]) -> str:
  937. """
  938. We AOT copy constants to the devices they are needed on.
  939. If device_override doesn't match the constant's device, then
  940. copy it and return a different name.
  941. """
  942. if self.constants[name].device == device_override or device_override is None:
  943. return name
  944. with torch.utils._python_dispatch._disable_current_modes():
  945. # caller might have OrderedSet fake tensor mode which will create a fake tensor
  946. # when calling .to, so unset modes here
  947. return self.allocate_non_dup_const_name(
  948. f"{name}_{device_override.type}{device_override.index or 0}",
  949. self.constants[name].to(device_override),
  950. )
  951. def placeholder(
  952. self,
  953. target: str, # type: ignore[override]
  954. args: tuple[object], # type: ignore[override]
  955. kwargs: dict[str, object],
  956. ) -> Union[Expr, TensorBox, None]:
  957. self.placeholder_idx += 1
  958. example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
  959. target = self.qualify_name(target)
  960. if isinstance(example, SymTypes):
  961. # TODO fix partitioning issue and re-enable for backward
  962. # https://github.com/pytorch/pytorch/issues/155468.
  963. if not V.graph.is_backward:
  964. expr = _get_placeholder_expr(example.node)
  965. else:
  966. expr = example.node.expr
  967. self.graph_inputs[target] = expr
  968. self.graph_input_names.append(target)
  969. return expr
  970. elif isinstance(example, (int, bool, float)):
  971. expr = sympy.sympify(example)
  972. self.graph_inputs[target] = expr
  973. self.graph_input_names.append(target)
  974. return expr
  975. elif isinstance(example, FakeScriptObject):
  976. obj = TorchBindObject(name=target, value=example)
  977. self.graph_inputs[target] = obj
  978. self.graph_input_names.append(target)
  979. return obj
  980. elif example is None:
  981. self.graph_input_names.append(target)
  982. return None
  983. if isinstance(example, BackwardState):
  984. # Ignored arg, must be unused
  985. # Alternately we could filter this out in AotAutograd
  986. self.graph_input_names.append(target)
  987. return None
  988. # See note: Note: [Generator arguments in AOTDispatcher]
  989. elif isinstance(example, torch.Generator):
  990. assert len(V.graph.current_node.users) == 1 and next(
  991. iter(V.graph.current_node.users)
  992. ).target in (
  993. torch._prims.rng_prims.graphsafe_run_with_rng_state,
  994. torch.ops.higher_order.invoke_subgraph,
  995. )
  996. gen = ir.GeneratorState(name=target, device=example.device)
  997. self.graph_inputs[target] = gen # type: ignore[assignment]
  998. self.graph_input_names.append(target)
  999. return gen
  1000. assert isinstance(example, torch.Tensor), example
  1001. # todo(chilli): We can remove the last check once we turn buffers into
  1002. # static shape tensors. That's a hack to workaround Inductor believing
  1003. # the buffer should be static but us passing in a fake tensor with
  1004. # symbolic shapes.
  1005. if not example._has_symbolic_sizes_strides:
  1006. # the first N inputs are weights
  1007. sizes, strides = self.static_sizes_strides(example)
  1008. else:
  1009. sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
  1010. if (
  1011. self.is_backward
  1012. and self.bw_donated_idxs
  1013. and self.placeholder_idx in self.bw_donated_idxs
  1014. ):
  1015. tensor = TensorBox.create(
  1016. DonatedBuffer(
  1017. name=target,
  1018. layout=FixedLayout(example.device, example.dtype, sizes, strides),
  1019. )
  1020. )
  1021. else:
  1022. # TODO(jansel): handle input aliasing
  1023. tensor = TensorBox.create(
  1024. InputBuffer(
  1025. name=target,
  1026. layout=FixedLayout(example.device, example.dtype, sizes, strides),
  1027. )
  1028. )
  1029. self.graph_inputs[target] = tensor
  1030. self.graph_input_names.append(target)
  1031. self.graph_inputs_original[target] = tensor.data.data # type: ignore[union-attr]
  1032. if self.current_node.users: # cudagraphs should work with an unused CPU input
  1033. self.add_device_info(example.device)
  1034. # Note: [Input Alignment handling in Inductor]
  1035. # Alignment matters for generating efficient code. Some operations,
  1036. # e.g. vectorized loads, can only be performed on aligned inputs.
  1037. #
  1038. # But if we codegen assuming aligned inputs and then get unaligned
  1039. # inputs at runtime, then we are forced to clone - which is bad for
  1040. # both perf and memory usage.
  1041. #
  1042. # One option would be to guard on storage_offset%ALIGNMENT, and then
  1043. # codegen based on this. But storage_offset guards turned out to be
  1044. # expensive and cause recompiles; Instead, we're generating code
  1045. # based on the alignment of the example input without guarding.
  1046. with maybe_get_suppress_shape_guards_ctx():
  1047. if not should_assume_input_aligned(example):
  1048. self.unaligned_buffers.add(target)
  1049. return tensor
  1050. def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
  1051. if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
  1052. return super().call_function(target, args, kwargs)
  1053. # hasattr on OpOverloadPacket is slow, check isinstance first
  1054. if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr(
  1055. target, "_inductor_lowering_function"
  1056. ):
  1057. # passthrough lowerings from .pattern_matcher
  1058. return target(*args, **kwargs)
  1059. if target not in lowerings:
  1060. assert isinstance(target, torch._ops.OpOverload), (
  1061. f"{target} is not an OpOverload"
  1062. )
  1063. base_name = target.name().split(".")[0]
  1064. if base_name in FALLBACK_ALLOW_LIST:
  1065. make_fallback(target, warn=False, override_decomp=True)
  1066. elif config.implicit_fallbacks:
  1067. error = (
  1068. MissingOperatorWithDecomp
  1069. if get_decompositions([target])
  1070. else MissingOperatorWithoutDecomp
  1071. )
  1072. log.info(
  1073. "Creating implicit fallback for:\n%s",
  1074. error.operator_str(target, args, kwargs),
  1075. )
  1076. tag: Optional[torch._C.Tag] = get_layout_constraint_tag(
  1077. target, with_default=False
  1078. )
  1079. if (
  1080. tag is None
  1081. and torch._library.utils.is_builtin(target)
  1082. and self.is_backward
  1083. ):
  1084. # for implicit fallback ATen ops during backward, if there
  1085. # is no layout constraint tag, we conservatively require contiguous
  1086. # input since some eager kernels do not
  1087. # support non-contiguous inputs. Otherwise they may silently cause
  1088. # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
  1089. # We only do this For ATen ops and for backward.
  1090. #
  1091. # TODO: should really switch to "needs_fixed_stride" constraint on these
  1092. # and identify them one by one.
  1093. decided_constraint = require_contiguous # type: ignore[assignment]
  1094. else:
  1095. default_tag: torch._C.Tag = get_layout_constraint_tag(
  1096. target, with_default=True
  1097. )
  1098. decided_constraint = tag_to_layout_constraint(default_tag)
  1099. make_fallback(target, layout_constraint=decided_constraint)
  1100. elif get_decompositions([target]):
  1101. # There isn't a good way to dynamically patch this in
  1102. # since AOT Autograd already ran. The error message tells
  1103. # the user how to fix it.
  1104. raise MissingOperatorWithDecomp(target, args, kwargs)
  1105. else:
  1106. raise MissingOperatorWithoutDecomp(target, args, kwargs)
  1107. try:
  1108. log.debug(" via %s", lowerings[target]) # type: ignore[index]
  1109. n = self.current_node
  1110. layout_constraints = maybe_layout_constraints(target)
  1111. if layout_constraints:
  1112. old_args, old_kwargs = args, kwargs
  1113. if layout_constraints is constrain_to_fake_tensors:
  1114. # only constrain_to_fake_tensor if this exists.
  1115. # otherwise, no constraints at all: the implication is
  1116. # that this operator was inserted by a custom pass
  1117. # so we'll give them the freedom.
  1118. if "eager_input_vals" in n.meta:
  1119. fake_args, fake_kwargs = n.meta["eager_input_vals"]
  1120. # (fake_args, fake_kwargs) might not align with (args, kwargs).
  1121. # we need to normalize them based on the schema
  1122. assert isinstance(target, torch._ops.OpOverload)
  1123. def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]:
  1124. result = torch.fx.operator_schemas.normalize_function(
  1125. target, args, kwargs
  1126. )
  1127. assert result is not None
  1128. return result[0], result[1]
  1129. fake_args, fake_kwargs = normalize(fake_args, fake_kwargs)
  1130. args, kwargs = normalize(args, kwargs)
  1131. old_args, old_kwargs = normalize(old_args, old_kwargs)
  1132. args, kwargs = constrain_to_fake_tensors(
  1133. args, kwargs, fake_args, fake_kwargs
  1134. )
  1135. else:
  1136. args, kwargs = layout_constraints(n, *args, **kwargs)
  1137. out = lowerings[target](*args, **kwargs) # type: ignore[index]
  1138. if layout_constraints:
  1139. # layout_constraints are allowed to make new copies of the inputs.
  1140. # if they do, and if the target is mutable, then we need to
  1141. # write the new values back into the original inputs.
  1142. self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
  1143. return out
  1144. except Exception as e:
  1145. raise LoweringException(e, target, args, kwargs).with_traceback(
  1146. e.__traceback__
  1147. ) from None
  1148. @staticmethod
  1149. def can_inline_constant(t: torch.Tensor) -> bool:
  1150. """
  1151. True if this is a small constant attr that will be inlined.
  1152. """
  1153. return len(t.shape) == 1 and t.shape[0] <= 8
  1154. def get_attr(
  1155. self,
  1156. target: str, # type: ignore[override]
  1157. args: tuple[()], # type: ignore[override]
  1158. kwargs: dict[str, object],
  1159. ) -> Union[
  1160. Constant, TensorBox, ShapeAsConstantBuffer, ir.Subgraph, TorchBindObject
  1161. ]:
  1162. # this is a constant
  1163. value = getattr_recursive(self.module, target) # type: ignore[arg-type]
  1164. if isinstance(value, torch.fx.GraphModule):
  1165. # Reuse the existing subgraph if we have seen it before already.
  1166. if target in self.seen_subgraphs:
  1167. return self.seen_subgraphs[target]
  1168. out = ir.Subgraph(name=target, graph_module=value)
  1169. self.seen_subgraphs[target] = out
  1170. return out
  1171. if isinstance(value, torch._C.ScriptObject):
  1172. self.torchbind_constants[target] = value
  1173. self.constant_reprs[target] = ""
  1174. return TorchBindObject(name=target, value=value)
  1175. elif isinstance(value, FakeScriptObject):
  1176. self.torchbind_constants[target] = value
  1177. self.constant_reprs[target] = ""
  1178. return TorchBindObject(name=target, value=value)
  1179. assert isinstance(value, torch.Tensor)
  1180. if (
  1181. config.aot_inductor.use_runtime_constant_folding
  1182. or config.always_keep_tensor_constants
  1183. or unsupported_output_tensor(value)
  1184. ):
  1185. return self.add_tensor_constant(value, target)
  1186. with no_dispatch():
  1187. if value.shape == ():
  1188. return Constant(
  1189. value=value.item(), dtype=value.dtype, device=value.device
  1190. )
  1191. if self.can_inline_constant(value):
  1192. log.debug("Inlining constant: %s ", str(target))
  1193. # tensor lowering has constant inlining logic
  1194. from .lowering import tensor
  1195. return tensor(value.tolist(), dtype=value.dtype, device=value.device)
  1196. return self.add_tensor_constant(value, target)
  1197. def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
  1198. raise AssertionError
  1199. def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
  1200. raise AssertionError
  1201. def output(
  1202. self,
  1203. target: str, # type: ignore[override]
  1204. args: tuple[object], # type: ignore[override]
  1205. kwargs: dict[str, object],
  1206. ) -> None:
  1207. result = super().output(target, args, kwargs) # type: ignore[arg-type]
  1208. if not isinstance(result, (tuple, list)):
  1209. # nested subgraphs can have singleton outputs
  1210. result = (result,)
  1211. assert isinstance(result, (tuple, list)), type(result)
  1212. assert all(
  1213. isinstance(
  1214. x,
  1215. (
  1216. TensorBox,
  1217. ir.Constant,
  1218. type(None),
  1219. ir.ConstantBuffer,
  1220. sympy.Expr,
  1221. sympy.logic.boolalg.Boolean,
  1222. int,
  1223. ir.EffectfulKernel,
  1224. ir.ShapeAsConstantBuffer,
  1225. ),
  1226. )
  1227. for x in result
  1228. ), result
  1229. fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
  1230. if not isinstance(fx_node_args, (tuple, list)):
  1231. # nested subgraphs can have singleton outputs
  1232. fx_node_args = (fx_node_args,)
  1233. result = [ir.ExternKernel.realize_input(x) for x in result]
  1234. result_correct_strides = []
  1235. assert len(fx_node_args) == len(result)
  1236. for r, fx_node in zip(result, fx_node_args):
  1237. if not isinstance(r, (ir.TensorBox, ir.BaseView)):
  1238. result_correct_strides.append(r)
  1239. elif isinstance(r.get_output_spec(), ir.CommBufferLayout):
  1240. # Active references to persistent comm buffers are not allowed
  1241. # outside of graphs
  1242. result_correct_strides.append(ir.ExternKernel.copy_input(r))
  1243. else:
  1244. # AOT Autograd tries to detect stride divergence of inductor from output metadata.
  1245. # Here, we try to avoid spurious divergence by matching insignificant strides such as
  1246. # should have already been realized
  1247. assert torch._inductor.ir.is_storage_and_layout(r)
  1248. meta_strides = [
  1249. s.node.expr if isinstance(s, torch.SymInt) else s
  1250. for s in fx_node.meta["val"].stride()
  1251. ]
  1252. result_correct_strides.append(
  1253. ir.try_match_insignificant_strides(r, meta_strides)
  1254. )
  1255. self.graph_outputs = result_correct_strides
  1256. value: ir.IRNode
  1257. for name, value in self.graph_inputs.items():
  1258. if isinstance(value, TorchBindObject):
  1259. continue
  1260. assert isinstance(
  1261. value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState)
  1262. ), f"Unsupported inductor graph input type: {type(value)}"
  1263. if not isinstance(value, TensorBox):
  1264. continue
  1265. value.realize()
  1266. assert isinstance(value, TensorBox)
  1267. value = value.data
  1268. assert isinstance(value, ir.StorageBox)
  1269. value_storage_box = value
  1270. value = value.data
  1271. if not isinstance(value, InputBuffer) or value.get_name() != name:
  1272. # one of our inputs was mutated, need to turn that into a copy
  1273. ir.MutationLayoutSHOULDREMOVE.realize_into(
  1274. value, self.graph_inputs_original[name]
  1275. )
  1276. # replace output with mutated input
  1277. try:
  1278. ind = self.graph_outputs.index(value_storage_box)
  1279. self.graph_outputs[ind] = self.graph_inputs_original[name]
  1280. except ValueError:
  1281. pass
  1282. self.finalize()
  1283. log.debug(
  1284. "Force channels last inputs for %d conv for the current graph with id %d",
  1285. self.num_channels_last_conv,
  1286. self.graph_id if self.graph_id is not None else -1,
  1287. )
  1288. def finalize(self) -> None:
  1289. for buf in self.buffers:
  1290. buf.decide_layout()
  1291. @contextmanager
  1292. def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def]
  1293. old = self.current_node
  1294. try:
  1295. self.current_node = node
  1296. yield
  1297. finally:
  1298. self.current_node = old
  1299. @contextmanager
  1300. def set_current_wrapper_code(self) -> Iterator[None]:
  1301. old = self.wrapper_code
  1302. try:
  1303. yield
  1304. finally:
  1305. self.wrapper_code = old
  1306. def propagate_mutation(
  1307. self,
  1308. fx_node: torch.fx.Node,
  1309. old_args: tuple[Any],
  1310. old_kwargs: dict[str, Any],
  1311. new_args: tuple[Any],
  1312. new_kwargs: dict[str, Any],
  1313. ) -> None:
  1314. """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
  1315. Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs
  1316. and then called fx_node(*new_args, **new_kwargs).
  1317. If fx_node mutates any of new_args/new_kwargs, and they are different from
  1318. old_args/old_kwargs, then we need to update the original tensor.
  1319. """
  1320. assert len(old_args) == len(new_args)
  1321. assert len(old_kwargs) == len(new_kwargs)
  1322. if fx_node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation:
  1323. kwargs = fx_node.kwargs["kwargs"]
  1324. assert isinstance(kwargs, dict)
  1325. mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors(
  1326. old_kwargs["kernel_idx"],
  1327. old_kwargs["constant_args_idx"],
  1328. {
  1329. k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
  1330. for k, v in kwargs.items()
  1331. },
  1332. old_kwargs["tma_descriptor_metadata"],
  1333. )
  1334. for name in mutated:
  1335. old_arg = old_kwargs["kwargs"][name]
  1336. new_arg = new_kwargs["kwargs"][name]
  1337. if old_arg is new_arg:
  1338. continue
  1339. self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
  1340. return
  1341. assert isinstance(fx_node.target, torch._ops.OpOverload)
  1342. def maybe_propagate(
  1343. schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode
  1344. ) -> None:
  1345. if old_arg is new_arg:
  1346. return
  1347. if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
  1348. # The lowering for copy_ is smart enough to "replace" old_arg with
  1349. # new_arg in all future uses so a copy_ kernel never gets emitted.
  1350. # old_arg, new_arg may be immutable_list
  1351. if isinstance(old_arg, ir.IRNode):
  1352. old_arg = (old_arg,) # type: ignore[assignment]
  1353. new_arg = (new_arg,) # type: ignore[assignment]
  1354. for old_arg_item, new_arg_item in zip(old_arg, new_arg): # type: ignore[call-overload]
  1355. if old_arg_item is new_arg_item:
  1356. continue
  1357. self.call_function(
  1358. torch.ops.aten.copy_.default, (old_arg_item, new_arg_item), {}
  1359. )
  1360. schema = fx_node.target._schema
  1361. for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)):
  1362. schema_arg = schema.arguments[idx]
  1363. maybe_propagate(schema_arg, old_arg, new_arg)
  1364. schema_kwargs = {arg.name: arg for arg in schema.arguments}
  1365. for key in old_kwargs.keys():
  1366. old_arg = old_kwargs[key]
  1367. new_arg = new_kwargs[key]
  1368. schema_arg = schema_kwargs[key]
  1369. maybe_propagate(schema_arg, old_arg, new_arg)
  1370. def run_node(self, n: torch.fx.Node) -> object:
  1371. def debug(msg: str) -> None:
  1372. log.debug("lowering %s %s", LazyString(n.format_node), msg) # type: ignore[arg-type]
  1373. from torch._inductor.compiler_bisector import CompilerBisector
  1374. buffer_watermark = len(self.buffers)
  1375. operation_watermark = len(self.operations)
  1376. # origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n])
  1377. origins: OrderedSet[Any] = OrderedSet([n])
  1378. is_call_function = n.op == "call_function"
  1379. if is_call_function:
  1380. args, kwargs = self.fetch_args_kwargs_from_env(n)
  1381. origins |= gather_origins(args, kwargs)
  1382. with (
  1383. ir.IRNode.current_origins(origins),
  1384. self.set_current_node(n),
  1385. V.set_current_node(n),
  1386. ):
  1387. if (
  1388. n.op == "call_function"
  1389. # this path only for built-in operators
  1390. and n.target
  1391. and isinstance(n.target, torch._ops.OpOverload)
  1392. and torch._library.utils.is_builtin(n.target)
  1393. and (
  1394. fallback_node_due_to_unsupported_type(n)
  1395. or CompilerBisector.disable_subsystem(
  1396. "inductor", "lowerings", lambda: repr(n)
  1397. )
  1398. )
  1399. ):
  1400. debug("fallback_handler")
  1401. result = fallback_handler(n.target, add_to_fallback_set=False)(
  1402. *args, # type: ignore[possibly-undefined]
  1403. **kwargs, # type: ignore[possibly-undefined]
  1404. )
  1405. elif (
  1406. n.op == "call_function"
  1407. and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation
  1408. and config.triton_kernel_default_layout_constraint != "flexible_layout"
  1409. ):
  1410. debug("user_defined_triton_kernel_layout_constraints")
  1411. if (
  1412. config.triton_kernel_default_layout_constraint
  1413. == "needs_fixed_stride_order"
  1414. ):
  1415. old_args = args # type: ignore[possibly-undefined]
  1416. old_kwargs = kwargs # type: ignore[possibly-undefined]
  1417. if eager_input_vals := n.meta.get("eager_input_vals"):
  1418. inp_args = eager_input_vals[0]
  1419. inp_kwargs = eager_input_vals[1]
  1420. args, kwargs = constrain_to_fake_tensors(
  1421. args, kwargs, inp_args, inp_kwargs
  1422. )
  1423. else:
  1424. args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
  1425. result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
  1426. self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
  1427. else:
  1428. raise RuntimeError(
  1429. f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}"
  1430. )
  1431. elif is_magic_method(n.target):
  1432. # TODO: this is sus, it probably should be handled in the
  1433. # lowerings themselves similarly to sym_size/sym-stride
  1434. # https://github.com/pytorch/pytorch/issues/127789
  1435. debug("is_magic_method")
  1436. if isinstance(
  1437. n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
  1438. ):
  1439. result = n.meta["val"].node.expr
  1440. else:
  1441. result = super().run_node(n)
  1442. else:
  1443. debug("")
  1444. result = super().run_node(n)
  1445. # require the same stride order for dense outputs,
  1446. # 1. user-land view() will not throw because inductor
  1447. # output different strides than eager
  1448. # long term the solution is to make view() always succeed
  1449. # with infallible strides.
  1450. # 2: as_strided ops, we need make sure its input has same size/stride with
  1451. # eager model to align with eager behavior.
  1452. as_strided_ops = [
  1453. torch.ops.aten.as_strided.default,
  1454. torch.ops.aten.as_strided_.default,
  1455. torch.ops.aten.as_strided_scatter.default,
  1456. torch.ops.aten.resize.default,
  1457. torch.ops.aten.resize_as.default,
  1458. ]
  1459. is_output = any(user.op == "output" for user in n.users)
  1460. is_user_visible = n in self.user_visible_output_strides
  1461. is_input_for_as_strided = any(
  1462. user.target in as_strided_ops for user in n.users
  1463. )
  1464. if n.meta.get("inductor_realize_to_strides", False) and isinstance(
  1465. result, TensorBox
  1466. ):
  1467. result.realize()
  1468. strides = n.meta["val"].stride()
  1469. sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
  1470. if result.maybe_get_stride() != strides and not sym_strides:
  1471. stride_order = ir.get_stride_order(strides)
  1472. result = ir.ExternKernel.require_stride_order(result, stride_order)
  1473. if (
  1474. is_output
  1475. and isinstance(result, TensorBox)
  1476. and isinstance(result.data, ir.BaseView)
  1477. ):
  1478. # Realize so that outputs are correctly aliased
  1479. result.realize()
  1480. if (is_output or is_input_for_as_strided) and isinstance(
  1481. n.meta["val"], torch.Tensor
  1482. ):
  1483. if is_user_visible:
  1484. strides = self.user_visible_output_strides.get(n)
  1485. else:
  1486. strides = n.meta["val"].stride()
  1487. if strides is not None and len(strides) > 0:
  1488. allow_padding = (
  1489. config.pad_outputs or not is_user_visible
  1490. ) and not is_input_for_as_strided
  1491. dense = torch._prims_common.is_non_overlapping_and_dense(
  1492. n.meta["val"]
  1493. )
  1494. unbacked_symbols_in_strides = (
  1495. len(free_unbacked_symbols(strides)) > 0
  1496. )
  1497. if (
  1498. not unbacked_symbols_in_strides
  1499. and dense
  1500. and len(result.get_size()) == 4
  1501. and n in self.nodes_prefer_channels_last
  1502. and not is_user_visible
  1503. and not is_input_for_as_strided
  1504. ):
  1505. strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
  1506. result.get_size(), torch.channels_last
  1507. )
  1508. if not unbacked_symbols_in_strides and len(strides):
  1509. # To avoid converting possible view ops to a copy kernel, we use the previous
  1510. # require_exact_strides to handle views. But ultimately it's better to require
  1511. # the right strides at the tensor definition.
  1512. if n.meta["val"]._is_view() or isinstance(
  1513. result.data, ir.BaseView
  1514. ):
  1515. result = ir.ExternKernel.require_stride_order(
  1516. result,
  1517. ir.get_stride_order(strides),
  1518. allow_padding=allow_padding,
  1519. )
  1520. else:
  1521. strides = [
  1522. s.node.expr if isinstance(s, torch.SymInt) else s
  1523. for s in strides
  1524. ]
  1525. result = ir.ExternKernel.require_exact_strides(
  1526. result, strides, allow_padding=allow_padding
  1527. )
  1528. # Realize if (1) any user need inputs realized, or (2) there is
  1529. # already too many reads and rematerializing can be bad.
  1530. num_users = len(OrderedSet(n.users))
  1531. if num_users > 1 and isinstance(result, TensorBox):
  1532. for user in n.users:
  1533. if user.target in needs_realized_inputs:
  1534. result.realize_hint()
  1535. # This inclusion is somewhat controversial (from
  1536. # discussion between Horace, Natalia, and Elias).
  1537. # Currently, it's not very clear why this is helpful.
  1538. # The general idea here is that even though a node may
  1539. # have FlexibleLayout, we still often *treat* it as if
  1540. # it was contiguous. This appears to sometimes result in
  1541. # suboptimal behavior.
  1542. #
  1543. # When we do a better job selecting layout, we should
  1544. # revisit this.
  1545. need_fixed_layout = [
  1546. torch.ops.aten.convolution_backward.default,
  1547. torch.ops.aten.mm.default,
  1548. torch.ops.aten._int_mm.default,
  1549. ]
  1550. need_fixed_channels_last_layout = []
  1551. if not self.layout_opt:
  1552. need_fixed_layout.append(torch.ops.aten.convolution.default)
  1553. if torch._C._has_mkldnn:
  1554. need_fixed_layout += [
  1555. torch.ops.mkldnn._linear_pointwise.default,
  1556. torch.ops.mkldnn._linear_pointwise.binary,
  1557. torch.ops.aten.mkldnn_rnn_layer.default,
  1558. torch.ops.onednn.qlinear_pointwise.default,
  1559. torch.ops.onednn.qlinear_pointwise.tensor,
  1560. torch.ops.onednn.qlinear_pointwise.binary,
  1561. torch.ops.onednn.qlinear_pointwise.binary_tensor,
  1562. ]
  1563. need_fixed_channels_last_layout += [
  1564. torch.ops.mkldnn._convolution_pointwise.default,
  1565. torch.ops.mkldnn._convolution_pointwise.binary,
  1566. torch.ops.mkldnn._convolution_pointwise_.binary,
  1567. torch.ops.mkldnn._convolution_transpose_pointwise.default,
  1568. torch.ops.onednn.qconv_pointwise.default,
  1569. torch.ops.onednn.qconv2d_pointwise.binary,
  1570. ]
  1571. if torch._C.has_mkl:
  1572. need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
  1573. if user.target in need_fixed_layout:
  1574. result = ir.ExternKernel.require_stride_order(
  1575. result,
  1576. ir.get_stride_order(n.meta["val"].stride()),
  1577. allow_padding=True,
  1578. )
  1579. if (
  1580. user.target in need_fixed_channels_last_layout
  1581. and n is user.args[0]
  1582. ):
  1583. result = ir.ExternKernel.require_stride_order(
  1584. result,
  1585. ir.get_stride_order(
  1586. make_channels_last_strides_for(n.meta["val"].shape)
  1587. ),
  1588. )
  1589. if user.op == "output":
  1590. if isinstance(result.data.data, (Pointwise, Reduction)):
  1591. result.realize()
  1592. # TODO(jansel): introduce a store vs inline choice
  1593. result.mark_reuse(len(n.users))
  1594. # Realize if the IRNode already has accumulated lots of reads
  1595. if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
  1596. # Prevent excessive accumulation in a computed buffer, when
  1597. # there are multiple branches each with small number of memory
  1598. # reads, but they converge to a user.
  1599. result.realize_hint()
  1600. # Realize if a Pointwise has too much stuff to be inlined.
  1601. # As this may cause RecursionError during Inductor's evaluation.
  1602. if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
  1603. curr = result.data.data
  1604. if isinstance(curr, Pointwise):
  1605. # Use inner fn as a rough proxy. Good enough.
  1606. if curr.has_large_inner_fn(threshold=100):
  1607. result.realize()
  1608. # This is not complete, but it doesn't have to be: origin_node
  1609. # tracking is best effort. The logic here critically relies on direct
  1610. # TensorBox -> StorageBox denoting a non-view; we don't bother trying
  1611. # to get views to work. Feel free to add any extra cases as needed.
  1612. #
  1613. # Note: we can't YOLO tree_map over this result, because if there are
  1614. # buffers or a view involved, we might not be able to validly assign
  1615. # the origin_node here.
  1616. if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
  1617. if isinstance(result.data.data, ir.Loops):
  1618. result.data.data._post_init_setattr("origin_node", n)
  1619. elif isinstance(result.data.data, ir.Buffer):
  1620. result.data.data._post_init_setattr("origin_node", n)
  1621. if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
  1622. result.data.data.data, ir.Loops
  1623. ):
  1624. result.data.data.data._post_init_setattr("origin_node", n)
  1625. # Not really multi-output, can straightforwardly recurse in
  1626. elif (
  1627. isinstance(result.data.data, ir.MultiOutput)
  1628. and not result.data.data.indices
  1629. ):
  1630. if isinstance(result.data.data.inputs[0], ir.Buffer):
  1631. result.data.data.inputs[0]._post_init_setattr("origin_node", n)
  1632. self.register_users_of(result)
  1633. new_unbacked_defs = OrderedSet[sympy.Symbol]()
  1634. for buf in self.buffers[buffer_watermark:]:
  1635. new_unbacked_defs |= buf.get_unbacked_symbol_defs()
  1636. for op in self.operations[operation_watermark:]:
  1637. new_unbacked_defs |= op.get_unbacked_symbol_defs()
  1638. shape_env = V.graph.sizevars.shape_env
  1639. # An input can be unbacked symint i.e.: when mark_unabcked is used.
  1640. # in that case add it to new_unbacked_defs.
  1641. if (
  1642. n.op == "placeholder"
  1643. and isinstance(result, sympy.Symbol)
  1644. and shape_env.is_unbacked_symint(result)
  1645. ):
  1646. new_unbacked_defs.add(result)
  1647. def format_new_defs() -> str:
  1648. r = [
  1649. f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
  1650. for buf in self.buffers[buffer_watermark:]
  1651. ]
  1652. r.extend(
  1653. f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
  1654. for op in self.operations[operation_watermark:]
  1655. )
  1656. return "***\n".join(r)
  1657. # We do not skip unbacked symints that are input for backward see the note below.
  1658. if V.graph.is_backward and n.op == "placeholder":
  1659. return result
  1660. # Note [Backwards runtime asserts]
  1661. # Backwards poses an interesting problem for deferred runtime
  1662. # asserts. In the easy case, we may solely close over data
  1663. # dependent sized tensors, and there are no binding sites for
  1664. # unbacked SymInts. In this case, we can just drop all the
  1665. # runtime asserts on the floor: no non-placeholder bindings, no
  1666. # problem.
  1667. #
  1668. # However, it is *possible* for a fresh runtime assert to show up
  1669. # between forwards and backwards. Right now, the freezing process
  1670. # that happens when we lower forwards means that we will freeze
  1671. # runtime asserts, and then the moment the backwards lowering
  1672. # process attempts to add a new deferred runtime assert, we will
  1673. # fail. Let's say you remove that assert. Now when we get here,
  1674. # we need to make sure we actually emit these asserts (because we
  1675. # can't emit them in forwards, we already compiled it). So we
  1676. # have to do something here. But we don't want to reemit ALL
  1677. # deferred runtime asserts, we only want to emit the NEW ones.
  1678. # Therefore needing some sort of stratification in the ShapeEnv.
  1679. # This is all doable, it just hasn't been done yet.
  1680. unbacked_bindings = resolve_unbacked_bindings(
  1681. V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
  1682. )
  1683. assert unbacked_bindings is not None
  1684. # When we do lowering, it is possible we reallocate unbacked SymInts.
  1685. # So we need to line up the unbacked SymInts when performing the test
  1686. # here
  1687. #
  1688. # In principle, we could permit lowering to introduce MORE unbacked
  1689. # SymInts: as long as all the old unbacked ones are accounted for,
  1690. # it's fine for inductor to introduce extra calls to item()/unbacked()
  1691. # whatever. This actually happens in practice when an unbacked SymInt
  1692. # gets memoized away; naively, when Inductor reprocesses a kernel, it
  1693. # doesn't know that the memo still applies, and ends up allocating a
  1694. # new symbol. However, this is generally a bad thing: we may still
  1695. # end up needing to test equalities on the symbols, and a fresh
  1696. # symbol is likely to hit lots of GuardOnDataDependent errors that
  1697. # we already know facts for.
  1698. renamed_unbacked_bindings = OrderedSet(
  1699. V.fake_mode.shape_env.unbacked_renamings.get(s, s)
  1700. for s in unbacked_bindings.keys()
  1701. )
  1702. assert new_unbacked_defs >= renamed_unbacked_bindings, (
  1703. f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
  1704. f"fx node is: {n.format_node()}\n"
  1705. f"new operations are:\n\n{format_new_defs()}"
  1706. )
  1707. self.create_deferred_runtime_asserts(n, new_unbacked_defs)
  1708. return result
  1709. def create_deferred_runtime_asserts(
  1710. self, n: torch.fx.Node, new_unbacked_defs: OrderedSet[sympy.Symbol]
  1711. ) -> None:
  1712. # [NOTE] Codegen runtime asserts in Inductor
  1713. #
  1714. # We need to generate runtime asserts directly in Inductor instead
  1715. # of just reusing the asserts from input graphs because we reuse the
  1716. # same ShapeEnv as before. In particular, on subsequent graph passes,
  1717. # we would immediately turn all of these assertions into noops,
  1718. # because when we evaluated their expressions, we would see that
  1719. # because we had a deferred runtime assert in the ShapeEnv, we
  1720. # know "oh, of course this expression is True" already.
  1721. # One example is below:
  1722. #
  1723. # class Model(torch.nn.Module):
  1724. # def forward(self, a, b, c):
  1725. # nz = torch.nonzero(a)
  1726. # ones = a.new_ones([nz.size(0), b.size(0)])
  1727. # torch._check(ones.size(0) >= 1)
  1728. # equals = torch.add(ones, c)
  1729. # return equals
  1730. # torch._dynamo.mark_dynamic(c, 0)
  1731. # When we reuse the ShapeEnv in Inductor lowering, the check that checks
  1732. # a and nonzero have the same shape would be evaluated to True after we resolve
  1733. # unbacked bindings using the ShapeEnv.
  1734. # See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor.
  1735. #
  1736. #
  1737. # In addition to the Inductor generated runtime asserts, we also
  1738. # need the runtime asserts from the input graph, because some derived
  1739. # runtime asserts on backed symints are not generated in Inductor. One example is
  1740. # this: `y = x.reshape(100, -1).clone()`. x.shape[0] needs to be a multiple of 100.
  1741. # See test_aoti_runtime_asserts_backed_symint in test_aot_inductor.
  1742. def make_assert(expr: SympyBoolean, msg: str) -> None:
  1743. assert_op = ir.AssertScalar(expr, msg)
  1744. self.register_buffer(assert_op, set_name=True)
  1745. self.register_operation(assert_op)
  1746. if (
  1747. full_aoti_runtime_assert()
  1748. and n.target == torch.ops.aten._assert_scalar.default
  1749. and self.aot_mode
  1750. ):
  1751. node_args, _ = self.fetch_args_kwargs_from_env(n)
  1752. if node_args[0] != True: # noqa: E712
  1753. make_assert(node_args[0], f"{node_args[0]} to be True")
  1754. else:
  1755. # bound_unbacked_symbols tracks the symbols that are created so far,
  1756. # we use it to make sure that runtime assertions are added after all
  1757. # symbols used in them are defined.
  1758. self.bound_unbacked_symbols |= new_unbacked_defs
  1759. shape_env = V.graph.sizevars.shape_env
  1760. # Emit code for runtime asserts that can be inserted at this point.
  1761. for i0 in new_unbacked_defs:
  1762. ras = self.ras_by_symbol.pop(i0, [])
  1763. # NB: size-like not needed, we won't retrace
  1764. vr = shape_env.var_to_range[i0]
  1765. if not shape_env._default_unspecified_value_range().issubset(vr):
  1766. def is_convertible(s: Expr) -> bool:
  1767. if s in (int_oo, -int_oo):
  1768. return False
  1769. try:
  1770. int(s)
  1771. return True
  1772. except TypeError:
  1773. return False
  1774. if is_convertible(vr.lower):
  1775. make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}")
  1776. if is_convertible(vr.upper):
  1777. make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}")
  1778. for ra in ras:
  1779. fvs = free_unbacked_symbols(ra.expr)
  1780. missing = fvs - self.bound_unbacked_symbols
  1781. if missing:
  1782. i1 = min(missing, key=str)
  1783. self.ras_by_symbol.setdefault(i1, []).append(ra)
  1784. else:
  1785. make_assert(ra.expr, f"{ra.expr}")
  1786. def validate_can_generate_cpp_wrapper(self) -> None:
  1787. if config.disable_cpp_codegen:
  1788. raise CppWrapperCodegenError("C++ codegen is disabled")
  1789. if sys.platform not in ("linux", "darwin", "win32"):
  1790. raise CppWrapperCodegenError(f"Unsupported platform {sys.platform}")
  1791. def init_wrapper_code(
  1792. self,
  1793. is_subgraph: bool = False,
  1794. subgraph_name: Optional[str] = None,
  1795. parent_wrapper_code: Optional[PythonWrapperCodegen] = None,
  1796. partition_signatures: Optional[GraphPartitionSignature] = None,
  1797. ) -> None:
  1798. device_types = self.device_types.copy()
  1799. device_types.discard("cpu")
  1800. device_types.discard("meta")
  1801. # TODO(Eikan): Only support mixing cpu and other device now.
  1802. assert len(device_types) <= 1, "Does not support mixing {}".format(
  1803. "+".join(device_types)
  1804. )
  1805. only_cpu = len(device_types) == 0
  1806. self.device_type = "cpu" if only_cpu else device_types.pop()
  1807. if self.cpp_wrapper:
  1808. self.validate_can_generate_cpp_wrapper()
  1809. self.device_ops = get_device_op_overrides(self.device_type)
  1810. wrapper_code_gen_cls = get_wrapper_codegen_for_device(
  1811. self.device_type, self.cpp_wrapper, self.fx_wrapper
  1812. )
  1813. assert wrapper_code_gen_cls is not None, (
  1814. f"Device {self.device_type} not supported"
  1815. )
  1816. self.wrapper_code = wrapper_code_gen_cls.create(
  1817. is_subgraph,
  1818. subgraph_name,
  1819. parent_wrapper_code,
  1820. partition_signatures,
  1821. )
  1822. if self.const_module:
  1823. self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
  1824. def extract_autotune_inputs(
  1825. self, example_inputs: list[Union[int, float, torch.Tensor]]
  1826. ) -> None:
  1827. import copy
  1828. cloned_gm = copy.deepcopy(self.orig_gm)
  1829. example_inputs = copy.deepcopy(example_inputs)
  1830. triton_nodes = []
  1831. for node in cloned_gm.graph.nodes:
  1832. if (
  1833. node.op == "call_function"
  1834. and node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation
  1835. ):
  1836. triton_nodes.append(node)
  1837. # Store grid related nodes
  1838. grid_inputs: list[torch.fx.Node] = []
  1839. visited_grids: dict[torch.fx.Node, int] = {}
  1840. # Store kwargs related nodes
  1841. triton_inputs: dict[str, Any] = {}
  1842. kwargs_inputs: list[torch.fx.Node] = []
  1843. visited_kwargs: dict[Any, int] = {}
  1844. for node in triton_nodes:
  1845. # first check whether we have fx node in grid settings.
  1846. for grid in node.kwargs["grid"]:
  1847. for val in grid:
  1848. if val in visited_grids:
  1849. continue
  1850. if isinstance(val, torch.fx.Node):
  1851. visited_grids[val] = len(grid_inputs)
  1852. grid_inputs.append(val)
  1853. kwargs = node.kwargs["kwargs"]
  1854. # identify which args might be mutated, those should be cloned.
  1855. mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors(
  1856. node.kwargs["kernel_idx"],
  1857. node.kwargs["constant_args_idx"],
  1858. {
  1859. k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
  1860. for k, v in kwargs.items()
  1861. },
  1862. node.kwargs["tma_descriptor_metadata"],
  1863. )
  1864. new_kwargs: dict[str, int] = {}
  1865. with cloned_gm.graph.inserting_before(node):
  1866. for k, v in kwargs.items():
  1867. if k in mutated:
  1868. new_node = cloned_gm.graph.call_function(torch.clone, args=(v,))
  1869. new_kwargs[k] = len(kwargs_inputs)
  1870. kwargs_inputs.append(new_node)
  1871. continue
  1872. if v in visited_kwargs:
  1873. new_kwargs[k] = visited_kwargs[v]
  1874. continue
  1875. visited_kwargs[v] = len(kwargs_inputs)
  1876. kwargs_inputs.append(v)
  1877. new_kwargs[k] = visited_kwargs[v]
  1878. triton_inputs[node.name] = new_kwargs
  1879. new_outputs = kwargs_inputs + grid_inputs
  1880. for node in cloned_gm.graph.nodes:
  1881. if node.op == "output":
  1882. node.args = (tuple(new_outputs),)
  1883. break
  1884. cloned_gm.recompile()
  1885. runner = torch.fx.Interpreter(cloned_gm)
  1886. returned_outputs = runner.run(example_inputs)
  1887. # Extract and store the grid for autotuning
  1888. if len(grid_inputs) > 0:
  1889. grid_outputs = returned_outputs[len(kwargs_inputs) :]
  1890. self.autotuning_grids = {}
  1891. for node in triton_nodes:
  1892. dynamic_grid = False
  1893. new_grids: list[tuple[Any]] = []
  1894. for grid in node.kwargs["grid"]:
  1895. new_grid = []
  1896. for val in grid:
  1897. if not isinstance(val, torch.fx.Node):
  1898. new_grid.append(val)
  1899. continue
  1900. dynamic_grid = True
  1901. new_grid.append(grid_outputs[visited_grids[val]])
  1902. new_grids.append(tuple(new_grid))
  1903. if dynamic_grid:
  1904. self.autotuning_grids[node.name] = new_grids
  1905. # Store the kwargs input for autotuning
  1906. self.autotuning_inputs = returned_outputs[: len(kwargs_inputs)]
  1907. self.autotuning_mapping = triton_inputs
  1908. def codegen_with_cpp_wrapper(
  1909. self,
  1910. ) -> tuple[ValueWithLineMap, ValueWithLineMap]:
  1911. """
  1912. For GPU, Triton kernels are autotuned and stored as cubin files
  1913. """
  1914. if any(device in self.device_types for device in ["cuda", "xpu"]):
  1915. def extract_real_inputs() -> list[Union[int, float, torch.Tensor]]:
  1916. def materialize(
  1917. x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
  1918. ) -> Union[int, float, torch.Tensor]:
  1919. if x is None:
  1920. return None
  1921. elif isinstance(x, (torch.SymInt, torch.SymFloat)):
  1922. # Need concrete value to run dynamic shapes and tune the result
  1923. return x.node.hint
  1924. elif isinstance(x, FakeTensor):
  1925. return defake(x)
  1926. else:
  1927. assert isinstance(x, torch.Tensor), (
  1928. "Unknown type when creating real inputs" + str(type(x))
  1929. )
  1930. return x
  1931. tracing_context = torch._guards.TracingContext.try_get()
  1932. if tracing_context is not None and not isinstance(
  1933. V.real_inputs, NullHandler
  1934. ):
  1935. if tracing_context.output_strides:
  1936. tracing_context.output_strides.clear()
  1937. params_flat = [
  1938. param
  1939. for param in tracing_context.params_flat # type: ignore[union-attr]
  1940. if param is not None
  1941. ]
  1942. real_inputs = [
  1943. materialize(x)
  1944. for x in itertools.chain(params_flat, V.real_inputs)
  1945. ]
  1946. else:
  1947. # In the backward pass, V.real_inputs is not OrderedSet.
  1948. # Generating random inputs based on self.example_inputs sometimes can be problematic,
  1949. # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
  1950. real_inputs = [
  1951. materialize(x) # type:ignore[arg-type]
  1952. for x in (
  1953. self.example_inputs # type:ignore[union-attr]
  1954. if isinstance(V.real_inputs, NullHandler)
  1955. else V.real_inputs
  1956. )
  1957. ]
  1958. if self.mutated_inputs:
  1959. from .compile_fx import clone_preserve_strides
  1960. mutated_input_idxs = [
  1961. idx
  1962. for idx, name in enumerate(self.graph_inputs)
  1963. if name in self.mutated_inputs
  1964. and isinstance(real_inputs[idx], torch.Tensor)
  1965. ]
  1966. for idx in mutated_input_idxs:
  1967. # clone mutated Tensor inputs to avoid mutating them in
  1968. # the first pass of the CPP wrapper-based compilation, as
  1969. # this will lead to a side effect on the example inputs:
  1970. # e.g. if torch.compile(f)(x) if called on input-mutating
  1971. # f, the inputs x will be mutated twice in the process:
  1972. # once here, and again when running the compiled model;
  1973. # this will also lead to a numerically incorrect output
  1974. mutated_inp = real_inputs[idx]
  1975. assert isinstance(mutated_inp, torch.Tensor)
  1976. real_inputs[idx] = clone_preserve_strides(mutated_inp)
  1977. del mutated_inp
  1978. return real_inputs
  1979. if config.triton.autotune_at_compile_time:
  1980. # If autotune_at_compile_time is True, we can do the codegen in one-pass
  1981. # We will construct the autotuning values if user defined kernel exists.
  1982. if config.triton.autotune_with_sample_inputs:
  1983. user_defined_kernels = False
  1984. for op in self.operations:
  1985. if isinstance(op, ir.UserDefinedTritonKernel):
  1986. user_defined_kernels = True
  1987. break
  1988. if user_defined_kernels:
  1989. real_inputs = extract_real_inputs()
  1990. self.extract_autotune_inputs(real_inputs)
  1991. return self.codegen()
  1992. else:
  1993. # first pass
  1994. self.cpp_wrapper = False
  1995. compiled = self.compile_to_module().call
  1996. real_inputs = extract_real_inputs()
  1997. with torch.utils._python_dispatch._disable_current_modes():
  1998. compiled(real_inputs)
  1999. del real_inputs
  2000. # second pass
  2001. self.cpp_wrapper = True
  2002. self.removed_buffers.clear()
  2003. self.removed_operations.clear()
  2004. self.inplaced_to_remove.clear()
  2005. V.graph.sizevars.precomputed_replacements.clear()
  2006. V.graph.sizevars.inv_precomputed_replacements.clear()
  2007. metrics.reset()
  2008. with config.patch({"triton.autotune_at_compile_time": False}):
  2009. return self.codegen()
  2010. else:
  2011. # cpu
  2012. return self.codegen()
  2013. def _update_scheduler(self) -> None:
  2014. """
  2015. (Re)initializes the scheduler member. When initializing the scheduler, no CUBIN
  2016. files should be generated (to avoid biasing any benchmarks and pessimizing
  2017. fusion decisions).
  2018. """
  2019. from .scheduler import Scheduler
  2020. with config.patch("triton.store_cubin", False):
  2021. self.scheduler = Scheduler(self.operations)
  2022. def codegen(self) -> tuple[ValueWithLineMap, ValueWithLineMap]:
  2023. with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
  2024. self.init_wrapper_code()
  2025. self._update_scheduler()
  2026. V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
  2027. self.wrapper_code.push_codegened_graph(self)
  2028. self.scheduler.codegen()
  2029. log.debug(
  2030. "Finished codegen for all nodes. The list of kernel names available: %s",
  2031. V.graph.all_codegen_kernel_names,
  2032. )
  2033. result = self.wrapper_code.generate(self.is_inference)
  2034. self.wrapper_code.pop_codegened_graph()
  2035. return result
  2036. def codegen_subgraph(self, parent_graph: GraphLowering) -> None:
  2037. """
  2038. This is a more compact version of the `codegen()` above
  2039. where we codegen this graph as a subgraph of some parent
  2040. graph. The parent graph is passed as an argument: the
  2041. intention is to inline codegening of the subgraph in
  2042. the parent graph's wrapper code (including the generated
  2043. kernels). The wrapper code is not finalized (via `.generate()`
  2044. call), as this will be done in the parent graph's `codegen()`.
  2045. """
  2046. with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True):
  2047. self.wrapper_code = parent_graph.wrapper_code
  2048. self.device_ops = parent_graph.device_ops
  2049. self.cpp_wrapper = parent_graph.cpp_wrapper
  2050. self._update_scheduler()
  2051. self.scheduler.codegen()
  2052. def count_bytes(
  2053. self,
  2054. ) -> tuple[
  2055. int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]]
  2056. ]:
  2057. total_bytes = 0
  2058. node_counts = []
  2059. node_runtimes = []
  2060. for node in self.scheduler.nodes:
  2061. num_bytes = node.get_read_write_buffers_sizes()
  2062. total_bytes += num_bytes
  2063. node_counts.append((node, num_bytes // 4))
  2064. node_runtimes.append((node, node.get_estimated_runtime()))
  2065. return total_bytes, node_counts, node_runtimes
  2066. # No-op to be patched for unit tests
  2067. save_output_code: Optional[Callable[[str], None]] = None
  2068. def compile_to_module(self) -> CompiledModule:
  2069. with dynamo_timed(
  2070. "GraphLowering.compile_to_module",
  2071. phase_name="code_gen",
  2072. log_pt2_compile_event=True,
  2073. dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
  2074. ):
  2075. return self._compile_to_module()
  2076. def _compile_to_module(self) -> CompiledModule:
  2077. # If we're here, we don't have to worry about the kernel code, which is only
  2078. # returned separately in AOTInductor mode.
  2079. wrapper_code, _ = (
  2080. self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  2081. )
  2082. if isinstance(wrapper_code, ValueWithLineMap):
  2083. mod = self._compile_to_module_lines(wrapper_code)
  2084. elif isinstance(wrapper_code, FileBackedGraphModule):
  2085. mod = wrapper_code
  2086. else:
  2087. raise NotImplementedError(
  2088. f"Unrecognized wrapper code type: {type(wrapper_code)}"
  2089. )
  2090. # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
  2091. # TODO. Revisit this once the logging API is more mature
  2092. assert mod.__file__ is not None
  2093. log_module_code(mod.__file__)
  2094. log.debug("Output code written to: %s", mod.__file__)
  2095. output_code_log.info("Output code written to: %s", mod.__file__)
  2096. if config.benchmark_kernel:
  2097. print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
  2098. V.debug.output_code(mod.__file__)
  2099. V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
  2100. return mod
  2101. def _compile_to_module_lines(
  2102. self, wrapper_code: ValueWithLineMap
  2103. ) -> CompiledModule:
  2104. from .codecache import PyCodeCache
  2105. if config.triton.autotune_at_compile_time:
  2106. # sanitize docstrings in kernel defs (#155006)
  2107. kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue()
  2108. kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"')
  2109. tuning_code = (
  2110. '"""\n'
  2111. + "Compile-time auto-tuning block: \n"
  2112. + kernel_autotune_defs
  2113. + self.wrapper_code.kernel_autotune_calls.getvalue()
  2114. + '"""\n'
  2115. )
  2116. wrapper_code.value = tuning_code + wrapper_code.value
  2117. if GraphLowering.save_output_code is not None:
  2118. GraphLowering.save_output_code(wrapper_code.value)
  2119. output_code_log.debug("Output code: \n%s", wrapper_code.value)
  2120. inductor_meta = autotune_cache.inductor_meta_from_config()
  2121. AutotuneCacheBundler.begin_compile(inductor_meta, code=wrapper_code.value)
  2122. try:
  2123. linemap = [
  2124. (line_no, node.stack_trace) # type: ignore[attr-defined]
  2125. for line_no, node in wrapper_code.line_map
  2126. ]
  2127. key, path = PyCodeCache.write(wrapper_code.value)
  2128. output_code_log.debug("Output code written to: %s", path)
  2129. except Exception:
  2130. trace_structured(
  2131. "inductor_output_code",
  2132. # Just omit the filename, I still want the code though!
  2133. payload_fn=lambda: wrapper_code.value,
  2134. )
  2135. raise
  2136. else:
  2137. trace_structured(
  2138. "inductor_output_code",
  2139. lambda: {"filename": path},
  2140. payload_fn=lambda: wrapper_code.value,
  2141. )
  2142. with dynamo_timed("PyCodeCache.load_by_key_path", log_pt2_compile_event=True):
  2143. mod = PyCodeCache.load_by_key_path(
  2144. key,
  2145. path,
  2146. linemap=linemap, # type: ignore[arg-type]
  2147. attrs={**self.constants, **self.torchbind_constants},
  2148. )
  2149. self.cache_key = key
  2150. self.cache_path = path
  2151. self.cache_linemap = linemap # type: ignore[assignment]
  2152. if config.benchmark_harness and config.profile_bandwidth_output:
  2153. # run the inputs code gen to get the bandwidth info
  2154. mod.benchmark_compiled_module(times=1, repeat=1)
  2155. return mod
  2156. def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]:
  2157. names = []
  2158. shape_counter = itertools.count(0)
  2159. none_counter = itertools.count(0)
  2160. for node in graph_outputs:
  2161. if isinstance(node, ir.NoneAsConstantBuffer):
  2162. names.append(f"{self.name}_none{next(none_counter)}")
  2163. elif isinstance(node, ir.ShapeAsConstantBuffer):
  2164. names.append(f"{self.name}_shape{next(shape_counter)}")
  2165. else:
  2166. names.append(node.get_name())
  2167. return names
  2168. def get_output_names(self) -> list[str]:
  2169. return self._get_output_names(self.graph_outputs)
  2170. def is_unspec_arg(self, name: str) -> bool:
  2171. # dynamo wraps unspec variable as 0d CPU tensor,
  2172. # need to convert to scalar during codegen (triton only)
  2173. return (
  2174. name in self.graph_inputs.keys()
  2175. and self.graph_inputs[name].get_numel() == 1
  2176. and len(self.graph_inputs[name].get_size()) == 0
  2177. and get_device_type(self.graph_inputs[name]) == "cpu"
  2178. ) or name in self.zero_dim_cpu_tensor_list
  2179. class SubgraphLowering(GraphLowering):
  2180. """
  2181. Mostly a helper class for the subgraph lowering. The main goal is to call
  2182. init_wrapper_code with the subgraph related arguments.
  2183. """
  2184. def __init__(self, parent: GraphLowering, *args: Any, **kwargs: Any) -> None:
  2185. self.parent = parent
  2186. super().__init__(*args, **kwargs)
  2187. def init_wrapper_code(
  2188. self,
  2189. is_subgraph: bool = False,
  2190. subgraph_name: Optional[str] = None,
  2191. parent_wrapper_code: Optional[PythonWrapperCodegen] = None,
  2192. partition_signatures: Optional[GraphPartitionSignature] = None,
  2193. ) -> None:
  2194. super().init_wrapper_code(
  2195. is_subgraph=True,
  2196. subgraph_name=self.name,
  2197. parent_wrapper_code=self.parent.wrapper_code,
  2198. )