dicts.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371
  1. # mypy: ignore-errors
  2. """
  3. Dictionary-related variable tracking classes for PyTorch Dynamo.
  4. This module implements variable tracking for different types of dictionary-like objects:
  5. - Regular Python dictionaries (dict)
  6. - Ordered dictionaries (collections.OrderedDict)
  7. - Default dictionaries (collections.defaultdict)
  8. - Dictionary views (keys and values)
  9. - Sets and frozensets (implemented internally using dictionaries)
  10. These classes are responsible for tracking dictionary operations during graph compilation,
  11. maintaining proper guards for dictionary mutations and key existence checks. They handle
  12. dictionary creation, modification, key/value access, and view operations while ensuring
  13. correct behavior in the compiled code through appropriate guard installation.
  14. The implementation uses a special _HashableTracker wrapper to handle dictionary keys
  15. while preserving proper aliasing semantics. Sets are implemented as dictionaries with
  16. None values for efficiency and code reuse.
  17. """
  18. import collections
  19. import functools
  20. import inspect
  21. import operator
  22. import types
  23. from collections.abc import Hashable as py_Hashable
  24. from typing import Optional, TYPE_CHECKING
  25. from torch._subclasses.fake_tensor import is_fake
  26. from .. import graph_break_hints, polyfills, variables
  27. from ..bytecode_transformation import create_call_function, create_instruction
  28. from ..exc import raise_observed_exception, unimplemented_v2
  29. from ..guards import GuardBuilder, install_guard
  30. from ..source import is_from_local_source
  31. from ..utils import (
  32. cmp_name_to_op_mapping,
  33. dict_items,
  34. dict_keys,
  35. dict_values,
  36. istype,
  37. raise_args_mismatch,
  38. specialize_symnode,
  39. )
  40. from .base import ValueMutationNew, VariableTracker
  41. from .constant import ConstantVariable
  42. if TYPE_CHECKING:
  43. from torch._dynamo.codegen import PyCodegen
  44. from torch._dynamo.symbolic_convert import InstructionTranslator
  45. # [Adding a new supported class within the keys of ConstDictVarialble]
  46. # - Add its tracker type to is_hashable
  47. # - (perhaps) Define how it is compared in _HashableTracker._eq_impl
  48. def was_instancecheck_override(obj):
  49. return type(obj).__dict__.get("__instancecheck__", False)
  50. def raise_unhashable(arg, tx=None):
  51. if tx is None:
  52. from torch._dynamo.symbolic_convert import InstructionTranslator
  53. tx = InstructionTranslator.current_tx()
  54. raise_observed_exception(
  55. TypeError, tx, args=[ConstantVariable(f"unhashable type: {type(arg)}")]
  56. )
  57. def is_hashable(x):
  58. # NB - performing isinstance check on a LazVT realizes the VT, accidentally
  59. # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at
  60. # the underlying value without realizing the VT. Consider updating the
  61. # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT.
  62. if (
  63. isinstance(x, variables.LazyVariableTracker)
  64. and not x.is_realized()
  65. and x.is_hashable()
  66. ):
  67. return True
  68. if isinstance(x, variables.TensorVariable):
  69. # Tensors are hashable if they have an example_value (a fake tensor)
  70. # Most VT's should have one.
  71. # It'd be nice if at some point we could assert that they all have one
  72. return x.as_proxy().node.meta.get("example_value") is not None
  73. elif isinstance(x, variables.TupleVariable):
  74. return all(is_hashable(e) for e in x.items)
  75. elif isinstance(x, variables.FrozenDataClassVariable):
  76. return all(is_hashable(e) for e in x.fields.values())
  77. elif (
  78. isinstance(x, variables.UserDefinedObjectVariable)
  79. and not was_instancecheck_override(x.value)
  80. and inspect.getattr_static(x.value, "__hash__") is int.__hash__
  81. and isinstance(x.value, int)
  82. ):
  83. return isinstance(x.value, py_Hashable)
  84. else:
  85. return isinstance(
  86. x,
  87. (
  88. variables.BuiltinVariable,
  89. variables.SymNodeVariable,
  90. variables.ConstantVariable,
  91. variables.EnumVariable,
  92. variables.FrozensetVariable,
  93. variables.UserDefinedClassVariable,
  94. variables.UserFunctionVariable,
  95. variables.SkipFunctionVariable,
  96. variables.misc.NumpyVariable,
  97. variables.NNModuleVariable,
  98. variables.UnspecializedNNModuleVariable,
  99. variables.MethodWrapperVariable,
  100. variables.TorchInGraphFunctionVariable,
  101. variables.TypingVariable,
  102. variables.FunctoolsPartialVariable,
  103. variables.WeakRefVariable,
  104. variables.TorchHigherOrderOperatorVariable,
  105. ),
  106. )
  107. class ConstDictVariable(VariableTracker):
  108. CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS
  109. _nonvar_fields = {
  110. "user_cls",
  111. *VariableTracker._nonvar_fields,
  112. }
  113. class _HashableTracker:
  114. """
  115. Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable
  116. This should not be seen or touched by anything outside of ConstDictVariable and its children
  117. Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing
  118. """
  119. def __init__(self, vt) -> None:
  120. # We specialize SymNodes
  121. vt = specialize_symnode(vt)
  122. # TODO Temporarily remove to figure out what keys are we breaking on
  123. # and add proper support for them
  124. if not is_hashable(vt):
  125. raise_unhashable(vt)
  126. self.vt = vt
  127. @property
  128. def underlying_value(self):
  129. if (
  130. isinstance(self.vt, variables.LazyVariableTracker)
  131. and not self.vt.is_realized()
  132. and self.vt.is_hashable()
  133. ):
  134. return self.vt.original_value()
  135. if isinstance(self.vt, variables.TensorVariable):
  136. x = self.vt.as_proxy().node.meta["example_value"]
  137. elif isinstance(self.vt, variables.TupleVariable):
  138. Hashable = ConstDictVariable._HashableTracker
  139. x = tuple(Hashable(e).underlying_value for e in self.vt.items)
  140. elif isinstance(self.vt, variables.NNModuleVariable):
  141. return self.vt.value
  142. elif isinstance(self.vt, variables.UnspecializedNNModuleVariable):
  143. return self.vt.value
  144. elif isinstance(self.vt, variables.UserFunctionVariable):
  145. return self.vt.get_function()
  146. elif isinstance(self.vt, variables.WeakRefVariable):
  147. # Access the underlying value inside the referent_vt for the key representation
  148. Hashable = ConstDictVariable._HashableTracker
  149. return Hashable(self.vt.referent_vt).underlying_value
  150. elif isinstance(self.vt, variables.FrozenDataClassVariable):
  151. Hashable = ConstDictVariable._HashableTracker
  152. fields_values = {
  153. k: Hashable(v).underlying_value for k, v in self.vt.fields.items()
  154. }
  155. return variables.FrozenDataClassVariable.HashWrapper(
  156. self.vt.python_type(), fields_values
  157. )
  158. elif isinstance(self.vt, variables.UserDefinedObjectVariable):
  159. # The re module in Python 3.13+ has a dictionary (_cache2) with
  160. # an object as key (`class _ZeroSentinel(int): ...`):
  161. # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual
  162. return self.vt.value
  163. else:
  164. x = self.vt.as_python_constant()
  165. return x
  166. def __hash__(self):
  167. return hash(self.underlying_value)
  168. @staticmethod
  169. def _eq_impl(a, b):
  170. # TODO: Put this in utils and share it between variables/builtin.py and here
  171. if type(a) != type(b):
  172. return False
  173. elif isinstance(a, tuple):
  174. Hashable = ConstDictVariable._HashableTracker
  175. return len(a) == len(b) and all(
  176. Hashable._eq_impl(u, v) for u, v in zip(a, b)
  177. )
  178. elif is_fake(a):
  179. return a is b
  180. else:
  181. return a == b
  182. def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool:
  183. Hashable = ConstDictVariable._HashableTracker
  184. assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), (
  185. type(other)
  186. )
  187. if isinstance(other, Hashable):
  188. return Hashable._eq_impl(self.underlying_value, other.underlying_value)
  189. # constant
  190. return Hashable._eq_impl(self.underlying_value, other)
  191. def __init__(
  192. self,
  193. items: dict[VariableTracker, VariableTracker],
  194. user_cls=dict,
  195. **kwargs,
  196. ) -> None:
  197. # .clone() pass these arguments in kwargs but they're recreated a few
  198. # lines below
  199. if "original_items" in kwargs:
  200. kwargs.pop("original_items")
  201. if "should_reconstruct_all" in kwargs:
  202. kwargs.pop("should_reconstruct_all")
  203. super().__init__(**kwargs)
  204. Hashable = ConstDictVariable._HashableTracker
  205. # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers
  206. assert all(
  207. isinstance(x, (VariableTracker, Hashable))
  208. and isinstance(v, VariableTracker)
  209. for x, v in items.items()
  210. )
  211. def make_hashable(key):
  212. return key if isinstance(key, Hashable) else Hashable(key)
  213. dict_cls = self._get_dict_cls_from_user_cls(user_cls)
  214. self.items = dict_cls({make_hashable(x): v for x, v in items.items()})
  215. # need to reconstruct everything if the dictionary is an intermediate value
  216. # or if a pop/delitem was executed
  217. self.should_reconstruct_all = not is_from_local_source(self.source)
  218. self.original_items = items.copy()
  219. self.user_cls = user_cls
  220. def _get_dict_cls_from_user_cls(self, user_cls):
  221. accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict)
  222. # avoid executing user code if user_cls is a dict subclass
  223. if user_cls in accepted_dict_types:
  224. dict_cls = user_cls
  225. else:
  226. # <Subclass, ..., dict, object>
  227. dict_cls = next(
  228. base for base in user_cls.__mro__ if base in accepted_dict_types
  229. )
  230. assert dict_cls in accepted_dict_types, dict_cls
  231. # Use a dict instead as the call "defaultdict({make_hashable(x): v ..})"
  232. # would fail as defaultdict expects a callable as first argument
  233. if dict_cls is collections.defaultdict:
  234. dict_cls = dict
  235. return dict_cls
  236. def as_proxy(self):
  237. return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()}
  238. def debug_repr(self):
  239. return (
  240. "{"
  241. + ", ".join(
  242. f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items()
  243. )
  244. + "}"
  245. )
  246. def as_python_constant(self):
  247. return {
  248. k.vt.as_python_constant(): v.as_python_constant()
  249. for k, v in self.items.items()
  250. }
  251. def keys_as_python_constant(self):
  252. self.install_dict_keys_match_guard()
  253. return {k.vt.as_python_constant(): v for k, v in self.items.items()}
  254. def python_type(self):
  255. return self.user_cls
  256. def __contains__(self, vt) -> bool:
  257. assert isinstance(vt, VariableTracker)
  258. Hashable = ConstDictVariable._HashableTracker
  259. return (
  260. is_hashable(vt)
  261. and Hashable(vt) in self.items
  262. and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
  263. )
  264. def len(self) -> int:
  265. return sum(
  266. not isinstance(x, variables.DeletedVariable) for x in self.items.values()
  267. )
  268. def has_new_items(self) -> bool:
  269. return self.should_reconstruct_all or any(
  270. self.is_new_item(self.original_items.get(key.vt), value)
  271. for key, value in self.items.items()
  272. )
  273. def is_new_item(self, value, other):
  274. # compare the id of the realized values if both values are not lazy VTs
  275. if value and value.is_realized() and other.is_realized():
  276. return id(value.realize()) != id(other.realize())
  277. return id(value) != id(other)
  278. def reconstruct_kvs_into_new_dict(self, codegen):
  279. # Build a dictionary that contains the keys and values.
  280. num_args = 0
  281. for key, value in self.items.items():
  282. # We can safely call realize() here as it won't introduce any new guards
  283. item = self.original_items.get(key.vt)
  284. if self.is_new_item(item, value) or self.should_reconstruct_all:
  285. codegen(key.vt)
  286. codegen(value)
  287. num_args += 1
  288. codegen.append_output(create_instruction("BUILD_MAP", arg=num_args))
  289. def reconstruct(self, codegen: "PyCodegen"):
  290. if self.user_cls is collections.OrderedDict:
  291. # emit `OrderedDict(constructed_dict)`
  292. codegen.add_push_null(
  293. lambda: codegen.extend_output(
  294. [
  295. codegen.create_load_python_module(collections),
  296. codegen.create_load_attr("OrderedDict"),
  297. ]
  298. )
  299. )
  300. self.reconstruct_kvs_into_new_dict(codegen)
  301. codegen.extend_output(create_call_function(1, False))
  302. else:
  303. self.reconstruct_kvs_into_new_dict(codegen)
  304. def getitem_const_raise_exception_if_absent(
  305. self, tx: "InstructionTranslator", arg: VariableTracker
  306. ):
  307. key = ConstDictVariable._HashableTracker(arg)
  308. if key not in self.items:
  309. raise_observed_exception(KeyError, tx)
  310. return self.items[key]
  311. def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
  312. key = ConstDictVariable._HashableTracker(arg)
  313. if key not in self.items:
  314. msg = f"Dictionary key {arg.value} not found during tracing"
  315. unimplemented_v2(
  316. gb_type="key not found in dict",
  317. context=f"Key {arg.value}",
  318. explanation=msg,
  319. hints=[
  320. "Check if the key exists in the dictionary before accessing it.",
  321. *graph_break_hints.USER_ERROR,
  322. ],
  323. )
  324. return self.items[key]
  325. def maybe_getitem_const(self, arg: VariableTracker):
  326. key = ConstDictVariable._HashableTracker(arg)
  327. if key not in self.items:
  328. return None
  329. return self.items[key]
  330. def realize_key_vt(self, arg: VariableTracker):
  331. # Realize the LazyVT on a particular index
  332. assert arg in self
  333. key = ConstDictVariable._HashableTracker(arg)
  334. index = tuple(self.items.keys()).index(key)
  335. original_key_vt = tuple(self.original_items.keys())[index]
  336. if isinstance(original_key_vt, variables.LazyVariableTracker):
  337. original_key_vt.realize()
  338. def install_dict_keys_match_guard(self):
  339. if self.source:
  340. install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH))
  341. def install_dict_contains_guard(self, tx, args):
  342. # Key guarding - These are the cases to consider
  343. # 1) The dict has been mutated. In this case, we would have already
  344. # inserted a DICT_KEYS_MATCH guard, so we can skip.
  345. #
  346. # 2) args[0].source is None. This happens for const keys. Here, we
  347. # have to insert the DICT_CONTAINS guard.
  348. #
  349. # 3) args[0].source is not None. This can happen for non-const VTs.
  350. # 3a) contains=True. In this case, we can access the lazyVT from
  351. # original_items and selectively realize it.
  352. # 3b) contains=False. There is no easy way to selectively apply this
  353. # DICT_NOT_CONTAINS guard because our guard are represented via trees.
  354. # Be conservative and add DICT_KEYS_MATCH guard.
  355. from . import ConstantVariable
  356. if not self.source:
  357. return
  358. if tx.output.side_effects.is_modified(self):
  359. return
  360. contains = args[0] in self
  361. if args[0].source is None and isinstance(args[0], ConstantVariable):
  362. install_guard(
  363. self.make_guard(
  364. functools.partial(
  365. type(self).CONTAINS_GUARD,
  366. key=args[0].value,
  367. invert=not contains,
  368. )
  369. )
  370. )
  371. elif args[0].source:
  372. if contains:
  373. self.realize_key_vt(args[0])
  374. else:
  375. self.install_dict_keys_match_guard()
  376. def call_method(
  377. self,
  378. tx,
  379. name,
  380. args: "list[VariableTracker]",
  381. kwargs: "dict[str, VariableTracker]",
  382. ) -> "VariableTracker":
  383. # NB - Both key and value are LazyVariableTrackers in the beginning. So,
  384. # we have to insert guards when a dict method is accessed. For this to
  385. # be simple, we are conservative and overguard. We skip guard only for
  386. # get/__getitem__ because the key guard will be inserted by the
  387. # corresponding value VT. For __contains__, we add a DICT_CONTAINS
  388. # guard. But for all the other methods, we insert the DICT_KEYS_MATCH
  389. # guard to be conservative.
  390. from . import BuiltinVariable, ConstantVariable
  391. Hashable = ConstDictVariable._HashableTracker
  392. arg_hashable = args and is_hashable(args[0])
  393. if name == "__init__":
  394. temp_dict_vt = variables.BuiltinVariable(dict).call_dict(
  395. tx, *args, **kwargs
  396. )
  397. tx.output.side_effects.mutation(self)
  398. self.items.update(temp_dict_vt.items)
  399. return ConstantVariable.create(None)
  400. elif name == "__getitem__":
  401. # Key guarding - Nothing to do. LazyVT for value will take care.
  402. if len(args) != 1:
  403. raise_args_mismatch(tx, name)
  404. return self.getitem_const_raise_exception_if_absent(tx, args[0])
  405. elif name == "items":
  406. if args or kwargs:
  407. raise_args_mismatch(tx, name)
  408. self.install_dict_keys_match_guard()
  409. if self.source:
  410. tx.output.guard_on_key_order.add(self.source)
  411. return DictItemsVariable(self)
  412. elif name == "keys":
  413. if len(args):
  414. raise_args_mismatch(tx, name)
  415. self.install_dict_keys_match_guard()
  416. if self.source:
  417. tx.output.guard_on_key_order.add(self.source)
  418. return DictKeysVariable(self)
  419. elif name == "values":
  420. if args or kwargs:
  421. raise_args_mismatch(tx, name)
  422. self.install_dict_keys_match_guard()
  423. if self.source:
  424. tx.output.guard_on_key_order.add(self.source)
  425. if args or kwargs:
  426. raise_observed_exception(TypeError, tx)
  427. return DictValuesVariable(self)
  428. elif name == "copy":
  429. self.install_dict_keys_match_guard()
  430. if args or kwargs:
  431. raise_args_mismatch(tx, name)
  432. return self.clone(
  433. items=self.items.copy(), mutation_type=ValueMutationNew(), source=None
  434. )
  435. elif name == "__len__":
  436. if args or kwargs:
  437. raise_args_mismatch(tx, name)
  438. self.install_dict_keys_match_guard()
  439. return ConstantVariable.create(len(self.items))
  440. elif name == "__setitem__" and self.is_mutable():
  441. if not arg_hashable:
  442. raise_unhashable(args[0])
  443. self.install_dict_keys_match_guard()
  444. assert not kwargs and len(args) == 2
  445. tx.output.side_effects.mutation(self)
  446. self.items[Hashable(args[0])] = args[1]
  447. return ConstantVariable.create(None)
  448. elif name == "__delitem__" and arg_hashable and self.is_mutable():
  449. self.install_dict_keys_match_guard()
  450. self.should_reconstruct_all = True
  451. tx.output.side_effects.mutation(self)
  452. self.items.__delitem__(Hashable(args[0]))
  453. return ConstantVariable.create(None)
  454. elif name == "get":
  455. if len(args) not in (1, 2):
  456. raise_args_mismatch(tx, name)
  457. if not arg_hashable:
  458. raise_unhashable(args[0])
  459. if args[0] not in self:
  460. self.install_dict_contains_guard(tx, args)
  461. if len(args) == 1:
  462. # if default is not given, return None
  463. return ConstantVariable.create(None)
  464. return args[1]
  465. # Key guarding - Nothing to do.
  466. return self.getitem_const(tx, args[0])
  467. elif name == "pop" and self.is_mutable():
  468. if len(args) not in (1, 2):
  469. raise_args_mismatch(tx, name)
  470. if not arg_hashable:
  471. raise_unhashable(args[0])
  472. if args[0] not in self:
  473. # missing item, return the default value. Install no DICT_CONTAINS guard.
  474. self.install_dict_contains_guard(tx, args)
  475. if len(args) == 1:
  476. # if default is not given, raise KeyError
  477. raise_observed_exception(KeyError, tx)
  478. return args[1]
  479. self.should_reconstruct_all = True
  480. tx.output.side_effects.mutation(self)
  481. return self.items.pop(Hashable(args[0]))
  482. elif name == "popitem" and self.is_mutable():
  483. if (
  484. issubclass(self.user_cls, dict)
  485. and not issubclass(self.user_cls, collections.OrderedDict)
  486. and len(args)
  487. ):
  488. raise_args_mismatch(tx, name)
  489. if not self.items:
  490. msg = ConstantVariable.create("popitem(): dictionary is empty")
  491. raise_observed_exception(KeyError, tx, args=[msg])
  492. if self.user_cls is collections.OrderedDict and (
  493. len(args) == 1 or "last" in kwargs
  494. ):
  495. if len(args) == 1 and isinstance(args[0], ConstantVariable):
  496. last = args[0].value
  497. elif (v := kwargs.get("last")) and isinstance(v, ConstantVariable):
  498. last = v.value
  499. else:
  500. raise_args_mismatch(tx, name)
  501. k, v = self.items.popitem(last=last)
  502. else:
  503. k, v = self.items.popitem()
  504. self.should_reconstruct_all = True
  505. tx.output.side_effects.mutation(self)
  506. return variables.TupleVariable([k.vt, v])
  507. elif name == "clear":
  508. if args or kwargs:
  509. raise_args_mismatch(tx, name)
  510. self.should_reconstruct_all = True
  511. tx.output.side_effects.mutation(self)
  512. self.items.clear()
  513. return ConstantVariable.create(None)
  514. elif name == "update" and self.is_mutable():
  515. # In general, this call looks like `a.update(b, x=1, y=2, ...)`.
  516. # Either `b` or the kwargs is omittable, but not both.
  517. self.install_dict_keys_match_guard()
  518. has_arg = len(args) == 1
  519. has_kwargs = len(kwargs) > 0
  520. if has_arg or has_kwargs:
  521. tx.output.side_effects.mutation(self)
  522. if has_arg:
  523. if isinstance(args[0], ConstDictVariable):
  524. # NB - Guard on all the keys of the other dict to ensure
  525. # correctness.
  526. args[0].install_dict_keys_match_guard()
  527. dict_vt = args[0]
  528. else:
  529. dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0])
  530. self.items.update(dict_vt.items)
  531. if has_kwargs:
  532. # Handle kwargs
  533. kwargs = {
  534. Hashable(ConstantVariable.create(k)): v
  535. for k, v in kwargs.items()
  536. }
  537. self.items.update(kwargs)
  538. return ConstantVariable.create(None)
  539. else:
  540. return super().call_method(tx, name, args, kwargs)
  541. elif name == "__contains__":
  542. if not len(args):
  543. raise_args_mismatch(tx, name)
  544. if not arg_hashable:
  545. raise_unhashable(args[0])
  546. self.install_dict_contains_guard(tx, args)
  547. contains = args[0] in self
  548. return ConstantVariable.create(contains)
  549. elif name == "setdefault" and self.is_mutable():
  550. if len(args) not in (1, 2):
  551. raise_args_mismatch(tx, name)
  552. if not arg_hashable:
  553. raise_unhashable(args[0])
  554. self.install_dict_keys_match_guard()
  555. assert not kwargs
  556. assert len(args) <= 2
  557. value = self.maybe_getitem_const(args[0])
  558. if value is not None:
  559. return value
  560. else:
  561. if len(args) == 1:
  562. x = ConstantVariable.create(None)
  563. else:
  564. x = args[1]
  565. tx.output.side_effects.mutation(self)
  566. self.items[Hashable(args[0])] = x
  567. return x
  568. elif name == "move_to_end":
  569. self.install_dict_keys_match_guard()
  570. tx.output.side_effects.mutation(self)
  571. if args[0] not in self:
  572. raise_observed_exception(KeyError, tx)
  573. last = True
  574. if len(args) == 2 and isinstance(args[1], ConstantVariable):
  575. last = args[1].value
  576. if (
  577. kwargs
  578. and "last" in kwargs
  579. and isinstance(kwargs["last"], ConstantVariable)
  580. ):
  581. last = kwargs.get("last").value
  582. key = Hashable(args[0])
  583. self.items.move_to_end(key, last=last)
  584. return ConstantVariable.create(None)
  585. elif name == "__eq__" and istype(
  586. self, ConstDictVariable
  587. ): # don't let Set use this function
  588. if len(args) != 1:
  589. raise_args_mismatch(tx, name)
  590. return variables.UserFunctionVariable(polyfills.dict___eq__).call_function(
  591. tx, [self, args[0]], {}
  592. )
  593. elif name == "__ne__":
  594. return ConstantVariable.create(
  595. not self.call_method(tx, "__eq__", args, kwargs).value
  596. )
  597. elif name == "__or__":
  598. assert len(args) == 1
  599. other = args[0]
  600. # Method resolution for binops works as follow (using __or__ as example):
  601. # (1) dict.__or__(dict) => dict
  602. # (2) dict.__or__(subclass): return NotImplemented
  603. # (3) Check if subclass implements __ror__ => forward the call
  604. # to subclass.__ror__(dict)
  605. # Let's not forward the call to __ror__ yet because __ror__ can be
  606. # implemented in C (i.e. OrderedDict subclass) which Dynamo cannot
  607. # trace
  608. # if istype(other, variables.UserDefinedDictVariable):
  609. # if other.call_obj_hasattr(tx, "__ror__").value:
  610. # return other.call_method(tx, "__ror__", [self], kwargs)
  611. # The three dict types Dynamo can handle are dict, OrderedDict and
  612. # defaultdict.
  613. # TODO(guilhermeleobas): this check should be on builtin.py::call_or_
  614. if not istype(
  615. other, (ConstDictVariable, variables.UserDefinedDictVariable)
  616. ):
  617. msg = (
  618. f"unsupported operand type(s) for |: '{self.python_type().__name__}'"
  619. f"and '{other.python_type().__name__}'"
  620. )
  621. raise_observed_exception(TypeError, tx, args=[msg])
  622. # OrderedDict overloads __ror__
  623. ts = {self.user_cls, other.user_cls}
  624. user_cls = (
  625. collections.OrderedDict
  626. if any(issubclass(t, collections.OrderedDict) for t in ts)
  627. else dict
  628. )
  629. self.install_dict_keys_match_guard()
  630. new_dict_vt = self.clone(
  631. items=self.items.copy(),
  632. mutation_type=ValueMutationNew(),
  633. source=None,
  634. user_cls=user_cls,
  635. )
  636. # NB - Guard on all the keys of the other dict to ensure
  637. # correctness.
  638. args[0].install_dict_keys_match_guard()
  639. new_dict_vt.items.update(args[0].items)
  640. return new_dict_vt
  641. elif name == "__ior__":
  642. self.call_method(tx, "update", args, kwargs)
  643. return self
  644. else:
  645. return super().call_method(tx, name, args, kwargs)
  646. def unpack_var_sequence(self, tx):
  647. self.install_dict_keys_match_guard()
  648. return [x.vt for x in self.items.keys()]
  649. def call_obj_hasattr(self, tx, name):
  650. # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict.
  651. # OrderedDict though requires side effects tracking because it supports arbitrary setattr.
  652. if self.user_cls is dict:
  653. if name in self.user_cls.__dict__:
  654. return ConstantVariable.create(True)
  655. return ConstantVariable.create(False)
  656. msg = f"hasattr on {self.user_cls} is not supported"
  657. unimplemented_v2(
  658. gb_type="unsupported hasattr operation",
  659. context=f"Class {self.user_cls}",
  660. explanation=msg,
  661. hints=[
  662. "Consider using a regular dictionary instead",
  663. *graph_break_hints.SUPPORTABLE,
  664. ],
  665. )
  666. def clone(self, **kwargs):
  667. self.install_dict_keys_match_guard()
  668. return super().clone(**kwargs)
  669. class MappingProxyVariable(VariableTracker):
  670. # proxies to the original dict_vt
  671. def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
  672. super().__init__(**kwargs)
  673. assert isinstance(dv_dict, ConstDictVariable)
  674. self.dv_dict = dv_dict
  675. def python_type(self):
  676. return types.MappingProxyType
  677. def unpack_var_sequence(self, tx):
  678. return self.dv_dict.unpack_var_sequence(tx)
  679. def reconstruct(self, codegen: "PyCodegen"):
  680. # load types.MappingProxyType
  681. if self.source:
  682. msg = (
  683. f"Preexisting MappingProxyVariable (source: {self.source}) cannot be reconstructed "
  684. "because the connection to the original dict will be lost."
  685. )
  686. unimplemented_v2(
  687. gb_type="mapping proxy cannot be reconstructed",
  688. context=f"Source: {self.source}",
  689. explanation=msg,
  690. hints=[
  691. "Use a mapping proxy constructed in the same `torch.compile` region.",
  692. *graph_break_hints.SUPPORTABLE,
  693. ],
  694. )
  695. codegen.add_push_null(
  696. lambda: codegen.extend_output(
  697. [
  698. codegen.create_load_python_module(types),
  699. codegen.create_load_attr("MappingProxyType"),
  700. ]
  701. )
  702. )
  703. codegen(self.dv_dict)
  704. codegen.extend_output(create_call_function(1, False))
  705. def call_method(
  706. self,
  707. tx,
  708. name,
  709. args: list["VariableTracker"],
  710. kwargs: dict[str, "VariableTracker"],
  711. ) -> "VariableTracker":
  712. if self.source and tx.output.side_effects.has_existing_dict_mutation():
  713. msg = (
  714. "A dict has been modified while we have an existing mappingproxy object. "
  715. "A mapping proxy object, as the name suggest, proxies a mapping "
  716. "object (usually a dict). If the original dict object mutates, it "
  717. "is reflected in the proxy object as well. For an existing proxy "
  718. "object, we do not know the original dict it points to. Therefore, "
  719. "for correctness we graph break when there is dict mutation and we "
  720. "are trying to access a proxy object."
  721. )
  722. unimplemented_v2(
  723. gb_type="mapping proxy affected by dictionary mutation",
  724. context=f"Source: {self.source}, Dict mutation detected",
  725. explanation=msg,
  726. hints=[
  727. "Avoid modifying dictionaries that might be referenced by mapping proxy objects",
  728. "Or avoid using the mapping proxy objects after modifying its underlying dictionary",
  729. ],
  730. )
  731. return self.dv_dict.call_method(tx, name, args, kwargs)
  732. class NNModuleHooksDictVariable(ConstDictVariable):
  733. # Special class to avoid adding any guards on the nn module hook ids.
  734. def install_dict_keys_match_guard(self):
  735. pass
  736. def install_dict_contains_guard(self, tx, args):
  737. pass
  738. class DefaultDictVariable(ConstDictVariable):
  739. def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None:
  740. super().__init__(items, user_cls, **kwargs)
  741. assert user_cls is collections.defaultdict
  742. self.default_factory = default_factory
  743. def is_python_constant(self):
  744. # Return false for unsupported defaults. This ensures that a bad handler
  745. # path is not taken in BuiltinVariable for getitem.
  746. if self.default_factory not in [list, tuple, dict] and not self.items:
  747. return False
  748. return super().is_python_constant()
  749. def debug_repr(self):
  750. return (
  751. f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})"
  752. )
  753. @staticmethod
  754. def is_supported_arg(arg):
  755. if isinstance(arg, variables.BuiltinVariable):
  756. return arg.fn in (list, tuple, dict, set)
  757. else:
  758. return isinstance(arg, variables.functions.BaseUserFunctionVariable)
  759. def call_method(
  760. self,
  761. tx,
  762. name,
  763. args: "list[VariableTracker]",
  764. kwargs: "dict[str, VariableTracker]",
  765. ) -> "VariableTracker":
  766. if name == "__getitem__":
  767. assert len(args) == 1
  768. if args[0] in self:
  769. return self.getitem_const(tx, args[0])
  770. else:
  771. if self.default_factory is None:
  772. raise KeyError(f"{args[0]}")
  773. else:
  774. default_var = self.default_factory.call_function(tx, [], {})
  775. super().call_method(
  776. tx, "__setitem__", (args[0], default_var), kwargs
  777. )
  778. return default_var
  779. else:
  780. return super().call_method(tx, name, args, kwargs)
  781. def reconstruct(self, codegen):
  782. # emit `defaultdict(default_factory, new_dict)`
  783. codegen.add_push_null(
  784. lambda: codegen.extend_output(
  785. [
  786. codegen.create_load_python_module(collections),
  787. codegen.create_load_attr("defaultdict"),
  788. ]
  789. )
  790. )
  791. codegen(self.default_factory)
  792. self.reconstruct_kvs_into_new_dict(codegen)
  793. codegen.extend_output(create_call_function(2, False))
  794. # TODO: Implementing this via inheritance rather than composition is a
  795. # footgun, because self method calls in dict will route back to the set
  796. # implementation, which is almost assuredly wrong
  797. class SetVariable(ConstDictVariable):
  798. """We model a sets as dictionary with None values"""
  799. CONTAINS_GUARD = GuardBuilder.SET_CONTAINS
  800. def __init__(
  801. self,
  802. items: list[VariableTracker],
  803. **kwargs,
  804. ) -> None:
  805. items = dict.fromkeys(items, SetVariable._default_value())
  806. super().__init__(items, **kwargs)
  807. def debug_repr(self):
  808. if not self.items:
  809. return "set()"
  810. else:
  811. return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
  812. @property
  813. def set_items(self):
  814. return set(self.items.keys())
  815. @staticmethod
  816. def _default_value():
  817. # Variable to fill in he keys of the dictionary
  818. return ConstantVariable.create(None)
  819. def as_proxy(self):
  820. return {k.vt.as_proxy() for k in self.set_items}
  821. def python_type(self):
  822. return set
  823. def as_python_constant(self):
  824. return {k.vt.as_python_constant() for k in self.set_items}
  825. def reconstruct(self, codegen: "PyCodegen"):
  826. codegen.foreach([x.vt for x in self.set_items])
  827. codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items)))
  828. def _fast_set_method(self, tx, fn, args, kwargs):
  829. try:
  830. res = fn(
  831. *[x.as_python_constant() for x in [self, *args]],
  832. **{k: v.as_python_constant() for k, v in kwargs.items()},
  833. )
  834. except Exception as exc:
  835. raise_observed_exception(
  836. type(exc), tx, args=list(map(ConstantVariable.create, exc.args))
  837. )
  838. return VariableTracker.build(tx, res)
  839. def call_method(
  840. self,
  841. tx,
  842. name,
  843. args: list[VariableTracker],
  844. kwargs: dict[str, VariableTracker],
  845. ) -> "VariableTracker":
  846. # We forward the calls to the dictionary model
  847. from ..utils import check_constant_args
  848. if (
  849. name
  850. in (
  851. "isdisjoint",
  852. "union",
  853. "intersection",
  854. "difference",
  855. "symmetric_difference",
  856. )
  857. and check_constant_args(args, kwargs)
  858. and self.python_type() is set
  859. ):
  860. py_type = self.python_type()
  861. return self._fast_set_method(tx, getattr(py_type, name), args, kwargs)
  862. if name == "__init__":
  863. temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs)
  864. tx.output.side_effects.mutation(self)
  865. self.items.clear()
  866. self.items.update(temp_set_vt.items)
  867. return ConstantVariable.create(None)
  868. elif name == "add":
  869. assert not kwargs
  870. if len(args) != 1:
  871. raise_args_mismatch(tx, name)
  872. name = "__setitem__"
  873. args = (args[0], SetVariable._default_value())
  874. elif name == "pop":
  875. assert not kwargs
  876. assert not args
  877. # Choose an item at random and pop it via the Dict.pop method
  878. try:
  879. result = self.set_items.pop().vt
  880. except KeyError as e:
  881. raise_observed_exception(
  882. KeyError, tx, args=list(map(ConstantVariable.create, e.args))
  883. )
  884. super().call_method(tx, name, (result,), kwargs)
  885. return result
  886. elif name == "isdisjoint":
  887. if len(args) != 1:
  888. raise_args_mismatch(tx, name)
  889. assert not kwargs
  890. return variables.UserFunctionVariable(
  891. polyfills.set_isdisjoint
  892. ).call_function(tx, [self, args[0]], {})
  893. elif name == "intersection":
  894. assert not kwargs
  895. return variables.UserFunctionVariable(
  896. polyfills.set_intersection
  897. ).call_function(tx, [self, *args], {})
  898. elif name == "intersection_update":
  899. assert not kwargs
  900. return variables.UserFunctionVariable(
  901. polyfills.set_intersection_update
  902. ).call_function(tx, [self, *args], {})
  903. elif name == "union":
  904. assert not kwargs
  905. return variables.UserFunctionVariable(polyfills.set_union).call_function(
  906. tx, [self, *args], {}
  907. )
  908. elif name == "difference":
  909. assert not kwargs
  910. return variables.UserFunctionVariable(
  911. polyfills.set_difference
  912. ).call_function(tx, [self, *args], {})
  913. elif name == "difference_update":
  914. assert not kwargs
  915. return variables.UserFunctionVariable(
  916. polyfills.set_difference_update
  917. ).call_function(tx, [self, *args], {})
  918. elif name == "symmetric_difference":
  919. if len(args) != 1:
  920. raise_args_mismatch(tx, name)
  921. assert not kwargs
  922. return variables.UserFunctionVariable(
  923. polyfills.set_symmetric_difference
  924. ).call_function(tx, [self, *args], {})
  925. elif name == "symmetric_difference_update":
  926. if len(args) != 1:
  927. raise_args_mismatch(tx, name)
  928. assert not kwargs
  929. return variables.UserFunctionVariable(
  930. polyfills.set_symmetric_difference_update
  931. ).call_function(tx, [self, *args], {})
  932. elif name == "update" and self.is_mutable():
  933. assert not kwargs
  934. return variables.UserFunctionVariable(polyfills.set_update).call_function(
  935. tx, [self, *args], {}
  936. )
  937. elif name == "remove":
  938. assert not kwargs
  939. assert len(args) == 1
  940. if args[0] not in self:
  941. raise_observed_exception(KeyError, tx, args=args)
  942. return super().call_method(tx, "pop", args, kwargs)
  943. elif name == "discard":
  944. assert not kwargs
  945. assert len(args) == 1
  946. if args[0] in self:
  947. return super().call_method(tx, "pop", args, kwargs)
  948. else:
  949. return ConstantVariable.create(value=None)
  950. elif name in ("issubset", "issuperset"):
  951. if len(args) != 1:
  952. raise_args_mismatch(tx, name)
  953. op = {
  954. "issubset": operator.le,
  955. "issuperset": operator.ge,
  956. }
  957. other = args[0].realize()
  958. if not istype(other, SetVariable):
  959. other = variables.BuiltinVariable(set).call_function(tx, [other], {})
  960. return variables.BuiltinVariable(op.get(name)).call_function(
  961. tx, [self, other], {}
  962. )
  963. elif name in ("__and__", "__or__", "__xor__", "__sub__"):
  964. m = {
  965. "__and__": "intersection",
  966. "__or__": "union",
  967. "__xor__": "symmetric_difference",
  968. "__sub__": "difference",
  969. }.get(name)
  970. if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
  971. msg = ConstantVariable.create(
  972. f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
  973. )
  974. raise_observed_exception(TypeError, tx, args=[msg])
  975. return self.call_method(tx, m, args, kwargs)
  976. elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"):
  977. if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
  978. msg = ConstantVariable.create(
  979. f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'"
  980. )
  981. raise_observed_exception(TypeError, tx, args=[msg])
  982. m = {
  983. "__iand__": "intersection_update",
  984. "__ior__": "update",
  985. "__ixor__": "symmetric_difference_update",
  986. "__isub__": "difference_update",
  987. }.get(name)
  988. self.call_method(tx, m, args, kwargs)
  989. return self
  990. elif name == "__eq__":
  991. if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
  992. return ConstantVariable.create(False)
  993. r = self.call_method(tx, "symmetric_difference", args, kwargs)
  994. return ConstantVariable.create(len(r.set_items) == 0)
  995. elif name in cmp_name_to_op_mapping:
  996. if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)):
  997. return ConstantVariable.create(NotImplemented)
  998. return ConstantVariable.create(
  999. cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
  1000. )
  1001. return super().call_method(tx, name, args, kwargs)
  1002. def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker):
  1003. raise RuntimeError("Illegal to getitem on a set")
  1004. def install_dict_keys_match_guard(self):
  1005. # Already EQUALS_MATCH guarded
  1006. pass
  1007. def install_dict_contains_guard(self, tx, args):
  1008. super().install_dict_contains_guard(tx, args)
  1009. class FrozensetVariable(SetVariable):
  1010. def __init__(
  1011. self,
  1012. items: list[VariableTracker],
  1013. **kwargs,
  1014. ) -> None:
  1015. super().__init__(items, **kwargs)
  1016. def debug_repr(self):
  1017. if not self.items:
  1018. return "frozenset()"
  1019. else:
  1020. return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}"
  1021. @property
  1022. def set_items(self):
  1023. return self.items.keys()
  1024. def python_type(self):
  1025. return frozenset
  1026. def as_python_constant(self):
  1027. return frozenset({k.vt.as_python_constant() for k in self.set_items})
  1028. def reconstruct(self, codegen: "PyCodegen"):
  1029. codegen.foreach([x.vt for x in self.set_items])
  1030. codegen.add_push_null(
  1031. lambda: codegen.extend_output(
  1032. [
  1033. codegen.create_load_global("frozenset"),
  1034. ]
  1035. )
  1036. )
  1037. codegen.extend_output(create_call_function(0, False))
  1038. def call_method(
  1039. self,
  1040. tx,
  1041. name,
  1042. args: list[VariableTracker],
  1043. kwargs: dict[str, VariableTracker],
  1044. ) -> "VariableTracker":
  1045. if name in ["add", "pop", "update", "remove", "discard", "clear"]:
  1046. raise RuntimeError(f"Illegal call_method {name} on a frozenset")
  1047. elif name == "__init__":
  1048. # frozenset is immutable. Calling __init__ again shouldn't have any effect
  1049. # In[1]: s = frozenset([1, 2])
  1050. #
  1051. # In[2]: s.__init__([3, 4])
  1052. #
  1053. # In[3]: s
  1054. # frozenset({1, 2})
  1055. return ConstantVariable.create(None)
  1056. elif name in (
  1057. "copy",
  1058. "difference",
  1059. "intersection",
  1060. "symmetric_difference",
  1061. ):
  1062. r = super().call_method(tx, name, args, kwargs)
  1063. return FrozensetVariable(r.items)
  1064. return super().call_method(tx, name, args, kwargs)
  1065. class DictKeySetVariable(SetVariable):
  1066. def __init__(
  1067. self,
  1068. items: list[VariableTracker],
  1069. **kwargs,
  1070. ) -> None:
  1071. super().__init__(items, **kwargs)
  1072. def debug_repr(self):
  1073. if not self.items:
  1074. return "dict_keys([])"
  1075. else:
  1076. return (
  1077. "dict_keys(["
  1078. + ",".join(k.vt.debug_repr() for k in self.items.keys())
  1079. + "])"
  1080. )
  1081. def install_dict_keys_match_guard(self):
  1082. # Already EQUALS_MATCH guarded
  1083. pass
  1084. def install_dict_contains_guard(self, tx, args):
  1085. # Already EQUALS_MATCH guarded
  1086. pass
  1087. @property
  1088. def set_items(self):
  1089. return self.items
  1090. def python_type(self):
  1091. return dict_keys
  1092. def as_python_constant(self):
  1093. return dict.fromkeys(
  1094. {k.vt.as_python_constant() for k in self.set_items}, None
  1095. ).keys()
  1096. def call_method(
  1097. self,
  1098. tx,
  1099. name,
  1100. args: list[VariableTracker],
  1101. kwargs: dict[str, VariableTracker],
  1102. ) -> "VariableTracker":
  1103. if name in ["add", "pop", "update", "remove", "discard", "clear"]:
  1104. raise RuntimeError(f"Illegal call_method {name} on a dict_keys")
  1105. return super().call_method(tx, name, args, kwargs)
  1106. class DictViewVariable(VariableTracker):
  1107. """
  1108. Models _PyDictViewObject
  1109. This is an "abstract" class. Subclasses will override kv and the items method
  1110. """
  1111. kv: Optional[str] = None
  1112. def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None:
  1113. super().__init__(**kwargs)
  1114. assert self.kv in ("keys", "values", "items")
  1115. assert isinstance(dv_dict, ConstDictVariable)
  1116. self.dv_dict = dv_dict
  1117. @property
  1118. def view_items(self):
  1119. return getattr(self.dv_dict.items, self.kv)()
  1120. @property
  1121. def view_items_vt(self):
  1122. # Returns an iterable of the unpacked items
  1123. # Implement in the subclasses
  1124. raise NotImplementedError
  1125. def unpack_var_sequence(self, tx):
  1126. return self.view_items_vt
  1127. def reconstruct(self, codegen: "PyCodegen"):
  1128. codegen(self.dv_dict)
  1129. codegen.load_method(self.kv)
  1130. codegen.call_method(0)
  1131. def call_obj_hasattr(self, tx, name):
  1132. if name in self.python_type().__dict__:
  1133. return ConstantVariable.create(True)
  1134. return ConstantVariable.create(False)
  1135. def call_method(
  1136. self,
  1137. tx,
  1138. name,
  1139. args: list["VariableTracker"],
  1140. kwargs: dict[str, "VariableTracker"],
  1141. ) -> "VariableTracker":
  1142. if name == "__len__":
  1143. return self.dv_dict.call_method(tx, name, args, kwargs)
  1144. return super().call_method(tx, name, args, kwargs)
  1145. class DictKeysVariable(DictViewVariable):
  1146. kv = "keys"
  1147. @property
  1148. def set_items(self):
  1149. return set(self.view_items)
  1150. @property
  1151. def view_items_vt(self):
  1152. # Returns an iterable of the unpacked items
  1153. return [x.vt for x in self.view_items]
  1154. def python_type(self):
  1155. return dict_keys
  1156. def call_method(
  1157. self,
  1158. tx,
  1159. name,
  1160. args: list["VariableTracker"],
  1161. kwargs: dict[str, "VariableTracker"],
  1162. ) -> "VariableTracker":
  1163. if name == "__contains__":
  1164. return self.dv_dict.call_method(tx, name, args, kwargs)
  1165. elif name in (
  1166. "__and__",
  1167. "__iand__",
  1168. "__or__",
  1169. "__ior__",
  1170. "__sub__",
  1171. "__isub__",
  1172. "__xor__",
  1173. "__ixor__",
  1174. ):
  1175. # These methods always returns a set
  1176. m = getattr(self.set_items, name)
  1177. r = m(args[0].set_items)
  1178. return SetVariable(r)
  1179. if name in cmp_name_to_op_mapping:
  1180. if not isinstance(args[0], (SetVariable, DictKeysVariable)):
  1181. return ConstantVariable.create(NotImplemented)
  1182. return ConstantVariable.create(
  1183. cmp_name_to_op_mapping[name](self.set_items, args[0].set_items)
  1184. )
  1185. return super().call_method(tx, name, args, kwargs)
  1186. class DictValuesVariable(DictViewVariable):
  1187. # DictValuesVariable is an iterable but cannot be compared.
  1188. kv = "values"
  1189. @property
  1190. def view_items_vt(self):
  1191. return list(self.view_items)
  1192. def python_type(self):
  1193. return dict_values
  1194. class DictItemsVariable(DictViewVariable):
  1195. kv = "items"
  1196. @property
  1197. def view_items_vt(self):
  1198. # Returns an iterable of the unpacked items
  1199. return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items]
  1200. def python_type(self):
  1201. return dict_items