functions.py 92 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493
  1. # mypy: ignore-errors
  2. """
  3. Function-related variable tracking classes for Dynamo's symbolic execution.
  4. This module contains classes that track different types of functions during graph
  5. compilation, including:
  6. - User-defined functions and methods
  7. - Built-in functions and methods
  8. - Wrapped functions (e.g. from decorators)
  9. - Special function types (e.g. functools.partial)
  10. - Triton kernels and related function types
  11. These classes are responsible for:
  12. - Tracking function calls and their arguments
  13. - Managing function closures and cell variables
  14. - Handling function attributes and special methods
  15. - Maintaining guards for function identity and closure contents
  16. - Supporting function inlining and specialization
  17. - Enabling proper symbolic execution of different function types
  18. The variable trackers here work together with the rest of Dynamo to enable
  19. accurate graph capture while handling Python's various function-related behaviors.
  20. """
  21. import builtins
  22. import functools
  23. import inspect
  24. import itertools
  25. import logging
  26. import sys
  27. import traceback
  28. import types
  29. from collections.abc import Sequence
  30. from types import FunctionType
  31. from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
  32. from typing_extensions import Never
  33. from unittest.mock import patch
  34. from weakref import WeakKeyDictionary
  35. import torch
  36. from torch._dynamo.exc import get_stack_above_dynamo
  37. from .. import config, graph_break_hints, polyfills, variables
  38. from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
  39. from ..exc import (
  40. get_dynamo_observed_exception,
  41. handle_observed_exception,
  42. InfiniteGeneratorError,
  43. ObservedException,
  44. ObservedGeneratorExit,
  45. ObservedUserStopIteration,
  46. raise_observed_exception,
  47. SkipFrame,
  48. unimplemented_v2,
  49. Unsupported,
  50. )
  51. from ..guards import GuardBuilder, install_guard
  52. from ..source import (
  53. AttrSource,
  54. ClosureSource,
  55. ConstantSource,
  56. DefaultsSource,
  57. GetItemSource,
  58. SkipGuardSource,
  59. )
  60. from ..utils import (
  61. check_constant_args,
  62. check_unspec_or_constant_args,
  63. cmp_name_to_op_mapping,
  64. counters,
  65. identity,
  66. is_function,
  67. is_wrapper_or_member_descriptor,
  68. istype,
  69. make_cell,
  70. )
  71. from .base import (
  72. AsPythonConstantNotImplementedError,
  73. AttributeMutationNew,
  74. ValueMutationNew,
  75. VariableTracker,
  76. )
  77. from .constant import ConstantVariable
  78. try:
  79. from torch.distributed.fsdp._fully_shard import _fsdp_param_group
  80. except ModuleNotFoundError:
  81. _fsdp_param_group = None
  82. if TYPE_CHECKING:
  83. from torch._dynamo.codegen import PyCodegen
  84. from torch._dynamo.symbolic_convert import InstructionTranslator
  85. from torch._higher_order_ops.triton_kernel_wrap import (
  86. TritonGridType,
  87. TritonKernelType,
  88. )
  89. _F = TypeVar("_F", bound=Callable)
  90. CO_VARARGS = 0x04
  91. CO_VARKEYWORDS = 0x08
  92. # Module‐level cache keyed by the function object
  93. _spec_cache = WeakKeyDictionary()
  94. class FunctionSpec:
  95. def __init__(self, func: FunctionType):
  96. code = func.__code__
  97. vn = code.co_varnames
  98. self.posonly_count = code.co_posonlyargcount
  99. self.arg_count = code.co_argcount
  100. self.kwonly_count = code.co_kwonlyargcount
  101. self.posonly_names = vn[: self.posonly_count]
  102. self.pos_or_kw_names = vn[self.posonly_count : self.arg_count]
  103. self.all_pos_names = self.posonly_names + self.pos_or_kw_names
  104. self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count]
  105. off = self.arg_count + self.kwonly_count
  106. self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None
  107. off += 1 if self.varargs_name else 0
  108. self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None
  109. def update_defaults(self, func: FunctionType):
  110. # Defaults can change from function call to function call. So re-update
  111. # them on every call.
  112. self.defaults = func.__defaults__ or ()
  113. self.kwdefaults = func.__kwdefaults__ or {}
  114. # Map positional‐default names → their index in self.defaults
  115. self.pos_default_map = dict(
  116. zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults)))
  117. )
  118. def _get_spec(func: FunctionType) -> FunctionSpec:
  119. spec = _spec_cache.get(func)
  120. if spec is None:
  121. spec = FunctionSpec(func)
  122. _spec_cache[func] = spec
  123. return spec
  124. def bind_args_cached(func, tx, fn_source, args, kwargs):
  125. spec = _get_spec(func)
  126. spec.update_defaults(func)
  127. ba = {}
  128. rem_kw = dict(kwargs)
  129. # 1) Bind all positional (pos-only + pos-or-kw)
  130. for i, name in enumerate(spec.all_pos_names):
  131. if i < len(args):
  132. ba[name] = wrap_bound_arg(tx, args[i])
  133. elif name in rem_kw:
  134. if name in spec.posonly_names:
  135. raise_observed_exception(
  136. TypeError,
  137. tx,
  138. args=[ConstantVariable.create(f"{name} is positional-only")],
  139. )
  140. ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
  141. elif name in spec.pos_default_map:
  142. idx = spec.pos_default_map[name]
  143. default_source = None
  144. if fn_source and not (
  145. ConstantVariable.is_literal(spec.defaults[idx])
  146. and config.skip_guards_on_constant_func_defaults
  147. ):
  148. default_source = DefaultsSource(fn_source, idx)
  149. ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source)
  150. else:
  151. raise_observed_exception(
  152. TypeError,
  153. tx,
  154. args=[
  155. ConstantVariable.create(
  156. f"Missing required positional argument: {name}"
  157. )
  158. ],
  159. )
  160. # 2) *args
  161. extra = args[len(spec.all_pos_names) :]
  162. if spec.varargs_name:
  163. ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra))
  164. elif extra:
  165. raise_observed_exception(
  166. TypeError,
  167. tx,
  168. args=[
  169. ConstantVariable.create(
  170. f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}"
  171. )
  172. ],
  173. )
  174. # 3) Keyword-only
  175. for name in spec.kwonly_names:
  176. if name in rem_kw:
  177. ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
  178. elif name in spec.kwdefaults:
  179. kwdefault_source = None
  180. if fn_source:
  181. kwdefault_source = DefaultsSource(fn_source, name, is_kw=True)
  182. ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source)
  183. else:
  184. raise_observed_exception(
  185. TypeError,
  186. tx,
  187. args=[
  188. ConstantVariable.create(
  189. f"Missing required keyword-only argument: {name}"
  190. )
  191. ],
  192. )
  193. # 4) **kwargs
  194. if spec.varkw_name:
  195. ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw)
  196. elif rem_kw:
  197. raise_observed_exception(
  198. TypeError,
  199. tx,
  200. args=[
  201. ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}")
  202. ],
  203. )
  204. return ba
  205. def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
  206. # Source propagation is best effort since not every object we encounter has a source to begin with.
  207. if isinstance(val, VariableTracker):
  208. return val
  209. elif not source:
  210. return VariableTracker.build(tx, val)
  211. else:
  212. # Create a lazy variable to avoid guarding on __defaults__ unless really
  213. # needed.
  214. return variables.LazyVariableTracker.create(val, source)
  215. def wrap_args_kwargs(tx: "InstructionTranslator", result):
  216. for k, v in list(result.items()):
  217. if isinstance(v, (tuple, dict)):
  218. # args/kwargs
  219. result[k] = wrap_bound_arg(tx, v)
  220. def init_cellvars(parent, result: dict[str, VariableTracker], code):
  221. """
  222. Update `result` to add mapping from local name to new cells created
  223. directly by `code`, or update SideEffects in `parent` if the a local cell is
  224. already in `result` (cell argument).
  225. """
  226. side_effects = parent.output.side_effects
  227. for name in code.co_cellvars:
  228. new_cell = side_effects.track_cell_new()
  229. if name in result:
  230. # This handles when a function argument is a cell (e.g., captured by
  231. # a nested func). See `MAKE_CELL` bytecode for more info.
  232. side_effects.store_cell(new_cell, result.pop(name))
  233. result[name] = new_cell
  234. def _create_nested_fn(
  235. code, f_globals, name, defaults, closure, kwdefaults, annotations
  236. ):
  237. from types import FunctionType
  238. func = FunctionType(code, f_globals, name, defaults, closure)
  239. func.__kwdefaults__ = kwdefaults
  240. if isinstance(annotations, tuple):
  241. from itertools import pairwise
  242. annotations = dict(pairwise(annotations))
  243. # TypeError: __annotations__ must be set to a dict object
  244. assert annotations is None or isinstance(annotations, dict)
  245. func.__annotations__ = annotations
  246. return func
  247. fn_known_dunder_attrs = {
  248. "__annotations__",
  249. "__defaults__",
  250. "__kwdefaults__",
  251. "__code__",
  252. "__globals__",
  253. "__closure__",
  254. "__doc__",
  255. }
  256. def fn_var_getattr(tx, fn, source, name):
  257. source = source and AttrSource(source, name)
  258. if source and name == "__annotations__":
  259. # We get a large number of silly guards from annotations from inspect
  260. # module. Changing annotations is rare, and it impacting the extracted
  261. # graph is even rarer. So skip guards.
  262. source = SkipGuardSource(source)
  263. try:
  264. subobj = inspect.getattr_static(fn, name)
  265. except AttributeError:
  266. # function does not have a __getattr__ or __getattribute__ method,
  267. # so we can safely assume that this attribute is absent
  268. raise_observed_exception(AttributeError, tx)
  269. # Special handling for known dunder attributes
  270. if name in fn_known_dunder_attrs:
  271. subobj = getattr(fn, name)
  272. if source:
  273. return variables.LazyVariableTracker.create(subobj, source)
  274. return VariableTracker.build(tx, subobj)
  275. class BaseUserFunctionVariable(VariableTracker):
  276. def get_filename(self):
  277. return self.get_code().co_filename
  278. def get_name(self):
  279. return self.get_code().co_name
  280. def call_function(
  281. self,
  282. tx: "InstructionTranslator",
  283. args: "list[VariableTracker]",
  284. kwargs: "dict[str, VariableTracker]",
  285. ) -> "VariableTracker":
  286. return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  287. def call_obj_hasattr(
  288. self, tx: "InstructionTranslator", name: str
  289. ) -> VariableTracker:
  290. result = False
  291. try:
  292. result = hasattr(self.get_function(), name)
  293. except NotImplementedError:
  294. if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
  295. result = True
  296. return variables.ConstantVariable.create(result)
  297. def inspect_parameter_names(self):
  298. return list(inspect.signature(self.get_function()).parameters)
  299. def closure_vars(self, tx):
  300. return {}
  301. class UserFunctionVariable(BaseUserFunctionVariable):
  302. """Some unsupported user-defined global function"""
  303. _nonvar_fields = {
  304. "fn",
  305. "is_constant",
  306. *BaseUserFunctionVariable._nonvar_fields,
  307. }
  308. @classmethod
  309. def create_with_source(cls, value, source):
  310. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
  311. return cls(value, source=source)
  312. def __init__(self, fn, is_constant=False, **kwargs) -> None:
  313. super().__init__(**kwargs)
  314. if getattr(fn, "_dynamo_marked_constant", False):
  315. # This method should be treated as a constant for the purposes of compilation
  316. self.is_constant = True
  317. else:
  318. self.is_constant = False
  319. # TODO putting this here to avoid duplication, because we could hit this
  320. # from several paths (e.g., SuperVariable or `var_getattr`s).
  321. if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)):
  322. unimplemented_v2(
  323. gb_type="can't handle functions not implemented in python ",
  324. context=f"{fn}",
  325. explanation="Dynamo can only handle functions defined in python",
  326. hints=[
  327. "Move usage of this function out of `torch.compile` region",
  328. *graph_break_hints.INFERENCE_MODE,
  329. ],
  330. )
  331. # TODO(anijain2305) - Replace directly calling UserFunctionVariable with
  332. # VariableBuilder, which handles the wrapping of _torchdynamo_inline.
  333. # unpack @torch._dynamo.optimize()(fn) wrapped function
  334. fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
  335. self.fn: types.FunctionType = fn
  336. def as_python_constant(self):
  337. if istype(self, UserFunctionVariable):
  338. return self.fn
  339. # subclasses (such as methods) usually aren't a constant
  340. return super().as_python_constant()
  341. def self_args(self):
  342. return []
  343. def get_function(self):
  344. return self.fn
  345. def get_code(self):
  346. return self.fn.__code__
  347. def python_type(self):
  348. return types.FunctionType
  349. def has_self(self):
  350. return getattr(self.fn, "__self__", None) is not None
  351. def get_globals(self):
  352. return self.fn.__globals__
  353. def get_source(self):
  354. source = self.source
  355. if source and isinstance(self, variables.UserMethodVariable):
  356. source = self.source_fn
  357. return source
  358. def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]:
  359. """
  360. Assume `args` and `kwargs` are VariableTracker arguments for a call to
  361. this function, create new bindings for initial locals.
  362. """
  363. assert not self.is_constant
  364. fn: types.FunctionType = self.fn
  365. if not isinstance(fn, FunctionType):
  366. raise TypeError("Only supports regular Python functions.")
  367. root_tx = parent.output.root_tx
  368. source = self.get_source()
  369. result = bind_args_cached(fn, root_tx, source, args, kwargs)
  370. init_cellvars(parent, result, fn.__code__)
  371. closure = self.fn.__closure__ or ()
  372. assert len(closure) == len(self.fn.__code__.co_freevars)
  373. for idx, name, cell in zip(
  374. itertools.count(), self.fn.__code__.co_freevars, closure
  375. ):
  376. # TODO refactor these 3 branches.
  377. side_effects = parent.output.side_effects
  378. if cell in side_effects:
  379. cell_var = side_effects[cell]
  380. elif source:
  381. closure_cell = GetItemSource(ClosureSource(source), idx)
  382. closure_cell_contents = AttrSource(closure_cell, "cell_contents")
  383. try:
  384. contents_var = VariableTracker.build(
  385. parent, cell.cell_contents, closure_cell_contents
  386. )
  387. except ValueError:
  388. # Cell has not yet been assigned
  389. contents_var = variables.DeletedVariable()
  390. cell_var = side_effects.track_cell_existing(
  391. closure_cell, cell, contents_var
  392. )
  393. else:
  394. # TODO figure out why source isn't available here, and whether
  395. # we can fix that and remove this branch.
  396. try:
  397. contents_var = VariableTracker.build(parent, cell.cell_contents)
  398. except ValueError:
  399. # Cell has not yet been assigned
  400. contents_var = variables.DeletedVariable()
  401. cell_var = side_effects.track_cell_existing(None, cell, contents_var)
  402. result[name] = cell_var
  403. return result
  404. def var_getattr(self, tx: "InstructionTranslator", name: str):
  405. if name in cmp_name_to_op_mapping:
  406. return variables.GetAttrVariable(self, name)
  407. source = self.get_source()
  408. return fn_var_getattr(tx, self.fn, source, name)
  409. def call_obj_hasattr(
  410. self, tx: "InstructionTranslator", name: str
  411. ) -> VariableTracker:
  412. result = hasattr(self.fn, name)
  413. return variables.ConstantVariable.create(result)
  414. def call_function(
  415. self,
  416. tx: "InstructionTranslator",
  417. args: "list[VariableTracker]",
  418. kwargs: "dict[str, VariableTracker]",
  419. ) -> "VariableTracker":
  420. # Handle patch_dynamo_config call
  421. if self.fn is torch._dynamo.patch_dynamo_config:
  422. try:
  423. args_const = [arg.as_python_constant() for arg in args]
  424. kwargs_const = {
  425. key: val.as_python_constant() for key, val in kwargs.items()
  426. }
  427. changes = torch._dynamo.patch_dynamo_config(
  428. *args_const, **kwargs_const
  429. ).changes
  430. return variables.DynamoConfigPatchVariable(changes)
  431. except AsPythonConstantNotImplementedError as e:
  432. raise RuntimeError(
  433. "Cannot convert patch_dynamo_config args/kwargs to constants. "
  434. "Please fix your call to patch_dynamo_config by using simpler inputs. "
  435. f"args: {args}, kwargs: {kwargs}"
  436. ) from e
  437. elif self.fn is torch._dynamo.error_on_graph_break:
  438. try:
  439. bound = inspect.signature(self.fn).bind(*args, **kwargs)
  440. error_on_graph_break = bound.arguments[
  441. "error_on_graph_break"
  442. ].as_python_constant()
  443. assert isinstance(error_on_graph_break, bool)
  444. return variables.ErrorOnGraphBreakVariable(error_on_graph_break)
  445. except Exception as e:
  446. raise RuntimeError(
  447. "Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). "
  448. f"args: {args}, kwargs: {kwargs}"
  449. ) from e
  450. # Handle a `nonstrict_trace(fn)` call
  451. elif self.fn is torch._dynamo.nonstrict_trace:
  452. bound = inspect.signature(self.fn).bind(*args, **kwargs)
  453. fn_var = bound.args[0]
  454. if not isinstance(fn_var, BaseUserFunctionVariable):
  455. typ = fn_var.python_type()
  456. msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
  457. unimplemented_v2(
  458. gb_type="TypeError from user code",
  459. context=f"call_function({self.value}, {args}, {kwargs})",
  460. explanation=msg,
  461. hints=[
  462. *graph_break_hints.USER_ERROR,
  463. ],
  464. )
  465. if not isinstance(fn_var, UserFunctionVariable):
  466. fn_name = fn_var.get_name()
  467. msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950
  468. unimplemented_v2(
  469. gb_type="Limitation of `nonstrict_trace",
  470. context=f"{self}",
  471. explanation=msg,
  472. hints=[
  473. f"make sure definition of {fn_name} is outside ",
  474. "`torch.compile` region",
  475. ],
  476. )
  477. fn = fn_var.fn
  478. return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)
  479. if self.is_constant:
  480. return invoke_and_store_as_constant(
  481. tx, self.fn, self.get_name(), args, kwargs
  482. )
  483. if (
  484. not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
  485. and self.fn
  486. is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer
  487. ):
  488. with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer(
  489. tx
  490. ):
  491. return super().call_function(tx, args, kwargs)
  492. if (
  493. tx.output.current_tracer.under_activation_checkpoint
  494. and not tx.output.current_tracer.allow_side_effects_under_checkpoint
  495. ):
  496. try:
  497. from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
  498. except Exception:
  499. FSDPState = None
  500. if FSDPState is not None and self.fn in [
  501. FSDPState._pre_forward,
  502. FSDPState._post_forward,
  503. ]:
  504. with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
  505. return super().call_function(tx, args, kwargs)
  506. return super().call_function(tx, args, kwargs)
  507. class BuiltinMethodVariable(BaseUserFunctionVariable):
  508. def __init__(self, fn, is_constant=False, **kwargs) -> None:
  509. super().__init__(**kwargs)
  510. assert isinstance(fn, types.BuiltinMethodType)
  511. self.fn = fn
  512. @staticmethod
  513. def is_supported_builtin_method(obj):
  514. method_self = obj.__self__
  515. method_name = obj.__name__
  516. # TODO(anijain2305) - Add support for more builtin methods
  517. # Supports tuple.__new__ and frozenset({....}).__contains__
  518. return (method_self is tuple and method_name == "__new__") or (
  519. type(method_self) is frozenset and method_name == "__contains__"
  520. )
  521. def call_function(
  522. self,
  523. tx: "InstructionTranslator",
  524. args: "list[VariableTracker]",
  525. kwargs: "dict[str, VariableTracker]",
  526. ) -> "VariableTracker":
  527. method_self = self.fn.__self__
  528. name = self.fn.__name__
  529. obj_source = self.source and AttrSource(self.source, "__self__")
  530. obj_vt = VariableTracker.build(tx, method_self, obj_source)
  531. return obj_vt.call_method(tx, name, args, kwargs)
  532. class LocalGeneratorObjectVariable(VariableTracker):
  533. def __init__(
  534. self,
  535. code: types.CodeType,
  536. f_globals,
  537. inline_tracer: Optional["InstructionTranslator"],
  538. **kwargs,
  539. ):
  540. super().__init__(**kwargs)
  541. self.code = code
  542. self.f_globals = f_globals
  543. self.inline_tracer = inline_tracer
  544. def get_code(self):
  545. return self.code
  546. def get_filename(self):
  547. return self.get_code().co_filename
  548. def get_name(self):
  549. return self.get_code().co_name
  550. def get_function(self):
  551. raise NotImplementedError
  552. def has_self(self):
  553. return False
  554. def __name__(self):
  555. return self.get_name()
  556. def __str__(self):
  557. return f"{self.__class__.__name__}({self.get_name()})"
  558. __repr__ = __str__
  559. def reconstruct(self, codegen: "PyCodegen"):
  560. from torch._dynamo.side_effects import disallow_side_effects_in_generator
  561. from torch._dynamo.symbolic_convert import (
  562. InstructionTranslator,
  563. save_and_restart_speculation_log,
  564. temporarely_allow_writes_to_output_graph,
  565. )
  566. tx = InstructionTranslator.current_tx()
  567. save = save_and_restart_speculation_log(tx)
  568. disallow = disallow_side_effects_in_generator(tx)
  569. temp = temporarely_allow_writes_to_output_graph(tx)
  570. with save, disallow, temp:
  571. tracer = self._get_inline_tracer(tx)
  572. if not tracer.generator_exhausted:
  573. self.remaining_items = self.force_unpack_var_sequence(tx)
  574. variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen)
  575. def bind_args(self, tx, args, kwargs):
  576. return self.fn.bind_args(tx, args, kwargs)
  577. def get_globals(self):
  578. return self.f_globals
  579. def python_type(self):
  580. return types.GeneratorType
  581. def _get_inline_tracer(self, tx):
  582. from torch._dynamo.symbolic_convert import InliningInstructionTranslator
  583. if self.inline_tracer is None:
  584. self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
  585. tx, self, [], {}
  586. )
  587. return self.inline_tracer
  588. def next_variable(self, tx):
  589. tracer = self._get_inline_tracer(tx)
  590. if self._is_generator_exhausted():
  591. raise_observed_exception(StopIteration, tx)
  592. try:
  593. # Hierarchically, tx can be seen as the parent of the inline tracer
  594. # created on call_function. Any exception needs to be propagated to tx
  595. # for Dynamo to behave correctly
  596. with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
  597. return tracer.inline_call_()
  598. except ObservedException as e:
  599. tracer.generator_exhausted = True
  600. raise e
  601. except InfiniteGeneratorError:
  602. # test/dynamo/test_misc.py::test_iterator_limit
  603. raise
  604. except Unsupported as e:
  605. torch._dynamo.eval_frame.skip_code(self.get_code())
  606. raise SkipFrame from e
  607. finally:
  608. counters["unimplemented"] |= counters["inline_call"]
  609. def call_obj_hasattr(self, tx, name):
  610. if name in self.python_type().__dict__:
  611. return ConstantVariable.create(True)
  612. return ConstantVariable.create(False)
  613. def has_unpack_var_sequence(self, tx):
  614. return False
  615. def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
  616. return True
  617. def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
  618. result = []
  619. self.force_apply_to_var_sequence(tx, result.append)
  620. return result
  621. def force_apply_to_var_sequence(self, tx, fn) -> None:
  622. while True:
  623. try:
  624. fn(self.next_variable(tx))
  625. except ObservedUserStopIteration:
  626. handle_observed_exception(tx)
  627. break
  628. def _setup_exception(self, tx, exc):
  629. tracer = self._get_inline_tracer(tx)
  630. try:
  631. tracer._raise_exception_variable(exc)
  632. except ObservedException as e:
  633. # if no handler is available (i.e. user code doesn't catch it), the
  634. # exception is raised again.
  635. tracer.exception_handler(e)
  636. def _is_generator_just_started(self):
  637. return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
  638. def _is_generator_exhausted(self):
  639. return getattr(self.inline_tracer, "generator_exhausted", False)
  640. def call_method(
  641. self,
  642. tx: "InstructionTranslator",
  643. name: str,
  644. args: "list[VariableTracker]",
  645. kwargs: "dict[str, VariableTracker]",
  646. ) -> "VariableTracker":
  647. if name == "__next__":
  648. return self.next_variable(tx)
  649. elif name == "__iter__":
  650. # iter(gen) returns itself
  651. return self
  652. elif name == "send":
  653. # Sends a value into the generator function. Returns the next value
  654. # yielded by the generator, or raises StopIteration if the generator
  655. # exits without yielding another value
  656. if self._is_generator_just_started() and len(args):
  657. # can't send non-None value to a just-started generator
  658. # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
  659. if not all(
  660. isinstance(arg, ConstantVariable) and arg.value is None
  661. for arg in args
  662. ):
  663. raise_observed_exception(TypeError, tx)
  664. tracer = self._get_inline_tracer(tx)
  665. tracer.push_many(args)
  666. return self.next_variable(tx)
  667. elif name == "close":
  668. # * Raises a GeneratorExit at the point where the generator function was paused.
  669. # * If the generator function catches the exception and returns a
  670. # value, this value is returned from close() - Python 3.13+
  671. # * If the generator function is already closed, or raises GeneratorExit
  672. # (by not catching the exception), close() returns None.
  673. # * If the generator yields a value, a RuntimeError is raised.
  674. # * If the generator raises any other exception, it is propagated to the caller.
  675. # * If the generator has already exited due to an exception or normal
  676. # exit, close() returns None and has no other effect.
  677. # Return None if close is called on a just-started generator
  678. # See test GeneratorCloseCpythonTests::test_close_not_started
  679. tracer = self._get_inline_tracer(tx)
  680. if self._is_generator_just_started() or self._is_generator_exhausted():
  681. tracer.generator_exhausted = True
  682. return variables.ConstantVariable(None)
  683. # Raise GeneratorExit to see if user code catches it. Any other exception
  684. # is propagated to the parent frame.
  685. try:
  686. self._setup_exception(
  687. tx, variables.ExceptionVariable(GeneratorExit, ())
  688. )
  689. # There's an extra block on Python 3.12+ to handle StopIteration
  690. # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397
  691. #
  692. # 1 0 RETURN_GENERATOR
  693. # 2 POP_TOP
  694. # 4 RESUME 0
  695. # 2 6 LOAD_CONST 1 (1)
  696. # 8 YIELD_VALUE 1
  697. # 10 RESUME 1
  698. # 12 POP_TOP
  699. # 14 RETURN_CONST 0 (None)
  700. # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR)
  701. # 18 RERAISE 1
  702. # ExceptionTable:
  703. # 4 to 14 -> 16 [0] lasti
  704. if (
  705. sys.version_info >= (3, 12)
  706. and tracer.next_instruction.opname == "CALL_INTRINSIC_1"
  707. ):
  708. tracer.generator_exhausted = True
  709. return variables.ConstantVariable(None)
  710. except ObservedGeneratorExit:
  711. # If it doesn't catch, we just return None, as per the text above
  712. tracer.generator_exhausted = True
  713. return variables.ConstantVariable(None)
  714. try:
  715. # Raise RuntimeError if the generator yields any other value
  716. if self.next_variable(tx):
  717. raise_observed_exception(RuntimeError, tx)
  718. except ObservedGeneratorExit:
  719. tracer.generator_exhausted = True
  720. return variables.ConstantVariable(None)
  721. except ObservedUserStopIteration:
  722. # In Python 3.13+, one can capture GeneratorExit and return a value
  723. # See test_generator.py::test_close_capture_GeneratorExit_return
  724. # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26
  725. # https://github.com/python/cpython/pull/104771
  726. assert tracer.symbolic_result is not None
  727. return tracer.symbolic_result
  728. elif name == "throw":
  729. # * Raises an exception at the point where the generator was paused, and
  730. # returns the next value yielded by the generator.
  731. # * If the generator exits without yielding, raise StopIteration
  732. # * If the generator function does not catch the passed-in exception,
  733. # or raises a different exception, then that exception propagates to the caller.
  734. # Setup the exception table and jump target in case of try...finally
  735. tracer = self._get_inline_tracer(tx)
  736. try:
  737. # In Python 3.9, the exception is represented as a triple (typ, val, tb)
  738. # In such cases, we re-raise the exception object given to avoid
  739. # creating a new object, so that IS_OP works.
  740. # See: https://github.com/pytorch/pytorch/pull/146496
  741. self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
  742. except ObservedException: # noqa: TRY203
  743. # propagate the exception back to the parent caller
  744. raise
  745. retval = self.next_variable(tx)
  746. # The exception raised before is still active. We need to check the exception
  747. # table one more time to find the next target. But why? Let’s walk
  748. # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
  749. #
  750. # z = 0
  751. # def whoo():
  752. # global z
  753. # z = 0
  754. # try:
  755. # yield 1
  756. # except ValueError:
  757. # yield 2
  758. # finally:
  759. # z += 1
  760. # z += 10
  761. #
  762. # gen = whoo()
  763. # next(gen)
  764. # gen.throw(ValueError)
  765. # print('z', z) -> z = 1
  766. #
  767. # ...
  768. # >> 58 PUSH_EXC_INFO
  769. #
  770. # 8 60 LOAD_GLOBAL 2 (ValueError)
  771. # 70 CHECK_EXC_MATCH
  772. # 72 POP_JUMP_IF_FALSE 7 (to 88)
  773. # 74 POP_TOP
  774. #
  775. # 9 76 LOAD_CONST 3 (2)
  776. # 78 YIELD_VALUE 3 <------ ValueError is still active here
  777. # 80 RESUME 1
  778. # 82 POP_TOP
  779. # 84 POP_EXCEPT
  780. # 86 jump_backward 34 (to 20)
  781. # ...
  782. #
  783. # ExceptionTable:
  784. # 4 to 8 -> 124 [0] lasti
  785. # 12 to 18 -> 58 [0]
  786. # 20 to 56 -> 124 [0] lasti
  787. # 58 to 82 -> 90 [1] lasti <------ move to 90
  788. # 84 to 86 -> 96 [0]
  789. # 88 to 88 -> 90 [1] lasti
  790. # 90 to 94 -> 96 [0]
  791. # 96 to 116 -> 118 [1] lasti
  792. # 118 to 122 -> 124 [0] lasti
  793. #
  794. # In this scenario, a generator can yield after `throw()` is called. Even
  795. # after the exception is raised a few lines above, it remains active
  796. # within the `78 YIELD_VALUE` instruction. When the generator resumes
  797. # after the second yield on instruction `80 RESUME`, we cannot simply
  798. # return the control flow to the next instruction. Instead, one must
  799. # check the exception table (or equivalent) to find the next target
  800. # In this case, it says the instruction pointer must be moved to 90.
  801. #
  802. # Without this step, if we let the trace proceed to the next
  803. # instruction, it would follow the control flow where the exception
  804. # raised by `throw()` was handled and swallowed, potentially leading
  805. # to incorrect behavior.
  806. exc_type = type("__InternalThrowException", (Exception,), {})
  807. try:
  808. self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
  809. self.next_variable(tx)
  810. except get_dynamo_observed_exception(exc_type):
  811. # We should get back the exception raised before.
  812. pass
  813. else:
  814. raise_observed_exception(RuntimeError, tracer)
  815. return retval
  816. super().call_method(tx, name, args, kwargs)
  817. class ContextlibContextManagerLocalGeneratorObjectVariable(
  818. LocalGeneratorObjectVariable
  819. ):
  820. """
  821. .. note::
  822. This is only used when the function is annotated with @contextlib.contextmanager
  823. It is a special case of a generator function as we do not allow return a context manager
  824. from a torch.compile function.
  825. """
  826. class LocalGeneratorFunctionVariable(BaseUserFunctionVariable):
  827. """functions that behaves like iterators
  828. .. note::
  829. This is a wrapper around (Nested)UserFunctionVariable
  830. """
  831. def __init__(
  832. self,
  833. vt: VariableTracker,
  834. *,
  835. generator_cls=LocalGeneratorObjectVariable,
  836. **kwargs,
  837. ):
  838. super().__init__(**kwargs)
  839. self.vt = vt
  840. self.generator_cls = generator_cls
  841. def __getattr__(self, name):
  842. if name in self.__class__.__dict__.keys():
  843. return getattr(self, name)
  844. return getattr(self.vt, name)
  845. def _build_inline_tracer(self, tx, args, kwargs):
  846. from torch._dynamo.symbolic_convert import InliningInstructionTranslator
  847. return InliningInstructionTranslator.build_inline_tracer(
  848. tx,
  849. self,
  850. args,
  851. kwargs,
  852. )
  853. def call_function(
  854. self,
  855. tx: "InstructionTranslator",
  856. args: "list[VariableTracker]",
  857. kwargs: "dict[str, VariableTracker]",
  858. ) -> "VariableTracker":
  859. if not is_generator(self.vt.get_code()):
  860. unimplemented_v2(
  861. gb_type="non-generator contextlib.contextmanager",
  862. context=str(self.vt.get_code()),
  863. explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator"
  864. ", i.e. does not use `yield`",
  865. hints=[
  866. "Use `yield` in the function body instead of `return`.",
  867. "Remove the `@contextlib.contextmanager` decorator.",
  868. ],
  869. )
  870. inline_tracer = self._build_inline_tracer(tx, args, kwargs)
  871. code = self.vt.get_code()
  872. f_globals = self.vt.get_globals()
  873. # calling a generator returns a generator object
  874. return self.generator_cls(
  875. code,
  876. f_globals,
  877. inline_tracer,
  878. source=self.source,
  879. )
  880. class FunctionDecoratedByContextlibContextManagerVariable(
  881. LocalGeneratorFunctionVariable
  882. ):
  883. """
  884. .. note::
  885. This is only used when the function is annotated with @contextlib.contextmanager
  886. """
  887. def __init__(self, vt, **kwargs):
  888. super().__init__(
  889. vt,
  890. generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable,
  891. **kwargs,
  892. )
  893. def _build_inline_tracer(self, tx, args, kwargs):
  894. # NOTE: This only exists to not break support for context manager when
  895. # config.enable_faithful_generator_behavior = False and
  896. # config.enable_trace_contextlib = True. In case the former is false,
  897. # Dynamo should still be able to trace through @contextmanager functions
  898. tracer = super()._build_inline_tracer(tx, args, kwargs)
  899. assert isinstance(
  900. tracer,
  901. torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator,
  902. )
  903. tracer.is_generator_from_ctx_manager = True
  904. return tracer
  905. class UserMethodVariable(UserFunctionVariable):
  906. """Some unsupported user-defined method"""
  907. def __init__(self, fn, obj, source_fn=None, **kwargs) -> None:
  908. super().__init__(fn=fn, **kwargs)
  909. self.obj = obj
  910. self.source_fn = source_fn
  911. # Note on source and source_fn
  912. # Be careful with `source` when delegating to UserFunctionVariable
  913. # (base-class) methods. In this __init__, `source` is a *bound method*
  914. # object, but the base class expects the underlying *function* object.
  915. # One way is to simplly use `__func__` to unwrap it.
  916. #
  917. # For recursive dict-tag optimizations, it can be faster to fetch the
  918. # function directly from `cls.__dict__`; that’s why we pass on
  919. # `source_fn`. Whenever it is possible to access the function from
  920. # cls.__dict__, we pass that on to `source_fn`. Because bind_args
  921. # operates on the unbound function, most guards should target
  922. # `source_fn` rather than the original `source`.
  923. if source_fn is None and kwargs.get("source") is not None:
  924. self.source_fn = AttrSource(kwargs.get("source"), "__func__")
  925. def __repr__(self) -> str:
  926. return f"{self.__class__.__name__}({self.fn}, {self.obj})"
  927. def self_args(self):
  928. return [self.obj]
  929. def python_type(self):
  930. return types.MethodType
  931. def call_function(
  932. self,
  933. tx: "InstructionTranslator",
  934. args: "list[VariableTracker]",
  935. kwargs: "dict[str, VariableTracker]",
  936. ) -> "VariableTracker":
  937. # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually
  938. # a `nonstrict_trace`-ed function will be wrapped by
  939. # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`,
  940. # but in the case of method, we manually wrap it with `UserMethodVariable`
  941. # inside `UserDefinedObjectVariable.var_getattr`.
  942. #
  943. # We might be able to simplify this away by canonicalizing the
  944. # function/method wrapping code paths.
  945. from ..trace_rules import is_nonstrict_trace_callable
  946. if is_nonstrict_trace_callable(self.fn):
  947. call_args = [*self.self_args(), *args]
  948. var = variables.TorchInGraphFunctionVariable(
  949. self.fn, nonstrict_traceable=True
  950. )
  951. return var.call_function(tx, call_args, kwargs)
  952. # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
  953. # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
  954. # since we ensure `forward` of allowed modules can be traced by AOT safely.
  955. # Note this is not only for allowed modules, as user customized modules can extend from
  956. # allowed modules but using parent's `forward` method, which is also covered by this branch.
  957. # If we are tracing the higher order op, we want Dynamo to step inside
  958. # the module call so that Dynamo can see the underlying parameters and
  959. # buffers and raise them as inputs to the graph. The is_root_tracer
  960. # check bypasses the if condition for non-root tracers and directly
  961. # calls the super().call_function at the end, which is basically
  962. # equivalent of inlining the method.
  963. if tx.output.is_root_tracer() and isinstance(
  964. self.obj, variables.NNModuleVariable
  965. ):
  966. module_attr = getattr(self.fn, "__module__", "")
  967. # inline torch.nn.utils.parametrize
  968. if (
  969. module_attr is not None
  970. and module_attr.startswith("torch.nn.")
  971. and module_attr != "torch.nn.utils.parametrize"
  972. or self.is_constant
  973. ):
  974. return self.obj.call_method(
  975. tx, self.fn.__name__, args, kwargs, constant=self.is_constant
  976. )
  977. elif (
  978. _fsdp_param_group is not None
  979. and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state
  980. ):
  981. return variables.TorchCtxManagerClassVariable(self.fn).call_function(
  982. tx, (self.obj, *args), kwargs
  983. )
  984. if self.is_constant:
  985. fn = getattr(self.obj.value, self.fn.__name__)
  986. return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
  987. return super().call_function(tx, args, kwargs)
  988. def inspect_parameter_names(self):
  989. return super().inspect_parameter_names()[1:]
  990. def var_getattr(self, tx: "InstructionTranslator", name: str):
  991. if name == "__self__":
  992. return self.obj
  993. if name == "__func__":
  994. # We might have a better way to access the function object, this
  995. # information is stored in self.source_fn, use that to construct the
  996. # variable tracker.
  997. return VariableTracker.build(tx, self.fn, self.source_fn)
  998. return super().var_getattr(tx, name)
  999. class WrappedUserMethodVariable(UserMethodVariable):
  1000. def __init__(self, wrapped, context, **kwargs) -> None:
  1001. kwargs.pop("fn", None)
  1002. kwargs.pop("obj", None)
  1003. super().__init__(wrapped.fn, wrapped.obj, **kwargs)
  1004. self.wrapped = wrapped
  1005. self.context = context
  1006. def call_function(
  1007. self,
  1008. tx: "InstructionTranslator",
  1009. args: "list[VariableTracker]",
  1010. kwargs: "dict[str, VariableTracker]",
  1011. ) -> "VariableTracker":
  1012. self.context.enter(tx)
  1013. result = super().call_function(tx, args, kwargs)
  1014. self.context.exit(tx)
  1015. return result
  1016. def reconstruct(self, codegen):
  1017. codegen.add_push_null(lambda: codegen(self.context))
  1018. codegen(self.wrapped)
  1019. codegen.extend_output(create_call_function(1, False))
  1020. class WrappedUserFunctionVariable(UserFunctionVariable):
  1021. def __init__(self, wrapped, context, **kwargs) -> None:
  1022. kwargs.pop("fn", None)
  1023. super().__init__(wrapped.fn, **kwargs)
  1024. self.wrapped = wrapped
  1025. self.context = context
  1026. def call_function(
  1027. self,
  1028. tx: "InstructionTranslator",
  1029. args: "list[VariableTracker]",
  1030. kwargs: "dict[str, VariableTracker]",
  1031. ) -> "VariableTracker":
  1032. self.context.enter(tx)
  1033. result = super().call_function(tx, args, kwargs)
  1034. self.context.exit(tx)
  1035. return result
  1036. def reconstruct(self, codegen):
  1037. codegen.add_push_null(lambda: codegen(self.context))
  1038. codegen(self.wrapped)
  1039. codegen.extend_output(create_call_function(1, False))
  1040. def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs):
  1041. def convert(x):
  1042. if isinstance(x, variables.TensorVariable):
  1043. return x.get_real_value()
  1044. return x.as_python_constant()
  1045. args = [convert(x) for x in args]
  1046. kwargs = {k: convert(v) for k, v in kwargs.items()}
  1047. res = fn(*args, **kwargs)
  1048. return tx.output.register_attr_or_module(
  1049. res,
  1050. name,
  1051. source=ConstantSource(name),
  1052. )
  1053. class NestedUserFunctionVariable(BaseUserFunctionVariable):
  1054. _nonvar_fields = {
  1055. "f_globals",
  1056. *BaseUserFunctionVariable._nonvar_fields,
  1057. }
  1058. def __init__(
  1059. self,
  1060. fn_name,
  1061. code,
  1062. f_globals,
  1063. defaults,
  1064. kwdefaults,
  1065. annotations,
  1066. closure,
  1067. # This is present when this function is created by
  1068. # `functools.wrap(wrapped_fn)(this_fn)`.
  1069. wrapped_fn=None,
  1070. **kwargs,
  1071. ) -> None:
  1072. if kwargs.get("mutation_type") is None:
  1073. kwargs.update(mutation_type=AttributeMutationNew())
  1074. super().__init__(**kwargs)
  1075. assert isinstance(fn_name.as_python_constant(), str)
  1076. assert isinstance(code.as_python_constant(), types.CodeType)
  1077. assert isinstance(f_globals, dict)
  1078. self.fn_name = fn_name
  1079. self.code = code
  1080. self.f_globals = f_globals
  1081. self.defaults = defaults
  1082. self.kwdefaults = kwdefaults
  1083. self.annotations = annotations
  1084. self.closure = closure
  1085. self.wrapped_fn: Optional[VariableTracker] = wrapped_fn
  1086. def self_args(self):
  1087. return []
  1088. def get_code(self):
  1089. return self.code.as_python_constant()
  1090. def python_type(self):
  1091. return types.FunctionType
  1092. def get_function(self):
  1093. if self.closure:
  1094. raise NotImplementedError
  1095. func = types.FunctionType(
  1096. self.code.as_python_constant(),
  1097. self.f_globals,
  1098. self.fn_name.as_python_constant(),
  1099. )
  1100. if self.defaults:
  1101. func.__defaults__ = self.defaults.as_python_constant()
  1102. if self.kwdefaults:
  1103. func.__kwdefaults__ = self.kwdefaults.as_python_constant()
  1104. if self.annotations:
  1105. annotations = self.annotations.as_python_constant()
  1106. if isinstance(annotations, tuple):
  1107. from itertools import pairwise
  1108. annotations = dict(pairwise(annotations))
  1109. # TypeError: __annotations__ must be set to a dict object
  1110. assert isinstance(annotations, dict)
  1111. func.__annotations__ = annotations
  1112. return func
  1113. def call_setattr(
  1114. self,
  1115. tx: "InstructionTranslator",
  1116. name_var: VariableTracker,
  1117. val: VariableTracker,
  1118. ):
  1119. tx.output.side_effects.store_attr(self, name_var.value, val)
  1120. return ConstantVariable(None)
  1121. def call_method(self, tx, name, args, kwargs):
  1122. if name == "__setattr__":
  1123. return self.call_setattr(tx, *args)
  1124. return super().call_method(tx, name, args, kwargs)
  1125. def has_closure(self):
  1126. return self.closure is not None
  1127. def const_getattr(self, tx, name):
  1128. if name == "__name__":
  1129. return self.fn_name.as_python_constant()
  1130. return super().const_getattr(tx, name)
  1131. def has_self(self):
  1132. return False
  1133. def get_globals(self):
  1134. return self.f_globals
  1135. def bind_args(self, parent, args, kwargs):
  1136. code = self.get_code()
  1137. func = types.FunctionType(
  1138. code,
  1139. self.f_globals,
  1140. self.fn_name.as_python_constant(),
  1141. tuple(self.defaults.items) if self.defaults else None,
  1142. tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
  1143. )
  1144. if self.kwdefaults:
  1145. func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
  1146. bound = inspect.signature(func).bind(*args, **kwargs)
  1147. bound.apply_defaults()
  1148. result = dict(bound.arguments.items())
  1149. wrap_args_kwargs(parent.output.root_tx, result)
  1150. init_cellvars(parent, result, code)
  1151. for idx, name in enumerate(code.co_freevars):
  1152. assert name not in result
  1153. cell = self.closure.items[idx]
  1154. result[name] = cell
  1155. return result
  1156. def reconstruct(self, codegen: "PyCodegen"):
  1157. codegen.add_push_null(
  1158. lambda: codegen.load_import_from(__name__, "_create_nested_fn")
  1159. )
  1160. codegen(self.code)
  1161. codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)])
  1162. codegen(ConstantVariable.create(self.code.value.co_name))
  1163. if self.defaults:
  1164. codegen(self.defaults)
  1165. else:
  1166. codegen.extend_output([codegen.create_load_const(None)])
  1167. if self.closure:
  1168. codegen(self.closure)
  1169. else:
  1170. codegen.extend_output([codegen.create_load_const(None)])
  1171. if self.kwdefaults:
  1172. codegen(self.kwdefaults)
  1173. else:
  1174. codegen.extend_output([codegen.create_load_const(None)])
  1175. if self.annotations:
  1176. try:
  1177. annotations = self.annotations.as_python_constant()
  1178. codegen.extend_output(
  1179. [codegen.create_load_const_unchecked(annotations)]
  1180. )
  1181. except NotImplementedError:
  1182. codegen(self.annotations)
  1183. else:
  1184. codegen.extend_output([codegen.create_load_const(None)])
  1185. codegen.extend_output(create_call_function(7, False))
  1186. if self.wrapped_fn:
  1187. codegen.add_push_null(
  1188. lambda: codegen.load_import_from("functools", "wraps")
  1189. )
  1190. codegen(self.wrapped_fn)
  1191. codegen.extend_output(create_call_function(1, False))
  1192. codegen.extend_output(create_rot_n(2))
  1193. codegen.extend_output(create_call_function(1, True))
  1194. # codegen attributes
  1195. from torch._dynamo.symbolic_convert import InstructionTranslator
  1196. tx = InstructionTranslator.current_tx()
  1197. if tx.output.side_effects.has_pending_mutation(self):
  1198. for name, value in tx.output.side_effects.store_attr_mutations[
  1199. self
  1200. ].items():
  1201. codegen.dup_top()
  1202. codegen(value)
  1203. codegen.extend_output(create_rot_n(2))
  1204. codegen.store_attr(name)
  1205. class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable):
  1206. def __init__(self, wrapped, context, **kwargs) -> None:
  1207. kwargs.pop("fn_name", None)
  1208. kwargs.pop("code", None)
  1209. kwargs.pop("f_globals", None)
  1210. kwargs.pop("defaults", None)
  1211. kwargs.pop("kwdefaults", None)
  1212. kwargs.pop("annotations", None)
  1213. kwargs.pop("closure", None)
  1214. kwargs.pop("wrapped_fn", None)
  1215. super().__init__(
  1216. wrapped.fn_name,
  1217. wrapped.code,
  1218. wrapped.f_globals,
  1219. wrapped.defaults,
  1220. wrapped.kwdefaults,
  1221. wrapped.annotations,
  1222. wrapped.closure,
  1223. wrapped.wrapped_fn,
  1224. )
  1225. self.wrapped = wrapped
  1226. self.context = context
  1227. def call_function(
  1228. self,
  1229. tx: "InstructionTranslator",
  1230. args: "list[VariableTracker]",
  1231. kwargs: "dict[str, VariableTracker]",
  1232. ) -> "VariableTracker":
  1233. self.context.enter(tx)
  1234. result = super().call_function(tx, args, kwargs)
  1235. self.context.exit(tx)
  1236. return result
  1237. def reconstruct(self, codegen):
  1238. codegen.add_push_null(lambda: codegen(self.context))
  1239. codegen(self.wrapped)
  1240. codegen.extend_output(create_call_function(1, False))
  1241. class SkipFunctionVariable(VariableTracker):
  1242. _nonvar_fields = {
  1243. "value",
  1244. "reason",
  1245. *VariableTracker._nonvar_fields,
  1246. }
  1247. def __init__(self, value, reason=None, **kwargs) -> None:
  1248. super().__init__(**kwargs)
  1249. self.value = value
  1250. self.reason = reason
  1251. def as_python_constant(self):
  1252. return self.value
  1253. @classmethod
  1254. def create_with_source(cls, value, source):
  1255. # Use closure match guard (i.e. guard on __code__ object instead of
  1256. # function id) to avoid guarding on nested functions.
  1257. if inspect.getattr_static(value, "_torchdynamo_disable", False):
  1258. # For torch._dynamo.disable function, ensure that the original
  1259. # function is guarded. Otherwise, the else branch will guard on the
  1260. # _dynamo.disable.__code__
  1261. guard_on_source = source
  1262. guard_on_value = value
  1263. while getattr(guard_on_value, "_torchdynamo_orig_callable", False):
  1264. guard_on_value = guard_on_value._torchdynamo_orig_callable
  1265. guard_on_source = AttrSource(
  1266. guard_on_source, "_torchdynamo_orig_callable"
  1267. )
  1268. guard_on_source.make_guard(GuardBuilder.CLOSURE_MATCH)
  1269. elif not is_wrapper_or_member_descriptor(value):
  1270. # These descriptors are not guaranteed to return the same object on
  1271. # attribute lookup. They are unlikely to be changed, so we can skip
  1272. # guarding them.
  1273. install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
  1274. return cls(value, source=source)
  1275. def call_function(
  1276. self,
  1277. tx: "InstructionTranslator",
  1278. args: "list[VariableTracker]",
  1279. kwargs: "dict[str, VariableTracker]",
  1280. ) -> "VariableTracker":
  1281. if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
  1282. msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None)
  1283. unimplemented_v2(
  1284. gb_type="Skip calling `torch.compiler.disable()`d function",
  1285. context=str(self.value),
  1286. explanation=f"Skip calling function `{self.value}` since it was wrapped "
  1287. f"with `torch.compiler.disable` (reason: {msg})",
  1288. hints=[
  1289. "Remove the `torch.compiler.disable` call",
  1290. ],
  1291. )
  1292. elif self.value is torch._dynamo.graph_break:
  1293. graph_break_msg = kwargs.get("msg", None)
  1294. if graph_break_msg:
  1295. graph_break_msg = graph_break_msg.as_python_constant()
  1296. unimplemented_v2(
  1297. gb_type="Call to `torch._dynamo.graph_break()`",
  1298. context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`",
  1299. explanation=f"User-inserted graph break. Message: {graph_break_msg}",
  1300. hints=[
  1301. "Remove the `torch._dynamo.graph_break()` call.",
  1302. ],
  1303. )
  1304. elif self.value is torch._dynamo.skip_frame:
  1305. skip_frame_msg = kwargs.get("msg", None)
  1306. if skip_frame_msg:
  1307. skip_frame_msg = skip_frame_msg.as_python_constant()
  1308. raise SkipFrame(
  1309. f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
  1310. )
  1311. else:
  1312. if config.dont_skip_tracing:
  1313. from .builder import SourcelessBuilder
  1314. # re-build the function, attempting to not skip
  1315. rebuilt_fn = SourcelessBuilder.create(tx, self.value)
  1316. # if we still get SkipFunctionVariable, then we *really* should skip this function
  1317. if not isinstance(rebuilt_fn, SkipFunctionVariable):
  1318. return rebuilt_fn.call_function(tx, args, kwargs)
  1319. qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
  1320. module_or = getattr(self.value, "__module__", None)
  1321. module_name = "<unknown module>" if module_or is None else str(module_or)
  1322. try:
  1323. path = inspect.getfile(self.value)
  1324. explanation = (
  1325. f"Dynamo developers have intentionally marked that the function `{qualname}` "
  1326. f"in file `{path}` should not be traced."
  1327. )
  1328. hints = [
  1329. f"Avoid calling the function `{qualname}`.",
  1330. ]
  1331. # TODO improve trace_rules reasoning to provide better hints.
  1332. # How do we tell that a function/file should NOT be removed from skip files?
  1333. # Do a very basic check for now.
  1334. if "_dynamo" not in path:
  1335. hints += [
  1336. f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` "
  1337. "to force tracing into the function. "
  1338. "More graph breaks may occur as a result of attempting to trace into the function.",
  1339. "Please file an issue to PyTorch.",
  1340. ]
  1341. except TypeError:
  1342. known_python_builtin_modules = {"_abc", "_warnings"}
  1343. if module_or in known_python_builtin_modules:
  1344. explanation = (
  1345. f"Dynamo does not know how to trace the Python builtin "
  1346. f"`{module_name}.{qualname}`."
  1347. )
  1348. hints = [
  1349. "If you are attempting to call a logging function (e.g. `_warnings.warn`), "
  1350. "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
  1351. "Please file an issue on GitHub "
  1352. "so the PyTorch team can add support for it. ",
  1353. ]
  1354. elif module_or is not None and module_or.startswith("optree"):
  1355. explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}."
  1356. hints = [
  1357. " Consider using torch.utils._pytree - "
  1358. "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
  1359. ]
  1360. # also warn on it because most users won't see the graph break message
  1361. torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
  1362. else:
  1363. explanation = (
  1364. f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` "
  1365. f"This function is either a Python builtin (e.g. _warnings.warn) "
  1366. f"or a third-party C/C++ Python extension (perhaps created with pybind)."
  1367. )
  1368. hints = [
  1369. "If it is a Python builtin, please file an issue on GitHub "
  1370. "so the PyTorch team can add support for it and see the next case for a workaround.",
  1371. "If it is a third-party C/C++ Python extension, please "
  1372. "either wrap it into a PyTorch-understood custom operator "
  1373. "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
  1374. "for more details) or, if it is traceable, use "
  1375. "`torch.compiler.allow_in_graph`.",
  1376. ]
  1377. # also warn on it because most users won't see the graph break message
  1378. torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
  1379. if qualname == "allow_in_graph":
  1380. explanation = (
  1381. "Found an allow_in_graph decorator to a function which "
  1382. "is created inside the parent function that is getting "
  1383. "compiled. This is not supported for now."
  1384. )
  1385. hints = []
  1386. reason = self.reason if self.reason else "<missing reason>"
  1387. unimplemented_v2(
  1388. gb_type="Attempted to call function marked as skipped",
  1389. context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}",
  1390. explanation=explanation,
  1391. hints=hints,
  1392. )
  1393. def call_obj_hasattr(self, tx: "InstructionTranslator", name):
  1394. return variables.ConstantVariable.create(hasattr(self.value, name))
  1395. def var_getattr(self, tx: "InstructionTranslator", name: str):
  1396. if name in cmp_name_to_op_mapping:
  1397. return variables.GetAttrVariable(self, name)
  1398. return fn_var_getattr(tx, self.value, self.source, name)
  1399. class WrappedSkipFunctionVariable(SkipFunctionVariable):
  1400. def __init__(self, wrapped, context, **kwargs) -> None:
  1401. kwargs.pop("value", None)
  1402. kwargs.pop("reason", None)
  1403. super().__init__(wrapped.value, reason=wrapped.reason, **kwargs)
  1404. self.wrapped = wrapped
  1405. self.context = context
  1406. def call_function(
  1407. self,
  1408. tx: "InstructionTranslator",
  1409. args: "list[VariableTracker]",
  1410. kwargs: "dict[str, VariableTracker]",
  1411. ) -> "VariableTracker":
  1412. self.context.enter(tx)
  1413. result = super().call_function(tx, args, kwargs)
  1414. self.context.exit(tx)
  1415. return result
  1416. def reconstruct(self, codegen):
  1417. codegen.add_push_null(lambda: codegen(self.context))
  1418. codegen(self.wrapped)
  1419. codegen.extend_output(create_call_function(1, False))
  1420. class WrapperUserFunctionVariable(VariableTracker):
  1421. """
  1422. Used to represent a wrapper object that contains the actual callable as an
  1423. attribute. For example, torch.jit.script/trace have the original function at
  1424. their _torchdynamo_inline attribute. Similarly, functions with
  1425. __script_if_tracing_wrapper have the original attr at "__original_fn".
  1426. """
  1427. def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None:
  1428. super().__init__(**kwargs)
  1429. self.wrapper_obj = wrapper_obj
  1430. self.attr_to_trace = attr_to_trace
  1431. def var_getattr(self, tx: "InstructionTranslator", name):
  1432. if name == self.attr_to_trace:
  1433. val = getattr(self.wrapper_obj, self.attr_to_trace)
  1434. source = self.source and AttrSource(self.source, name)
  1435. return VariableTracker.build(tx, val, source)
  1436. return super().var_getattr(tx, name)
  1437. def self_args(self):
  1438. return []
  1439. def call_function(
  1440. self,
  1441. tx: "InstructionTranslator",
  1442. args: "list[VariableTracker]",
  1443. kwargs: "dict[str, VariableTracker]",
  1444. ) -> "VariableTracker":
  1445. if hasattr(self.wrapper_obj, "cache_info"):
  1446. target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None)
  1447. module_name = getattr(target_fn, "__module__", "") or ""
  1448. if module_name.split(".", maxsplit=1)[0] != "torch":
  1449. msg = (
  1450. "Dynamo detected a call to a `functools.lru_cache`-wrapped "
  1451. "function. Dynamo ignores the cache wrapper and directly "
  1452. "traces the wrapped function. Silent incorrectness is only "
  1453. "a *potential* risk, not something we have observed. "
  1454. 'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.'
  1455. )
  1456. torch._dynamo.utils.warn_once(msg)
  1457. dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo")
  1458. if dynamo_logger.isEnabledFor(logging.DEBUG):
  1459. user_stack = torch._guards.TracingContext.extract_stack()
  1460. user_stack = get_stack_above_dynamo() + user_stack
  1461. frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
  1462. user_stack_formatted = "".join(traceback.format_list(user_stack))
  1463. user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n"
  1464. user_stack_trace += str(user_stack_formatted)
  1465. dynamo_logger.debug(user_stack_trace)
  1466. all_args = self.self_args() + args
  1467. return variables.UserFunctionVariable(
  1468. polyfills.getattr_and_trace
  1469. ).call_function(
  1470. tx,
  1471. [self, variables.ConstantVariable(self.attr_to_trace), *all_args],
  1472. kwargs,
  1473. )
  1474. class WrapperUserMethodVariable(WrapperUserFunctionVariable):
  1475. """
  1476. Similar to WrapperUserFunctionVariable, but for methods. The only delta is
  1477. saving the vt for `self` object of the method which is then used by
  1478. WrapperUserFunctionVariable in `call_function` method.
  1479. """
  1480. def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None:
  1481. super().__init__(wrapper_obj, attr_to_trace, **kwargs)
  1482. self.obj = self_obj
  1483. def self_args(self):
  1484. return [self.obj]
  1485. def _traceable_collective_remaps():
  1486. # We can't rely on importing from distributed, since it's not always built
  1487. if torch.distributed.is_available():
  1488. from torch.distributed._functional_collectives import (
  1489. traceable_collective_remaps,
  1490. )
  1491. return traceable_collective_remaps
  1492. return {}
  1493. def _traceable_collectives_source(tx: "InstructionTranslator", fn):
  1494. assert torch.distributed.is_available(), "Illegal invocation."
  1495. assert fn in _traceable_collective_remaps().values()
  1496. inner_name = fn.__name__
  1497. path_source = tx.import_source("torch.distributed._functional_collectives")
  1498. return AttrSource(path_source, inner_name)
  1499. class CollectiveFunctionRewriteVariable(UserFunctionVariable):
  1500. """
  1501. Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
  1502. This class provides both a way to check if a function is remappable, and perform the remapping.
  1503. In the case that a function is 'remappable' but only for some combinations of call-time arguments,
  1504. we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
  1505. than status-quo as we currently graph-break on all distributed.* collectives.
  1506. """
  1507. def __init__(self, fn, *, replacement_var, **kwargs) -> None:
  1508. super().__init__(fn, **kwargs)
  1509. assert isinstance(replacement_var, UserFunctionVariable)
  1510. self.replacement_var = replacement_var
  1511. @staticmethod
  1512. def create(tx: "InstructionTranslator", old_fn, source, **options):
  1513. new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
  1514. return CollectiveFunctionRewriteVariable(
  1515. old_fn,
  1516. replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
  1517. source=source,
  1518. **options,
  1519. )
  1520. @staticmethod
  1521. def can_rewrite(variable):
  1522. return (
  1523. inspect.isfunction(variable) and variable in _traceable_collective_remaps()
  1524. )
  1525. @staticmethod
  1526. def rewrite(tx: "InstructionTranslator", fn):
  1527. new_fn = _traceable_collective_remaps()[fn]
  1528. return new_fn, _traceable_collectives_source(tx, new_fn)
  1529. def call_function(
  1530. self,
  1531. tx: "InstructionTranslator",
  1532. args: "list[VariableTracker]",
  1533. kwargs: "dict[str, VariableTracker]",
  1534. ) -> "VariableTracker":
  1535. # call_function must check any unsupported arguments and graph-break.
  1536. # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
  1537. # since that's the contract for putting a mapping in `traceable_collective_remaps`
  1538. import torch.distributed as dist
  1539. from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
  1540. # Merge args into kwargs so positional and keyword args
  1541. # can be processed the same way.
  1542. signature = inspect.signature(self.fn)
  1543. kwargs = dict(signature.bind(*args, **kwargs).arguments)
  1544. args = ()
  1545. if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
  1546. unimplemented_v2(
  1547. gb_type="async_op=True for distributed collectives",
  1548. context=f"{self.fn}, {args=}, {kwargs=}",
  1549. explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}",
  1550. hints=[
  1551. *graph_break_hints.SUPPORTABLE,
  1552. ],
  1553. )
  1554. if self.fn in (
  1555. dist.all_reduce,
  1556. dist.reduce_scatter_tensor,
  1557. dist._reduce_scatter_base,
  1558. ):
  1559. reduce_op_var = kwargs.get("op")
  1560. reduce_op = (
  1561. reduce_op_var.value
  1562. if reduce_op_var is not None
  1563. else signature.parameters["op"].default
  1564. )
  1565. if reduce_op not in REDUCE_OP_TO_STR:
  1566. raise ValueError(f"Unsupported all_reduce op: {reduce_op}")
  1567. kwargs["op"] = variables.ConstantVariable.create(
  1568. REDUCE_OP_TO_STR[reduce_op]
  1569. )
  1570. return self.replacement_var.call_function(tx, args, kwargs)
  1571. class FunctoolsWrapsVariable(UserFunctionVariable):
  1572. def call_function(
  1573. self,
  1574. tx: "InstructionTranslator",
  1575. args: "list[VariableTracker]",
  1576. kwargs: "dict[str, VariableTracker]",
  1577. ) -> "VariableTracker":
  1578. if not kwargs and len(args) == 1:
  1579. def wraps(fn):
  1580. if isinstance(fn, variables.NestedUserFunctionVariable):
  1581. return fn.clone(wrapped_fn=args[0])
  1582. unimplemented_v2(
  1583. gb_type="functools.wraps",
  1584. context=f"{fn}",
  1585. explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region",
  1586. hints=[
  1587. *graph_break_hints.SUPPORTABLE,
  1588. ],
  1589. )
  1590. return variables.LambdaVariable(wraps)
  1591. return super().call_function(tx, args, kwargs)
  1592. class CollectionsNamedTupleFunction(UserFunctionVariable):
  1593. def as_python_constant(self):
  1594. return self.fn
  1595. def call_function(
  1596. self,
  1597. tx: "InstructionTranslator",
  1598. args: "list[VariableTracker]",
  1599. kwargs: "dict[str, VariableTracker]",
  1600. ) -> "VariableTracker":
  1601. constant_args = check_constant_args(args, kwargs)
  1602. if constant_args:
  1603. try:
  1604. value = self.fn(
  1605. *[x.as_python_constant() for x in args],
  1606. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1607. )
  1608. except TypeError as exc:
  1609. raise_observed_exception(
  1610. type(exc),
  1611. tx,
  1612. args=list(map(ConstantVariable.create, exc.args)),
  1613. )
  1614. return variables.UserDefinedClassVariable(
  1615. value, mutation_type=ValueMutationNew()
  1616. )
  1617. unimplemented_v2(
  1618. gb_type="namedtuple construction",
  1619. context=f"{args=}, {kwargs=}",
  1620. explanation="`torch.compile` only support certain input types for namedtuple",
  1621. hints=[
  1622. *graph_break_hints.SUPPORTABLE,
  1623. ],
  1624. )
  1625. class FunctoolsPartialVariable(VariableTracker):
  1626. def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None:
  1627. super().__init__(**kwargs)
  1628. self.func = func
  1629. assert isinstance(args, list)
  1630. self.args = args
  1631. assert isinstance(keywords, dict)
  1632. self.keywords = keywords
  1633. # fake_value is used for id calculation. Creating this value and id'ng
  1634. # on it is sufficient for the tracing purposes.
  1635. self.fake_value = functools.partial(identity)
  1636. def python_type(self):
  1637. return functools.partial
  1638. def reconstruct(self, codegen: "PyCodegen"):
  1639. codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial"))
  1640. codegen(self.func)
  1641. if self.args:
  1642. codegen.foreach(self.args)
  1643. if not self.keywords:
  1644. codegen.extend_output(create_call_function(len(self.args) + 1, False))
  1645. return
  1646. codegen.foreach(self.keywords.values())
  1647. keys = tuple(self.keywords.keys())
  1648. codegen.extend_output(
  1649. codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False)
  1650. )
  1651. def get_function(self):
  1652. return self.as_python_constant()
  1653. def call_function(
  1654. self,
  1655. tx: "InstructionTranslator",
  1656. args: "list[VariableTracker]",
  1657. kwargs: "dict[str, VariableTracker]",
  1658. ) -> "VariableTracker":
  1659. merged_args = self.args + args
  1660. merged_kwargs = {**self.keywords, **kwargs}
  1661. return self.func.call_function(tx, merged_args, merged_kwargs)
  1662. def call_obj_hasattr(
  1663. self, tx: "InstructionTranslator", name: str
  1664. ) -> VariableTracker:
  1665. # functools.partial uses slots, so attributes are constant
  1666. return variables.ConstantVariable.create(
  1667. hasattr(functools.partial(identity), name)
  1668. )
  1669. def var_getattr(self, tx: "InstructionTranslator", name: str):
  1670. source = self.source and AttrSource(self.source, name)
  1671. # Handle __slots__
  1672. if name == "func":
  1673. return self.func
  1674. if name == "args":
  1675. return variables.ListVariable(self.args, source=source)
  1676. if name == "keywords":
  1677. items = {ConstantVariable.create(k): v for k, v in self.keywords.items()}
  1678. return variables.ConstDictVariable(items, source=source)
  1679. if name in cmp_name_to_op_mapping:
  1680. return variables.GetAttrVariable(self, name)
  1681. raise_observed_exception(AttributeError, tx)
  1682. def as_python_constant(self):
  1683. return functools.partial(
  1684. self.func.as_python_constant(),
  1685. *[arg.as_python_constant() for arg in self.args],
  1686. **{k: v.as_python_constant() for k, v in self.keywords.items()},
  1687. )
  1688. def guard_as_python_constant(self):
  1689. """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
  1690. return functools.partial(
  1691. self.func.guard_as_python_constant(),
  1692. *[v.guard_as_python_constant() for v in self.args],
  1693. **{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
  1694. )
  1695. class PolyfilledFunctionVariable(VariableTracker):
  1696. _nonvar_fields = {
  1697. "fn",
  1698. "wrapped_fn",
  1699. "traceable_fn",
  1700. *VariableTracker._nonvar_fields,
  1701. }
  1702. @classmethod
  1703. @functools.cache
  1704. def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]:
  1705. return {}
  1706. @classmethod
  1707. def create_with_source(cls, value, source):
  1708. install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
  1709. return cls(value, source=source)
  1710. def __init__(self, fn: _F, **kwargs) -> None:
  1711. super().__init__(**kwargs)
  1712. self.fn: _F = fn
  1713. handler = self._get_polyfill_handlers().get(fn, fn)
  1714. assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}"
  1715. for candidate_attr in (
  1716. "__torch_dynamo_polyfill__", # registered polyfill
  1717. "__python_implementation__", # self handler from third-party libraries
  1718. ):
  1719. candidate = getattr(handler, candidate_attr, None)
  1720. if candidate:
  1721. assert callable(candidate)
  1722. traceable_fn = candidate
  1723. break
  1724. else:
  1725. raise RuntimeError(
  1726. f"Polyfill handler {handler} does not have a traceable function"
  1727. )
  1728. self.wrapped_fn: _F = handler
  1729. self.traceable_fn: _F = traceable_fn
  1730. @property
  1731. def polyfill_fn(self) -> _F:
  1732. return self.traceable_fn
  1733. def can_constant_fold_through(self):
  1734. return getattr(
  1735. self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False
  1736. )
  1737. def get_function(self):
  1738. return self.as_python_constant()
  1739. def call_function(
  1740. self,
  1741. tx: "InstructionTranslator",
  1742. args: "list[VariableTracker]",
  1743. kwargs: "dict[str, VariableTracker]",
  1744. ) -> "VariableTracker":
  1745. if self.can_constant_fold_through() and check_unspec_or_constant_args(
  1746. args, kwargs
  1747. ):
  1748. result = (
  1749. self.fn( # use the original function which is faster than the polyfill
  1750. *[x.as_python_constant() for x in args],
  1751. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1752. )
  1753. )
  1754. return VariableTracker.build(tx, result)
  1755. # Special case for sum on tuple/list of ints
  1756. if (
  1757. self.fn is builtins.sum
  1758. and len(args) == 1
  1759. and not kwargs
  1760. and isinstance(args[0], (variables.ListVariable, variables.TupleVariable))
  1761. and all(
  1762. (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int))
  1763. or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int)
  1764. for x in args[0].items
  1765. )
  1766. ):
  1767. return variables.SymNodeVariable.create(
  1768. tx,
  1769. tx.output.create_proxy(
  1770. "call_function",
  1771. torch.sym_sum,
  1772. (tuple(a.as_proxy() for a in args[0].items),),
  1773. {},
  1774. ),
  1775. sym_num=torch.sym_sum(
  1776. [
  1777. (
  1778. x.value
  1779. if isinstance(x, variables.ConstantVariable)
  1780. else x.sym_num
  1781. )
  1782. for x in args[0].items
  1783. ]
  1784. ),
  1785. )
  1786. traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
  1787. return traceable_function_variable.call_function(tx, args, kwargs)
  1788. def call_method(
  1789. self,
  1790. tx,
  1791. name,
  1792. args: "list[VariableTracker]",
  1793. kwargs: "dict[str, VariableTracker]",
  1794. ) -> "VariableTracker":
  1795. if name == "__call__":
  1796. return self.call_function(tx, args, kwargs)
  1797. method = getattr(self.fn, name, None)
  1798. assert method is not None, f"Member {name} not found in {self.fn}"
  1799. assert is_function(method), f"Member {name} is not callable in {self.fn}"
  1800. options = {}
  1801. if self.source:
  1802. options["source"] = AttrSource(self.source, name)
  1803. polyfilled_method_variable = PolyfilledFunctionVariable(method, **options)
  1804. return polyfilled_method_variable.call_function(tx, args, kwargs)
  1805. def as_python_constant(self):
  1806. return self.fn
  1807. class TracebackVariable(VariableTracker):
  1808. # We don't track traceback. A call to any function in this module is a no-op
  1809. def call_function(self, tx, args, kwargs): ...
  1810. class SysFunctionVariable(VariableTracker):
  1811. def __init__(self, value, **kwargs):
  1812. super().__init__(**kwargs)
  1813. self.value = value
  1814. def exc_info(self, tx):
  1815. if len(tx.exn_vt_stack):
  1816. exn = tx.exn_vt_stack[-1]
  1817. typ = exn.exc_type
  1818. tb = None
  1819. items = [
  1820. VariableTracker.build(tx, typ),
  1821. exn,
  1822. VariableTracker.build(tx, tb),
  1823. ]
  1824. else:
  1825. items = [
  1826. variables.ConstantVariable(None),
  1827. variables.ConstantVariable(None),
  1828. variables.ConstantVariable(None),
  1829. ]
  1830. return variables.TupleVariable(items)
  1831. def exception(self, tx):
  1832. return self.exc_info(tx).items[1]
  1833. def call_function(self, tx, args, kwargs):
  1834. if self.value is sys.exc_info:
  1835. return self.exc_info(tx)
  1836. assert self.value is sys.exception
  1837. return self.exception(tx)
  1838. from torch._higher_order_ops.triton_kernel_wrap import (
  1839. create_tma_experimental_metadata,
  1840. create_tma_stable_metadata,
  1841. TMADescriptorMetadata,
  1842. TritonHOPifier,
  1843. )
  1844. class DynamoTritonHOPifier(TritonHOPifier):
  1845. def raise_unsupported(self, msg: str) -> Never:
  1846. raise Unsupported(msg)
  1847. def is_callable(self, maybe_callable: Any) -> bool:
  1848. return isinstance(
  1849. maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable)
  1850. )
  1851. def get_value(self, val: Any) -> Any:
  1852. return val.value
  1853. def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
  1854. from .lists import BaseListVariable
  1855. if isinstance(grid, BaseListVariable):
  1856. return grid.as_proxy()
  1857. else:
  1858. unimplemented_v2(
  1859. gb_type="unsupported grid type for triton hop check_grid",
  1860. context=f"grid type = {type(grid)}",
  1861. explanation="`torch.compile` only supports list-like grid for check_grid",
  1862. hints=[
  1863. *graph_break_hints.SUPPORTABLE,
  1864. ],
  1865. )
  1866. def call_grid(self, grid, meta, tx):
  1867. meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()}
  1868. grid = grid.call_function(tx, [meta], {})
  1869. return grid
  1870. # We use this function to wrap call_prune_configs
  1871. def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable):
  1872. from .builder import SourcelessBuilder
  1873. wrapped_user_function = SourcelessBuilder.create(tx, user_fn)
  1874. result = wrapped_user_function.call_function(tx, args, kwargs)
  1875. return result
  1876. def wrap_user_defined_obj(self, user_obj, tx, variable, name):
  1877. from .builder import VariableBuilder
  1878. wrapped_user_obj = VariableBuilder(
  1879. tx, AttrSource(variable.kernel_source, f"{name}")
  1880. )._wrap(user_obj)
  1881. return wrapped_user_obj
  1882. def maybe_unpack_configs(self, configs, tx):
  1883. # unpack the list of configs
  1884. configs = configs.unpack_var_sequence(tx)
  1885. # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed.
  1886. configs = [config.guard_as_python_constant() for config in configs]
  1887. return configs
  1888. def maybe_unpack_heuristic_result(self, result: Any) -> Any:
  1889. if not result.is_python_constant():
  1890. self.raise_unsupported(
  1891. "@triton.heuristics must return constant values because configs can only contain constant values."
  1892. )
  1893. return result.guard_as_python_constant()
  1894. # We need to override call_getitem here so that we can add the source in the case
  1895. # where we call the triton kernel with a grid
  1896. def call_getitem(
  1897. self,
  1898. variable: "TritonKernelVariable",
  1899. args: Sequence[Any],
  1900. ) -> "TritonKernelVariable":
  1901. # __getitem__ should only be called if we don't already have a grid
  1902. # Only grid needs to be passed
  1903. if variable.grid is not None or len(args) != 1:
  1904. self.raise_unsupported(
  1905. "Triton kernels should be called with only a single grid"
  1906. )
  1907. return type(variable)(
  1908. kernel=variable.kernel,
  1909. kernel_idx=variable.kernel_idx,
  1910. grid=args[0],
  1911. kernel_source=variable.source,
  1912. )
  1913. def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
  1914. from .constant import ConstantVariable
  1915. from .dicts import ConstDictVariable
  1916. # as we can only pass tensors as non-const args in fx graph,
  1917. # here we replace TMA descriptors
  1918. # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable
  1919. # instances) with the underlying tensors, while moving the
  1920. # TMA descriptor-related metadata to a separate argument,
  1921. # so that we can reconstruct the TMA descriptors downstream
  1922. tma_descriptor_metadata: TMADescriptorMetadata = {}
  1923. for k in list(combined_args_raw.keys()):
  1924. v = combined_args_raw[k]
  1925. if isinstance(
  1926. v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable)
  1927. ):
  1928. tma_descriptor_metadata[k] = v.to_metadata()
  1929. combined_args_raw[k] = v.get_tensor()
  1930. combined_args = {
  1931. variables.ConstantVariable.create(k): v
  1932. for k, v in combined_args_raw.items()
  1933. }
  1934. from torch._higher_order_ops.triton_kernel_wrap import (
  1935. kernel_side_table,
  1936. triton_kernel_wrapper_mutation,
  1937. )
  1938. # Combine args and kwargs and pass as a dict so that if user defined triton
  1939. # kernel uses variables as 'grid' or 'kernel', it does not conflict with
  1940. # parameters of the wrapper function
  1941. constant_args = {
  1942. k: v.as_python_constant()
  1943. for k, v in combined_args_raw.items()
  1944. if isinstance(v, ConstantVariable)
  1945. }
  1946. non_constant_args = {
  1947. k: v
  1948. for k, v in combined_args.items()
  1949. if not isinstance(v, ConstantVariable)
  1950. }
  1951. for v in non_constant_args.values():
  1952. v = v.realize()
  1953. if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)):
  1954. self.raise_unsupported(
  1955. f"Unexpected argument type for a Triton kernel: {repr(v)}."
  1956. )
  1957. constant_args_idx = kernel_side_table.add_constant_args(constant_args)
  1958. meta = ConstDictVariable(non_constant_args, dict)
  1959. tx.output.create_proxy(
  1960. "call_function",
  1961. triton_kernel_wrapper_mutation,
  1962. (),
  1963. {
  1964. "kernel_idx": variable.kernel_idx,
  1965. "constant_args_idx": constant_args_idx,
  1966. "grid": grids,
  1967. "tma_descriptor_metadata": tma_descriptor_metadata,
  1968. "kwargs": meta.as_proxy(),
  1969. },
  1970. )
  1971. return variables.ConstantVariable(
  1972. None,
  1973. )
  1974. dynamo_triton_hopifier_singleton = DynamoTritonHOPifier()
  1975. class TritonKernelVariable(VariableTracker):
  1976. grid: "TritonGridType"
  1977. kernel: "TritonKernelType"
  1978. kernel_idx: Optional[int]
  1979. kernel_source: "AttrSource"
  1980. def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
  1981. self.kernel_source = kwargs.pop("kernel_source", None)
  1982. super().__init__(**kwargs)
  1983. dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
  1984. def call_function(
  1985. self,
  1986. tx: "InstructionTranslator",
  1987. args: "list[VariableTracker]",
  1988. kwargs: "dict[str, VariableTracker]",
  1989. ) -> "VariableTracker":
  1990. return dynamo_triton_hopifier_singleton.call_triton_kernel(
  1991. self, args, kwargs, tx
  1992. )
  1993. def call_method(
  1994. self,
  1995. tx,
  1996. name,
  1997. args: "list[VariableTracker]",
  1998. kwargs: "dict[str, VariableTracker]",
  1999. ) -> "VariableTracker":
  2000. if name == "__getitem__":
  2001. return dynamo_triton_hopifier_singleton.call_getitem(self, args)
  2002. elif name == "run":
  2003. return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx)
  2004. # Bail out to parent's implementation
  2005. return super().call_method(tx, name, args, kwargs)
  2006. def specialize_symbolic(self, arg: Any) -> Any:
  2007. from .constant import ConstantVariable
  2008. from .tensor import SymNodeVariable
  2009. # See [Note: Specialize tl.constexpr args in user-defined triton kernels]
  2010. if isinstance(arg, SymNodeVariable):
  2011. return ConstantVariable.create(arg.evaluate_expr())
  2012. return arg
  2013. class TMADescriptorExperimentalVariable(VariableTracker):
  2014. def __init__(
  2015. self,
  2016. data_ptr: "variables.DataPtrVariable",
  2017. dims: "list[ConstantVariable]",
  2018. block_dims: "list[ConstantVariable]",
  2019. element_size: "ConstantVariable",
  2020. **kwargs,
  2021. ):
  2022. assert isinstance(data_ptr, variables.DataPtrVariable)
  2023. super().__init__(**kwargs)
  2024. self.data_ptr = data_ptr
  2025. self.dims = dims
  2026. self.block_dims = block_dims
  2027. self.element_size = element_size
  2028. def to_metadata(self):
  2029. return create_tma_experimental_metadata(
  2030. [dim.as_proxy() for dim in self.dims],
  2031. [dim.as_proxy() for dim in self.block_dims],
  2032. self.element_size.as_proxy(),
  2033. )
  2034. def reconstruct(self, codegen: "PyCodegen"):
  2035. codegen.add_push_null(
  2036. lambda: codegen.load_import_from(
  2037. "triton.tools.experimental_descriptor",
  2038. f"create_{len(self.dims)}d_tma_descriptor",
  2039. )
  2040. )
  2041. self.data_ptr.reconstruct(codegen)
  2042. args = [*self.dims, *self.block_dims, self.element_size]
  2043. codegen.foreach(args)
  2044. codegen.call_function(len(args) + 1, False)
  2045. def get_tensor(self):
  2046. return self.data_ptr.from_tensor
  2047. class TMADescriptorStableVariable(VariableTracker):
  2048. def __init__(
  2049. self,
  2050. tensor: "variables.TensorVariable",
  2051. block_shape: "variables.ListVariable",
  2052. **kwargs,
  2053. ):
  2054. assert isinstance(tensor, variables.TensorVariable)
  2055. super().__init__(**kwargs)
  2056. self.tensor = tensor
  2057. self.block_shape = block_shape
  2058. def to_metadata(self):
  2059. return create_tma_stable_metadata(
  2060. self.block_shape.as_proxy(),
  2061. )
  2062. def reconstruct(self, codegen: "PyCodegen"):
  2063. codegen.add_push_null(
  2064. lambda: codegen.load_import_from(
  2065. "triton.tools.tensor_descriptor",
  2066. "TensorDescriptor",
  2067. )
  2068. )
  2069. codegen.load_method("from_tensor")
  2070. self.tensor.reconstruct(codegen)
  2071. codegen(self.block_shape)
  2072. codegen.call_method(2)
  2073. def get_tensor(self) -> "variables.TensorVariable":
  2074. return self.tensor
  2075. class CreateTMADescriptorExperimentalVariable(VariableTracker):
  2076. def __init__(
  2077. self,
  2078. rank: int,
  2079. **kwargs,
  2080. ) -> None:
  2081. assert rank in (1, 2)
  2082. super().__init__(**kwargs)
  2083. self.rank = rank
  2084. def call_function(
  2085. self,
  2086. tx: "InstructionTranslator",
  2087. args: "list[VariableTracker]",
  2088. kwargs: "dict[str, VariableTracker]",
  2089. ) -> "VariableTracker":
  2090. ptr = kwargs["ptr"] if "ptr" in kwargs else args[0]
  2091. if not isinstance(ptr, variables.DataPtrVariable):
  2092. raise Unsupported(
  2093. "Please ensure there were no graph breaks between "
  2094. f"create_{self.rank}d_tma_descriptor and the upstream "
  2095. ".data_ptr() call."
  2096. )
  2097. if self.rank == 1:
  2098. assert len(args) + len(kwargs) == 4
  2099. dims = [
  2100. kwargs["dim"] if "dim" in kwargs else args[1],
  2101. ]
  2102. block_dims = [
  2103. kwargs["block_dim"] if "block_dim" in kwargs else args[2],
  2104. ]
  2105. else:
  2106. assert len(args) + len(kwargs) == 6
  2107. dims = [
  2108. kwargs["dim1"] if "dim1" in kwargs else args[1],
  2109. kwargs["dim0"] if "dim0" in kwargs else args[2],
  2110. ]
  2111. block_dims = [
  2112. kwargs["block_dim1"] if "block_dim1" in kwargs else args[3],
  2113. kwargs["block_dim0"] if "block_dim0" in kwargs else args[4],
  2114. ]
  2115. element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
  2116. return TMADescriptorExperimentalVariable(
  2117. data_ptr=ptr,
  2118. dims=dims,
  2119. block_dims=block_dims,
  2120. element_size=element_size,
  2121. )
  2122. class CreateTMADescriptorStableVariable(VariableTracker):
  2123. def call_function(
  2124. self,
  2125. tx: "InstructionTranslator",
  2126. args: "list[VariableTracker]",
  2127. kwargs: "dict[str, VariableTracker]",
  2128. ) -> "VariableTracker":
  2129. tensor = kwargs["tensor"] if "tensor" in kwargs else args[0]
  2130. block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1]
  2131. return TMADescriptorStableVariable(
  2132. tensor=tensor,
  2133. block_shape=block_shape,
  2134. )