guards.py 172 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313
  1. """
  2. Core guard system for Dynamo that detects when compiled code needs to be recompiled due to
  3. changes in program state. Guards are conditions that must remain true for previously-compiled
  4. code to be valid for reuse.
  5. This module provides the infrastructure for creating, managing and checking guards, including:
  6. - Guard creation and composition
  7. - Guard state management and invalidation
  8. - Guard checking and failure handling
  9. - Utilities for guard optimization and debugging
  10. - Integration with Dynamo's compilation caching
  11. The guard system is critical for Dynamo's ability to efficiently reuse compiled code while
  12. maintaining correctness by detecting when recompilation is necessary due to changes in
  13. program state, tensor properties, or control flow.
  14. """
  15. from __future__ import annotations
  16. import ast
  17. import builtins
  18. import collections
  19. import dataclasses
  20. import enum
  21. import functools
  22. import importlib
  23. import inspect
  24. import io
  25. import logging
  26. import math
  27. import pickle
  28. import sys
  29. import textwrap
  30. import traceback
  31. import types
  32. import warnings
  33. import weakref
  34. from contextlib import contextmanager
  35. from copy import deepcopy
  36. from inspect import currentframe
  37. from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
  38. try:
  39. from typing import LiteralString
  40. except ImportError:
  41. from typing_extensions import LiteralString
  42. from typing_extensions import TypeAliasType, TypeVar
  43. from weakref import ReferenceType
  44. import torch
  45. import torch.overrides
  46. import torch.utils._device
  47. from torch._C._dynamo.eval_frame import code_framelocals_names
  48. from torch._C._dynamo.guards import (
  49. check_obj_id,
  50. check_type_id,
  51. ClosureGuardAccessor,
  52. CodeGuardAccessor,
  53. dict_version,
  54. DictGetItemGuardAccessor,
  55. DictGuardManager,
  56. FuncDefaultsGuardAccessor,
  57. FuncKwDefaultsGuardAccessor,
  58. GetAttrGuardAccessor,
  59. GetGenericDictGuardAccessor,
  60. GuardAccessor,
  61. GuardDebugInfo,
  62. GuardManager,
  63. install_no_tensor_aliasing_guard,
  64. install_object_aliasing_guard,
  65. install_storage_overlapping_guard,
  66. install_symbolic_shape_guard,
  67. LeafGuard,
  68. profile_guard_manager,
  69. RelationalGuard,
  70. RootGuardManager,
  71. TupleGetItemGuardAccessor,
  72. TypeDictGuardAccessor,
  73. TypeGuardAccessor,
  74. TypeMROGuardAccessor,
  75. )
  76. from torch._dynamo.source import (
  77. get_global_source_name,
  78. get_local_source_name,
  79. IndexedSource,
  80. is_from_flatten_script_object_source,
  81. is_from_local_source,
  82. is_from_optimizer_source,
  83. is_from_skip_guard_source,
  84. is_from_unspecialized_builtin_nn_module_source,
  85. TensorProperty,
  86. TensorPropertySource,
  87. )
  88. from torch._dynamo.utils import CompileEventLogger, get_metrics_context
  89. from torch._guards import (
  90. CompileContext,
  91. CompileId,
  92. DuplicateInputs,
  93. Guard,
  94. GuardBuilderBase,
  95. GuardEnvExpr,
  96. GuardSource,
  97. Source,
  98. StorageOverlap,
  99. )
  100. from torch._inductor.utils import IndentedBuffer
  101. from torch._logging import structured
  102. from torch._utils_internal import justknobs_check
  103. from torch.fx.experimental.symbolic_shapes import (
  104. _CppShapeGuardsHelper,
  105. _ShapeGuardsHelper,
  106. EqualityConstraint,
  107. is_symbolic,
  108. SYMPY_INTERP,
  109. )
  110. from torch.utils import _pytree as pytree
  111. from torch.utils._ordered_set import OrderedSet
  112. from torch.utils._traceback import format_frame, report_compile_source_on_error
  113. from torch.utils.weak import TensorWeakRef
  114. from . import config, convert_frame, exc
  115. from .eval_frame import set_guard_error_hook
  116. from .source import (
  117. AttrProxySource,
  118. AttrSource,
  119. CallFunctionNoArgsSource,
  120. CallMethodItemSource,
  121. ChainedSource,
  122. ClosureSource,
  123. CodeSource,
  124. ConstantSource,
  125. ConstDictKeySource,
  126. DataclassFieldsSource,
  127. DefaultsSource,
  128. DictGetItemSource,
  129. DictSubclassGetItemSource,
  130. FlattenScriptObjectSource,
  131. FloatTensorSource,
  132. FSDPNNModuleSource,
  133. GenericAttrSource,
  134. GetItemSource,
  135. GlobalSource,
  136. GlobalStateSource,
  137. GlobalWeakRefSource,
  138. GradSource,
  139. ListGetItemSource,
  140. LocalSource,
  141. NamedTupleFieldsSource,
  142. NNModuleSource,
  143. NonSerializableSetGetItemSource,
  144. NumpyTensorSource,
  145. OptimizerSource,
  146. ScriptObjectQualifiedNameSource,
  147. ShapeEnvSource,
  148. SubclassAttrListSource,
  149. TorchFunctionModeStackSource,
  150. TorchSource,
  151. TupleIteratorGetItemSource,
  152. TypeDictSource,
  153. TypeMROSource,
  154. TypeSource,
  155. UnspecializedBuiltinNNModuleSource,
  156. UnspecializedNNModuleSource,
  157. UnspecializedParamBufferSource,
  158. WeakRefCallSource,
  159. )
  160. from .types import ( # noqa: F401
  161. CacheEntry,
  162. DynamoFrameType,
  163. ExtraState,
  164. GuardedCode,
  165. GuardFail,
  166. GuardFilterEntry,
  167. GuardFn,
  168. )
  169. from .utils import (
  170. builtin_dict_keys,
  171. common_constant_types,
  172. dataclass_fields,
  173. dict_keys,
  174. get_custom_getattr,
  175. get_torch_function_mode_stack,
  176. get_torch_function_mode_stack_at,
  177. guard_failures,
  178. istype,
  179. key_is_id,
  180. key_to_id,
  181. normalize_range_iter,
  182. orig_code_map,
  183. tensor_always_has_static_shape,
  184. tuple_iterator_getitem,
  185. tuple_iterator_len,
  186. unpatched_nn_module_getattr,
  187. verify_guard_fn_signature,
  188. )
  189. guard_manager_testing_hook_fn: Optional[Callable[[Any, Any, Any], Any]] = None
  190. try:
  191. import numpy as np
  192. except ModuleNotFoundError:
  193. np = None # type: ignore[assignment]
  194. if TYPE_CHECKING:
  195. from collections.abc import Generator, KeysView, Sequence
  196. from sympy import Symbol
  197. from torch._C import DispatchKeySet
  198. from torch._dynamo.output_graph import OutputGraph, OutputGraphGuardsState
  199. T = TypeVar("T")
  200. log = logging.getLogger(__name__)
  201. guards_log = torch._logging.getArtifactLogger(__name__, "guards")
  202. recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
  203. recompiles_verbose_log = torch._logging.getArtifactLogger(
  204. __name__, "recompiles_verbose"
  205. )
  206. verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
  207. dunder_attrs_assumed_constants = (
  208. "__defaults__",
  209. "__kwdefaults__",
  210. "__code__",
  211. "__closure__",
  212. "__annotations__",
  213. "__func__",
  214. "__mro__",
  215. )
  216. class IndentedBufferWithPrefix(IndentedBuffer):
  217. def prefix(self) -> str:
  218. return "| " * (self._indent * self.tabwidth)
  219. def writeline(self, line: str, skip_prefix: bool = False) -> None: # type: ignore[override]
  220. if skip_prefix:
  221. super().writeline(line)
  222. else:
  223. super().writeline("+- " + line)
  224. class GuardManagerWrapper:
  225. """
  226. A helper class that contains the root guard manager. An instance of this
  227. class is stored in the Dynamo cache entry, so that the cache entry can
  228. access the RootGuardManager stored in the "root" attribute and directly call
  229. the check_nopybind from C++.
  230. """
  231. def __init__(self, root: Optional[RootGuardManager] = None) -> None:
  232. if root is None:
  233. self.root = RootGuardManager()
  234. else:
  235. self.root = root
  236. self.diff_guard_root: Optional[RootGuardManager] = None
  237. self.closure_vars: Optional[dict[str, Any]] = None
  238. self.args: Optional[list[str]] = None
  239. self.code_parts: list[str] = []
  240. self.verbose_code_parts: Optional[list[str]] = None
  241. self.global_scope: Optional[dict[str, Any]] = None
  242. self.guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
  243. self.cache_entry: Optional[CacheEntry] = None
  244. self.extra_state: Optional[ExtraState] = None
  245. self.id_matched_objs: dict[str, ReferenceType[object]] = {}
  246. self.no_tensor_aliasing_sources: list[str] = []
  247. self.printed_relational_guards: set[RelationalGuard] = set()
  248. self.diff_guard_sources: OrderedSet[str] = OrderedSet()
  249. @contextmanager
  250. def _preserve_printed_relational_guards(self) -> Generator[None, None, None]:
  251. self.printed_relational_guards = set()
  252. try:
  253. yield
  254. finally:
  255. self.printed_relational_guards = set()
  256. # TODO: clarify what fn and attributes guard manager has to get the right things here
  257. def collect_diff_guard_sources(self) -> OrderedSet[str]:
  258. # At the time of finalize, we have only marked guard managers with
  259. # TENSOR_MATCH guards as diff guard managers. So, we do a tree traversal
  260. # and collect all the nodes in the tree (branches) that lead to tensor
  261. # guards.
  262. # After a recompilation, some of guard managers will have a fail_count >
  263. # 0, so we collect them as well. Later on, we accumulate the diff guard
  264. # sources for all the guard managers.
  265. def visit_dict_manager(node: DictGuardManager) -> bool:
  266. is_diff_guard_node = (
  267. node.get_source() in self.diff_guard_sources or node.fail_count() > 0
  268. )
  269. for idx, (key_mgr, val_mgr) in sorted(
  270. node.get_key_value_managers().items()
  271. ):
  272. is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
  273. if is_diff_guard_node:
  274. self.diff_guard_sources.add(node.get_source())
  275. return is_diff_guard_node
  276. def visit_manager(node: GuardManager) -> bool:
  277. assert not isinstance(node, DictGuardManager)
  278. is_diff_guard_node = (
  279. node.get_source() in self.diff_guard_sources or node.fail_count() > 0
  280. )
  281. for child_mgr in node.get_child_managers():
  282. is_diff_guard_node |= visit(child_mgr)
  283. if is_diff_guard_node:
  284. self.diff_guard_sources.add(node.get_source())
  285. return is_diff_guard_node
  286. def visit(node: GuardManager) -> bool:
  287. if node is None:
  288. return False
  289. if isinstance(node, DictGuardManager):
  290. return visit_dict_manager(node)
  291. return visit_manager(node)
  292. visit(self.root)
  293. return self.diff_guard_sources
  294. def finalize(self) -> None:
  295. if config.use_recursive_dict_tags_for_guards and justknobs_check(
  296. "pytorch/compiler:use_recursive_dict_tags_for_guards"
  297. ):
  298. self.find_tag_safe_roots()
  299. self.prepare_diff_guard_manager()
  300. def prepare_diff_guard_manager(self) -> None:
  301. self.collect_diff_guard_sources()
  302. self.populate_diff_guard_manager()
  303. def find_tag_safe_roots(self) -> None:
  304. """
  305. Identify ``tag safe nodes`` and ``tag safe roots`` within a guard tree.
  306. -----------------------------------------------------------------------
  307. tag safe node
  308. -----------------------------------------------------------------------
  309. A *tag safe node* is a ``GuardManager`` whose guarded value satisfies one
  310. of the following conditions:
  311. 1. Immutable value - The value is intrinsically immutable according to
  312. ``is_immutable_object``. Tensors are considered immutable. To ensure
  313. that symbolic guards run, we also check that the GuardManager has no
  314. accessors.
  315. 2. Nested tag safe dictionary - The value is a ``dict`` whose keys and
  316. values are all tag safe nodes (checked recursively). Such dictionaries
  317. allow entire nested structures to be skipped once their identity tag
  318. matches.
  319. 3. Pure ``nn.Module`` - The value is an ``nn.Module`` whose sole
  320. accessor is ``GetGenericDictGuardAccessor``—i.e., it only exposes its
  321. ``__dict__`` and nothing else that could mutate between runs.
  322. For every tag safe node, verifying the identity/tag of just the top-level
  323. dictionary is enough to guarantee the entire subtree is unchanged, enabling
  324. a *fast-path* guard check.
  325. -----------------------------------------------------------------------
  326. tag safe root
  327. -----------------------------------------------------------------------
  328. A ``tag safe root`` is a tag safe node whose parent is not tag safe.
  329. These boundary nodes mark the points where guard evaluation can safely
  330. prune traversal: if a tag-safe root’s dictionary tag matches, the entire
  331. subtree beneath it is skipped.
  332. One strong requirement for tag safe root is for the guarded object to
  333. support weakref. Refer to more details in the Recursive dict tag
  334. matching note. In short, we need to save the weakref of the object on
  335. first invocation, and check if it is still valid in later iterations, to
  336. apply recursive dict tag optimizations. `dict` objects do NOT support
  337. weakref. Therefore, as of now, we only mark nn module related guard
  338. managers as tag safe roots.
  339. Algorithm
  340. ---------
  341. The search runs in post-order traversal
  342. 1. Visit leaves and classify them as tag safe or not.
  343. 2. Propagate tag-safety upward: a parent dictionary becomes tag safe only if
  344. all of its children are already tag-safe.
  345. 3. Propagate tag-safe-rootness upward: if the whole subtree is tag safe,
  346. the current node becomes the new tag safe root, otherwise propagate the
  347. subtree tag safe roots.
  348. 4. Collect every tag safe node and, by inspecting parent tags, label the
  349. subset that are tag safe roots.
  350. """
  351. def check_tag_safety(
  352. node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...]
  353. ) -> bool:
  354. accessors = node.get_accessors()
  355. child_mgrs = node.get_child_managers()
  356. return all(
  357. isinstance(accessor, accepted_accessors) and mgr.is_tag_safe()
  358. for accessor, mgr in zip(accessors, child_mgrs)
  359. )
  360. def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]:
  361. # Just recurse through the key and value dict managers and check if
  362. # all of them are tag safe nodes.
  363. assert issubclass(node.get_type_of_guarded_value(), dict)
  364. tag_safe_roots = []
  365. is_subtree_tag_safe = True
  366. # Recurse to get the tag safe roots from subtree.
  367. for idx, (key_mgr, val_mgr) in sorted(
  368. node.get_key_value_managers().items()
  369. ):
  370. if key_mgr is not None:
  371. visit(key_mgr)
  372. if val_mgr is not None:
  373. tag_safe_roots.extend(visit(val_mgr))
  374. for idx, (key_mgr, val_mgr) in sorted(
  375. node.get_key_value_managers().items()
  376. ):
  377. if key_mgr:
  378. is_subtree_tag_safe &= key_mgr.is_tag_safe()
  379. if val_mgr:
  380. is_subtree_tag_safe &= val_mgr.is_tag_safe()
  381. if is_subtree_tag_safe:
  382. node.mark_tag_safe()
  383. return tag_safe_roots
  384. def visit_manager(node: GuardManager) -> list[GuardManager]:
  385. assert not isinstance(node, DictGuardManager)
  386. # Collect the subtree tag safe roots
  387. tag_safe_roots = []
  388. for child_mgr in node.get_child_managers():
  389. tag_safe_roots.extend(visit(child_mgr))
  390. if node.is_guarded_value_immutable():
  391. # If the node guards a tensor, mark it tag safe only if there
  392. # are no accessors. Presence of accessors means presence of
  393. # symbolic shape guards.
  394. if issubclass(node.get_type_of_guarded_value(), torch.Tensor):
  395. if node.has_no_accessors() and not node.has_object_aliasing_guard():
  396. node.mark_tag_safe()
  397. else:
  398. node.mark_tag_safe()
  399. elif issubclass(node.get_type_of_guarded_value(), dict):
  400. accessors = node.get_accessors()
  401. child_mgrs = node.get_child_managers()
  402. is_subtree_tag_safe = all(
  403. isinstance(accessor, DictGetItemGuardAccessor) and mgr.is_tag_safe()
  404. for accessor, mgr in zip(accessors, child_mgrs)
  405. )
  406. if is_subtree_tag_safe:
  407. node.mark_tag_safe()
  408. elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
  409. is_subtree_tag_safe = check_tag_safety(
  410. node, (GetGenericDictGuardAccessor, TypeGuardAccessor)
  411. )
  412. if is_subtree_tag_safe:
  413. node.mark_tag_safe()
  414. # Return the current node as tag safe root, discarding the
  415. # subtree tag safe roots.
  416. return [
  417. node,
  418. ]
  419. elif (
  420. node.get_type_of_guarded_value()
  421. in (
  422. types.FunctionType,
  423. types.MethodType,
  424. staticmethod,
  425. classmethod,
  426. )
  427. and config.assume_dunder_attributes_remain_unchanged
  428. ):
  429. # Assumption: callers will not reassignthe attributes
  430. # func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__.
  431. # Mutating the objects those attributes point to is fine;
  432. # rebinding the attribute itself is not.
  433. # Example ─ allowed: foo.__defaults__[0].bar = 99
  434. # forbidden: foo.__defaults__ = (3, 4)
  435. is_subtree_tag_safe = check_tag_safety(
  436. node,
  437. (
  438. CodeGuardAccessor,
  439. ClosureGuardAccessor,
  440. FuncDefaultsGuardAccessor,
  441. FuncKwDefaultsGuardAccessor,
  442. GetAttrGuardAccessor,
  443. ),
  444. )
  445. for accessor in node.get_accessors():
  446. if isinstance(accessor, GetAttrGuardAccessor):
  447. is_subtree_tag_safe &= (
  448. accessor.get_attr_name() in dunder_attrs_assumed_constants
  449. )
  450. if is_subtree_tag_safe:
  451. node.mark_tag_safe()
  452. elif issubclass(node.get_type_of_guarded_value(), types.CellType):
  453. is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,))
  454. is_subtree_tag_safe &= all(
  455. isinstance(accessor, GetAttrGuardAccessor)
  456. and accessor.get_attr_name() == "cell_contents"
  457. for accessor in node.get_accessors()
  458. )
  459. if is_subtree_tag_safe:
  460. node.mark_tag_safe()
  461. elif (
  462. issubclass(node.get_type_of_guarded_value(), tuple)
  463. and node.get_source().endswith(dunder_attrs_assumed_constants)
  464. and config.assume_dunder_attributes_remain_unchanged
  465. ):
  466. # We trust tuples obtained from a function’s __closure__ or
  467. # __defaults__. Any *other* tuple-valued attribute can be
  468. # silently replaced—for example:
  469. #
  470. # foo.bar = (1, 2) # original
  471. # foo.bar = (3, 4) # rebinding that our dict-tag optimisation won’t see
  472. #
  473. # Therefore only tuples from __closure__ / __defaults__ participate in the
  474. # recursive-dict-tag optimization; all others are ignored.
  475. is_subtree_tag_safe = check_tag_safety(
  476. node, (TupleGetItemGuardAccessor,)
  477. )
  478. if is_subtree_tag_safe:
  479. node.mark_tag_safe()
  480. elif issubclass(node.get_type_of_guarded_value(), type):
  481. is_subtree_tag_safe = check_tag_safety(
  482. node, (TypeDictGuardAccessor, TypeMROGuardAccessor)
  483. )
  484. if is_subtree_tag_safe:
  485. node.mark_tag_safe()
  486. return tag_safe_roots
  487. def visit(node: GuardManager) -> list[GuardManager]:
  488. if node is None:
  489. return []
  490. if isinstance(node, DictGuardManager):
  491. return visit_dict_manager(node)
  492. return visit_manager(node)
  493. tag_safe_roots = visit(self.root)
  494. for node in tag_safe_roots:
  495. if issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
  496. node.mark_tag_safe_root()
  497. def populate_diff_guard_manager(self) -> None:
  498. self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources)
  499. # Ensure that that C++ side points to the updated diff guard manager.
  500. # When a new GuardManagerWrapper is created, it does not have a
  501. # cache_entry attribute, so it relies on the CacheEntry constructor to
  502. # set the diff_guard_root in C++. But once it is saved in the Dynamo
  503. # cache, C++ side adds a cache_entry attribute. On recompiles, this
  504. # cache_entry is visible, so we update the C++ side to point to the
  505. # update guard manager.
  506. if self.cache_entry:
  507. self.cache_entry.update_diff_guard_root_manager()
  508. def clone_with_chosen_sources(
  509. self, chosen_sources: OrderedSet[str]
  510. ) -> RootGuardManager:
  511. def filter_fn(node_mgr: GuardManager) -> bool:
  512. return node_mgr.get_source() in chosen_sources
  513. return self.root.clone_manager(filter_fn)
  514. def get_guard_lines(self, guard: LeafGuard) -> list[str]:
  515. guard_name = guard.__class__.__name__
  516. parts = guard.verbose_code_parts()
  517. parts = [guard_name + ": " + part for part in parts]
  518. return parts
  519. def get_manager_line(
  520. self, guard_manager: GuardManager, accessor_str: Optional[str] = None
  521. ) -> str:
  522. source = guard_manager.get_source()
  523. t = guard_manager.__class__.__name__
  524. s = t + ": source=" + source
  525. if accessor_str:
  526. s += ", " + accessor_str
  527. s += f", type={guard_manager.get_type_of_guarded_value()}"
  528. s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})"
  529. return s
  530. def construct_dict_manager_string(
  531. self, mgr: DictGuardManager, body: IndentedBufferWithPrefix
  532. ) -> None:
  533. for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()):
  534. body.writeline(f"KeyValueManager pair at index={idx}")
  535. with body.indent():
  536. if key_mgr:
  537. body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}")
  538. self.construct_manager_string(key_mgr, body)
  539. if val_mgr:
  540. body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}")
  541. self.construct_manager_string(val_mgr, body)
  542. def construct_manager_string(
  543. self, mgr: GuardManager, body: IndentedBufferWithPrefix
  544. ) -> None:
  545. with body.indent():
  546. for guard in mgr.get_leaf_guards():
  547. if isinstance(guard, RelationalGuard):
  548. if guard not in self.printed_relational_guards:
  549. self.printed_relational_guards.add(guard)
  550. body.writelines(self.get_guard_lines(guard))
  551. else:
  552. body.writelines(
  553. [
  554. guard.__class__.__name__,
  555. ]
  556. )
  557. else:
  558. body.writelines(self.get_guard_lines(guard))
  559. # This works for both DictGuardManager and SubclassedDictGuardManager
  560. if isinstance(mgr, DictGuardManager):
  561. self.construct_dict_manager_string(mgr, body)
  562. # General case of GuardManager/RootGuardManager
  563. for accessor, child_mgr in zip(
  564. mgr.get_accessors(), mgr.get_child_managers()
  565. ):
  566. body.writeline(
  567. self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}")
  568. )
  569. self.construct_manager_string(child_mgr, body)
  570. def __str__(self) -> str:
  571. with self._preserve_printed_relational_guards():
  572. body = IndentedBufferWithPrefix()
  573. body.tabwidth = 1
  574. body.writeline("", skip_prefix=True)
  575. body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True)
  576. body.writeline("RootGuardManager")
  577. self.construct_manager_string(self.root, body)
  578. if hasattr(self.root, "get_epilogue_lambda_guards"):
  579. for guard in self.root.get_epilogue_lambda_guards():
  580. body.writelines(self.get_guard_lines(guard))
  581. return body.getvalue()
  582. def check(self, x: Any) -> bool:
  583. # Only needed for debugging purposes.
  584. return self.root.check(x)
  585. def check_verbose(self, x: Any) -> GuardDebugInfo:
  586. # Only needed for debugging purposes.
  587. return self.root.check_verbose(x)
  588. def populate_code_parts_for_debugging(self) -> None:
  589. # This should be called when the guard manager is fully populated
  590. relational_guards_seen = set()
  591. def get_code_parts(leaf_guard: LeafGuard) -> list[str]:
  592. code_parts = []
  593. for verbose_code_part in leaf_guard.verbose_code_parts():
  594. code_part = verbose_code_part.split("#")[0].rstrip()
  595. code_parts.append(code_part)
  596. return code_parts
  597. def visit(mgr: GuardManager) -> None:
  598. nonlocal relational_guards_seen
  599. for guard in mgr.get_leaf_guards():
  600. if isinstance(guard, RelationalGuard):
  601. if guard not in relational_guards_seen:
  602. self.code_parts.extend(get_code_parts(guard))
  603. relational_guards_seen.add(guard)
  604. else:
  605. self.code_parts.extend(get_code_parts(guard))
  606. for child_mgr in mgr.get_child_managers():
  607. visit(child_mgr)
  608. visit(self.root)
  609. def from_numpy(a: Any) -> torch.Tensor:
  610. # If not numpy array, piggy back on e.g. tensor guards to check type
  611. # Re-enable torch function since we disable it on leaf guards
  612. # we need it to properly construct the tensor if a default device is set
  613. with torch.overrides._enable_torch_function():
  614. return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
  615. # For user stack printing
  616. @functools.cache
  617. def uninteresting_files() -> set[str]:
  618. import torch._dynamo.external_utils
  619. import torch._dynamo.polyfills
  620. mods = [torch._dynamo.external_utils, torch._dynamo.polyfills]
  621. from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
  622. mods.extend(POLYFILLED_MODULES)
  623. return {inspect.getfile(m) for m in mods}
  624. _CLOSURE_VARS: Optional[dict[str, object]] = None
  625. def _get_closure_vars() -> dict[str, object]:
  626. global _CLOSURE_VARS
  627. if _CLOSURE_VARS is None:
  628. _CLOSURE_VARS = {
  629. "___check_type_id": check_type_id,
  630. "___check_obj_id": check_obj_id,
  631. "___odict_getitem": collections.OrderedDict.__getitem__,
  632. "___key_to_id": key_to_id,
  633. "___dict_version": dict_version,
  634. "___dict_contains": lambda a, b: dict.__contains__(b, a),
  635. "___tuple_iterator_len": tuple_iterator_len,
  636. "___normalize_range_iter": normalize_range_iter,
  637. "___tuple_iterator_getitem": tuple_iterator_getitem,
  638. "___dataclass_fields": dataclass_fields,
  639. "___namedtuple_fields": lambda x: x._fields,
  640. "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
  641. "__math_isnan": math.isnan,
  642. "__numpy_isnan": None if np is None else np.isnan,
  643. "inf": float("inf"),
  644. "__load_module": importlib.import_module,
  645. "utils_device": torch.utils._device,
  646. "device": torch.device,
  647. "___from_numpy": from_numpy,
  648. "___as_tensor": torch._as_tensor_fullprec,
  649. "torch": torch,
  650. "inspect": inspect,
  651. }
  652. return _CLOSURE_VARS
  653. def _ast_unparse(node: ast.AST) -> str:
  654. return ast.unparse(node).replace("\n", "")
  655. strip_function_call = torch._C._dynamo.strip_function_call
  656. def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str:
  657. extra = ""
  658. if guard is not None:
  659. if guard.user_stack:
  660. for fs in reversed(guard.user_stack):
  661. if fs.filename not in uninteresting_files():
  662. extra = f" # {format_frame(fs, line=True)}"
  663. if len(extra) > 1024:
  664. # For fx graphs, the line can be very long in case of
  665. # torch.stack ops, where many inputs are set to None
  666. # after the operation. This increases the size of the
  667. # guards log file. In such cases, do not print the line
  668. # contents.
  669. extra = f" # {format_frame(fs)}"
  670. break
  671. elif guard.stack:
  672. summary = guard.stack.summary()
  673. if len(summary) > 0:
  674. extra = f" # {format_frame(summary[-1])}"
  675. else:
  676. extra = " # <unknown>"
  677. return f"{code_part:<60}{extra}"
  678. def get_verbose_code_parts(
  679. code_parts: Union[str, list[str]],
  680. guard: Optional[Guard],
  681. recompile_hint: Optional[str] = None,
  682. ) -> list[str]:
  683. if not isinstance(code_parts, list):
  684. code_parts = [code_parts]
  685. verbose_code_parts = [
  686. get_verbose_code_part(code_part, guard) for code_part in code_parts
  687. ]
  688. if recompile_hint:
  689. verbose_code_parts = [
  690. f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
  691. ]
  692. return verbose_code_parts
  693. def convert_int_to_concrete_values(dim: Any) -> Optional[int]:
  694. if dim is None:
  695. return None
  696. if not is_symbolic(dim):
  697. return dim
  698. else:
  699. assert isinstance(dim, torch.SymInt)
  700. return dim.node.maybe_as_int()
  701. def convert_to_concrete_values(size_or_stride: list[Any]) -> list[Optional[int]]:
  702. return [convert_int_to_concrete_values(dim) for dim in size_or_stride]
  703. def get_tensor_guard_code_part(
  704. value: torch.Tensor,
  705. name: str,
  706. sizes: list[Optional[int]],
  707. strides: list[Optional[int]],
  708. pytype: type,
  709. dispatch_keys: DispatchKeySet,
  710. ) -> str:
  711. dispatch_key = (
  712. dispatch_keys | torch._C._dispatch_tls_local_include_set()
  713. ) - torch._C._dispatch_tls_local_exclude_set()
  714. dtype = value.dtype
  715. device_index = value.device.index
  716. requires_grad = value.requires_grad
  717. guard_str = (
  718. f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, "
  719. f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})"
  720. )
  721. return guard_str
  722. def get_key_index(dct: dict[Any, Any], key: Any) -> int:
  723. # Ensure that we call dict.keys and not value.keys (which can call
  724. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  725. # to traverse the dictionary, which uses the internal data structure and
  726. # does not call the overridden keys method.
  727. return list(builtin_dict_keys(dct)).index(key)
  728. def get_key_index_source(source: Any, index: Any) -> str:
  729. return f"list(dict.keys({source}))[{index}]"
  730. def raise_local_type_error(obj: Any) -> NoReturn:
  731. raise TypeError(
  732. f"Type {type(obj)} for object {obj} cannot be saved "
  733. + "into torch.compile() package since it's defined in local scope. "
  734. + "Please define the class at global scope (top level of a module)."
  735. )
  736. def should_optimize_getattr_on_nn_module(value: Any) -> bool:
  737. # If inline_inbuilt_nn_modules flag is True, Dynamo has already traced
  738. # through the __getattr__, and therefore it is always safe to optimize
  739. # getattr on nn modules.
  740. return isinstance(value, torch.nn.Module) and (
  741. config.inline_inbuilt_nn_modules
  742. or get_custom_getattr(value) is unpatched_nn_module_getattr
  743. )
  744. @dataclasses.dataclass(frozen=True)
  745. class NNModuleAttrAccessorInfo:
  746. # Represents where is the attr name is present in the nn module attribute
  747. # access
  748. # Tells that the attribute can be accessed via __dict__
  749. present_in_generic_dict: bool = False
  750. # Either the actual name or _parameters/_buffers/_modules
  751. l1_key: Optional[str] = None
  752. # Actual parameter/buffer/submodule name
  753. l2_key: Optional[str] = None
  754. def getitem_on_dict_manager(
  755. source: Union[DictGetItemSource, DictSubclassGetItemSource],
  756. base_guard_manager: DictGuardManager,
  757. base_example_value: Any,
  758. example_value: Any,
  759. guard_manager_enum: GuardManagerType,
  760. ) -> GuardManager:
  761. base_source_name = source.base.name()
  762. if isinstance(source.index, ConstDictKeySource):
  763. index = source.index.index
  764. else:
  765. assert isinstance(base_example_value, dict)
  766. index = get_key_index(base_example_value, source.index)
  767. key_source = get_key_index_source(base_source_name, index)
  768. # Ensure that we call dict.keys and not value.keys (which can call
  769. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  770. # to traverse the dictionary, which uses the internal data structure and
  771. # does not call the overridden keys method.
  772. key_example_value = list(builtin_dict_keys(base_example_value))[index]
  773. if isinstance(key_example_value, (int, str)):
  774. value_source = f"{base_source_name}[{key_example_value!r}]"
  775. else:
  776. value_source = f"{base_source_name}[{key_source}]"
  777. if not isinstance(source.index, ConstDictKeySource):
  778. # We have to insert a key manager guard here
  779. # TODO - source debug string is probably wrong here.
  780. base_guard_manager.get_key_manager(
  781. index=index,
  782. source=key_source,
  783. example_value=source.index,
  784. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  785. ).add_equals_match_guard(
  786. source.index, [f"{key_source} == {key_example_value!r}"]
  787. )
  788. return base_guard_manager.get_value_manager(
  789. index=index,
  790. source=value_source,
  791. example_value=example_value,
  792. guard_manager_enum=guard_manager_enum,
  793. )
  794. def match_on_id_for_tensor(guard: Guard) -> bool:
  795. source = guard.originating_source
  796. # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads
  797. # to a new tensor every time and therefore id differs.
  798. if isinstance(source, NumpyTensorSource):
  799. return False
  800. if guard.is_specialized_nn_module():
  801. return True
  802. return source.is_dict_key() and not isinstance(source, GradSource)
  803. # The ready to eval generated code (possibly multiple parts) for a guard, plus
  804. # the original guard object that created it for provenance
  805. @dataclasses.dataclass
  806. class GuardCodeList:
  807. code_list: list[str]
  808. guard: Guard
  809. class GuardManagerType(enum.Enum):
  810. GUARD_MANAGER = 1
  811. DICT_GUARD_MANAGER = 2
  812. @functools.cache
  813. def code_framelocals_names_reversed_cached(code: types.CodeType) -> list[str]:
  814. return list(reversed(code_framelocals_names(code)))
  815. class GuardBuilder(GuardBuilderBase):
  816. def __init__(
  817. self,
  818. f_code: types.CodeType,
  819. id_ref: Callable[[object, str], int],
  820. source_ref: Callable[[Source], str],
  821. lookup_weakrefs: Callable[[object], Optional[weakref.ref[object]]],
  822. local_scope: dict[str, object],
  823. global_scope: dict[str, object],
  824. guard_manager: GuardManagerWrapper,
  825. check_fn_manager: CheckFunctionManager,
  826. save_guards: bool = False,
  827. runtime_global_scope: Optional[dict[str, object]] = None,
  828. ) -> None:
  829. self.f_code = f_code
  830. self.id_ref = id_ref
  831. self.source_ref = source_ref
  832. self.lookup_weakrefs = lookup_weakrefs
  833. self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
  834. self.runtime_global_scope = runtime_global_scope or global_scope
  835. self.scope["__builtins__"] = builtins.__dict__.copy()
  836. for (
  837. name,
  838. package_module,
  839. ) in torch.package.package_importer._package_imported_modules.items():
  840. name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
  841. # Write the package module into the scope so that we can import it
  842. self.scope["__builtins__"][name] = package_module
  843. # Write the demangled name to the scope so that we can use it
  844. self.scope[name] = package_module
  845. self.guard_manager = guard_manager
  846. self.argnames: list[str] = []
  847. # Code is python expression strings generated for each guard
  848. self.code: list[GuardCodeList] = []
  849. # shape_env_code is only used by builder and is used for
  850. # shape env code. This exists only because we need to make sure
  851. # shape env guards get run after tensor match guards (since the
  852. # tensor match guards make sure we actually have tensors)
  853. self.shape_env_code: list[GuardCodeList] = []
  854. # Collect the guard managers and debug info to insert no tensor aliasing
  855. # guards.
  856. self.no_tensor_aliasing_names: list[str] = []
  857. self.no_tensor_aliasing_guard_managers: list[GuardManager] = []
  858. self.check_fn_manager: CheckFunctionManager = check_fn_manager
  859. # Collect the ids of dicts which need key order guarding. source_name is
  860. # not sufficient because for nn modules, we can have different sources
  861. # to access the same object - self._module["param"] is same as
  862. # self.param.
  863. self.key_order_guarded_dict_ids = set()
  864. assert self.check_fn_manager.output_graph is not None
  865. for source in self.check_fn_manager.output_graph.guard_on_key_order:
  866. self.key_order_guarded_dict_ids.add(id(self.get(source.name())))
  867. # Keep track of weak references of objects with ID_MATCH guard. This
  868. # info is stored alongside optimized_code and guard_manager and is used to
  869. # limit the number of cache entries with same ID_MATCH'd object.
  870. self.id_matched_objs: dict[str, ReferenceType[object]] = {}
  871. # Save the guard managers to avoid repeatedly traversing sources.
  872. self._cached_guard_managers: dict[str, GuardManager] = {}
  873. self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
  874. self.object_aliasing_guard_codes: list[tuple[str, str]] = []
  875. self.save_guards = save_guards
  876. self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
  877. "pytorch/compiler:guard_nn_modules"
  878. )
  879. self.already_guarded_not_present_in_generic_dict: OrderedSet[
  880. tuple[str, str]
  881. ] = OrderedSet()
  882. def guard_on_dict_keys_and_ignore_order(
  883. self, example_value: dict[Any, Any], guard: Guard
  884. ) -> None:
  885. dict_mgr = self.get_guard_manager(guard)
  886. if isinstance(dict_mgr, DictGuardManager):
  887. raise NotImplementedError(
  888. "Not expecting a DictGuardManager. Seems like Dynamo incorrectly "
  889. f"added the dict to tx.output.guard_on_key_order for {guard.name}"
  890. )
  891. # Iterate over the dicts and install a dict_getitem_manager.
  892. dict_source = guard.originating_source.name()
  893. # Ensure that we call dict.keys and not value.keys (which can call
  894. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  895. # to traverse the dictionary, which uses the internal data structure and
  896. # does not call the overridden keys method.
  897. for key in builtin_dict_keys(example_value):
  898. value = example_value[key]
  899. value_source = DictGetItemSource(guard.originating_source, index=key)
  900. guard_manager_enum = self.get_guard_manager_type(
  901. value_source, example_value
  902. )
  903. dict_mgr.dict_getitem_manager(
  904. key=key,
  905. source=f"{dict_source}[{key!r}]",
  906. example_value=value,
  907. guard_manager_enum=guard_manager_enum,
  908. )
  909. def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> None:
  910. # Add key managers for the DictGuardManager. Then add either an
  911. # ID_MATCH or EQUALS_MATCH guard on the key.
  912. dict_mgr = self.get_guard_manager(guard)
  913. if not isinstance(dict_mgr, DictGuardManager):
  914. raise NotImplementedError(
  915. "Expecting a DictGuardManager. Seems like Dynamo forgot "
  916. f"to set the right guard manager enum for {guard.name}"
  917. )
  918. assert isinstance(dict_mgr, DictGuardManager)
  919. # Ensure that we call dict.keys and not value.keys (which can call
  920. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  921. # to traverse the dictionary, which uses the internal data structure and
  922. # does not call the overridden keys method.
  923. for idx, key in enumerate(builtin_dict_keys(value)):
  924. key_source = get_key_index_source(guard.name, idx)
  925. key_manager = dict_mgr.get_key_manager(
  926. index=idx,
  927. source=key_source,
  928. example_value=key,
  929. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  930. )
  931. if key_is_id(key):
  932. # Install ID_MATCH guard
  933. id_val = self.id_ref(key, key_source)
  934. key_manager.add_id_match_guard(
  935. id_val,
  936. get_verbose_code_parts(
  937. f"__check_obj_id({key_source}, {id_val})", guard
  938. ),
  939. )
  940. else:
  941. # Install EQUALS_MATCH guard
  942. key_manager.add_equals_match_guard(
  943. key, get_verbose_code_parts(f"{key_source} == {key!r}", guard)
  944. )
  945. @staticmethod
  946. def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any]:
  947. # due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115,
  948. # reported in https://github.com/python/cpython/issues/125608,
  949. # fixed by https://github.com/python/cpython/pull/125611), we cannot take
  950. # advantage of __dict__ versions to speed up guard checks.
  951. if (
  952. config.issue_3_13_0_warning
  953. and sys.version_info >= (3, 13)
  954. and sys.version_info < (3, 13, 1)
  955. ):
  956. warnings.warn(
  957. "Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.",
  958. RuntimeWarning,
  959. )
  960. return None
  961. return example_value
  962. def getattr_on_nn_module(
  963. self,
  964. source: AttrSource,
  965. base_guard_manager: GuardManager,
  966. base_example_value: Any,
  967. example_value: Any,
  968. base_source_name: str,
  969. source_name: str,
  970. guard_manager_enum: GuardManagerType,
  971. ) -> GuardManager:
  972. """
  973. This tries to avoid calling the expensive nn module custom getattr method by
  974. checking if the attribute is accessible via __dict__. For attributes that
  975. are not accessible via __dict__ (like descriptors), we fallback to
  976. PyObject_GetAttr.
  977. There are two cases that we optimize for
  978. 1) attributes present directly in __dict__, e.g training.
  979. 2) parameters/buffers/modules - they can be accessed via _parameters,
  980. _buffers, _modules keys in __dict__. For example, mod.linear can be
  981. accessed as mod.__dict__["_parameters"]["linear"]
  982. The most common and expensive case for nn module guards is of type
  983. mod.submod1.submod2.submod3.training. We avoid the python getattr of nn
  984. modules by going through the __dict__.
  985. """
  986. def getitem_on_dict_mgr(
  987. mgr: GuardManager,
  988. key: Any,
  989. source_name: str,
  990. base_example_value: Any,
  991. example_value: Any,
  992. guard_manager_enum: GuardManagerType,
  993. ) -> GuardManager:
  994. if isinstance(mgr, DictGuardManager):
  995. # Case where the user code relies on key order, e.g.,
  996. # named_parameters
  997. index = get_key_index(base_example_value, key)
  998. # Install the key manager and add equals match guard
  999. key_source = f"list(dict.keys({source_name}))[{index!r}]"
  1000. mgr.get_key_manager(
  1001. index=index,
  1002. source=key_source,
  1003. example_value=key,
  1004. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1005. ).add_equals_match_guard(key, [f"{key_source} == {key!r}"])
  1006. # Install the value manager
  1007. return mgr.get_value_manager(
  1008. index=index,
  1009. source=source_name,
  1010. example_value=example_value,
  1011. guard_manager_enum=guard_manager_enum,
  1012. )
  1013. else:
  1014. return mgr.dict_getitem_manager(
  1015. key=key,
  1016. source=source_name,
  1017. example_value=example_value,
  1018. guard_manager_enum=guard_manager_enum,
  1019. )
  1020. attr_name = source.member
  1021. mod_dict = base_example_value.__dict__
  1022. all_class_attribute_names: set[str] = set()
  1023. for x in inspect.getmro(base_example_value.__class__):
  1024. all_class_attribute_names.update(x.__dict__.keys())
  1025. accessor_info = NNModuleAttrAccessorInfo(False, None, None)
  1026. if attr_name in mod_dict:
  1027. accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None)
  1028. elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]:
  1029. accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name)
  1030. elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]:
  1031. accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name)
  1032. elif (
  1033. attr_name not in all_class_attribute_names
  1034. and "_modules" in mod_dict
  1035. and attr_name in mod_dict["_modules"]
  1036. ):
  1037. # Check test_attr_precedence test - instance attributes always take precedence unless its an nn.Module.
  1038. accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name)
  1039. if not accessor_info.present_in_generic_dict:
  1040. # The attribute can be accessed by __getattribute__ call, so rely on
  1041. # PyObject_GetAttr
  1042. return base_guard_manager.getattr_manager(
  1043. attr=source.member,
  1044. source=source_name,
  1045. example_value=example_value,
  1046. guard_manager_enum=guard_manager_enum,
  1047. )
  1048. else:
  1049. assert accessor_info.l1_key
  1050. l1_key = accessor_info.l1_key
  1051. l2_key = accessor_info.l2_key
  1052. # Set source strings for debug info
  1053. mod_dict_source = f"{base_source_name}.__dict__"
  1054. l1_source_name = l2_source_name = None
  1055. l1_value = l2_value = None
  1056. l1_guard_manager_enum = l2_guard_manager_enum = None
  1057. if l2_key:
  1058. l1_source = AttrSource(source.base, l1_key)
  1059. l1_source_name = l1_source.name()
  1060. l1_value = mod_dict[l1_key]
  1061. # do not guard on key order for _parameters etc unless the user code
  1062. # actually needs the key order (e.g. calling named_parameters)
  1063. l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value)
  1064. l2_source_name = source_name
  1065. l2_value = example_value
  1066. l2_guard_manager_enum = self.get_guard_manager_type(
  1067. source, example_value
  1068. )
  1069. else:
  1070. l1_source_name = source_name
  1071. l1_value = example_value
  1072. l1_guard_manager_enum = self.get_guard_manager_type(
  1073. source, example_value
  1074. )
  1075. # Get __dict__ accessor. No need to guard on dict key order, so use base
  1076. # Guard Manager
  1077. mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager(
  1078. source=mod_dict_source,
  1079. example_value=self._get_generic_dict_manager_example_value(mod_dict),
  1080. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1081. )
  1082. l1_mgr = getitem_on_dict_mgr(
  1083. mgr=mod_generic_dict_manager,
  1084. key=l1_key,
  1085. source_name=l1_source_name,
  1086. base_example_value=mod_dict,
  1087. example_value=l1_value,
  1088. guard_manager_enum=l1_guard_manager_enum,
  1089. )
  1090. if l2_key:
  1091. assert l2_source_name is not None and l2_guard_manager_enum is not None
  1092. return getitem_on_dict_mgr(
  1093. mgr=l1_mgr,
  1094. key=l2_key,
  1095. source_name=l2_source_name,
  1096. base_example_value=l1_value,
  1097. example_value=l2_value,
  1098. guard_manager_enum=l2_guard_manager_enum,
  1099. )
  1100. return l1_mgr
  1101. def requires_key_order_guarding(self, source: Source) -> bool:
  1102. source_name = source.name()
  1103. if source_name == "":
  1104. return False
  1105. obj_id = id(self.get(source_name))
  1106. return obj_id in self.key_order_guarded_dict_ids
  1107. def get_guard_manager_type(
  1108. self,
  1109. source: Source,
  1110. example_value: Optional[
  1111. Union[KeysView[Any], set[Any], frozenset[Any], dict[Any, Any]]
  1112. ],
  1113. ) -> GuardManagerType:
  1114. guard_manager_enum = GuardManagerType.GUARD_MANAGER
  1115. if self.requires_key_order_guarding(source):
  1116. # Fix this if condition
  1117. if isinstance(example_value, dict_keys):
  1118. guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
  1119. elif isinstance(example_value, (set, frozenset)):
  1120. # we don't need to guard on key order for set/frozenset
  1121. # but the if above will be true for these types as set is
  1122. # implemented using a dict in Dynamo
  1123. guard_manager_enum = GuardManagerType.GUARD_MANAGER
  1124. else:
  1125. assert isinstance(example_value, dict)
  1126. guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
  1127. return guard_manager_enum
  1128. def manager_guards_on_keys(self, mgr_enum: GuardManagerType) -> bool:
  1129. return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER
  1130. def get_global_guard_manager(self) -> GuardManager:
  1131. return self.guard_manager.root.globals_dict_manager(
  1132. f_globals=self.runtime_global_scope,
  1133. source="G",
  1134. example_value=self.scope["G"],
  1135. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1136. )
  1137. def get_guard_manager_from_source(self, source: Source) -> GuardManager:
  1138. root_guard_manager = self.guard_manager.root
  1139. example_value = None
  1140. source_name = source.name()
  1141. if source_name != "" and source_name in self._cached_guard_managers:
  1142. return self._cached_guard_managers[source_name]
  1143. if source_name != "":
  1144. example_value = self.get(source_name)
  1145. guard_manager_enum = self.get_guard_manager_type(source, example_value)
  1146. # Get base manager related information
  1147. base_source_name = None
  1148. base_example_value = None
  1149. base_guard_manager = None
  1150. base_guard_manager_enum = GuardManagerType.GUARD_MANAGER
  1151. if isinstance(source, ChainedSource):
  1152. base_source_name = source.base.name()
  1153. base_example_value = self.get(base_source_name)
  1154. base_guard_manager = self.get_guard_manager_from_source(source.base)
  1155. base_guard_manager_enum = self.get_guard_manager_type(
  1156. source.base, base_example_value
  1157. )
  1158. # Use istype instead of isinstance to check for exact type of source.
  1159. if istype(source, LocalSource):
  1160. # Refer to index in the frame's localsplus directly.
  1161. # NOTE: name order for a code object doesn't change.
  1162. # NOTE: we need to find the LAST matching index because <= 3.10 contains
  1163. # duplicate names in the case of cells: a name can be both local and cell
  1164. # and will take up 2 slots of the frame's localsplus. The correct behavior
  1165. # is to refer to the cell, which has a higher index.
  1166. framelocals_names_reversed = code_framelocals_names_reversed_cached(
  1167. self.f_code
  1168. )
  1169. framelocals_idx = (
  1170. len(framelocals_names_reversed)
  1171. - framelocals_names_reversed.index(source.local_name)
  1172. - 1
  1173. )
  1174. out = root_guard_manager.framelocals_manager(
  1175. key=(source.local_name, framelocals_idx),
  1176. source=source_name,
  1177. example_value=example_value,
  1178. guard_manager_enum=guard_manager_enum,
  1179. )
  1180. elif istype(source, GlobalSource):
  1181. # Global manager accepts a dict but it is not a DictGuardManager
  1182. # because globals dict is big and we typically guard on a very
  1183. # selected items on globals.
  1184. out = self.get_global_guard_manager().dict_getitem_manager(
  1185. key=source.global_name,
  1186. source=source_name,
  1187. example_value=example_value,
  1188. guard_manager_enum=guard_manager_enum,
  1189. )
  1190. elif istype(source, GlobalWeakRefSource):
  1191. out = self.get_global_guard_manager().global_weakref_manager(
  1192. global_name=source.global_name,
  1193. source=source_name,
  1194. example_value=example_value,
  1195. guard_manager_enum=guard_manager_enum,
  1196. )
  1197. elif istype(source, GlobalStateSource):
  1198. # Don't do anything here. We guard on global state completely in
  1199. # C++. So just return the root mgr.
  1200. return root_guard_manager
  1201. elif istype(source, ShapeEnvSource):
  1202. return root_guard_manager
  1203. elif istype(source, TypeSource):
  1204. assert base_guard_manager # to make mypy happy
  1205. out = base_guard_manager.type_manager(
  1206. source=source_name,
  1207. example_value=example_value,
  1208. guard_manager_enum=guard_manager_enum,
  1209. )
  1210. elif istype(source, TypeDictSource):
  1211. assert base_guard_manager # to make mypy happy
  1212. out = base_guard_manager.type_dict_manager(
  1213. source=source_name,
  1214. example_value=example_value,
  1215. guard_manager_enum=guard_manager_enum,
  1216. )
  1217. elif istype(source, TypeMROSource):
  1218. assert base_guard_manager # to make mypy happy
  1219. out = base_guard_manager.type_mro_manager(
  1220. source=source_name,
  1221. example_value=example_value,
  1222. guard_manager_enum=guard_manager_enum,
  1223. )
  1224. elif istype(
  1225. source,
  1226. (
  1227. OptimizerSource,
  1228. NNModuleSource,
  1229. UnspecializedNNModuleSource,
  1230. UnspecializedBuiltinNNModuleSource,
  1231. FSDPNNModuleSource,
  1232. ),
  1233. ):
  1234. assert base_guard_manager # to make mypy happy
  1235. out = base_guard_manager
  1236. elif istype(source, TorchSource):
  1237. out = root_guard_manager.lambda_manager(
  1238. python_lambda=lambda _: torch,
  1239. source=source_name,
  1240. example_value=example_value,
  1241. guard_manager_enum=guard_manager_enum,
  1242. )
  1243. elif istype(source, TorchFunctionModeStackSource):
  1244. out = root_guard_manager.lambda_manager(
  1245. python_lambda=lambda _: get_torch_function_mode_stack_at(
  1246. source._get_index()
  1247. ),
  1248. source=source_name,
  1249. example_value=example_value,
  1250. guard_manager_enum=guard_manager_enum,
  1251. )
  1252. elif istype(source, GradSource):
  1253. assert base_guard_manager # to make mypy happy
  1254. out = base_guard_manager.grad_manager(
  1255. source=source_name,
  1256. example_value=example_value,
  1257. guard_manager_enum=guard_manager_enum,
  1258. )
  1259. elif istype(source, GenericAttrSource):
  1260. assert base_guard_manager # to make mypy happy
  1261. out = base_guard_manager.generic_getattr_manager(
  1262. attr=source.member,
  1263. source=source_name,
  1264. example_value=example_value,
  1265. guard_manager_enum=guard_manager_enum,
  1266. )
  1267. elif istype(source, (AttrSource, UnspecializedParamBufferSource)):
  1268. assert base_guard_manager # to make mypy happy
  1269. assert isinstance(source, AttrSource)
  1270. if should_optimize_getattr_on_nn_module(base_example_value):
  1271. assert base_source_name
  1272. out = self.getattr_on_nn_module(
  1273. source,
  1274. base_guard_manager,
  1275. base_example_value,
  1276. example_value,
  1277. base_source_name,
  1278. source_name,
  1279. guard_manager_enum,
  1280. )
  1281. else:
  1282. out = base_guard_manager.getattr_manager(
  1283. attr=source.member,
  1284. source=source_name,
  1285. example_value=example_value,
  1286. guard_manager_enum=guard_manager_enum,
  1287. )
  1288. elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)):
  1289. assert base_guard_manager # to make mypy happy
  1290. assert isinstance(base_example_value, (dict, collections.OrderedDict))
  1291. assert isinstance(source, (DictGetItemSource, DictSubclassGetItemSource))
  1292. if isinstance(base_guard_manager, DictGuardManager):
  1293. assert self.manager_guards_on_keys(base_guard_manager_enum)
  1294. out = getitem_on_dict_manager(
  1295. source,
  1296. base_guard_manager,
  1297. base_example_value,
  1298. example_value,
  1299. guard_manager_enum,
  1300. )
  1301. else:
  1302. if isinstance(source.index, ConstDictKeySource):
  1303. raise RuntimeError(
  1304. "Expecting clean index here. Likely Dynamo forgot to mark"
  1305. " a dict as guard_on_key_order"
  1306. )
  1307. out = base_guard_manager.dict_getitem_manager(
  1308. key=source.index,
  1309. source=source_name,
  1310. example_value=example_value,
  1311. guard_manager_enum=guard_manager_enum,
  1312. )
  1313. elif istype(source, TensorPropertySource):
  1314. out = getattr(
  1315. base_guard_manager,
  1316. f"tensor_property_{source.prop.name.lower()}_manager",
  1317. )(
  1318. idx=source.idx,
  1319. source=source_name,
  1320. example_value=example_value,
  1321. guard_manager_enum=guard_manager_enum,
  1322. )
  1323. elif istype(source, IndexedSource):
  1324. assert base_guard_manager # to make mypy happy
  1325. out = base_guard_manager.indexed_manager(
  1326. idx=source.idx,
  1327. source=source_name,
  1328. example_value=example_value,
  1329. guard_manager_enum=guard_manager_enum,
  1330. )
  1331. elif istype(source, ListGetItemSource):
  1332. assert base_guard_manager # to make mypy happy
  1333. out = base_guard_manager.list_getitem_manager(
  1334. key=source.index,
  1335. source=source_name,
  1336. example_value=example_value,
  1337. guard_manager_enum=guard_manager_enum,
  1338. )
  1339. elif istype(source, GetItemSource):
  1340. assert base_guard_manager # to make mypy happy
  1341. assert not isinstance(
  1342. base_example_value, (dict, collections.OrderedDict)
  1343. ), "Use DictGetItemSource"
  1344. if isinstance(base_example_value, list) and not source.index_is_slice:
  1345. out = base_guard_manager.list_getitem_manager(
  1346. key=source.index,
  1347. source=source_name,
  1348. example_value=example_value,
  1349. guard_manager_enum=guard_manager_enum,
  1350. )
  1351. elif isinstance(base_example_value, tuple) and not source.index_is_slice:
  1352. out = base_guard_manager.tuple_getitem_manager(
  1353. key=source.index,
  1354. source=source_name,
  1355. example_value=example_value,
  1356. guard_manager_enum=guard_manager_enum,
  1357. )
  1358. else:
  1359. index = source.index
  1360. if source.index_is_slice:
  1361. index = source.unpack_slice()
  1362. out = base_guard_manager.getitem_manager(
  1363. key=index,
  1364. source=source_name,
  1365. example_value=example_value,
  1366. guard_manager_enum=guard_manager_enum,
  1367. )
  1368. elif istype(source, DefaultsSource):
  1369. assert base_guard_manager # to make mypy happy
  1370. assert base_source_name
  1371. assert callable(base_example_value)
  1372. if not source.is_kw:
  1373. out = base_guard_manager.func_defaults_manager(
  1374. source=base_source_name,
  1375. example_value=base_example_value.__defaults__,
  1376. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1377. ).getitem_manager(
  1378. key=source.idx_key,
  1379. source=source_name,
  1380. example_value=example_value,
  1381. guard_manager_enum=guard_manager_enum,
  1382. )
  1383. else:
  1384. # kwdefauts is a dict, so use a DictGuardManager
  1385. kwdefaults = base_example_value.__kwdefaults__
  1386. assert base_source_name is not None
  1387. kw_source = base_source_name + ".__kwdefaults__"
  1388. # kwdefaults is a dict. No need to guard on dict order.
  1389. dict_mgr = base_guard_manager.func_kwdefaults_manager(
  1390. source=kw_source,
  1391. example_value=kwdefaults,
  1392. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1393. )
  1394. assert not isinstance(dict_mgr, DictGuardManager)
  1395. out = dict_mgr.dict_getitem_manager(
  1396. key=source.idx_key,
  1397. source=source_name,
  1398. example_value=example_value,
  1399. guard_manager_enum=guard_manager_enum,
  1400. )
  1401. elif istype(source, NumpyTensorSource):
  1402. assert base_guard_manager # to make mypy happy
  1403. out = base_guard_manager.lambda_manager(
  1404. python_lambda=from_numpy,
  1405. source=source_name,
  1406. example_value=example_value,
  1407. guard_manager_enum=guard_manager_enum,
  1408. )
  1409. elif istype(source, SubclassAttrListSource):
  1410. assert base_guard_manager # to make mypy happy
  1411. out = base_guard_manager.lambda_manager(
  1412. python_lambda=lambda x: x.__tensor_flatten__()[0],
  1413. source=source_name,
  1414. example_value=example_value,
  1415. guard_manager_enum=guard_manager_enum,
  1416. )
  1417. elif istype(source, FlattenScriptObjectSource):
  1418. assert base_guard_manager # to make mypy happy
  1419. out = base_guard_manager.lambda_manager(
  1420. python_lambda=lambda x: x.__obj_flatten__(),
  1421. source=source_name,
  1422. example_value=example_value,
  1423. guard_manager_enum=guard_manager_enum,
  1424. )
  1425. elif istype(source, ScriptObjectQualifiedNameSource):
  1426. assert base_guard_manager # to make mypy happy
  1427. out = base_guard_manager.lambda_manager(
  1428. python_lambda=lambda x: x._type().qualified_name(),
  1429. source=source_name,
  1430. example_value=example_value,
  1431. guard_manager_enum=guard_manager_enum,
  1432. )
  1433. elif istype(source, AttrProxySource):
  1434. assert base_guard_manager # to make mypy happy
  1435. out = base_guard_manager.lambda_manager(
  1436. python_lambda=lambda x: x.get_base(),
  1437. source=source_name,
  1438. example_value=example_value,
  1439. guard_manager_enum=guard_manager_enum,
  1440. )
  1441. elif istype(source, CallMethodItemSource):
  1442. assert base_guard_manager # to make mypy happy
  1443. out = base_guard_manager.lambda_manager(
  1444. python_lambda=lambda x: x.item(),
  1445. source=source_name,
  1446. example_value=example_value,
  1447. guard_manager_enum=guard_manager_enum,
  1448. )
  1449. elif istype(source, FloatTensorSource):
  1450. assert base_guard_manager # to make mypy happy
  1451. out = base_guard_manager.lambda_manager(
  1452. python_lambda=lambda x: torch._as_tensor_fullprec(x),
  1453. source=source_name,
  1454. example_value=example_value,
  1455. guard_manager_enum=guard_manager_enum,
  1456. )
  1457. elif istype(source, TupleIteratorGetItemSource):
  1458. assert base_guard_manager # to make mypy happy
  1459. out = base_guard_manager.tuple_iterator_getitem_manager(
  1460. index=source.index,
  1461. source=source_name,
  1462. example_value=example_value,
  1463. guard_manager_enum=guard_manager_enum,
  1464. )
  1465. elif isinstance(source, ConstDictKeySource):
  1466. if not isinstance(base_guard_manager, DictGuardManager):
  1467. raise AssertionError(
  1468. "ConstDictKeySource can only work on DictGuardManager"
  1469. )
  1470. out = base_guard_manager.get_key_manager(
  1471. index=source.index,
  1472. source=source_name,
  1473. example_value=example_value,
  1474. guard_manager_enum=guard_manager_enum,
  1475. )
  1476. elif istype(source, NonSerializableSetGetItemSource):
  1477. assert base_guard_manager
  1478. out = base_guard_manager.set_getitem_manager(
  1479. index=source.index,
  1480. source=source_name,
  1481. example_value=example_value,
  1482. guard_manager_enum=guard_manager_enum,
  1483. )
  1484. elif istype(source, WeakRefCallSource):
  1485. assert base_guard_manager # to make mypy happy
  1486. out = base_guard_manager.weakref_call_manager(
  1487. source=source_name,
  1488. example_value=example_value,
  1489. guard_manager_enum=guard_manager_enum,
  1490. )
  1491. elif istype(source, CallFunctionNoArgsSource):
  1492. assert base_guard_manager # to make mypy happy
  1493. out = base_guard_manager.call_function_no_args_manager(
  1494. source=source_name,
  1495. example_value=example_value,
  1496. guard_manager_enum=guard_manager_enum,
  1497. )
  1498. elif istype(source, DataclassFieldsSource):
  1499. assert base_guard_manager
  1500. out = base_guard_manager.lambda_manager(
  1501. python_lambda=lambda x: dataclass_fields(x),
  1502. source=source_name,
  1503. example_value=example_value,
  1504. guard_manager_enum=guard_manager_enum,
  1505. )
  1506. elif istype(source, NamedTupleFieldsSource):
  1507. assert base_guard_manager
  1508. out = base_guard_manager.lambda_manager(
  1509. python_lambda=lambda x: x._fields,
  1510. source=source_name,
  1511. example_value=example_value,
  1512. guard_manager_enum=guard_manager_enum,
  1513. )
  1514. elif istype(source, CodeSource):
  1515. assert base_guard_manager # to make mypy happy
  1516. out = base_guard_manager.code_manager(
  1517. source=source_name,
  1518. example_value=example_value,
  1519. guard_manager_enum=guard_manager_enum,
  1520. )
  1521. elif istype(source, ClosureSource):
  1522. assert base_guard_manager # to make mypy happy
  1523. out = base_guard_manager.closure_manager(
  1524. source=source_name,
  1525. example_value=example_value,
  1526. guard_manager_enum=guard_manager_enum,
  1527. )
  1528. else:
  1529. raise AssertionError(
  1530. f"missing guard manager builder {source} - {source.name()}"
  1531. )
  1532. self._cached_guard_managers[source.name()] = out
  1533. return out
  1534. def get_guard_manager(self, guard: Guard) -> GuardManager:
  1535. return self.get_guard_manager_from_source(guard.originating_source)
  1536. def add_python_lambda_leaf_guard_to_root(
  1537. self,
  1538. code_parts: list[str],
  1539. verbose_code_parts: list[str],
  1540. closure_vars: Optional[dict[str, object]] = None,
  1541. is_epilogue: bool = True,
  1542. ) -> None:
  1543. if closure_vars is None:
  1544. closure_vars = _get_closure_vars()
  1545. # Adds a lambda leaf guard to the root guard manager. It wraps the
  1546. # code_parts in a function object which is then passed on to the leaf
  1547. # guard.
  1548. make_guard_fn_args = ", ".join(closure_vars.keys())
  1549. _guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args)
  1550. out: dict[str, Any] = {}
  1551. globals_for_guard_fn = {"G": self.scope["G"]}
  1552. guards_log.debug("Python shape guard function:\n%s", pycode)
  1553. exec(pycode, globals_for_guard_fn, out)
  1554. guard_fn = out["___make_guard_fn"](*closure_vars.values())
  1555. if is_epilogue:
  1556. # Epilogue guards are run after all the other guards have finished.
  1557. # If epilogue guards contain a getattr or getitem access, one of the
  1558. # other guards would fail preventing the epilogue guards to run.
  1559. self.guard_manager.root.add_epilogue_lambda_guard(
  1560. guard_fn, verbose_code_parts
  1561. )
  1562. else:
  1563. self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts)
  1564. # Warning: use this with care! This lets you access what the current
  1565. # value of the value you are guarding on is. You probably don't want
  1566. # to actually durably save this value though (because it's specific
  1567. # to this frame!) Instead, you should be reading out some property
  1568. # (like its type) which is what you permanently install into the
  1569. # guard code.
  1570. def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any:
  1571. if closure_vars is None:
  1572. closure_vars = _get_closure_vars()
  1573. return eval(name, self.scope, closure_vars)
  1574. # Registers the usage of the source name referenced by the
  1575. # string (or stored in the Guard) as being guarded upon. It's important
  1576. # to call this before generating some code that makes use of 'guard',
  1577. # because without this call, we won't actually bind the variable
  1578. # you reference in the actual guard closure (oops!)
  1579. def arg_ref(self, guard: Union[str, Guard]) -> str:
  1580. name: str
  1581. if isinstance(guard, str):
  1582. name = guard
  1583. else:
  1584. name = guard.name
  1585. base = strip_function_call(name)
  1586. if base not in self.argnames:
  1587. is_valid = torch._C._dynamo.is_valid_var_name(base)
  1588. if is_valid:
  1589. if is_valid == 2:
  1590. log.warning("invalid var name: %s", guard)
  1591. self.argnames.append(base)
  1592. return name
  1593. def _guard_on_attribute(
  1594. self,
  1595. guard: Guard,
  1596. attr_name: str,
  1597. guard_fn: Callable[[GuardBuilderBase, Guard], Any],
  1598. ) -> None:
  1599. if attr_name == "__code__":
  1600. attr_source = CodeSource(guard.originating_source)
  1601. else:
  1602. attr_source = AttrSource(guard.originating_source, attr_name) # type: ignore[assignment]
  1603. # Copy the stack info
  1604. new_guard = Guard(
  1605. attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack
  1606. )
  1607. new_guard.create(self)
  1608. # Note: the order of the guards in this file matters since we sort guards on the same object by lineno
  1609. def HASATTR(self, guard: Guard) -> None:
  1610. source = guard.originating_source
  1611. if isinstance(source, NNModuleSource):
  1612. source = source.base
  1613. if isinstance(source, CodeSource):
  1614. # No need to guard that a function has a __code__ attribute
  1615. return
  1616. assert isinstance(source, AttrSource), f"invalid source {guard.name}"
  1617. base_source = source.base
  1618. base = base_source.name()
  1619. attr = source.member
  1620. ref = self.arg_ref(base)
  1621. val = hasattr(self.get(base), attr)
  1622. code = None
  1623. if val:
  1624. code = f"hasattr({ref}, {attr!r})"
  1625. else:
  1626. code = f"not hasattr({ref}, {attr!r})"
  1627. self._set_guard_export_info(
  1628. guard, [code], provided_guarded_object=self.get(base)
  1629. )
  1630. base_manager = self.get_guard_manager_from_source(base_source)
  1631. if val:
  1632. # Just install a getattr manager. GetAttrGuardAccessor itself
  1633. # acts as hasattr guard.
  1634. example_value = self.get(source.name())
  1635. base_example_value = self.get(base)
  1636. guard_manager_enum = self.get_guard_manager_type(source, example_value)
  1637. # if the base value is nn.Module, check if we can speedup the
  1638. # guard by going through __dict__ attrs.
  1639. if should_optimize_getattr_on_nn_module(base_example_value):
  1640. self.getattr_on_nn_module(
  1641. source,
  1642. base_manager,
  1643. base_example_value,
  1644. example_value,
  1645. base,
  1646. source.name(),
  1647. guard_manager_enum,
  1648. )
  1649. else:
  1650. base_manager.getattr_manager(
  1651. attr=attr,
  1652. source=guard.name,
  1653. example_value=example_value,
  1654. guard_manager_enum=guard_manager_enum,
  1655. )
  1656. else:
  1657. base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard))
  1658. def NOT_PRESENT_IN_GENERIC_DICT(
  1659. self, guard: Guard, attr: Optional[Any] = None
  1660. ) -> None:
  1661. assert attr is not None
  1662. ref = self.arg_ref(guard)
  1663. val = self.get(guard.name)
  1664. base_manager = self.get_guard_manager(guard)
  1665. if (ref, attr) in self.already_guarded_not_present_in_generic_dict:
  1666. return
  1667. mod_dict_source = f"{guard.name}.__dict__"
  1668. mod_generic_dict_manager = base_manager.get_generic_dict_manager(
  1669. source=mod_dict_source,
  1670. example_value=self._get_generic_dict_manager_example_value(val.__dict__),
  1671. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1672. )
  1673. code = f"not ___dict_contains({attr!r}, {ref}.__dict__)"
  1674. mod_generic_dict_manager.add_dict_contains_guard(
  1675. False, attr, get_verbose_code_parts(code, guard)
  1676. )
  1677. self.already_guarded_not_present_in_generic_dict.add((ref, attr))
  1678. def TYPE_MATCH(self, guard: Guard) -> None:
  1679. # ___check_type_id is same as `id(type(x)) == y`
  1680. value = self.get(guard.name)
  1681. if isinstance(value, torch._subclasses.FakeTensor) and value.pytype:
  1682. t = value.pytype
  1683. else:
  1684. t = type(value)
  1685. if t.__qualname__ != t.__name__:
  1686. # Type match guards must be local scope, this is
  1687. # raised in self.serialize_guards
  1688. guard._unserializable = True
  1689. obj_id = self.id_ref(t, f"type({guard.name})")
  1690. code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})"
  1691. self._set_guard_export_info(guard, [code])
  1692. self.get_guard_manager(guard).add_type_match_guard(
  1693. obj_id, get_verbose_code_parts(code, guard)
  1694. )
  1695. def DICT_VERSION(self, guard: Guard) -> None:
  1696. # ___check_dict_version is same as `dict_version(x) == y`
  1697. ref = self.arg_ref(guard)
  1698. val = self.get(guard.name)
  1699. version = dict_version(self.get(guard.name))
  1700. code = f"___dict_version({ref}) == {version}"
  1701. self._set_guard_export_info(guard, [code])
  1702. # TODO(anijain2305) - Delete this when DictGuardManager uses tags
  1703. # for dicts.
  1704. self.get_guard_manager(guard).add_dict_version_guard(
  1705. val, get_verbose_code_parts(code, guard)
  1706. )
  1707. def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None:
  1708. dict_ref = self.arg_ref(guard)
  1709. maybe_not = "not " if invert else ""
  1710. code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})"
  1711. self._set_guard_export_info(guard, [code])
  1712. self.get_guard_manager(guard).add_dict_contains_guard(
  1713. not invert, key, get_verbose_code_parts(code, guard)
  1714. )
  1715. def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None:
  1716. set_ref = self.arg_ref(guard)
  1717. item = key
  1718. contains = not invert # install_dict_contains_guard inverts "contains"
  1719. code = f"set.__contains__({set_ref}, {item!r})"
  1720. self._set_guard_export_info(guard, [code])
  1721. self.get_guard_manager(guard).add_set_contains_guard(
  1722. contains, item, get_verbose_code_parts(code, guard)
  1723. )
  1724. def BOOL_MATCH(self, guard: Guard) -> None:
  1725. # checks val == True or val == False
  1726. ref = self.arg_ref(guard)
  1727. val = self.get(guard.name)
  1728. assert istype(val, bool)
  1729. code = [f"{ref} == {val!r}"]
  1730. self._set_guard_export_info(guard, code)
  1731. if val:
  1732. self.get_guard_manager(guard).add_true_match_guard(
  1733. get_verbose_code_parts(code, guard)
  1734. )
  1735. else:
  1736. self.get_guard_manager(guard).add_false_match_guard(
  1737. get_verbose_code_parts(code, guard)
  1738. )
  1739. def NONE_MATCH(self, guard: Guard) -> None:
  1740. # checks `val is None`
  1741. ref = self.arg_ref(guard)
  1742. val = self.get(guard.name)
  1743. assert val is None
  1744. code = [f"{ref} is None"]
  1745. self._set_guard_export_info(guard, code)
  1746. self.get_guard_manager(guard).add_none_match_guard(
  1747. get_verbose_code_parts(code, guard)
  1748. )
  1749. def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
  1750. return self.id_match_unchecked(guard, recompile_hint)
  1751. def id_match_unchecked(
  1752. self, guard: Guard, recompile_hint: Optional[str] = None
  1753. ) -> None:
  1754. # ___check_obj_id is same as `id(x) == y`
  1755. if isinstance(guard.originating_source, TypeSource):
  1756. # optional optimization to produce cleaner/faster guard code
  1757. return self.TYPE_MATCH(
  1758. Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) # type: ignore[arg-type]
  1759. )
  1760. ref = self.arg_ref(guard)
  1761. val = self.get(guard.name)
  1762. id_val = self.id_ref(val, guard.name)
  1763. code = f"___check_obj_id({ref}, {id_val})"
  1764. self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH")
  1765. self.get_guard_manager(guard).add_id_match_guard(
  1766. id_val, get_verbose_code_parts(code, guard, recompile_hint)
  1767. )
  1768. # Keep track of ID_MATCH'd objects. This will be used to modify the
  1769. # cache size logic
  1770. if isinstance(guard.originating_source, LocalSource):
  1771. # TODO(anijain2305) - This is currently restricted to nn.Module objects
  1772. # because many other ID_MATCH'd objects fail - like DeviceMesh.
  1773. # Increase the scope of ID_MATCH'd objects.
  1774. if isinstance(val, torch.nn.Module):
  1775. local_name = guard.originating_source.local_name
  1776. weak_id = self.lookup_weakrefs(val)
  1777. if weak_id is not None:
  1778. self.id_matched_objs[local_name] = weak_id
  1779. def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None:
  1780. ref = self.arg_ref(guard)
  1781. val = self.get(guard.name)
  1782. assert isinstance(val, torch.Tensor)
  1783. code = f"{ref} is not None"
  1784. self._set_guard_export_info(guard, [code])
  1785. self.get_guard_manager(guard).add_not_none_guard(
  1786. get_verbose_code_parts(code, guard)
  1787. )
  1788. def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None:
  1789. ref = self.arg_ref(guard)
  1790. val = self.get(guard.name)
  1791. assert isinstance(val, torch._C.DispatchKeySet)
  1792. code_parts = f"{ref}.raw_repr() == {val!r}.raw_repr()"
  1793. self.get_guard_manager(guard).add_dispatch_key_set_guard(
  1794. val, get_verbose_code_parts(code_parts, guard)
  1795. )
  1796. def NAME_MATCH(self, guard: Guard) -> None:
  1797. self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) # type: ignore[arg-type]
  1798. def DUAL_LEVEL(self, guard: Guard) -> None:
  1799. # Invalidate dual level if current dual level is different than the one
  1800. # in the fx graph
  1801. assert self.check_fn_manager.output_graph is not None
  1802. dual_level = self.check_fn_manager.output_graph.dual_level
  1803. code = [f"torch.autograd.forward_ad._current_level == {dual_level}"]
  1804. self._set_guard_export_info(guard, code)
  1805. # TODO(anijain2305) - Consider this moving this guard to C++
  1806. forward_ad = torch.autograd.forward_ad
  1807. def fn(x: Any) -> bool:
  1808. return forward_ad._current_level == dual_level
  1809. self.guard_manager.root.add_lambda_guard(
  1810. fn, get_verbose_code_parts(code, guard)
  1811. )
  1812. def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None:
  1813. # Invalidate functorch code if current level is different than
  1814. # the one when FX graph was generated
  1815. assert self.check_fn_manager.output_graph is not None
  1816. cis = self.check_fn_manager.output_graph.functorch_layers
  1817. states = [ci.get_state() for ci in cis]
  1818. code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
  1819. self._set_guard_export_info(guard, code)
  1820. # TODO(anijain2305) - Consider this moving this guard to C++
  1821. compare_fn = torch._functorch.pyfunctorch.compare_functorch_state
  1822. def fn(x: Any) -> bool:
  1823. return compare_fn(states)
  1824. self.guard_manager.root.add_lambda_guard(
  1825. fn, get_verbose_code_parts(code, guard)
  1826. )
  1827. def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard) -> None:
  1828. get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
  1829. are_inline_hooks = (
  1830. torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
  1831. )
  1832. def hooks_ids_fn(
  1833. hooks: tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]],
  1834. ) -> Optional[tuple[int, ...]]:
  1835. if not are_inline_hooks(hooks):
  1836. return None
  1837. pack_hook, unpack_hook = hooks
  1838. return tuple(map(id, hooks))
  1839. guard_hooks_ids = hooks_ids_fn(get_hooks())
  1840. code = [
  1841. f"torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == {guard_hooks_ids}"
  1842. ]
  1843. self._set_guard_export_info(guard, code)
  1844. def fn(x: Any) -> bool:
  1845. return guard_hooks_ids == hooks_ids_fn(get_hooks())
  1846. self.guard_manager.root.add_lambda_guard(
  1847. fn, get_verbose_code_parts(code, guard)
  1848. )
  1849. def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None:
  1850. value = self.get(guard.name)
  1851. original_metadata = deepcopy(self.get(guard.name).__tensor_flatten__()[1])
  1852. if hasattr(value, "__metadata_guard__"):
  1853. verify_guard_fn_signature(value)
  1854. def metadata_checker(x: Any) -> bool:
  1855. return value.__metadata_guard__(
  1856. original_metadata, x.__tensor_flatten__()[1]
  1857. )
  1858. else:
  1859. def metadata_checker(x: Any) -> bool:
  1860. return x.__tensor_flatten__()[1] == original_metadata
  1861. global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}"
  1862. self.get_guard_manager(guard).add_lambda_guard(
  1863. metadata_checker, get_verbose_code_parts(global_name, guard)
  1864. )
  1865. def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
  1866. ref = self.arg_ref(guard)
  1867. val = self.get(guard.name)
  1868. if np:
  1869. np_types: tuple[type[Any], ...] = (
  1870. np.int8,
  1871. np.int16,
  1872. np.int32,
  1873. np.int64,
  1874. np.uint8,
  1875. np.uint16,
  1876. np.uint32,
  1877. np.uint64,
  1878. np.float16,
  1879. np.float32,
  1880. np.float64,
  1881. )
  1882. else:
  1883. np_types = ()
  1884. ok_mutable_types = (list, set)
  1885. ok_types = tuple(
  1886. common_constant_types
  1887. | {
  1888. type,
  1889. tuple,
  1890. frozenset,
  1891. slice,
  1892. range,
  1893. dict_keys,
  1894. torch.Size,
  1895. *np_types,
  1896. *ok_mutable_types,
  1897. }
  1898. )
  1899. if torch.distributed.is_available():
  1900. from torch.distributed.device_mesh import DeviceMesh
  1901. from torch.distributed.tensor.placement_types import (
  1902. _StridedShard,
  1903. Partial,
  1904. Replicate,
  1905. Shard,
  1906. )
  1907. ok_types = ok_types + (
  1908. Shard,
  1909. Replicate,
  1910. Partial,
  1911. DeviceMesh,
  1912. _StridedShard,
  1913. )
  1914. from torch.export.dynamic_shapes import _IntWrapper
  1915. ok_types = ok_types + (_IntWrapper,)
  1916. import torch.utils._pytree as pytree
  1917. assert istype(val, ok_types) or pytree.is_constant_class(type(val)), (
  1918. f"Unexpected type {type(val)}"
  1919. )
  1920. # Special case for nan because float("nan") == float("nan") evaluates to False
  1921. if istype(val, float) and math.isnan(val):
  1922. self.TYPE_MATCH(guard)
  1923. code = []
  1924. code.append(f"__math_isnan({ref})")
  1925. self._set_guard_export_info(guard, code)
  1926. self.get_guard_manager(guard).add_lambda_guard(
  1927. _get_closure_vars()["__math_isnan"], # type: ignore[arg-type]
  1928. get_verbose_code_parts(code, guard),
  1929. )
  1930. return
  1931. # Python math library doesn't support complex nan, so we need to use numpy
  1932. if istype(val, complex) and np.isnan(val):
  1933. self.TYPE_MATCH(guard)
  1934. code = []
  1935. code.append(f"__numpy_isnan({ref})")
  1936. self._set_guard_export_info(guard, code)
  1937. self.get_guard_manager(guard).add_lambda_guard(
  1938. _get_closure_vars()["__numpy_isnan"], # type: ignore[arg-type]
  1939. get_verbose_code_parts(code, guard),
  1940. )
  1941. return
  1942. # Construct a debug string to put into the c++ equals match guard.
  1943. code = [f"{ref} == {val!r}"]
  1944. if istype(val, ok_mutable_types):
  1945. # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object
  1946. # is immutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the
  1947. # pointer equality check.
  1948. val = deepcopy(val)
  1949. verbose_code_parts = get_verbose_code_parts(code, guard)
  1950. if recompile_hint:
  1951. verbose_code_parts = [
  1952. f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
  1953. ]
  1954. self.get_guard_manager(guard).add_equals_match_guard(val, verbose_code_parts)
  1955. self._set_guard_export_info(guard, code)
  1956. return
  1957. def CONSTANT_MATCH(self, guard: Guard) -> None:
  1958. val = self.get(guard.name)
  1959. if istype(val, bool):
  1960. self.BOOL_MATCH(guard)
  1961. elif val is None:
  1962. self.NONE_MATCH(guard)
  1963. elif istype(val, types.CodeType):
  1964. self.ID_MATCH(guard)
  1965. else:
  1966. self.EQUALS_MATCH(guard)
  1967. def NN_MODULE(self, guard: Guard) -> None:
  1968. # don't support this in serialization because it uses unsupported ID_MATCH
  1969. self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]")
  1970. val = self.get(guard.name)
  1971. if hasattr(val, "training"):
  1972. assert istype(val.training, bool)
  1973. if not self.guard_nn_modules:
  1974. # If guard_nn_modules is true, we will guard on the right set of guards
  1975. self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type]
  1976. else:
  1977. exc.unimplemented_v2(
  1978. gb_type="Attempted to guard on uninitialized nn.Module",
  1979. context="",
  1980. explanation="Attempted to setup an NN_MODULE guard on uninitialized "
  1981. f"nn.Module subclass `{type(val)}`.",
  1982. hints=[
  1983. "Ensure the `nn.Module` subclass instance has called `super().__init__()`.",
  1984. ],
  1985. )
  1986. def FUNCTION_MATCH(self, guard: Guard) -> None:
  1987. """things like torch.add and user defined functions"""
  1988. # don't support this in serialization because it uses unsupported ID_MATCH
  1989. return self.ID_MATCH(guard)
  1990. def CLOSURE_MATCH(self, guard: Guard) -> None:
  1991. """matches a closure by __code__ id."""
  1992. # don't support this in serialization because it uses unsupported FUNCTION_MATCH
  1993. val = self.get(guard.name)
  1994. # Strictly only want user-defined functions
  1995. if type(val) == types.FunctionType and hasattr(val, "__code__"):
  1996. self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
  1997. self._guard_on_attribute(guard, "__code__", GuardBuilder.FUNCTION_MATCH) # type: ignore[arg-type]
  1998. else:
  1999. self.FUNCTION_MATCH(guard)
  2000. def BUILTIN_MATCH(self, guard: Guard) -> None:
  2001. if self.save_guards:
  2002. # Record which builtin variables are used for pruning later.
  2003. if isinstance(guard.originating_source, DictGetItemSource):
  2004. self.check_fn_manager.used_builtin_vars.add(
  2005. guard.originating_source.index
  2006. )
  2007. return self.id_match_unchecked(guard)
  2008. return self.ID_MATCH(guard)
  2009. def SEQUENCE_LENGTH(self, guard: Guard) -> None:
  2010. # This guard is used to check length of PySequence objects like list,
  2011. # tuple, collections.deque etc
  2012. ref = self.arg_ref(guard)
  2013. value = self.get(guard.name)
  2014. if not isinstance(value, dict):
  2015. # C++ DICT_LENGTH checks for type
  2016. self.TYPE_MATCH(guard)
  2017. code = []
  2018. if len(value) == 0:
  2019. code.append(f"not {ref}")
  2020. else:
  2021. code.append(f"len({ref}) == {len(value)}")
  2022. self._set_guard_export_info(guard, code)
  2023. if isinstance(value, dict):
  2024. self.get_guard_manager(guard).add_dict_length_check_guard(
  2025. len(value), get_verbose_code_parts(code, guard)
  2026. )
  2027. else:
  2028. self.get_guard_manager(guard).add_length_check_guard(
  2029. len(value), get_verbose_code_parts(code, guard)
  2030. )
  2031. def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None:
  2032. ref = self.arg_ref(guard)
  2033. value = self.get(guard.name)
  2034. t = type(value)
  2035. code = []
  2036. code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}")
  2037. self._set_guard_export_info(guard, code)
  2038. t = type(value)
  2039. obj_id = self.id_ref(t, f"type({guard.name})")
  2040. self.get_guard_manager(guard).add_tuple_iterator_length_guard(
  2041. tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard)
  2042. )
  2043. def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None:
  2044. ref = self.arg_ref(guard)
  2045. value = self.get(guard.name)
  2046. t = type(value)
  2047. code = []
  2048. normalized_range_iter = normalize_range_iter(value)
  2049. code.append(f"___normalize_range_iter({ref}) == {normalized_range_iter}")
  2050. self._set_guard_export_info(guard, code)
  2051. t = type(value)
  2052. obj_id = self.id_ref(t, f"type({guard.name})")
  2053. start, stop, step = normalized_range_iter
  2054. self.get_guard_manager(guard).add_range_iterator_match_guard(
  2055. start, stop, step, obj_id, get_verbose_code_parts(code, guard)
  2056. )
  2057. # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
  2058. def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None:
  2059. if self.save_guards:
  2060. if name := get_local_source_name(source_b):
  2061. self.check_fn_manager.additional_used_local_vars.add(name)
  2062. if name := get_global_source_name(source_b):
  2063. self.check_fn_manager.additional_used_global_vars.add(name)
  2064. ref_a = self.arg_ref(guard)
  2065. ref_b = self.arg_ref(source_b.name())
  2066. if is_from_optimizer_source(
  2067. guard.originating_source
  2068. ) or is_from_optimizer_source(source_b):
  2069. return
  2070. # Check that the guard has not been inserted already
  2071. key = (ref_a, ref_b)
  2072. if key in self._cached_duplicate_input_guards:
  2073. return
  2074. self._cached_duplicate_input_guards.add((ref_a, ref_b))
  2075. self._cached_duplicate_input_guards.add((ref_b, ref_a))
  2076. code = [f"{ref_b} is {ref_a}"]
  2077. self._set_guard_export_info(guard, code)
  2078. if config.use_lamba_guard_for_object_aliasing:
  2079. # Save the code part so that we can install a lambda guard at the
  2080. # end. Read the Note - On Lambda guarding of object aliasing - to
  2081. # get more information.
  2082. code_part = code[0]
  2083. verbose_code_part = get_verbose_code_parts(code_part, guard)[0]
  2084. self.object_aliasing_guard_codes.append((code_part, verbose_code_part))
  2085. else:
  2086. install_object_aliasing_guard(
  2087. self.get_guard_manager(guard),
  2088. self.get_guard_manager_from_source(source_b),
  2089. get_verbose_code_parts(code, guard),
  2090. )
  2091. def WEAKREF_ALIVE(self, guard: Guard) -> None:
  2092. code = [f"{self.arg_ref(guard)} is not None"]
  2093. self._set_guard_export_info(guard, code)
  2094. self.get_guard_manager(guard).add_not_none_guard(
  2095. get_verbose_code_parts(code, guard)
  2096. )
  2097. def MAPPING_KEYS_CHECK(self, guard: Guard) -> None:
  2098. """Guard on the key order of types.MappingProxyType object"""
  2099. ref = self.arg_ref(guard)
  2100. value = self.get(guard.name)
  2101. code = []
  2102. code.append(f"list({ref}.keys()) == {list(value.keys())}")
  2103. self._set_guard_export_info(guard, code)
  2104. self.get_guard_manager(guard).add_mapping_keys_guard(value, code)
  2105. def DICT_KEYS_MATCH(self, guard: Guard) -> None:
  2106. """Insert guard to check that the keys of a dict are same"""
  2107. ref = self.arg_ref(guard)
  2108. value = self.get(guard.name)
  2109. if value is torch.utils._pytree.SUPPORTED_NODES:
  2110. # For SUPPORTED_NODES, we can guard on the dictionary version (PEP509).
  2111. self.DICT_VERSION(guard)
  2112. return
  2113. self.SEQUENCE_LENGTH(guard)
  2114. code = []
  2115. # Ensure that we call dict.keys and not value.keys (which can call
  2116. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  2117. # to traverse the dictionary, which uses the internal data structure and
  2118. # does not call the overridden keys method.
  2119. code.append(f"list(dict.keys({ref})) == {list(builtin_dict_keys(value))!r}")
  2120. self._set_guard_export_info(guard, code)
  2121. if self.requires_key_order_guarding(guard.originating_source):
  2122. self.guard_on_dict_keys_and_order(value, guard)
  2123. else:
  2124. self.guard_on_dict_keys_and_ignore_order(value, guard)
  2125. def EMPTY_NN_MODULE_HOOKS_DICT(self, guard: Guard) -> None:
  2126. """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
  2127. if config.skip_nnmodule_hook_guards:
  2128. # This is unsafe if you add/remove a hook on nn module variable
  2129. return
  2130. self.SEQUENCE_LENGTH(guard)
  2131. def GRAD_MODE(self, guard: Guard) -> None:
  2132. pass # we always guard on this via GlobalStateGuard()
  2133. def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
  2134. pass # we always guard on this via GlobalStateGuard()
  2135. def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
  2136. pass # we always guard on this via GlobalStateGuard()
  2137. def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
  2138. pass # we always guard on this via GlobalStateGuard()
  2139. def DEFAULT_DEVICE(self, guard: Guard) -> None:
  2140. """Guard on CURRENT_DEVICE per torch.utils._device"""
  2141. assert guard.source is GuardSource.GLOBAL
  2142. assert self.check_fn_manager.output_graph is not None
  2143. code = [
  2144. f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}"
  2145. ]
  2146. self._set_guard_export_info(guard, code)
  2147. self.get_guard_manager(guard).add_default_device_guard(
  2148. get_verbose_code_parts(code, guard)
  2149. )
  2150. def SHAPE_ENV(self, guard: Guard) -> None:
  2151. from torch._dynamo.output_graph import OutputGraph
  2152. assert guard.name == ""
  2153. output_graph = self.check_fn_manager.output_graph
  2154. assert output_graph is not None
  2155. if self.check_fn_manager.shape_code_parts is not None:
  2156. shape_code_parts = self.check_fn_manager.shape_code_parts
  2157. python_code_parts = shape_code_parts.python_code_parts
  2158. verbose_code_parts = shape_code_parts.verbose_code_parts
  2159. if shape_code_parts.cpp_code_parts is not None:
  2160. cpp_code_parts = shape_code_parts.cpp_code_parts
  2161. python_fallback = shape_code_parts.python_fallback
  2162. else:
  2163. # Let's handle ShapeEnv guards. To do this, we will resolve
  2164. # shape variables to sources from tracked_fakes. This must happen after
  2165. # tensor checks.
  2166. # NB: self.output_graph can be None in the debug_nops tests
  2167. assert isinstance(output_graph, OutputGraph)
  2168. fs = output_graph.tracked_fakes
  2169. input_contexts = [a.symbolic_context for a in fs]
  2170. def get_sources(t_id: int, dim: int) -> list[Source]:
  2171. # Looks up base sources mapped to a tensor id and uses them to create
  2172. # sources for the corresponding tensor dimension.
  2173. return [
  2174. TensorPropertySource(source, TensorProperty.SIZE, dim)
  2175. for source in output_graph.tracked_fakes_id_to_source[t_id]
  2176. ]
  2177. assert output_graph.shape_env is not None
  2178. if output_graph.export_constraints:
  2179. names: dict[str, tuple[int, int]] = {}
  2180. source_pairs: list[tuple[Source, Source]] = []
  2181. derived_equalities: list[ # type: ignore[type-arg]
  2182. tuple[Source, Union[Source, Symbol], Callable]
  2183. ] = []
  2184. phantom_symbols: dict[str, Symbol] = {}
  2185. relaxed_sources: set[Source] = set()
  2186. for constraint in output_graph.export_constraints: # type: ignore[attr-defined]
  2187. if constraint.t_id in output_graph.tracked_fakes_id_to_source:
  2188. torch.export.dynamic_shapes._process_equalities(
  2189. constraint,
  2190. get_sources,
  2191. output_graph.shape_env,
  2192. names,
  2193. source_pairs,
  2194. derived_equalities,
  2195. phantom_symbols,
  2196. relaxed_sources,
  2197. )
  2198. else:
  2199. log.warning("Untracked tensor used in export constraints")
  2200. equalities_inputs = EqualityConstraint(
  2201. source_pairs=source_pairs,
  2202. derived_equalities=derived_equalities,
  2203. phantom_symbols=list(phantom_symbols.values()),
  2204. relaxed_sources=relaxed_sources,
  2205. warn_only=False,
  2206. )
  2207. else:
  2208. equalities_inputs = None
  2209. def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
  2210. return output_graph.shape_env.produce_guards_verbose(
  2211. [a.fake for a in fs], # type: ignore[misc]
  2212. [a.source for a in fs],
  2213. input_contexts=input_contexts, # type: ignore[arg-type]
  2214. equalities_inputs=equalities_inputs,
  2215. source_ref=self.source_ref,
  2216. # Export keeps static.
  2217. ignore_static=(not output_graph.export),
  2218. langs=langs,
  2219. )
  2220. if config.enable_cpp_symbolic_shape_guards:
  2221. try:
  2222. # For exporting we need the python code parts
  2223. python_code_parts, verbose_code_parts, cpp_code_parts = (
  2224. _get_code_parts(("python", "verbose_python", "cpp")) # type: ignore[assignment]
  2225. )
  2226. python_fallback = False
  2227. except OverflowError:
  2228. # Cannot use int64_t
  2229. python_fallback = True
  2230. python_code_parts, verbose_code_parts = _get_code_parts(
  2231. ("python", "verbose_python")
  2232. )
  2233. else:
  2234. python_fallback = True
  2235. python_code_parts, verbose_code_parts = _get_code_parts(
  2236. ("python", "verbose_python")
  2237. )
  2238. # When exporting, we may work with the shape constraints some more in
  2239. # postprocessing, so don't freeze yet
  2240. if not output_graph.export:
  2241. output_graph.shape_env.freeze()
  2242. if self.save_guards:
  2243. # For SHAPE_ENV we want to skip serializing the entire ShapeEnv so instead
  2244. # we directly serialize the generated code here.
  2245. maybe_cpp_code_parts = locals().get("cpp_code_parts")
  2246. assert maybe_cpp_code_parts is None or isinstance(
  2247. maybe_cpp_code_parts, _CppShapeGuardsHelper
  2248. )
  2249. maybe_shape_env_sources = (
  2250. []
  2251. if maybe_cpp_code_parts is None
  2252. else list(maybe_cpp_code_parts.source_to_symbol.keys())
  2253. )
  2254. self.check_fn_manager.shape_code_parts = ShapeCodeParts(
  2255. python_code_parts=python_code_parts,
  2256. verbose_code_parts=verbose_code_parts,
  2257. cpp_code_parts=maybe_cpp_code_parts,
  2258. python_fallback=python_fallback,
  2259. shape_env_sources=maybe_shape_env_sources,
  2260. )
  2261. for code in python_code_parts.exprs:
  2262. self._set_guard_export_info(guard, [code])
  2263. # Make ShapeEnv guards available for testing.
  2264. if compile_context := CompileContext.try_get():
  2265. compile_context.shape_env_guards.extend(verbose_code_parts.exprs)
  2266. int_source_to_symbol = []
  2267. float_source_to_symbol = []
  2268. if not python_fallback:
  2269. assert cpp_code_parts # type: ignore[possibly-undefined]
  2270. code_parts, source_to_symbol = (
  2271. cpp_code_parts.exprs,
  2272. cpp_code_parts.source_to_symbol,
  2273. )
  2274. if not code_parts:
  2275. return
  2276. for source, symbol in source_to_symbol.items():
  2277. if isinstance(source, ConstantSource):
  2278. python_fallback = True
  2279. else:
  2280. example_value = self.get(
  2281. source.name(),
  2282. closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
  2283. )
  2284. if isinstance(example_value, int):
  2285. int_source_to_symbol.append((source, symbol))
  2286. elif isinstance(example_value, float):
  2287. float_source_to_symbol.append((source, symbol))
  2288. else:
  2289. # SymInts/SymFloats go through python guard as we only support
  2290. # int64_t/double in C++ guards for now.
  2291. python_fallback = True
  2292. if not python_fallback:
  2293. import ctypes
  2294. from torch._inductor.codecache import CppCodeCache
  2295. assert cpp_code_parts # type: ignore[possibly-undefined]
  2296. code_parts, source_to_symbol = (
  2297. cpp_code_parts.exprs,
  2298. cpp_code_parts.source_to_symbol,
  2299. )
  2300. source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol)
  2301. try:
  2302. guard_managers = [
  2303. self.get_guard_manager_from_source(IndexedSource(source, i))
  2304. for i, source in enumerate(source_to_symbol)
  2305. ]
  2306. int_symbols_str = ", ".join(
  2307. f"{symbol} = int_values[{i}]"
  2308. for i, (_, symbol) in enumerate(int_source_to_symbol)
  2309. )
  2310. float_symbols_str = ", ".join(
  2311. f"{symbol} = float_values[{i}]"
  2312. for i, (_, symbol) in enumerate(float_source_to_symbol)
  2313. )
  2314. if int_symbols_str:
  2315. int_symbols_str = f"int64_t {int_symbols_str};"
  2316. if float_symbols_str:
  2317. float_symbols_str = f"double {float_symbols_str};"
  2318. func_str = textwrap.dedent(
  2319. f"""
  2320. #include <algorithm>
  2321. #include <cstdint>
  2322. #include <cmath>
  2323. #include <c10/util/generic_math.h>
  2324. #if defined(_MSC_VER)
  2325. # define EXTERN_DLL_EXPORT extern "C" __declspec(dllexport)
  2326. #else
  2327. # define EXTERN_DLL_EXPORT extern "C"
  2328. #endif
  2329. EXTERN_DLL_EXPORT int8_t guard(int64_t *int_values, double *float_values) {{
  2330. {int_symbols_str}
  2331. {float_symbols_str}
  2332. return ({") && (".join(code_parts)});
  2333. }}
  2334. """
  2335. )
  2336. guards_log.debug(
  2337. "C++ shape guard function: %s %s",
  2338. func_str,
  2339. verbose_code_parts.exprs,
  2340. )
  2341. clib = CppCodeCache.load(func_str)
  2342. cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value
  2343. assert cguard
  2344. except torch._inductor.exc.InvalidCxxCompiler:
  2345. # No valid C++ compiler to compile the shape guard
  2346. pass
  2347. else:
  2348. install_symbolic_shape_guard(
  2349. guard_managers,
  2350. len(int_source_to_symbol),
  2351. len(float_source_to_symbol),
  2352. cguard,
  2353. clib,
  2354. verbose_code_parts.exprs,
  2355. )
  2356. return
  2357. # Install all the symbolic guards in one python lambda guard. These are run
  2358. # at the very end of the RootGuardManager via epilogue guards.
  2359. # TODO(anijain2305,williamwen42) - Consider moving this to C++.
  2360. if python_code_parts.exprs:
  2361. self.add_python_lambda_leaf_guard_to_root(
  2362. python_code_parts.exprs,
  2363. verbose_code_parts.exprs,
  2364. closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
  2365. )
  2366. def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None:
  2367. if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module():
  2368. return
  2369. # For tensors that are part of the Dynamo extracted Fx graph module, an
  2370. # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these
  2371. # will be lifted as inputs and have a TENSOR_MATCH guard.
  2372. if match_on_id_for_tensor(guard):
  2373. self.ID_MATCH(guard)
  2374. else:
  2375. if isinstance(value, TensorWeakRef):
  2376. value = value()
  2377. value = value if value is not None else self.get(guard.name)
  2378. pytype = type(value)
  2379. dispatch_keys = torch._C._dispatch_keys(value)
  2380. if isinstance(value, torch._subclasses.FakeTensor):
  2381. if value.pytype is not None:
  2382. pytype = value.pytype
  2383. if value.dispatch_keys is not None:
  2384. dispatch_keys = value.dispatch_keys
  2385. assert isinstance(value, torch.Tensor)
  2386. if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter):
  2387. metrics_context = get_metrics_context()
  2388. metrics_context.increment("param_numel", value.numel())
  2389. metrics_context.increment("param_bytes", value.nbytes)
  2390. metrics_context.increment("param_count", 1)
  2391. tensor_name = self.arg_ref(guard)
  2392. # [Note - On Export Tensor Guards]
  2393. #
  2394. # In eager mode, tensor guards are evaluated through C++, in guards.cpp
  2395. # see [Note - On Eager Tensor Guards] for more info.
  2396. #
  2397. # In export mode, we instead maintain parallel logic between C++ and python
  2398. # here, with an exception of checking the dispatch key - with the idea that a dispatch key
  2399. # is an entirely runtime notion that would make no sense to keep in an exported graph.
  2400. #
  2401. # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although
  2402. # not entirely true.
  2403. # For example, suppose one of the input tensors had the negative dispatch key.
  2404. # You should end up with a graph that is specialized for tensors that have a negative dispatch key.
  2405. # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated.
  2406. # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't
  2407. # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key.
  2408. # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported
  2409. # subset of keys during export.
  2410. #
  2411. # The list of tensor fields and calls we care about can be found in `terms` below.
  2412. # TODO(voz): We are missing storage offset in all our tensor guards?
  2413. code: list[str] = []
  2414. assert self.check_fn_manager.output_graph is not None
  2415. if self.check_fn_manager.output_graph.export:
  2416. self.TYPE_MATCH(guard)
  2417. terms = [
  2418. "dtype",
  2419. "device",
  2420. "requires_grad",
  2421. "ndimension()",
  2422. ]
  2423. for term in terms:
  2424. real_value = self.get(tensor_name + "." + term)
  2425. if istype(real_value, (torch.device, torch.dtype)):
  2426. # copy pasted from EQUALS_MATCH
  2427. code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}")
  2428. else:
  2429. code.append(f"{tensor_name}.{term} == {real_value}")
  2430. else:
  2431. guard_manager = self.get_guard_manager(guard)
  2432. # skip_no_tensor_aliasing_guards_on_parameters bring
  2433. # unsoundness. If you compile a function with two different
  2434. # parameters, but later on you pass on same tensor as two
  2435. # different outputs (aliasing), Dynamo will not detect this.
  2436. # But we deliberately take this soundness hit because this
  2437. # usecase is quite rare and there is substantial reduction in
  2438. # guard overhead.
  2439. # For numpy tensors, since those are ephemeral, we don't have to
  2440. # insert aliasing guards on them
  2441. if not (
  2442. config.skip_no_tensor_aliasing_guards_on_parameters
  2443. and (
  2444. istype(value, torch.nn.Parameter)
  2445. or is_from_unspecialized_builtin_nn_module_source(
  2446. guard.originating_source
  2447. )
  2448. )
  2449. ) and not isinstance(guard.originating_source, NumpyTensorSource):
  2450. # Keep track of all the tensor guard managers to insert
  2451. # NoAliasing check at the end.
  2452. self.no_tensor_aliasing_names.append(tensor_name)
  2453. self.no_tensor_aliasing_guard_managers.append(guard_manager)
  2454. output_graph = self.check_fn_manager.output_graph
  2455. metadata = output_graph.input_source_to_sizes_strides[
  2456. guard.originating_source
  2457. ]
  2458. size = convert_to_concrete_values(metadata["size"])
  2459. stride = convert_to_concrete_values(metadata["stride"])
  2460. verbose_code_parts = get_verbose_code_parts(
  2461. get_tensor_guard_code_part(
  2462. value,
  2463. tensor_name,
  2464. size,
  2465. stride,
  2466. pytype,
  2467. dispatch_keys,
  2468. ),
  2469. guard,
  2470. )
  2471. guard_manager.add_tensor_match_guard(
  2472. value,
  2473. size, # type: ignore[arg-type]
  2474. stride, # type: ignore[arg-type]
  2475. tensor_name,
  2476. verbose_code_parts,
  2477. pytype,
  2478. dispatch_keys,
  2479. )
  2480. # We consider TENSOR_MATCH guard to be important enough to be
  2481. # included in diff guard manager by default.
  2482. if not isinstance(value, torch.nn.Parameter):
  2483. self.guard_manager.diff_guard_sources.add(guard.name)
  2484. # A frame is valid for reuse with dynamic dimensions if the new
  2485. # (user-requested) dynamic dimensions are a subset of the old
  2486. # (already compiled) dynamic dimensions.
  2487. #
  2488. # It's a little non-obvious why you'd want this: in particular,
  2489. # if an already compiled frame matches all of the guards, why
  2490. # not just use it, why force a recompile?
  2491. #
  2492. # We force it for two reasons:
  2493. #
  2494. # - The user *required* us to compile with a new dynamic dimension,
  2495. # we should not ignore that and serve up the old, specialized
  2496. # frame. Listen to the user!
  2497. #
  2498. # - In fact, we are obligated to *raise an error* if we fail to
  2499. # make the requested dimension dynamic. If we don't
  2500. # recompile, we can't tell if that dimension can actually be
  2501. # made dynamic.
  2502. #
  2503. # If the new dynamic dims are a subset of the old, we already know
  2504. # we can make them dynamic (since we made them dynamic in old).
  2505. # This is slightly unsound, because maybe your input size is
  2506. # [s0, s0, s1] and so you can do it dynamic if you say dynamic
  2507. # dims {0, 1, 2} but you can't if you only do {0, 2} (because now
  2508. # the second s0 is specialized). But we're not entirely sure if
  2509. # this is a good idea anyway lol... (if you want to try removing
  2510. # this logic, be my guest! -- ezyang 2024)
  2511. #
  2512. assert guard.source is not None
  2513. static, _reason = tensor_always_has_static_shape(
  2514. value, is_tensor=True, tensor_source=guard.originating_source
  2515. )
  2516. if not static:
  2517. if hasattr(value, "_dynamo_dynamic_indices"):
  2518. dynamic_indices = value._dynamo_dynamic_indices
  2519. code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950
  2520. code.append(code_part)
  2521. self.get_guard_manager(guard).add_dynamic_indices_guard(
  2522. dynamic_indices, get_verbose_code_parts(code_part, guard)
  2523. )
  2524. # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of
  2525. # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled.
  2526. else:
  2527. code_part = (
  2528. f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False"
  2529. )
  2530. code.append(code_part)
  2531. self.get_guard_manager(guard).add_no_hasattr_guard(
  2532. "_dynamo_dynamic_indices",
  2533. get_verbose_code_parts(code_part, guard),
  2534. )
  2535. if len(code) > 0:
  2536. self._set_guard_export_info(guard, code)
  2537. # A util that in the case of export, adds data onto guards
  2538. def _set_guard_export_info(
  2539. self,
  2540. guard: Guard,
  2541. code_list: list[str],
  2542. provided_guarded_object: Optional[Any] = None,
  2543. provided_func_name: Optional[str] = None,
  2544. ) -> None:
  2545. # WARNING: It is important that cur_frame/caller do NOT stay in
  2546. # the current frame, because they will keep things live longer
  2547. # than they should. See TestMisc.test_release_module_memory
  2548. cur_frame = currentframe()
  2549. assert cur_frame is not None
  2550. caller = cur_frame.f_back
  2551. del cur_frame
  2552. assert caller is not None
  2553. func_name = provided_func_name or caller.f_code.co_name
  2554. del caller
  2555. # We use func_name for export, so might as well get a nice defensive check out of it
  2556. assert func_name in self.__class__.__dict__, (
  2557. f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}"
  2558. )
  2559. # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD)
  2560. if provided_guarded_object is None:
  2561. name = guard.name
  2562. guarded_object = None if not name else self.get(name)
  2563. else:
  2564. guarded_object = provided_guarded_object
  2565. guarded_object_type = (
  2566. weakref.ref(type(guarded_object)) if guarded_object is not None else None
  2567. )
  2568. obj_ref = None
  2569. # Not necessary to have weakref for Enum type, but there is a bug that
  2570. # makes hasattr(guarded_object.__class__, "__weakref__") return True.
  2571. supports_weakref = (
  2572. getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
  2573. )
  2574. # See D64140537 for why we are checking for tuple.
  2575. if supports_weakref and not isinstance(
  2576. guarded_object, (enum.Enum, tuple, weakref.ProxyTypes)
  2577. ):
  2578. obj_ref = weakref.ref(guarded_object)
  2579. guard.set_export_info(
  2580. func_name,
  2581. guarded_object_type,
  2582. code_list,
  2583. obj_ref,
  2584. )
  2585. # Common Sub-Expression Elimination for Python expressions.
  2586. #
  2587. # There are 2 steps to this pass:
  2588. # 1. Count the frequency of each sub-expression (i.e. inner
  2589. # node in the AST tree)
  2590. #
  2591. # 2. Replace those that occur more than once by a fresh variable 'v'.
  2592. # 'v' will be defined in the 'preface' list (output argument to
  2593. # 'NodeTransformer')
  2594. #
  2595. # NB: the use of 'ast.unparse' while visiting the nodes makes this pass
  2596. # quadratic on the depth of the tree.
  2597. #
  2598. # NB: this pass creates a new variable for each AST node that is repeated
  2599. # more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c'
  2600. # and 'a.b' are also used 10 times. So, there will be a new variable for
  2601. # each of them.
  2602. class PyExprCSEPass:
  2603. # Maximum number of times a given expression can be used without being
  2604. # replaced by a fresh variable.
  2605. USE_THRESHOLD = 1
  2606. # Ad-Hoc: AST nodes this pass focuses on.
  2607. ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript)
  2608. @dataclasses.dataclass
  2609. class Config:
  2610. expr_count: dict[str, int]
  2611. expr_to_name: dict[str, str]
  2612. class ExprCounter(ast.NodeVisitor):
  2613. def __init__(self, config: PyExprCSEPass.Config) -> None:
  2614. self._config = config
  2615. def visit(self, node: ast.AST) -> None:
  2616. if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES):
  2617. self._config.expr_count[_ast_unparse(node)] += 1
  2618. super().visit(node)
  2619. class Replacer(ast.NodeTransformer):
  2620. def __init__(
  2621. self,
  2622. config: PyExprCSEPass.Config,
  2623. gen_name: Callable[[], str],
  2624. ) -> None:
  2625. super().__init__()
  2626. self._config = config
  2627. self._gen_name = gen_name
  2628. self.preface: list[str] = []
  2629. def visit(self, node: ast.AST) -> Any:
  2630. if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES):
  2631. expr = _ast_unparse(node)
  2632. # Replacement only occurs if a given expression is used more
  2633. # than once.
  2634. if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD:
  2635. if expr not in self._config.expr_to_name:
  2636. # Parent 'visit' is called so that we CSE the inner expressions first.
  2637. #
  2638. # The resulting expression is used as right-hand-side of the variable
  2639. # assignment. i.e. we are CSE-ing the children before the parents.
  2640. #
  2641. # Indexing still uses the old 'node', since that's what was counted
  2642. # by the 'NodeVisitor'.
  2643. node_ = super().visit(node)
  2644. expr_ = _ast_unparse(node_)
  2645. var_name = self._gen_name()
  2646. self.preface.append(f"{var_name} = {expr_}")
  2647. self._config.expr_to_name[expr] = var_name
  2648. else:
  2649. var_name = self._config.expr_to_name[expr]
  2650. return ast.Name(var_name, ast.Load())
  2651. return super().visit(node)
  2652. def __init__(self) -> None:
  2653. self._counter = 0
  2654. self._config = self.Config(
  2655. expr_count=collections.defaultdict(lambda: 0), expr_to_name={}
  2656. )
  2657. def _new_var(self, prefix: str = "_var") -> str:
  2658. name = f"{prefix}{self._counter}"
  2659. self._counter += 1
  2660. return name
  2661. def count(self, exprs: list[str]) -> None:
  2662. counter = self.ExprCounter(self._config)
  2663. for e in exprs:
  2664. try:
  2665. counter.visit(ast.parse(e))
  2666. except SyntaxError as ex:
  2667. log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e)
  2668. raise
  2669. def replace(self, expr: str) -> tuple[list[str], str]:
  2670. replacer = self.Replacer(self._config, self._new_var)
  2671. new_node = replacer.visit(ast.parse(expr))
  2672. return replacer.preface, _ast_unparse(new_node)
  2673. def must_add_nn_module_guards(guard: Guard) -> bool:
  2674. # For config.guard_nn_modules=False, we can skip all the guards that
  2675. # originate from inside of nn module except for a few categories.
  2676. return (
  2677. # Guard for defaults
  2678. isinstance(guard.originating_source, DefaultsSource)
  2679. # Guard using dict tags if the config flag is set
  2680. or (
  2681. config.guard_nn_modules_using_dict_tags
  2682. and guard.create_fn is GuardBuilder.NN_MODULE
  2683. )
  2684. )
  2685. class DeletedGuardManagerWrapper(GuardManagerWrapper):
  2686. def __init__(self, reason: str) -> None:
  2687. super().__init__()
  2688. self.invalidation_reason = reason
  2689. def populate_diff_guard_manager(self) -> None:
  2690. self.diff_guard_root = None
  2691. @dataclasses.dataclass
  2692. class ShapeCodeParts:
  2693. python_code_parts: _ShapeGuardsHelper
  2694. verbose_code_parts: _ShapeGuardsHelper
  2695. cpp_code_parts: Optional[_CppShapeGuardsHelper]
  2696. python_fallback: bool
  2697. shape_env_sources: list[Source]
  2698. @dataclasses.dataclass
  2699. class GuardsState:
  2700. output_graph: OutputGraphGuardsState
  2701. shape_code_parts: Optional[ShapeCodeParts]
  2702. class _Missing:
  2703. pass
  2704. class GuardsStatePickler(pickle.Pickler):
  2705. def __init__(self, *args: Any, **kwargs: Any) -> None:
  2706. super().__init__(*args, **kwargs)
  2707. self.fake_mode = torch._subclasses.FakeTensorMode()
  2708. self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
  2709. @classmethod
  2710. def _unpickle_module(cls, state: Any) -> torch.nn.Module:
  2711. mod = torch.nn.Module()
  2712. mod.__setstate__(state)
  2713. return mod
  2714. @classmethod
  2715. def _unpickle_tensor(
  2716. cls,
  2717. meta_tensor: torch.Tensor,
  2718. device: torch.device,
  2719. pytype: type,
  2720. dispatch_keys_raw: int,
  2721. grad: torch.Tensor,
  2722. ) -> torch.Tensor:
  2723. fake_mode = torch._subclasses.FakeTensorMode()
  2724. tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
  2725. ret = tensor_converter.from_meta_and_device(
  2726. fake_mode,
  2727. meta_tensor,
  2728. device,
  2729. pytype,
  2730. torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
  2731. )
  2732. ret.grad = grad
  2733. return ret
  2734. @classmethod
  2735. def _unpickle_traceable_wrapper_subclass(
  2736. cls,
  2737. meta_tensor: torch.Tensor,
  2738. device: torch.device,
  2739. pytype: type,
  2740. dispatch_keys_raw: int,
  2741. ctx: Any,
  2742. inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]],
  2743. ) -> torch.Tensor:
  2744. # Unpickle the inner tensor components. These could also be subclass instances.
  2745. inner_tensors = {}
  2746. for attr, unpickle_func, unpickle_func_args in inner_data:
  2747. inner_tensors[attr] = unpickle_func(*unpickle_func_args)
  2748. outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride()
  2749. out = type(meta_tensor).__tensor_unflatten__( # type: ignore[attr-defined]
  2750. inner_tensors, ctx, outer_size, outer_stride
  2751. )
  2752. out.pytype = pytype
  2753. out.dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw)
  2754. return out
  2755. @classmethod
  2756. def _unpickle_python_module(cls, alias: str) -> types.ModuleType:
  2757. return importlib.import_module(alias)
  2758. @classmethod
  2759. def _unpickle_dispatch_key_set(cls, raw_repr: int) -> torch._C.DispatchKeySet:
  2760. return torch._C.DispatchKeySet.from_raw_repr(raw_repr)
  2761. @classmethod
  2762. def _unpickle_functorch_interpreter(
  2763. cls, json: bytes
  2764. ) -> torch._C._functorch.CInterpreter:
  2765. return torch._C._functorch.CInterpreter.deserialize(json)
  2766. @classmethod
  2767. def _unpickle_mapping_proxy(
  2768. cls, d: dict[Any, Any]
  2769. ) -> types.MappingProxyType[Any, Any]:
  2770. return types.MappingProxyType(d)
  2771. @classmethod
  2772. def _unpickle_c_op(cls, name: str) -> Any:
  2773. return getattr(torch.ops._C, name)
  2774. def reducer_override(
  2775. self, obj: Any
  2776. ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
  2777. import sympy
  2778. if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
  2779. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  2780. if is_traceable_wrapper_subclass(obj):
  2781. # inner_data is a list of tuples of:
  2782. # (inner attr name, unpickle func, tuple of func inputs)
  2783. # This supports traceable wrapper subclass inner tensors.
  2784. inner_data = []
  2785. attrs, ctx = obj.__tensor_flatten__()
  2786. # recursively call for inner tensor components
  2787. for attr in attrs:
  2788. inner = getattr(obj, attr)
  2789. func, args_tuple = self.reducer_override(inner)
  2790. inner_data.append((attr, func, args_tuple))
  2791. return type(self)._unpickle_traceable_wrapper_subclass, (
  2792. torch.empty_like(obj, device="meta"),
  2793. obj.device,
  2794. type(obj),
  2795. torch._C._dispatch_keys(obj).raw_repr(),
  2796. ctx,
  2797. inner_data,
  2798. )
  2799. return type(self)._unpickle_tensor, (
  2800. torch.empty_like(obj, device="meta", requires_grad=obj.requires_grad),
  2801. obj.device,
  2802. type(obj),
  2803. torch._C._dispatch_keys(obj).raw_repr(),
  2804. obj.grad,
  2805. )
  2806. elif isinstance(obj, torch.nn.Module):
  2807. if type(obj).__qualname__ == type(obj).__name__:
  2808. return NotImplemented
  2809. if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
  2810. return type(self)._unpickle_module, (obj.__getstate__(),)
  2811. elif inspect.ismodule(obj):
  2812. return type(self)._unpickle_python_module, (obj.__name__,)
  2813. elif isinstance(obj, torch._C.DispatchKeySet):
  2814. return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),)
  2815. elif isinstance(obj, torch._C._functorch.CInterpreter):
  2816. return type(self)._unpickle_functorch_interpreter, (obj.serialize(),)
  2817. elif (
  2818. inspect.isclass(obj)
  2819. and issubclass(obj, sympy.Function)
  2820. and hasattr(obj, "_torch_handler_name")
  2821. ):
  2822. assert hasattr(obj, "_torch_unpickler")
  2823. return obj._torch_unpickler, (obj._torch_handler_name,)
  2824. elif isinstance(obj, torch.SymInt):
  2825. raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})")
  2826. elif isinstance(obj, types.MappingProxyType):
  2827. return type(self)._unpickle_mapping_proxy, (obj.copy(),)
  2828. elif isinstance(
  2829. obj, torch._ops.OpOverloadPacket
  2830. ) and obj._qualified_op_name.startswith("_C::"):
  2831. return type(self)._unpickle_c_op, (obj.__name__,)
  2832. elif (
  2833. obj.__class__.__module__ == "builtins"
  2834. and obj.__class__.__name__ == "PyCapsule"
  2835. ):
  2836. # Skipping PyCapsule since there isn't much to be guarded about them.
  2837. return _Missing, ()
  2838. elif isinstance(obj, types.CodeType):
  2839. # We only do ID_MATCH on code objects which is already banned from guards serialization.
  2840. return _Missing, ()
  2841. elif inspect.isfunction(obj) and (obj.__code__.co_flags & inspect.CO_NESTED):
  2842. # Skipping nested function since CLOSURE_MATCH is banned from guards serialization.
  2843. assert obj.__qualname__ != obj.__name__
  2844. return _Missing, ()
  2845. if type(obj).__qualname__ != type(obj).__name__:
  2846. raise torch._dynamo.exc.PackageError(
  2847. f"Type {type(obj)} for object {obj} cannot be saved "
  2848. + "into torch.compile() package since it's defined in local scope. "
  2849. + "Please define the class at global scope (top level of a module)."
  2850. )
  2851. return NotImplemented
  2852. def pickle_guards_state(state: GuardsState) -> bytes:
  2853. buf = io.BytesIO()
  2854. pickler = GuardsStatePickler(buf)
  2855. try:
  2856. pickler.dump(state)
  2857. except AttributeError as e:
  2858. raise torch._dynamo.exc.PackageError(str(e)) from e
  2859. return buf.getvalue()
  2860. # NB: Naively, you'd expect this to only be a function that produces
  2861. # the callable that constitutes the guard. However, there is some
  2862. # delicate handling for invalidating this check function when the
  2863. # locals/globals get invalidated, so there's some extra state
  2864. # we have to hold in this manager class.
  2865. class CheckFunctionManager:
  2866. def __init__(
  2867. self,
  2868. f_code: types.CodeType,
  2869. output_graph: OutputGraphGuardsState,
  2870. cache_entry: Optional[CacheEntry] = None,
  2871. guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
  2872. guard_filter_fn: Optional[
  2873. Callable[[list[GuardFilterEntry]], list[bool]]
  2874. ] = None,
  2875. shape_code_parts: Optional[ShapeCodeParts] = None,
  2876. runtime_global_scope: Optional[dict[str, Any]] = None,
  2877. save_guards: bool = False,
  2878. strict_error: bool = False,
  2879. ):
  2880. guards = output_graph.guards if output_graph else None
  2881. self._weakrefs: dict[int, ReferenceType[object]] = {}
  2882. existing_diff_guard_sources = (
  2883. update_diff_guard_managers_for_existing_cache_entries(cache_entry)
  2884. )
  2885. self.output_graph: Optional[OutputGraphGuardsState] = output_graph
  2886. assert self.output_graph is not None
  2887. # Only used for serialization.
  2888. self.shape_code_parts = shape_code_parts
  2889. # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
  2890. # in case a set default device call was made in the graph.
  2891. self.torch_function_mode_stack = (
  2892. output_graph.torch_function_mode_stack if output_graph else None
  2893. )
  2894. self.used_builtin_vars: OrderedSet[str] = OrderedSet()
  2895. self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
  2896. self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
  2897. self.runtime_global_scope = runtime_global_scope
  2898. if not justknobs_check("pytorch/compiler:guard_nn_modules"):
  2899. log.warning("guard_nn_modules is turned off using justknobs killswitch")
  2900. # TODO Be more explicit about the behavior for the users.
  2901. if torch._dynamo.config.caching_precompile:
  2902. _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs])
  2903. def guard_filter_fn(guards: list[GuardFilterEntry]) -> list[bool]:
  2904. ret = []
  2905. for keep, g in zip(_guard_filter_fn(guards), guards):
  2906. if not keep:
  2907. ret.append(False)
  2908. elif (
  2909. g.guard_type in ("ID_MATCH", "CLOSURE_MATCH", "WEAKREF_ALIVE")
  2910. or "ID_MATCH" in g.derived_guard_types
  2911. ):
  2912. log.warning(
  2913. "%s guard on %s is dropped with caching_precompile=True.",
  2914. g.guard_type,
  2915. g.orig_guard.name,
  2916. )
  2917. ret.append(False)
  2918. else:
  2919. ret.append(True)
  2920. return ret
  2921. sorted_guards = sorted(guards or (), key=Guard.sort_key)
  2922. if guard_filter_fn:
  2923. # If we're filtering guards, we need to build it an extra time first
  2924. # because filtering depends on the builder/guard_manager results
  2925. builder, guard_manager = self.build_guards(
  2926. sorted_guards, existing_diff_guard_sources, f_code, output_graph, False
  2927. )
  2928. def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry:
  2929. MISSING = object()
  2930. name = strip_local_scope(guard.name)
  2931. if name == "":
  2932. has_value = False
  2933. value = MISSING
  2934. else:
  2935. try:
  2936. # Guard evaluation is expected to fail when we guard on
  2937. # things like "not hasattr(x, 'foo')". In cases like this,
  2938. # we don't have a well defined value because such thing
  2939. # doesn't exist.
  2940. value = builder.get(guard.name)
  2941. has_value = True
  2942. except: # noqa: B001,E722
  2943. value = MISSING
  2944. has_value = False
  2945. is_global = get_global_source_name(guard.originating_source) is not None
  2946. return GuardFilterEntry(
  2947. name=name,
  2948. has_value=has_value,
  2949. value=value,
  2950. guard_type=guard.create_fn_name(),
  2951. derived_guard_types=(
  2952. tuple(guard.guard_types) if guard.guard_types else ()
  2953. ),
  2954. is_global=is_global,
  2955. orig_guard=guard,
  2956. )
  2957. filter_results = guard_filter_fn(
  2958. [make_guard_filter_entry(guard) for guard in sorted_guards]
  2959. )
  2960. assert len(filter_results) == len(sorted_guards)
  2961. assert all(type(x) == bool for x in filter_results)
  2962. sorted_guards = [
  2963. guard for i, guard in enumerate(sorted_guards) if filter_results[i]
  2964. ]
  2965. # Redo the guards because filtering relies on the results from the last guard builder.
  2966. builder, guard_manager = self.build_guards(
  2967. sorted_guards,
  2968. existing_diff_guard_sources,
  2969. f_code,
  2970. output_graph,
  2971. save_guards,
  2972. )
  2973. self.guard_manager = guard_manager
  2974. self.compile_check_fn(builder, sorted_guards, guard_fail_fn)
  2975. # Keep track of weak references of objects with ID_MATCH guard. This
  2976. # info is stored alongside optimized_code and guard_manager and is used to
  2977. # limit the number of cache entries with same ID_MATCH'd object.
  2978. # TODO(anijain2305) - Currently this information is stored as an attr on
  2979. # the guard_manager itself to avoid changing CacheEntry data structure in
  2980. # eval_frame.c. In future, we should probably replace guard_manager with a
  2981. # queryable data structure such that this information is already present
  2982. # in some form.
  2983. self.guard_manager.id_matched_objs = builder.id_matched_objs
  2984. guards_log.debug("%s", self.guard_manager)
  2985. self.guard_manager.id_matched_objs = builder.id_matched_objs
  2986. # Check that the guard returns True. False means that we will always
  2987. # recompile.
  2988. # TODO(anijain2305, ydwu4) - Skipping export because of following test
  2989. # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
  2990. latency = 0.0
  2991. if not output_graph.skip_guards_check and not output_graph.export:
  2992. if not self.guard_manager.check(output_graph.local_scope):
  2993. reasons = get_guard_fail_reason_helper(
  2994. self.guard_manager,
  2995. output_graph.local_scope,
  2996. CompileContext.current_compile_id(),
  2997. )
  2998. raise AssertionError(f"Guard check failed: {reasons}")
  2999. if guard_manager_testing_hook_fn is not None:
  3000. guard_manager_testing_hook_fn(
  3001. self.guard_manager, output_graph.local_scope, builder
  3002. )
  3003. # NB for developers: n_iters is chosen to be 1 to prevent excessive
  3004. # increase in compile time. We first do a cache flush to measure the
  3005. # guard latency more accurately. This cache flush is expensive.
  3006. # Note - If you are working on a guard optimization, it might be a
  3007. # good idea to increase this number for more stabiilty during
  3008. # development.
  3009. latency = profile_guard_manager(
  3010. self.guard_manager.root, output_graph.local_scope, 1
  3011. )
  3012. guards_log.debug("Guard eval latency = %s us", f"{latency:.2f}")
  3013. # Note: We use `increment_toplevel` instead of `compilation_metric`
  3014. # here. This is because, in scenarios where `torch._dynamo.reset`
  3015. # is invoked, the same frame ID and compile ID may be reused during
  3016. # a new compilation cycle. This behavior causes issues with
  3017. # `compilation_metric`, as it expects the metric field to be empty.
  3018. # Ideally, we would overwrite the existing entry in such cases, but
  3019. # we currently lack an API to support overwriting metrics. However,
  3020. # since these situations are rare and typically impractical to
  3021. # account for, we simply increment at the toplevel instead.
  3022. CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
  3023. self.guards_state: Optional[bytes] = None
  3024. if save_guards:
  3025. from torch._dynamo.output_graph import OutputGraph
  3026. assert isinstance(self.output_graph, OutputGraph)
  3027. try:
  3028. self.guards_state = self.serialize_guards(
  3029. builder, sorted_guards, self.output_graph
  3030. )
  3031. except exc.PackageError as e:
  3032. if torch._dynamo.config.strict_precompile or strict_error:
  3033. raise e
  3034. self.output_graph.bypass_package(
  3035. f"Guard evaluation failed: {str(e)}",
  3036. traceback=traceback.format_exc().split("\n"),
  3037. )
  3038. # TODO: don't do the string rep, do something more structured here
  3039. torch._logging.trace_structured(
  3040. "dynamo_cpp_guards_str",
  3041. payload_fn=lambda: f"{self.guard_manager}\nGuard latency = {latency:.2f} us",
  3042. )
  3043. # NB - We have to very careful of cleaning up here. Because of the
  3044. # invalidate function, we can create a weakref finalizer that keeps
  3045. # `self` alive for very long. Sometimes by mistake, we can run
  3046. # invalidate for a type/object (check id_ref method) that Python can
  3047. # leak by design, preventing us from calling the finalizer. In that
  3048. # case, the `self` will be alive even though the cache entry will be
  3049. # deleted (check invalidate method), which can cause a memory leak,
  3050. # e.g., not setting output_graph = None can keep hold of nn_modules.
  3051. self._weakrefs.clear()
  3052. self.output_graph = None
  3053. UNSUPPORTED_SERIALIZATION_GUARD_TYPES: tuple[LiteralString, ...] = (
  3054. "DICT_VERSION",
  3055. "NN_MODULE",
  3056. "ID_MATCH",
  3057. "FUNCTION_MATCH",
  3058. "CLOSURE_MATCH",
  3059. "WEAKREF_ALIVE",
  3060. )
  3061. def serialize_guards(
  3062. self,
  3063. builder: GuardBuilder,
  3064. sorted_guards: list[Guard],
  3065. output_graph: OutputGraph,
  3066. ) -> bytes:
  3067. # We check whether our list of guards are serializable here
  3068. for guard in sorted_guards:
  3069. guard_type = guard.create_fn_name()
  3070. derived_guard_types = tuple(guard.guard_types) if guard.guard_types else ()
  3071. # BUILTIN_MATCH calls TYPE_MATCH sometimes, so we need to check both for
  3072. # a chance that the guard is unserializable
  3073. if guard_type in ("TYPE_MATCH", "BUILTIN_MATCH"):
  3074. if guard._unserializable:
  3075. # Only call builder.get again if we know we're going to throw
  3076. obj = builder.get(guard.name)
  3077. raise_local_type_error(obj)
  3078. elif (
  3079. guard_type in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
  3080. ):
  3081. raise torch._dynamo.exc.PackageError(
  3082. f"{guard_type} guard cannot be serialized."
  3083. )
  3084. elif failed := next(
  3085. (
  3086. i
  3087. for i in derived_guard_types
  3088. if i in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
  3089. ),
  3090. None,
  3091. ):
  3092. # Just raise the first failed guard name
  3093. raise torch._dynamo.exc.PackageError(
  3094. f"{failed} guard cannot be serialized."
  3095. )
  3096. builtins_dict_name = output_graph.name_of_builtins_dict_key_in_fglobals
  3097. used_global_vars = set()
  3098. used_local_vars = set()
  3099. def prune_variable(source: Source) -> None:
  3100. if name := get_global_source_name(source):
  3101. assert isinstance(name, str)
  3102. # Leave out the builtins dict key, as we will special handle
  3103. # it later because the guarded code rarely use the entire
  3104. # builtin dict in the common case.
  3105. if name not in (builtins_dict_name,):
  3106. used_global_vars.add(name)
  3107. elif name := get_local_source_name(source):
  3108. assert isinstance(name, str)
  3109. used_local_vars.add(name)
  3110. output_graph_guards_state = output_graph.dump_guards_state()
  3111. # Only serialize the global variables that are actually used in guards.
  3112. for guard in sorted_guards:
  3113. if isinstance(guard.originating_source, ShapeEnvSource):
  3114. assert self.shape_code_parts
  3115. for source in self.shape_code_parts.shape_env_sources:
  3116. prune_variable(source)
  3117. else:
  3118. prune_variable(guard.originating_source)
  3119. for source in output_graph.guard_on_key_order:
  3120. prune_variable(source)
  3121. def normalize_create_fn(x: Callable[..., None]) -> Callable[..., None]:
  3122. if isinstance(x, functools.partial):
  3123. def _ref(x: Any) -> Any:
  3124. if isinstance(x, (TensorWeakRef, weakref.ref)):
  3125. return x()
  3126. return x
  3127. new_args = tuple(_ref(a) for a in x.args)
  3128. new_keywords = {k: _ref(v) for k, v in x.keywords.items()}
  3129. return functools.partial(x.func, *new_args, **new_keywords)
  3130. return x
  3131. global_scope_state = {
  3132. k: v
  3133. for k, v in output_graph_guards_state.global_scope.items()
  3134. if k in used_global_vars or k in self.additional_used_global_vars
  3135. }
  3136. global_scope_state[builtins_dict_name] = {
  3137. k: v
  3138. for k, v in output_graph_guards_state.global_scope[
  3139. builtins_dict_name
  3140. ].items() # type: ignore[attr-defined]
  3141. if k in self.used_builtin_vars
  3142. }
  3143. output_graph_guards_state = dataclasses.replace(
  3144. output_graph_guards_state,
  3145. local_scope={
  3146. k: v
  3147. for k, v in output_graph_guards_state.local_scope.items()
  3148. if k in used_local_vars or k in self.additional_used_local_vars
  3149. },
  3150. global_scope=global_scope_state,
  3151. _guards=torch._guards.GuardsSet(
  3152. {
  3153. dataclasses.replace(
  3154. guard,
  3155. obj_weakref=None,
  3156. guarded_class_weakref=None,
  3157. create_fn=normalize_create_fn(guard.create_fn),
  3158. )
  3159. for guard in sorted_guards
  3160. }
  3161. ),
  3162. input_source_to_sizes_strides=pytree.tree_map(
  3163. convert_int_to_concrete_values,
  3164. output_graph_guards_state.input_source_to_sizes_strides,
  3165. ),
  3166. skip_guards_check=True,
  3167. )
  3168. guards_state = GuardsState(
  3169. output_graph=output_graph_guards_state,
  3170. shape_code_parts=self.shape_code_parts,
  3171. )
  3172. return pickle_guards_state(guards_state)
  3173. def build_guards(
  3174. self,
  3175. sorted_guards: list[Guard],
  3176. existing_diff_guard_sources: OrderedSet[str],
  3177. f_code: types.CodeType,
  3178. output_graph: OutputGraphGuardsState,
  3179. save_guards: bool,
  3180. ) -> tuple[GuardBuilder, GuardManagerWrapper]:
  3181. guard_manager = GuardManagerWrapper()
  3182. guard_manager.diff_guard_sources = existing_diff_guard_sources
  3183. w_builder = None
  3184. def source_ref(source: Source) -> str:
  3185. guard_source = source.guard_source()
  3186. if guard_source is GuardSource.CONSTANT:
  3187. # No need to track constants
  3188. return source.name()
  3189. assert w_builder
  3190. r_builder = w_builder()
  3191. assert r_builder is not None
  3192. return r_builder.arg_ref(source.name())
  3193. builder = GuardBuilder(
  3194. f_code,
  3195. self.id_ref,
  3196. source_ref,
  3197. self.lookup_weakrefs,
  3198. output_graph.local_scope,
  3199. output_graph.global_scope,
  3200. guard_manager,
  3201. self,
  3202. save_guards,
  3203. runtime_global_scope=self.runtime_global_scope,
  3204. )
  3205. # Break retain cycle. See test_release_scope_memory
  3206. def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None:
  3207. b = weak_b()
  3208. if b:
  3209. b.scope = None # type: ignore[assignment]
  3210. # Break retain cycle. See test_release_input_memory
  3211. w_builder = weakref.ref(builder, cleanup_builder)
  3212. guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
  3213. "pytorch/compiler:guard_nn_modules"
  3214. )
  3215. for guard in sorted_guards:
  3216. if (
  3217. not guard_on_nn_modules
  3218. and guard.is_specialized_nn_module()
  3219. # Default func args must be guarded on.
  3220. # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
  3221. and "__defaults__" not in guard.name
  3222. and "__kwdefaults__" not in guard.name
  3223. and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
  3224. ):
  3225. continue
  3226. guard.create(builder)
  3227. return builder, guard_manager
  3228. def compile_check_fn(
  3229. self,
  3230. builder: GuardBuilder,
  3231. guards_out: list[Guard],
  3232. guard_fail_fn: Optional[Callable[[GuardFail], None]],
  3233. ) -> None:
  3234. # see parallel handling of ".0" / "___implicit0" in _eval_frame.c
  3235. largs = builder.argnames
  3236. largs += ["**___kwargs_ignored"]
  3237. guards_log.debug("GUARDS:")
  3238. code_parts = []
  3239. verbose_code_parts = []
  3240. structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
  3241. assert self.torch_function_mode_stack is not None
  3242. torch_function_mode_stack_check_fn = make_torch_function_mode_stack_guard(
  3243. self.torch_function_mode_stack
  3244. )
  3245. # Add compile id info in the guard manager for debugging purpose
  3246. self.guard_manager.root.attach_compile_id(
  3247. str(CompileContext.current_compile_id())
  3248. )
  3249. # Insert the global_state guard
  3250. assert self.output_graph is not None
  3251. global_state = self.output_graph.global_state_guard
  3252. self.guard_manager.root.add_global_state_guard(
  3253. global_state, ["___check_global_state()"]
  3254. )
  3255. self.guard_manager.root.add_torch_function_mode_stack_guard(
  3256. self.torch_function_mode_stack,
  3257. ["___check_torch_function_mode_stack()"],
  3258. )
  3259. # Clear references to torch_function modes held in the list
  3260. self.torch_function_mode_stack = None
  3261. def add_code_part(
  3262. code_part: str, guard: Optional[Guard], log_only: bool = False
  3263. ) -> None:
  3264. verbose_code_part = get_verbose_code_part(code_part, guard)
  3265. guards_log.debug("%s", verbose_code_part)
  3266. structured_guard_fns.append(
  3267. lambda: {
  3268. "code": code_part,
  3269. "stack": (
  3270. structured.from_traceback(guard.stack.summary())
  3271. if guard and guard.stack
  3272. else None
  3273. ),
  3274. "user_stack": (
  3275. structured.from_traceback(guard.user_stack)
  3276. if guard and guard.user_stack
  3277. else None
  3278. ),
  3279. }
  3280. )
  3281. if verbose_guards_log.isEnabledFor(logging.DEBUG):
  3282. maybe_stack = ""
  3283. maybe_user_stack = ""
  3284. if guard is not None:
  3285. if guard.stack:
  3286. maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}"
  3287. if guard.user_stack:
  3288. maybe_user_stack = (
  3289. f"\nUser stack:\n{''.join(guard.user_stack.format())}"
  3290. )
  3291. verbose_guards_log.debug(
  3292. "Guard: %s%s%s",
  3293. code_part,
  3294. maybe_stack,
  3295. maybe_user_stack,
  3296. )
  3297. if not log_only:
  3298. code_parts.append(code_part)
  3299. verbose_code_parts.append(verbose_code_part)
  3300. seen = set()
  3301. for gcl in builder.code:
  3302. for code in gcl.code_list:
  3303. if code not in seen:
  3304. # If Cpp guard manager is enabled, we don't need to add to
  3305. # code_parts.
  3306. add_code_part(code, gcl.guard, True)
  3307. seen.add(code)
  3308. no_tensor_aliasing_names = builder.no_tensor_aliasing_names
  3309. check_tensors_fn = None
  3310. check_tensors_verbose_fn = None
  3311. if len(no_tensor_aliasing_names) > 1:
  3312. # Install tensor aliasing guard. TENSOR_MATCH guards are already
  3313. # installed for cpp guard manager.
  3314. install_no_tensor_aliasing_guard(
  3315. builder.no_tensor_aliasing_guard_managers,
  3316. no_tensor_aliasing_names,
  3317. ["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"],
  3318. )
  3319. # Note - On Lambda guarding of object aliasing
  3320. # We previously installed object‑aliasing guards as relational guards,
  3321. # but that undermined the recursive‑dict guard optimization: placing the
  3322. # aliasing guard at a leaf prevented the parent dict node from
  3323. # qualifying as a recursive‑dict guard root. Because aliasing guards are
  3324. # rare, we now emit them as epilogue guards via a small Python lambda.
  3325. # This repeats the access in Python—adding a bit of work—but the
  3326. # overhead is outweighed by the gains from enabling recursive‑dict guard
  3327. # optimization.
  3328. if (
  3329. config.use_lamba_guard_for_object_aliasing
  3330. and builder.object_aliasing_guard_codes
  3331. ):
  3332. aliasing_code_parts, aliasing_verbose_code_parts = map(
  3333. list, zip(*builder.object_aliasing_guard_codes)
  3334. )
  3335. builder.add_python_lambda_leaf_guard_to_root(
  3336. aliasing_code_parts, aliasing_verbose_code_parts
  3337. )
  3338. aotautograd_guards: list[GuardEnvExpr] = (
  3339. self.output_graph.aotautograd_guards if self.output_graph else []
  3340. )
  3341. # TODO(anijain2305) - There is a duplicate logic in Dynamo to find
  3342. # aliased input tensors. So most probably we don't need this here.
  3343. # Revisit.
  3344. for guard in aotautograd_guards:
  3345. if isinstance(guard, DuplicateInputs):
  3346. source_a = guard.input_source_a
  3347. source_b = guard.input_source_b
  3348. code_part = f"{source_a.name()} is {source_b.name()}"
  3349. install_object_aliasing_guard(
  3350. builder.get_guard_manager_from_source(source_a),
  3351. builder.get_guard_manager_from_source(source_b),
  3352. [code_part],
  3353. )
  3354. add_code_part(code_part, None, True)
  3355. elif isinstance(guard, StorageOverlap):
  3356. overlapping_guard_managers = [
  3357. builder.get_guard_manager_from_source(s)
  3358. for s in guard.overlapping_sources
  3359. ]
  3360. non_overlapping_guard_managers = [
  3361. builder.get_guard_manager_from_source(s)
  3362. for s in guard.non_overlapping_sources
  3363. ]
  3364. code_part = (
  3365. """check_overlapping("""
  3366. f"""overlapping=[{", ".join(s.name() for s in guard.overlapping_sources)}], """
  3367. f"""non_overlapping=[{", ".join(s.name() for s in guard.non_overlapping_sources)}])"""
  3368. )
  3369. install_storage_overlapping_guard(
  3370. overlapping_guard_managers,
  3371. non_overlapping_guard_managers,
  3372. [code_part],
  3373. )
  3374. add_code_part(code_part, None, True)
  3375. else:
  3376. raise RuntimeError(f"Unknown GuardEnvExpr: {guard}")
  3377. # TODO: the "guard" here is actually just the top level SHAPE_ENV
  3378. # which is useless. Get ShapeEnv to pass in more provenance.
  3379. for gcl in builder.shape_env_code:
  3380. for code in gcl.code_list:
  3381. # Shape env guards are already added for CPP guard manager in
  3382. # SHAPE_ENV implementation.
  3383. add_code_part(code, gcl.guard, True)
  3384. # OK, all done generating guards
  3385. if structured_guard_fns:
  3386. torch._logging.trace_structured(
  3387. "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
  3388. )
  3389. if convert_frame.initial_global_state is None:
  3390. # we should only hit this case in NopTests()
  3391. global_state = convert_frame.GlobalStateGuard()
  3392. closure_vars = {
  3393. "___check_tensors": check_tensors_fn,
  3394. "___check_tensors_verbose": check_tensors_verbose_fn,
  3395. "___check_global_state": global_state.check,
  3396. "___check_torch_function_mode_stack": torch_function_mode_stack_check_fn,
  3397. **SYMPY_INTERP,
  3398. **_get_closure_vars(),
  3399. }
  3400. self.guard_manager.finalize()
  3401. globals_for_guard_fn = {"G": builder.scope["G"]}
  3402. # Guard manager construction is complete. Ensure we did not miss to
  3403. # insert a guard in cpp guard manager.
  3404. assert len(code_parts) == 0
  3405. self.guard_manager.closure_vars = closure_vars
  3406. self.guard_manager.args = largs
  3407. self.guard_manager.populate_code_parts_for_debugging()
  3408. self.guard_manager.verbose_code_parts = verbose_code_parts
  3409. # Grab only G, but preserve "G" because guards access it as "G"
  3410. self.guard_manager.global_scope = globals_for_guard_fn
  3411. self.guard_manager.guard_fail_fn = guard_fail_fn
  3412. # will be populated by a non-owning reference to CacheEntry/ExtraState
  3413. # when the CacheEntry is constructed
  3414. self.guard_manager.cache_entry = None
  3415. self.guard_manager.extra_state = None
  3416. self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names
  3417. def invalidate(self, obj_str: str) -> None:
  3418. # Some tests reveal that CheckFunctionManager has no attribute
  3419. # guard_manager, but this case should not be of any concern.
  3420. # This case doesn't seem easy to repro.
  3421. if (
  3422. hasattr(self, "guard_manager")
  3423. and not isinstance(self.guard_manager, DeletedGuardManagerWrapper)
  3424. and (cache_entry := self.guard_manager.cache_entry) is not None
  3425. and (extra_state := self.guard_manager.extra_state) is not None
  3426. ):
  3427. assert isinstance(cache_entry, CacheEntry)
  3428. assert isinstance(extra_state, ExtraState)
  3429. reason = f"Cache line invalidated because {obj_str} got deallocated"
  3430. deleted_guard_manager = DeletedGuardManagerWrapper(reason)
  3431. extra_state.invalidate(cache_entry, deleted_guard_manager)
  3432. self.guard_manager = deleted_guard_manager
  3433. def id_ref(self, obj: object, obj_str: str) -> int:
  3434. """add a weakref, return the id"""
  3435. try:
  3436. if id(obj) not in self._weakrefs:
  3437. # We will clear the _weakrefs dict at the end of __init__
  3438. # function, which will delete the callbacks as well. Therefore,
  3439. # we are using a finalizer which is kept alive.
  3440. self._weakrefs[id(obj)] = weakref.ref(obj)
  3441. weakref.finalize(
  3442. obj, functools.partial(self.invalidate, obj_str=obj_str)
  3443. )
  3444. except TypeError:
  3445. pass # cannot weakref bool object
  3446. return id(obj)
  3447. def lookup_weakrefs(self, obj: object) -> Optional[weakref.ref[object]]:
  3448. """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects"""
  3449. if id(obj) in self._weakrefs:
  3450. return self._weakrefs[id(obj)]
  3451. return None
  3452. def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, str]:
  3453. from torch._inductor.utils import IndentedBuffer
  3454. csepass = PyExprCSEPass()
  3455. try:
  3456. csepass.count(code_parts)
  3457. def replace(expr: str) -> tuple[list[str], str]:
  3458. return csepass.replace(expr)
  3459. except RecursionError:
  3460. # If we hit recursion limits during CSE analysis, fall back to a no-op replace function
  3461. # This can happen with extremely complex guard expressions
  3462. def replace(expr: str) -> tuple[list[str], str]:
  3463. return [], expr
  3464. # Generate the inner body of the guard function.
  3465. # i.e. if-chain of the guard expressions.
  3466. guard_body = IndentedBuffer()
  3467. for expr in code_parts:
  3468. preface, expr = replace(expr)
  3469. guard_body.writelines(preface)
  3470. guard_body.writeline(f"if not ({expr}):")
  3471. with guard_body.indent():
  3472. guard_body.writeline("return False")
  3473. # Wrap the inner body into the actual guard function.
  3474. guard = IndentedBuffer()
  3475. guard.writeline("def guard(L):")
  3476. with guard.indent():
  3477. guard.splice(guard_body)
  3478. guard.writeline("return True")
  3479. # Wrap the whole guard function into another function
  3480. # with the closure variables.
  3481. make_guard_fn = IndentedBuffer()
  3482. make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):")
  3483. with make_guard_fn.indent():
  3484. make_guard_fn.splice(guard)
  3485. make_guard_fn.writeline("return guard")
  3486. return guard_body.getvalue(), make_guard_fn.getvalue()
  3487. def is_recompiles_enabled() -> bool:
  3488. return torch._logging._internal.log_state.is_artifact_enabled("recompiles")
  3489. def is_recompiles_verbose_enabled() -> bool:
  3490. return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose")
  3491. # this will only be used if cpp guards are disabled
  3492. def make_torch_function_mode_stack_guard(
  3493. initial_stack: list[torch.overrides.TorchFunctionMode],
  3494. ) -> Callable[[], bool]:
  3495. types = [type(x) for x in initial_stack]
  3496. def check_torch_function_mode_stack() -> bool:
  3497. cur_stack = get_torch_function_mode_stack()
  3498. if len(cur_stack) != len(types):
  3499. return False
  3500. for ty, mode in zip(types, cur_stack):
  3501. if ty != type(mode):
  3502. return False
  3503. return True
  3504. return check_torch_function_mode_stack
  3505. Scope = TypeAliasType("Scope", dict[str, object])
  3506. def recompilation_reason_for_no_tensor_aliasing_guard(
  3507. guard_manager: GuardManagerWrapper, scope: Scope
  3508. ) -> list[str]:
  3509. assert guard_manager.global_scope is not None
  3510. global_scope = dict(guard_manager.global_scope)
  3511. ids_to_source = collections.defaultdict(list)
  3512. for tensor_source in guard_manager.no_tensor_aliasing_sources:
  3513. global_scope["__compile_source__"] = tensor_source
  3514. tensor_id = id(eval(tensor_source, global_scope, scope))
  3515. ids_to_source[tensor_id].append(tensor_source)
  3516. duplicate_tensors = [
  3517. f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1
  3518. ]
  3519. reason = ", ".join(duplicate_tensors)
  3520. return [f"Duplicate tensors found: {reason}"]
  3521. def strip_local_scope(s: str) -> str:
  3522. """
  3523. Replace occurrences of L[...] with just the inner content.
  3524. Handles both single and double quotes.
  3525. This is to generate user friendly recompilation messages.
  3526. """
  3527. import re
  3528. pattern = r"L\[\s*['\"](.*?)['\"]\s*\]"
  3529. return re.sub(pattern, r"\1", s)
  3530. def get_guard_fail_reason_helper(
  3531. guard_manager: GuardManagerWrapper,
  3532. f_locals: dict[str, object],
  3533. compile_id: Optional[CompileId],
  3534. ) -> str:
  3535. """
  3536. Return the reason why `guard_manager` failed.
  3537. Updates `guard_failures` with the generated reason.
  3538. Only the first failed check of guard_manager is reported.
  3539. """
  3540. assert guard_manager.global_scope is not None
  3541. assert guard_manager.closure_vars is not None
  3542. scope = {"L": f_locals, "G": guard_manager.global_scope["G"]}
  3543. scope.update(guard_manager.closure_vars)
  3544. reasons: list[str] = []
  3545. no_tensor_aliasing_check_failed = False
  3546. verbose_code_parts: list[str] = []
  3547. guard_debug_info = guard_manager.check_verbose(f_locals)
  3548. # For test_export_with_map_cond, the check_verbose fail even without the
  3549. # C++ guard manager. We need to fix the issue to remove the comment.
  3550. # assert not guard_debug_info.result
  3551. if not guard_debug_info.result:
  3552. verbose_code_parts = guard_debug_info.verbose_code_parts
  3553. # verbose_code_parts is either the actual reason (e.g. in case of
  3554. # TENSOR_MATCH) or it could be a list of verbose_code_part that we
  3555. # passed to the leaf guard at construction time. If its a list, we
  3556. # walk through this list and find the guard that failed. This is
  3557. # very important for symbolic shape guards which are currently
  3558. # installed as a lambda guard and can encompass a long list of code_parts.
  3559. if len(verbose_code_parts) == 1:
  3560. if "Duplicate tensor found" in verbose_code_parts[0]:
  3561. no_tensor_aliasing_check_failed = True
  3562. else:
  3563. reasons = verbose_code_parts
  3564. verbose_code_parts = []
  3565. if no_tensor_aliasing_check_failed:
  3566. reasons = recompilation_reason_for_no_tensor_aliasing_guard(
  3567. guard_manager, scope
  3568. )
  3569. else:
  3570. for part in verbose_code_parts:
  3571. global_scope = dict(guard_manager.global_scope)
  3572. global_scope["__compile_source__"] = part
  3573. with report_compile_source_on_error():
  3574. try:
  3575. fail_reason = eval(part, global_scope, scope)
  3576. except Exception:
  3577. if is_recompiles_verbose_enabled():
  3578. continue
  3579. else:
  3580. raise
  3581. # Only ___check_tensors knows how to return a fancy fail reason;
  3582. # for everything else we just report the code that failed
  3583. if isinstance(fail_reason, bool) and not fail_reason:
  3584. fail_reason = part
  3585. if isinstance(fail_reason, str):
  3586. reasons.append(fail_reason)
  3587. if not is_recompiles_verbose_enabled():
  3588. break
  3589. reason_str = f"{compile_id}: " + "; ".join(reasons)
  3590. return strip_local_scope(reason_str)
  3591. def get_guard_fail_reason(
  3592. guard_manager: GuardManagerWrapper,
  3593. code: types.CodeType,
  3594. f_locals: dict[str, object],
  3595. compile_id: CompileId,
  3596. skip_logging: bool = False,
  3597. ) -> str:
  3598. if isinstance(guard_manager, DeletedGuardManagerWrapper):
  3599. return f"{compile_id}: {guard_manager.invalidation_reason}"
  3600. reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id)
  3601. if skip_logging:
  3602. return reason_str
  3603. guard_failures[orig_code_map[code]].append(reason_str)
  3604. try:
  3605. if guard_manager.guard_fail_fn is not None:
  3606. guard_manager.guard_fail_fn(
  3607. GuardFail(reason_str or "unknown reason", orig_code_map[code])
  3608. )
  3609. except Exception:
  3610. log.exception(
  3611. "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval",
  3612. )
  3613. return reason_str
  3614. def get_and_maybe_log_recompilation_reasons(
  3615. cache_entry: Optional[CacheEntry],
  3616. frame: DynamoFrameType,
  3617. skip_logging: bool = False,
  3618. ) -> list[str]:
  3619. """
  3620. Return the list of guard failure reasons using cache_entry.
  3621. Logs the recompilation reason if `recompiles` logging is enabled.
  3622. Raises a RecompileError if `config.error_on_recompile` is enabled.
  3623. """
  3624. reasons = []
  3625. while cache_entry is not None:
  3626. reason = get_guard_fail_reason(
  3627. cache_entry.guard_manager,
  3628. cache_entry.code,
  3629. frame.f_locals,
  3630. cache_entry.compile_id,
  3631. skip_logging,
  3632. )
  3633. if reason:
  3634. reasons.append(reason)
  3635. cache_entry = cache_entry.next
  3636. code = frame.f_code
  3637. if skip_logging:
  3638. return reasons
  3639. # at least one of "recompiles" or "recompiles_verbose" is enabled
  3640. do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled()
  3641. if do_recompiles_log or config.error_on_recompile:
  3642. if is_recompiles_verbose_enabled():
  3643. failures = "\n\n".join(
  3644. f"guard {i} failures:\n" + textwrap.indent(reason, "- ")
  3645. for i, reason in enumerate(reasons)
  3646. )
  3647. else:
  3648. failures = textwrap.indent("\n".join(reasons), "- ")
  3649. guard_failure_details = (
  3650. f"triggered by the following guard failure(s):\n{failures}"
  3651. )
  3652. message = (
  3653. f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n"
  3654. f"{textwrap.indent(guard_failure_details, ' ')}"
  3655. )
  3656. if do_recompiles_log:
  3657. if is_recompiles_verbose_enabled():
  3658. recompiles_verbose_log.debug(message)
  3659. else:
  3660. recompiles_log.debug(message)
  3661. if config.error_on_recompile:
  3662. raise exc.RecompileError(message)
  3663. torch._logging.trace_structured(
  3664. "artifact",
  3665. metadata_fn=lambda: {
  3666. "name": "recompile_reasons",
  3667. "encoding": "json",
  3668. },
  3669. payload_fn=lambda: reasons,
  3670. )
  3671. return reasons
  3672. def update_diff_guard_managers_for_existing_cache_entries(
  3673. cache_entry: Optional[CacheEntry],
  3674. ) -> OrderedSet[str]:
  3675. first_cache_entry = cache_entry
  3676. # On the first pass, go through the cache entries and accumulate the diff
  3677. # guard sources. Different guard managers can fail with different sources.
  3678. # So, we collect all of them first.
  3679. acc_diff_guard_sources: OrderedSet[str] = OrderedSet()
  3680. while cache_entry is not None:
  3681. acc_diff_guard_sources.update(
  3682. cache_entry.guard_manager.collect_diff_guard_sources()
  3683. )
  3684. cache_entry = cache_entry.next # type: ignore[assignment]
  3685. # On the second pass, set the diff_guard_sources for each cache line to the
  3686. # accumulated value. And the re-populate the diff guard manager.
  3687. cache_entry = first_cache_entry
  3688. while cache_entry is not None:
  3689. cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources
  3690. cache_entry.guard_manager.populate_diff_guard_manager()
  3691. cache_entry = cache_entry.next # type: ignore[assignment]
  3692. # return the accumulated sources to set up the new cache line.
  3693. return acc_diff_guard_sources
  3694. def guard_error_hook(
  3695. guard_manager: GuardFn,
  3696. code: types.CodeType,
  3697. f_locals: dict[str, object],
  3698. index: int,
  3699. last: bool,
  3700. ) -> None:
  3701. print(
  3702. f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
  3703. )
  3704. print("lambda " + ", ".join(guard_manager.args) + ":")
  3705. print(" ", " and\n ".join(guard_manager.code_parts))
  3706. print(guard_manager)
  3707. local_scope = {"L": f_locals, **guard_manager.closure_vars}
  3708. for guard in guard_manager.code_parts:
  3709. try:
  3710. eval(guard, guard_manager.global_scope, local_scope)
  3711. except: # noqa: B001,E722
  3712. print(f"Malformed guard:\n{guard}")
  3713. set_guard_error_hook(guard_error_hook)
  3714. def unique(seq: Sequence[T]) -> Generator[T, None, None]:
  3715. seen = set()
  3716. for x in seq:
  3717. if x not in seen:
  3718. yield x
  3719. seen.add(x)
  3720. def make_dupe_guard(
  3721. obj_source: Source, dupe_source: Source
  3722. ) -> Optional[functools.partial[Any]]:
  3723. # Note - we may end up in a situation where we invoke something like
  3724. # def fn(x, y)
  3725. # with fn(x, x)
  3726. # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
  3727. # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
  3728. # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
  3729. # In the fn(x, x) example call above look like a graph with a single input.
  3730. # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
  3731. # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
  3732. # leave unsourced - like a local list created and discharged entirely within a local scope.
  3733. if dupe_source and dupe_source != obj_source:
  3734. ser_source_is_local = is_from_local_source(dupe_source)
  3735. source_is_local = is_from_local_source(obj_source)
  3736. if is_from_flatten_script_object_source(
  3737. dupe_source
  3738. ) or is_from_flatten_script_object_source(obj_source):
  3739. raise exc.UnsafeScriptObjectError(
  3740. f"{obj_source.name()} is aliasing {dupe_source.name()}. This is not supported."
  3741. f" Please do a clone for corresponding input."
  3742. )
  3743. # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
  3744. # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
  3745. # so maybe we should do this refactor before we land this...
  3746. # TODO(voz): Combine local and global guard builders.
  3747. if ser_source_is_local == source_is_local:
  3748. # Note - this is a little aggressive - these being duplicate input does not always matter.
  3749. # However, this should always be a sound guard to add here.
  3750. return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
  3751. return None
  3752. def install_guard(*guards: Guard, skip: int = 0) -> None:
  3753. """
  3754. Add dynamo guards to the current tracing context.
  3755. Args:
  3756. guards: guard(s) to add
  3757. skip: number of stack frames to ignore for debug stack trace
  3758. """
  3759. from torch._guards import TracingContext
  3760. collect_debug_stack = guards_log.isEnabledFor(
  3761. logging.DEBUG
  3762. ) or verbose_guards_log.isEnabledFor(logging.DEBUG)
  3763. add = TracingContext.get().guards_context.dynamo_guards.add
  3764. for guard in guards:
  3765. assert isinstance(guard, Guard)
  3766. if is_from_skip_guard_source(guard.originating_source):
  3767. continue
  3768. add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1)