| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493 |
- # mypy: ignore-errors
- """
- Function-related variable tracking classes for Dynamo's symbolic execution.
- This module contains classes that track different types of functions during graph
- compilation, including:
- - User-defined functions and methods
- - Built-in functions and methods
- - Wrapped functions (e.g. from decorators)
- - Special function types (e.g. functools.partial)
- - Triton kernels and related function types
- These classes are responsible for:
- - Tracking function calls and their arguments
- - Managing function closures and cell variables
- - Handling function attributes and special methods
- - Maintaining guards for function identity and closure contents
- - Supporting function inlining and specialization
- - Enabling proper symbolic execution of different function types
- The variable trackers here work together with the rest of Dynamo to enable
- accurate graph capture while handling Python's various function-related behaviors.
- """
- import builtins
- import functools
- import inspect
- import itertools
- import logging
- import sys
- import traceback
- import types
- from collections.abc import Sequence
- from types import FunctionType
- from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
- from typing_extensions import Never
- from unittest.mock import patch
- from weakref import WeakKeyDictionary
- import torch
- from torch._dynamo.exc import get_stack_above_dynamo
- from .. import config, graph_break_hints, polyfills, variables
- from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
- from ..exc import (
- get_dynamo_observed_exception,
- handle_observed_exception,
- InfiniteGeneratorError,
- ObservedException,
- ObservedGeneratorExit,
- ObservedUserStopIteration,
- raise_observed_exception,
- SkipFrame,
- unimplemented_v2,
- Unsupported,
- )
- from ..guards import GuardBuilder, install_guard
- from ..source import (
- AttrSource,
- ClosureSource,
- ConstantSource,
- DefaultsSource,
- GetItemSource,
- SkipGuardSource,
- )
- from ..utils import (
- check_constant_args,
- check_unspec_or_constant_args,
- cmp_name_to_op_mapping,
- counters,
- identity,
- is_function,
- is_wrapper_or_member_descriptor,
- istype,
- make_cell,
- )
- from .base import (
- AsPythonConstantNotImplementedError,
- AttributeMutationNew,
- ValueMutationNew,
- VariableTracker,
- )
- from .constant import ConstantVariable
- try:
- from torch.distributed.fsdp._fully_shard import _fsdp_param_group
- except ModuleNotFoundError:
- _fsdp_param_group = None
- if TYPE_CHECKING:
- from torch._dynamo.codegen import PyCodegen
- from torch._dynamo.symbolic_convert import InstructionTranslator
- from torch._higher_order_ops.triton_kernel_wrap import (
- TritonGridType,
- TritonKernelType,
- )
- _F = TypeVar("_F", bound=Callable)
- CO_VARARGS = 0x04
- CO_VARKEYWORDS = 0x08
- # Module‐level cache keyed by the function object
- _spec_cache = WeakKeyDictionary()
- class FunctionSpec:
- def __init__(self, func: FunctionType):
- code = func.__code__
- vn = code.co_varnames
- self.posonly_count = code.co_posonlyargcount
- self.arg_count = code.co_argcount
- self.kwonly_count = code.co_kwonlyargcount
- self.posonly_names = vn[: self.posonly_count]
- self.pos_or_kw_names = vn[self.posonly_count : self.arg_count]
- self.all_pos_names = self.posonly_names + self.pos_or_kw_names
- self.kwonly_names = vn[self.arg_count : self.arg_count + self.kwonly_count]
- off = self.arg_count + self.kwonly_count
- self.varargs_name = vn[off] if code.co_flags & CO_VARARGS else None
- off += 1 if self.varargs_name else 0
- self.varkw_name = vn[off] if code.co_flags & CO_VARKEYWORDS else None
- def update_defaults(self, func: FunctionType):
- # Defaults can change from function call to function call. So re-update
- # them on every call.
- self.defaults = func.__defaults__ or ()
- self.kwdefaults = func.__kwdefaults__ or {}
- # Map positional‐default names → their index in self.defaults
- self.pos_default_map = dict(
- zip(self.all_pos_names[-len(self.defaults) :], range(len(self.defaults)))
- )
- def _get_spec(func: FunctionType) -> FunctionSpec:
- spec = _spec_cache.get(func)
- if spec is None:
- spec = FunctionSpec(func)
- _spec_cache[func] = spec
- return spec
- def bind_args_cached(func, tx, fn_source, args, kwargs):
- spec = _get_spec(func)
- spec.update_defaults(func)
- ba = {}
- rem_kw = dict(kwargs)
- # 1) Bind all positional (pos-only + pos-or-kw)
- for i, name in enumerate(spec.all_pos_names):
- if i < len(args):
- ba[name] = wrap_bound_arg(tx, args[i])
- elif name in rem_kw:
- if name in spec.posonly_names:
- raise_observed_exception(
- TypeError,
- tx,
- args=[ConstantVariable.create(f"{name} is positional-only")],
- )
- ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
- elif name in spec.pos_default_map:
- idx = spec.pos_default_map[name]
- default_source = None
- if fn_source and not (
- ConstantVariable.is_literal(spec.defaults[idx])
- and config.skip_guards_on_constant_func_defaults
- ):
- default_source = DefaultsSource(fn_source, idx)
- ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source)
- else:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- ConstantVariable.create(
- f"Missing required positional argument: {name}"
- )
- ],
- )
- # 2) *args
- extra = args[len(spec.all_pos_names) :]
- if spec.varargs_name:
- ba[spec.varargs_name] = wrap_bound_arg(tx, tuple(extra))
- elif extra:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- ConstantVariable.create(
- f"Too many positional arguments: got {len(args)}, expected {len(spec.all_pos_names)}"
- )
- ],
- )
- # 3) Keyword-only
- for name in spec.kwonly_names:
- if name in rem_kw:
- ba[name] = wrap_bound_arg(tx, rem_kw.pop(name))
- elif name in spec.kwdefaults:
- kwdefault_source = None
- if fn_source:
- kwdefault_source = DefaultsSource(fn_source, name, is_kw=True)
- ba[name] = wrap_bound_arg(tx, spec.kwdefaults[name], kwdefault_source)
- else:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- ConstantVariable.create(
- f"Missing required keyword-only argument: {name}"
- )
- ],
- )
- # 4) **kwargs
- if spec.varkw_name:
- ba[spec.varkw_name] = wrap_bound_arg(tx, rem_kw)
- elif rem_kw:
- raise_observed_exception(
- TypeError,
- tx,
- args=[
- ConstantVariable.create(f"Unexpected keyword arguments: {list(rem_kw)}")
- ],
- )
- return ba
- def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
- # Source propagation is best effort since not every object we encounter has a source to begin with.
- if isinstance(val, VariableTracker):
- return val
- elif not source:
- return VariableTracker.build(tx, val)
- else:
- # Create a lazy variable to avoid guarding on __defaults__ unless really
- # needed.
- return variables.LazyVariableTracker.create(val, source)
- def wrap_args_kwargs(tx: "InstructionTranslator", result):
- for k, v in list(result.items()):
- if isinstance(v, (tuple, dict)):
- # args/kwargs
- result[k] = wrap_bound_arg(tx, v)
- def init_cellvars(parent, result: dict[str, VariableTracker], code):
- """
- Update `result` to add mapping from local name to new cells created
- directly by `code`, or update SideEffects in `parent` if the a local cell is
- already in `result` (cell argument).
- """
- side_effects = parent.output.side_effects
- for name in code.co_cellvars:
- new_cell = side_effects.track_cell_new()
- if name in result:
- # This handles when a function argument is a cell (e.g., captured by
- # a nested func). See `MAKE_CELL` bytecode for more info.
- side_effects.store_cell(new_cell, result.pop(name))
- result[name] = new_cell
- def _create_nested_fn(
- code, f_globals, name, defaults, closure, kwdefaults, annotations
- ):
- from types import FunctionType
- func = FunctionType(code, f_globals, name, defaults, closure)
- func.__kwdefaults__ = kwdefaults
- if isinstance(annotations, tuple):
- from itertools import pairwise
- annotations = dict(pairwise(annotations))
- # TypeError: __annotations__ must be set to a dict object
- assert annotations is None or isinstance(annotations, dict)
- func.__annotations__ = annotations
- return func
- fn_known_dunder_attrs = {
- "__annotations__",
- "__defaults__",
- "__kwdefaults__",
- "__code__",
- "__globals__",
- "__closure__",
- "__doc__",
- }
- def fn_var_getattr(tx, fn, source, name):
- source = source and AttrSource(source, name)
- if source and name == "__annotations__":
- # We get a large number of silly guards from annotations from inspect
- # module. Changing annotations is rare, and it impacting the extracted
- # graph is even rarer. So skip guards.
- source = SkipGuardSource(source)
- try:
- subobj = inspect.getattr_static(fn, name)
- except AttributeError:
- # function does not have a __getattr__ or __getattribute__ method,
- # so we can safely assume that this attribute is absent
- raise_observed_exception(AttributeError, tx)
- # Special handling for known dunder attributes
- if name in fn_known_dunder_attrs:
- subobj = getattr(fn, name)
- if source:
- return variables.LazyVariableTracker.create(subobj, source)
- return VariableTracker.build(tx, subobj)
- class BaseUserFunctionVariable(VariableTracker):
- def get_filename(self):
- return self.get_code().co_filename
- def get_name(self):
- return self.get_code().co_name
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
- def call_obj_hasattr(
- self, tx: "InstructionTranslator", name: str
- ) -> VariableTracker:
- result = False
- try:
- result = hasattr(self.get_function(), name)
- except NotImplementedError:
- if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
- result = True
- return variables.ConstantVariable.create(result)
- def inspect_parameter_names(self):
- return list(inspect.signature(self.get_function()).parameters)
- def closure_vars(self, tx):
- return {}
- class UserFunctionVariable(BaseUserFunctionVariable):
- """Some unsupported user-defined global function"""
- _nonvar_fields = {
- "fn",
- "is_constant",
- *BaseUserFunctionVariable._nonvar_fields,
- }
- @classmethod
- def create_with_source(cls, value, source):
- install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
- return cls(value, source=source)
- def __init__(self, fn, is_constant=False, **kwargs) -> None:
- super().__init__(**kwargs)
- if getattr(fn, "_dynamo_marked_constant", False):
- # This method should be treated as a constant for the purposes of compilation
- self.is_constant = True
- else:
- self.is_constant = False
- # TODO putting this here to avoid duplication, because we could hit this
- # from several paths (e.g., SuperVariable or `var_getattr`s).
- if not isinstance(fn, (types.FunctionType, torch.jit.ScriptFunction)):
- unimplemented_v2(
- gb_type="can't handle functions not implemented in python ",
- context=f"{fn}",
- explanation="Dynamo can only handle functions defined in python",
- hints=[
- "Move usage of this function out of `torch.compile` region",
- *graph_break_hints.INFERENCE_MODE,
- ],
- )
- # TODO(anijain2305) - Replace directly calling UserFunctionVariable with
- # VariableBuilder, which handles the wrapping of _torchdynamo_inline.
- # unpack @torch._dynamo.optimize()(fn) wrapped function
- fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
- self.fn: types.FunctionType = fn
- def as_python_constant(self):
- if istype(self, UserFunctionVariable):
- return self.fn
- # subclasses (such as methods) usually aren't a constant
- return super().as_python_constant()
- def self_args(self):
- return []
- def get_function(self):
- return self.fn
- def get_code(self):
- return self.fn.__code__
- def python_type(self):
- return types.FunctionType
- def has_self(self):
- return getattr(self.fn, "__self__", None) is not None
- def get_globals(self):
- return self.fn.__globals__
- def get_source(self):
- source = self.source
- if source and isinstance(self, variables.UserMethodVariable):
- source = self.source_fn
- return source
- def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]:
- """
- Assume `args` and `kwargs` are VariableTracker arguments for a call to
- this function, create new bindings for initial locals.
- """
- assert not self.is_constant
- fn: types.FunctionType = self.fn
- if not isinstance(fn, FunctionType):
- raise TypeError("Only supports regular Python functions.")
- root_tx = parent.output.root_tx
- source = self.get_source()
- result = bind_args_cached(fn, root_tx, source, args, kwargs)
- init_cellvars(parent, result, fn.__code__)
- closure = self.fn.__closure__ or ()
- assert len(closure) == len(self.fn.__code__.co_freevars)
- for idx, name, cell in zip(
- itertools.count(), self.fn.__code__.co_freevars, closure
- ):
- # TODO refactor these 3 branches.
- side_effects = parent.output.side_effects
- if cell in side_effects:
- cell_var = side_effects[cell]
- elif source:
- closure_cell = GetItemSource(ClosureSource(source), idx)
- closure_cell_contents = AttrSource(closure_cell, "cell_contents")
- try:
- contents_var = VariableTracker.build(
- parent, cell.cell_contents, closure_cell_contents
- )
- except ValueError:
- # Cell has not yet been assigned
- contents_var = variables.DeletedVariable()
- cell_var = side_effects.track_cell_existing(
- closure_cell, cell, contents_var
- )
- else:
- # TODO figure out why source isn't available here, and whether
- # we can fix that and remove this branch.
- try:
- contents_var = VariableTracker.build(parent, cell.cell_contents)
- except ValueError:
- # Cell has not yet been assigned
- contents_var = variables.DeletedVariable()
- cell_var = side_effects.track_cell_existing(None, cell, contents_var)
- result[name] = cell_var
- return result
- def var_getattr(self, tx: "InstructionTranslator", name: str):
- if name in cmp_name_to_op_mapping:
- return variables.GetAttrVariable(self, name)
- source = self.get_source()
- return fn_var_getattr(tx, self.fn, source, name)
- def call_obj_hasattr(
- self, tx: "InstructionTranslator", name: str
- ) -> VariableTracker:
- result = hasattr(self.fn, name)
- return variables.ConstantVariable.create(result)
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- # Handle patch_dynamo_config call
- if self.fn is torch._dynamo.patch_dynamo_config:
- try:
- args_const = [arg.as_python_constant() for arg in args]
- kwargs_const = {
- key: val.as_python_constant() for key, val in kwargs.items()
- }
- changes = torch._dynamo.patch_dynamo_config(
- *args_const, **kwargs_const
- ).changes
- return variables.DynamoConfigPatchVariable(changes)
- except AsPythonConstantNotImplementedError as e:
- raise RuntimeError(
- "Cannot convert patch_dynamo_config args/kwargs to constants. "
- "Please fix your call to patch_dynamo_config by using simpler inputs. "
- f"args: {args}, kwargs: {kwargs}"
- ) from e
- elif self.fn is torch._dynamo.error_on_graph_break:
- try:
- bound = inspect.signature(self.fn).bind(*args, **kwargs)
- error_on_graph_break = bound.arguments[
- "error_on_graph_break"
- ].as_python_constant()
- assert isinstance(error_on_graph_break, bool)
- return variables.ErrorOnGraphBreakVariable(error_on_graph_break)
- except Exception as e:
- raise RuntimeError(
- "Improper error_on_graph_break() call. Please fix your call to error_on_graph_break(). "
- f"args: {args}, kwargs: {kwargs}"
- ) from e
- # Handle a `nonstrict_trace(fn)` call
- elif self.fn is torch._dynamo.nonstrict_trace:
- bound = inspect.signature(self.fn).bind(*args, **kwargs)
- fn_var = bound.args[0]
- if not isinstance(fn_var, BaseUserFunctionVariable):
- typ = fn_var.python_type()
- msg = f"`nonstrict_trace` expects a callable, but got value of type <{typ.__name__}>"
- unimplemented_v2(
- gb_type="TypeError from user code",
- context=f"call_function({self.value}, {args}, {kwargs})",
- explanation=msg,
- hints=[
- *graph_break_hints.USER_ERROR,
- ],
- )
- if not isinstance(fn_var, UserFunctionVariable):
- fn_name = fn_var.get_name()
- msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950
- unimplemented_v2(
- gb_type="Limitation of `nonstrict_trace",
- context=f"{self}",
- explanation=msg,
- hints=[
- f"make sure definition of {fn_name} is outside ",
- "`torch.compile` region",
- ],
- )
- fn = fn_var.fn
- return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True)
- if self.is_constant:
- return invoke_and_store_as_constant(
- tx, self.fn, self.get_name(), args, kwargs
- )
- if (
- not tx.output.current_tracer.unsafe_allow_externally_visible_side_effects
- and self.fn
- is torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer
- ):
- with torch._dynamo.side_effects.allow_externally_visible_side_effects_in_subtracer(
- tx
- ):
- return super().call_function(tx, args, kwargs)
- if (
- tx.output.current_tracer.under_activation_checkpoint
- and not tx.output.current_tracer.allow_side_effects_under_checkpoint
- ):
- try:
- from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState
- except Exception:
- FSDPState = None
- if FSDPState is not None and self.fn in [
- FSDPState._pre_forward,
- FSDPState._post_forward,
- ]:
- with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
- return super().call_function(tx, args, kwargs)
- return super().call_function(tx, args, kwargs)
- class BuiltinMethodVariable(BaseUserFunctionVariable):
- def __init__(self, fn, is_constant=False, **kwargs) -> None:
- super().__init__(**kwargs)
- assert isinstance(fn, types.BuiltinMethodType)
- self.fn = fn
- @staticmethod
- def is_supported_builtin_method(obj):
- method_self = obj.__self__
- method_name = obj.__name__
- # TODO(anijain2305) - Add support for more builtin methods
- # Supports tuple.__new__ and frozenset({....}).__contains__
- return (method_self is tuple and method_name == "__new__") or (
- type(method_self) is frozenset and method_name == "__contains__"
- )
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- method_self = self.fn.__self__
- name = self.fn.__name__
- obj_source = self.source and AttrSource(self.source, "__self__")
- obj_vt = VariableTracker.build(tx, method_self, obj_source)
- return obj_vt.call_method(tx, name, args, kwargs)
- class LocalGeneratorObjectVariable(VariableTracker):
- def __init__(
- self,
- code: types.CodeType,
- f_globals,
- inline_tracer: Optional["InstructionTranslator"],
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.code = code
- self.f_globals = f_globals
- self.inline_tracer = inline_tracer
- def get_code(self):
- return self.code
- def get_filename(self):
- return self.get_code().co_filename
- def get_name(self):
- return self.get_code().co_name
- def get_function(self):
- raise NotImplementedError
- def has_self(self):
- return False
- def __name__(self):
- return self.get_name()
- def __str__(self):
- return f"{self.__class__.__name__}({self.get_name()})"
- __repr__ = __str__
- def reconstruct(self, codegen: "PyCodegen"):
- from torch._dynamo.side_effects import disallow_side_effects_in_generator
- from torch._dynamo.symbolic_convert import (
- InstructionTranslator,
- save_and_restart_speculation_log,
- temporarely_allow_writes_to_output_graph,
- )
- tx = InstructionTranslator.current_tx()
- save = save_and_restart_speculation_log(tx)
- disallow = disallow_side_effects_in_generator(tx)
- temp = temporarely_allow_writes_to_output_graph(tx)
- with save, disallow, temp:
- tracer = self._get_inline_tracer(tx)
- if not tracer.generator_exhausted:
- self.remaining_items = self.force_unpack_var_sequence(tx)
- variables.ListIteratorVariable(self.remaining_items).reconstruct(codegen)
- def bind_args(self, tx, args, kwargs):
- return self.fn.bind_args(tx, args, kwargs)
- def get_globals(self):
- return self.f_globals
- def python_type(self):
- return types.GeneratorType
- def _get_inline_tracer(self, tx):
- from torch._dynamo.symbolic_convert import InliningInstructionTranslator
- if self.inline_tracer is None:
- self.inline_tracer = InliningInstructionTranslator.build_inline_tracer(
- tx, self, [], {}
- )
- return self.inline_tracer
- def next_variable(self, tx):
- tracer = self._get_inline_tracer(tx)
- if self._is_generator_exhausted():
- raise_observed_exception(StopIteration, tx)
- try:
- # Hierarchically, tx can be seen as the parent of the inline tracer
- # created on call_function. Any exception needs to be propagated to tx
- # for Dynamo to behave correctly
- with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
- return tracer.inline_call_()
- except ObservedException as e:
- tracer.generator_exhausted = True
- raise e
- except InfiniteGeneratorError:
- # test/dynamo/test_misc.py::test_iterator_limit
- raise
- except Unsupported as e:
- torch._dynamo.eval_frame.skip_code(self.get_code())
- raise SkipFrame from e
- finally:
- counters["unimplemented"] |= counters["inline_call"]
- def call_obj_hasattr(self, tx, name):
- if name in self.python_type().__dict__:
- return ConstantVariable.create(True)
- return ConstantVariable.create(False)
- def has_unpack_var_sequence(self, tx):
- return False
- def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
- return True
- def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
- result = []
- self.force_apply_to_var_sequence(tx, result.append)
- return result
- def force_apply_to_var_sequence(self, tx, fn) -> None:
- while True:
- try:
- fn(self.next_variable(tx))
- except ObservedUserStopIteration:
- handle_observed_exception(tx)
- break
- def _setup_exception(self, tx, exc):
- tracer = self._get_inline_tracer(tx)
- try:
- tracer._raise_exception_variable(exc)
- except ObservedException as e:
- # if no handler is available (i.e. user code doesn't catch it), the
- # exception is raised again.
- tracer.exception_handler(e)
- def _is_generator_just_started(self):
- return self.inline_tracer is None or self.inline_tracer.instruction_pointer == 0
- def _is_generator_exhausted(self):
- return getattr(self.inline_tracer, "generator_exhausted", False)
- def call_method(
- self,
- tx: "InstructionTranslator",
- name: str,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__next__":
- return self.next_variable(tx)
- elif name == "__iter__":
- # iter(gen) returns itself
- return self
- elif name == "send":
- # Sends a value into the generator function. Returns the next value
- # yielded by the generator, or raises StopIteration if the generator
- # exits without yielding another value
- if self._is_generator_just_started() and len(args):
- # can't send non-None value to a just-started generator
- # Test: GeneratorCPythonTests.test_send_non_none_to_new_gen
- if not all(
- isinstance(arg, ConstantVariable) and arg.value is None
- for arg in args
- ):
- raise_observed_exception(TypeError, tx)
- tracer = self._get_inline_tracer(tx)
- tracer.push_many(args)
- return self.next_variable(tx)
- elif name == "close":
- # * Raises a GeneratorExit at the point where the generator function was paused.
- # * If the generator function catches the exception and returns a
- # value, this value is returned from close() - Python 3.13+
- # * If the generator function is already closed, or raises GeneratorExit
- # (by not catching the exception), close() returns None.
- # * If the generator yields a value, a RuntimeError is raised.
- # * If the generator raises any other exception, it is propagated to the caller.
- # * If the generator has already exited due to an exception or normal
- # exit, close() returns None and has no other effect.
- # Return None if close is called on a just-started generator
- # See test GeneratorCloseCpythonTests::test_close_not_started
- tracer = self._get_inline_tracer(tx)
- if self._is_generator_just_started() or self._is_generator_exhausted():
- tracer.generator_exhausted = True
- return variables.ConstantVariable(None)
- # Raise GeneratorExit to see if user code catches it. Any other exception
- # is propagated to the parent frame.
- try:
- self._setup_exception(
- tx, variables.ExceptionVariable(GeneratorExit, ())
- )
- # There's an extra block on Python 3.12+ to handle StopIteration
- # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397
- #
- # 1 0 RETURN_GENERATOR
- # 2 POP_TOP
- # 4 RESUME 0
- # 2 6 LOAD_CONST 1 (1)
- # 8 YIELD_VALUE 1
- # 10 RESUME 1
- # 12 POP_TOP
- # 14 RETURN_CONST 0 (None)
- # >> 16 CALL_INTRINSIC_1 3 (INTRINSIC_STOPITERATION_ERROR)
- # 18 RERAISE 1
- # ExceptionTable:
- # 4 to 14 -> 16 [0] lasti
- if (
- sys.version_info >= (3, 12)
- and tracer.next_instruction.opname == "CALL_INTRINSIC_1"
- ):
- tracer.generator_exhausted = True
- return variables.ConstantVariable(None)
- except ObservedGeneratorExit:
- # If it doesn't catch, we just return None, as per the text above
- tracer.generator_exhausted = True
- return variables.ConstantVariable(None)
- try:
- # Raise RuntimeError if the generator yields any other value
- if self.next_variable(tx):
- raise_observed_exception(RuntimeError, tx)
- except ObservedGeneratorExit:
- tracer.generator_exhausted = True
- return variables.ConstantVariable(None)
- except ObservedUserStopIteration:
- # In Python 3.13+, one can capture GeneratorExit and return a value
- # See test_generator.py::test_close_capture_GeneratorExit_return
- # https://discuss.python.org/t/let-generator-close-return-stopiteration-value/24786/26
- # https://github.com/python/cpython/pull/104771
- assert tracer.symbolic_result is not None
- return tracer.symbolic_result
- elif name == "throw":
- # * Raises an exception at the point where the generator was paused, and
- # returns the next value yielded by the generator.
- # * If the generator exits without yielding, raise StopIteration
- # * If the generator function does not catch the passed-in exception,
- # or raises a different exception, then that exception propagates to the caller.
- # Setup the exception table and jump target in case of try...finally
- tracer = self._get_inline_tracer(tx)
- try:
- # In Python 3.9, the exception is represented as a triple (typ, val, tb)
- # In such cases, we re-raise the exception object given to avoid
- # creating a new object, so that IS_OP works.
- # See: https://github.com/pytorch/pytorch/pull/146496
- self._setup_exception(tx, args[1] if len(args) == 3 else args[0])
- except ObservedException: # noqa: TRY203
- # propagate the exception back to the parent caller
- raise
- retval = self.next_variable(tx)
- # The exception raised before is still active. We need to check the exception
- # table one more time to find the next target. But why? Let’s walk
- # through an example and its generated bytecode: https://godbolt.org/z/ebdTbMv8M
- #
- # z = 0
- # def whoo():
- # global z
- # z = 0
- # try:
- # yield 1
- # except ValueError:
- # yield 2
- # finally:
- # z += 1
- # z += 10
- #
- # gen = whoo()
- # next(gen)
- # gen.throw(ValueError)
- # print('z', z) -> z = 1
- #
- # ...
- # >> 58 PUSH_EXC_INFO
- #
- # 8 60 LOAD_GLOBAL 2 (ValueError)
- # 70 CHECK_EXC_MATCH
- # 72 POP_JUMP_IF_FALSE 7 (to 88)
- # 74 POP_TOP
- #
- # 9 76 LOAD_CONST 3 (2)
- # 78 YIELD_VALUE 3 <------ ValueError is still active here
- # 80 RESUME 1
- # 82 POP_TOP
- # 84 POP_EXCEPT
- # 86 jump_backward 34 (to 20)
- # ...
- #
- # ExceptionTable:
- # 4 to 8 -> 124 [0] lasti
- # 12 to 18 -> 58 [0]
- # 20 to 56 -> 124 [0] lasti
- # 58 to 82 -> 90 [1] lasti <------ move to 90
- # 84 to 86 -> 96 [0]
- # 88 to 88 -> 90 [1] lasti
- # 90 to 94 -> 96 [0]
- # 96 to 116 -> 118 [1] lasti
- # 118 to 122 -> 124 [0] lasti
- #
- # In this scenario, a generator can yield after `throw()` is called. Even
- # after the exception is raised a few lines above, it remains active
- # within the `78 YIELD_VALUE` instruction. When the generator resumes
- # after the second yield on instruction `80 RESUME`, we cannot simply
- # return the control flow to the next instruction. Instead, one must
- # check the exception table (or equivalent) to find the next target
- # In this case, it says the instruction pointer must be moved to 90.
- #
- # Without this step, if we let the trace proceed to the next
- # instruction, it would follow the control flow where the exception
- # raised by `throw()` was handled and swallowed, potentially leading
- # to incorrect behavior.
- exc_type = type("__InternalThrowException", (Exception,), {})
- try:
- self._setup_exception(tx, variables.ExceptionVariable(exc_type, ()))
- self.next_variable(tx)
- except get_dynamo_observed_exception(exc_type):
- # We should get back the exception raised before.
- pass
- else:
- raise_observed_exception(RuntimeError, tracer)
- return retval
- super().call_method(tx, name, args, kwargs)
- class ContextlibContextManagerLocalGeneratorObjectVariable(
- LocalGeneratorObjectVariable
- ):
- """
- .. note::
- This is only used when the function is annotated with @contextlib.contextmanager
- It is a special case of a generator function as we do not allow return a context manager
- from a torch.compile function.
- """
- class LocalGeneratorFunctionVariable(BaseUserFunctionVariable):
- """functions that behaves like iterators
- .. note::
- This is a wrapper around (Nested)UserFunctionVariable
- """
- def __init__(
- self,
- vt: VariableTracker,
- *,
- generator_cls=LocalGeneratorObjectVariable,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.vt = vt
- self.generator_cls = generator_cls
- def __getattr__(self, name):
- if name in self.__class__.__dict__.keys():
- return getattr(self, name)
- return getattr(self.vt, name)
- def _build_inline_tracer(self, tx, args, kwargs):
- from torch._dynamo.symbolic_convert import InliningInstructionTranslator
- return InliningInstructionTranslator.build_inline_tracer(
- tx,
- self,
- args,
- kwargs,
- )
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if not is_generator(self.vt.get_code()):
- unimplemented_v2(
- gb_type="non-generator contextlib.contextmanager",
- context=str(self.vt.get_code()),
- explanation="Cannot compile function decorated with `@contextlib.contextmanager` that is not a generator"
- ", i.e. does not use `yield`",
- hints=[
- "Use `yield` in the function body instead of `return`.",
- "Remove the `@contextlib.contextmanager` decorator.",
- ],
- )
- inline_tracer = self._build_inline_tracer(tx, args, kwargs)
- code = self.vt.get_code()
- f_globals = self.vt.get_globals()
- # calling a generator returns a generator object
- return self.generator_cls(
- code,
- f_globals,
- inline_tracer,
- source=self.source,
- )
- class FunctionDecoratedByContextlibContextManagerVariable(
- LocalGeneratorFunctionVariable
- ):
- """
- .. note::
- This is only used when the function is annotated with @contextlib.contextmanager
- """
- def __init__(self, vt, **kwargs):
- super().__init__(
- vt,
- generator_cls=ContextlibContextManagerLocalGeneratorObjectVariable,
- **kwargs,
- )
- def _build_inline_tracer(self, tx, args, kwargs):
- # NOTE: This only exists to not break support for context manager when
- # config.enable_faithful_generator_behavior = False and
- # config.enable_trace_contextlib = True. In case the former is false,
- # Dynamo should still be able to trace through @contextmanager functions
- tracer = super()._build_inline_tracer(tx, args, kwargs)
- assert isinstance(
- tracer,
- torch._dynamo.symbolic_convert.InliningGeneratorInstructionTranslator,
- )
- tracer.is_generator_from_ctx_manager = True
- return tracer
- class UserMethodVariable(UserFunctionVariable):
- """Some unsupported user-defined method"""
- def __init__(self, fn, obj, source_fn=None, **kwargs) -> None:
- super().__init__(fn=fn, **kwargs)
- self.obj = obj
- self.source_fn = source_fn
- # Note on source and source_fn
- # Be careful with `source` when delegating to UserFunctionVariable
- # (base-class) methods. In this __init__, `source` is a *bound method*
- # object, but the base class expects the underlying *function* object.
- # One way is to simplly use `__func__` to unwrap it.
- #
- # For recursive dict-tag optimizations, it can be faster to fetch the
- # function directly from `cls.__dict__`; that’s why we pass on
- # `source_fn`. Whenever it is possible to access the function from
- # cls.__dict__, we pass that on to `source_fn`. Because bind_args
- # operates on the unbound function, most guards should target
- # `source_fn` rather than the original `source`.
- if source_fn is None and kwargs.get("source") is not None:
- self.source_fn = AttrSource(kwargs.get("source"), "__func__")
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}({self.fn}, {self.obj})"
- def self_args(self):
- return [self.obj]
- def python_type(self):
- return types.MethodType
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- # NOTE this is to handle methods annotated by `nonstrict_trace`. Usually
- # a `nonstrict_trace`-ed function will be wrapped by
- # `VariableTracker.build` and route to `TorchInGraphFunctionVariable`,
- # but in the case of method, we manually wrap it with `UserMethodVariable`
- # inside `UserDefinedObjectVariable.var_getattr`.
- #
- # We might be able to simplify this away by canonicalizing the
- # function/method wrapping code paths.
- from ..trace_rules import is_nonstrict_trace_callable
- if is_nonstrict_trace_callable(self.fn):
- call_args = [*self.self_args(), *args]
- var = variables.TorchInGraphFunctionVariable(
- self.fn, nonstrict_traceable=True
- )
- return var.call_function(tx, call_args, kwargs)
- # For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
- # rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
- # since we ensure `forward` of allowed modules can be traced by AOT safely.
- # Note this is not only for allowed modules, as user customized modules can extend from
- # allowed modules but using parent's `forward` method, which is also covered by this branch.
- # If we are tracing the higher order op, we want Dynamo to step inside
- # the module call so that Dynamo can see the underlying parameters and
- # buffers and raise them as inputs to the graph. The is_root_tracer
- # check bypasses the if condition for non-root tracers and directly
- # calls the super().call_function at the end, which is basically
- # equivalent of inlining the method.
- if tx.output.is_root_tracer() and isinstance(
- self.obj, variables.NNModuleVariable
- ):
- module_attr = getattr(self.fn, "__module__", "")
- # inline torch.nn.utils.parametrize
- if (
- module_attr is not None
- and module_attr.startswith("torch.nn.")
- and module_attr != "torch.nn.utils.parametrize"
- or self.is_constant
- ):
- return self.obj.call_method(
- tx, self.fn.__name__, args, kwargs, constant=self.is_constant
- )
- elif (
- _fsdp_param_group is not None
- and self.fn is _fsdp_param_group.FSDPParamGroup.use_training_state
- ):
- return variables.TorchCtxManagerClassVariable(self.fn).call_function(
- tx, (self.obj, *args), kwargs
- )
- if self.is_constant:
- fn = getattr(self.obj.value, self.fn.__name__)
- return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
- return super().call_function(tx, args, kwargs)
- def inspect_parameter_names(self):
- return super().inspect_parameter_names()[1:]
- def var_getattr(self, tx: "InstructionTranslator", name: str):
- if name == "__self__":
- return self.obj
- if name == "__func__":
- # We might have a better way to access the function object, this
- # information is stored in self.source_fn, use that to construct the
- # variable tracker.
- return VariableTracker.build(tx, self.fn, self.source_fn)
- return super().var_getattr(tx, name)
- class WrappedUserMethodVariable(UserMethodVariable):
- def __init__(self, wrapped, context, **kwargs) -> None:
- kwargs.pop("fn", None)
- kwargs.pop("obj", None)
- super().__init__(wrapped.fn, wrapped.obj, **kwargs)
- self.wrapped = wrapped
- self.context = context
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- self.context.enter(tx)
- result = super().call_function(tx, args, kwargs)
- self.context.exit(tx)
- return result
- def reconstruct(self, codegen):
- codegen.add_push_null(lambda: codegen(self.context))
- codegen(self.wrapped)
- codegen.extend_output(create_call_function(1, False))
- class WrappedUserFunctionVariable(UserFunctionVariable):
- def __init__(self, wrapped, context, **kwargs) -> None:
- kwargs.pop("fn", None)
- super().__init__(wrapped.fn, **kwargs)
- self.wrapped = wrapped
- self.context = context
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- self.context.enter(tx)
- result = super().call_function(tx, args, kwargs)
- self.context.exit(tx)
- return result
- def reconstruct(self, codegen):
- codegen.add_push_null(lambda: codegen(self.context))
- codegen(self.wrapped)
- codegen.extend_output(create_call_function(1, False))
- def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs):
- def convert(x):
- if isinstance(x, variables.TensorVariable):
- return x.get_real_value()
- return x.as_python_constant()
- args = [convert(x) for x in args]
- kwargs = {k: convert(v) for k, v in kwargs.items()}
- res = fn(*args, **kwargs)
- return tx.output.register_attr_or_module(
- res,
- name,
- source=ConstantSource(name),
- )
- class NestedUserFunctionVariable(BaseUserFunctionVariable):
- _nonvar_fields = {
- "f_globals",
- *BaseUserFunctionVariable._nonvar_fields,
- }
- def __init__(
- self,
- fn_name,
- code,
- f_globals,
- defaults,
- kwdefaults,
- annotations,
- closure,
- # This is present when this function is created by
- # `functools.wrap(wrapped_fn)(this_fn)`.
- wrapped_fn=None,
- **kwargs,
- ) -> None:
- if kwargs.get("mutation_type") is None:
- kwargs.update(mutation_type=AttributeMutationNew())
- super().__init__(**kwargs)
- assert isinstance(fn_name.as_python_constant(), str)
- assert isinstance(code.as_python_constant(), types.CodeType)
- assert isinstance(f_globals, dict)
- self.fn_name = fn_name
- self.code = code
- self.f_globals = f_globals
- self.defaults = defaults
- self.kwdefaults = kwdefaults
- self.annotations = annotations
- self.closure = closure
- self.wrapped_fn: Optional[VariableTracker] = wrapped_fn
- def self_args(self):
- return []
- def get_code(self):
- return self.code.as_python_constant()
- def python_type(self):
- return types.FunctionType
- def get_function(self):
- if self.closure:
- raise NotImplementedError
- func = types.FunctionType(
- self.code.as_python_constant(),
- self.f_globals,
- self.fn_name.as_python_constant(),
- )
- if self.defaults:
- func.__defaults__ = self.defaults.as_python_constant()
- if self.kwdefaults:
- func.__kwdefaults__ = self.kwdefaults.as_python_constant()
- if self.annotations:
- annotations = self.annotations.as_python_constant()
- if isinstance(annotations, tuple):
- from itertools import pairwise
- annotations = dict(pairwise(annotations))
- # TypeError: __annotations__ must be set to a dict object
- assert isinstance(annotations, dict)
- func.__annotations__ = annotations
- return func
- def call_setattr(
- self,
- tx: "InstructionTranslator",
- name_var: VariableTracker,
- val: VariableTracker,
- ):
- tx.output.side_effects.store_attr(self, name_var.value, val)
- return ConstantVariable(None)
- def call_method(self, tx, name, args, kwargs):
- if name == "__setattr__":
- return self.call_setattr(tx, *args)
- return super().call_method(tx, name, args, kwargs)
- def has_closure(self):
- return self.closure is not None
- def const_getattr(self, tx, name):
- if name == "__name__":
- return self.fn_name.as_python_constant()
- return super().const_getattr(tx, name)
- def has_self(self):
- return False
- def get_globals(self):
- return self.f_globals
- def bind_args(self, parent, args, kwargs):
- code = self.get_code()
- func = types.FunctionType(
- code,
- self.f_globals,
- self.fn_name.as_python_constant(),
- tuple(self.defaults.items) if self.defaults else None,
- tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
- )
- if self.kwdefaults:
- func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
- bound = inspect.signature(func).bind(*args, **kwargs)
- bound.apply_defaults()
- result = dict(bound.arguments.items())
- wrap_args_kwargs(parent.output.root_tx, result)
- init_cellvars(parent, result, code)
- for idx, name in enumerate(code.co_freevars):
- assert name not in result
- cell = self.closure.items[idx]
- result[name] = cell
- return result
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(
- lambda: codegen.load_import_from(__name__, "_create_nested_fn")
- )
- codegen(self.code)
- codegen.extend_output([codegen.create_load_const_unchecked(self.f_globals)])
- codegen(ConstantVariable.create(self.code.value.co_name))
- if self.defaults:
- codegen(self.defaults)
- else:
- codegen.extend_output([codegen.create_load_const(None)])
- if self.closure:
- codegen(self.closure)
- else:
- codegen.extend_output([codegen.create_load_const(None)])
- if self.kwdefaults:
- codegen(self.kwdefaults)
- else:
- codegen.extend_output([codegen.create_load_const(None)])
- if self.annotations:
- try:
- annotations = self.annotations.as_python_constant()
- codegen.extend_output(
- [codegen.create_load_const_unchecked(annotations)]
- )
- except NotImplementedError:
- codegen(self.annotations)
- else:
- codegen.extend_output([codegen.create_load_const(None)])
- codegen.extend_output(create_call_function(7, False))
- if self.wrapped_fn:
- codegen.add_push_null(
- lambda: codegen.load_import_from("functools", "wraps")
- )
- codegen(self.wrapped_fn)
- codegen.extend_output(create_call_function(1, False))
- codegen.extend_output(create_rot_n(2))
- codegen.extend_output(create_call_function(1, True))
- # codegen attributes
- from torch._dynamo.symbolic_convert import InstructionTranslator
- tx = InstructionTranslator.current_tx()
- if tx.output.side_effects.has_pending_mutation(self):
- for name, value in tx.output.side_effects.store_attr_mutations[
- self
- ].items():
- codegen.dup_top()
- codegen(value)
- codegen.extend_output(create_rot_n(2))
- codegen.store_attr(name)
- class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable):
- def __init__(self, wrapped, context, **kwargs) -> None:
- kwargs.pop("fn_name", None)
- kwargs.pop("code", None)
- kwargs.pop("f_globals", None)
- kwargs.pop("defaults", None)
- kwargs.pop("kwdefaults", None)
- kwargs.pop("annotations", None)
- kwargs.pop("closure", None)
- kwargs.pop("wrapped_fn", None)
- super().__init__(
- wrapped.fn_name,
- wrapped.code,
- wrapped.f_globals,
- wrapped.defaults,
- wrapped.kwdefaults,
- wrapped.annotations,
- wrapped.closure,
- wrapped.wrapped_fn,
- )
- self.wrapped = wrapped
- self.context = context
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- self.context.enter(tx)
- result = super().call_function(tx, args, kwargs)
- self.context.exit(tx)
- return result
- def reconstruct(self, codegen):
- codegen.add_push_null(lambda: codegen(self.context))
- codegen(self.wrapped)
- codegen.extend_output(create_call_function(1, False))
- class SkipFunctionVariable(VariableTracker):
- _nonvar_fields = {
- "value",
- "reason",
- *VariableTracker._nonvar_fields,
- }
- def __init__(self, value, reason=None, **kwargs) -> None:
- super().__init__(**kwargs)
- self.value = value
- self.reason = reason
- def as_python_constant(self):
- return self.value
- @classmethod
- def create_with_source(cls, value, source):
- # Use closure match guard (i.e. guard on __code__ object instead of
- # function id) to avoid guarding on nested functions.
- if inspect.getattr_static(value, "_torchdynamo_disable", False):
- # For torch._dynamo.disable function, ensure that the original
- # function is guarded. Otherwise, the else branch will guard on the
- # _dynamo.disable.__code__
- guard_on_source = source
- guard_on_value = value
- while getattr(guard_on_value, "_torchdynamo_orig_callable", False):
- guard_on_value = guard_on_value._torchdynamo_orig_callable
- guard_on_source = AttrSource(
- guard_on_source, "_torchdynamo_orig_callable"
- )
- guard_on_source.make_guard(GuardBuilder.CLOSURE_MATCH)
- elif not is_wrapper_or_member_descriptor(value):
- # These descriptors are not guaranteed to return the same object on
- # attribute lookup. They are unlikely to be changed, so we can skip
- # guarding them.
- install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
- return cls(value, source=source)
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
- msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None)
- unimplemented_v2(
- gb_type="Skip calling `torch.compiler.disable()`d function",
- context=str(self.value),
- explanation=f"Skip calling function `{self.value}` since it was wrapped "
- f"with `torch.compiler.disable` (reason: {msg})",
- hints=[
- "Remove the `torch.compiler.disable` call",
- ],
- )
- elif self.value is torch._dynamo.graph_break:
- graph_break_msg = kwargs.get("msg", None)
- if graph_break_msg:
- graph_break_msg = graph_break_msg.as_python_constant()
- unimplemented_v2(
- gb_type="Call to `torch._dynamo.graph_break()`",
- context=f"Called `torch._dynamo.graph_break()` with args `{args}`, kwargs `{kwargs}`",
- explanation=f"User-inserted graph break. Message: {graph_break_msg}",
- hints=[
- "Remove the `torch._dynamo.graph_break()` call.",
- ],
- )
- elif self.value is torch._dynamo.skip_frame:
- skip_frame_msg = kwargs.get("msg", None)
- if skip_frame_msg:
- skip_frame_msg = skip_frame_msg.as_python_constant()
- raise SkipFrame(
- f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
- )
- else:
- if config.dont_skip_tracing:
- from .builder import SourcelessBuilder
- # re-build the function, attempting to not skip
- rebuilt_fn = SourcelessBuilder.create(tx, self.value)
- # if we still get SkipFunctionVariable, then we *really* should skip this function
- if not isinstance(rebuilt_fn, SkipFunctionVariable):
- return rebuilt_fn.call_function(tx, args, kwargs)
- qualname = getattr(self.value, "__qualname__", "<unknown qualname>")
- module_or = getattr(self.value, "__module__", None)
- module_name = "<unknown module>" if module_or is None else str(module_or)
- try:
- path = inspect.getfile(self.value)
- explanation = (
- f"Dynamo developers have intentionally marked that the function `{qualname}` "
- f"in file `{path}` should not be traced."
- )
- hints = [
- f"Avoid calling the function `{qualname}`.",
- ]
- # TODO improve trace_rules reasoning to provide better hints.
- # How do we tell that a function/file should NOT be removed from skip files?
- # Do a very basic check for now.
- if "_dynamo" not in path:
- hints += [
- f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{qualname}` "
- "to force tracing into the function. "
- "More graph breaks may occur as a result of attempting to trace into the function.",
- "Please file an issue to PyTorch.",
- ]
- except TypeError:
- known_python_builtin_modules = {"_abc", "_warnings"}
- if module_or in known_python_builtin_modules:
- explanation = (
- f"Dynamo does not know how to trace the Python builtin "
- f"`{module_name}.{qualname}`."
- )
- hints = [
- "If you are attempting to call a logging function (e.g. `_warnings.warn`), "
- "you can try adding it to `torch._dynamo.config.reorderable_logging_functions`.",
- "Please file an issue on GitHub "
- "so the PyTorch team can add support for it. ",
- ]
- elif module_or is not None and module_or.startswith("optree"):
- explanation = f"Dynamo cannot trace optree C/C++ function {module_name}.{qualname}."
- hints = [
- " Consider using torch.utils._pytree - "
- "https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py"
- ]
- # also warn on it because most users won't see the graph break message
- torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
- else:
- explanation = (
- f"Dynamo does not know how to trace the builtin `{module_name}.{qualname}.` "
- f"This function is either a Python builtin (e.g. _warnings.warn) "
- f"or a third-party C/C++ Python extension (perhaps created with pybind)."
- )
- hints = [
- "If it is a Python builtin, please file an issue on GitHub "
- "so the PyTorch team can add support for it and see the next case for a workaround.",
- "If it is a third-party C/C++ Python extension, please "
- "either wrap it into a PyTorch-understood custom operator "
- "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html "
- "for more details) or, if it is traceable, use "
- "`torch.compiler.allow_in_graph`.",
- ]
- # also warn on it because most users won't see the graph break message
- torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
- if qualname == "allow_in_graph":
- explanation = (
- "Found an allow_in_graph decorator to a function which "
- "is created inside the parent function that is getting "
- "compiled. This is not supported for now."
- )
- hints = []
- reason = self.reason if self.reason else "<missing reason>"
- unimplemented_v2(
- gb_type="Attempted to call function marked as skipped",
- context=f"module: {module_name}, qualname: {qualname}, skip reason: {reason}",
- explanation=explanation,
- hints=hints,
- )
- def call_obj_hasattr(self, tx: "InstructionTranslator", name):
- return variables.ConstantVariable.create(hasattr(self.value, name))
- def var_getattr(self, tx: "InstructionTranslator", name: str):
- if name in cmp_name_to_op_mapping:
- return variables.GetAttrVariable(self, name)
- return fn_var_getattr(tx, self.value, self.source, name)
- class WrappedSkipFunctionVariable(SkipFunctionVariable):
- def __init__(self, wrapped, context, **kwargs) -> None:
- kwargs.pop("value", None)
- kwargs.pop("reason", None)
- super().__init__(wrapped.value, reason=wrapped.reason, **kwargs)
- self.wrapped = wrapped
- self.context = context
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- self.context.enter(tx)
- result = super().call_function(tx, args, kwargs)
- self.context.exit(tx)
- return result
- def reconstruct(self, codegen):
- codegen.add_push_null(lambda: codegen(self.context))
- codegen(self.wrapped)
- codegen.extend_output(create_call_function(1, False))
- class WrapperUserFunctionVariable(VariableTracker):
- """
- Used to represent a wrapper object that contains the actual callable as an
- attribute. For example, torch.jit.script/trace have the original function at
- their _torchdynamo_inline attribute. Similarly, functions with
- __script_if_tracing_wrapper have the original attr at "__original_fn".
- """
- def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None:
- super().__init__(**kwargs)
- self.wrapper_obj = wrapper_obj
- self.attr_to_trace = attr_to_trace
- def var_getattr(self, tx: "InstructionTranslator", name):
- if name == self.attr_to_trace:
- val = getattr(self.wrapper_obj, self.attr_to_trace)
- source = self.source and AttrSource(self.source, name)
- return VariableTracker.build(tx, val, source)
- return super().var_getattr(tx, name)
- def self_args(self):
- return []
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if hasattr(self.wrapper_obj, "cache_info"):
- target_fn = getattr(self.wrapper_obj, self.attr_to_trace, None)
- module_name = getattr(target_fn, "__module__", "") or ""
- if module_name.split(".", maxsplit=1)[0] != "torch":
- msg = (
- "Dynamo detected a call to a `functools.lru_cache`-wrapped "
- "function. Dynamo ignores the cache wrapper and directly "
- "traces the wrapped function. Silent incorrectness is only "
- "a *potential* risk, not something we have observed. "
- 'Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.'
- )
- torch._dynamo.utils.warn_once(msg)
- dynamo_logger = torch._dynamo.utils.logging.getLogger("torch._dynamo")
- if dynamo_logger.isEnabledFor(logging.DEBUG):
- user_stack = torch._guards.TracingContext.extract_stack()
- user_stack = get_stack_above_dynamo() + user_stack
- frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
- user_stack_formatted = "".join(traceback.format_list(user_stack))
- user_stack_trace = f"call to a lru_cache wrapped function at: {frame_loc[0]}:{frame_loc[1]}\n"
- user_stack_trace += str(user_stack_formatted)
- dynamo_logger.debug(user_stack_trace)
- all_args = self.self_args() + args
- return variables.UserFunctionVariable(
- polyfills.getattr_and_trace
- ).call_function(
- tx,
- [self, variables.ConstantVariable(self.attr_to_trace), *all_args],
- kwargs,
- )
- class WrapperUserMethodVariable(WrapperUserFunctionVariable):
- """
- Similar to WrapperUserFunctionVariable, but for methods. The only delta is
- saving the vt for `self` object of the method which is then used by
- WrapperUserFunctionVariable in `call_function` method.
- """
- def __init__(self, wrapper_obj, attr_to_trace, self_obj, **kwargs) -> None:
- super().__init__(wrapper_obj, attr_to_trace, **kwargs)
- self.obj = self_obj
- def self_args(self):
- return [self.obj]
- def _traceable_collective_remaps():
- # We can't rely on importing from distributed, since it's not always built
- if torch.distributed.is_available():
- from torch.distributed._functional_collectives import (
- traceable_collective_remaps,
- )
- return traceable_collective_remaps
- return {}
- def _traceable_collectives_source(tx: "InstructionTranslator", fn):
- assert torch.distributed.is_available(), "Illegal invocation."
- assert fn in _traceable_collective_remaps().values()
- inner_name = fn.__name__
- path_source = tx.import_source("torch.distributed._functional_collectives")
- return AttrSource(path_source, inner_name)
- class CollectiveFunctionRewriteVariable(UserFunctionVariable):
- """
- Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
- This class provides both a way to check if a function is remappable, and perform the remapping.
- In the case that a function is 'remappable' but only for some combinations of call-time arguments,
- we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
- than status-quo as we currently graph-break on all distributed.* collectives.
- """
- def __init__(self, fn, *, replacement_var, **kwargs) -> None:
- super().__init__(fn, **kwargs)
- assert isinstance(replacement_var, UserFunctionVariable)
- self.replacement_var = replacement_var
- @staticmethod
- def create(tx: "InstructionTranslator", old_fn, source, **options):
- new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
- return CollectiveFunctionRewriteVariable(
- old_fn,
- replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
- source=source,
- **options,
- )
- @staticmethod
- def can_rewrite(variable):
- return (
- inspect.isfunction(variable) and variable in _traceable_collective_remaps()
- )
- @staticmethod
- def rewrite(tx: "InstructionTranslator", fn):
- new_fn = _traceable_collective_remaps()[fn]
- return new_fn, _traceable_collectives_source(tx, new_fn)
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- # call_function must check any unsupported arguments and graph-break.
- # It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
- # since that's the contract for putting a mapping in `traceable_collective_remaps`
- import torch.distributed as dist
- from torch.distributed._functional_collectives import REDUCE_OP_TO_STR
- # Merge args into kwargs so positional and keyword args
- # can be processed the same way.
- signature = inspect.signature(self.fn)
- kwargs = dict(signature.bind(*args, **kwargs).arguments)
- args = ()
- if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
- unimplemented_v2(
- gb_type="async_op=True for distributed collectives",
- context=f"{self.fn}, {args=}, {kwargs=}",
- explanation=f"`torch.compile` doesn't support `async_op=True for {self.fn}",
- hints=[
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- if self.fn in (
- dist.all_reduce,
- dist.reduce_scatter_tensor,
- dist._reduce_scatter_base,
- ):
- reduce_op_var = kwargs.get("op")
- reduce_op = (
- reduce_op_var.value
- if reduce_op_var is not None
- else signature.parameters["op"].default
- )
- if reduce_op not in REDUCE_OP_TO_STR:
- raise ValueError(f"Unsupported all_reduce op: {reduce_op}")
- kwargs["op"] = variables.ConstantVariable.create(
- REDUCE_OP_TO_STR[reduce_op]
- )
- return self.replacement_var.call_function(tx, args, kwargs)
- class FunctoolsWrapsVariable(UserFunctionVariable):
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if not kwargs and len(args) == 1:
- def wraps(fn):
- if isinstance(fn, variables.NestedUserFunctionVariable):
- return fn.clone(wrapped_fn=args[0])
- unimplemented_v2(
- gb_type="functools.wraps",
- context=f"{fn}",
- explanation="`torch.compile` can't trace `functools.wraps` on functions defined outside the compile region",
- hints=[
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- return variables.LambdaVariable(wraps)
- return super().call_function(tx, args, kwargs)
- class CollectionsNamedTupleFunction(UserFunctionVariable):
- def as_python_constant(self):
- return self.fn
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- constant_args = check_constant_args(args, kwargs)
- if constant_args:
- try:
- value = self.fn(
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- )
- except TypeError as exc:
- raise_observed_exception(
- type(exc),
- tx,
- args=list(map(ConstantVariable.create, exc.args)),
- )
- return variables.UserDefinedClassVariable(
- value, mutation_type=ValueMutationNew()
- )
- unimplemented_v2(
- gb_type="namedtuple construction",
- context=f"{args=}, {kwargs=}",
- explanation="`torch.compile` only support certain input types for namedtuple",
- hints=[
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- class FunctoolsPartialVariable(VariableTracker):
- def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None:
- super().__init__(**kwargs)
- self.func = func
- assert isinstance(args, list)
- self.args = args
- assert isinstance(keywords, dict)
- self.keywords = keywords
- # fake_value is used for id calculation. Creating this value and id'ng
- # on it is sufficient for the tracing purposes.
- self.fake_value = functools.partial(identity)
- def python_type(self):
- return functools.partial
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial"))
- codegen(self.func)
- if self.args:
- codegen.foreach(self.args)
- if not self.keywords:
- codegen.extend_output(create_call_function(len(self.args) + 1, False))
- return
- codegen.foreach(self.keywords.values())
- keys = tuple(self.keywords.keys())
- codegen.extend_output(
- codegen.create_call_function_kw(len(keys) + len(self.args) + 1, keys, False)
- )
- def get_function(self):
- return self.as_python_constant()
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- merged_args = self.args + args
- merged_kwargs = {**self.keywords, **kwargs}
- return self.func.call_function(tx, merged_args, merged_kwargs)
- def call_obj_hasattr(
- self, tx: "InstructionTranslator", name: str
- ) -> VariableTracker:
- # functools.partial uses slots, so attributes are constant
- return variables.ConstantVariable.create(
- hasattr(functools.partial(identity), name)
- )
- def var_getattr(self, tx: "InstructionTranslator", name: str):
- source = self.source and AttrSource(self.source, name)
- # Handle __slots__
- if name == "func":
- return self.func
- if name == "args":
- return variables.ListVariable(self.args, source=source)
- if name == "keywords":
- items = {ConstantVariable.create(k): v for k, v in self.keywords.items()}
- return variables.ConstDictVariable(items, source=source)
- if name in cmp_name_to_op_mapping:
- return variables.GetAttrVariable(self, name)
- raise_observed_exception(AttributeError, tx)
- def as_python_constant(self):
- return functools.partial(
- self.func.as_python_constant(),
- *[arg.as_python_constant() for arg in self.args],
- **{k: v.as_python_constant() for k, v in self.keywords.items()},
- )
- def guard_as_python_constant(self):
- """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
- return functools.partial(
- self.func.guard_as_python_constant(),
- *[v.guard_as_python_constant() for v in self.args],
- **{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
- )
- class PolyfilledFunctionVariable(VariableTracker):
- _nonvar_fields = {
- "fn",
- "wrapped_fn",
- "traceable_fn",
- *VariableTracker._nonvar_fields,
- }
- @classmethod
- @functools.cache
- def _get_polyfill_handlers(cls) -> dict[Callable[..., Any], types.FunctionType]:
- return {}
- @classmethod
- def create_with_source(cls, value, source):
- install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
- return cls(value, source=source)
- def __init__(self, fn: _F, **kwargs) -> None:
- super().__init__(**kwargs)
- self.fn: _F = fn
- handler = self._get_polyfill_handlers().get(fn, fn)
- assert callable(handler), f"Polyfill handler {handler} is not callable for {fn}"
- for candidate_attr in (
- "__torch_dynamo_polyfill__", # registered polyfill
- "__python_implementation__", # self handler from third-party libraries
- ):
- candidate = getattr(handler, candidate_attr, None)
- if candidate:
- assert callable(candidate)
- traceable_fn = candidate
- break
- else:
- raise RuntimeError(
- f"Polyfill handler {handler} does not have a traceable function"
- )
- self.wrapped_fn: _F = handler
- self.traceable_fn: _F = traceable_fn
- @property
- def polyfill_fn(self) -> _F:
- return self.traceable_fn
- def can_constant_fold_through(self):
- return getattr(
- self.wrapped_fn, "__torch_dynamo_can_constant_fold_through__", False
- )
- def get_function(self):
- return self.as_python_constant()
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if self.can_constant_fold_through() and check_unspec_or_constant_args(
- args, kwargs
- ):
- result = (
- self.fn( # use the original function which is faster than the polyfill
- *[x.as_python_constant() for x in args],
- **{k: v.as_python_constant() for k, v in kwargs.items()},
- )
- )
- return VariableTracker.build(tx, result)
- # Special case for sum on tuple/list of ints
- if (
- self.fn is builtins.sum
- and len(args) == 1
- and not kwargs
- and isinstance(args[0], (variables.ListVariable, variables.TupleVariable))
- and all(
- (isinstance(x, variables.ConstantVariable) and isinstance(x.value, int))
- or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int)
- for x in args[0].items
- )
- ):
- return variables.SymNodeVariable.create(
- tx,
- tx.output.create_proxy(
- "call_function",
- torch.sym_sum,
- (tuple(a.as_proxy() for a in args[0].items),),
- {},
- ),
- sym_num=torch.sym_sum(
- [
- (
- x.value
- if isinstance(x, variables.ConstantVariable)
- else x.sym_num
- )
- for x in args[0].items
- ]
- ),
- )
- traceable_function_variable = VariableTracker.build(tx, self.traceable_fn)
- return traceable_function_variable.call_function(tx, args, kwargs)
- def call_method(
- self,
- tx,
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__call__":
- return self.call_function(tx, args, kwargs)
- method = getattr(self.fn, name, None)
- assert method is not None, f"Member {name} not found in {self.fn}"
- assert is_function(method), f"Member {name} is not callable in {self.fn}"
- options = {}
- if self.source:
- options["source"] = AttrSource(self.source, name)
- polyfilled_method_variable = PolyfilledFunctionVariable(method, **options)
- return polyfilled_method_variable.call_function(tx, args, kwargs)
- def as_python_constant(self):
- return self.fn
- class TracebackVariable(VariableTracker):
- # We don't track traceback. A call to any function in this module is a no-op
- def call_function(self, tx, args, kwargs): ...
- class SysFunctionVariable(VariableTracker):
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- def exc_info(self, tx):
- if len(tx.exn_vt_stack):
- exn = tx.exn_vt_stack[-1]
- typ = exn.exc_type
- tb = None
- items = [
- VariableTracker.build(tx, typ),
- exn,
- VariableTracker.build(tx, tb),
- ]
- else:
- items = [
- variables.ConstantVariable(None),
- variables.ConstantVariable(None),
- variables.ConstantVariable(None),
- ]
- return variables.TupleVariable(items)
- def exception(self, tx):
- return self.exc_info(tx).items[1]
- def call_function(self, tx, args, kwargs):
- if self.value is sys.exc_info:
- return self.exc_info(tx)
- assert self.value is sys.exception
- return self.exception(tx)
- from torch._higher_order_ops.triton_kernel_wrap import (
- create_tma_experimental_metadata,
- create_tma_stable_metadata,
- TMADescriptorMetadata,
- TritonHOPifier,
- )
- class DynamoTritonHOPifier(TritonHOPifier):
- def raise_unsupported(self, msg: str) -> Never:
- raise Unsupported(msg)
- def is_callable(self, maybe_callable: Any) -> bool:
- return isinstance(
- maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable)
- )
- def get_value(self, val: Any) -> Any:
- return val.value
- def check_grid(self, grid) -> tuple[torch.fx.proxy.Proxy, ...]:
- from .lists import BaseListVariable
- if isinstance(grid, BaseListVariable):
- return grid.as_proxy()
- else:
- unimplemented_v2(
- gb_type="unsupported grid type for triton hop check_grid",
- context=f"grid type = {type(grid)}",
- explanation="`torch.compile` only supports list-like grid for check_grid",
- hints=[
- *graph_break_hints.SUPPORTABLE,
- ],
- )
- def call_grid(self, grid, meta, tx):
- meta = {variables.ConstantVariable.create(k): v for k, v in meta.items()}
- grid = grid.call_function(tx, [meta], {})
- return grid
- # We use this function to wrap call_prune_configs
- def call_user_defined_fn(self, user_fn, args, kwargs, tx, variable):
- from .builder import SourcelessBuilder
- wrapped_user_function = SourcelessBuilder.create(tx, user_fn)
- result = wrapped_user_function.call_function(tx, args, kwargs)
- return result
- def wrap_user_defined_obj(self, user_obj, tx, variable, name):
- from .builder import VariableBuilder
- wrapped_user_obj = VariableBuilder(
- tx, AttrSource(variable.kernel_source, f"{name}")
- )._wrap(user_obj)
- return wrapped_user_obj
- def maybe_unpack_configs(self, configs, tx):
- # unpack the list of configs
- configs = configs.unpack_var_sequence(tx)
- # guard_as_python_constant inserts guards for Dynamo to check if the configs object changed.
- configs = [config.guard_as_python_constant() for config in configs]
- return configs
- def maybe_unpack_heuristic_result(self, result: Any) -> Any:
- if not result.is_python_constant():
- self.raise_unsupported(
- "@triton.heuristics must return constant values because configs can only contain constant values."
- )
- return result.guard_as_python_constant()
- # We need to override call_getitem here so that we can add the source in the case
- # where we call the triton kernel with a grid
- def call_getitem(
- self,
- variable: "TritonKernelVariable",
- args: Sequence[Any],
- ) -> "TritonKernelVariable":
- # __getitem__ should only be called if we don't already have a grid
- # Only grid needs to be passed
- if variable.grid is not None or len(args) != 1:
- self.raise_unsupported(
- "Triton kernels should be called with only a single grid"
- )
- return type(variable)(
- kernel=variable.kernel,
- kernel_idx=variable.kernel_idx,
- grid=args[0],
- kernel_source=variable.source,
- )
- def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable:
- from .constant import ConstantVariable
- from .dicts import ConstDictVariable
- # as we can only pass tensors as non-const args in fx graph,
- # here we replace TMA descriptors
- # (TMADescriptorExperimentalVariable and TMADescriptorStableVariable
- # instances) with the underlying tensors, while moving the
- # TMA descriptor-related metadata to a separate argument,
- # so that we can reconstruct the TMA descriptors downstream
- tma_descriptor_metadata: TMADescriptorMetadata = {}
- for k in list(combined_args_raw.keys()):
- v = combined_args_raw[k]
- if isinstance(
- v, (TMADescriptorExperimentalVariable, TMADescriptorStableVariable)
- ):
- tma_descriptor_metadata[k] = v.to_metadata()
- combined_args_raw[k] = v.get_tensor()
- combined_args = {
- variables.ConstantVariable.create(k): v
- for k, v in combined_args_raw.items()
- }
- from torch._higher_order_ops.triton_kernel_wrap import (
- kernel_side_table,
- triton_kernel_wrapper_mutation,
- )
- # Combine args and kwargs and pass as a dict so that if user defined triton
- # kernel uses variables as 'grid' or 'kernel', it does not conflict with
- # parameters of the wrapper function
- constant_args = {
- k: v.as_python_constant()
- for k, v in combined_args_raw.items()
- if isinstance(v, ConstantVariable)
- }
- non_constant_args = {
- k: v
- for k, v in combined_args.items()
- if not isinstance(v, ConstantVariable)
- }
- for v in non_constant_args.values():
- v = v.realize()
- if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)):
- self.raise_unsupported(
- f"Unexpected argument type for a Triton kernel: {repr(v)}."
- )
- constant_args_idx = kernel_side_table.add_constant_args(constant_args)
- meta = ConstDictVariable(non_constant_args, dict)
- tx.output.create_proxy(
- "call_function",
- triton_kernel_wrapper_mutation,
- (),
- {
- "kernel_idx": variable.kernel_idx,
- "constant_args_idx": constant_args_idx,
- "grid": grids,
- "tma_descriptor_metadata": tma_descriptor_metadata,
- "kwargs": meta.as_proxy(),
- },
- )
- return variables.ConstantVariable(
- None,
- )
- dynamo_triton_hopifier_singleton = DynamoTritonHOPifier()
- class TritonKernelVariable(VariableTracker):
- grid: "TritonGridType"
- kernel: "TritonKernelType"
- kernel_idx: Optional[int]
- kernel_source: "AttrSource"
- def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None:
- self.kernel_source = kwargs.pop("kernel_source", None)
- super().__init__(**kwargs)
- dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid)
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- return dynamo_triton_hopifier_singleton.call_triton_kernel(
- self, args, kwargs, tx
- )
- def call_method(
- self,
- tx,
- name,
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__getitem__":
- return dynamo_triton_hopifier_singleton.call_getitem(self, args)
- elif name == "run":
- return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx)
- # Bail out to parent's implementation
- return super().call_method(tx, name, args, kwargs)
- def specialize_symbolic(self, arg: Any) -> Any:
- from .constant import ConstantVariable
- from .tensor import SymNodeVariable
- # See [Note: Specialize tl.constexpr args in user-defined triton kernels]
- if isinstance(arg, SymNodeVariable):
- return ConstantVariable.create(arg.evaluate_expr())
- return arg
- class TMADescriptorExperimentalVariable(VariableTracker):
- def __init__(
- self,
- data_ptr: "variables.DataPtrVariable",
- dims: "list[ConstantVariable]",
- block_dims: "list[ConstantVariable]",
- element_size: "ConstantVariable",
- **kwargs,
- ):
- assert isinstance(data_ptr, variables.DataPtrVariable)
- super().__init__(**kwargs)
- self.data_ptr = data_ptr
- self.dims = dims
- self.block_dims = block_dims
- self.element_size = element_size
- def to_metadata(self):
- return create_tma_experimental_metadata(
- [dim.as_proxy() for dim in self.dims],
- [dim.as_proxy() for dim in self.block_dims],
- self.element_size.as_proxy(),
- )
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- "triton.tools.experimental_descriptor",
- f"create_{len(self.dims)}d_tma_descriptor",
- )
- )
- self.data_ptr.reconstruct(codegen)
- args = [*self.dims, *self.block_dims, self.element_size]
- codegen.foreach(args)
- codegen.call_function(len(args) + 1, False)
- def get_tensor(self):
- return self.data_ptr.from_tensor
- class TMADescriptorStableVariable(VariableTracker):
- def __init__(
- self,
- tensor: "variables.TensorVariable",
- block_shape: "variables.ListVariable",
- **kwargs,
- ):
- assert isinstance(tensor, variables.TensorVariable)
- super().__init__(**kwargs)
- self.tensor = tensor
- self.block_shape = block_shape
- def to_metadata(self):
- return create_tma_stable_metadata(
- self.block_shape.as_proxy(),
- )
- def reconstruct(self, codegen: "PyCodegen"):
- codegen.add_push_null(
- lambda: codegen.load_import_from(
- "triton.tools.tensor_descriptor",
- "TensorDescriptor",
- )
- )
- codegen.load_method("from_tensor")
- self.tensor.reconstruct(codegen)
- codegen(self.block_shape)
- codegen.call_method(2)
- def get_tensor(self) -> "variables.TensorVariable":
- return self.tensor
- class CreateTMADescriptorExperimentalVariable(VariableTracker):
- def __init__(
- self,
- rank: int,
- **kwargs,
- ) -> None:
- assert rank in (1, 2)
- super().__init__(**kwargs)
- self.rank = rank
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- ptr = kwargs["ptr"] if "ptr" in kwargs else args[0]
- if not isinstance(ptr, variables.DataPtrVariable):
- raise Unsupported(
- "Please ensure there were no graph breaks between "
- f"create_{self.rank}d_tma_descriptor and the upstream "
- ".data_ptr() call."
- )
- if self.rank == 1:
- assert len(args) + len(kwargs) == 4
- dims = [
- kwargs["dim"] if "dim" in kwargs else args[1],
- ]
- block_dims = [
- kwargs["block_dim"] if "block_dim" in kwargs else args[2],
- ]
- else:
- assert len(args) + len(kwargs) == 6
- dims = [
- kwargs["dim1"] if "dim1" in kwargs else args[1],
- kwargs["dim0"] if "dim0" in kwargs else args[2],
- ]
- block_dims = [
- kwargs["block_dim1"] if "block_dim1" in kwargs else args[3],
- kwargs["block_dim0"] if "block_dim0" in kwargs else args[4],
- ]
- element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1]
- return TMADescriptorExperimentalVariable(
- data_ptr=ptr,
- dims=dims,
- block_dims=block_dims,
- element_size=element_size,
- )
- class CreateTMADescriptorStableVariable(VariableTracker):
- def call_function(
- self,
- tx: "InstructionTranslator",
- args: "list[VariableTracker]",
- kwargs: "dict[str, VariableTracker]",
- ) -> "VariableTracker":
- tensor = kwargs["tensor"] if "tensor" in kwargs else args[0]
- block_shape = kwargs["block_shape"] if "block_shape" in kwargs else args[1]
- return TMADescriptorStableVariable(
- tensor=tensor,
- block_shape=block_shape,
- )
|