misc.py 73 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947
  1. # mypy: ignore-errors
  2. """
  3. This module contains miscellaneous variable tracker implementations for various Python types
  4. and features used in Dynamo's symbolic execution. These classes help track and propagate
  5. information about different kinds of variables during graph capture.
  6. Key classes include:
  7. - SuperVariable: Handles super() calls and method resolution
  8. - ExceptionVariable: Tracks exception objects
  9. - RandomVariable: Manages random number generators
  10. - GetAttrVariable: Tracks attribute access
  11. - MethodWrapperVariable: Handles method wrappers
  12. - PythonModuleVariable: Tracks Python modules
  13. - NumpyVariable: Handles numpy functions and types
  14. - StringFormatVariable: Manages string formatting
  15. - DebuggingVariable: Handles print and logging
  16. """
  17. import dataclasses
  18. import functools
  19. import inspect
  20. import itertools
  21. import random
  22. import re
  23. import sys
  24. import types
  25. import warnings
  26. from typing import Optional, TYPE_CHECKING
  27. import torch._C
  28. import torch._numpy as tnp
  29. import torch.utils._pytree as pytree
  30. from .. import config, graph_break_hints, trace_rules, variables
  31. from ..bytecode_transformation import create_call_function, create_instruction
  32. from ..create_parameter_op import do_not_convert_to_tracable_parameter
  33. from ..exc import raise_observed_exception, unimplemented, unimplemented_v2
  34. from ..guards import GuardBuilder, install_guard
  35. from ..mutation_guard import unpatched_nn_module_init
  36. from ..source import (
  37. AttrSource,
  38. GenericAttrSource,
  39. GetItemSource,
  40. TypeMROSource,
  41. TypeSource,
  42. WeakRefCallSource,
  43. )
  44. from ..utils import (
  45. check_unspec_or_constant_args,
  46. cmp_name_to_op_mapping,
  47. identity,
  48. is_tensor_base_attr_getter,
  49. istype,
  50. list_methods,
  51. proxy_args_kwargs,
  52. tuple_methods,
  53. )
  54. from .base import VariableTracker
  55. from .constant import ConstantVariable
  56. from .functions import NestedUserFunctionVariable, UserFunctionVariable
  57. from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
  58. if TYPE_CHECKING:
  59. from torch._dynamo.codegen import PyCodegen
  60. from torch._dynamo.symbolic_convert import InstructionTranslator
  61. class NO_SUCH_SUBOBJ:
  62. pass
  63. class SuperVariable(VariableTracker):
  64. _nonvar_fields = {
  65. *VariableTracker._nonvar_fields,
  66. }
  67. def __init__(self, typevar, objvar=None, **kwargs) -> None:
  68. super().__init__(**kwargs)
  69. # typevar is the first argument to super(). In the case where no argument
  70. # is provided to super(), it is the __class__ object where
  71. # the super() function is being called
  72. self.typevar = typevar
  73. # objvar here must be an instance or subtype of typevar.
  74. # In the case where super() is called without arguments, it is the first argument
  75. # to the current function where super() is called from (self for regular method,
  76. # cls for a classmethod)
  77. self.objvar = objvar
  78. def reconstruct(self, codegen: "PyCodegen"):
  79. codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super)))
  80. codegen(self.typevar)
  81. if self.objvar is not None:
  82. codegen(self.objvar)
  83. codegen.extend_output(create_call_function(2, False))
  84. else:
  85. codegen.extend_output(create_call_function(1, False))
  86. def _resolved_getattr_and_source(self, tx: "InstructionTranslator", name):
  87. assert self.objvar, "1-arg super not implemented"
  88. search_type = self.typevar.as_python_constant()
  89. # The rest of this function does two things:
  90. # - Walk the mro to find where the attribute comes from to be
  91. # able to provide accurate source
  92. # - Call the getattr to get the object
  93. # Find the class object, where the function lives.
  94. # When objvar is "self", use type(self), when objvar is "cls", use it as-is
  95. type_to_use = self.objvar.python_type()
  96. type_to_use_source = (
  97. TypeSource(self.objvar.source) if self.objvar.source else None
  98. )
  99. if issubclass(type_to_use, type):
  100. type_to_use = self.objvar.value
  101. type_to_use_source = self.objvar.source
  102. source = None
  103. search_mro = type_to_use.__mro__
  104. try:
  105. start_index = search_mro.index(search_type) + 1
  106. except ValueError:
  107. # Corner case where the typevar is not in the mro of the objvar
  108. # https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8843-L8844
  109. return getattr(super(search_type, type_to_use), name), None
  110. # Implemented based on https://github.com/python/cpython/blob/3.11/Objects/typeobject.c#L8812
  111. # super has its getattro implementation. The key point is that instead of calling getattr, it checks the
  112. # attribute in the class __dict__
  113. for index in range(start_index, len(search_mro)):
  114. # Dont call getattr, just check the __dict__ of the class
  115. if resolved_getattr := search_mro[index].__dict__.get(name, NO_SUCH_SUBOBJ):
  116. if resolved_getattr is not NO_SUCH_SUBOBJ:
  117. # Equivalent of something like type(L['self']).__mro__[1].attr_name
  118. if type_to_use_source:
  119. source = AttrSource(
  120. GetItemSource(TypeMROSource(type_to_use_source), index),
  121. name,
  122. )
  123. return resolved_getattr, source
  124. unimplemented_v2(
  125. gb_type="Unable to resolve super getattr",
  126. context="",
  127. explanation=f"Dynamo failed to trace attribute `{name}` accessed "
  128. f"via `super()` (for type `{self.typevar}` and object `{self.objvar}`) "
  129. "because the resolved attribute type is not supported.",
  130. hints=[
  131. "Ensure the attribute exists in the parent class.",
  132. "Check the arguments passed to `super()`.",
  133. ],
  134. )
  135. def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
  136. # Check if getattr is a constant. If not, delay the actual work by
  137. # wrapping the result in GetAttrVariable. Mostly super is called with a
  138. # method, so most of the work is delayed to call_function.
  139. #
  140. # We could have just implemented a const_getattr. However, super is
  141. # special when it comes to finding sources. Compared to other VTs, super
  142. # requires the attr name to walk the mro and find the actual source (and
  143. # not just AttrSource).
  144. value, source = self._resolved_getattr_and_source(self, name)
  145. if not variables.ConstantVariable.is_literal(value):
  146. return GetAttrVariable(self, name)
  147. if source:
  148. install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
  149. return variables.ConstantVariable.create(value, source=source)
  150. def call_method(
  151. self,
  152. tx: "InstructionTranslator",
  153. name,
  154. args: "list[VariableTracker]",
  155. kwargs: "dict[str, VariableTracker]",
  156. ) -> "VariableTracker":
  157. inner_fn, source = self._resolved_getattr_and_source(self, name)
  158. # This essentially simulates CPython's `super_getattro`:
  159. # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168
  160. # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`.
  161. #
  162. # However, `res`'s type needs to be checked for `tp_descr_get`, and
  163. # applied if it has one. We currently don't have polyfills for all the
  164. # relevant `tp_descr_get`, so we explicitly handle the cases we care
  165. # about here (e.g., note the staticmethod, classmethod cases).
  166. if inner_fn is object.__init__:
  167. return LambdaVariable(identity)
  168. elif inner_fn is torch.nn.Module.__init__:
  169. objvar = self.objvar
  170. from ..side_effects import AttributeMutationNew
  171. if (
  172. isinstance(objvar, variables.UserDefinedObjectVariable)
  173. and isinstance(objvar.mutation_type, AttributeMutationNew)
  174. and not (args or kwargs)
  175. ):
  176. with do_not_convert_to_tracable_parameter():
  177. return variables.UserFunctionVariable(
  178. unpatched_nn_module_init, source=source
  179. ).call_function(tx, [self.objvar] + args, kwargs)
  180. else:
  181. unimplemented_v2(
  182. gb_type="Unsupported super().__init__() call",
  183. context=f"call_method {self} {name} {args} {kwargs}",
  184. explanation="Dynamo encountered a super().__init__() call "
  185. f"on {objvar} that resolved to a `torch.nn.Module.__init__()` "
  186. "call that we cannot trace.",
  187. hints=[*graph_break_hints.DIFFICULT],
  188. )
  189. elif (
  190. self.objvar.source
  191. and hasattr(inner_fn, "__name__")
  192. and inner_fn.__name__ == "__new__"
  193. and variables.UserDefinedClassVariable.is_supported_new_method(inner_fn)
  194. ):
  195. user_cls = inner_fn.__self__
  196. if hasattr(user_cls, "__module__") and user_cls.__module__ == "builtins":
  197. user_cls_vt = variables.BuiltinVariable(user_cls)
  198. else:
  199. user_cls_source = source.member
  200. user_cls_vt = variables.UserDefinedClassVariable(
  201. user_cls, source=user_cls_source
  202. )
  203. return user_cls_vt.call_method(tx, "__new__", args, kwargs)
  204. elif isinstance(inner_fn, staticmethod) and isinstance(
  205. inner_fn.__func__, types.FunctionType
  206. ):
  207. return variables.UserFunctionVariable(
  208. inner_fn.__func__, source=source
  209. ).call_function(tx, args, kwargs)
  210. elif isinstance(inner_fn, classmethod) and isinstance(
  211. inner_fn.__func__, types.FunctionType
  212. ):
  213. if isinstance(self.objvar, variables.UserDefinedClassVariable):
  214. # super().classmethod is called from a classmethod itself. So,
  215. # super was converted to super(__class__, cls) in bytecode and
  216. # therefore we have to propagate the cls.
  217. cls_variable = self.objvar
  218. else:
  219. # current function is an instance method, therefore super was
  220. # converted to super(__class__, self). We have to find
  221. # type(self) to bind the cls to the parent classmethod.
  222. # Note that it can't be the self.typevar because __class__ is
  223. # the class where the method is defined, which could be
  224. # different from type(self) with polymorphism.
  225. cls_source = None
  226. if self.objvar.source:
  227. cls_source = TypeSource(self.objvar.source)
  228. cls_variable = VariableTracker.build(
  229. tx, self.objvar.value_type, cls_source
  230. )
  231. return variables.UserFunctionVariable(
  232. inner_fn.__func__, source=AttrSource(source, "__func__")
  233. ).call_function(tx, [cls_variable, *args], kwargs)
  234. elif isinstance(inner_fn, types.FunctionType):
  235. return variables.UserFunctionVariable(
  236. inner_fn, source=source
  237. ).call_function(tx, [self.objvar] + args, kwargs)
  238. elif isinstance(inner_fn, types.MethodType):
  239. return variables.UserMethodVariable(
  240. inner_fn.__func__, self.objvar, source=source
  241. ).call_function(tx, args, kwargs)
  242. elif is_standard_setattr(inner_fn) and isinstance(
  243. self.objvar, UserDefinedObjectVariable
  244. ):
  245. return self.objvar.method_setattr_standard(tx, *args, **kwargs)
  246. elif inner_fn is object.__delattr__:
  247. attr = args[0]
  248. try:
  249. attr = attr.as_python_constant()
  250. except NotImplementedError as exc:
  251. unimplemented_v2(
  252. gb_type="Non-constant attribute given to `super().__delattr__()`",
  253. context=f"call_method {self} {name}",
  254. explanation="Dynamo requires the attribute name passed to "
  255. "`super().__delattr__(...)` to be a constant (string).",
  256. hints=[
  257. "Ensure the attribute name is a string literal or a constant variable."
  258. ],
  259. from_exc=exc,
  260. )
  261. if not tx.output.side_effects.is_attribute_mutation(self.objvar):
  262. unimplemented_v2(
  263. gb_type="Attempted super().__delattr__() on an object without mutation tracking",
  264. context=f"call_method {self} {name}",
  265. explanation="Dynamo needs to track mutations on an object "
  266. "before `super().__delattr__` can be used on it. But the "
  267. f"object ({self.objvar}) doesn't have attribute mutation "
  268. "tracking enabled.",
  269. hints=[
  270. "Ensure the object is tracked by Dynamo's side effect system.",
  271. *graph_break_hints.DYNAMO_BUG,
  272. ],
  273. )
  274. tx.output.side_effects.store_attr(
  275. self.objvar, attr, variables.DeletedVariable()
  276. )
  277. return variables.ConstantVariable(None)
  278. elif (
  279. isinstance(self.objvar, variables.UserDefinedDictVariable)
  280. and inner_fn in self.objvar._dict_methods
  281. ):
  282. return self.objvar._dict_vt.call_method(tx, name, args, kwargs)
  283. elif (
  284. isinstance(self.objvar, variables.UserDefinedSetVariable)
  285. and inner_fn in self.objvar._set_methods
  286. ):
  287. return self.objvar._set_vt.call_method(tx, name, args, kwargs)
  288. elif (
  289. isinstance(self.objvar, variables.UserDefinedTupleVariable)
  290. and inner_fn in tuple_methods
  291. ):
  292. return self.objvar._tuple_vt.call_method(tx, name, args, kwargs)
  293. elif (
  294. isinstance(self.objvar, variables.UserDefinedListVariable)
  295. and inner_fn in list_methods
  296. ):
  297. return self.objvar._list_vt.call_method(tx, name, args, kwargs)
  298. elif inner_fn is object.__getattribute__:
  299. # object.__getattribute__ has no side-effects. We can directly call
  300. # __getattribute__ to access the attribute.
  301. attr_name = args[0].value
  302. if tx.output.side_effects.has_pending_mutation_of_attr(
  303. self.objvar, attr_name
  304. ):
  305. result = tx.output.side_effects.load_attr(
  306. self.objvar, attr_name, deleted_ok=True
  307. )
  308. if isinstance(result, variables.DeletedVariable):
  309. raise_observed_exception(AttributeError, tx)
  310. return result
  311. try:
  312. # NB - use object.__getattribute__ to prevent running any user code
  313. attr_value = object.__getattribute__(self.objvar.value, attr_name)
  314. except AttributeError:
  315. raise_observed_exception(AttributeError, tx)
  316. attr_source = None
  317. if self.objvar.source is not None:
  318. # setup a object.__getattribute__(self.objvar, name) source
  319. attr_source = GenericAttrSource(self.objvar.source, attr_name)
  320. return VariableTracker.build(tx, attr_value, attr_source)
  321. elif inner_fn is torch._C._disabled_torch_function_impl:
  322. # See `THPModule_disable_torch_function` for the C impl.
  323. # The signature of _disabled_torch_function_impl is similar to
  324. # `__torch_function__`, just without the first `cls` argument:
  325. # * (func, types, args, kwargs)
  326. func = args[0]
  327. tf_kwargs = {}
  328. tf_args = args[2].items
  329. for hash_key_vt, value_vt in args[3].items.items():
  330. key_str = hash_key_vt.vt.as_python_constant()
  331. tf_kwargs[key_str] = value_vt
  332. tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled
  333. tx.symbolic_torch_function_state.torch_function_subclass_enabled = False
  334. try:
  335. return func.call_function(tx, tf_args, tf_kwargs)
  336. finally:
  337. tx.symbolic_torch_function_state.torch_function_subclass_enabled = (
  338. tx_old
  339. )
  340. elif (
  341. isinstance(inner_fn, types.MethodDescriptorType)
  342. and inner_fn in trace_rules.get_tensor_method()
  343. ):
  344. # FunctionType but implementation is in C, we support some of these,
  345. # e.g., tensor ops like `torch.Tensor.to`.
  346. fn_var = VariableTracker.build(tx, inner_fn, source)
  347. return fn_var.call_function(tx, [self.objvar] + args, kwargs)
  348. unimplemented_v2(
  349. gb_type="Attempted to call a super() attribute that is "
  350. "not a function or method",
  351. context=f"call_method {self} {name}",
  352. explanation="Dynamo does not know how to trace the call "
  353. f"`super().{name}()` because `super().{name}` is not a "
  354. "function or method attribute.",
  355. hints=[
  356. "Ensure the attribute accessed via `super()` is a standard method or function.",
  357. ],
  358. )
  359. class ExceptionVariable(VariableTracker):
  360. # The ExceptionVariable corresponds to the BaseException class in Python
  361. def __init__(self, exc_type, args, **kwargs) -> None:
  362. super().__init__(**kwargs)
  363. self.exc_type = exc_type
  364. self.args = args
  365. # When raising a new exception while another exception is already being
  366. # handled, the new exception's __context__ attribute is automatically
  367. # set to the handled exception.
  368. self.__context__ = ConstantVariable(None)
  369. # Set when user raised an exception from another:
  370. # raise ... from ...
  371. self.__cause__ = ConstantVariable(None)
  372. # Boolean flag that controls whether the __context__ attribute is set
  373. self.__suppress_context__ = ConstantVariable(False)
  374. # Contains the call stack where the exception was raised. Dynamo does
  375. # not track traceback. So, this variable is always set to None
  376. self.__traceback__ = ConstantVariable(None)
  377. def set_context(self, context: "ExceptionVariable"):
  378. self.__context__ = context
  379. def reconstruct(self, codegen: "PyCodegen"):
  380. codegen.add_push_null(
  381. lambda: codegen.load_import_from("builtins", self.exc_type.__name__)
  382. )
  383. codegen.foreach(self.args)
  384. codegen.call_function(len(self.args), False)
  385. def codegen_attr(name: str) -> None:
  386. attr = getattr(self, name)
  387. if istype(attr, ConstantVariable):
  388. assert attr.value in (True, False, None), attr
  389. else:
  390. codegen.dup_top()
  391. codegen(attr)
  392. codegen.extend_output(codegen.rot_n(2))
  393. codegen.store_attr(name)
  394. codegen_attr("__context__")
  395. codegen_attr("__cause__")
  396. codegen_attr("__suppress_context__")
  397. def python_type(self):
  398. return self.exc_type
  399. def call_setattr(
  400. self,
  401. tx: "InstructionTranslator",
  402. name_var: VariableTracker,
  403. val: VariableTracker,
  404. ):
  405. def raise_error(msg):
  406. raise_observed_exception(TypeError, tx, args=[ConstantVariable(msg)])
  407. name = name_var.as_python_constant()
  408. if name == "__context__":
  409. self.set_context(val)
  410. elif name == "__cause__":
  411. if (isinstance(val, ConstantVariable) and val.value is None) or isinstance(
  412. val,
  413. (
  414. variables.BuiltinVariable,
  415. variables.ExceptionVariable,
  416. variables.UserDefinedExceptionClassVariable,
  417. variables.UserDefinedExceptionObjectVariable,
  418. ),
  419. ):
  420. self.__cause__ = val
  421. self.__suppress_context__ = variables.ConstantVariable(True)
  422. else:
  423. raise_error("exception cause must be None or derive from BaseException")
  424. elif name == "__suppress_context__":
  425. if isinstance(val, ConstantVariable) and val.value in (True, False):
  426. self.__suppress_context__ = val
  427. else:
  428. raise_error("exception cause must be None or derive from BaseException")
  429. elif name == "__traceback__":
  430. if isinstance(val, ConstantVariable) and val.value is None:
  431. self.__traceback__ = val
  432. else:
  433. unimplemented_v2(
  434. gb_type="Set Exception object `__traceback__` attribute to not-`None`",
  435. context=f"call_setattr {self} {name}",
  436. explanation="Dynamo does not support setting the attribute "
  437. "'__traceback__' on tracked exception objects to anything "
  438. "other than None.",
  439. hints=[
  440. "Avoid setting '__traceback__' on exception objects "
  441. "within traced code, or set it to None."
  442. ],
  443. )
  444. else:
  445. unimplemented_v2(
  446. gb_type="Unsupported attribute assignment on Exception object",
  447. context=f"call_setattr {self} {name}",
  448. explanation="Dynamo does not support setting the attribute "
  449. f"'{name}' on tracked exception objects. Only `__context__`, "
  450. "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.",
  451. hints=[*graph_break_hints.SUPPORTABLE],
  452. )
  453. return variables.ConstantVariable(None)
  454. def call_method(self, tx, name, args, kwargs):
  455. if name == "__setattr__":
  456. return self.call_setattr(tx, *args)
  457. elif name == "with_traceback":
  458. [tb] = args
  459. self.call_setattr(tx, ConstantVariable("__traceback__"), tb)
  460. return self
  461. else:
  462. return super().call_method(tx, name, args, kwargs)
  463. def var_getattr(self, tx, name):
  464. if name == "__context__":
  465. return self.__context__
  466. elif name == "__cause__":
  467. return self.__cause__
  468. elif name == "__suppress_context__":
  469. return self.__suppress_context__
  470. elif name == "__traceback__":
  471. return variables.ConstantVariable(None)
  472. elif name == "args":
  473. return variables.ListVariable(self.args, source=self.source)
  474. return super().var_getattr(tx, name)
  475. def __str__(self):
  476. return f"{self.__class__.__name__}({self.exc_type})"
  477. __repr__ = __str__
  478. class UnknownVariable(VariableTracker):
  479. """
  480. It could be anything!
  481. """
  482. class DelayGraphBreakVariable(UnknownVariable):
  483. """
  484. Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
  485. """
  486. def __init__(self, msg=None, **kwargs):
  487. super().__init__(**kwargs)
  488. self.msg = msg
  489. def call_function(
  490. self,
  491. tx: "InstructionTranslator",
  492. args: "list[VariableTracker]",
  493. kwargs: "dict[str, VariableTracker]",
  494. ) -> "VariableTracker":
  495. unimplemented_v2(
  496. gb_type="Unsupported function call (delayed)",
  497. context=f"source: {self.source}",
  498. explanation="Dynamo determined that a graph break should occur "
  499. f"when calling `{self.source.name()}`. Reason: {self.msg}",
  500. hints=[],
  501. )
  502. class ComptimeVariable(VariableTracker):
  503. """
  504. This variable is special, it lets you execute arbitrary code at
  505. Dynamo compile time
  506. """
  507. def reconstruct(self, codegen: "PyCodegen"):
  508. raise NotImplementedError("comptime is special form")
  509. def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
  510. from ..comptime import comptime
  511. # To support the comptime.print_graph convenience accessors
  512. from .functions import UserFunctionVariable
  513. return UserFunctionVariable(
  514. getattr(comptime, name), source=AttrSource(self.source, name)
  515. )
  516. def call_function(
  517. self,
  518. tx: "InstructionTranslator",
  519. args: "list[VariableTracker]",
  520. kwargs: "dict[str, VariableTracker]",
  521. ) -> "VariableTracker":
  522. from ..comptime import ComptimeContext
  523. # TODO: support an expression form as well
  524. assert not kwargs
  525. # Second argument is runtime lambda, ignored
  526. assert len(args) <= 2
  527. fn = args[0]
  528. if isinstance(fn, UserFunctionVariable):
  529. fn.get_function()(ComptimeContext(tx))
  530. elif isinstance(fn, NestedUserFunctionVariable):
  531. # We have to manually bind the freevars ourselves
  532. code = fn.get_code()
  533. assert not fn.closure, (
  534. "comptime function must not have free variables, "
  535. f"but these variables were free: {code.co_freevars}"
  536. )
  537. func = types.FunctionType(
  538. code,
  539. fn.f_globals,
  540. fn.fn_name.as_python_constant(),
  541. tuple(fn.defaults.items) if fn.defaults else None,
  542. # We could automatically promote free variables into
  543. # ComptimeVar but this is confusing if you access
  544. # a free variable that we actually DO have the runtime
  545. # value for
  546. # tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
  547. (),
  548. )
  549. func(ComptimeContext(tx))
  550. else:
  551. raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
  552. return variables.ConstantVariable.create(None)
  553. class CellVariable(VariableTracker):
  554. # If the cell existed before Dynamo tracing started, this will be the
  555. # VariableTracker that represents the cell content.
  556. #
  557. # Note that all mutation to the cell (i.e., its content) will be buffered in
  558. # SideEffects, rather than being reflected here. One can think of
  559. # `CellVariable` as a special case for `UserDefinedObjectVariable`.
  560. pre_existing_contents: Optional[VariableTracker]
  561. # This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the
  562. # root frame via this name (e.g., the name is in `co_cellvars/co_freevars`).
  563. local_name: Optional[str] = None
  564. def __init__(
  565. self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs
  566. ) -> None:
  567. super().__init__(**kwargs)
  568. self.pre_existing_contents = pre_existing_contents
  569. class NewGlobalVariable(VariableTracker):
  570. def __init__(self, **kwargs) -> None:
  571. super().__init__(**kwargs)
  572. def produce_trampoline_autograd_apply(fn_cls):
  573. def trampoline_autograd_apply(*args, **kwargs):
  574. return fn_cls.apply(*args, **kwargs)
  575. trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
  576. return trampoline_autograd_apply
  577. class AutogradFunctionVariable(VariableTracker):
  578. """represents a torch.autograd.Function subclass"""
  579. _nonvar_fields = {
  580. "fn_cls",
  581. *VariableTracker._nonvar_fields,
  582. }
  583. def __init__(self, fn_cls, **kwargs) -> None:
  584. super().__init__(**kwargs)
  585. self.fn_cls = fn_cls
  586. def call_apply(self, tx: "InstructionTranslator", args, kwargs):
  587. requires_grad = False
  588. def visit(vt):
  589. nonlocal requires_grad
  590. if isinstance(vt, variables.TensorVariable):
  591. if vt.requires_grad is not False:
  592. requires_grad = True
  593. if isinstance(vt, variables.NNModuleVariable):
  594. if vt.is_training(tx):
  595. requires_grad = True
  596. VariableTracker.visit(visit, (args, kwargs))
  597. if requires_grad and torch.is_grad_enabled():
  598. if config.capture_autograd_function is False:
  599. warnings.warn(
  600. "The config.capture_autograd_function flag is deprecated, it's now always true."
  601. )
  602. from torch._functorch.autograd_function import (
  603. autograd_function_forward_rewritten,
  604. )
  605. from torch.autograd.function import _is_setup_context_defined
  606. forward_fn = self.fn_cls.forward
  607. is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
  608. if is_setup_ctx_defined:
  609. # If setup_context is defined, we generate a new forward function which includes
  610. # the original forward and setup_context function, and trace the new forward function.
  611. forward_fn = autograd_function_forward_rewritten(
  612. self.fn_cls.forward, self.fn_cls.setup_context
  613. )
  614. vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
  615. if vjp_fn is not torch.autograd.Function.vjp:
  616. unimplemented_v2(
  617. gb_type="Unsupported custom vjp",
  618. context=f"call_apply {self} {args} {kwargs}",
  619. explanation="Dynamo does not support tracing "
  620. "`torch.autograd.Function` subclasses that define "
  621. "a custom `vjp` method.",
  622. hints=[
  623. "Remove the custom `vjp` method if possible.",
  624. "Use standard `backward` instead if applicable.",
  625. *graph_break_hints.SUPPORTABLE,
  626. ],
  627. )
  628. jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
  629. if jvp_fn is not torch.autograd.Function.jvp:
  630. unimplemented_v2(
  631. gb_type="Unsupported custom jvp",
  632. context=f"call_apply {self} {args} {kwargs}",
  633. explanation="Dynamo does not support tracing "
  634. "`torch.autograd.Function` subclasses that define "
  635. "a custom `jvp` method.",
  636. hints=[
  637. "Remove the custom `jvp` method if possible.",
  638. *graph_break_hints.SUPPORTABLE,
  639. ],
  640. )
  641. from .higher_order_ops import AutogradFunctionApplyVariable
  642. source = self.source
  643. if source is None:
  644. source = AttrSource(
  645. tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
  646. )
  647. val = AutogradFunctionApplyVariable(
  648. forward_fn,
  649. self.fn_cls.backward,
  650. source,
  651. source=AttrSource(source, member="apply"),
  652. ).call_function(tx, args, kwargs)
  653. # Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
  654. # the forward function, as we don't want to generate guards for new_forward.__closure__
  655. # if forward is rewritten by autograd_function_forward_rewritten.
  656. # But we still need to generate correct guards for the original forward and setup_context
  657. # functions, so we have to add guards manually.
  658. if self.source:
  659. fwd_src = AttrSource(self.source, "forward")
  660. install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
  661. if is_setup_ctx_defined:
  662. setup_ctx_src = AttrSource(self.source, "setup_context")
  663. install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
  664. return val
  665. if self.source:
  666. source = AttrSource(self.source, "forward")
  667. else:
  668. source = None
  669. fn = self.fn_cls.forward
  670. ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
  671. args = [ctx, *args]
  672. if isinstance(fn, types.FunctionType):
  673. sig = inspect.signature(fn)
  674. if len(args) - 1 == len(sig._parameters):
  675. args = args[1:] # Don't use context
  676. return variables.UserFunctionVariable(fn, source=source).call_function(
  677. tx, args, kwargs
  678. )
  679. elif isinstance(fn, types.MethodType):
  680. return variables.UserMethodVariable(
  681. fn.__func__,
  682. variables.UserDefinedClassVariable(self.fn_cls),
  683. source=source,
  684. ).call_function(tx, args, kwargs)
  685. else:
  686. unimplemented_v2(
  687. gb_type="Non-function or method in subclass of torch.autograd.Function",
  688. context=f"call_apply {self} {args} {kwargs}",
  689. explanation="Dynamo requires the `forward` attribute of a "
  690. "`torch.autograd.Function` subclass to be a standard Python "
  691. f"function or method. Found type `{type(fn).__name__}` instead.",
  692. hints=[
  693. "Ensure the `forward` method is defined as a regular "
  694. "function or instance method."
  695. ],
  696. )
  697. def call_backward(self, tx: "InstructionTranslator", args, kwargs):
  698. fn = self.fn_cls.backward
  699. assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
  700. assert isinstance(fn, types.FunctionType)
  701. fn_source = AttrSource(self.source, "backward")
  702. return variables.UserFunctionVariable(fn, source=fn_source).call_function(
  703. tx, args, kwargs
  704. )
  705. def call_function(self, tx: "InstructionTranslator", args, kwargs):
  706. return AutogradFunctionVariable(self.fn_cls)
  707. def call_method(
  708. self,
  709. tx: "InstructionTranslator",
  710. name,
  711. args: "list[VariableTracker]",
  712. kwargs: "dict[str, VariableTracker]",
  713. ):
  714. from .builder import wrap_fx_proxy
  715. if name == "apply":
  716. if trace_rules.is_callable_allowed(self.fn_cls):
  717. trampoline_autograd_apply = produce_trampoline_autograd_apply(
  718. self.fn_cls
  719. )
  720. return wrap_fx_proxy(
  721. tx=tx,
  722. proxy=tx.output.create_proxy(
  723. "call_function",
  724. trampoline_autograd_apply,
  725. *proxy_args_kwargs(args, kwargs),
  726. ),
  727. )
  728. else:
  729. return self.call_apply(tx, args, kwargs)
  730. elif name == "backward":
  731. return self.call_backward(tx, args, kwargs)
  732. else:
  733. source = AttrSource(self.source, name) if self.source is not None else None
  734. try:
  735. obj = inspect.getattr_static(self.fn_cls, name)
  736. except AttributeError:
  737. obj = None
  738. if isinstance(obj, staticmethod):
  739. func = obj.__get__(self.fn_cls)
  740. if source is not None:
  741. return (
  742. trace_rules.lookup(func)
  743. .create_with_source(func, source=source)
  744. .call_function(tx, args, kwargs)
  745. )
  746. else:
  747. return trace_rules.lookup(func)(func).call_function(
  748. tx, args, kwargs
  749. )
  750. elif isinstance(obj, classmethod):
  751. return variables.UserMethodVariable(
  752. obj.__func__, self, source=source
  753. ).call_function(tx, args, kwargs)
  754. else:
  755. unimplemented_v2(
  756. gb_type="Unsupported autograd.Function method",
  757. context=f"call_method {self} {name}",
  758. explanation="Dynamo does not support calling the method "
  759. f"`{name}` directly on the `torch.autograd.Function` "
  760. "instance. Supported methods include `apply`, `backward`, "
  761. "static methods, and class methods.",
  762. hints=[
  763. "Ensure the method is decorated with `@staticmethod` "
  764. "or `@classmethod` if it's meant to be called on the class.",
  765. ],
  766. )
  767. @dataclasses.dataclass
  768. class SavedTensorBox:
  769. tensors: list[VariableTracker] = dataclasses.field(default_factory=list)
  770. class AutogradFunctionContextVariable(UserDefinedObjectVariable):
  771. """
  772. Tracks an autograd.Function() context using mutation tracking in side_effects.py
  773. """
  774. _nonvar_fields = {
  775. "proxy",
  776. "inference",
  777. "saved_tensors",
  778. *UserDefinedObjectVariable._nonvar_fields,
  779. }
  780. def __init__(
  781. self,
  782. value,
  783. value_type=None,
  784. inference=False,
  785. saved_tensors=None,
  786. needs_input_grad=None,
  787. non_differentiable=None,
  788. **kwargs,
  789. ) -> None:
  790. super().__init__(value=value, value_type=value_type, **kwargs)
  791. self.inference = inference
  792. self.saved_tensors = saved_tensors
  793. self.needs_input_grad = needs_input_grad
  794. self.non_differentiable = non_differentiable
  795. @staticmethod
  796. def create(tx: "InstructionTranslator", args=None, kwargs=None):
  797. needs_input_grad = None
  798. if args and not kwargs:
  799. needs_input_grad = tuple(
  800. isinstance(x, variables.TensorVariable) and x.requires_grad
  801. for x in args
  802. )
  803. out = tx.output.side_effects.track_object_new(
  804. None,
  805. torch.autograd.function.FunctionCtx,
  806. functools.partial(
  807. AutogradFunctionContextVariable,
  808. inference=True,
  809. saved_tensors=SavedTensorBox(),
  810. needs_input_grad=needs_input_grad,
  811. ),
  812. {},
  813. )
  814. return out
  815. def as_proxy(self):
  816. if self.proxy is None:
  817. unimplemented_v2(
  818. gb_type="proxy not set",
  819. context=f"as_proxy {self}",
  820. explanation="Dynamo requires the autograd.Function context "
  821. "to be initialized with a proxy.",
  822. hints=[*graph_break_hints.DYNAMO_BUG],
  823. )
  824. return self.proxy
  825. def call_method(
  826. self,
  827. tx: "InstructionTranslator",
  828. name,
  829. args: "list[VariableTracker]",
  830. kwargs: "dict[str, VariableTracker]",
  831. ) -> "VariableTracker":
  832. if name == "__setattr__":
  833. return super().call_method(tx, name, args, kwargs)
  834. elif name == "mark_non_differentiable":
  835. assert len(kwargs) == 0
  836. self.non_differentiable = proxy_args_kwargs(args, {})[0]
  837. return variables.ConstantVariable.create(None)
  838. if name != "save_for_backward":
  839. unimplemented_v2(
  840. gb_type="Unsupported autograd.Function context method",
  841. context=f"call_method {self} {name}",
  842. explanation="Dynamo does not support calling the method "
  843. f"`{name}` on `autograd.Function` context objects. Supported "
  844. "methods are `__setattr__`, `save_for_backward` and "
  845. "`mark_non_differentiable`.",
  846. hints=[*graph_break_hints.SUPPORTABLE],
  847. )
  848. if self.saved_tensors is None:
  849. unimplemented_v2(
  850. gb_type="Unsupported autograd.Function context `save_for_backward`",
  851. context=f"call_method {self} {name}",
  852. explanation="Dynamo requires the `saved_tensors` attribute "
  853. "to be initialized on the `autograd.Function` context object.",
  854. hints=[
  855. "Ensure that the `saved_tensors` attribute is properly "
  856. "initialized before calling `save_for_backward`. "
  857. "`save_for_backward` only supported on a newly constructed `torch.autograd.function.FunctionCtx`.",
  858. ],
  859. )
  860. if not self.inference:
  861. assert self.source and not kwargs
  862. tx.output.side_effects.track_save_for_backward(self, args)
  863. # In eager mode, multiple calls to .save_for_backward() will overwrite previous calls.
  864. if len(self.saved_tensors.tensors) > 0:
  865. self.saved_tensors.tensors = []
  866. for arg in args:
  867. self.saved_tensors.tensors.append(arg)
  868. return variables.ConstantVariable.create(None)
  869. def var_getattr(self, tx: "InstructionTranslator", name):
  870. if name in ["save_for_backward", "mark_non_differentiable"]:
  871. return LambdaVariable(
  872. lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
  873. )
  874. if name == "saved_tensors" and self.saved_tensors is not None:
  875. return variables.TupleVariable(list(self.saved_tensors.tensors))
  876. if name == "needs_input_grad":
  877. if self.needs_input_grad is not None:
  878. return variables.ConstantVariable.create(self.needs_input_grad)
  879. if self.source:
  880. source = AttrSource(self.source, "needs_input_grad")
  881. return VariableTracker.build(tx, self.value.needs_input_grad, source)
  882. return super().var_getattr(tx, name)
  883. class AutogradEngineVariable(UserDefinedObjectVariable):
  884. """
  885. Represents a torch._C._ImperativeEngine instance.
  886. """
  887. def __init__(
  888. self,
  889. value,
  890. value_type=None,
  891. **kwargs,
  892. ) -> None:
  893. super().__init__(value=value, value_type=value_type, **kwargs)
  894. def call_method(
  895. self,
  896. tx: "InstructionTranslator",
  897. name,
  898. args: "list[VariableTracker]",
  899. kwargs: "dict[str, VariableTracker]",
  900. ) -> "VariableTracker":
  901. if name == "queue_callback":
  902. if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
  903. assert tx.one_graph or tx.error_on_graph_break, (
  904. "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
  905. )
  906. return variables.UserFunctionVariable(
  907. torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
  908. source=self.source,
  909. ).call_function(
  910. tx,
  911. (tx.output.side_effects.get_ca_final_callbacks_var(), *args),
  912. kwargs,
  913. )
  914. else:
  915. unimplemented_v2(
  916. gb_type="Unsupported torch._C._ImperativeEngine.queue_callback()",
  917. context=f"call_method {self} {name}",
  918. explanation="queue_callback() is only supported when "
  919. "Compiled Autograd is enabled with fullgraph=True.",
  920. hints=[],
  921. )
  922. else:
  923. unimplemented_v2(
  924. gb_type="Unsupported torch._C._ImperativeEngine method",
  925. context=f"call_method {self} {name}",
  926. explanation="Dynamo only supports the `queue_callback` method "
  927. f"on a torch._C._ImperativeEngine instance, but found: `{name}`.",
  928. hints=[],
  929. )
  930. class LambdaVariable(VariableTracker):
  931. def __init__(self, fn, **kwargs) -> None:
  932. super().__init__(**kwargs)
  933. self.fn = fn
  934. def call_function(
  935. self,
  936. tx: "InstructionTranslator",
  937. args: "list[VariableTracker]",
  938. kwargs: "dict[str, VariableTracker]",
  939. ) -> "VariableTracker":
  940. return self.fn(*args, **kwargs)
  941. class GetAttrVariable(VariableTracker):
  942. _nonvar_fields = {
  943. "name",
  944. "py_type",
  945. *VariableTracker._nonvar_fields,
  946. }
  947. def __init__(self, obj, name, py_type=None, **kwargs) -> None:
  948. super().__init__(**kwargs)
  949. assert isinstance(obj, VariableTracker)
  950. assert isinstance(name, str)
  951. self.obj = obj
  952. self.name = name
  953. self.py_type = py_type # In some cases we know the type (ex. tensor methods)
  954. def python_type(self):
  955. if self.py_type is not None:
  956. return self.py_type
  957. else:
  958. return super().python_type()
  959. def __repr__(self) -> str:
  960. return f"{self.__class__.__name__}({self.obj}, {self.name})"
  961. @staticmethod
  962. def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
  963. return getattr(base_proxy, attr)
  964. def as_proxy(self):
  965. return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
  966. def as_python_constant(self):
  967. constant = self.obj.as_python_constant()
  968. try:
  969. return getattr(constant, self.name)
  970. except AttributeError:
  971. raise NotImplementedError(f"{self} is not a constant") from None
  972. def const_getattr(self, tx: "InstructionTranslator", name):
  973. if not isinstance(self.obj, variables.NNModuleVariable):
  974. raise NotImplementedError
  975. step1 = tx.output.get_submodule(self.obj.module_key)
  976. if self.name not in step1.__dict__:
  977. raise NotImplementedError
  978. step2 = inspect.getattr_static(step1, self.name)
  979. if name not in step2.__dict__:
  980. raise NotImplementedError
  981. return inspect.getattr_static(step2, name)
  982. def reconstruct(self, codegen: "PyCodegen"):
  983. codegen(self.obj)
  984. codegen.extend_output(codegen.create_load_attrs(self.name))
  985. def call_function(
  986. self,
  987. tx: "InstructionTranslator",
  988. args: "list[VariableTracker]",
  989. kwargs: "dict[str, VariableTracker]",
  990. ) -> "VariableTracker":
  991. return self.obj.call_method(tx, self.name, args, kwargs)
  992. def call_method(
  993. self,
  994. tx: "InstructionTranslator",
  995. name,
  996. args: list[VariableTracker],
  997. kwargs: dict[str, VariableTracker],
  998. ) -> VariableTracker:
  999. if (
  1000. name in ("__getitem__", "get")
  1001. and self.name == "__dict__"
  1002. and not kwargs
  1003. and args[0].is_python_constant()
  1004. and isinstance(
  1005. self.obj,
  1006. (
  1007. variables.UserDefinedObjectVariable,
  1008. variables.NNModuleVariable,
  1009. variables.UserDefinedClassVariable,
  1010. ),
  1011. )
  1012. ):
  1013. obj = self.obj
  1014. key = args[0].as_python_constant()
  1015. if obj.has_key_in_generic_dict(tx, key):
  1016. # redirect to var_getattr on the original obj
  1017. return obj.var_getattr(tx, key)
  1018. # Return the default value for get
  1019. if name == "get":
  1020. if len(args) == 2:
  1021. return args[1]
  1022. else:
  1023. return variables.ConstantVariable(None)
  1024. elif (
  1025. name == "__contains__"
  1026. and self.name == "__dict__"
  1027. and len(args) == 1
  1028. and args[0].is_python_constant()
  1029. and not kwargs
  1030. and isinstance(
  1031. self.obj,
  1032. (
  1033. variables.UserDefinedObjectVariable,
  1034. variables.NNModuleVariable,
  1035. variables.UserDefinedClassVariable,
  1036. ),
  1037. )
  1038. ):
  1039. obj = self.obj
  1040. key = args[0].as_python_constant()
  1041. if obj.has_key_in_generic_dict(tx, key):
  1042. return variables.ConstantVariable(True)
  1043. else:
  1044. return variables.ConstantVariable(False)
  1045. elif name == "__setitem__" and self.name == "__dict__" and not kwargs:
  1046. if isinstance(self.obj, variables.UserDefinedObjectVariable):
  1047. # Bypass any custom setattr as we are updating the `__dict__` itself
  1048. return self.obj.method_setattr_standard(
  1049. tx, args[0], args[1], directly_update_dict=True
  1050. )
  1051. if isinstance(self.obj, variables.NNModuleVariable):
  1052. # This matches how `setattr` is handled for NNModuleVariable
  1053. self.obj.convert_to_unspecialized(tx)
  1054. return super().call_method(tx, name, args, kwargs)
  1055. def get_forwarded_dict(self, tx):
  1056. assert (
  1057. self.name == "__dict__"
  1058. and isinstance(self.obj, variables.UserDefinedClassVariable)
  1059. and not tx.output.side_effects.has_pending_mutation(self.obj)
  1060. )
  1061. self.obj.ban_mutation = True
  1062. return VariableTracker.build(tx, self.obj.value.__dict__, self.source)
  1063. class MethodWrapperVariable(VariableTracker):
  1064. def __init__(self, method_wrapper, **kwargs) -> None:
  1065. super().__init__(**kwargs)
  1066. self.method_wrapper = method_wrapper
  1067. self._builtin_fns = {}
  1068. def call_function(
  1069. self,
  1070. tx: "InstructionTranslator",
  1071. args: "list[VariableTracker]",
  1072. kwargs: "dict[str, VariableTracker]",
  1073. ) -> "VariableTracker":
  1074. if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
  1075. args[0], variables.TensorVariable
  1076. ):
  1077. assert len(args) == 1 and len(kwargs) == 0
  1078. return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
  1079. # method-wrapper variables are common in __init__ calls. For example,
  1080. # str("foo").__init__ is a method-wrapper. These method wrappers point
  1081. # to C functions. Here we intercept if these method-wrappers are from
  1082. # builtins and then call the function counterpart directly by obtaining
  1083. # the self object.
  1084. self_obj = self.method_wrapper.__self__
  1085. wrapper_name = self.method_wrapper.__name__
  1086. # TODO(dynamo-team) - We can perhaps expand the scope to more names and
  1087. # more builtins.
  1088. if wrapper_name == "__init__":
  1089. fn_obj = type(self_obj).__init__
  1090. if fn_obj is object.__init__:
  1091. return variables.BuiltinVariable(object).call_method(
  1092. tx, wrapper_name, [self_obj, *args], kwargs
  1093. )
  1094. return super().call_function(tx, args, kwargs)
  1095. def is_python_constant(self):
  1096. return True
  1097. def as_python_constant(self):
  1098. return self.method_wrapper
  1099. class GetSetDescriptorVariable(VariableTracker):
  1100. def __init__(self, desc, **kwargs) -> None:
  1101. super().__init__(**kwargs)
  1102. self.desc = desc
  1103. def var_getattr(self, tx: "InstructionTranslator", name):
  1104. if name == "__get__" and self.source:
  1105. source = AttrSource(self.source, "__get__")
  1106. return VariableTracker.build(tx, self.desc.__get__, source)
  1107. else:
  1108. return super().var_getattr(tx, name)
  1109. def is_python_constant(self):
  1110. return True
  1111. def as_python_constant(self):
  1112. return self.desc
  1113. class PythonModuleVariable(VariableTracker):
  1114. _nonvar_fields = {
  1115. "value",
  1116. "is_torch",
  1117. *VariableTracker._nonvar_fields,
  1118. }
  1119. def __init__(self, value: types.ModuleType, **kwargs) -> None:
  1120. super().__init__(**kwargs)
  1121. self.value = value
  1122. self.is_torch = self.value is torch or self.value.__name__.startswith("torch.")
  1123. def python_type(self):
  1124. return types.ModuleType
  1125. def as_python_constant(self):
  1126. return self.value
  1127. def __repr__(self) -> str:
  1128. return f"PythonModuleVariable({self.value})"
  1129. def call_obj_hasattr(self, tx: "InstructionTranslator", name):
  1130. result = hasattr(self.value, name)
  1131. return variables.ConstantVariable.create(result)
  1132. def var_getattr(self, tx: "InstructionTranslator", name):
  1133. if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
  1134. return tx.output.side_effects.load_attr(self, name)
  1135. if self.is_torch or name not in self.value.__dict__:
  1136. try:
  1137. attr_value = getattr(self.value, name)
  1138. except AttributeError:
  1139. raise_observed_exception(AttributeError, tx)
  1140. else:
  1141. attr_value = self.value.__dict__[name]
  1142. source = self.source and AttrSource(self.source, name)
  1143. return VariableTracker.build(tx, attr_value, source)
  1144. class TypingVariable(VariableTracker):
  1145. def __init__(self, value, **kwargs) -> None:
  1146. super().__init__(**kwargs)
  1147. self.value = value
  1148. def call_method(
  1149. self,
  1150. tx: "InstructionTranslator",
  1151. name,
  1152. args: "list[VariableTracker]",
  1153. kwargs: "dict[str, VariableTracker]",
  1154. ) -> "VariableTracker":
  1155. # Create a new typing variable, e.g., `List[int]`
  1156. if name == "__getitem__" and len(args) == 1:
  1157. new_typing = self.value[args[0].as_python_constant()]
  1158. return TypingVariable(new_typing)
  1159. unimplemented("unsupported method call on typing variablel")
  1160. def var_getattr(self, tx: "InstructionTranslator", name: str):
  1161. from .builder import SourcelessBuilder, VariableBuilder
  1162. if name in cmp_name_to_op_mapping:
  1163. return variables.GetAttrVariable(self, name)
  1164. if tx.output.side_effects.has_pending_mutation_of_attr(self, name):
  1165. return tx.side_effects.load_attr(self, name)
  1166. value = getattr(self.value, name)
  1167. if self.source:
  1168. attr_source = AttrSource(self.source, name)
  1169. return VariableBuilder(tx, attr_source)(value)
  1170. else:
  1171. return SourcelessBuilder.create(tx, value)
  1172. def as_python_constant(self):
  1173. return self.value
  1174. def reconstruct(self, codegen: "PyCodegen") -> None:
  1175. # We're just trying to load the type here. Reconstructing the type from
  1176. # scratch is tricky - for a type like `typing.List[int]` we'd need to
  1177. # deconstruct the origin and args. The origin for `List[int]` is `list`
  1178. # and the args is `(int,)`. When we recombine those we get the parts
  1179. # back and need to emit code for:
  1180. #
  1181. # `typing.List[int]`
  1182. #
  1183. # But it's # worse than that - what if `typing` isn't in the globals (or
  1184. # was loaded like `import typing as _typing ; _typing.List[int]`?) so we
  1185. # really need to do something like:
  1186. #
  1187. # `sys.modules["typing"].List[int]`
  1188. #
  1189. # Argh - but what if they rewrote the global `int`? So we have to do:
  1190. #
  1191. # `sys.modules["typing"].List[sys.modules["builtins"].int]`
  1192. #
  1193. # But where do we get `sys`? What if they never imported it or have
  1194. # something ELSE called `sys`?
  1195. #
  1196. # Let's skip all that noise and just emit it as a simple const.
  1197. #
  1198. codegen.append_output(codegen.create_load_const(self.value))
  1199. @functools.lru_cache(maxsize=1)
  1200. def get_np_to_tnp_map():
  1201. """
  1202. This generates a mapping from numpy modules to their torch._numpy
  1203. modules equivalents.
  1204. """
  1205. from ..utils import NP_TO_TNP_MODULE
  1206. np_fn_to_tnp_fn = {}
  1207. for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
  1208. for fn_name, tnp_fn in tnp_mod.__dict__.items():
  1209. if callable(tnp_fn):
  1210. # some internal details do leak from tnp
  1211. # which are not part of numpy API.
  1212. if np_fn := getattr(np_mod, fn_name, None):
  1213. np_fn_to_tnp_fn[np_fn] = tnp_fn
  1214. return np_fn_to_tnp_fn
  1215. @functools.lru_cache(maxsize=1)
  1216. def get_tnp_to_np_map():
  1217. """
  1218. This is just the reverse mapping of get_np_to_tnp_map() - mapping from
  1219. torch._numpy modules to numpy equivalents.
  1220. """
  1221. m = get_np_to_tnp_map()
  1222. return {v: k for k, v in m.items()}
  1223. class NumpyVariable(VariableTracker):
  1224. """
  1225. Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
  1226. """
  1227. constant_fold_functions = (tnp.issubdtype,)
  1228. def __init__(self, value, **kwargs) -> None:
  1229. super().__init__(**kwargs)
  1230. self.value = value
  1231. @classmethod
  1232. def can_constant_fold_through(cls, fn):
  1233. mod = fn.__module__.split(".")
  1234. assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
  1235. return fn in cls.constant_fold_functions
  1236. @classmethod
  1237. def get_constant_collection_for_func(cls, fn):
  1238. mod = fn.__module__.split(".")
  1239. assert len(mod) >= 2 and mod[:2] == ["torch", "_numpy"]
  1240. return np_constant_collections_map.get(fn, None)
  1241. def call_function(
  1242. self,
  1243. tx: "InstructionTranslator",
  1244. args: "list[VariableTracker]",
  1245. kwargs: "dict[str, VariableTracker]",
  1246. ) -> "VariableTracker":
  1247. if not config.trace_numpy:
  1248. unimplemented(f"numpy.{self.value}()")
  1249. from ..utils import numpy_to_tensor_wrapper
  1250. from .tensor import NumpyNdarrayVariable
  1251. func = get_np_to_tnp_map().get(self.value)
  1252. if func is None:
  1253. unimplemented(
  1254. f"Can't find numpy function {self.value} in torch._numpy. "
  1255. " Please file an issue to request support for this function."
  1256. )
  1257. # We are dealing with a function that produces a const collection type (np.dtype, np.iinfo/np.finfo)
  1258. if (
  1259. collection_variable_typ := self.get_constant_collection_for_func(func)
  1260. ) is not None:
  1261. try:
  1262. return collection_variable_typ(
  1263. self.value(
  1264. *[x.as_python_constant() for x in args],
  1265. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1266. )
  1267. )
  1268. except NotImplementedError:
  1269. unimplemented(
  1270. f"{self.value.__name__} with non-const args: {args} {kwargs}"
  1271. )
  1272. else:
  1273. if (
  1274. func.__module__ == "torch._numpy.random"
  1275. and config.use_numpy_random_stream
  1276. ):
  1277. msg = f"delegate '{func.__qualname__}' to NumPy itself via "
  1278. msg += (
  1279. f"config.use_numpy_random_stream={config.use_numpy_random_stream}"
  1280. )
  1281. unimplemented(msg)
  1282. args, kwargs = NumpyNdarrayVariable.patch_args(func.__name__, args, kwargs)
  1283. if self.can_constant_fold_through(func) and (
  1284. check_unspec_or_constant_args(args, kwargs)
  1285. ):
  1286. # constant fold
  1287. return variables.ConstantVariable.create(
  1288. self.as_python_constant()(
  1289. *[x.as_python_constant() for x in args],
  1290. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1291. ),
  1292. )
  1293. # TODO Add all the functions that go from constants to constants to can_constant_fold_through
  1294. proxy = tx.output.create_proxy(
  1295. "call_function",
  1296. numpy_to_tensor_wrapper(func),
  1297. *proxy_args_kwargs(args, kwargs),
  1298. )
  1299. return NumpyNdarrayVariable.create(tx, proxy)
  1300. def call_method(
  1301. self,
  1302. tx: "InstructionTranslator",
  1303. name,
  1304. args: "list[VariableTracker]",
  1305. kwargs: "dict[str, VariableTracker]",
  1306. ) -> "VariableTracker":
  1307. unimplemented("numpy")
  1308. def as_python_constant(self):
  1309. return self.value
  1310. def as_proxy(self):
  1311. if config.trace_numpy and isinstance(self.value, type):
  1312. # This handles numpy dtype attributes such as np.float32
  1313. # We return a string as we don't want to serialize non-PyTorch objects in the output FX graph
  1314. # In torch/_numpy we normalize strings to their dtypes when the input is a dtype, as NumPy does
  1315. return self.value.__name__
  1316. return super().as_proxy()
  1317. # Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
  1318. class NullVariable(VariableTracker):
  1319. def __init__(self, **kwargs) -> None:
  1320. super().__init__(**kwargs)
  1321. def __repr__(self) -> str:
  1322. return "NullVariable"
  1323. def reconstruct(self, codegen: "PyCodegen"):
  1324. if sys.version_info < (3, 11):
  1325. unimplemented("cannot reconstruct NullVariable in < Python 3.11")
  1326. codegen.append_output(create_instruction("PUSH_NULL"))
  1327. class DeletedVariable(VariableTracker):
  1328. """Marker used to implement delattr()"""
  1329. class StringFormatVariable(VariableTracker):
  1330. """
  1331. Represents a call to str.format(), we delay calling format until after the graph.
  1332. """
  1333. _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
  1334. @classmethod
  1335. def create(cls, format_string, sym_args, sym_kwargs):
  1336. if all(
  1337. x.is_python_constant()
  1338. for x in itertools.chain(sym_args, sym_kwargs.values())
  1339. ):
  1340. return variables.ConstantVariable.create(
  1341. format_string.format(
  1342. *[v.as_python_constant() for v in sym_args],
  1343. **{k: v.as_python_constant() for k, v in sym_kwargs.items()},
  1344. )
  1345. )
  1346. return cls(format_string, list(sym_args), dict(sym_kwargs))
  1347. def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None:
  1348. super().__init__(**kwargs)
  1349. assert isinstance(format_string, str)
  1350. self.format_string = format_string
  1351. self.sym_args = sym_args
  1352. self.sym_kwargs = sym_kwargs
  1353. def __repr__(self) -> str:
  1354. return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
  1355. def reconstruct(self, codegen: "PyCodegen"):
  1356. codegen.add_push_null(
  1357. lambda: codegen.extend_output(
  1358. [
  1359. codegen.create_load_const(self.format_string),
  1360. codegen.create_load_attr("format"),
  1361. ]
  1362. ),
  1363. call_function_ex=True,
  1364. )
  1365. codegen(variables.TupleVariable(self.sym_args))
  1366. kwargs = {
  1367. variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
  1368. }
  1369. codegen(variables.ConstDictVariable(kwargs))
  1370. codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
  1371. class DebuggingVariable(VariableTracker):
  1372. """
  1373. Represents a call to a debugging function like print(), or something
  1374. registered to config.reorderable_logging_functions.
  1375. """
  1376. def __init__(self, value, **kwargs) -> None:
  1377. super().__init__(**kwargs)
  1378. self.value = value
  1379. @staticmethod
  1380. def is_reorderable_logging_function(obj):
  1381. return (
  1382. callable(obj)
  1383. and isinstance(obj, (types.FunctionType, types.BuiltinFunctionType))
  1384. and obj in torch._dynamo.config.reorderable_logging_functions
  1385. )
  1386. def call_function(self, tx: "InstructionTranslator", args, kwargs):
  1387. if tx.export:
  1388. # For export cases, we can just make debugging functions no-ops
  1389. return
  1390. if not self.can_reorder_logs(self.value, args, kwargs):
  1391. unimplemented(
  1392. f"Reordering debugging function {self.value} "
  1393. f"with inputs {args} {kwargs} is not yet implemented."
  1394. )
  1395. tx.debug_locals.append((self, list(args)))
  1396. def reconstruct(self, codegen: "PyCodegen"):
  1397. return self.source.reconstruct(codegen)
  1398. @staticmethod
  1399. def can_reorder_logs(fn, args, kwargs) -> True:
  1400. """
  1401. Run some additional checks for what sort of function calls can we
  1402. actually reorder.
  1403. """
  1404. allowed_input_types = (
  1405. variables.TensorVariable,
  1406. variables.ConstantVariable,
  1407. StringFormatVariable,
  1408. )
  1409. flat_args = pytree.tree_leaves([args, kwargs])
  1410. for arg in flat_args:
  1411. if not isinstance(arg, allowed_input_types):
  1412. return False
  1413. return True
  1414. class LoggingLoggerVariable(VariableTracker):
  1415. """
  1416. Represents a call to any of logging.Logger methods
  1417. """
  1418. def __init__(self, value, **kwargs) -> None:
  1419. super().__init__(**kwargs)
  1420. self.value = value
  1421. def call_method(
  1422. self,
  1423. tx: "InstructionTranslator",
  1424. name,
  1425. args: "list[VariableTracker]",
  1426. kwargs: "dict[str, VariableTracker]",
  1427. ) -> "VariableTracker":
  1428. if tx.export:
  1429. # For export cases, we can just make debugging functions no-ops
  1430. return
  1431. method = getattr(self.value, name, None)
  1432. function = getattr(method, "__func__", None)
  1433. if {method, function}.intersection(torch._dynamo.config.ignore_logger_methods):
  1434. return variables.ConstantVariable.create(None)
  1435. unimplemented(
  1436. "Logger not supported for non-export cases. "
  1437. "To avoid graph breaks caused by logger in compile-mode, it is recommended to"
  1438. " disable logging by adding logging methods to config.ignore_logger_methods"
  1439. )
  1440. class ConstantLikeVariable(VariableTracker):
  1441. """self.value is a compile-time constant, but not a literal"""
  1442. _error_prefix = "ConstantLikeVariable"
  1443. try:
  1444. from numpy import (
  1445. dtype as np_dtype,
  1446. floating as np_floating,
  1447. generic as np_generic,
  1448. )
  1449. except ImportError:
  1450. np_floating = type("invalid_type", (), {})
  1451. np_dtype = type("invalid_type", (), {})
  1452. def __init__(self, value, **kwargs) -> None:
  1453. super().__init__(**kwargs)
  1454. self.value = value
  1455. def as_python_constant(self):
  1456. return self.value
  1457. def call_method(
  1458. self,
  1459. tx: "InstructionTranslator",
  1460. name,
  1461. args: list[VariableTracker],
  1462. kwargs: dict[str, VariableTracker],
  1463. ) -> VariableTracker:
  1464. try:
  1465. # we only support constant propagation for methods
  1466. cargs = [x.as_python_constant() for x in args]
  1467. ckwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
  1468. except NotImplementedError:
  1469. unimplemented(f"{self._error_prefix}.{name}(*{args}, **{kwargs})")
  1470. result = getattr(self.value, name)(*cargs, **ckwargs)
  1471. if variables.ConstantVariable.is_literal(result):
  1472. return variables.ConstantVariable.create(result)
  1473. if isinstance(result, re.Match):
  1474. return ConstantRegexMatchVariable(result)
  1475. unimplemented(f"{self._error_prefix}.{name}() -> {result}")
  1476. def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
  1477. result = getattr(self.value, name)
  1478. if isinstance(result, self.np_floating):
  1479. result = float(result)
  1480. if isinstance(result, self.np_dtype):
  1481. return NumpyDTypeVariable(result)
  1482. if isinstance(result, type) and issubclass(result, self.np_generic):
  1483. # things like x.dtype.type
  1484. return NumpyVariable(result)
  1485. if variables.ConstantVariable.is_literal(result):
  1486. return variables.ConstantVariable.create(result)
  1487. return GetAttrVariable(self, name)
  1488. class RegexPatternVariable(ConstantLikeVariable):
  1489. _error_prefix = "re.Pattern"
  1490. class ConstantRegexMatchVariable(ConstantLikeVariable):
  1491. _error_prefix = "re.Match"
  1492. class TorchVersionVariable(ConstantLikeVariable):
  1493. _error_prefix = "torch.__version__"
  1494. def __init__(self, **kwargs) -> None:
  1495. kwargs.setdefault("value", torch.__version__)
  1496. assert kwargs["value"] is torch.__version__
  1497. super().__init__(**kwargs)
  1498. class NumpyTypeInfoVariable(ConstantLikeVariable):
  1499. _error_prefix = "np.iinfo/np.finfo"
  1500. class NumpyDTypeVariable(ConstantLikeVariable):
  1501. _error_prefix = "np.dtype[...]"
  1502. def as_proxy(self):
  1503. """Similar to how numpy dtype descriptors (e.g. np.float32 ) are handled by NumpyVariable:
  1504. np.dtype() objects are serialized as strings, torch._numpy wrappers will normalize to the torch dtype.
  1505. This also handles unsupported things nicely (i.e. structured arrays and object arrays).
  1506. """
  1507. return self.value.type.__name__
  1508. np_constant_collections_map = {
  1509. tnp.finfo: NumpyTypeInfoVariable,
  1510. tnp.iinfo: NumpyTypeInfoVariable,
  1511. tnp.dtype: NumpyDTypeVariable,
  1512. }
  1513. class RandomClassVariable(VariableTracker):
  1514. """random.Random"""
  1515. def __init__(self, **kwargs) -> None:
  1516. super().__init__(**kwargs)
  1517. def call_function(self, tx: "InstructionTranslator", args, kwargs):
  1518. if len(args) > 1:
  1519. unimplemented("random.Random() with > 1 arg")
  1520. elif kwargs:
  1521. unimplemented("random.Random() with kwargs")
  1522. seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0]
  1523. return RandomVariable(
  1524. seed=seed, mutation_type=variables.base.ValueMutationNew()
  1525. )
  1526. class RandomVariable(VariableTracker):
  1527. """random.Random()
  1528. Implemented by wrapping a VariableTracker around a random.Random object.
  1529. The supported methods for the random.Random object cannot be overridden.
  1530. Assumes that random objects behave the same given a set seed or state.
  1531. """
  1532. _nonvar_fields = {
  1533. "random",
  1534. *VariableTracker._nonvar_fields,
  1535. }
  1536. _supported_fn_names = {
  1537. "random",
  1538. "randint",
  1539. "randrange",
  1540. "uniform",
  1541. }
  1542. def __init__(
  1543. self,
  1544. rand: Optional[random.Random] = None,
  1545. seed: Optional[VariableTracker] = None,
  1546. **kwargs,
  1547. ) -> None:
  1548. super().__init__(**kwargs)
  1549. if rand is not None:
  1550. assert self.is_supported_random_obj(rand)
  1551. self.random = random.Random()
  1552. self.random.setstate(rand.getstate())
  1553. else:
  1554. seed = seed.as_python_constant() if seed is not None else None
  1555. self.random = random.Random(seed)
  1556. def python_type(self):
  1557. return random.Random
  1558. def as_python_constant(self):
  1559. return self.random
  1560. @staticmethod
  1561. def is_supported_random_obj(val):
  1562. if type(val) is not random.Random:
  1563. return False
  1564. for name in itertools.chain(
  1565. RandomVariable._supported_fn_names, ("seed", "getstate", "setstate")
  1566. ):
  1567. if not hasattr(val, name):
  1568. return False
  1569. meth = getattr(val, name)
  1570. if inspect.isbuiltin(meth):
  1571. # e.g. random.Random.random
  1572. if meth != getattr(random.Random, name).__get__(val):
  1573. return False
  1574. else:
  1575. if getattr(meth, "__func__", None) is not getattr(random.Random, name):
  1576. return False
  1577. return True
  1578. @staticmethod
  1579. def check_state(state):
  1580. assert type(state) is tuple
  1581. assert type(state[0]) is int
  1582. assert type(state[1]) is tuple
  1583. assert all(type(x) is int for x in state[1])
  1584. assert state[2] is None or type(state[2]) is float
  1585. @staticmethod
  1586. def wrap_state(state):
  1587. RandomVariable.check_state(state)
  1588. return variables.TupleVariable(
  1589. [
  1590. variables.ConstantVariable.create(state[0]),
  1591. variables.TupleVariable(
  1592. [variables.ConstantVariable.create(x) for x in state[1]]
  1593. ),
  1594. variables.ConstantVariable.create(state[2]),
  1595. ]
  1596. )
  1597. @staticmethod
  1598. def unwrap_state(state):
  1599. state_obj = state.as_python_constant()
  1600. RandomVariable.check_state(state_obj)
  1601. return state_obj
  1602. def call_method(
  1603. self,
  1604. tx: "InstructionTranslator",
  1605. name,
  1606. args: list[VariableTracker],
  1607. kwargs: dict[str, VariableTracker],
  1608. ) -> VariableTracker:
  1609. if name == "seed":
  1610. tx.output.side_effects.mutation(self)
  1611. self.random.seed(
  1612. *[x.as_python_constant() for x in args],
  1613. **{key: val.as_python_constant() for key, val in kwargs.items()},
  1614. )
  1615. return variables.ConstantVariable.create(None)
  1616. elif name == "getstate":
  1617. return self.wrap_state(self.random.getstate())
  1618. elif name == "setstate":
  1619. tx.output.side_effects.mutation(self)
  1620. self.random.setstate(self.unwrap_state(args[0]))
  1621. return variables.ConstantVariable.create(None)
  1622. elif name in self._supported_fn_names:
  1623. tx.output.side_effects.mutation(self)
  1624. state = self.random.getstate()
  1625. def call_random_meth(*args, **kwargs):
  1626. r = random.Random()
  1627. r.setstate(state)
  1628. return getattr(r, name)(*args, **kwargs)
  1629. # self.random state not actually updated by call_random_meth, so update here
  1630. # by calling the method
  1631. getattr(self.random, name)(
  1632. *[x.as_python_constant() for x in args],
  1633. **{k: v.as_python_constant() for k, v in kwargs.items()},
  1634. )
  1635. return call_random_fn(tx, call_random_meth, args, kwargs)
  1636. return super().call_method(tx, name, args, kwargs)
  1637. def reconstruct(self, codegen: "PyCodegen"):
  1638. codegen.add_push_null(
  1639. lambda: codegen.extend_output(
  1640. [
  1641. codegen.create_load_python_module(random),
  1642. codegen.create_load_attr("Random"),
  1643. ]
  1644. )
  1645. )
  1646. codegen.call_function(0, False)
  1647. # NOTE using add_push_null may result in NULL being duplicated
  1648. # so defer the push_null to call_function
  1649. codegen.dup_top()
  1650. codegen.load_attr("setstate")
  1651. codegen(self.wrap_state(self.random.getstate()))
  1652. codegen.call_function(1, True)
  1653. codegen.pop_top()
  1654. class WeakRefVariable(VariableTracker):
  1655. @staticmethod
  1656. def build(tx, weakref_value, **options):
  1657. source = options.get("source", None)
  1658. callback = weakref_value.__callback__
  1659. callback_source = source and AttrSource(source, "__callback__")
  1660. callback_vt = VariableTracker.build(tx, callback, callback_source)
  1661. referent = weakref_value()
  1662. source = source and WeakRefCallSource(source)
  1663. referent_vt = VariableTracker.build(tx, referent, source)
  1664. options["source"] = source
  1665. return WeakRefVariable(referent_vt, callback_vt, **options)
  1666. def __init__(self, referent_vt, callback_vt, **options):
  1667. super().__init__(**options)
  1668. self.referent_vt = referent_vt
  1669. self.callback_vt = callback_vt
  1670. def call_function(
  1671. self,
  1672. tx: "InstructionTranslator",
  1673. args: "list[VariableTracker]",
  1674. kwargs: "dict[str, VariableTracker]",
  1675. ) -> "VariableTracker":
  1676. return self.referent_vt
  1677. def reconstruct(self, codegen: "PyCodegen"):
  1678. codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref"))
  1679. codegen(self.referent_vt)
  1680. codegen(self.callback_vt)
  1681. codegen.extend_output(create_call_function(2, False))