symbolic_convert.py 186 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689
  1. """
  2. Core module responsible for converting Python bytecode into TorchDynamo's symbolic execution format.
  3. This module implements the bytecode-level tracing system that allows TorchDynamo to analyze
  4. and transform Python code. It converts Python bytecode instructions into a symbolic format
  5. that tracks the flow of tensors and other values through the program.
  6. Key components:
  7. - InstructionTranslatorBase: Base class for converting bytecode to symbolic execution
  8. - InstructionTranslator: Main translator for function bytecode
  9. - InliningInstructionTranslator: Handles inlining of called functions
  10. - SpeculationLog: Manages state for speculative execution and rollback
  11. The symbolic conversion process handles:
  12. - Control flow (loops, conditionals, etc.)
  13. - Function inlining and call stack management
  14. - Tracking of program values and side effects
  15. - Graph breaks and resumption points
  16. - Exception handling and stack frame management
  17. This is a core part of TorchDynamo's tracing system that enables ahead-of-time
  18. optimization of PyTorch programs.
  19. """
  20. from __future__ import annotations
  21. import collections
  22. import collections.abc
  23. import contextlib
  24. import copy
  25. import dataclasses
  26. import dis
  27. import functools
  28. import importlib
  29. import inspect
  30. import itertools
  31. import linecache
  32. import logging
  33. import operator
  34. import re
  35. import sys
  36. import threading
  37. import traceback
  38. import types
  39. import weakref
  40. from traceback import StackSummary
  41. from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
  42. from typing_extensions import TypeAlias, TypeIs
  43. from unittest.mock import patch
  44. import torch
  45. import torch._logging
  46. from torch._dynamo.exc import ObservedException, TensorifyScalarRestartAnalysis
  47. from torch._guards import tracing, TracingContext
  48. from torch._logging.structured import dump_file
  49. from torch.fx.experimental.symbolic_shapes import guard_bool
  50. from torch.utils._functools import cache_method
  51. from . import (
  52. config,
  53. exc,
  54. graph_break_hints,
  55. logging as torchdynamo_logging,
  56. trace_rules,
  57. variables,
  58. )
  59. from .bytecode_analysis import (
  60. get_indexof,
  61. JUMP_OPNAMES,
  62. livevars_analysis,
  63. propagate_line_nums,
  64. )
  65. from .bytecode_transformation import (
  66. cleaned_instructions,
  67. create_binary_slice,
  68. create_call_function,
  69. create_copy,
  70. create_dup_top,
  71. create_instruction,
  72. create_jump_absolute,
  73. create_rot_n,
  74. create_swap,
  75. get_code_keys,
  76. Instruction,
  77. is_generator,
  78. is_jump_absolute,
  79. unique_id,
  80. )
  81. from .code_context import code_context
  82. from .codegen import PyCodegen
  83. from .exc import (
  84. ArgsMismatchError,
  85. BackendCompilerFailed,
  86. collapse_resume_frames,
  87. format_graph_break_message,
  88. get_stack_above_dynamo,
  89. ResumePrologueTracingError,
  90. unimplemented_v2,
  91. Unsupported,
  92. )
  93. from .funcname_cache import get_funcname
  94. from .guards import GuardBuilder, install_guard
  95. from .output_graph import GraphCompileReason, OutputGraph
  96. from .polyfills import impl_CONTAINS_OP_fallback
  97. from .replay_record import DummyModule, ExecutionRecorder
  98. from .resume_execution import (
  99. ContinueExecutionCache,
  100. IS_TRACING_RESUME_PROLOGUE_VARNAME,
  101. ReenterWith,
  102. )
  103. from .source import (
  104. AttrSource,
  105. DictGetItemSource,
  106. GlobalSource,
  107. GlobalWeakRefSource,
  108. LocalCellSource,
  109. LocalSource,
  110. SkipGuardSource,
  111. Source,
  112. )
  113. from .trace_rules import is_builtin_constant, is_forbidden
  114. from .utils import (
  115. _get_error_on_graph_break,
  116. counters,
  117. get_fake_value,
  118. get_instruction_source_311,
  119. get_metrics_context,
  120. graph_break_dup_warning_checker,
  121. istype,
  122. LazyString,
  123. proxy_args_kwargs,
  124. )
  125. from .variables.base import typestr, ValueMutationNew, VariableTracker
  126. from .variables.builder import FrameStateSizeEntry, VariableBuilder, wrap_fx_proxy
  127. from .variables.builtin import BuiltinVariable
  128. from .variables.constant import ConstantVariable
  129. from .variables.ctx_manager import (
  130. ContextWrappingVariable,
  131. GenericContextWrappingVariable,
  132. WithExitFunctionVariable,
  133. )
  134. from .variables.dicts import ConstDictVariable, SetVariable
  135. from .variables.functions import (
  136. BaseUserFunctionVariable,
  137. LocalGeneratorFunctionVariable,
  138. LocalGeneratorObjectVariable,
  139. NestedUserFunctionVariable,
  140. SkipFunctionVariable,
  141. UserFunctionVariable,
  142. UserMethodVariable,
  143. )
  144. from .variables.iter import MAX_ITERATOR_LIMIT
  145. from .variables.lazy import LazyVariableTracker
  146. from .variables.lists import (
  147. BaseListVariable,
  148. IteratorVariable,
  149. ListIteratorVariable,
  150. ListVariable,
  151. SliceVariable,
  152. TupleVariable,
  153. )
  154. from .variables.misc import (
  155. CellVariable,
  156. ExceptionVariable,
  157. GetAttrVariable,
  158. NullVariable,
  159. PythonModuleVariable,
  160. UnknownVariable,
  161. )
  162. from .variables.nn_module import NNModuleVariable
  163. from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
  164. from .variables.torch_function import (
  165. SymbolicTorchFunctionState,
  166. TorchFunctionModeVariable,
  167. )
  168. from .variables.user_defined import (
  169. RemovableHandleVariable,
  170. UserDefinedClassVariable,
  171. UserDefinedExceptionClassVariable,
  172. UserDefinedExceptionObjectVariable,
  173. UserDefinedObjectVariable,
  174. )
  175. if TYPE_CHECKING:
  176. from collections.abc import Generator, Sequence
  177. from torch._subclasses.fake_tensor import FakeTensorMode
  178. from .package import CompilePackage
  179. log = logging.getLogger(__name__)
  180. graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
  181. trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
  182. trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
  183. trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode")
  184. tls = threading.local()
  185. compare_op_handlers: dict[str, Any] = {
  186. k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
  187. }
  188. handle_contains = BuiltinVariable(operator.contains).call_function
  189. handle_not = BuiltinVariable(operator.not_).call_function
  190. compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
  191. tx, [*reversed(args)], {}
  192. )
  193. compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
  194. tx, [handle_contains(tx, [*reversed(args)], {})], {}
  195. )
  196. PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml"
  197. ExceptionVals: TypeAlias = Union[
  198. variables.ExceptionVariable,
  199. UserDefinedExceptionClassVariable,
  200. UserDefinedExceptionObjectVariable,
  201. ]
  202. @functools.cache
  203. def _import_module(name: str) -> types.ModuleType:
  204. """
  205. Import the named module and cache the result. importlib.import_module()
  206. seems to do some filesystem checking to validate the name so not caching
  207. this can be slow.
  208. """
  209. return importlib.import_module(name)
  210. @dataclasses.dataclass
  211. class SpeculationEntry:
  212. filename: str
  213. lineno: int
  214. instruction_pointer: int
  215. inst: Instruction # for debugging only
  216. _failed: bool = False
  217. error_on_graph_break: Optional[bool] = None
  218. reason: Optional[GraphCompileReason] = None
  219. def fail_and_restart_analysis(self, error_on_graph_break: bool) -> None:
  220. """
  221. Start tracing of the current frame over again, and don't take this branch.
  222. """
  223. self._failed = True
  224. self.error_on_graph_break = error_on_graph_break
  225. if self.reason is not None:
  226. restart_reason = self.reason.reason
  227. else:
  228. restart_reason = "Unknown fail_and_restart_analysis"
  229. raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
  230. def failed(self, tx: InstructionTranslatorBase) -> bool:
  231. if self._failed:
  232. assert self.error_on_graph_break is not None
  233. tx.error_on_graph_break = self.error_on_graph_break
  234. return True
  235. return False
  236. @dataclasses.dataclass
  237. class SpeculationLog:
  238. """
  239. SpeculationLog replaces the prior copy_graphstate/restore_graphstate
  240. checkpointing. Rather than saving/restoring state, we restart the
  241. dynamo conversion process over from the beginning -- but when we
  242. hit the start of the speculation that failed, we instead generate
  243. a graph break.
  244. """
  245. entries: list[SpeculationEntry] = dataclasses.field(default_factory=list)
  246. index: int = 0
  247. def restart(self) -> None:
  248. self.index = 0
  249. def clear(self) -> None:
  250. self.entries.clear()
  251. self.index = 0
  252. def next(
  253. self, filename: str, lineno: int, instruction_pointer: int, inst: Instruction
  254. ) -> SpeculationEntry:
  255. """
  256. Lookup or create a SpeculationEntry() that is shared across
  257. RestartAnalysis calls. Args are used only for debug checks.
  258. """
  259. if len(self.entries) == self.index:
  260. self.entries.append(
  261. SpeculationEntry(filename, lineno, instruction_pointer, inst)
  262. )
  263. entry = self.entries[self.index]
  264. prev_entry_msg = ""
  265. if self.index != 0:
  266. prev_entry = self.entries[self.index - 1]
  267. prev_entry_msg = (
  268. f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}"
  269. f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n"
  270. )
  271. if not (
  272. entry.instruction_pointer == instruction_pointer
  273. and entry.filename == filename
  274. and entry.lineno == lineno
  275. ):
  276. raise SpeculationLogDivergence(
  277. f"""
  278. SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries):
  279. - Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer})
  280. - Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer})
  281. {prev_entry_msg}
  282. There are two usual reasons why this may have occurred:
  283. - When Dynamo analysis restarted, the second run took a different path than
  284. the first. If this occurred, the previous instruction is the critical instruction that
  285. behaved differently.
  286. - Speculation entries are only added under certain conditions (as seen in
  287. step()), e.g., there must exist operators in the graph; those conditions may
  288. have changed on restart.
  289. If this divergence was intentional, clear the speculation log before restarting (do NOT
  290. do this for graph breaks, you will infinite loop).
  291. Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
  292. """
  293. )
  294. self.index += 1
  295. return entry
  296. @dataclasses.dataclass
  297. class LocalState:
  298. automatic_dynamic: dict[str, FrameStateSizeEntry] = dataclasses.field(
  299. default_factory=dict
  300. )
  301. def render(self) -> str:
  302. return "\n".join(
  303. f"{k}: {v.render()}" for k, v in self.automatic_dynamic.items()
  304. )
  305. # Mutable box that is shared across restarts
  306. @dataclasses.dataclass
  307. class DistributedState:
  308. compile_pg: Any
  309. local_state: LocalState
  310. all_states: Optional[list[LocalState]] = None
  311. class TensorifyState:
  312. # These are the set of string symfloats names (eg. "zf0") that we collect
  313. # from the tensorify_python_scalars.py joint fx pass to inform us about
  314. # which float inputs we should specialize when we restart analysis.
  315. force_specializations: set[str] = set()
  316. @classmethod
  317. def specialize(cls, index: str) -> None:
  318. cls.force_specializations.add(index)
  319. @classmethod
  320. def should_specialize(cls, index: str) -> bool:
  321. return index in cls.force_specializations
  322. @classmethod
  323. def clear(cls) -> None:
  324. cls.force_specializations.clear()
  325. @classmethod
  326. def empty(cls) -> bool:
  327. return len(cls.force_specializations) == 0
  328. @functools.cache
  329. def _step_logger() -> Callable[..., None]:
  330. return torchdynamo_logging.get_step_logger(log)
  331. @contextlib.contextmanager
  332. def save_and_restart_speculation_log(
  333. tx: InstructionTranslatorBase,
  334. ) -> Generator[None, None, None]:
  335. # When reconstructing a generator after a graph break, we advance it until
  336. # it is fully exhausted. This process adds new entries to the speculation
  337. # log that were not previously observed. Without temporarily clearing the
  338. # speculation log, this could lead to a divergence error.
  339. entries = tx.speculation_log.entries
  340. index = tx.speculation_log.index
  341. try:
  342. tx.speculation_log.entries = []
  343. tx.speculation_log.index = 0
  344. yield
  345. finally:
  346. tx.speculation_log.entries = entries
  347. tx.speculation_log.index = index
  348. @contextlib.contextmanager
  349. def temporarely_allow_writes_to_output_graph(
  350. tx: InstructionTranslatorBase,
  351. ) -> Generator[None, None, None]:
  352. try:
  353. tmp = tx.output.should_exit
  354. tx.output.should_exit = False
  355. yield
  356. finally:
  357. tx.output.should_exit = tmp
  358. @dataclasses.dataclass
  359. class BlockStackEntry:
  360. # Current instruction that pushes something to block_stack
  361. inst: Instruction
  362. target: Instruction
  363. stack_index: int
  364. with_context: Optional[
  365. Union[ContextWrappingVariable, GenericContextWrappingVariable]
  366. ] = None
  367. def can_restore(self) -> bool:
  368. return self.with_context is not None
  369. def resume_fn(self) -> ReenterWith:
  370. assert self.stack_index is not None
  371. if (
  372. self.with_context
  373. and hasattr(self.with_context, "target_values")
  374. and self.with_context.target_values
  375. ):
  376. return ReenterWith(
  377. self.stack_index - 1, tuple(self.with_context.target_values)
  378. )
  379. else:
  380. return ReenterWith(self.stack_index - 1)
  381. def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None:
  382. assert self.with_context is not None
  383. if (
  384. is_graph_break and self.with_context.exit_on_graph_break()
  385. ) or not is_graph_break:
  386. return self.with_context.exit(tx) # type: ignore[arg-type]
  387. class SpeculationLogDivergence(AssertionError):
  388. pass
  389. class ReturnValueOp(Exception):
  390. pass
  391. class YieldValueOp(Exception):
  392. """
  393. Signal to the symbolic tracer to stop and return control flow to the
  394. caller
  395. """
  396. def stack_op(fn: Callable[..., object]) -> Callable[..., Any]:
  397. nargs = len(inspect.signature(fn).parameters)
  398. fn_var = BuiltinVariable(fn)
  399. @functools.wraps(fn)
  400. def impl(self: InstructionTranslator, inst: Instruction) -> None:
  401. self.push(fn_var.call_function(self, self.popn(nargs), {}))
  402. return impl
  403. def is_stdlib(mod: object) -> bool:
  404. if sys.version_info < (3, 10):
  405. # For < 3.10, no easy way to identify a stdlib module name.
  406. return False
  407. if not isinstance(mod, types.ModuleType):
  408. return False
  409. return mod.__name__.split(".")[0] in sys.stdlib_module_names
  410. def _detect_and_normalize_assert_statement(
  411. self: InstructionTranslatorBase,
  412. truth_fn: Callable[[object], bool],
  413. push: bool,
  414. ) -> bool:
  415. # Detect if this jump instruction is assert and normalize the assert
  416. # by pushing dummy error message when nothing is given.
  417. #
  418. # Python 3.9 assertion is in following format:
  419. # 18 POP_JUMP_IF_TRUE 28
  420. # 20 LOAD_ASSERTION_ERROR
  421. # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
  422. # 24 CALL_FUNCTION 1 -> optional instruction
  423. # 26 RAISE_VARARGS
  424. #
  425. # Python 3.8 assertion is in following format:
  426. # 18 POP_JUMP_IF_TRUE 28
  427. # 20 LOAD_GLOBAL 0 (Assertion type)
  428. # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
  429. # 24 CALL_FUNCTION 1 -> optional instruction
  430. # 26 RAISE_VARARGS 1
  431. if (truth_fn is not operator.truth) or push:
  432. return False
  433. assert isinstance(self.instruction_pointer, int)
  434. current_instruction_pointer = self.instruction_pointer
  435. inst = self.instructions[current_instruction_pointer]
  436. # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
  437. if inst.opname != "LOAD_ASSERTION_ERROR":
  438. return False
  439. current_instruction_pointer += 1
  440. # Use dummy error message if its hard to extract
  441. error_msg = "assertion error"
  442. inst = self.instructions[current_instruction_pointer]
  443. # DETECT RAISE_VARARGS or LOAD CONST
  444. if inst.opname == "LOAD_CONST":
  445. if not isinstance(inst.argval, str):
  446. return False
  447. error_msg = inst.argval
  448. # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
  449. # (PRECALL for Python 3.11, CALL for Python 3.12+)
  450. current_instruction_pointer += 1
  451. inst = self.instructions[current_instruction_pointer]
  452. if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"):
  453. return False
  454. # for Python 3.11, PRECALL should be followed by CALL, then RAISE_VARARGS
  455. # for Python != 3.11, CALL_FUNCTION/CALL should be followed by RAISE_VARARGS
  456. current_instruction_pointer += 1
  457. if inst.opname == "PRECALL":
  458. current_instruction_pointer += 1
  459. inst = self.instructions[current_instruction_pointer]
  460. if inst.opname != "RAISE_VARARGS":
  461. return False
  462. self.push(ConstantVariable.create(error_msg))
  463. return True
  464. explain = False
  465. def log_graph_break(
  466. code_options: dict[str, Any],
  467. reason: str = "",
  468. exc_info: bool = False,
  469. user_stack: Optional[StackSummary] = None,
  470. ) -> None:
  471. if user_stack is None:
  472. user_stack = torch._guards.TracingContext.extract_stack()
  473. try:
  474. frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
  475. except IndexError:
  476. # first instruction
  477. frame_loc = (
  478. code_options["co_filename"],
  479. code_options["co_firstlineno"],
  480. )
  481. stack_above_dynamo_formatted = ""
  482. if config.verbose:
  483. stack_above_dynamo = get_stack_above_dynamo()
  484. stack_above_dynamo_formatted = "".join(
  485. traceback.format_list(stack_above_dynamo)
  486. )
  487. else:
  488. user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
  489. user_stack = collapse_resume_frames(user_stack)
  490. user_stack_formatted = "".join(traceback.format_list(user_stack))
  491. user_stack_trace = (
  492. f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n"
  493. f"Graph Break Reason: {reason}\n"
  494. "User code traceback:\n"
  495. )
  496. if config.verbose:
  497. user_stack_trace += (
  498. f"{stack_above_dynamo_formatted}\n"
  499. "========== most recent `torch.compile` tracing attempt started here ==========\n\n"
  500. f"{user_stack_formatted}\n"
  501. "NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! "
  502. "This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another "
  503. "Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python "
  504. "function, which Dynamo intercepts as a top-level frame.\n"
  505. )
  506. else:
  507. user_stack_trace += str(user_stack_formatted)
  508. torch._logging.trace_structured(
  509. "artifact",
  510. metadata_fn=lambda: {
  511. "name": "dynamo_graph_break_reason",
  512. "encoding": "string",
  513. },
  514. payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc() if exc_info else ''}",
  515. )
  516. # torch._dynamo.explain() formats this a little nicer, and presents a slightly
  517. # more actionable user code pointer
  518. if (
  519. graph_break_log.isEnabledFor(logging.DEBUG)
  520. and not explain
  521. and graph_break_dup_warning_checker.add(frame_loc)
  522. ):
  523. # This log line MUST contain the string "Graph break in user code",
  524. # This log line is exercised from
  525. # python test/dynamo/test_exc.py -k test_graph_break_log
  526. graph_break_log.debug(
  527. user_stack_trace,
  528. )
  529. else:
  530. # This log line MUST not contain the string "Graph break in user code",
  531. # exercised by
  532. # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log
  533. graph_break_log.debug(
  534. "Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s",
  535. frame_loc[0],
  536. frame_loc[1],
  537. reason,
  538. )
  539. def generic_jump(
  540. truth_fn: Callable[[object], bool], push: bool
  541. ) -> Callable[[InstructionTranslatorBase, Instruction], None]:
  542. # graph break message fields for data dependent branching
  543. _gb_type = "Data-dependent branching"
  544. _explanation = (
  545. "Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). "
  546. "Dynamo does not support tracing dynamic control flow."
  547. )
  548. _hints = [
  549. *graph_break_hints.FUNDAMENTAL,
  550. "Use `torch.cond` to express dynamic control flow.",
  551. ]
  552. def jump_graph_break(
  553. self: InstructionTranslatorBase,
  554. inst: Instruction,
  555. value: VariableTracker,
  556. extra_msg: str = "",
  557. ) -> None:
  558. log_graph_break(
  559. self.code_options,
  560. reason=format_graph_break_message(
  561. gb_type=_gb_type,
  562. context=f"attempted to jump with {value}",
  563. explanation=_explanation,
  564. hints=_hints,
  565. ),
  566. )
  567. assert self.should_compile_partial_graph()
  568. # compile a partial subgraph prefix then jump into user code
  569. if self.maybe_has_backedge():
  570. msg = (
  571. "Skipping frame because there is a graph break in a for/while loop\n"
  572. f"{self.frame_summary()}"
  573. )
  574. log.info(msg)
  575. raise exc.SkipFrame(msg)
  576. self.push(value)
  577. log.debug("generic_jump triggered compile")
  578. all_stack_locals_metadata = self.output.compile_subgraph(
  579. self,
  580. reason=GraphCompileReason(
  581. f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()]
  582. ),
  583. stack_pops=1,
  584. )
  585. self.pop()
  586. if_next = self.create_call_resume_at(
  587. self.next_instruction, all_stack_locals_metadata, False
  588. )
  589. if push:
  590. self.push(value)
  591. assert inst.target is not None
  592. if_jump = self.create_call_resume_at(
  593. inst.target, all_stack_locals_metadata, False
  594. )
  595. if sys.version_info >= (3, 13):
  596. # 3.13 requires stack[-1] to be bool type
  597. self.output.add_output_instructions([create_instruction("TO_BOOL")])
  598. jump_inst = create_instruction(inst.opname, target=if_jump[0])
  599. jump_inst.copy_positions(inst)
  600. self.output.add_output_instructions([jump_inst] + if_next + if_jump)
  601. def inner(self: InstructionTranslatorBase, inst: Instruction) -> None:
  602. value: VariableTracker = self.pop()
  603. if (
  604. config.rewrite_assert_with_torch_assert
  605. and _detect_and_normalize_assert_statement(self, truth_fn, push)
  606. ):
  607. error_msg: VariableTracker = self.pop()
  608. # Skip over things like `assert True`
  609. if value.is_python_constant():
  610. if bool(value.as_python_constant()):
  611. return self.jump(inst)
  612. elif self.should_compile_partial_graph():
  613. jump_graph_break(self, inst, value)
  614. else:
  615. unimplemented_v2(
  616. gb_type="Data-dependent assertion failed (cannot compile partial graph)",
  617. context=f"value: {value}",
  618. explanation="Dynamo has determined when encountering a data-dependent assert failure "
  619. "that it should not compile the partial graph.",
  620. hints=[
  621. *graph_break_hints.FUNDAMENTAL,
  622. "Use `torch._assert()` to raise a hard AssertionError when the check fails. "
  623. "This error will propagate back the user code "
  624. "that called the compiled function (i.e. Dynamo will not trace any exception handling).",
  625. "Remove the assert statement.",
  626. "Move the assert statement outside of any context managers in order to graph break with "
  627. "partial graph compilation (if fullgraph=False).",
  628. ],
  629. )
  630. # TODO maybe should respect DtoH sync intention of users later??
  631. # Manually insert torch._assert_async instead of python assert and jump over
  632. # assert related instructions as we don't need them anymore.
  633. # if we see Tensor as assert statement, no need to call scalar_tensor
  634. if isinstance(value, TensorVariable):
  635. self.output.create_proxy(
  636. "call_function",
  637. torch._assert_async,
  638. *proxy_args_kwargs((value, error_msg), {}),
  639. )
  640. self.jump(inst)
  641. return
  642. if isinstance(value, SymNodeVariable):
  643. # if the assertion is normal shape expression.
  644. # just install guard and bail out.
  645. sym_expr = value.sym_num
  646. if not isinstance(sym_expr, torch.SymBool):
  647. sym_expr = sym_expr != 0
  648. result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
  649. if not result:
  650. unimplemented_v2(
  651. gb_type="Assertion failed on symbolic shapes",
  652. context=str(sym_expr),
  653. explanation="",
  654. hints=[*graph_break_hints.USER_ERROR],
  655. )
  656. self.jump(inst)
  657. return
  658. scalar_to_tensor_proxy = self.output.create_proxy(
  659. "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
  660. )
  661. scalar_to_tensor = wrap_fx_proxy(
  662. self,
  663. scalar_to_tensor_proxy,
  664. example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
  665. )
  666. self.output.create_proxy(
  667. "call_function",
  668. torch._assert_async,
  669. *proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
  670. )
  671. self.jump(inst)
  672. return
  673. if value.is_python_constant():
  674. # ConstDictVariable is optimized to be very lazy about insertion of
  675. # guards, so we have to manually insert a SEQUENCE_LENGTH guard
  676. # here.
  677. if isinstance(value, ConstDictVariable) and value.source:
  678. install_guard(value.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
  679. if truth_fn(value.as_python_constant()):
  680. if push:
  681. self.push(value)
  682. self.jump(inst)
  683. elif (
  684. isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
  685. ):
  686. jump_graph_break(self, inst, value)
  687. elif isinstance(value, NNModuleVariable):
  688. # Equivalent of "self.nn_module is not None"
  689. mod = self.output.get_submodule(value.module_key)
  690. if truth_fn(mod):
  691. if push:
  692. self.push(value)
  693. self.jump(inst)
  694. elif isinstance(value, UserDefinedObjectVariable):
  695. try:
  696. x = value.var_getattr(self, "__bool__") # type: ignore[arg-type]
  697. except exc.ObservedAttributeError:
  698. exc.handle_observed_exception(self)
  699. # if __bool__ is missing, trying __len__ to infer a truth value.
  700. try:
  701. x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
  702. except exc.ObservedAttributeError:
  703. exc.handle_observed_exception(self)
  704. x = None
  705. # __bool__ or __len__ is function
  706. if isinstance(x, UserMethodVariable):
  707. result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment]
  708. if isinstance(result, ConstantVariable) and isinstance(
  709. result.value, (bool, int)
  710. ):
  711. if truth_fn(result.value):
  712. if push:
  713. self.push(value)
  714. self.jump(inst)
  715. elif isinstance(result, SymNodeVariable):
  716. if result.evaluate_expr():
  717. if push:
  718. self.push(value)
  719. self.jump(inst)
  720. else:
  721. unimplemented_v2(
  722. gb_type="Data-dependent branching with non-constant __bool__",
  723. context=f"method: {x}, result: {result}",
  724. explanation="Attempted to perform data-dependent branching on a user-defined "
  725. "object with a __bool__ method that did not return a constant.",
  726. hints=[],
  727. )
  728. # __bool__ or __len__ is non-function or not existed in the user defined object
  729. else:
  730. if truth_fn(True):
  731. if push:
  732. self.push(value)
  733. self.jump(inst)
  734. elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
  735. self
  736. ):
  737. if truth_fn(len(value.unpack_var_sequence(self))):
  738. if push:
  739. self.push(value)
  740. self.jump(inst)
  741. elif isinstance(value, SymNodeVariable):
  742. try:
  743. # if the user is branching on a SymBool, guard on it
  744. # if the user has code like:
  745. # if size:
  746. # ...
  747. # then they are just testing truthiness: guard that the expr != 0
  748. if isinstance(value.sym_num, torch.SymBool):
  749. eval_result = value.evaluate_expr(self.output)
  750. else:
  751. eval_result = guard_bool(value.sym_num != 0)
  752. except exc.UserError as e:
  753. if self.should_compile_partial_graph():
  754. return jump_graph_break(self, inst, value, extra_msg=f"\n{e}")
  755. raise
  756. if truth_fn(eval_result):
  757. if push:
  758. self.push(value)
  759. self.jump(inst)
  760. elif isinstance(value, variables.BackwardHookVariable):
  761. if truth_fn(True):
  762. if push:
  763. self.push(value)
  764. self.jump(inst)
  765. else:
  766. from .source import is_constant_source
  767. if value.source is not None and is_constant_source(value.source):
  768. if truth_fn(value.get_real_value()): # type: ignore[attr-defined]
  769. if push:
  770. self.push(value)
  771. self.jump(inst)
  772. else:
  773. unimplemented_v2(
  774. gb_type="Data-dependent branching",
  775. context=f"attempted to jump with {value}",
  776. explanation=_explanation,
  777. hints=[
  778. *graph_break_hints.FUNDAMENTAL,
  779. "Use `torch.cond` to express dynamic control flow.",
  780. ],
  781. )
  782. return inner
  783. def break_graph_if_unsupported(
  784. *, push: int
  785. ) -> Callable[
  786. [Callable[..., None]], Callable[[InstructionTranslatorBase, Instruction], None]
  787. ]:
  788. def decorator(
  789. inner_fn: Callable[..., None],
  790. ) -> Callable[[InstructionTranslatorBase, Instruction], None]:
  791. @functools.wraps(inner_fn)
  792. def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None:
  793. speculation = self.speculate()
  794. if speculation.failed(self):
  795. assert speculation.reason is not None
  796. return handle_graph_break(self, inst, speculation.reason)
  797. try:
  798. return inner_fn(self, inst)
  799. except Unsupported as excp:
  800. if self.active_generic_context_managers:
  801. # We don't support graph break under GenericContextWrappingVariable,
  802. # If there is, we roll back to the checkpoint and fall back.
  803. excp.remove_from_stats()
  804. unimplemented_v2(
  805. gb_type="Graph break under GenericContextWrappingVariable",
  806. context=f"Active generic context managers: {self.active_generic_context_managers}",
  807. explanation="Attempted to graph break in an active context manager(s) that doesn't support graph breaking.",
  808. hints=[
  809. "Move the offending context manager(s) to outside the compiled region.",
  810. *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
  811. ],
  812. from_exc=excp,
  813. )
  814. if isinstance(excp, exc.UncapturedHigherOrderOpError):
  815. raise
  816. if not self.should_compile_partial_graph():
  817. raise
  818. log_graph_break(
  819. self.code_options,
  820. exc_info=True,
  821. reason=str(excp),
  822. user_stack=excp.real_stack,
  823. )
  824. if self.maybe_has_backedge():
  825. msg = (
  826. "Skipping frame because there is a graph break in a for/while loop\n"
  827. f"{self.frame_summary()}"
  828. )
  829. log.info(msg)
  830. raise exc.SkipFrame(msg) from excp
  831. excp.remove_from_stats()
  832. excp.add_to_stats("graph_break")
  833. speculation.reason = GraphCompileReason(excp.msg, excp.real_stack)
  834. speculation.fail_and_restart_analysis(self.error_on_graph_break)
  835. def handle_graph_break(
  836. self: InstructionTranslatorBase,
  837. inst: Instruction,
  838. reason: GraphCompileReason,
  839. ) -> None:
  840. if (
  841. sys.version_info >= (3, 11)
  842. and sys.version_info < (3, 12)
  843. and inst.opname == "CALL"
  844. ):
  845. # stack effect for PRECALL + CALL is split between the two instructions
  846. stack_effect = dis.stack_effect(
  847. dis.opmap["PRECALL"], inst.arg
  848. ) + dis.stack_effect(dis.opmap["CALL"], inst.arg)
  849. else:
  850. stack_effect = dis.stack_effect(inst.opcode, inst.arg)
  851. all_stack_locals_metadata = self.output.compile_subgraph(
  852. self, reason=reason, stack_pops=push - stack_effect
  853. )
  854. cg = PyCodegen(self)
  855. cleanup: list[Instruction] = []
  856. # Reconstruct the context variable CLASS in the block stack
  857. for b in self.block_stack:
  858. # Don't exit any modes we have entered,
  859. # output bytecode will mutate the tf mode stack accordingly
  860. if isinstance(b.with_context, TorchFunctionModeVariable):
  861. cg.extend_output(
  862. b.resume_fn().try_except_torch_function_mode(
  863. cg.code_options, cleanup
  864. )
  865. )
  866. continue
  867. assert b.with_context is not None
  868. assert isinstance(b.with_context, (ContextWrappingVariable))
  869. b.with_context.reconstruct_type(cg)
  870. cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
  871. self.output.add_output_instructions(cg.get_instructions())
  872. del cg
  873. if sys.version_info >= (3, 11) and inst.opname == "CALL":
  874. kw_names = (
  875. self.kw_names.as_python_constant()
  876. if self.kw_names is not None
  877. else ()
  878. )
  879. if len(kw_names) > 0:
  880. # KW_NAMES no longer used in 3.13
  881. assert sys.version_info < (3, 13)
  882. self.output.add_output_instructions(
  883. [create_instruction("KW_NAMES", argval=kw_names)]
  884. )
  885. assert inst.arg is not None
  886. call_insts = create_call_function(inst.arg, False)
  887. call_insts[-1].copy_positions(inst)
  888. self.output.add_output_instructions(call_insts)
  889. else:
  890. # copy instruction, but without exception table data
  891. assert inst.target is None
  892. inst_copy = copy.copy(inst)
  893. inst_copy.exn_tab_entry = None
  894. self.output.add_output_instructions([inst_copy])
  895. self.output.add_output_instructions(cleanup)
  896. self.popn(push - stack_effect)
  897. for _ in range(push):
  898. self.push(UnknownVariable())
  899. self.output.add_output_instructions(
  900. self.create_call_resume_at(
  901. self.next_instruction, all_stack_locals_metadata, False
  902. )
  903. )
  904. return wrapper
  905. return decorator
  906. class BytecodeDistpatchTableMeta(type):
  907. """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()"""
  908. def __init__(cls: type, name: str, bases: Any, dct: Any) -> None:
  909. super().__init__(name, bases, dct) # type: ignore[misc]
  910. def _missing(opname: str, *args: Any) -> None:
  911. unimplemented_v2(
  912. gb_type="Missing bytecode handler",
  913. context=f"{opname} with args {args}",
  914. explanation=f"Dynamo does not know how to handle the bytecode instruction `{opname}`.",
  915. hints=[
  916. f"Do not trace code that produces the `{opname}` bytecode instruction "
  917. "(see https://docs.python.org/3/library/dis.html for bytecode semantics).",
  918. *graph_break_hints.SUPPORTABLE,
  919. ],
  920. )
  921. dispatch_table = {
  922. op: getattr(cls, opname, functools.partial(_missing, opname))
  923. for opname, op in dis.opmap.items()
  924. }
  925. cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
  926. @dataclasses.dataclass
  927. class ExceptionStack:
  928. """
  929. Exception stack that it is shared among all InstructionTranslator instances
  930. """
  931. # Exception handling in CPython is a bit confusing and some of the bytecode
  932. # have a slightly different behavior than what is is documented. While reading
  933. # the documentation, is important to notice that the terms "current exception"
  934. # and "stack" sometimes refers to a C variable with the same name and the
  935. # exception stack, respectively.
  936. #
  937. # The lifetime of an exception is (Python 3.11+):
  938. # + tx._raise_exception_variable(...) := sets the current_exception variable
  939. # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
  940. # + POP_EXCEPT := pops TOS from the *exception stack*
  941. _exc_stack: list[ExceptionVals] = dataclasses.field(default_factory=list)
  942. _current_exception: Optional[ExceptionVals] = dataclasses.field(default=None)
  943. def clear_current_exception(self) -> None:
  944. self._current_exception = None
  945. def set_current_exception(self, val: ExceptionVals) -> None:
  946. self._set_context_and_break_context_reference_cycle(val)
  947. self._current_exception = val
  948. def move_current_exception_to_stack(self) -> None:
  949. assert self._current_exception is not None
  950. self.append(self._current_exception)
  951. self.clear_current_exception()
  952. def get_current_exception(self) -> ExceptionVals:
  953. assert self._current_exception is not None
  954. return self._current_exception
  955. def _set_context_recursive(
  956. self, val: ExceptionVals, prev_idx: int
  957. ) -> ExceptionVals:
  958. if (ctx := val.__context__) and type(ctx) is not ConstantVariable: # type: ignore[union-attr]
  959. return val
  960. if len(self._exc_stack) + prev_idx > 0:
  961. prev = self._exc_stack[prev_idx]
  962. self._set_context_recursive(prev, prev_idx - 1)
  963. val.set_context(prev) # type: ignore[union-attr, arg-type]
  964. return val
  965. def _break_context_reference_cycle(self, val: ExceptionVals) -> None:
  966. # See test_exceptions::test_raise_does_not_create_context_chain_cycle
  967. # Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
  968. # As noted on CPython, this is O(chain length) but the context chains
  969. # are usually very small
  970. o = slow_o = val
  971. slow_update_toggle = False # floyd's algorithm for detecting cycle
  972. while True:
  973. context = o.__context__ # type: ignore[union-attr]
  974. if type(context) is ConstantVariable: # context not set
  975. break
  976. if context is val:
  977. o.set_context(ConstantVariable(None)) # type: ignore[union-attr, arg-type]
  978. break
  979. o = context # type: ignore[assignment]
  980. if o is slow_o:
  981. # pre-existing cycle - all exceptions on the path were
  982. # visited and checked
  983. break
  984. if slow_update_toggle:
  985. # visited all exceptions
  986. slow_o = slow_o.__context__ # type: ignore[union-attr, assignment]
  987. slow_update_toggle = not slow_update_toggle
  988. def _set_context_and_break_context_reference_cycle(
  989. self, val: ExceptionVals
  990. ) -> None:
  991. # set Exception.__context__
  992. self._set_context_recursive(val, len(self._exc_stack) - 1)
  993. self._break_context_reference_cycle(val)
  994. def pop(self) -> ExceptionVals:
  995. return self._exc_stack.pop()
  996. def append(self, val: ExceptionVals) -> None:
  997. self._exc_stack.append(val)
  998. def __len__(self) -> int:
  999. return len(self._exc_stack)
  1000. def __getitem__(self, index: int) -> ExceptionVals:
  1001. return self._exc_stack[index]
  1002. def __str__(self) -> str:
  1003. return f"{self._exc_stack=} - {self._current_exception=}"
  1004. __repr__ = __str__
  1005. class InstructionTranslatorBase(
  1006. metaclass=BytecodeDistpatchTableMeta,
  1007. ):
  1008. output: OutputGraph
  1009. symbolic_locals: dict[str, VariableTracker]
  1010. symbolic_globals: dict[str, VariableTracker]
  1011. symbolic_torch_function_state: SymbolicTorchFunctionState
  1012. post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
  1013. stack: list[VariableTracker]
  1014. instruction_pointer: Optional[int]
  1015. current_instruction: Instruction
  1016. block_stack: list[BlockStackEntry]
  1017. lineno: int
  1018. kw_names: Optional[ConstantVariable]
  1019. accept_prefix_inst: bool
  1020. prefix_insts: list[Instruction]
  1021. inline_depth: int
  1022. inconsistent_side_effects: bool
  1023. current_speculation: Optional[SpeculationEntry]
  1024. dispatch_table: list[Any]
  1025. exn_vt_stack: ExceptionStack
  1026. exec_recorder: Optional[ExecutionRecorder]
  1027. strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
  1028. start_point: Optional[int]
  1029. is_leaf_tracer: bool
  1030. parent: Optional[InstructionTranslatorBase]
  1031. debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
  1032. package: Optional[CompilePackage]
  1033. def mark_inconsistent_side_effects(self) -> None:
  1034. """
  1035. InstructionTranslator has encountered instructions which may cause
  1036. dynamo to see a different version of history from eager
  1037. See: https://github.com/pytorch/pytorch/issues/110765
  1038. """
  1039. self.inconsistent_side_effects = True
  1040. def maybe_has_backedge(self) -> bool:
  1041. # This function employs a heuristic. It does not reliably detect a backedge.
  1042. # The heuristic is straightforward: starting from the current instruction and
  1043. # continuing to the end, if any jump instruction targets an instruction before
  1044. # the current one, there might be a backedge.
  1045. # Python 3.12 introduced changes to bytecode that group common paths in
  1046. # blockstacks (with or try...else) and allow for early returns. Consequently,
  1047. # there can be multiple RETURN_VALUE instructions. Another heuristic is to
  1048. # halt detection upon encountering the first RETURN_VALUE or RETURN_CONST.
  1049. # These heuristics can result in both false positives and negatives, but
  1050. # in either case, the Dynamo code remains valid. For false positives
  1051. # (where an edge is incorrectly marked as a backedge), Dynamo will
  1052. # perform a SkipFrame instead of potentially applying optimizations. For
  1053. # false negatives (where an edge that should be marked as a backedge
  1054. # isn't), multiple graphs may be generated if there's a break in the
  1055. # graph during a for loop. In general, its better to have fewer false
  1056. # negatives so that Dynamo does not skip the whole frame.
  1057. # If any parent tx has a backedge, then return True
  1058. cur_tx: Optional[InstructionTranslatorBase] = self
  1059. while cur_tx is not None:
  1060. cur_offset = cur_tx.current_instruction.offset
  1061. assert cur_tx.instruction_pointer is not None
  1062. for inst in cur_tx.instructions[cur_tx.instruction_pointer :]:
  1063. if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
  1064. break
  1065. if inst.opname in JUMP_OPNAMES:
  1066. jump_offset = inst.argval
  1067. if jump_offset < cur_offset:
  1068. return True
  1069. cur_tx = cur_tx.parent
  1070. return False
  1071. def cellvars(self) -> list[str]:
  1072. return self.code_options["co_cellvars"]
  1073. def freevars(self) -> list[str]:
  1074. return self.code_options["co_freevars"]
  1075. def cell_and_freevars(self) -> list[str]:
  1076. if not hasattr(self, "_cell_and_freevars"):
  1077. self._cell_and_freevars = self.cellvars() + self.freevars()
  1078. return self._cell_and_freevars
  1079. def prune_dead_locals(self) -> None:
  1080. # keep cell and freevar references alive
  1081. self.post_prune_cell_and_freevars = {
  1082. k: v
  1083. for k, v in self.symbolic_locals.items()
  1084. if k in self.cell_and_freevars()
  1085. }
  1086. # Only keep the locals that must remain on the stack.
  1087. reads = livevars_analysis(self.instructions, self.current_instruction)
  1088. self.symbolic_locals = {
  1089. k: v for k, v in self.symbolic_locals.items() if k in reads
  1090. }
  1091. def call_function(
  1092. self,
  1093. fn: VariableTracker,
  1094. args: list[VariableTracker],
  1095. kwargs: dict[str, VariableTracker],
  1096. ) -> None:
  1097. assert isinstance(fn, VariableTracker)
  1098. assert isinstance(args, list)
  1099. assert isinstance(kwargs, dict)
  1100. assert all(
  1101. isinstance(x, VariableTracker)
  1102. for x in itertools.chain(args, kwargs.values())
  1103. )
  1104. inner_fn = None
  1105. if hasattr(fn, "value"):
  1106. inner_fn = fn.value
  1107. if hasattr(fn, "fn"):
  1108. inner_fn = fn.fn
  1109. if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
  1110. raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
  1111. self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
  1112. def inline_generator_function(
  1113. self, fn: VariableTracker, args: Sequence[Any], kwargs: dict[str, Any]
  1114. ) -> Any:
  1115. """
  1116. Redirect the call to the generator "call_function"
  1117. """
  1118. if not isinstance(fn, LocalGeneratorFunctionVariable):
  1119. fn = LocalGeneratorFunctionVariable(fn) # type: ignore[arg-type]
  1120. return fn.call_function(self, args, kwargs) # type: ignore[arg-type]
  1121. def inline_user_function_return(
  1122. self, fn: VariableTracker, args: Sequence[Any], kwargs: dict[str, Any]
  1123. ) -> Any:
  1124. """
  1125. A call to some user defined function by inlining it.
  1126. """
  1127. self.is_leaf_tracer = False
  1128. if config.enable_faithful_generator_behavior and is_generator(fn.get_code()): # type: ignore[attr-defined]
  1129. return self.inline_generator_function(fn, args, kwargs)
  1130. else:
  1131. return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  1132. def get_line_of_code_header(self, lineno: Optional[int] = None) -> str:
  1133. if lineno is None:
  1134. lineno = self.lineno
  1135. inline_depth_str = (
  1136. f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
  1137. )
  1138. funcname = get_funcname(self.f_code.co_filename, lineno)
  1139. funcname_str = "" if funcname is None else f" ({funcname})"
  1140. return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}"
  1141. def get_log_starts_line_log_str(self) -> str:
  1142. log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
  1143. line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
  1144. log_str += f" {line}"
  1145. return log_str
  1146. def starts_line(self, lineno: int) -> None:
  1147. if self.lineno == lineno:
  1148. return
  1149. self.lineno = lineno
  1150. TracingContext.set_current_loc(
  1151. self.f_code.co_filename, lineno, self.f_code.co_name
  1152. )
  1153. if self.is_trace_source_log_enabled:
  1154. trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
  1155. def step(self) -> bool:
  1156. """Process exactly one instruction, return False we should exit"""
  1157. self.error_on_graph_break = _get_error_on_graph_break()
  1158. ip = self.instruction_pointer
  1159. if ip is None:
  1160. return False
  1161. self.current_instruction = inst = self.instructions[ip]
  1162. self.instruction_pointer = ip + 1
  1163. if inst.starts_line:
  1164. self.starts_line(inst.starts_line)
  1165. if (
  1166. not self.stack
  1167. and self.should_compile_partial_graph()
  1168. and self.is_non_empty_graph()
  1169. ):
  1170. self.current_speculation = self.speculate()
  1171. if self.current_speculation.failed(self):
  1172. self.step_graph_break(inst)
  1173. return False
  1174. if self.is_trace_bytecode_log_enabled:
  1175. trace_bytecode_log.debug(
  1176. "TRACE %s %s %s", inst.opname, inst.argval, self.stack
  1177. )
  1178. self.update_block_stack(inst)
  1179. try:
  1180. self.dispatch_table[inst.opcode](self, inst)
  1181. return not self.output.should_exit
  1182. except TensorifyScalarRestartAnalysis:
  1183. raise
  1184. except exc.ObservedException as e:
  1185. self.exception_handler(e)
  1186. return True
  1187. except (ReturnValueOp, YieldValueOp):
  1188. return False
  1189. except Unsupported:
  1190. # More restrictive condition than should_compile_partial_graph:
  1191. # if this condition is true, then we SHOULD NOT attempt to find
  1192. # a previous checkpoint to resume from and try to resume - we should
  1193. # immediately error out.
  1194. # The condition is more restrictive because, it may be possible to resume significantly earlier
  1195. # in the code (the most recent speculation point). This happens, for example, in the case
  1196. # of a graph break in a try block.
  1197. if (
  1198. self.one_graph
  1199. or self.error_on_graph_break
  1200. or self.is_tracing_resume_prologue
  1201. ):
  1202. raise
  1203. if self.current_speculation is None:
  1204. log.debug("empty checkpoint")
  1205. raise
  1206. log.debug("step triggered compile", exc_info=True)
  1207. self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break)
  1208. return False
  1209. if sys.version_info >= (3, 11):
  1210. def update_block_stack(self, inst: Instruction) -> None:
  1211. # 3.11+ no longer uses a block stack, but we still keep track of one
  1212. # so that we know which contexts are currently active.
  1213. # For our purposes, all exception table entries with the same target
  1214. # are considered to be part of the same "block".
  1215. # NOTE: we only keep track of with blocks that are not contained in try blocks.
  1216. # This is because we will not create continuation functions on graph breaks in try blocks,
  1217. # but we may for with blocks. We do not push blocks here since
  1218. # with blocks are pushed when handling BEFORE_WITH.
  1219. entry = inst.exn_tab_entry
  1220. if entry:
  1221. # Detect when we have exited the top with block.
  1222. # The with blocks on the block stack are not enclosed in try
  1223. # blocks, so a with block's cleanup code should be in the
  1224. # previous with block (if any).
  1225. if (
  1226. len(self.block_stack) >= 2
  1227. and entry.target is not self.block_stack[-1].target
  1228. and entry.target is self.block_stack[-2].target
  1229. ):
  1230. # exit the current block
  1231. self.block_stack.pop()
  1232. else:
  1233. # no longer in any block
  1234. # It is possible for NOPs to be between two instructions
  1235. # in the same block, but the NOPs are not covered by an
  1236. # exception table entry. In this case, assume that we
  1237. # are still in the same block.
  1238. # In 3.12+, JUMP_BACKWARD might also not be covered by
  1239. # an exception table entry, so we also assume that we
  1240. # are still in the same block. It is probably safe to do
  1241. # this in 3.11, even though we haven't encountered this case before.
  1242. if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
  1243. # If we really escape from a block and the current
  1244. # instruction is not in another block, then there
  1245. # should be no other nested blocks that we are in.
  1246. assert len(self.block_stack) == 1
  1247. self.block_stack.pop()
  1248. else:
  1249. def update_block_stack(self, inst: Instruction) -> None:
  1250. pass
  1251. @property
  1252. def next_instruction(self) -> Instruction:
  1253. assert self.instruction_pointer is not None
  1254. return self.instructions[self.instruction_pointer]
  1255. def step_graph_break(self, continue_inst: Instruction) -> None:
  1256. # generate code from checkpoint
  1257. assert not self.output.output_instructions
  1258. assert self.current_speculation is not None
  1259. # NOTE: adding an assert here since it seems like the only place
  1260. # where we call step_graph_break right now is when the stack is empty,
  1261. # so let's enforce that for now.
  1262. assert not self.stack
  1263. # NOTE: if we support non-empty self.stack in the future, the `stack_pops` argument
  1264. # below should be set to the stack length to ensure that the stack is codegen'd
  1265. # for the rest of the function.
  1266. all_stack_locals_metadata = self.output.compile_subgraph(
  1267. self,
  1268. partial_convert=True,
  1269. reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
  1270. )
  1271. if self.parent:
  1272. # nested graph break
  1273. assert config.nested_graph_breaks
  1274. self.output.add_output_instructions(
  1275. self.create_call_resume_at(
  1276. continue_inst, all_stack_locals_metadata, True
  1277. )
  1278. )
  1279. else:
  1280. # load locals from frame values
  1281. # current frame state
  1282. # [
  1283. # frame N locals,
  1284. # frame N-1 stack + locals,
  1285. # ...,
  1286. # frame 1 stack + locals,
  1287. # ],
  1288. cg = PyCodegen(self)
  1289. self.output.add_output_instructions(
  1290. [
  1291. cg.create_load_const(-1),
  1292. cg.create_binary_subscr(),
  1293. ]
  1294. )
  1295. for local, idx in all_stack_locals_metadata[-1].locals_names.items():
  1296. self.output.add_output_instructions(
  1297. [
  1298. create_dup_top(),
  1299. cg.create_load_const(idx),
  1300. cg.create_binary_subscr(),
  1301. cg.create_store(local),
  1302. ]
  1303. )
  1304. self.output.add_output_instructions(
  1305. [
  1306. create_instruction("POP_TOP"),
  1307. create_jump_absolute(continue_inst),
  1308. *self.instructions,
  1309. ]
  1310. )
  1311. def run_ctx_mgr(self) -> Any:
  1312. # NB: Don't push the top level frame summary; set_current_loc will
  1313. # take care of it. However, DO make sure we attach real_stack to
  1314. # exceptions
  1315. return TracingContext.current_frame(None)
  1316. def run(self) -> None:
  1317. with self.run_ctx_mgr():
  1318. dump_file(self.f_code.co_filename)
  1319. try:
  1320. self.output.push_tx(self)
  1321. self.start_point = self.instruction_pointer
  1322. try:
  1323. while self.step():
  1324. pass
  1325. except Exception as e:
  1326. if self.is_tracing_resume_prologue:
  1327. raise ResumePrologueTracingError(
  1328. "Error while tracing through a Dynamo-generated resume function prologue. "
  1329. "Errors are not allowed when tracing resume function prologues.\n"
  1330. f"{type(e).__qualname__}: {str(e)}"
  1331. ).with_traceback(e.__traceback__) from None
  1332. raise
  1333. except TensorifyScalarRestartAnalysis:
  1334. raise
  1335. except BackendCompilerFailed:
  1336. raise
  1337. except RuntimeError as e:
  1338. if hasattr(e, "msg") and "Data-dependent" in e.msg:
  1339. readable_graph = torch.fx.GraphModule(
  1340. self.output.nn_modules, self.output.graph
  1341. ).print_readable(
  1342. print_output=False, include_stride=True, include_device=True
  1343. )
  1344. e.partial_fx_graph = readable_graph # type: ignore[attr-defined]
  1345. raise
  1346. raise
  1347. except Exception as e:
  1348. if self.exec_recorder:
  1349. e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined]
  1350. raise
  1351. finally:
  1352. self.output.pop_tx()
  1353. # Cleanup the outputGraph to delete the held tensors. We perform the
  1354. # cleanup only for InstructionTranslator and not
  1355. # InliningInstructionTranslator. The InliningInstructionTranslator
  1356. # mutates the output object and is restored to original state if
  1357. # there was an exception.
  1358. if isinstance(self, InstructionTranslator):
  1359. self.output.cleanup()
  1360. # Note that this call maybe redundant if compile_subgraph is
  1361. # called. This is ok, because calling exit stack close()
  1362. # twice is not an issue (second stop is a no op).
  1363. self.output.mark_bytecode_tracing_stop()
  1364. def push(self, val: Optional[VariableTracker]) -> None:
  1365. assert val is None or isinstance(val, VariableTracker), (
  1366. f"push expects VariableTracker, got {typestr(val)}"
  1367. )
  1368. self.stack.append(val) # type: ignore[arg-type]
  1369. def push_many(self, vals: list[VariableTracker]) -> None:
  1370. for val in vals:
  1371. self.push(val)
  1372. def pop(self) -> VariableTracker:
  1373. return self.stack.pop()
  1374. def popn(self, n: int) -> list[VariableTracker]:
  1375. return [*reversed([self.pop() for _ in range(n)])]
  1376. def LOAD_FAST(self, inst: Instruction) -> None:
  1377. name = inst.argval
  1378. if self.exec_recorder and name in self.f_locals:
  1379. self.exec_recorder.add_local_var(name, self.f_locals[name])
  1380. try:
  1381. self.push(self.symbolic_locals[name].unwrap())
  1382. except KeyError:
  1383. if name.startswith("."):
  1384. try:
  1385. # This happens in dict/list comprehensions
  1386. new_name = name.replace(".", "implicit")
  1387. self.push(self.symbolic_locals[new_name])
  1388. except KeyError:
  1389. unimplemented_v2(
  1390. gb_type="Attempted to read undefined local variable (implicit)",
  1391. context=f"LOAD_FAST {name}",
  1392. explanation=f"Could not find an implicit local variable with name `{name}`",
  1393. hints=[
  1394. "This happens in dict/list comprehensions",
  1395. *graph_break_hints.USER_ERROR,
  1396. ],
  1397. )
  1398. else:
  1399. unimplemented_v2(
  1400. gb_type="Attempted to read undefined local variable",
  1401. context=f"LOAD_FAST {name}",
  1402. explanation=f"Could not find a local variable with name `{name}`",
  1403. hints=[*graph_break_hints.USER_ERROR],
  1404. )
  1405. # for continuation functions
  1406. if name.startswith("__stack"):
  1407. self.symbolic_locals.pop(name)
  1408. def LOAD_DEREF(self, inst: Instruction) -> None:
  1409. assert inst.argval in self.cell_and_freevars()
  1410. cell = self.symbolic_locals[inst.argval]
  1411. contents_var = self.output.side_effects.load_cell(cell)
  1412. self.push(contents_var)
  1413. if self.exec_recorder and inst.argval in self.f_locals:
  1414. self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
  1415. def STORE_FAST(self, inst: Instruction) -> None:
  1416. name = inst.argval
  1417. loaded_vt = self.pop()
  1418. loaded_vt.set_name_hint(name)
  1419. self.symbolic_locals[name] = loaded_vt
  1420. if name == IS_TRACING_RESUME_PROLOGUE_VARNAME:
  1421. val = loaded_vt.as_python_constant()
  1422. assert type(val) is bool
  1423. self.is_tracing_resume_prologue = val
  1424. def DELETE_FAST(self, inst: Instruction) -> None:
  1425. del self.symbolic_locals[inst.argval]
  1426. def STORE_DEREF(self, inst: Instruction) -> None: # type: ignore[override]
  1427. assert inst.argval in self.cell_and_freevars()
  1428. cell = self.symbolic_locals[inst.argval]
  1429. val = self.pop()
  1430. self.output.side_effects.store_cell(cell, val)
  1431. assert isinstance(cell, CellVariable) # tame mypy
  1432. if cell.local_name is not None:
  1433. val.set_name_hint(cell.local_name) # type: ignore[attr-defined]
  1434. LOAD_CLOSURE = LOAD_FAST
  1435. def _load_const(self, inst: Instruction) -> ConstantVariable:
  1436. i = inst.arg
  1437. if i is None:
  1438. return ConstantVariable.create(value=inst.argval) # type: ignore[return-value]
  1439. val = self._constants_cache[i]
  1440. if not val:
  1441. self._constants_cache[i] = ConstantVariable.create(value=inst.argval) # type: ignore[call-overload]
  1442. val = self._constants_cache[i]
  1443. assert val is not None
  1444. return val
  1445. def LOAD_CONST(self, inst: Instruction) -> None:
  1446. self.push(self._load_const(inst))
  1447. def _load_global(self, inst: Instruction) -> None:
  1448. name = inst.argval
  1449. if self.exec_recorder:
  1450. if name in self.f_globals:
  1451. self.exec_recorder.add_global_var(name, self.f_globals[name])
  1452. else:
  1453. assert name in self.f_builtins
  1454. self.exec_recorder.builtins[name] = self.f_builtins[name]
  1455. if name not in self.f_globals:
  1456. return self.load_builtin(inst)
  1457. if name in self.symbolic_globals:
  1458. variable = self.output.side_effects[self.symbolic_globals[name]]
  1459. self.push(self.output.side_effects.load_global(variable, name))
  1460. return
  1461. value = self.f_globals[name]
  1462. self.push(VariableTracker.build(self, value, GlobalSource(name)))
  1463. @functools.cached_property
  1464. def nn_modules_globals_vt(self) -> VariableTracker:
  1465. module_name = "torch.nn.modules.module"
  1466. module_source = self.import_source(module_name)
  1467. fglobals_value = _import_module(module_name)
  1468. return VariableTracker.build(self, fglobals_value, module_source)
  1469. def LOAD_GLOBAL(self, inst: Instruction) -> None:
  1470. assert inst.arg is not None
  1471. if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
  1472. self.PUSH_NULL(inst)
  1473. self._load_global(inst)
  1474. if sys.version_info >= (3, 13) and inst.arg % 2:
  1475. self.PUSH_NULL(inst)
  1476. def STORE_GLOBAL(self, inst: Instruction) -> None:
  1477. value = self.pop()
  1478. name = inst.argval
  1479. source = GlobalSource(name)
  1480. if name not in self.symbolic_globals:
  1481. self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
  1482. variable = self.output.side_effects.track_global_existing(
  1483. source, self.symbolic_globals[name]
  1484. )
  1485. if isinstance(value, RemovableHandleVariable):
  1486. unimplemented_v2(
  1487. gb_type="Storing Tensor hook handle in globals",
  1488. context=name,
  1489. explanation="This is not supported.",
  1490. hints=[],
  1491. )
  1492. self.output.side_effects.store_global(variable, name, value)
  1493. # Cache note: This cache only exists for the duration of this
  1494. # InstructionTranslator - so it should be safe to do.
  1495. @cache_method
  1496. def import_source(self, module_name: str) -> GlobalSource:
  1497. """Create an alias to a module for use in guards"""
  1498. if "torch_package" in module_name:
  1499. value = torch.package.package_importer._package_imported_modules[
  1500. module_name
  1501. ]
  1502. alias = (
  1503. module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
  1504. )
  1505. else:
  1506. value = _import_module(module_name)
  1507. alias = f"__import_{module_name.replace('.', '_dot_')}"
  1508. if self.package is not None:
  1509. self.package.add_import_source(alias, module_name)
  1510. self.output.import_sources[alias] = module_name
  1511. f_globals = self.output.global_scope
  1512. assert alias not in f_globals or f_globals[alias] is value
  1513. f_globals[alias] = value
  1514. self.output.update_co_names(alias)
  1515. return GlobalSource(alias)
  1516. def resolve_name(self, name: str, package: str, level: int) -> str:
  1517. """
  1518. Copied from the Cpython implementation of __import__
  1519. Resolve a relative module name to an absolute one.
  1520. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
  1521. """
  1522. bits = package.rsplit(".", level - 1)
  1523. if len(bits) < level:
  1524. raise ImportError("attempted relative import beyond top-level package")
  1525. base = bits[0]
  1526. return f"{base}.{name}" if name else base
  1527. def calc_package(self) -> str:
  1528. """
  1529. Copied from the Cpython implementation of __import__
  1530. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
  1531. """
  1532. package = self.f_globals.get("__package__")
  1533. spec = self.f_globals.get("__spec__")
  1534. if package is not None:
  1535. if spec is not None and package != spec.parent:
  1536. log.warning(
  1537. "__package__ != __spec__.parent (%r != %r)",
  1538. package,
  1539. spec.parent,
  1540. stacklevel=3,
  1541. )
  1542. return package
  1543. elif spec is not None:
  1544. return spec.parent
  1545. else:
  1546. log.warning(
  1547. "can't resolve package from __spec__ or __package__, "
  1548. "falling back on __name__ and __path__",
  1549. stacklevel=3,
  1550. )
  1551. package = self.f_globals["__name__"]
  1552. if "__path__" not in self.f_globals:
  1553. package = package.rpartition(".")[0]
  1554. return package
  1555. def IMPORT_NAME(self, inst: Instruction) -> None:
  1556. level, fromlist = self.popn(2)
  1557. level = level.as_python_constant()
  1558. fromlist = fromlist.as_python_constant()
  1559. module_name = inst.argval
  1560. # Are we replaying? if so, load recorded module
  1561. recorded_name = (
  1562. f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
  1563. )
  1564. if recorded_name in self.f_globals:
  1565. value = self.f_globals[recorded_name]
  1566. source = GlobalSource(recorded_name)
  1567. else:
  1568. try:
  1569. value = __import__(
  1570. module_name,
  1571. fromlist=fromlist,
  1572. level=level,
  1573. globals=self.f_globals,
  1574. )
  1575. except ImportError:
  1576. unimplemented_v2(
  1577. gb_type="Import failure",
  1578. context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}",
  1579. explanation="Failure when attempting to import.",
  1580. hints=[*graph_break_hints.USER_ERROR],
  1581. )
  1582. if level != 0:
  1583. pkg = self.calc_package()
  1584. module_name = self.resolve_name(module_name, pkg, level)
  1585. # For __import__, when the name variable is of the form package.module,
  1586. # normally, the top-level package (the name up till the first dot) is
  1587. # returned, not the module named by module_name. However, when a
  1588. # non-empty fromlist argument is given, the module named by name is
  1589. # returned. Therefore, we set the source correctly here.
  1590. if not fromlist:
  1591. top_level_module_name = module_name.partition(".")[0]
  1592. source = self.import_source(top_level_module_name)
  1593. else:
  1594. source = self.import_source(module_name)
  1595. if self.exec_recorder:
  1596. self.exec_recorder.add_local_mod(recorded_name, value)
  1597. if istype(value, (types.ModuleType, DummyModule)):
  1598. self.push(PythonModuleVariable(value, source=source))
  1599. else:
  1600. unimplemented_v2(
  1601. gb_type="Bad import result",
  1602. context=typestr(value),
  1603. explanation="Import result is not a Python module.",
  1604. hints=[],
  1605. )
  1606. # fb internal 3.12 opcode
  1607. EAGER_IMPORT_NAME = IMPORT_NAME
  1608. def IMPORT_FROM(self, inst: Instruction) -> None:
  1609. self.DUP_TOP(inst)
  1610. self._load_attr(inst)
  1611. # Cache note: This cache only exists for the duration of this
  1612. # InstructionTranslator - so it should be safe to do.
  1613. @cache_method
  1614. def load_builtin_from_argval(self, argval: Any) -> VariableTracker:
  1615. if argval not in self.f_builtins:
  1616. raise Unsupported(f"name '{argval}' is not defined")
  1617. val = self.f_builtins[argval]
  1618. if callable(val):
  1619. builtins_source = GlobalSource(
  1620. self.output.name_of_builtins_dict_key_in_fglobals
  1621. )
  1622. var_source = DictGetItemSource(builtins_source, argval)
  1623. return VariableTracker.build(self, val, var_source)
  1624. else:
  1625. assert is_builtin_constant(val)
  1626. return ConstantVariable.create(value=val)
  1627. def load_builtin(self, inst: Instruction) -> None:
  1628. self.push(self.load_builtin_from_argval(inst.argval))
  1629. def jump(self, inst: Instruction) -> None:
  1630. assert self.instruction_pointer is not None
  1631. assert self.start_point is not None
  1632. assert inst.target is not None
  1633. get_metrics_context().increment(
  1634. "ir_count", self.instruction_pointer - self.start_point
  1635. )
  1636. self.instruction_pointer = self.indexof[inst.target]
  1637. self.start_point = self.instruction_pointer
  1638. JUMP_FORWARD = jump
  1639. JUMP_ABSOLUTE = jump
  1640. POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
  1641. POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
  1642. JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
  1643. JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
  1644. def SETUP_LOOP(self, inst: Instruction) -> None:
  1645. # only exists in python<=3.7
  1646. assert inst.target is not None
  1647. self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
  1648. def SETUP_EXCEPT(self, inst: Instruction) -> None:
  1649. # only exists in python<=3.7
  1650. assert inst.target is not None
  1651. self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
  1652. def POP_BLOCK(self, inst: Instruction) -> None:
  1653. self.block_stack.pop()
  1654. def SETUP_WITH(self, inst: Instruction) -> None:
  1655. self.setup_or_before_with(inst)
  1656. def SETUP_FINALLY(self, inst: Instruction) -> None:
  1657. assert inst.target is not None
  1658. self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
  1659. def BEGIN_FINALLY(self, inst: Instruction) -> None:
  1660. self.push(None)
  1661. def WITH_CLEANUP_START(self, inst: Instruction) -> None:
  1662. exit, exc = self.popn(2)
  1663. assert exc is None
  1664. self.push(exc)
  1665. self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {}))
  1666. def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
  1667. self.popn(2)
  1668. self.push(None)
  1669. def FOR_ITER(self, inst: Instruction) -> None:
  1670. it = self.pop().realize()
  1671. try:
  1672. val = it.next_variable(self)
  1673. self.push(it)
  1674. self.push(val)
  1675. except (StopIteration, exc.ObservedUserStopIteration) as e:
  1676. if isinstance(e, exc.ObservedUserStopIteration):
  1677. exc.handle_observed_exception(self)
  1678. # leave iterator upon exhaustion in 3.12
  1679. if sys.version_info >= (3, 12):
  1680. # CPython 3.12 actually jumps to the instruction after the END_FOR
  1681. # and performs the action of END_FOR as part of FOR_ITER. We jump
  1682. # to the END_FOR and run it, so we need to make sure 2 values are
  1683. # on the stack for it to pop.
  1684. self.push(it)
  1685. self.push(ConstantVariable.create(None))
  1686. self.jump(inst)
  1687. def _create_exception_type(self, val: VariableTracker) -> VariableTracker:
  1688. if isinstance(
  1689. val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable)
  1690. ):
  1691. # Create the instance of the exception type
  1692. # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
  1693. val = val.call_function(self, [], {}) # type: ignore[arg-type]
  1694. return val
  1695. def _raise_exception_variable(self, val: VariableTracker) -> NoReturn:
  1696. # User can raise exception in 2 ways
  1697. # 1) raise exception type - raise NotImplementedError
  1698. # 2) raise exception instance - raise NotImplemetedError("foo")
  1699. # 1) when user raises exception type
  1700. val = self._create_exception_type(val)
  1701. # Handle https://peps.python.org/pep-0479/
  1702. # CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this
  1703. if (
  1704. is_generator(self.f_code)
  1705. and isinstance(val, variables.ExceptionVariable)
  1706. and val.exc_type is StopIteration
  1707. ):
  1708. val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type]
  1709. # Save the exception in a global data structure
  1710. self.exn_vt_stack.set_current_exception(val) # type: ignore[arg-type]
  1711. # 2) when user raises exception instance
  1712. if self._isinstance_exception(val):
  1713. observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined, union-attr]
  1714. raise observed_exception_type(f"raised exception {val}")
  1715. unimplemented_v2(
  1716. gb_type="Failed to raise exception",
  1717. context=str(exc),
  1718. explanation="Attempted to raise a non-Exception type/value.",
  1719. hints=[*graph_break_hints.USER_ERROR],
  1720. )
  1721. def RAISE_VARARGS(self, inst: Instruction) -> None:
  1722. if inst.arg == 0:
  1723. if not len(self.exn_vt_stack):
  1724. msg = ConstantVariable("No active exception to reraise")
  1725. exc.raise_observed_exception(RuntimeError, self, args=[msg])
  1726. # re-raise the previous exception. Here CPython refers to the exception
  1727. # on top of the exception stack
  1728. assert len(self.exn_vt_stack)
  1729. val = self.exn_vt_stack[-1]
  1730. assert self._isinstance_exception(val), val
  1731. self._raise_exception_variable(val)
  1732. elif inst.arg == 1:
  1733. # raise TOS
  1734. val = self.stack[-1] # type: ignore[assignment]
  1735. self._raise_exception_variable(val)
  1736. else:
  1737. # raise .. from ...
  1738. from_vt = self.pop()
  1739. val = self.pop() # type: ignore[assignment]
  1740. try:
  1741. self._raise_exception_variable(val)
  1742. finally:
  1743. # Update __cause__/__supppress_context__ in the raised exception
  1744. curr_exc = self.exn_vt_stack.get_current_exception()
  1745. cause = self._create_exception_type(from_vt)
  1746. curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) # type: ignore[arg-type, union-attr, assignment]
  1747. def CLEANUP_THROW(self, inst: Instruction) -> None:
  1748. # https://github.com/python/cpython/pull/96010
  1749. tos = self.stack[-1]
  1750. assert isinstance(tos, ExceptionVariable)
  1751. if tos.exc_type is StopIteration:
  1752. unimplemented_v2(
  1753. gb_type="CLEANUP_THROW with StopIteration",
  1754. context="",
  1755. explanation="Received StopIteration when handling generator.throw/close. This is not supported.",
  1756. hints=[],
  1757. )
  1758. else:
  1759. self.RERAISE(inst)
  1760. def RERAISE(self, inst: Instruction) -> None:
  1761. # https://docs.python.org/3/library/dis.html#opcode-RERAISE
  1762. # Re-raises the exception currently on top of the stack. If oparg is
  1763. # non-zero, pops an additional value from the stack which is used to
  1764. # set f_lasti of the current frame.
  1765. if sys.version_info >= (3, 11):
  1766. # RERAISE is currently supported in a narrow case of `raise ... from None`
  1767. val = self.pop()
  1768. if inst.argval:
  1769. # RERAISE 1
  1770. _ = self.pop()
  1771. self._raise_exception_variable(val)
  1772. else:
  1773. # RERAISE 0
  1774. self.push(val)
  1775. self._raise_exception_variable(val)
  1776. else:
  1777. _exc = self.pop()
  1778. val = self.pop()
  1779. _tb = self.pop()
  1780. self._raise_exception_variable(val)
  1781. def _isinstance_exception(self, val: VariableTracker) -> TypeIs[ExceptionVals]:
  1782. return isinstance(
  1783. val,
  1784. (
  1785. variables.ExceptionVariable,
  1786. UserDefinedExceptionClassVariable,
  1787. UserDefinedExceptionObjectVariable,
  1788. ),
  1789. )
  1790. def WITH_EXCEPT_START(self, inst: Instruction) -> None:
  1791. if sys.version_info >= (3, 11):
  1792. # At the top of the stack are 4 values:
  1793. # - TOP = exc_info()
  1794. # - SECOND = previous exception
  1795. # - THIRD: lasti of exception in exc_info()
  1796. # - FOURTH: the context.__exit__ bound method
  1797. # We call FOURTH(type(TOP), TOP, GetTraceback(TOP)).
  1798. # Then we push the __exit__ return value.
  1799. assert len(self.stack) >= 4
  1800. fn = self.stack[-4]
  1801. val = self.stack[-1]
  1802. assert self._isinstance_exception(val)
  1803. typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined, union-attr]
  1804. tb = ConstantVariable(None)
  1805. else:
  1806. assert len(self.stack) >= 7
  1807. fn = self.stack[-7]
  1808. val = self.stack[-2]
  1809. assert self._isinstance_exception(val)
  1810. typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
  1811. tb = ConstantVariable(None)
  1812. self.call_function(fn, [typ, val, tb], {})
  1813. def exception_handler(self, raised_exception: ObservedException) -> None:
  1814. observed_exn_gb_explanation = (
  1815. "Dynamo found no exception handler at the top-level compiled function "
  1816. "when encountering an exception. Exception will propagate outside the compiled region."
  1817. )
  1818. def bubble_exception_to_interpreter() -> None:
  1819. # Bubble the exception to the interpreter
  1820. curr_exc = self.exn_vt_stack.get_current_exception()
  1821. dynamo_exc = exc.get_dynamo_observed_exception(curr_exc.python_type())
  1822. assert isinstance(raised_exception, dynamo_exc) # sanity check
  1823. unimplemented_v2(
  1824. gb_type="Observed exception",
  1825. context=f"raised exception {curr_exc.python_type_name()}({curr_exc.args})", # type: ignore[union-attr]
  1826. explanation=observed_exn_gb_explanation,
  1827. hints=[
  1828. *graph_break_hints.USER_ERROR,
  1829. *graph_break_hints.SUPPORTABLE,
  1830. ],
  1831. )
  1832. if sys.version_info >= (3, 11):
  1833. exn_tab_entry = self.current_instruction.exn_tab_entry
  1834. if exn_tab_entry:
  1835. # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
  1836. # 1) pop values from the stack until it matches the stack depth
  1837. # for the handler
  1838. while len(self.stack) > exn_tab_entry.depth:
  1839. self.pop()
  1840. # 2) if 'lasti' is true, then push the offset that the exception was raised at
  1841. if exn_tab_entry.lasti:
  1842. self.push(
  1843. variables.ConstantVariable(self.current_instruction.offset)
  1844. )
  1845. # 3) push the exception to the stack
  1846. self.push(self.exn_vt_stack.get_current_exception())
  1847. # 4) jump to the handler
  1848. self.jump(exn_tab_entry) # type: ignore[arg-type]
  1849. else:
  1850. # No handler found. Bubble the exception to the parent
  1851. # instruction translator. We use special exception for this.
  1852. self.stack.clear()
  1853. if type(self) is InstructionTranslator:
  1854. bubble_exception_to_interpreter()
  1855. raise raised_exception
  1856. else:
  1857. if len(self.block_stack):
  1858. # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
  1859. block_stack_entry = self.block_stack.pop()
  1860. while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
  1861. # TODO(anijain2305) - This is not tested .. unable to create a testcase
  1862. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
  1863. self.popn(3)
  1864. self.exn_vt_stack.pop()
  1865. if len(self.block_stack) == 0:
  1866. # No handler found in this frame. Bubble the exception to the parent
  1867. # instruction translator.
  1868. self.stack.clear()
  1869. if type(self) is InstructionTranslator:
  1870. unimplemented_v2(
  1871. gb_type="Observed exception (EXCEPT_HANDLER)",
  1872. context=str(raised_exception),
  1873. explanation=observed_exn_gb_explanation
  1874. + " This graph break is unexpected.",
  1875. hints=[*graph_break_hints.DYNAMO_BUG],
  1876. )
  1877. raise raised_exception
  1878. block_stack_entry = self.block_stack.pop()
  1879. exception_var = self.exn_vt_stack.get_current_exception()
  1880. self.exn_vt_stack.move_current_exception_to_stack()
  1881. # 1) pop values from the stack until it matches the stack depth
  1882. # for the handler
  1883. while len(self.stack) > block_stack_entry.stack_index:
  1884. self.pop()
  1885. # Push a dummy block stack entry of EXCEPT_HANDLER
  1886. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
  1887. except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0)
  1888. self.block_stack.append(
  1889. BlockStackEntry(except_handler_inst, None, len(self.stack))
  1890. )
  1891. # Push old exception
  1892. if len(self.exn_vt_stack) >= 2:
  1893. old_exception = self.exn_vt_stack[-2]
  1894. # Push the old exception on to stack - tb, value, type
  1895. # Traceback is currently mapped to UnknownVariable
  1896. self.push(variables.UnknownVariable())
  1897. self.push(old_exception)
  1898. self.push(variables.BuiltinVariable(old_exception.exc_type))
  1899. else:
  1900. # Push empty exception tb, value, type
  1901. self.push(variables.ConstantVariable(None))
  1902. self.push(variables.ConstantVariable(None))
  1903. self.push(variables.ConstantVariable(None))
  1904. # Push new exception - tb, val, type
  1905. # Traceback is currently mapped to UnknownVariable
  1906. self.push(variables.UnknownVariable())
  1907. self.push(exception_var)
  1908. self.push(variables.BuiltinVariable(exception_var.exc_type))
  1909. # Jump to target
  1910. self.jump(block_stack_entry)
  1911. else:
  1912. # No handler found. Bubble the exception to the parent
  1913. # instruction translator. We use special exception for this.
  1914. self.stack.clear()
  1915. if type(self) is InstructionTranslator:
  1916. bubble_exception_to_interpreter()
  1917. raise raised_exception
  1918. def PUSH_EXC_INFO(self, inst: Instruction) -> None:
  1919. # https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO
  1920. # Pops a value from the stack. Pushes the current exception to the top
  1921. # of the stack. Pushes the value originally popped back to the stack.
  1922. #
  1923. # The behavior of this opcode in CPython is a bit different than what it
  1924. # is described. It pops a value from the stack, pushes the top of the
  1925. # exception stack to the interpreter stack and moves the
  1926. # "current exception" to the exception stack.
  1927. #
  1928. # As an example, suppose the stack is in the following state:
  1929. # + stack = [..., ConstantVariable(1), ConstantVariable(2)]
  1930. # + current_exception = TypeError
  1931. # + exception_stack = [ValueError]
  1932. #
  1933. # After PUSH_EXC_INFO is executed
  1934. # + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)]
  1935. # + current_exception = None
  1936. # + exception_stack = [ValueError, TypeError]
  1937. val = self.pop()
  1938. if len(self.exn_vt_stack) == 0:
  1939. prev_exc: VariableTracker = ConstantVariable(None)
  1940. else:
  1941. prev_exc = self.exn_vt_stack[-1]
  1942. self.push(prev_exc)
  1943. self.push(val)
  1944. self.exn_vt_stack.move_current_exception_to_stack()
  1945. def POP_EXCEPT(self, inst: Instruction) -> None:
  1946. if sys.version_info >= (3, 11):
  1947. _ = self.pop()
  1948. # This exception is handled and therefore we can clear the error indicator
  1949. assert len(self.exn_vt_stack)
  1950. self.exn_vt_stack.pop()
  1951. else:
  1952. assert len(self.block_stack) > 0
  1953. if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER":
  1954. raise AssertionError(
  1955. "Bug in Dynamo tracing of exception handling."
  1956. "Top of the block stack is not EXCEPT_HANDLER."
  1957. )
  1958. self.block_stack.pop()
  1959. self.popn(3)
  1960. # This exception is handled and therefore we can clear the error indicator
  1961. assert len(self.exn_vt_stack)
  1962. self.exn_vt_stack.pop()
  1963. def check_if_exc_matches(self) -> bool:
  1964. assert len(self.stack) >= 2
  1965. expected_exc_types = self.pop()
  1966. if sys.version_info >= (3, 11):
  1967. # CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop.
  1968. # This is the description from the disassembly doc
  1969. #
  1970. # Performs exception matching for ``except``. Tests whether the ``STACK[-2]``
  1971. # is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean
  1972. # result of the test.
  1973. exc_instance = self.stack[-1]
  1974. else:
  1975. # This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH
  1976. # There is no documentation but here is the code pointer that does 2 pops
  1977. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665
  1978. exc_instance = self.stack.pop()
  1979. # Users can check exception in 3 ways
  1980. # 1) except NotImplementedError --> BuiltinVariable
  1981. # 2) except CustomException --> UserDefinedExceptionClasVariable
  1982. # 3) except (NotImplemetedError, AttributeError) -> TupleVariable
  1983. if not isinstance(
  1984. expected_exc_types,
  1985. (
  1986. BuiltinVariable,
  1987. TupleVariable,
  1988. UserDefinedExceptionClassVariable,
  1989. UserDefinedExceptionObjectVariable,
  1990. ),
  1991. ):
  1992. unimplemented_v2(
  1993. gb_type="Exception with bad expected type",
  1994. context=str(expected_exc_types),
  1995. explanation=f"`except ...` has unsupported type {expected_exc_types}.",
  1996. hints=[*graph_break_hints.USER_ERROR],
  1997. )
  1998. if sys.version_info >= (3, 11):
  1999. if not self._isinstance_exception(exc_instance):
  2000. unimplemented_v2(
  2001. gb_type="Caught non-Exception value",
  2002. context=str(exc_instance),
  2003. explanation=f"Except expects to receive an object of Exception type but received {exc_instance}.",
  2004. hints=[*graph_break_hints.USER_ERROR],
  2005. )
  2006. if isinstance(expected_exc_types, TupleVariable):
  2007. expected_types = expected_exc_types.items
  2008. else:
  2009. expected_types = [
  2010. expected_exc_types,
  2011. ]
  2012. for expected_type in expected_types:
  2013. if not isinstance(
  2014. expected_type,
  2015. (
  2016. BuiltinVariable,
  2017. UserDefinedExceptionObjectVariable,
  2018. UserDefinedExceptionClassVariable,
  2019. ),
  2020. ):
  2021. unimplemented_v2(
  2022. gb_type="Exception with non-type expectation",
  2023. context=str(expected_type),
  2024. explanation=f"`except ...` expects a non-type: {expected_type}.",
  2025. hints=[*graph_break_hints.USER_ERROR],
  2026. )
  2027. if self._isinstance_exception(exc_instance) and issubclass(
  2028. exc_instance.exc_type, # type: ignore[union-attr]
  2029. expected_type.fn, # type: ignore[attr-defined]
  2030. ):
  2031. return True
  2032. elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
  2033. exc_instance.fn, expected_type.fn
  2034. ):
  2035. return True
  2036. return False
  2037. def CHECK_EXC_MATCH(self, inst: Instruction) -> None:
  2038. self.push(variables.ConstantVariable(self.check_if_exc_matches()))
  2039. def JUMP_IF_NOT_EXC_MATCH(self, inst: Instruction) -> None:
  2040. if not self.check_if_exc_matches():
  2041. self.jump(inst)
  2042. def COMPARE_OP(self, inst: Instruction) -> None:
  2043. if inst.argval == "exception match":
  2044. self.CHECK_EXC_MATCH(inst)
  2045. else:
  2046. self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
  2047. def GET_ITER(self, inst: Instruction) -> None:
  2048. self.call_function(BuiltinVariable(iter), [self.pop()], {})
  2049. @break_graph_if_unsupported(push=1)
  2050. def CALL_FUNCTION(self, inst: Instruction) -> None:
  2051. args = self.popn(inst.argval)
  2052. fn = self.pop()
  2053. self.call_function(fn, args, {})
  2054. @break_graph_if_unsupported(push=1)
  2055. def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
  2056. kwargsvars: VariableTracker
  2057. if inst.argval == 0:
  2058. kwargsvars = ConstDictVariable({})
  2059. argsvars = self.pop()
  2060. elif inst.argval == 1:
  2061. kwargsvars = self.pop()
  2062. argsvars = self.pop()
  2063. else:
  2064. unimplemented_v2(
  2065. gb_type="Variadic function call with bad flags",
  2066. context=f"flags: {inst.argval}",
  2067. explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}",
  2068. hints=[*graph_break_hints.DYNAMO_BUG],
  2069. )
  2070. if sys.version_info >= (3, 13):
  2071. # 3.13 swapped null and callable
  2072. null = self.pop()
  2073. assert isinstance(null, NullVariable)
  2074. fn = self.pop()
  2075. if sys.version_info >= (3, 11) and sys.version_info < (3, 13):
  2076. null = self.pop()
  2077. assert isinstance(null, NullVariable)
  2078. if not isinstance(
  2079. argsvars, BaseListVariable
  2080. ) and argsvars.has_force_unpack_var_sequence(self):
  2081. argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
  2082. # Unpack for cases like fn(**obj) where obj is a map
  2083. if isinstance(kwargsvars, UserDefinedObjectVariable):
  2084. kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
  2085. if not isinstance(argsvars, BaseListVariable) or not isinstance(
  2086. kwargsvars, ConstDictVariable
  2087. ):
  2088. unimplemented_v2(
  2089. gb_type="Variadic function call with bad args/kwargs type",
  2090. context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
  2091. explanation="Expected args to be a list and kwargs to be a dict",
  2092. hints=[*graph_break_hints.USER_ERROR],
  2093. )
  2094. # Map to a dictionary of str -> VariableTracker
  2095. kwargsvars = kwargsvars.keys_as_python_constant()
  2096. self.call_function(fn, argsvars.items, kwargsvars)
  2097. @break_graph_if_unsupported(push=1)
  2098. def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
  2099. argnames = self.pop()
  2100. args = self.popn(inst.argval)
  2101. fn = self.pop()
  2102. assert isinstance(argnames, TupleVariable) and argnames.is_python_constant()
  2103. argnames = argnames.as_python_constant()
  2104. args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :]
  2105. kwargs = dict(zip(argnames, kwargs_list))
  2106. assert len(kwargs) == len(argnames)
  2107. self.call_function(fn, args, kwargs)
  2108. def LOAD_METHOD_SUPER(self, inst: Instruction) -> None:
  2109. self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
  2110. arg = inst.argval[0]
  2111. argval = self.code_options["co_names"][arg]
  2112. if sys.version_info < (3, 11):
  2113. self._load_attr(dataclasses.replace(inst, argval=argval))
  2114. else:
  2115. self.LOAD_METHOD(dataclasses.replace(inst, argval=argval))
  2116. def LOAD_ATTR_SUPER(self, inst: Instruction) -> None:
  2117. self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
  2118. arg = inst.argval[0]
  2119. argval = self.code_options["co_names"][arg]
  2120. self._load_attr(dataclasses.replace(inst, argval=argval))
  2121. def LOAD_METHOD(self, inst: Instruction) -> None:
  2122. self._load_attr(inst)
  2123. obj = self.pop()
  2124. if sys.version_info >= (3, 13):
  2125. self.push(obj)
  2126. self.PUSH_NULL(inst)
  2127. elif sys.version_info >= (3, 11):
  2128. # always follow the NULL + fn convention, since if obj
  2129. # is actually a method, self is already bound to it, so it
  2130. # doesn't need to be passed in as an arg.
  2131. self.PUSH_NULL(inst)
  2132. self.push(obj)
  2133. else:
  2134. self.push(obj)
  2135. self.push(None)
  2136. def CALL_METHOD(self, inst: Instruction) -> None:
  2137. args = self.popn(inst.argval)
  2138. dummy = self.pop()
  2139. assert dummy is None
  2140. fn = self.pop()
  2141. self.call_function(fn, args, {})
  2142. def _load_attr(self, inst: Instruction) -> None:
  2143. obj = self.pop()
  2144. result = BuiltinVariable(getattr).call_function(
  2145. self, # type: ignore[arg-type]
  2146. [obj, ConstantVariable.create(inst.argval)],
  2147. {},
  2148. )
  2149. self.push(result)
  2150. def LOAD_ATTR(self, inst: Instruction) -> None:
  2151. if sys.version_info >= (3, 12):
  2152. if inst.arg % 2:
  2153. self.LOAD_METHOD(inst)
  2154. return
  2155. self._load_attr(inst)
  2156. def STORE_ATTR(self, inst: Instruction) -> None:
  2157. speculation = self.speculate()
  2158. if speculation.failed(self):
  2159. return self.store_attr_graph_break(inst)
  2160. val, obj = self.popn(2)
  2161. if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable):
  2162. # We don't allow side effects during export on non-constant values
  2163. # https://github.com/pytorch/torchdynamo/issues/1475
  2164. assert not self.export, (
  2165. f"Mutating module attribute {inst.argval} during export."
  2166. )
  2167. try:
  2168. BuiltinVariable(setattr).call_function(
  2169. self, # type: ignore[arg-type]
  2170. [obj, ConstantVariable.create(inst.argval), val],
  2171. {},
  2172. )
  2173. return
  2174. except Unsupported as e:
  2175. if not self.should_compile_partial_graph():
  2176. raise
  2177. log.debug("STORE_ATTR triggered compile", exc_info=True)
  2178. e.remove_from_stats()
  2179. e.add_to_stats("graph_break")
  2180. speculation.fail_and_restart_analysis(self.error_on_graph_break)
  2181. def store_attr_graph_break(self, inst: Instruction) -> None:
  2182. log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break")
  2183. if not self.should_compile_partial_graph():
  2184. unimplemented_v2(
  2185. gb_type="Should not compile partial graph (STORE_ATTR)",
  2186. context="",
  2187. explanation="Dynamo has determined when encountering an unsupported "
  2188. "STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.",
  2189. hints=[],
  2190. )
  2191. all_stack_locals_metadata = self.output.compile_subgraph(
  2192. self,
  2193. reason=GraphCompileReason("store_attr", [self.frame_summary()]),
  2194. stack_pops=2,
  2195. )
  2196. inst_copy = copy.copy(inst)
  2197. inst_copy.exn_tab_entry = None
  2198. self.output.add_output_instructions([inst_copy])
  2199. self.popn(2)
  2200. self.output.add_output_instructions(
  2201. self.create_call_resume_at(
  2202. self.next_instruction, all_stack_locals_metadata, False
  2203. )
  2204. )
  2205. def DELETE_ATTR(self, inst: Instruction) -> None:
  2206. obj = self.pop()
  2207. BuiltinVariable(delattr).call_function(
  2208. self, # type: ignore[arg-type]
  2209. [obj, ConstantVariable.create(inst.argval)],
  2210. {},
  2211. )
  2212. def create_call_resume_at(
  2213. self,
  2214. inst: Instruction,
  2215. all_stack_locals_metadata: Any,
  2216. disable_current_frame_resume: bool,
  2217. ) -> list[Instruction]:
  2218. """
  2219. Codegen resume function(s) and call it.
  2220. Assumes that the unsupported instruction has already been run.
  2221. Expects the stack to be in the state:
  2222. [
  2223. frame N locals,
  2224. frame N-1 stack + locals,
  2225. ...,
  2226. frame 1 stack + locals
  2227. ], frame N stack (post-instruction)
  2228. Args:
  2229. - inst: the instruction of the current (deepest) frame to resume at
  2230. - all_stack_locals_metadata: metadata returned from OutputGraph.compile_subgraph - contains
  2231. metadata such as local names, NULL positions, stack length, etc.
  2232. - disable_current_frame_resume: If True, disable tracing on the current frame's resume function.
  2233. Used for implementing nested step_graph_break.
  2234. """
  2235. self.instruction_pointer = None
  2236. if inst.opname == "RETURN_VALUE":
  2237. return [create_instruction("RETURN_VALUE")]
  2238. elif inst.opname == "RETURN_CONST":
  2239. return [create_instruction("RETURN_CONST", argval=inst.argval)]
  2240. cg = PyCodegen(self.output.root_tx)
  2241. # move frame N stack to the frame values list
  2242. current_num_stack = len(self.stack) - len(
  2243. all_stack_locals_metadata[0].stack_null_idxes
  2244. )
  2245. all_stack_locals_metadata[0].num_stack = current_num_stack
  2246. cg.extend_output(
  2247. [
  2248. create_instruction("BUILD_LIST", arg=current_num_stack),
  2249. *create_copy(2),
  2250. # frame_values, frame N stack, frame_values
  2251. cg.create_load_const(0),
  2252. cg.create_binary_subscr(),
  2253. *create_binary_slice(0, 0, True),
  2254. # frame_values[0][0:0] = frame N stack
  2255. # frame_values left on top of stack
  2256. ]
  2257. )
  2258. # current frame state
  2259. # [
  2260. # [frame N stack (fixed) + locals]
  2261. # ...,
  2262. # [frame 1 stack + locals]
  2263. # ],
  2264. #
  2265. txes = []
  2266. cur_tx: Optional[InstructionTranslatorBase] = self
  2267. while cur_tx is not None:
  2268. txes.append(cur_tx)
  2269. cur_tx = cur_tx.parent
  2270. assert len(txes) == len(all_stack_locals_metadata)
  2271. # Handle inactive context variables.
  2272. # The resume function assumes that context variables are the class, NOT the object.
  2273. # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
  2274. # NOTE: if the unsupported instruction modifies the inactive context variable, it may
  2275. # result in silent incorrectness!
  2276. for i, meta in enumerate(all_stack_locals_metadata):
  2277. if i == 0 and disable_current_frame_resume:
  2278. continue
  2279. for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
  2280. # Replace the stack var with the context class
  2281. ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
  2282. # frames[i][j] = reconstructed_ctx
  2283. cg.append_output(create_dup_top())
  2284. ctx.reconstruct_type(cg)
  2285. cg.extend_output(
  2286. [
  2287. *create_swap(2),
  2288. cg.create_load_const(i),
  2289. cg.create_binary_subscr(),
  2290. cg.create_load_const(j),
  2291. create_instruction("STORE_SUBSCR"),
  2292. ]
  2293. )
  2294. for name, _ in meta.locals_ctx_args:
  2295. # Replace the local with the context class
  2296. ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
  2297. # frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
  2298. cg.append_output(create_dup_top())
  2299. ctx.reconstruct_type(cg)
  2300. cg.extend_output(
  2301. [
  2302. *create_swap(2),
  2303. cg.create_load_const(i),
  2304. cg.create_binary_subscr(),
  2305. cg.create_load_const(meta.num_stack + meta.locals_names[name]),
  2306. create_instruction("STORE_SUBSCR"),
  2307. ]
  2308. )
  2309. # build the resume function for each frame
  2310. resume_names = []
  2311. resume_codes: list[types.CodeType] = []
  2312. for i, meta in enumerate(all_stack_locals_metadata):
  2313. cur_tx = txes[i]
  2314. if cur_tx is self:
  2315. resume_inst = inst
  2316. else:
  2317. resume_inst = cur_tx.next_instruction
  2318. # If the resume instruction is a jump absolute, then resume
  2319. # at the target instead. This handles the case where we
  2320. # graph break again in a nested function before jump-resuming
  2321. # this frame.
  2322. if is_jump_absolute(resume_inst):
  2323. assert resume_inst.target
  2324. resume_inst = resume_inst.target
  2325. resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
  2326. resume_names.append(resume_name)
  2327. # More locals may have been pruned in the current frame
  2328. # after the unsupported instruction (e.g. branch).
  2329. # There should not be any pruning in the other frames since
  2330. # the current instruction is a CALL.
  2331. if cur_tx is self:
  2332. reads = livevars_analysis(cur_tx.instructions, resume_inst)
  2333. all_argnames = tuple(
  2334. k
  2335. for k in cur_tx.symbolic_locals.keys()
  2336. if k in reads and k not in cur_tx.cell_and_freevars()
  2337. )
  2338. argnames_null_set = set(meta.locals_null_keys)
  2339. argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
  2340. argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
  2341. # codegen filter for current frame's locals
  2342. # current stack state: frames
  2343. cg.extend_output(
  2344. [
  2345. create_dup_top(),
  2346. cg.create_load_const(i),
  2347. cg.create_binary_subscr(),
  2348. create_dup_top(),
  2349. ]
  2350. )
  2351. for arg in argnames:
  2352. # current stack state: frames, frames[i], *(prev locals), frames[i]
  2353. cg.extend_output(
  2354. [
  2355. create_dup_top(),
  2356. cg.create_load_const(
  2357. meta.num_stack + meta.locals_names[arg]
  2358. ),
  2359. cg.create_binary_subscr(),
  2360. *create_swap(2),
  2361. ],
  2362. )
  2363. # current stack state: frames, frames[i], *(frame i live locals), frames[i]
  2364. cg.extend_output(
  2365. [
  2366. create_instruction("POP_TOP"),
  2367. create_instruction("BUILD_LIST", arg=len(argnames)),
  2368. *create_swap(2),
  2369. # frames, frames i live locals, frames[i]
  2370. *create_binary_slice(meta.num_stack, None, True),
  2371. # frames[i][num_stack:] = frame i live locals
  2372. ]
  2373. )
  2374. # current stack state: frames
  2375. else:
  2376. argnames = tuple(meta.locals_names.keys())
  2377. argnames_null = tuple(meta.locals_null_keys)
  2378. if sys.version_info < (3, 12):
  2379. assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
  2380. assert cur_tx.current_instruction.offset is not None
  2381. # compile_subgraph did not codegen any NULLs,
  2382. # so we should not count NullVariables
  2383. stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
  2384. new_code: types.CodeType = ContinueExecutionCache.lookup(
  2385. cur_tx.f_code,
  2386. cur_tx.lineno,
  2387. cur_tx.current_instruction.offset,
  2388. resume_inst.offset, # type: ignore[arg-type]
  2389. tuple(b.target.offset for b in cur_tx.block_stack),
  2390. stack_len,
  2391. argnames,
  2392. argnames_null,
  2393. tuple(b.resume_fn() for b in cur_tx.block_stack),
  2394. tuple(meta.stack_ctx_args),
  2395. tuple(meta.locals_ctx_args),
  2396. tuple(meta.stack_null_idxes),
  2397. tuple(resume_codes),
  2398. )
  2399. resume_codes.append(new_code)
  2400. # Add original GraphModule context to the resume function to handle
  2401. # the case of a graph break while tracing a GraphModule
  2402. orig_graphmodule_maybe = code_context.get_context(cur_tx.f_code).get(
  2403. "orig_graphmodule", lambda: None
  2404. )()
  2405. if orig_graphmodule_maybe is not None:
  2406. code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
  2407. orig_graphmodule_maybe
  2408. )
  2409. # add resume function to the global scope
  2410. if new_code.co_freevars:
  2411. # expose code object for debugging purposes
  2412. cur_tx.output.install_global_unsafe(resume_name, new_code)
  2413. package_name = None
  2414. else:
  2415. # This is safe: we pre-generate a unique name
  2416. cur_tx.output.install_global_unsafe(
  2417. resume_name,
  2418. types.FunctionType(new_code, cur_tx.f_globals, resume_name),
  2419. )
  2420. package_name = resume_name
  2421. if cur_tx.package is not None:
  2422. cur_tx.package.add_resume_function(
  2423. new_code, cur_tx.f_globals["__name__"], package_name
  2424. )
  2425. if disable_current_frame_resume:
  2426. from .eval_frame import skip_code
  2427. skip_code(resume_codes[0])
  2428. # load first resume function (to be called this frame)
  2429. if resume_codes[-1].co_freevars:
  2430. cg.make_function_with_closure(
  2431. txes[-1], resume_names[-1], resume_codes[-1], True, 1
  2432. )
  2433. else:
  2434. cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
  2435. # load all other resume functions (to be called later)
  2436. resume_names.pop()
  2437. resume_codes.pop()
  2438. for tx, name, code in zip(txes, resume_names, resume_codes):
  2439. if code.co_freevars:
  2440. cg.make_function_with_closure(tx, name, code, False, 0)
  2441. else:
  2442. cg.extend_output(cg.load_function_name(name, False, 0))
  2443. cg.extend_output(
  2444. [
  2445. create_instruction("BUILD_LIST", arg=len(resume_codes)),
  2446. *create_swap(2),
  2447. ]
  2448. )
  2449. # resume 1 (+ NULL), [resume N, ..., resume 2], frames
  2450. # load top level-frame; final stack state should be:
  2451. # first resume function (+ NULL),
  2452. # [
  2453. # [resume N, ..., resume 2],
  2454. # [
  2455. # frame N stack + locals,
  2456. # ...,
  2457. # frame 2 stack + locals,
  2458. # ], *(frame 1 stack + locals)
  2459. # ]
  2460. cg.extend_output(
  2461. [
  2462. create_dup_top(),
  2463. create_dup_top(),
  2464. # frames, frames, frames
  2465. cg.create_load_const(-1),
  2466. cg.create_binary_subscr(),
  2467. # frames, frames, frames[-1]
  2468. *create_swap(2),
  2469. # frames, frames[-1], frames
  2470. cg.create_load_const(-1),
  2471. create_instruction("DELETE_SUBSCR"),
  2472. ]
  2473. )
  2474. # TOS: resumes, frames (popped), frame 1 stack + locals
  2475. cg.extend_output(
  2476. [
  2477. *create_rot_n(3),
  2478. create_instruction("BUILD_LIST", arg=2),
  2479. *create_swap(2),
  2480. # [resumes, frames (popped)], frame 1 stack + locals
  2481. create_instruction("LIST_EXTEND", arg=1),
  2482. ]
  2483. )
  2484. # TOS: [resumes, frames, *(frame 1 stack + locals)]
  2485. cg.extend_output(
  2486. [
  2487. create_instruction("CALL_FUNCTION_EX", arg=0),
  2488. create_instruction("RETURN_VALUE"),
  2489. ]
  2490. )
  2491. return cg.get_instructions()
  2492. def should_compile_partial_graph(self) -> bool:
  2493. if sys.version_info >= (3, 11):
  2494. # Do not compile if current instruction's block is not the top with block
  2495. entry = self.current_instruction.exn_tab_entry
  2496. if entry and (
  2497. not self.block_stack or entry.target is not self.block_stack[-1].target
  2498. ):
  2499. return False
  2500. return (
  2501. all(b.can_restore() for b in self.block_stack)
  2502. and not self.one_graph
  2503. and not self.error_on_graph_break
  2504. and not self.is_tracing_resume_prologue
  2505. and not self.active_generic_context_managers
  2506. )
  2507. @break_graph_if_unsupported(push=0)
  2508. def STORE_SUBSCR(self, inst: Instruction) -> None:
  2509. val, obj, key = self.popn(3)
  2510. obj.call_method(self, "__setitem__", [key, val], {})
  2511. def DELETE_SUBSCR(self, inst: Instruction) -> None:
  2512. obj, key = self.popn(2)
  2513. obj.call_method(self, "__delitem__", [key], {})
  2514. def BUILD_TUPLE(self, inst: Instruction) -> None:
  2515. items = self.popn(inst.argval)
  2516. self.push(TupleVariable(items))
  2517. def BUILD_SLICE(self, inst: Instruction) -> None:
  2518. items = self.popn(inst.argval)
  2519. self.push(SliceVariable(items))
  2520. def BUILD_LIST(self, inst: Instruction) -> None:
  2521. items = self.popn(inst.argval)
  2522. self.push(ListVariable(items, mutation_type=ValueMutationNew()))
  2523. def BUILD_SET(self, inst: Instruction) -> None:
  2524. if config.inject_BUILD_SET_unimplemented_TESTING_ONLY:
  2525. unimplemented_v2(
  2526. gb_type="missing BUILD_SET handler",
  2527. context="",
  2528. explanation="Missing BUILD_SET bytecode handler (for testing purposes).",
  2529. hints=[],
  2530. )
  2531. items = self.popn(inst.argval)
  2532. new_set = SetVariable(items, mutation_type=ValueMutationNew())
  2533. self.push(new_set)
  2534. def BUILD_LIST_UNPACK(self, inst: Instruction, cls: type = ListVariable) -> None:
  2535. seqs = self.popn(inst.argval)
  2536. items = []
  2537. for seq in seqs:
  2538. try:
  2539. items.extend(seq.force_unpack_var_sequence(self))
  2540. except NotImplementedError:
  2541. unimplemented_v2(
  2542. gb_type="Failed to unpack object for BUILD_LIST_UNPACK",
  2543. context=str(seq),
  2544. explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK "
  2545. "bytecode (`[*x, *y, ...]`).",
  2546. hints=[*graph_break_hints.USER_ERROR],
  2547. )
  2548. self.push(cls(items, mutation_type=ValueMutationNew()))
  2549. def BUILD_TUPLE_UNPACK(self, inst: Instruction) -> None:
  2550. self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
  2551. BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
  2552. def BUILD_MAP(self, inst: Instruction) -> None:
  2553. items = self.popn(inst.argval * 2)
  2554. d = dict(zip(items[::2], items[1::2]))
  2555. self.push(ConstDictVariable(d, mutation_type=ValueMutationNew()))
  2556. def BUILD_MAP_UNPACK(self, inst: Instruction) -> None:
  2557. items = self.popn(inst.argval)
  2558. # ensure everything is a dict
  2559. items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type]
  2560. result: dict[Any, Any] = {}
  2561. for x in items:
  2562. assert isinstance(x, ConstDictVariable)
  2563. result.update(x.items)
  2564. self.push(
  2565. ConstDictVariable(
  2566. result,
  2567. mutation_type=ValueMutationNew(),
  2568. )
  2569. )
  2570. BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK
  2571. def BUILD_CONST_KEY_MAP(self, inst: Instruction) -> None:
  2572. keys = self.pop()
  2573. values = self.popn(inst.argval)
  2574. assert isinstance(keys, TupleVariable)
  2575. assert keys.is_python_constant()
  2576. keys = keys.force_unpack_var_sequence(self)
  2577. assert len(keys) == len(values)
  2578. self.push(
  2579. ConstDictVariable(
  2580. dict(zip(keys, values)),
  2581. mutation_type=ValueMutationNew(),
  2582. )
  2583. )
  2584. def MAP_ADD(self, inst: Instruction) -> None:
  2585. k, v = self.popn(2)
  2586. assert inst.argval > 0
  2587. assert inst.arg is not None
  2588. obj = self.stack[-inst.arg].realize()
  2589. assert isinstance(obj, ConstDictVariable)
  2590. obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type]
  2591. def SET_ADD(self, inst: Instruction) -> None:
  2592. v = self.pop()
  2593. assert inst.argval > 0
  2594. assert inst.arg is not None
  2595. obj = self.stack[-inst.arg]
  2596. assert isinstance(obj, SetVariable)
  2597. assert obj.is_mutable()
  2598. obj.call_method(self, "add", [v], {})
  2599. def SET_UPDATE(self, inst: Instruction) -> None:
  2600. v = self.pop()
  2601. assert inst.argval > 0
  2602. assert inst.arg is not None
  2603. obj = self.stack[-inst.arg]
  2604. assert isinstance(obj, SetVariable)
  2605. assert obj.is_mutable()
  2606. obj.call_method(self, "update", [v], {})
  2607. def LIST_APPEND(self, inst: Instruction) -> None:
  2608. v = self.pop()
  2609. assert inst.argval > 0
  2610. assert inst.arg is not None
  2611. obj = self.stack[-inst.arg].realize()
  2612. assert isinstance(obj, ListVariable)
  2613. assert obj.is_mutable()
  2614. self.output.side_effects.mutation(obj)
  2615. obj.items.append(v)
  2616. def MAKE_FUNCTION(self, inst: Instruction) -> None:
  2617. flags = inst.arg
  2618. if sys.version_info < (3, 11):
  2619. fn_name = self.pop()
  2620. code = self.pop()
  2621. if sys.version_info >= (3, 11):
  2622. # MAKE_FUNCTION behavior actually changed in 3.11, see
  2623. # https://github.com/python/cpython/pull/93189/
  2624. assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined]
  2625. fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined]
  2626. defaults = None
  2627. closure = None
  2628. annotations = None
  2629. kwdefaults = None
  2630. if sys.version_info < (3, 13):
  2631. # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE
  2632. if flags is not None:
  2633. if flags & 0x08:
  2634. closure = self.pop()
  2635. if flags & 0x04:
  2636. annotations = self.pop()
  2637. if flags & 0x02:
  2638. kwdefaults = self.pop()
  2639. if flags & 0x01:
  2640. defaults = self.pop()
  2641. self.push(
  2642. NestedUserFunctionVariable(
  2643. fn_name,
  2644. code,
  2645. self.f_globals,
  2646. defaults,
  2647. kwdefaults,
  2648. annotations,
  2649. closure,
  2650. )
  2651. )
  2652. def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
  2653. seq = self.pop()
  2654. if isinstance(seq, TensorVariable):
  2655. val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type]
  2656. elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
  2657. # x, y = a.shape
  2658. proxy = getattr(seq.obj.as_proxy(), seq.name)
  2659. val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
  2660. elif seq.has_force_unpack_var_sequence(self):
  2661. val = seq.force_unpack_var_sequence(self)
  2662. else:
  2663. unimplemented_v2(
  2664. gb_type="Failed to unpack object for UNPACK_SEQUENCE",
  2665. context=str(seq),
  2666. explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode "
  2667. "(i.e. `a, b, c = d`).",
  2668. hints=[*graph_break_hints.USER_ERROR],
  2669. )
  2670. if len(val) != inst.argval:
  2671. unimplemented_v2(
  2672. gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
  2673. context=f"expected length: {inst.argval}, actual: {len(val)}",
  2674. explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
  2675. "(i.e. `a, b, c = d`) with unexpected length.",
  2676. hints=[*graph_break_hints.DYNAMO_BUG],
  2677. )
  2678. for i in reversed(val):
  2679. self.push(i)
  2680. def UNPACK_EX(self, inst: Instruction) -> None:
  2681. assert 0 <= inst.argval <= 0xFFFF
  2682. prefix = inst.argval & 0xFF # low byte
  2683. suffix = inst.argval >> 8 # high byte
  2684. seq = self.pop()
  2685. if seq.has_force_unpack_var_sequence(self):
  2686. vals = list(seq.force_unpack_var_sequence(self))
  2687. assert len(vals) >= prefix + suffix
  2688. vals_prefix = vals[:prefix]
  2689. vals_list = vals[prefix : len(vals) - suffix]
  2690. vals_suffix = vals[len(vals) - suffix :]
  2691. for item in reversed(vals_suffix):
  2692. self.push(item)
  2693. self.push(TupleVariable(vals_list))
  2694. for item in reversed(vals_prefix):
  2695. self.push(item)
  2696. else:
  2697. unimplemented_v2(
  2698. gb_type="Failed to unpack object for UNPACK_EX",
  2699. context=str(seq),
  2700. explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.",
  2701. hints=[*graph_break_hints.USER_ERROR],
  2702. )
  2703. @break_graph_if_unsupported(push=0)
  2704. def graph_break_on_leaf_function(self, inst: Instruction) -> None:
  2705. if self.is_leaf_tracer:
  2706. unimplemented_v2(
  2707. gb_type="Forced graph break on leaf function",
  2708. context="",
  2709. explanation="Forced graph break for nested graph break testing purposes",
  2710. hints=[
  2711. "Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False",
  2712. ],
  2713. )
  2714. def NOP(self, inst: Instruction) -> None:
  2715. # Dynamo-specific testing behavior
  2716. if inst.argval == "GRAPH_BREAK_IF_LEAF":
  2717. self.graph_break_on_leaf_function(inst)
  2718. def POP_TOP(self, inst: Instruction) -> None:
  2719. self.pop()
  2720. def ROT_TWO(self, inst: Instruction) -> None:
  2721. a = self.pop()
  2722. b = self.pop()
  2723. self.push(a)
  2724. self.push(b)
  2725. def ROT_THREE(self, inst: Instruction) -> None:
  2726. a = self.pop()
  2727. b = self.pop()
  2728. c = self.pop()
  2729. self.push(a)
  2730. self.push(c)
  2731. self.push(b)
  2732. def ROT_FOUR(self, inst: Instruction) -> None:
  2733. a = self.pop()
  2734. b = self.pop()
  2735. c = self.pop()
  2736. d = self.pop()
  2737. self.push(a)
  2738. self.push(d)
  2739. self.push(c)
  2740. self.push(b)
  2741. def DUP_TOP(self, inst: Instruction) -> None:
  2742. a = self.pop()
  2743. self.push(a)
  2744. self.push(a)
  2745. def DUP_TOP_TWO(self, inst: Instruction) -> None:
  2746. a = self.pop()
  2747. b = self.pop()
  2748. self.push(b)
  2749. self.push(a)
  2750. self.push(b)
  2751. self.push(a)
  2752. def _convert_value(self, value: VariableTracker, flag: int) -> VariableTracker:
  2753. if flag == 1:
  2754. return BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type]
  2755. elif flag == 2:
  2756. return BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type]
  2757. elif flag == 3:
  2758. return BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type]
  2759. return value
  2760. def _format_value(self, fmt_spec: VariableTracker, flags: int) -> None:
  2761. value = self.pop()
  2762. if isinstance(value, SymNodeVariable):
  2763. from torch._dynamo.variables.lazy import (
  2764. LazySymNodeFormatString,
  2765. LazyVariableTracker,
  2766. )
  2767. value = LazyVariableTracker.create(
  2768. LazySymNodeFormatString(value, fmt_spec), source=value.source
  2769. )
  2770. self.push(value)
  2771. return
  2772. value = self._convert_value(value, flags & 0x03)
  2773. fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}")
  2774. self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
  2775. def FORMAT_VALUE(self, inst: Instruction) -> None:
  2776. flags = inst.arg
  2777. assert flags is not None
  2778. if (flags & 0x04) == 0x04:
  2779. fmt_spec = self.pop()
  2780. else:
  2781. fmt_spec = ConstantVariable.create("")
  2782. return self._format_value(fmt_spec, flags)
  2783. def BUILD_STRING(self, inst: Instruction) -> None:
  2784. format_string_parts: list[str] = []
  2785. args: list[VariableTracker] = []
  2786. kwargs: dict[str, VariableTracker] = {}
  2787. assert inst.arg is not None
  2788. for part in self.popn(inst.arg):
  2789. if isinstance(part, ConstantVariable):
  2790. format_string_parts.append("{}")
  2791. args.append(part)
  2792. elif isinstance(part, variables.StringFormatVariable):
  2793. format_string_parts.append(part.format_string)
  2794. args.extend(part.sym_args)
  2795. if set(kwargs.keys()) & set(part.sym_kwargs.keys()):
  2796. unimplemented_v2(
  2797. gb_type="BUILD_STRING key conflict",
  2798. context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}",
  2799. explanation="Failed to build format string due to key conflict",
  2800. hints=[*graph_break_hints.USER_ERROR],
  2801. )
  2802. kwargs.update(part.sym_kwargs)
  2803. else:
  2804. unimplemented_v2(
  2805. gb_type="BUILD_STRING type error",
  2806. context=str(part),
  2807. explanation="Format string part type is not correct - expected constant or format string.",
  2808. hints=[*graph_break_hints.USER_ERROR],
  2809. )
  2810. self.push(
  2811. variables.StringFormatVariable.create(
  2812. "".join(format_string_parts), args, kwargs
  2813. )
  2814. )
  2815. def IS_OP(self, inst: Instruction) -> None:
  2816. assert inst.argval == 0 or inst.argval == 1
  2817. if inst.argval == 0:
  2818. new_argval = "is"
  2819. else:
  2820. new_argval = "is not"
  2821. new_inst = create_instruction("COMPARE_OP", argval=new_argval)
  2822. self.COMPARE_OP(new_inst)
  2823. def CONTAINS_OP(self, inst: Instruction) -> None:
  2824. assert inst.argval == 0 or inst.argval == 1
  2825. left, right = self.popn(2)
  2826. op = inst.argval
  2827. try:
  2828. self.push(right.call_method(self, "__contains__", [left], {}))
  2829. except (
  2830. # right.__contains__ can raise TypeError
  2831. exc.ObservedTypeError,
  2832. # Ideally we should only capture TypeError here but some VTs don't
  2833. # implement hasattr(vt, "__contains__") entirely
  2834. Unsupported,
  2835. ) as excp: # object doesn't support __contains__
  2836. # Use __iter__ as fallback
  2837. if isinstance(excp, Unsupported):
  2838. excp.remove_from_stats()
  2839. self.push(
  2840. self.inline_user_function_return(
  2841. VariableTracker.build(self, impl_CONTAINS_OP_fallback),
  2842. [left, right],
  2843. {},
  2844. )
  2845. )
  2846. if op == 1:
  2847. self.UNARY_NOT(inst)
  2848. def LIST_EXTEND(self, inst: Instruction) -> None:
  2849. v = self.pop()
  2850. assert inst.argval > 0
  2851. assert inst.arg is not None
  2852. obj = self.stack[-inst.arg]
  2853. assert isinstance(obj, ListVariable)
  2854. assert obj.is_mutable()
  2855. obj.call_method(self, "extend", [v], {})
  2856. def LIST_TO_TUPLE(self, inst: Instruction) -> None:
  2857. self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
  2858. def STOPITERATION_ERROR(self, inst: Instruction) -> None:
  2859. # wrap the generator body in a try: ... except StopIteration: ... which
  2860. # converts the StopIteration into a RuntimeError
  2861. # https://peps.python.org/pep-0479/
  2862. # https://github.com/python/cpython/pull/99006
  2863. # https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2
  2864. val = self.stack[-1]
  2865. assert self._isinstance_exception(val)
  2866. if val.exc_type is StopIteration: # type: ignore[union-attr]
  2867. new_val = variables.BuiltinVariable(RuntimeError).call_function(
  2868. self, # type: ignore[arg-type]
  2869. [ConstantVariable("generator raised StopIteration")],
  2870. {},
  2871. )
  2872. new_val.call_setattr(self, ConstantVariable("__context__"), val) # type: ignore[attr-defined]
  2873. new_val.call_setattr(self, ConstantVariable("__cause__"), val) # type: ignore[attr-defined]
  2874. self.stack[-1] = new_val
  2875. def DICT_MERGE(self, inst: Instruction) -> None:
  2876. v = self.pop()
  2877. assert inst.argval > 0
  2878. assert inst.arg is not None
  2879. obj = self.stack[-inst.arg].realize()
  2880. assert isinstance(obj, ConstDictVariable)
  2881. assert obj.is_mutable()
  2882. obj.call_method(self, "update", [v], {})
  2883. DICT_UPDATE = DICT_MERGE
  2884. def GEN_START(self, inst: Instruction) -> None:
  2885. self.pop()
  2886. def GET_LEN(self, inst: Instruction) -> None:
  2887. tos = self.stack[-1]
  2888. if tos.is_python_constant():
  2889. self.push(ConstantVariable.create(len(tos.as_python_constant())))
  2890. else:
  2891. self.push(tos.call_method(self, "__len__", [], {}))
  2892. def MATCH_MAPPING(self, inst: Instruction) -> None:
  2893. tos = self.stack[-1]
  2894. assert isinstance(tos, ConstDictVariable)
  2895. if isinstance(tos.items, collections.abc.Mapping):
  2896. self.push(ConstantVariable.create(True))
  2897. else:
  2898. self.push(ConstantVariable.create(False))
  2899. def MATCH_SEQUENCE(self, inst: Instruction) -> None:
  2900. tos = self.stack[-1]
  2901. assert tos.is_python_constant()
  2902. tos_value = tos.as_python_constant()
  2903. if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
  2904. tos_value, (str, bytes, bytearray)
  2905. ):
  2906. self.push(ConstantVariable.create(True))
  2907. else:
  2908. self.push(ConstantVariable.create(False))
  2909. def MATCH_KEYS(self, inst: Instruction) -> None:
  2910. tos = self.stack[-1]
  2911. tos1 = self.stack[-2]
  2912. assert isinstance(tos1, ConstDictVariable)
  2913. if all(k in tos1 for k in tos): # type: ignore[attr-defined]
  2914. self.push(TupleVariable([tos1.getitem_const(self, k) for k in tos])) # type: ignore[attr-defined,arg-type]
  2915. if sys.version_info < (3, 11):
  2916. self.push(ConstantVariable.create(True))
  2917. else:
  2918. self.push(ConstantVariable.create(None))
  2919. if sys.version_info < (3, 11):
  2920. self.push(ConstantVariable.create(False))
  2921. def LOAD_ASSERTION_ERROR(self, inst: Instruction) -> None:
  2922. self.push(self.load_builtin_from_argval("AssertionError"))
  2923. def LOAD_BUILD_CLASS(self, inst: Instruction) -> None:
  2924. unimplemented_v2(
  2925. gb_type="LOAD_BUILD_CLASS bytecode not supported",
  2926. context="",
  2927. explanation="Dynamo does not support tracing classes that are defined in the compiled region.",
  2928. hints=[
  2929. "Move the class definition out of the compiled region.",
  2930. *graph_break_hints.SUPPORTABLE,
  2931. ],
  2932. )
  2933. UNARY_POSITIVE = stack_op(operator.pos)
  2934. UNARY_NEGATIVE = stack_op(operator.neg)
  2935. UNARY_NOT = stack_op(operator.not_)
  2936. UNARY_INVERT = stack_op(operator.invert)
  2937. BINARY_POWER = stack_op(operator.pow)
  2938. BINARY_MULTIPLY = stack_op(operator.mul)
  2939. BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
  2940. BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
  2941. BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
  2942. BINARY_MODULO = stack_op(operator.mod)
  2943. BINARY_REMAINDER = stack_op(operator.mod)
  2944. BINARY_ADD = stack_op(operator.add)
  2945. BINARY_SUBTRACT = stack_op(operator.sub)
  2946. BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
  2947. BINARY_LSHIFT = stack_op(operator.lshift)
  2948. BINARY_RSHIFT = stack_op(operator.rshift)
  2949. BINARY_AND = stack_op(operator.and_)
  2950. BINARY_OR = stack_op(operator.or_)
  2951. BINARY_XOR = stack_op(operator.xor)
  2952. INPLACE_POWER = stack_op(operator.ipow)
  2953. INPLACE_MULTIPLY = stack_op(operator.imul)
  2954. INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
  2955. INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
  2956. INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
  2957. INPLACE_MODULO = stack_op(operator.imod)
  2958. INPLACE_REMAINDER = stack_op(operator.imod)
  2959. INPLACE_ADD = stack_op(operator.iadd)
  2960. INPLACE_SUBTRACT = stack_op(operator.isub)
  2961. INPLACE_LSHIFT = stack_op(operator.ilshift)
  2962. INPLACE_RSHIFT = stack_op(operator.irshift)
  2963. INPLACE_AND = stack_op(operator.iand)
  2964. INPLACE_XOR = stack_op(operator.ixor)
  2965. INPLACE_OR = stack_op(operator.ior)
  2966. # 3.11 opcodes
  2967. def RESUME(self, inst: Instruction) -> None:
  2968. if inst.arg == 0:
  2969. self.append_prefix_inst(inst)
  2970. self.accept_prefix_inst = False
  2971. else:
  2972. assert not self.accept_prefix_inst
  2973. if sys.version_info >= (3, 11):
  2974. def BINARY_OP(self, inst: Instruction) -> None:
  2975. assert inst.arg is not None
  2976. return _binary_op_lookup[inst.arg](self, inst)
  2977. def PRECALL(self, inst: Instruction) -> None:
  2978. pass
  2979. def KW_NAMES(self, inst: Instruction) -> None:
  2980. kw_names = self.code_options["co_consts"][inst.arg]
  2981. assert isinstance(kw_names, tuple)
  2982. for name in kw_names:
  2983. assert isinstance(name, str)
  2984. assert self.kw_names is None
  2985. self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment]
  2986. def PUSH_NULL(self, inst: Instruction) -> None:
  2987. self.push(NullVariable())
  2988. def _call(self, inst: Instruction, call_kw: bool = False) -> None:
  2989. # see https://docs.python.org/3.11/library/dis.html#opcode-CALL
  2990. # for convention
  2991. if call_kw:
  2992. # TOS is kw_names for CALL_KW instruction
  2993. assert sys.version_info >= (3, 13)
  2994. kw_names = self.pop()
  2995. assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant()
  2996. kw_names = kw_names.as_python_constant()
  2997. else:
  2998. kw_names = self.kw_names.value if self.kw_names else ()
  2999. assert inst.arg is not None
  3000. contents = self.popn(inst.arg + 2)
  3001. if sys.version_info >= (3, 13):
  3002. # NULL and callable swapped
  3003. fn = contents[0]
  3004. args = [] if isinstance(contents[1], NullVariable) else [contents[1]]
  3005. else:
  3006. if isinstance(contents[0], NullVariable):
  3007. fn = contents[1]
  3008. args = []
  3009. else:
  3010. fn = contents[0]
  3011. args = [contents[1]]
  3012. if kw_names:
  3013. args = args + contents[2 : -len(kw_names)]
  3014. kwargs_list = contents[-len(kw_names) :]
  3015. kwargs = dict(zip(kw_names, kwargs_list))
  3016. assert len(kwargs) == len(kw_names)
  3017. else:
  3018. args = args + contents[2:]
  3019. kwargs = {}
  3020. try:
  3021. # if call_function fails, need to set kw_names to None, otherwise
  3022. # a subsequent call may have self.kw_names set to an old value
  3023. self.call_function(fn, args, kwargs)
  3024. finally:
  3025. self.kw_names = None
  3026. @break_graph_if_unsupported(push=1)
  3027. def CALL(self, inst: Instruction) -> None:
  3028. self._call(inst)
  3029. def COPY(self, inst: Instruction) -> None:
  3030. assert inst.arg is not None
  3031. self.push(self.stack[-inst.arg])
  3032. def SWAP(self, inst: Instruction) -> None:
  3033. assert inst.arg is not None
  3034. self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1]
  3035. JUMP_BACKWARD = jump
  3036. JUMP_BACKWARD_NO_INTERRUPT = jump
  3037. POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False)
  3038. POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False)
  3039. POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
  3040. POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)
  3041. def CACHE(self, inst: Instruction) -> None:
  3042. pass
  3043. def BEFORE_WITH(self, inst: Instruction) -> None:
  3044. self.setup_or_before_with(inst)
  3045. def setup_or_before_with(self, inst: Instruction) -> None:
  3046. ctx = self.pop()
  3047. if not isinstance(
  3048. ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
  3049. ):
  3050. unimplemented_v2(
  3051. gb_type="Unsupported context manager",
  3052. context=f"Attempted SETUP_WITH/BEFORE_WITH on {ctx}",
  3053. explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.",
  3054. hints=[
  3055. "Avoid using the unsupported context manager.",
  3056. "If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then "
  3057. "it may be the case that it was created outside the compiled region, which Dynamo does not support. "
  3058. "Supported context managers can cross graph break boundaries only if they are local non-closure "
  3059. "variables, or are intermediate values.",
  3060. "File an issue to PyTorch. Simple context managers can potentially be supported, "
  3061. "but note that context managers can't be supported in general",
  3062. ],
  3063. )
  3064. if (
  3065. isinstance(ctx, GenericContextWrappingVariable)
  3066. and not ctx.supports_graph_breaks()
  3067. ):
  3068. self.active_generic_context_managers.append(ctx)
  3069. # Need this redundant check for mypy
  3070. assert isinstance(
  3071. ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
  3072. )
  3073. exit = WithExitFunctionVariable(
  3074. ctx,
  3075. inst.target,
  3076. )
  3077. if sys.version_info >= (3, 11):
  3078. # See create_call_resume_at for block stack details.
  3079. # Only push a block if the current instruction's block is a
  3080. # with block that is not nested in a try block - that is, the current
  3081. # instruction's block target is the same as the top block's target.
  3082. if inst.exn_tab_entry and (
  3083. not self.block_stack
  3084. or inst.exn_tab_entry.target is not self.block_stack[-1].target
  3085. ):
  3086. target = None
  3087. else:
  3088. assert self.next_instruction.exn_tab_entry is not None
  3089. target = self.next_instruction.exn_tab_entry.target
  3090. else:
  3091. target = inst.target
  3092. self.push(exit)
  3093. if target:
  3094. if isinstance(self, InstructionTranslator) or config.nested_graph_breaks:
  3095. self.block_stack.append(
  3096. BlockStackEntry(inst, target, len(self.stack), ctx)
  3097. )
  3098. else:
  3099. self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
  3100. self.push(ctx.enter(self))
  3101. def append_prefix_inst(self, inst: Instruction) -> None:
  3102. assert self.accept_prefix_inst
  3103. self.prefix_insts.append(inst)
  3104. def MAKE_CELL(self, inst: Instruction) -> None:
  3105. if sys.version_info >= (3, 12) and not self.accept_prefix_inst:
  3106. # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction.
  3107. # It can be generated by inlined comprehensions.
  3108. assert isinstance(self.symbolic_locals[inst.argval], NullVariable)
  3109. self.symbolic_locals[inst.argval] = (
  3110. self.output.side_effects.track_cell_new()
  3111. )
  3112. else:
  3113. self.append_prefix_inst(inst)
  3114. def COPY_FREE_VARS(self, inst: Instruction) -> None:
  3115. self.append_prefix_inst(inst)
  3116. def RETURN_GENERATOR(self, inst: Instruction) -> None:
  3117. self.append_prefix_inst(inst)
  3118. # 3.12 opcodes
  3119. # BINARY/STORE_SLICE opcodes are broken down into
  3120. # BUILD_SLICE 2 and BINARY/STORE_SUBSCR
  3121. def END_FOR(self, inst: Instruction) -> None:
  3122. if sys.version_info >= (3, 13):
  3123. self.pop()
  3124. else:
  3125. self.popn(2)
  3126. def LOAD_FAST_CHECK(self, inst: Instruction) -> None:
  3127. if istype(self.symbolic_locals.get(inst.argval, None), NullVariable):
  3128. unimplemented_v2(
  3129. gb_type="LOAD_FAST_CHECK on uninitialized variable",
  3130. context=inst.argval,
  3131. explanation=f"Attempted to load uninitialized local variable {inst.argval}",
  3132. hints=[*graph_break_hints.USER_ERROR],
  3133. )
  3134. self.LOAD_FAST(inst)
  3135. def LOAD_FAST_AND_CLEAR(self, inst: Instruction) -> None:
  3136. if inst.argval not in self.symbolic_locals:
  3137. self.push(NullVariable())
  3138. else:
  3139. self.LOAD_FAST(inst)
  3140. self.symbolic_locals[inst.argval] = NullVariable()
  3141. def LOAD_SUPER_ATTR(self, inst: Instruction) -> None:
  3142. self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
  3143. assert inst.arg is not None
  3144. if inst.arg & 1:
  3145. self.LOAD_METHOD(inst)
  3146. else:
  3147. self._load_attr(inst)
  3148. def CALL_INTRINSIC_1(self, inst: Instruction) -> None:
  3149. if inst.argval == 3:
  3150. # INTRINSIC_STOPITERATION_ERROR
  3151. self.STOPITERATION_ERROR(inst)
  3152. elif inst.argval == 5:
  3153. # INTRINSIC_UNARY_POSITIVE
  3154. self.UNARY_POSITIVE(inst)
  3155. elif inst.argval == 6:
  3156. # INTRINSIC_LIST_TO_TUPLE
  3157. self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
  3158. else:
  3159. unimplemented_v2(
  3160. gb_type="Missing CALL_INTRINSIC_1 handler",
  3161. context=f"CALL_INTRINSIC_1 operand: {inst.argval}",
  3162. explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.",
  3163. hints=[*graph_break_hints.SUPPORTABLE],
  3164. )
  3165. def END_SEND(self, inst: Instruction) -> None:
  3166. tos = self.pop()
  3167. self.pop()
  3168. self.push(tos)
  3169. # 3.13 opcodes
  3170. # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST
  3171. # are broken down.
  3172. @break_graph_if_unsupported(push=1)
  3173. def CALL_KW(self, inst: Instruction) -> None:
  3174. self._call(inst, call_kw=True)
  3175. def TO_BOOL(self, inst: Instruction) -> None:
  3176. # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython)
  3177. # So we can skip this instruction as long as we remember to codegen a TO_BOOL
  3178. # before conditional jumps/UNARY_NOT.
  3179. assert self.next_instruction.opname in (
  3180. "POP_JUMP_IF_TRUE",
  3181. "POP_JUMP_IF_FALSE",
  3182. "UNARY_NOT",
  3183. )
  3184. def SET_FUNCTION_ATTRIBUTE(self, inst: Instruction) -> None:
  3185. flags = inst.arg
  3186. assert flags is not None
  3187. fn = self.pop()
  3188. assert isinstance(fn, NestedUserFunctionVariable)
  3189. attr = self.pop()
  3190. if flags & 0x08:
  3191. fn.closure = attr
  3192. elif flags & 0x04:
  3193. fn.annotations = attr
  3194. elif flags & 0x02:
  3195. fn.kwdefaults = attr
  3196. elif flags & 0x01:
  3197. fn.defaults = attr
  3198. self.push(fn)
  3199. def CONVERT_VALUE(self, inst: Instruction) -> None:
  3200. self.push(self._convert_value(self.pop(), inst.argval))
  3201. def FORMAT_SIMPLE(self, inst: Instruction) -> None:
  3202. self._format_value(ConstantVariable.create(""), 0)
  3203. def FORMAT_WITH_SPEC(self, inst: Instruction) -> None:
  3204. self._format_value(self.pop(), 0)
  3205. def is_non_empty_graph(self) -> bool:
  3206. if self.output.count_calls() > 1:
  3207. # perf optimization only
  3208. self.is_non_empty_graph = lambda: True # type: ignore[method-assign]
  3209. return True
  3210. return False
  3211. def format_frame_summary(
  3212. self, additional_stack_frames: Optional[list[Any]] = None
  3213. ) -> str:
  3214. if additional_stack_frames is None:
  3215. additional_stack_frames = []
  3216. return "".join(
  3217. traceback.format_list(
  3218. [self.frame_summary()] + list(reversed(additional_stack_frames))
  3219. )
  3220. )
  3221. def frame_summary(self) -> traceback.FrameSummary:
  3222. return traceback.FrameSummary(
  3223. getattr(self.f_code, "co_filename", "<unknown>"),
  3224. self.lineno,
  3225. getattr(self.f_code, "co_name", "<unknown>"),
  3226. lookup_line=False,
  3227. )
  3228. def is_co_filename_from_nn_modules(self) -> bool:
  3229. filename = getattr(self.f_code, "co_filename", "<unknown>")
  3230. nn_modules_pattern = re.compile(r".*torch/nn/modules.*")
  3231. return nn_modules_pattern.match(filename) is not None
  3232. def store_global_weakref_by_id(self, prefix: str, value: Any) -> str:
  3233. global_name = self.output.install_global_by_id(prefix, weakref.ref(value))
  3234. install_guard(
  3235. GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE)
  3236. )
  3237. return global_name
  3238. @property
  3239. def fake_mode(self) -> Optional[FakeTensorMode]:
  3240. return self.output.tracing_context.fake_mode
  3241. @contextlib.contextmanager
  3242. def strict_translation_mode(
  3243. self, check_fn: Callable[[VariableTracker], bool]
  3244. ) -> Any:
  3245. """
  3246. Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node).
  3247. """
  3248. prior = self.strict_checks_fn
  3249. self.strict_checks_fn = check_fn
  3250. try:
  3251. yield
  3252. finally:
  3253. self.strict_checks_fn = prior
  3254. def speculate(self) -> SpeculationEntry:
  3255. assert self.instruction_pointer is not None
  3256. assert self.instruction_pointer > 0
  3257. return self.speculation_log.next(
  3258. self.f_code.co_filename,
  3259. self.lineno,
  3260. self.instruction_pointer - 1,
  3261. self.instructions[self.instruction_pointer - 1],
  3262. )
  3263. def __init__(
  3264. self,
  3265. output: OutputGraph,
  3266. instructions: list[Instruction],
  3267. f_locals: dict[str, Any],
  3268. f_globals: dict[str, Any],
  3269. f_builtins: dict[str, Any],
  3270. code_options: dict[str, Any],
  3271. symbolic_locals: dict[str, VariableTracker],
  3272. symbolic_globals: dict[str, VariableTracker],
  3273. symbolic_torch_function_state: SymbolicTorchFunctionState,
  3274. f_code: types.CodeType,
  3275. export: bool,
  3276. inline_depth: int,
  3277. speculation_log: SpeculationLog,
  3278. exn_vt_stack: ExceptionStack,
  3279. distributed_state: Optional[DistributedState],
  3280. # This determines whether to use the execution recorder.
  3281. closure: Optional[tuple[types.CellType]] = None,
  3282. package: Optional[CompilePackage] = None,
  3283. ) -> None:
  3284. super().__init__()
  3285. self.speculation_log = speculation_log
  3286. self.distributed_state = distributed_state
  3287. # Mutable state checkpointed by copy_graphstate()
  3288. self.output = output
  3289. self.symbolic_locals = symbolic_locals
  3290. self.symbolic_globals = symbolic_globals
  3291. self.symbolic_torch_function_state = symbolic_torch_function_state
  3292. # used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
  3293. # in order to generate any nested closures
  3294. self.post_prune_cell_and_freevars = None
  3295. self.stack: list[VariableTracker] = []
  3296. self.instruction_pointer = 0
  3297. self.start_point = None
  3298. self.current_instruction = create_instruction("NOP")
  3299. self.block_stack = []
  3300. # states before SETUP_WITH for checkpointing and fallback
  3301. self.active_generic_context_managers: list[GenericContextWrappingVariable] = []
  3302. self.lineno = -1
  3303. self.kw_names = None
  3304. self.accept_prefix_inst = True
  3305. self.prefix_insts = []
  3306. self.exn_vt_stack = exn_vt_stack
  3307. # Properties of the input/output code
  3308. self.instructions: list[Instruction] = instructions
  3309. self.indexof: dict[Instruction, int] = get_indexof(self.instructions)
  3310. self.f_locals: dict[str, Any] = (
  3311. f_locals # needed for recording accessed locals for replay
  3312. )
  3313. self.f_globals: dict[str, Any] = f_globals
  3314. self.f_builtins: dict[str, Any] = f_builtins
  3315. self.code_options: dict[str, Any] = code_options
  3316. self.f_code: types.CodeType = f_code
  3317. # Execution record for replaying errors
  3318. if closure is not None and config.replay_record_enabled:
  3319. self.exec_recorder = ExecutionRecorder(
  3320. code=f_code, closure=closure, code_options=code_options
  3321. )
  3322. else:
  3323. self.exec_recorder = None
  3324. # Stack of module being parsed, current nn.module is at the end of ordered dict.
  3325. # The first field of tuple is the fully qualified name of current module
  3326. # in original hierarchy. The second field is the type of current nn.module
  3327. self.nn_module_stack: dict[str, tuple[str, type[Any]]] = {}
  3328. self.num_calls: dict[str, int] = {}
  3329. # Flag to indicate whether tracing is used for export.
  3330. self.export = export
  3331. # NOTE: one_graph is used for export/fullgraph=True to always force errors on graph breaks.
  3332. # To toggle erroring/resuming on graph breaks during fullgraph=False compile, self.error_on_graph_break
  3333. # is used instead. Every step(), its value is updated to the global tls.error_on_graph_break.
  3334. # We mirror this value since cleanup may (correctly) inadvertently change tls.error_on_graph_break.
  3335. # This assumes that we cannot both trace a change to tls.error_on_graph_break and graph break on
  3336. # the same instruction.
  3337. self.one_graph = False
  3338. self.error_on_graph_break = False
  3339. # Also do not graph break when tracing resume function prologues
  3340. self.is_tracing_resume_prologue = False
  3341. self.current_speculation = None
  3342. self.strict_checks_fn = None
  3343. self.is_leaf_tracer = True
  3344. self.parent = None
  3345. self.debug_locals = []
  3346. self.package = package
  3347. if sys.version_info >= (3, 10):
  3348. from .resume_execution import (
  3349. CO_ASYNC_GENERATOR,
  3350. CO_COROUTINE,
  3351. CO_GENERATOR,
  3352. CO_ITERABLE_COROUTINE,
  3353. )
  3354. if f_code.co_flags & (
  3355. CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
  3356. ):
  3357. self.push(BuiltinVariable(None))
  3358. self.inline_depth = inline_depth
  3359. self.inconsistent_side_effects = False
  3360. self._constants_cache: list[Optional[ConstantVariable]] = [None] * len(
  3361. f_code.co_consts
  3362. )
  3363. self.is_trace_bytecode_log_enabled: Optional[bool] = (
  3364. trace_bytecode_log.isEnabledFor(logging.DEBUG)
  3365. )
  3366. self.is_trace_source_log_enabled: Optional[bool] = (
  3367. trace_source_log.isEnabledFor(logging.DEBUG)
  3368. )
  3369. linecache.lazycache(f_code.co_filename, f_globals)
  3370. class InstructionTranslator(InstructionTranslatorBase):
  3371. @staticmethod
  3372. def current_tx() -> InstructionTranslator:
  3373. return tls.current_tx
  3374. @contextlib.contextmanager
  3375. def set_current_tx(self) -> Any:
  3376. prior = getattr(tls, "current_tx", None)
  3377. tls.current_tx = self
  3378. try:
  3379. yield
  3380. finally:
  3381. tls.current_tx = prior
  3382. def __init__(
  3383. self,
  3384. instructions: list[Instruction],
  3385. f_code: types.CodeType,
  3386. f_locals: dict[str, Any],
  3387. f_globals: dict[str, Any],
  3388. f_builtins: dict[str, Any],
  3389. closure: Optional[tuple[Any, ...]],
  3390. torch_function_mode_stack: Any,
  3391. code_options: dict[str, Any],
  3392. compiler_fn: Any,
  3393. one_graph: bool,
  3394. export: bool,
  3395. export_constraints: Any,
  3396. frame_state: Any,
  3397. speculation_log: SpeculationLog,
  3398. exn_vt_stack: ExceptionStack,
  3399. distributed_state: Optional[DistributedState],
  3400. package: Optional[CompilePackage],
  3401. ) -> None:
  3402. _step_logger()(
  3403. logging.INFO,
  3404. f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}",
  3405. )
  3406. super().__init__(
  3407. output=OutputGraph(
  3408. code_options,
  3409. compiler_fn,
  3410. self,
  3411. export,
  3412. export_constraints,
  3413. frame_state,
  3414. local_scope=f_locals,
  3415. global_scope=f_globals,
  3416. f_code=f_code,
  3417. torch_function_mode_stack=torch_function_mode_stack,
  3418. package=package,
  3419. ),
  3420. instructions=instructions,
  3421. f_locals=f_locals,
  3422. f_globals=f_globals,
  3423. f_builtins=f_builtins,
  3424. closure=closure,
  3425. code_options=code_options,
  3426. symbolic_locals={}, # set below
  3427. # A global var is inserted only after a STORE_GLOBAL happens to it
  3428. symbolic_globals={},
  3429. symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
  3430. f_code=f_code,
  3431. export=export,
  3432. inline_depth=0,
  3433. speculation_log=speculation_log,
  3434. exn_vt_stack=exn_vt_stack,
  3435. distributed_state=distributed_state,
  3436. package=package,
  3437. )
  3438. self._throw_if_in_functorch()
  3439. # as soon as we create the tracing context we should keep it active, so any calls
  3440. # into dynamo apis can rely on finding it
  3441. with tracing(self.output.tracing_context), self.set_current_tx():
  3442. self.one_graph: bool = one_graph
  3443. self.export = export
  3444. if self.export:
  3445. assert self.one_graph, (
  3446. "Export without one graph - something has gone wrong."
  3447. )
  3448. self.symbolic_locals = {}
  3449. # Populate `symbolic_locals` with non-cell variables.
  3450. cell_and_freevars: set[str] = set(self.cell_and_freevars())
  3451. dynamism = code_context.get_context(f_code).get("dynamism", None)
  3452. for name, value in f_locals.items():
  3453. if name not in cell_and_freevars:
  3454. local_dynamism = None
  3455. if dynamism:
  3456. local_dynamism = frozenset(dynamism.get(name, {}).items())
  3457. var = LazyVariableTracker.create(
  3458. value,
  3459. LocalSource(
  3460. name,
  3461. is_input=True,
  3462. dynamism=local_dynamism,
  3463. ),
  3464. )
  3465. self.symbolic_locals[name] = var
  3466. # Populate `symbolic_locals` with cells created by this frame,
  3467. # effectively implementing the `MAKE_CELL` instructions.
  3468. side_effects = self.output.side_effects
  3469. for name in self.cellvars():
  3470. if name in f_locals:
  3471. # This models cells that are also function inputs.
  3472. value = f_locals[name]
  3473. # NOTE: root frame inputs that are captured by a nested
  3474. # function become special cell objects -- they exist in
  3475. # `f_locals` as contents of the cells, rather than the cells
  3476. # objects themselves.
  3477. #
  3478. # In Dynamo, we choose to represent such input cell objects
  3479. # as newly created (rather than pre-existing) cell objects,
  3480. # because
  3481. #
  3482. # 1. The reason for representing a pre-existing cell object
  3483. # is to emit guard or codegen mutations. However, local
  3484. # cells should never be used for guards. Moreover, at this
  3485. # point these input cell objects should've never been
  3486. # accessed by anyone else, since Dynamo intercepts the frame
  3487. # right after its evaluation starts, i.e., right after these
  3488. # cell objects are created. So they should have no external
  3489. # reference, meaning no mutation needs to be propagated.
  3490. #
  3491. # 2. This conveniently allows codegen to prune away
  3492. # mutations to these cells, unless they escape the frame.
  3493. contents_source = LocalSource(
  3494. name, is_input=True, is_derefed_cell_contents=True
  3495. )
  3496. contents_var: VariableTracker = LazyVariableTracker.create(
  3497. value, contents_source
  3498. )
  3499. cell_var = side_effects.track_cell_new()
  3500. side_effects.store_cell(cell_var, contents_var)
  3501. else:
  3502. cell_var = side_effects.track_cell_new()
  3503. cell_var.local_name = name # type: ignore[attr-defined]
  3504. self.symbolic_locals[name] = cell_var
  3505. # Populate `symbolic_locals` with cells captured by this frame,
  3506. # effectively implementing the `COPY_FREE_VARS` instruction.
  3507. assert closure is not None
  3508. for name, cell in zip(self.freevars(), closure):
  3509. cell_source = LocalCellSource(name)
  3510. contents_source = LocalSource(name, is_derefed_cell_contents=True)
  3511. try:
  3512. contents_var = LazyVariableTracker.create(
  3513. cell.cell_contents, contents_source
  3514. )
  3515. except ValueError:
  3516. # Cell has not yet been assigned
  3517. contents_var = variables.DeletedVariable()
  3518. cell_var = side_effects.track_cell_existing(
  3519. cell_source, cell, contents_var
  3520. )
  3521. cell_var.local_name = name # type: ignore[attr-defined]
  3522. self.symbolic_locals[name] = cell_var
  3523. self.symbolic_torch_function_state = SymbolicTorchFunctionState(
  3524. torch_function_mode_stack
  3525. )
  3526. if export:
  3527. # export gets confused if we never realize unused inputs
  3528. # in export mode just eagerly realize everything
  3529. self.symbolic_locals = variables.LazyVariableTracker.realize_all(
  3530. self.symbolic_locals
  3531. )
  3532. def _throw_if_in_functorch(self) -> None:
  3533. # Fallback to eager in case of a graph break inside vmap
  3534. eager = torch._dynamo.lookup_backend("eager")
  3535. compiler_fn = inspect.getattr_static(
  3536. self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
  3537. )
  3538. ci = torch._C._functorch.peek_interpreter_stack()
  3539. forbidden_keys = (
  3540. torch._C._functorch.TransformType.Vmap,
  3541. torch._C._functorch.TransformType.Grad,
  3542. torch._C._functorch.TransformType.Jvp,
  3543. )
  3544. if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager:
  3545. name = ci.key().name.lower()
  3546. msg = (
  3547. "If you are reaching here, it means dynamo failed for one of the following reasons:\n"
  3548. # Calling a torch.compiled function
  3549. f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. "
  3550. f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. "
  3551. "For more information, see PyTorch issue #128711.\n"
  3552. # if it reaches here, it means Dynamo failed to inline a functorch function
  3553. f"- torch.func.{name}(fn) requires the function to be inlined by dynamo"
  3554. )
  3555. unimplemented_v2(
  3556. gb_type="Unsupported functorch tracing attempt",
  3557. context="",
  3558. explanation=msg,
  3559. hints=[],
  3560. )
  3561. def get_example_value(self, source: Source) -> Any:
  3562. if isinstance(source, LocalSource):
  3563. return self.f_locals[source.local_name]
  3564. if isinstance(source, GlobalSource):
  3565. return self.f_globals[source.global_name]
  3566. raise KeyError
  3567. def symbolic_locals_contain_module_class(self) -> bool:
  3568. for v in self.symbolic_locals.values():
  3569. if isinstance(v, UserDefinedClassVariable) and issubclass(
  3570. v.as_python_constant(), torch.nn.Module
  3571. ):
  3572. return True
  3573. return False
  3574. def replace_tos_if_return_is_generator(self) -> None:
  3575. if (
  3576. len(self.stack)
  3577. and (tos := self.stack[-1])
  3578. and isinstance(tos, LocalGeneratorObjectVariable)
  3579. ):
  3580. self.stack[-1] = ListIteratorVariable(
  3581. tos.force_unpack_var_sequence(self),
  3582. mutation_type=ValueMutationNew(),
  3583. )
  3584. def _return(self, inst: Instruction) -> None:
  3585. self.replace_tos_if_return_is_generator()
  3586. assert self.instruction_pointer is not None
  3587. assert self.start_point is not None
  3588. get_metrics_context().increment(
  3589. "ir_count", self.instruction_pointer - self.start_point
  3590. )
  3591. if (
  3592. not config.allow_empty_graphs
  3593. and self.output.count_calls() == 0
  3594. and not self.inconsistent_side_effects
  3595. and not self.symbolic_locals_contain_module_class()
  3596. and not self.export
  3597. and not self.one_graph
  3598. and not self.error_on_graph_break
  3599. and not self.is_tracing_resume_prologue
  3600. ):
  3601. raise exc.SkipFrame("because no content in function call")
  3602. self.instruction_pointer = None
  3603. _step_logger()(
  3604. logging.INFO,
  3605. f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
  3606. )
  3607. log.debug("%s triggered compile", inst.opname)
  3608. all_stack_locals_metadata = self.output.compile_subgraph(
  3609. self,
  3610. reason=GraphCompileReason(
  3611. "return_value", [self.frame_summary()], graph_break=False
  3612. ),
  3613. # the value to be returned
  3614. stack_pops=1 if inst.opname == "RETURN_VALUE" else 0,
  3615. )
  3616. # check that our stack/locals meta are correct:
  3617. # we should only be tracing 1 frame, and there should not be any NULLs on the stack
  3618. assert len(all_stack_locals_metadata) == 1
  3619. assert not all_stack_locals_metadata[0].stack_null_idxes
  3620. return_inst = (
  3621. create_instruction("RETURN_VALUE")
  3622. if inst.opname == "RETURN_VALUE"
  3623. else create_instruction("RETURN_CONST", argval=inst.argval)
  3624. )
  3625. # NOTE: does the stack need to be empty after the return?
  3626. self.output.add_output_instructions([return_inst])
  3627. raise ReturnValueOp
  3628. def RETURN_VALUE(self, inst: Instruction) -> None:
  3629. self._return(inst)
  3630. def RETURN_CONST(self, inst: Instruction) -> None:
  3631. self._return(inst)
  3632. if sys.version_info >= (3, 11):
  3633. _binary_op_lookup = [
  3634. getattr(
  3635. InstructionTranslator,
  3636. opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}",
  3637. )
  3638. for opname, _ in dis._nb_ops # type: ignore[attr-defined]
  3639. ]
  3640. class InliningInstructionTranslator(InstructionTranslatorBase):
  3641. """Trace and inline a called method"""
  3642. symbolic_result: Optional[VariableTracker]
  3643. parent: InstructionTranslatorBase
  3644. @classmethod
  3645. def inline_call(cls, parent: Any, func: Any, args: Any, kwargs: Any) -> Any:
  3646. with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
  3647. tracer = cls.build_inline_tracer(parent, func, args, kwargs)
  3648. return tracer.inline_call_()
  3649. @staticmethod
  3650. def check_inlineable(func: Any) -> trace_rules.SkipResult:
  3651. if func.has_self():
  3652. unimplemented_v2(
  3653. gb_type="Inline attempt with __self__",
  3654. context=str(func),
  3655. explanation="Attempted to inline a function with the `__self__` attribute. "
  3656. "Dynamo is expected to decompose method calls into function calls with a `self` argument.",
  3657. hints=[],
  3658. )
  3659. if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
  3660. func.get_function(), "_torchdynamo_disable", False
  3661. ):
  3662. msg = inspect.getattr_static(
  3663. func.get_function(), "_torchdynamo_disable_msg", None
  3664. )
  3665. unimplemented_v2(
  3666. gb_type="Skip inlining `torch.compiler.disable()`d function",
  3667. context=str(func.get_function()),
  3668. explanation=f"Skip inlining function {func.get_function()} since it was wrapped "
  3669. f"with `torch.compiler.disable` (reason: {msg})",
  3670. hints=[
  3671. "Remove the `torch.compiler.disable` call",
  3672. ],
  3673. )
  3674. result = trace_rules.check_verbose(func, is_inlined_call=True)
  3675. if result.skipped:
  3676. from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
  3677. # _origin marks this as coming from an internal dynamo known function that is safe to
  3678. # trace through.
  3679. if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [
  3680. produce_trampoline_autograd_apply,
  3681. ]:
  3682. # Known sound
  3683. return trace_rules.SkipResult(
  3684. False, "allowlist in dynamo known function"
  3685. )
  3686. fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else ""
  3687. hints = [
  3688. f"Avoid calling the function `{fn_qualname}`.",
  3689. ]
  3690. if "_dynamo" not in func.get_filename():
  3691. hints += [
  3692. f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{fn_qualname}` "
  3693. "to force tracing into the function. "
  3694. "More graph breaks may occur as a result of attempting to trace into the function.",
  3695. "Please file an issue to PyTorch.",
  3696. ]
  3697. unimplemented_v2(
  3698. gb_type="Attempted to inline function marked as skipped",
  3699. context=f"qualname: {fn_qualname}, name: {func.get_name()}, "
  3700. f"filename: `{func.get_filename()}`, skip reason: {result.reason}",
  3701. explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` "
  3702. "should not be traced.",
  3703. hints=hints,
  3704. )
  3705. return result
  3706. @staticmethod
  3707. def build_inline_tracer(
  3708. parent: Any,
  3709. func: VariableTracker,
  3710. args: list[VariableTracker],
  3711. kwargs: Any,
  3712. ) -> InliningInstructionTranslator:
  3713. assert isinstance(
  3714. func,
  3715. (
  3716. UserFunctionVariable,
  3717. NestedUserFunctionVariable,
  3718. LocalGeneratorFunctionVariable,
  3719. LocalGeneratorObjectVariable,
  3720. ),
  3721. )
  3722. code: types.CodeType = func.get_code()
  3723. result = None
  3724. tracing_ctx = parent.output.tracing_context
  3725. # Check if we have already identified this function to be inline-able.
  3726. # The exception is dont_skip_tracing flag which affects the inline
  3727. # behavior. If the flag is True, don't rely on previous results.
  3728. if not config.dont_skip_tracing and tracing_ctx:
  3729. if previous_result := tracing_ctx.previously_inlined_functions.get(
  3730. code, None
  3731. ):
  3732. result = previous_result
  3733. if result is None:
  3734. if isinstance(func, SkipFunctionVariable):
  3735. unimplemented_v2(
  3736. gb_type="Attempted to inline function marked as skipped (SkipFunctionVariable)",
  3737. context=f"Attempted to inline a SkipFunctionVariable {func}",
  3738. explanation=(
  3739. "Attempted to inline a function that was previously determined to be marked as intentionally skipped."
  3740. ),
  3741. hints=[],
  3742. )
  3743. result = InliningInstructionTranslator.check_inlineable(func)
  3744. assert result.skipped is False
  3745. if not config.dont_skip_tracing and tracing_ctx:
  3746. tracing_ctx.previously_inlined_functions[code] = result
  3747. try:
  3748. sub_locals = func.bind_args(parent, args, kwargs)
  3749. except TypeError as e:
  3750. # Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info
  3751. raise ArgsMismatchError( # noqa: B904
  3752. "{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format(
  3753. reason=str(e),
  3754. func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}",
  3755. args=[arg.python_type() for arg in args],
  3756. kwargs=kwargs,
  3757. ),
  3758. )
  3759. for v in itertools.chain(sub_locals.values()):
  3760. if not isinstance(v, VariableTracker):
  3761. unimplemented_v2(
  3762. gb_type="Encountered unconverted argument when attempting to inline",
  3763. context=f"func: {func}, arg: {v}",
  3764. explanation="An argument to an inlined function was not successfully converted to a VariableTracker.",
  3765. hints=[*graph_break_hints.DYNAMO_BUG],
  3766. )
  3767. if code.co_name in ("__setitem__", "__setattr__") and not (
  3768. args and isinstance(args[0], variables.UserDefinedObjectVariable)
  3769. ):
  3770. unimplemented_v2(
  3771. gb_type="Unsupported __setitem__/__setattr__ inline attempt",
  3772. context=f"code name: {code.co_name}, args: {args}",
  3773. explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.",
  3774. hints=[],
  3775. )
  3776. suffix = ""
  3777. # TODO: mlazos, add support for enabling multiple artifact logs
  3778. # with a single alias
  3779. if torch._logging._internal.log_state.is_artifact_enabled("bytecode"):
  3780. suffix = f"\n{dis.Bytecode(code).dis()}"
  3781. if sys.version_info >= (3, 11):
  3782. cur_inst = parent.current_instruction
  3783. parent_code = parent.f_code
  3784. def get_trace_call_log_str() -> str:
  3785. header = parent.get_line_of_code_header(
  3786. lineno=cur_inst.positions.lineno
  3787. )
  3788. line = get_instruction_source_311(parent_code, cur_inst).rstrip()
  3789. return f"TRACE inlined call {code.co_name} from {header}\n{line}"
  3790. trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
  3791. log.debug("INLINING %s%s, %s", code, suffix, result.reason)
  3792. # Detect inline GraphModule calls in order to propagate node metadata,
  3793. # by checking if the first argument (self) is a variable tracking a GraphModule.
  3794. if args and isinstance(args[0], NNModuleVariable):
  3795. module = parent.output.get_submodule(args[0].module_key)
  3796. if isinstance(module, torch.fx.GraphModule):
  3797. # The inline call might not actually be a call to `forward`,
  3798. # but it is enough to add a context for `forward` in case it is called.
  3799. code_context.get_context(module.forward.__code__)[
  3800. "orig_graphmodule"
  3801. ] = weakref.ref(module)
  3802. tracer: InliningInstructionTranslator
  3803. if is_generator(code):
  3804. tracer = InliningGeneratorInstructionTranslator(
  3805. parent,
  3806. code,
  3807. sub_locals,
  3808. parent.symbolic_globals,
  3809. parent.symbolic_torch_function_state,
  3810. func,
  3811. )
  3812. else:
  3813. # need the line below to make MyPy happy
  3814. assert not isinstance(func, LocalGeneratorObjectVariable)
  3815. tracer = InliningInstructionTranslator(
  3816. parent,
  3817. code,
  3818. sub_locals,
  3819. parent.symbolic_globals,
  3820. parent.symbolic_torch_function_state,
  3821. func,
  3822. )
  3823. return tracer
  3824. def inline_call_(self) -> VariableTracker:
  3825. parent = self.parent
  3826. code = self.f_code
  3827. strict_ctx: Any = contextlib.nullcontext()
  3828. if parent.strict_checks_fn:
  3829. strict_ctx = self.strict_translation_mode(parent.strict_checks_fn)
  3830. try:
  3831. with strict_ctx:
  3832. self.run()
  3833. except exc.ObservedException as e:
  3834. msg = f"Observed exception DURING INLING {code} : {e}"
  3835. log.debug(msg)
  3836. # bubble up the exception to the parent frame.
  3837. raise
  3838. except exc.SkipFrame as e:
  3839. msg = f"SKIPPED INLINING {code}: {e}"
  3840. log.debug(msg)
  3841. raise Unsupported(msg) from e
  3842. except Exception:
  3843. log.debug("FAILED INLINING %s", code)
  3844. raise
  3845. finally:
  3846. parent.error_on_graph_break = self.error_on_graph_break
  3847. if self.output.should_exit:
  3848. # graph break
  3849. return ConstantVariable.create(None) # return dummy variable
  3850. assert self.symbolic_result is not None
  3851. if self.f_globals is parent.f_globals:
  3852. # Merge symbolic_globals back if parent and child are in the same namespace
  3853. parent.symbolic_globals.update(self.symbolic_globals)
  3854. parent.inconsistent_side_effects |= self.inconsistent_side_effects
  3855. log.debug("DONE INLINING %s", code)
  3856. self.output.tracing_context.traced_code.append(code)
  3857. if config.enable_faithful_generator_behavior or (
  3858. isinstance(self, InliningGeneratorInstructionTranslator)
  3859. and self.is_generator_from_ctx_manager
  3860. ):
  3861. if (
  3862. is_generator(code)
  3863. and isinstance(self, InliningGeneratorInstructionTranslator)
  3864. and self.generator_exhausted
  3865. ):
  3866. assert isinstance(self, InliningGeneratorInstructionTranslator)
  3867. # When the generator returns None, we raise StopIteration
  3868. args = []
  3869. if not (
  3870. isinstance(self.symbolic_result, ConstantVariable)
  3871. and self.symbolic_result.value is None
  3872. ):
  3873. args = [self.symbolic_result]
  3874. exc.raise_observed_exception(StopIteration, self, args=args)
  3875. else:
  3876. return self.symbolic_result
  3877. else:
  3878. if is_generator(code):
  3879. assert isinstance(self, InliningGeneratorInstructionTranslator)
  3880. assert self.symbolic_result.as_python_constant() is None
  3881. return ListIteratorVariable(
  3882. self.generated_items,
  3883. mutation_type=ValueMutationNew(),
  3884. )
  3885. else:
  3886. return self.symbolic_result
  3887. def __init__(
  3888. self,
  3889. parent: InstructionTranslatorBase,
  3890. code: types.CodeType,
  3891. symbolic_locals: dict[str, VariableTracker],
  3892. symbolic_globals: dict[str, VariableTracker],
  3893. symbolic_torch_function_state: SymbolicTorchFunctionState,
  3894. funcvar: BaseUserFunctionVariable,
  3895. ) -> None:
  3896. f_globals = funcvar.get_globals() # type: ignore[attr-defined]
  3897. f_builtins = f_globals["__builtins__"]
  3898. if not isinstance(f_builtins, dict):
  3899. f_builtins = f_builtins.__dict__
  3900. # Get the cached instructions. These instructions are safe to cache
  3901. # because we dont mutate them in transform_code_object (those
  3902. # instructions are for the top most Instruction translator). Also, we
  3903. # have to be careful about not using _cached_cleaned_instructions here
  3904. # because that function is global, while we want the the cache to be
  3905. # alive only during a compmilation.
  3906. tracing_ctx = parent.output.tracing_context
  3907. instructions = None
  3908. if tracing_ctx:
  3909. if tracing_ctx.previously_cleaned_instructions.get(code):
  3910. instructions = tracing_ctx.previously_cleaned_instructions[code]
  3911. if instructions is None:
  3912. instructions = cleaned_instructions(code)
  3913. propagate_line_nums(instructions)
  3914. if tracing_ctx:
  3915. tracing_ctx.previously_cleaned_instructions[code] = instructions
  3916. super().__init__(
  3917. output=parent.output,
  3918. f_locals={},
  3919. f_globals=f_globals,
  3920. f_builtins=f_builtins,
  3921. symbolic_locals=symbolic_locals,
  3922. symbolic_globals=symbolic_globals,
  3923. symbolic_torch_function_state=symbolic_torch_function_state,
  3924. instructions=instructions,
  3925. code_options={k: getattr(code, k) for k in get_code_keys()},
  3926. f_code=code,
  3927. export=parent.export,
  3928. inline_depth=parent.inline_depth + 1,
  3929. speculation_log=parent.speculation_log,
  3930. exn_vt_stack=parent.exn_vt_stack,
  3931. distributed_state=parent.distributed_state,
  3932. package=parent.package,
  3933. )
  3934. self.funcvar = funcvar
  3935. self.parent = parent
  3936. self.num_calls = parent.num_calls
  3937. self.symbolic_result = None
  3938. self.nn_module_stack = parent.nn_module_stack.copy()
  3939. self.one_graph = parent.one_graph
  3940. @property
  3941. def fake_mode(self) -> Optional[FakeTensorMode]:
  3942. return self.parent.fake_mode
  3943. def run_ctx_mgr(self) -> Any:
  3944. return TracingContext.current_frame(self.parent.frame_summary())
  3945. def should_compile_partial_graph(self) -> bool:
  3946. if config.nested_graph_breaks:
  3947. if not self.parent.should_compile_partial_graph():
  3948. return False
  3949. return super().should_compile_partial_graph()
  3950. return False # inlining functions is all-or-nothing
  3951. def create_call_resume_at(
  3952. self,
  3953. inst: Instruction,
  3954. all_stack_locals_metadata: Any,
  3955. disable_current_frame_resume: bool,
  3956. ) -> list[Instruction]:
  3957. if config.nested_graph_breaks:
  3958. return super().create_call_resume_at(
  3959. inst, all_stack_locals_metadata, disable_current_frame_resume
  3960. )
  3961. unimplemented_v2(
  3962. gb_type="Graph break in inlined function",
  3963. context="",
  3964. explanation="Graph breaks in an inlined call are not supported.",
  3965. hints=[],
  3966. )
  3967. def RETURN_VALUE(self, inst: Instruction) -> None:
  3968. self.symbolic_result = self.pop() # type: ignore[assignment]
  3969. self.instruction_pointer = None
  3970. raise ReturnValueOp
  3971. def RETURN_CONST(self, inst: Instruction) -> None:
  3972. self.symbolic_result = self._load_const(inst)
  3973. self.instruction_pointer = None
  3974. raise ReturnValueOp
  3975. def get_globals_source_and_value(
  3976. self, name: str
  3977. ) -> tuple[Any, VariableTracker, Source]:
  3978. # NamedTuple's `__new__` has a fake global scope that's not an actual
  3979. # module. TODO generalize the check for other non-importable cases.
  3980. # https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447
  3981. if "__name__" in self.f_globals and not self.f_globals["__name__"].startswith(
  3982. "namedtuple_"
  3983. ):
  3984. module_name = self.f_globals["__name__"]
  3985. module_source = self.import_source(module_name)
  3986. if "torch_package" in module_name:
  3987. fglobals_value = (
  3988. torch.package.package_importer._package_imported_modules[
  3989. module_name
  3990. ]
  3991. ) # type: ignore[assignment]
  3992. else:
  3993. fglobals_value = _import_module(module_name)
  3994. # Dont use lazy vt because we will do a setattr afterwards
  3995. fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
  3996. global_source = AttrSource(module_source, name)
  3997. else:
  3998. globals_name = self.output.install_global_by_id(
  3999. "___unnamed_scope", self.f_globals
  4000. )
  4001. globals_source = GlobalSource(globals_name)
  4002. fglobals_value = self.f_globals # type: ignore[assignment]
  4003. # Dont use lazy vt because we will do a setattr afterwards
  4004. fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
  4005. global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment]
  4006. if is_stdlib(fglobals_value):
  4007. # Users don't inplace mutate a stdlib attribute (like inspect,
  4008. # collections), skip guards that originate from the stdlib modules.
  4009. global_source = SkipGuardSource(global_source) # type: ignore[assignment]
  4010. return fglobals_value, fglobals_vt, global_source
  4011. def _load_global(self, inst: Instruction) -> None:
  4012. name = inst.argval
  4013. if name not in self.f_globals:
  4014. return self.load_builtin(inst)
  4015. if self.output.global_scope is self.f_globals:
  4016. # If the global scope matches that of the root frame, use handler in
  4017. # root frame instruction translator, to enforce consistency.
  4018. super()._load_global(inst)
  4019. else:
  4020. _, fglobals_vt, global_source = self.get_globals_source_and_value(name)
  4021. if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
  4022. self.push(self.output.side_effects.load_attr(fglobals_vt, name))
  4023. else:
  4024. value = self.f_globals[name]
  4025. self.push(VariableTracker.build(self, value, global_source))
  4026. def STORE_GLOBAL(self, inst: Instruction) -> None:
  4027. if self.output.global_scope is self.f_globals:
  4028. # If the global scope matches that of the root frame, use handler in
  4029. # root frame instruction translator, to enforce consistency.
  4030. super().STORE_GLOBAL(inst)
  4031. else:
  4032. value = self.pop()
  4033. if isinstance(value, RemovableHandleVariable):
  4034. unimplemented_v2(
  4035. gb_type="Storing Tensor hook handle in globals (inline call)",
  4036. context=inst.argval,
  4037. explanation="This is not supported.",
  4038. hints=[],
  4039. )
  4040. name = inst.argval
  4041. _fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
  4042. self.output.side_effects.store_attr(fglobals_vt, name, value)
  4043. class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
  4044. generated_items: list[VariableTracker]
  4045. # Flag whether or not the InlineGenerator should consume the entire iterator
  4046. def __init__(self, *args: Any, **kwargs: Any) -> None:
  4047. super().__init__(*args, **kwargs)
  4048. self.generated_items = []
  4049. self.generator_exhausted = False
  4050. self.is_generator_from_ctx_manager = False
  4051. def YIELD_VALUE(self, inst: Instruction) -> None:
  4052. top = self.pop()
  4053. self.generated_items.append(top)
  4054. if len(self.generated_items) > MAX_ITERATOR_LIMIT:
  4055. raise exc.InfiniteGeneratorError(
  4056. "Too many yield values in generator. Maybe you are inlining an infinite generator. "
  4057. f"If not, please report a bug at {PT2_ISSUE_TRACKER_URL}",
  4058. )
  4059. self.push(ConstantVariable.create(None))
  4060. if (
  4061. config.enable_faithful_generator_behavior
  4062. or self.is_generator_from_ctx_manager
  4063. ):
  4064. self.symbolic_result = top
  4065. # Stop tracing
  4066. raise YieldValueOp
  4067. def GET_YIELD_FROM_ITER(self, inst: Instruction) -> None:
  4068. tos = self.stack[-1]
  4069. if not isinstance(tos, ListIteratorVariable):
  4070. self.pop()
  4071. res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type]
  4072. self.push(res)
  4073. def RETURN_VALUE(self, inst: Instruction) -> None:
  4074. self.generator_exhausted = True
  4075. return super().RETURN_VALUE(inst)
  4076. def RETURN_CONST(self, inst: Instruction) -> None:
  4077. self.generator_exhausted = True
  4078. return super().RETURN_CONST(inst)
  4079. def YIELD_FROM(self, inst: Instruction) -> None:
  4080. assert len(self.stack) >= 2
  4081. val = self.pop()
  4082. tos = self.stack[-1]
  4083. if not (isinstance(val, ConstantVariable) and val.value is None):
  4084. # invoke send
  4085. # Unreachable code - if you hit this, you are implementing generator support and have
  4086. # lifted the `unimplemented("generator")` in frame conversion. This codepath handles
  4087. # subgenerator and lines up with this line in Python 3.10
  4088. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
  4089. unimplemented_v2(
  4090. gb_type="Unreachable sub-generator code",
  4091. context="",
  4092. explanation="Should only be encountered while implementing generator support.",
  4093. hints=[],
  4094. )
  4095. try:
  4096. val = tos.next_variable(self)
  4097. except (StopIteration, exc.ObservedUserStopIteration) as ex:
  4098. if isinstance(ex, exc.ObservedUserStopIteration):
  4099. exc.handle_observed_exception(self)
  4100. # The iterator is exhausted. Stop the loop and return.
  4101. self.pop()
  4102. self.push(ConstantVariable.create(ex.value))
  4103. else:
  4104. # Repeat the YIELD_FROM instruction in the next eval loop
  4105. assert (
  4106. isinstance(self.instruction_pointer, int)
  4107. and self.instruction_pointer > 0
  4108. )
  4109. self.instruction_pointer -= 1
  4110. self.push(val)
  4111. # Add the value to yield into generated_items and replace the top of the stack with None
  4112. self.YIELD_VALUE(inst)
  4113. def SEND(self, inst: Instruction) -> None:
  4114. assert len(self.stack) >= 2
  4115. val = self.pop()
  4116. tos = self.stack[-1]
  4117. if isinstance(tos, (IteratorVariable, LocalGeneratorObjectVariable)) or (
  4118. isinstance(tos, UserDefinedObjectVariable)
  4119. and isinstance(tos.value, collections.abc.Iterator)
  4120. ):
  4121. if isinstance(val, ConstantVariable) and val.value is None:
  4122. try:
  4123. val = tos.next_variable(self)
  4124. except (StopIteration, exc.ObservedUserStopIteration) as ex:
  4125. # To implement SEND, we have to look at the implementation
  4126. # when the iterator returns StopIteration. This translates to this code
  4127. # 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619
  4128. # 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866
  4129. # The implementation is different in 3.11 and 3.12. In 3.12, we rely
  4130. # on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
  4131. if sys.version_info < (3, 12):
  4132. self.pop() # Python 3.12 uses new opcode END_SEND
  4133. self.push(ConstantVariable.create(ex.value))
  4134. self.jump(inst)
  4135. else:
  4136. self.push(val)
  4137. else:
  4138. # invoke send
  4139. # Unreachable code - if you hit this, you are implementing generator support and have
  4140. # lifted the `unimplemented("generator")` in frame conversion. This codepath handles
  4141. # subgenerator and lines up with this line in Python 3.11
  4142. # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
  4143. unimplemented_v2(
  4144. gb_type="Unreachable sub-generator code",
  4145. context="",
  4146. explanation="Should only be encountered while implementing generator support.",
  4147. hints=[],
  4148. )
  4149. else:
  4150. unimplemented_v2(
  4151. gb_type="SEND with bad type",
  4152. context=f"TOS type: {typestr(tos)}",
  4153. explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.",
  4154. hints=[],
  4155. )