_jit_internal.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550
  1. # mypy: allow-untyped-defs
  2. """
  3. The weak_script annotation needs to be here instead of inside torch/jit/ so it
  4. can be used in other places in torch/ (namely torch.nn) without running into
  5. circular dependency problems
  6. """
  7. import ast
  8. import builtins
  9. import collections
  10. import contextlib
  11. import enum
  12. import inspect
  13. import io
  14. import pickle
  15. import sys
  16. import textwrap
  17. import threading
  18. import types
  19. import typing
  20. import warnings
  21. import weakref
  22. from typing import ( # noqa: UP035, F401 # (Dict, List, Tuple) imported by torch.jit.annotations
  23. Any,
  24. Callable,
  25. Dict,
  26. Final,
  27. ForwardRef,
  28. get_args,
  29. get_origin,
  30. List,
  31. Optional,
  32. Tuple,
  33. TypeVar,
  34. Union,
  35. )
  36. from typing_extensions import ParamSpec
  37. import torch
  38. # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
  39. # Explicitly ask to import `torch.distributed.__init__` first.
  40. # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
  41. import torch.distributed.rpc
  42. import torch.package._mangling as package_mangling
  43. from torch._awaits import _Await
  44. from torch._C import _Await as CAwait, Future as CFuture
  45. from torch._sources import fake_range, get_source_lines_and_file, parse_def
  46. from torch.futures import Future
  47. _P = ParamSpec("_P")
  48. _R = TypeVar("_R")
  49. BuiltinUnionType: type | tuple[type, ...] = types.UnionType
  50. LockType: type
  51. try:
  52. import _thread
  53. LockType = _thread.LockType
  54. except ImportError:
  55. import _dummy_thread # type: ignore[import-not-found]
  56. LockType = _dummy_thread.LockType
  57. # Wrapper functions that can call either of 2 functions depending on a boolean
  58. # argument
  59. boolean_dispatched: "weakref.WeakKeyDictionary[Callable, dict[str, Callable]]" = (
  60. weakref.WeakKeyDictionary()
  61. ) # noqa: T484
  62. FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
  63. def is_final(ann) -> bool:
  64. return (
  65. hasattr(ann, "__module__")
  66. and ann.__module__ in {"typing", "typing_extensions"}
  67. and (get_origin(ann) is Final or isinstance(ann, type(Final)))
  68. )
  69. # allows BroadcastingList instance to be subscriptable
  70. class BroadcastingListCls:
  71. def __getitem__(self, types):
  72. return
  73. # mypy doesn't support parameters on types, so we have to explicitly type each
  74. # list size
  75. BroadcastingList1 = BroadcastingListCls()
  76. for i in range(2, 7):
  77. globals()[f"BroadcastingList{i}"] = BroadcastingList1
  78. def is_scripting() -> bool:
  79. r"""
  80. Function that returns True when in compilation and False otherwise. This
  81. is useful especially with the @unused decorator to leave code in your
  82. model that is not yet TorchScript compatible.
  83. .. testcode::
  84. import torch
  85. @torch.jit.unused
  86. def unsupported_linear_op(x):
  87. return x
  88. def linear(x):
  89. if torch.jit.is_scripting():
  90. return torch.linear(x)
  91. else:
  92. return unsupported_linear_op(x)
  93. """
  94. return False
  95. # Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
  96. def _qualified_name(obj, mangle_name=True) -> str:
  97. # This special case allows us to override the qualified name on a type.
  98. # It's currently used in conjunction with tracing, where we create a
  99. # fake module to filter only supported attributes. However, since this
  100. # new type is defined as a local class, we need a mechanism to override
  101. # its qualname so it appears correctly in the TorchScript system. This,
  102. # we set '_jit_override_qualname' with the original traced module's
  103. # qualified name, which is picked up here
  104. if hasattr(obj, "_jit_override_qualname"):
  105. return obj._jit_override_qualname
  106. # short-circuit in cases where the object already has a known qualified name
  107. if isinstance(obj, torch._C.ScriptFunction):
  108. return obj.qualified_name
  109. if getattr(obj, "__name__", None):
  110. name = obj.__name__
  111. # Enum classes do not have `__name__` attr, instead they have `name`.
  112. elif isinstance(obj, enum.Enum):
  113. name = obj.name
  114. else:
  115. raise RuntimeError("Could not get name of python class object")
  116. if name == "<lambda>":
  117. name = "_lambda" # make name a valid identifier
  118. module_name = obj.__module__
  119. # If the module is actually a torchbind module, then we should short circuit
  120. if module_name == "torch._classes":
  121. return obj.qualified_name # pyrefly: ignore [missing-attribute]
  122. # The Python docs are very clear that `__module__` can be None, but I can't
  123. # figure out when it actually would be.
  124. if module_name is None:
  125. raise RuntimeError(
  126. f"Could not get qualified name for class '{name}': "
  127. "__module__ can't be None."
  128. )
  129. # if getattr(sys.modules[module_name], name) is not obj:
  130. # raise RuntimeError(f"Could not get qualified name for class '{name}': "
  131. # f"the attr {name} on module {module_name} is not the class")
  132. # torch.package and TorchScript have separate mangling schemes to avoid
  133. # name collisions from multiple packages. To avoid them interfering with
  134. # each other, normalize the package managing here.
  135. if package_mangling.is_mangled(module_name):
  136. module_name = module_name.replace("<", "_")
  137. module_name = module_name.replace(">", "_")
  138. # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
  139. # does not need mangle the python class name.
  140. if mangle_name:
  141. # __main__ is a builtin module, so rewrite it to "__torch__".
  142. if module_name == "__main__":
  143. module_name = "__torch__"
  144. else:
  145. # Everything else gets a "__torch__" prefix to avoid name collisions
  146. # with the names of user values.
  147. module_name = "__torch__." + module_name
  148. if "." in name:
  149. raise RuntimeError(
  150. f"Could not get qualified name for class '{name}': "
  151. f"'{name}' is not a valid identifier"
  152. )
  153. return module_name + "." + name
  154. class SourceLoader:
  155. def __init__(self):
  156. self.content = {}
  157. def cache(self, fn, source):
  158. self.content[fn] = source
  159. def get_source(self, fn):
  160. return self.content.get(fn)
  161. loader = SourceLoader()
  162. def createResolutionCallbackFromEnv(lookup_base):
  163. """
  164. Creates a resolution callback that will look up qualified names in an
  165. environment, starting with `lookup_base` for the base of any qualified
  166. names, then proceeding down the lookup chain with the resolved object.
  167. You should not use this directly, it should only be used from the other
  168. createResolutionCallbackFrom* functions.
  169. """
  170. def lookupInModule(qualified_name, module):
  171. if "." in qualified_name:
  172. base, remaining_pieces = qualified_name.split(".", maxsplit=1)
  173. module_value = getattr(module, base)
  174. return lookupInModule(remaining_pieces, module_value)
  175. else:
  176. return getattr(module, qualified_name)
  177. def parseNestedExpr(expr, module) -> tuple[Any, int]:
  178. i = 0
  179. while i < len(expr) and expr[i] not in (",", "[", "]"):
  180. i += 1
  181. # Special case logic for the empty Tuple as a subscript (used
  182. # in the type annotation `Tuple[()]`)
  183. if expr[:i] == "()":
  184. return (), i
  185. base = lookupInModule(expr[:i].strip(), module)
  186. assert base is not None, f"Unresolvable type {expr[:i]}"
  187. if i == len(expr) or expr[i] != "[":
  188. return base, i
  189. assert expr[i] == "["
  190. parts = []
  191. while expr[i] != "]":
  192. part_len = 0
  193. i += 1
  194. part, part_len = parseNestedExpr(expr[i:], module)
  195. parts.append(part)
  196. i += part_len
  197. if len(parts) > 1:
  198. return base[tuple(parts)], i + 1
  199. else:
  200. return base[parts[0]], i + 1
  201. def parseExpr(expr, module):
  202. try:
  203. value, len_parsed = parseNestedExpr(expr, module)
  204. assert len_parsed == len(expr), (
  205. "whole expression was not parsed, falling back to c++ parser"
  206. )
  207. return value
  208. except Exception:
  209. """
  210. The python resolver fails in several cases in known unit tests, and is intended
  211. to fall back gracefully to the c++ resolver in general. For example, python 2 style
  212. annotations which are frequent in our unit tests often fail with types e.g. int not
  213. resolvable from the calling frame.
  214. """
  215. return None
  216. return lambda expr: parseExpr(expr, lookup_base)
  217. def createResolutionCallbackFromFrame(frames_up: int = 0):
  218. """
  219. Creates a function which, given a string variable name,
  220. returns the value of the variable in the scope of the caller of
  221. the function which called createResolutionCallbackFromFrame (by default).
  222. This is used to enable access in-scope Python variables inside
  223. TorchScript fragments.
  224. frames_up is number of additional frames to go up on the stack.
  225. The default value is 0, which correspond to the frame of the caller
  226. of createResolutionCallbackFromFrame. Also for example, if frames_up is set
  227. to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
  228. will be taken.
  229. For example, the following program prints 2::
  230. def bar():
  231. cb = createResolutionCallbackFromFrame(1)
  232. print(cb("foo"))
  233. def baz():
  234. foo = 2
  235. bar()
  236. baz()
  237. """
  238. frame = inspect.currentframe()
  239. i = 0
  240. while i < frames_up + 1:
  241. assert frame is not None
  242. frame = frame.f_back
  243. i += 1
  244. assert frame is not None
  245. f_locals = frame.f_locals
  246. f_globals = frame.f_globals
  247. class env:
  248. def __getattr__(self, key):
  249. if key in f_locals:
  250. return f_locals[key]
  251. elif key in f_globals:
  252. return f_globals[key]
  253. elif key in dir(builtins):
  254. return getattr(builtins, key)
  255. return createResolutionCallbackFromEnv(env())
  256. def get_closure(fn):
  257. """
  258. Get a dictionary of closed over variables from a function
  259. """
  260. captures = {}
  261. captures.update(fn.__globals__)
  262. for index, captured_name in enumerate(fn.__code__.co_freevars):
  263. captures[captured_name] = fn.__closure__[index].cell_contents
  264. return captures
  265. # [local resolution in python]
  266. # Depending on where a variable is defined, and where it is used, we may
  267. # or may not be able to recover its value when recursively compiling a
  268. # script function. Remember in the general case, a module or function is
  269. # first defined and then later scripted. This means we do not have a
  270. # chance to capture the active frames when the function is defined. Hence any
  271. # name resolution has to happen later on the created closure. The way
  272. # python captures type annotations restricts what we can recover. The
  273. # follow example illustrates the different cases:
  274. #
  275. # class MyGlobalClass:
  276. # ...
  277. # def my_local_scope():
  278. # @torch.jit.script
  279. # class MyClass:
  280. # ...
  281. # @torch.jit.script
  282. # class MyClassUsedAsVar:
  283. # ...
  284. # def eg(x: MyClass, y: MyGlobalClass):
  285. # a_local_capture : Foo
  286. # return MyClassUsedAsVar(x)
  287. #
  288. # MyGlobalClass is defined in the __globals__ dictionary of function
  289. # 'eg', so it is always recoverable. my_local_scope introduces a new local
  290. # variable scope in the function. Classes defined here are only visible as
  291. # local variables. For the case of MyClassUsedAsVar, it is captured
  292. # because it is used as a variable inside the body of the function, and we
  293. # can resolve it using the captures returned from `get_closure`. However,
  294. # the type annotations are not captured by the closure. In Python
  295. # 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
  296. # annotations on `eg``, but starting in Python 4.0, they will represented as
  297. # strings and no longer present. Furthermore, since the body of `eg` does
  298. # not reference those names, they do not appear in the list of closed over
  299. # variables. In Python 2.x, type annotations are in comments, leading to a
  300. # similar situation where their definitions are not available. We anticipate
  301. # that most users will not run into this issue because their modules and
  302. # functions will be defined at a global scope like MyGlobalClass. In cases
  303. # where they are not, it is possible to work around issues by declaring the
  304. # values global in the function.
  305. # In Python 3.9 declaring class as global will make it invisible to
  306. # `inspect.getsource`, see https://bugs.python.org/issue42666 .
  307. # This could be worked around by manually adding it to `global()` dictionary.
  308. def createResolutionCallbackFromClosure(fn):
  309. """
  310. Create a resolutionCallback by introspecting the function instead of
  311. looking up the stack for the enclosing scope
  312. """
  313. closure = get_closure(fn)
  314. class closure_lookup:
  315. # This is a class since `closure` is a dict and it's easier in
  316. # `env_helper` if everything just works with `getattr` calls
  317. def __getattr__(self, key):
  318. if key in closure:
  319. return closure[key]
  320. elif hasattr(typing, key):
  321. return getattr(typing, key)
  322. elif hasattr(builtins, key):
  323. return getattr(builtins, key)
  324. return None
  325. return createResolutionCallbackFromEnv(closure_lookup())
  326. def can_compile_class(cls) -> bool:
  327. # If any of the functions on a type don't have a code object, this type can't
  328. # be compiled and is probably a builtin / bound from C
  329. if is_ignored_fn(cls):
  330. return False
  331. # Ignore the following list of built-in classes.
  332. ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
  333. if issubclass(cls, ignored_builtin_classes):
  334. return False
  335. names = cls.__dict__
  336. fns = [
  337. getattr(cls, name)
  338. for name in names
  339. if inspect.isroutine(getattr(cls, name, None))
  340. ]
  341. has_code = [hasattr(fn, "__code__") for fn in fns]
  342. return all(has_code)
  343. def get_callable_argument_names(fn) -> list[str]:
  344. """
  345. Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
  346. Returns an empty list when other types of arguments are present.
  347. This is used by `torch.jit.trace` to assign meaningful argument names to
  348. traced functions and modules.
  349. Args:
  350. fn: A callable.
  351. Returns:
  352. Argument names: List[str]
  353. """
  354. # inspect.signature may fail, give up in that case.
  355. try:
  356. callable_signature = inspect.signature(fn)
  357. except Exception:
  358. return []
  359. argument_names = []
  360. for name, param in callable_signature.parameters.items():
  361. # All four other types of arguments do not map to individual values
  362. # with a keyword as name.
  363. if param.kind != param.POSITIONAL_OR_KEYWORD:
  364. continue
  365. argument_names.append(name)
  366. return argument_names
  367. def get_annotation_str(annotation):
  368. """
  369. Convert an AST node containing a type annotation to the string present in the source
  370. that represents the same annotation.
  371. """
  372. if isinstance(annotation, ast.Name):
  373. return annotation.id
  374. elif isinstance(annotation, ast.Attribute):
  375. return ".".join([get_annotation_str(annotation.value), annotation.attr])
  376. elif isinstance(annotation, ast.Subscript):
  377. # In Python3.9+ subscript indices are not wrapped in ast.Index
  378. subscript_slice = annotation.slice
  379. return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
  380. elif isinstance(annotation, ast.Tuple):
  381. return ",".join([get_annotation_str(elt) for elt in annotation.elts])
  382. elif isinstance(annotation, ast.Constant):
  383. return f"{annotation.value}"
  384. # If an AST node is not handled here, it's probably handled in ScriptTypeParser.
  385. return None
  386. def get_type_hint_captures(fn):
  387. """
  388. Get a dictionary containing type resolution mappings necessary to resolve types
  389. for the literal annotations on 'fn'. These are not considered to be closed-over by fn
  390. and must be obtained separately (e.g. using this function).
  391. Args:
  392. fn: A callable.
  393. Returns:
  394. A Dict[str, Any] containing a mapping from the literal annotations used on
  395. fn to the Python objects they refer to.
  396. """
  397. # First, try to get the source of the function. We'll need to parse it to find the actual string names
  398. # that were used to annotate the types, since inspect.signature() will only return the class object that
  399. # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
  400. # This may happen in cases where the function is synthesized dynamically at runtime.
  401. src = loader.get_source(fn)
  402. if src is None:
  403. try:
  404. src = inspect.getsource(fn)
  405. except OSError as e:
  406. raise OSError(
  407. f"Failed to get source for {fn} using inspect.getsource"
  408. ) from e
  409. # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
  410. # types are strings. These are only understood by TorchScript in the context of a type annotation
  411. # that refers to a class in its own definition, but trying to include a mapping for this in the result
  412. # function would cause infinite recursion because the class is currently being compiled.
  413. # In addition, there is logic in ScriptTypeParser to handle this.
  414. signature = inspect.signature(fn)
  415. name_to_type = {
  416. name: parameter.annotation
  417. for name, parameter in signature.parameters.items()
  418. if parameter.annotation is not inspect.Parameter.empty
  419. and not isinstance(parameter.annotation, str)
  420. }
  421. # Then, get the literal type annotations from the function declaration
  422. # by source inspection. This accounts for the case in which aliases are used
  423. # to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
  424. # frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
  425. a = ast.parse(textwrap.dedent(src))
  426. if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
  427. raise RuntimeError(f"Expected {fn} to be a function")
  428. f = a.body[0]
  429. # Prepare a dictionary of source annotation -> type, which will be the final result of this function,
  430. # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
  431. # them to the type object corresponding to the annotation via name_to_type using the parameter name.
  432. annotation_to_type = {}
  433. for arg in f.args.args:
  434. # Get the source type annotation string for this argument if possible.
  435. arg_annotation_str = (
  436. get_annotation_str(arg.annotation) if arg.annotation else None
  437. )
  438. # If the argument has no annotation or get_annotation_str cannot convert it to a string,
  439. # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
  440. # this in the latter case.
  441. if arg_annotation_str is None:
  442. continue
  443. # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
  444. # be present in name_to_type is that the annotation itself is a string and not a type object
  445. # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
  446. arg_name = arg.arg
  447. if arg_name in name_to_type:
  448. annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
  449. # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
  450. # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
  451. # of the annotation cannot be a string.
  452. literal_return_annotation = get_annotation_str(f.returns)
  453. valid_literal_annotation = literal_return_annotation is not None
  454. return_annotation = signature.return_annotation
  455. valid_return_annotation_type = (
  456. return_annotation is not inspect.Parameter.empty
  457. and not isinstance(return_annotation, str)
  458. )
  459. if valid_literal_annotation and valid_return_annotation_type:
  460. annotation_to_type[literal_return_annotation] = return_annotation
  461. return annotation_to_type
  462. def createResolutionCallbackForClassMethods(cls):
  463. """
  464. This looks at all the methods defined in a class and pulls their closed-over
  465. variables into a dictionary and uses that to resolve variables.
  466. """
  467. # cls is a type here, so `ismethod` is false since the methods on the type
  468. # aren't bound to anything, so Python treats them as regular functions
  469. fns = [
  470. getattr(cls, name)
  471. for name in cls.__dict__
  472. if inspect.isroutine(getattr(cls, name))
  473. ]
  474. # Skip built-ins, as they do not have global scope nor type hints
  475. # Needed to support `enum.Enum` derived classes in Python-3.11
  476. # That adds `_new_member_` property which is an alias to `__new__`
  477. fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
  478. captures = {}
  479. for fn in fns:
  480. captures.update(get_closure(fn))
  481. captures.update(get_type_hint_captures(fn))
  482. def lookup_in_class(key):
  483. if key in captures:
  484. return captures[key]
  485. else:
  486. return getattr(builtins, key, None)
  487. return lookup_in_class
  488. def boolean_dispatch(
  489. arg_name,
  490. arg_index,
  491. default,
  492. if_true,
  493. if_false,
  494. module_name,
  495. func_name,
  496. ):
  497. """
  498. Dispatches to either of 2 script functions based on a boolean argument.
  499. In TorchScript, the boolean argument must be constant so that the correct
  500. function to use can be determined at compile time.
  501. """
  502. def fn(*args, **kwargs):
  503. dispatch_flag = default
  504. if arg_name in kwargs:
  505. dispatch_flag = kwargs[arg_name]
  506. elif arg_index < len(args):
  507. dispatch_flag = args[arg_index]
  508. if dispatch_flag:
  509. return if_true(*args, **kwargs)
  510. else:
  511. return if_false(*args, **kwargs)
  512. if if_true.__doc__ is None and if_false.__doc__ is not None:
  513. doc = if_false.__doc__
  514. if_true.__doc__ = doc
  515. elif if_false.__doc__ is None and if_true.__doc__ is not None:
  516. doc = if_true.__doc__
  517. if_false.__doc__ = doc
  518. elif if_false.__doc__ is None and if_true.__doc__ is None:
  519. # neither function has a docstring
  520. doc = None
  521. else:
  522. raise RuntimeError("only one function can have a docstring")
  523. fn.__doc__ = doc
  524. if module_name is not None:
  525. fn.__module__ = module_name
  526. if func_name is not None:
  527. fn.__name__ = func_name
  528. boolean_dispatched[fn] = {
  529. "if_true": if_true,
  530. "if_false": if_false,
  531. "index": arg_index,
  532. "default": default,
  533. "arg_name": arg_name,
  534. }
  535. return fn
  536. class FunctionModifiers:
  537. """
  538. Used to denote the behavior of a function in TorchScript. See export() and
  539. ignore() for details.
  540. """
  541. UNUSED = "unused (ignored and replaced with raising of an exception)"
  542. IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
  543. EXPORT = "export (compile this function even if nothing calls it)"
  544. DEFAULT = "default (compile if called from a exported function / forward)"
  545. COPY_TO_SCRIPT_WRAPPER = (
  546. "if this method is not scripted, copy the python method onto the scripted model"
  547. )
  548. _DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
  549. def export(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  550. """
  551. This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
  552. :class:`ScriptModule` and should be compiled.
  553. ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
  554. Functions and methods called from ``forward`` are compiled as they are seen
  555. by the compiler, so they do not need this decorator either.
  556. Example (using ``@torch.jit.export`` on a method):
  557. .. testcode::
  558. import torch
  559. import torch.nn as nn
  560. class MyModule(nn.Module):
  561. def implicitly_compiled_method(self, x):
  562. return x + 99
  563. # `forward` is implicitly decorated with `@torch.jit.export`,
  564. # so adding it here would have no effect
  565. def forward(self, x):
  566. return x + 10
  567. @torch.jit.export
  568. def another_forward(self, x):
  569. # When the compiler sees this call, it will compile
  570. # `implicitly_compiled_method`
  571. return self.implicitly_compiled_method(x)
  572. def unused_method(self, x):
  573. return x - 20
  574. # `m` will contain compiled methods:
  575. # `forward`
  576. # `another_forward`
  577. # `implicitly_compiled_method`
  578. # `unused_method` will not be compiled since it was not called from
  579. # any compiled methods and wasn't decorated with `@torch.jit.export`
  580. m = torch.jit.script(MyModule())
  581. """
  582. fn._torchscript_modifier = FunctionModifiers.EXPORT # type:ignore[attr-defined]
  583. return fn
  584. def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  585. """
  586. This decorator indicates to the compiler that a function or method should
  587. be ignored and replaced with the raising of an exception. This allows you
  588. to leave code in your model that is not yet TorchScript compatible and still
  589. export your model.
  590. Example (using ``@torch.jit.unused`` on a method)::
  591. import torch
  592. import torch.nn as nn
  593. class MyModule(nn.Module):
  594. def __init__(self, use_memory_efficient):
  595. super().__init__()
  596. self.use_memory_efficient = use_memory_efficient
  597. @torch.jit.unused
  598. def memory_efficient(self, x):
  599. import pdb
  600. pdb.set_trace()
  601. return x + 10
  602. def forward(self, x):
  603. # Use not-yet-scriptable memory efficient mode
  604. if self.use_memory_efficient:
  605. return self.memory_efficient(x)
  606. else:
  607. return x + 10
  608. m = torch.jit.script(MyModule(use_memory_efficient=False))
  609. m.save("m.pt")
  610. m = torch.jit.script(MyModule(use_memory_efficient=True))
  611. # exception raised
  612. m(torch.rand(100))
  613. """
  614. if isinstance(fn, property):
  615. prop = fn
  616. setattr( # noqa: B010
  617. prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
  618. )
  619. if prop.fset:
  620. setattr( # noqa: B010
  621. prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
  622. )
  623. return prop # pyrefly: ignore [bad-return]
  624. fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
  625. return fn
  626. # No op context manager from python side
  627. class _IgnoreContextManager(contextlib.AbstractContextManager):
  628. def __init__(self, **kwargs):
  629. pass
  630. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  631. pass
  632. def ignore(drop=False, **kwargs):
  633. """
  634. This decorator indicates to the compiler that a function or method should
  635. be ignored and left as a Python function. This allows you to leave code in
  636. your model that is not yet TorchScript compatible. If called from TorchScript,
  637. ignored functions will dispatch the call to the Python interpreter. Models with ignored
  638. functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
  639. Example (using ``@torch.jit.ignore`` on a method)::
  640. import torch
  641. import torch.nn as nn
  642. class MyModule(nn.Module):
  643. @torch.jit.ignore
  644. def debugger(self, x):
  645. import pdb
  646. pdb.set_trace()
  647. def forward(self, x):
  648. x += 10
  649. # The compiler would normally try to compile `debugger`,
  650. # but since it is `@ignore`d, it will be left as a call
  651. # to Python
  652. self.debugger(x)
  653. return x
  654. m = torch.jit.script(MyModule())
  655. # Error! The call `debugger` cannot be saved since it calls into Python
  656. m.save("m.pt")
  657. Example (using ``@torch.jit.ignore(drop=True)`` on a method):
  658. .. testcode::
  659. import torch
  660. import torch.nn as nn
  661. class MyModule(nn.Module):
  662. @torch.jit.ignore(drop=True)
  663. def training_method(self, x):
  664. import pdb
  665. pdb.set_trace()
  666. def forward(self, x):
  667. if self.training:
  668. self.training_method(x)
  669. return x
  670. m = torch.jit.script(MyModule())
  671. # This is OK since `training_method` is not saved, the call is replaced
  672. # with a `raise`.
  673. m.save("m.pt")
  674. .. testcleanup::
  675. import os
  676. os.remove('m.pt')
  677. """
  678. if callable(drop):
  679. # used without any args, so drop is actually a function
  680. # @torch.jit.ignore
  681. # def fn(...):
  682. fn = drop
  683. # pyrefly: ignore [missing-attribute]
  684. fn._torchscript_modifier = FunctionModifiers.IGNORE
  685. return fn
  686. if not isinstance(drop, bool):
  687. raise RuntimeError(
  688. f"Argument to @torch.jit.ignore must be a bool or a function but got {drop}"
  689. )
  690. # for backwards compat
  691. drop_on_export = kwargs.pop("drop_on_export", None)
  692. if drop_on_export:
  693. warnings.warn(
  694. "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
  695. "call on compilation. Use torch.jit.unused now. {}",
  696. stacklevel=2,
  697. category=FutureWarning,
  698. )
  699. drop = drop_on_export
  700. elif drop:
  701. warnings.warn(
  702. "ignore(True) has been deprecated. TorchScript will now drop the function "
  703. "call on compilation. Use torch.jit.unused now. {}",
  704. stacklevel=2,
  705. category=FutureWarning,
  706. )
  707. def decorator(fn):
  708. if drop:
  709. fn._torchscript_modifier = FunctionModifiers.UNUSED
  710. else:
  711. fn._torchscript_modifier = FunctionModifiers.IGNORE
  712. return fn
  713. return decorator
  714. def _drop(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  715. fn._torchscript_modifier = FunctionModifiers._DROP # type: ignore[attr-defined]
  716. return fn
  717. def _copy_to_script_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
  718. fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER # type: ignore[attr-defined]
  719. return fn
  720. def module_has_exports(mod):
  721. for name in dir(mod):
  722. if hasattr(mod, name):
  723. item = getattr(mod, name)
  724. if callable(item):
  725. if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
  726. return True
  727. return False
  728. # WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
  729. # rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
  730. # allow JIT'd code to still be covered.
  731. def should_drop(fn) -> bool:
  732. attr = get_torchscript_modifier(fn)
  733. if attr is None:
  734. return False
  735. return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
  736. def is_ignored_fn(fn) -> bool:
  737. mod = get_torchscript_modifier(fn)
  738. return (
  739. mod is FunctionModifiers.UNUSED
  740. or mod is FunctionModifiers.IGNORE
  741. or mod is FunctionModifiers._DROP
  742. )
  743. def _is_drop_fn(fn) -> bool:
  744. mod = get_torchscript_modifier(fn)
  745. return mod is FunctionModifiers._DROP
  746. def is_static_fn(cls, fn) -> bool:
  747. return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
  748. def get_static_fn(cls, fn):
  749. return inspect.getattr_static(cls, fn).__func__
  750. def get_torchscript_modifier(fn):
  751. if not callable(fn):
  752. return None
  753. if hasattr(fn, "__func__"):
  754. fn = fn.__func__
  755. return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
  756. def copy_torchscript_modifier(orig, new) -> None:
  757. attr = get_torchscript_modifier(orig)
  758. if attr is None:
  759. return
  760. new._torchscript_modifier = attr
  761. # overloading registration
  762. # overloads get registered in this file, and compiled in torch/jit/__init__.py
  763. # so that they can be imported in nn/functional.py without an import cycle
  764. # qualified_name => list[overload_functions]
  765. _overloaded_fns: dict[str, list[Callable]] = {} # noqa: T484
  766. _OVERLOAD_EXAMPLE = """
  767. Example usage of overload function:
  768. @torch.jit._overload
  769. def my_function(x: type0) -> type0: # decl 1
  770. pass
  771. @torch.jit._overload
  772. def my_function(x: type1) -> type1: # decl 2
  773. pass
  774. def my_function(x): # implementation
  775. if isinstance(x, type0):
  776. return x
  777. elif isinstance(x, type1):
  778. return x
  779. """
  780. def get_overload_no_implementation_error_message(kind, obj):
  781. sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
  782. return (
  783. f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
  784. f"sure a definition is provided and defined after all overload declarations.\n"
  785. f'File "{filename}", line {file_lineno}:\n'
  786. + "".join(sourcelines)
  787. + "\n"
  788. + _OVERLOAD_EXAMPLE
  789. )
  790. def _check_overload_body(func):
  791. try:
  792. parsed_def = parse_def(func)
  793. except OSError:
  794. # Parsing the function definition can raise an OSError if source is unavailable.
  795. # Since this is just an initial check, just raise a warning if this is the case.
  796. warnings.warn(
  797. f"Unable to retrieve source for @torch.jit._overload function: {func}.",
  798. stacklevel=2,
  799. )
  800. return
  801. body = parsed_def.ast.body[0].body
  802. def is_pass(x):
  803. return isinstance(x, ast.Pass)
  804. def is_ellipsis(x):
  805. return (
  806. isinstance(x, ast.Expr)
  807. and isinstance(x.value, ast.Constant)
  808. and x.value.value is Ellipsis
  809. )
  810. if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
  811. msg = (
  812. "Only `pass` statement or `...` can be the body of overload declaration:\n"
  813. )
  814. msg += "\n".join(parsed_def.source.split("\n")[:3])
  815. msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
  816. raise RuntimeError(msg)
  817. def _overload(func):
  818. _check_overload_body(func)
  819. qual_name = _qualified_name(func)
  820. global _overloaded_fns
  821. fn_overload_list = _overloaded_fns.get(qual_name)
  822. if fn_overload_list is None:
  823. fn_overload_list = []
  824. _overloaded_fns[qual_name] = fn_overload_list
  825. fn_overload_list.append(func)
  826. return func
  827. def _get_fn_overloads(qual_name):
  828. return _overloaded_fns.get(qual_name)
  829. def _clear_fn_overloads(qual_name) -> None:
  830. del _overloaded_fns[qual_name]
  831. def get_class_name_lineno(method) -> tuple[str, int]:
  832. current_frame = inspect.currentframe()
  833. # one for the get_class_name call, one for _overload_method call
  834. for _ in range(2):
  835. assert (
  836. current_frame is not None
  837. ) # assert current frame is not an Optional[FrameType]
  838. current_frame = current_frame.f_back
  839. assert current_frame is not None # same here
  840. class_name = current_frame.f_code.co_name
  841. line_no = current_frame.f_code.co_firstlineno
  842. return class_name, line_no
  843. # At the point the decorator is applied to class methods the method
  844. # has no reference to its owning class. _qualified_name would not include
  845. # the class it is defined in, so any methods with the same name in the same file
  846. # would have the same _qualified_name, even if they were defined in different
  847. # classes. This problem only exists in python 2.
  848. # We get around this problem by looking at the stack frame and identifying
  849. # the class name, and throwing an error whenever overloads are used
  850. # when modules of the same name are in the same file
  851. # qualified_name => class name => list[overload_functions]
  852. _overloaded_methods: dict[str, dict[str, list[Callable]]] = {} # noqa: T484
  853. # (qualified_name, class name) => class_fileno
  854. _overloaded_method_class_fileno: dict[tuple[str, str], int] = {}
  855. def _overload_method(func):
  856. _check_overload_body(func)
  857. qual_name = _qualified_name(func)
  858. global _overloaded_methods
  859. class_name_map = _overloaded_methods.get(qual_name)
  860. if class_name_map is None:
  861. class_name_map = {}
  862. _overloaded_methods[qual_name] = class_name_map
  863. class_name, line_no = get_class_name_lineno(func)
  864. method_overloads = class_name_map.get(class_name)
  865. if method_overloads is None:
  866. method_overloads = []
  867. class_name_map[class_name] = method_overloads
  868. _overloaded_method_class_fileno[(qual_name, class_name)] = line_no
  869. else:
  870. existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
  871. if existing_lineno != line_no:
  872. raise RuntimeError(
  873. "Cannot currently overload the same method name in two different"
  874. " classes with the same name in the same module"
  875. )
  876. method_overloads.append(func)
  877. return func
  878. def _get_overloaded_methods(method, mod_class):
  879. # TODO: __name__ not set for submodules in recursive script
  880. if not hasattr(method, "__name__"):
  881. return None
  882. qual_name = _qualified_name(method)
  883. class_name_map = _overloaded_methods.get(qual_name)
  884. if class_name_map is None:
  885. return None
  886. overloads = class_name_map.get(mod_class.__name__, None)
  887. if overloads is None:
  888. return None
  889. method_line_no = get_source_lines_and_file(method)[1]
  890. mod_class_fileno = get_source_lines_and_file(mod_class)[1]
  891. mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
  892. if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
  893. raise AssertionError(
  894. "Overloads are not usable when a module is redeclared within the same file: "
  895. + str(method)
  896. )
  897. return overloads
  898. def is_tuple(ann) -> bool:
  899. # Check for typing.Tuple missing args (but `tuple` is fine)
  900. if ann is typing.Tuple: # noqa: UP006
  901. raise_error_container_parameter_missing("Tuple")
  902. # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
  903. if not hasattr(ann, "__module__"):
  904. return False
  905. ann_origin = get_origin(ann)
  906. return ann.__module__ in ("builtins", "typing") and ann_origin is tuple
  907. def is_list(ann) -> bool:
  908. # Check for typing.List missing args (but `list` is fine)
  909. if ann is typing.List: # noqa: UP006
  910. raise_error_container_parameter_missing("List")
  911. if not hasattr(ann, "__module__"):
  912. return False
  913. ann_origin = get_origin(ann)
  914. return ann.__module__ in ("builtins", "typing") and ann_origin is list
  915. def is_dict(ann) -> bool:
  916. # Check for typing.Dict missing args (but `dict` is fine)
  917. if ann is typing.Dict: # noqa: UP006
  918. raise_error_container_parameter_missing("Dict")
  919. if not hasattr(ann, "__module__"):
  920. return False
  921. ann_origin = get_origin(ann)
  922. return ann.__module__ in ("builtins", "typing") and ann_origin is dict
  923. def is_union(ann):
  924. if ann is Union:
  925. raise_error_container_parameter_missing("Union")
  926. return isinstance(ann, BuiltinUnionType) or (
  927. hasattr(ann, "__module__")
  928. and ann.__module__ == "typing"
  929. and (get_origin(ann) is Union)
  930. )
  931. def is_optional(ann):
  932. if ann is Optional:
  933. raise_error_container_parameter_missing("Optional")
  934. def is_optional_as_optional(ann):
  935. return (
  936. hasattr(ann, "__module__")
  937. and ann.__module__ == "typing"
  938. and (get_origin(ann) is Optional)
  939. )
  940. def is_union_as_optional(ann):
  941. ann_args = get_args(ann)
  942. return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
  943. return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
  944. def is_future(ann) -> bool:
  945. if ann is Future:
  946. raise RuntimeError(
  947. "Attempted to use Future without a "
  948. "contained type. Please add a contained type, e.g. "
  949. "Future[int]"
  950. )
  951. return get_origin(ann) is Future
  952. def is_await(ann) -> bool:
  953. if ann is _Await:
  954. return True
  955. return get_origin(ann) is _Await
  956. if torch.distributed.rpc.is_available():
  957. from torch._C._distributed_rpc import PyRRef
  958. from torch.distributed.rpc import RRef
  959. def is_rref(ann) -> bool:
  960. if ann is RRef:
  961. raise RuntimeError(
  962. "Attempted to use RRef without a "
  963. "contained type. Please add a contained type, e.g. "
  964. "RRef[int]"
  965. )
  966. return get_origin(ann) is RRef
  967. def is_rref_instance(obj) -> bool:
  968. return isinstance(obj, PyRRef)
  969. else:
  970. def is_rref_instance(obj) -> bool:
  971. # If the RPC module doesn't exist then RRefs don't exist either.
  972. return False
  973. def _try_get_dispatched_fn(fn):
  974. if not callable(fn):
  975. return None
  976. return boolean_dispatched.get(fn)
  977. def _get_named_tuple_properties(
  978. obj,
  979. loc: torch._C._jit_tree_views.SourceRange | None = None,
  980. rcb=None,
  981. ):
  982. if loc is None:
  983. loc = fake_range()
  984. assert issubclass(obj, tuple) and hasattr(obj, "_fields")
  985. if hasattr(obj, "_field_defaults"):
  986. defaults = [
  987. obj._field_defaults[field]
  988. for field in obj._fields
  989. if field in obj._field_defaults
  990. ]
  991. else:
  992. defaults = []
  993. obj_annotations = inspect.get_annotations(obj)
  994. if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
  995. obj_annotations = inspect.get_annotations(
  996. # pyrefly: ignore [bad-argument-type]
  997. obj.__base__
  998. )
  999. annotations = []
  1000. for field in obj._fields:
  1001. if field in obj_annotations:
  1002. field_type = obj_annotations[field]
  1003. # [Note: ForwardRef annotations in NamedTuple attributes]
  1004. # NamedTuple types are slightly different from normal types.
  1005. #
  1006. # Normally, annotations are evaluated like this (during jit.script):
  1007. # 1. Load strings of python code into c++ and parse.
  1008. # 2. Get annotations as strings
  1009. # 3. Use the PythonResolver's resolution callback (rcb) to convert
  1010. # the string into a python object
  1011. # 4. We call into annotations.py:ann_to_type to convert python obj
  1012. # from step 3 into a type that torchscript understands.
  1013. #
  1014. # NamedTuples are more complicated, because it has sub-types.
  1015. # Normally, once we have the NamedTuple type object from #3,
  1016. # we can just look at the annotation literal values and use
  1017. # ann_to_type directly on them.
  1018. #
  1019. # But sometimes, users will annotate with string literals, e.g.
  1020. # x: 'int'
  1021. # This also happens with PEP563 (from __forward__ import annotations)
  1022. #
  1023. # These annotations appear in the annotation dict as ForwardRef('int').
  1024. #
  1025. # Then, we need to convert the string into a python object. This
  1026. # requires having local context for custom objects or imported types.
  1027. # rcb() is what gives us this. So, we plumb rcb through the stack so
  1028. # it can be used in this context for the if block below.
  1029. #
  1030. # FAQ:
  1031. # - Why do we need this special handling for NamedTuple but string
  1032. # annotations work fine for normal types? Normally, we parse the
  1033. # string directly and then call rcb() directly from C++.
  1034. # - Why not use ForwardRef._evaluate? For that, we need globals()
  1035. # and locals() for the local context where the NamedTuple was defined.
  1036. # rcb is what lets us look up into these. So, basically rcb does the
  1037. # hard work for us.
  1038. if isinstance(field_type, ForwardRef) and rcb is not None:
  1039. rcb_type = rcb(field_type.__forward_arg__)
  1040. # rcb returns None if it can't find anything.
  1041. if rcb_type is None:
  1042. raise ValueError(
  1043. f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
  1044. f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
  1045. f" Issue occurred at {loc.highlight()}"
  1046. )
  1047. field_type = rcb_type
  1048. the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
  1049. annotations.append(the_type)
  1050. else:
  1051. annotations.append(torch._C.TensorType.getInferred())
  1052. return type(obj).__name__, obj._fields, annotations, defaults
  1053. def _create_named_tuple(
  1054. t,
  1055. unqual_name: str,
  1056. field_names: list[str],
  1057. defaults: tuple[Any, ...],
  1058. ):
  1059. TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
  1060. return TupleType(*t)
  1061. @contextlib.contextmanager
  1062. def _disable_emit_hooks():
  1063. hooks = torch._C._jit_get_emit_hooks()
  1064. torch._C._jit_set_emit_hooks(None, None)
  1065. try:
  1066. yield
  1067. finally:
  1068. torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
  1069. def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811
  1070. # noqa: F841
  1071. def __enter__(self) -> None:
  1072. self.hooks = torch._C._jit_get_emit_hooks()
  1073. torch._C._jit_set_emit_hooks(None, None)
  1074. def __exit__(self, *args) -> None:
  1075. torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
  1076. def _is_exception(obj) -> bool:
  1077. if not inspect.isclass(obj):
  1078. return False
  1079. return issubclass(obj, Exception)
  1080. def raise_error_container_parameter_missing(target_type) -> None:
  1081. if target_type.endswith("ict"):
  1082. raise RuntimeError(
  1083. f"Attempted to use {target_type} without "
  1084. "contained types. Please add contained type, e.g. "
  1085. f"{target_type}[int, int]"
  1086. )
  1087. raise RuntimeError(
  1088. f"Attempted to use {target_type} without a "
  1089. "contained type. Please add a contained type, e.g. "
  1090. f"{target_type}[int]"
  1091. )
  1092. _RAW_TYPE_NAME_MAPPING = {
  1093. dict: "dict",
  1094. list: "list",
  1095. tuple: "tuple",
  1096. typing.Dict: "Dict", # noqa: UP006
  1097. typing.List: "List", # noqa: UP006
  1098. typing.Optional: "Optional",
  1099. typing.Tuple: "Tuple", # noqa: UP006
  1100. }
  1101. def check_args_exist(target_type) -> None:
  1102. if name := _RAW_TYPE_NAME_MAPPING.get(target_type):
  1103. raise_error_container_parameter_missing(name)
  1104. def check_empty_containers(obj) -> None:
  1105. if obj == [] or obj == {} or obj == ():
  1106. warnings.warn(
  1107. "The inner type of a container is lost when "
  1108. "calling torch.jit.isinstance in eager mode. For "
  1109. "example, List[int] would become list and "
  1110. "therefore falsely return True for List[float] or"
  1111. " List[str].",
  1112. stacklevel=2,
  1113. )
  1114. # supports List/Dict/Tuple and Optional types
  1115. # TODO support future
  1116. def container_checker(obj, target_type) -> bool:
  1117. origin_type = get_origin(target_type)
  1118. check_args_exist(target_type)
  1119. if origin_type is None:
  1120. return False
  1121. elif origin_type is list or origin_type is typing.List: # noqa: UP006
  1122. check_empty_containers(obj)
  1123. if not isinstance(obj, list):
  1124. return False
  1125. arg_type = get_args(target_type)[0]
  1126. arg_origin = get_origin(arg_type)
  1127. for el in obj:
  1128. # check if nested container, ex: List[List[str]]
  1129. if arg_origin: # processes nested container, ex: List[List[str]]
  1130. if not container_checker(el, arg_type):
  1131. return False
  1132. elif not isinstance(el, arg_type):
  1133. return False
  1134. return True
  1135. elif origin_type is typing.Dict or origin_type is dict: # noqa: UP006
  1136. check_empty_containers(obj)
  1137. if not isinstance(obj, dict):
  1138. return False
  1139. key_type = get_args(target_type)[0]
  1140. val_type = get_args(target_type)[1]
  1141. for key, val in obj.items():
  1142. # check if keys are of right type
  1143. if not isinstance(key, key_type):
  1144. return False
  1145. val_origin = get_origin(val_type)
  1146. if val_origin:
  1147. if not container_checker(val, val_type):
  1148. return False
  1149. elif not isinstance(val, val_type):
  1150. return False
  1151. return True
  1152. elif origin_type is typing.Tuple or origin_type is tuple: # noqa: UP006
  1153. check_empty_containers(obj)
  1154. if not isinstance(obj, tuple):
  1155. return False
  1156. arg_types = get_args(target_type)
  1157. if len(obj) != len(arg_types):
  1158. return False
  1159. for el, el_type in zip(obj, arg_types):
  1160. el_origin = get_origin(el_type)
  1161. if el_origin:
  1162. if not container_checker(el, el_type):
  1163. return False
  1164. elif not isinstance(el, el_type):
  1165. return False
  1166. return True
  1167. elif origin_type is Union or issubclass(
  1168. # pyrefly: ignore [bad-argument-type]
  1169. origin_type,
  1170. BuiltinUnionType,
  1171. ): # also handles Optional
  1172. if obj is None: # check before recursion because None is always fine
  1173. return True
  1174. inner_types = get_args(target_type)
  1175. for t in inner_types:
  1176. t_origin = get_origin(t)
  1177. if t_origin:
  1178. return container_checker(obj, t)
  1179. elif isinstance(obj, t):
  1180. return True
  1181. return False
  1182. def _isinstance(obj, target_type) -> bool:
  1183. if isinstance(target_type, collections.abc.Container):
  1184. if not isinstance(target_type, tuple):
  1185. raise RuntimeError(
  1186. "The second argument to "
  1187. "`torch.jit.isinstance` must be a type "
  1188. "or a tuple of types"
  1189. )
  1190. for t_type in target_type:
  1191. if _isinstance(obj, t_type):
  1192. return True
  1193. return False
  1194. origin_type = get_origin(target_type)
  1195. if origin_type:
  1196. return container_checker(obj, target_type)
  1197. # Check to handle non-typed optional origin returns as none instead
  1198. # of as optional in 3.7-3.8
  1199. check_args_exist(target_type)
  1200. # handle non-containers
  1201. return isinstance(obj, target_type)
  1202. class _TensorExtractor(pickle.Pickler):
  1203. def __init__(self, *args, tensors: list[torch.Tensor], **kwargs):
  1204. super().__init__(*args, **kwargs)
  1205. self.tensors = tensors
  1206. def persistent_id(self, obj):
  1207. if isinstance(obj, torch.Tensor):
  1208. self.tensors.append(obj)
  1209. return ""
  1210. # Since we just want to extract tensors, we don't mind if an object is
  1211. # unpicklable if it doesn't contain tensors, as we can just ignore/skip
  1212. # it. To play it safe, we only do so for common objects that we're sure
  1213. # don't contain tensors. Feel free to add new types here. Note also that
  1214. # even if a type isn't listed here this won't block users, since they
  1215. # can just add a __getstate__ or __reduce__ method to their class.
  1216. if isinstance(obj, LockType):
  1217. return ""
  1218. # Futures and RRefs don't technically contain a value, they just offer
  1219. # the means to access a value.
  1220. if isinstance(obj, CFuture) or is_rref_instance(obj):
  1221. return ""
  1222. if isinstance(obj, CAwait):
  1223. return ""
  1224. if isinstance(obj, torch.cuda.Event):
  1225. return ""
  1226. if isinstance(obj, threading.Thread):
  1227. return ""
  1228. return None
  1229. def _extract_tensors(obj):
  1230. r"""
  1231. This function is exclusively called from C++.
  1232. See ``torch/csrc/jit/python/python_ivalue.h``.
  1233. It extracts the tensors contained in the given object, through pickling.
  1234. """
  1235. tensors: list[torch.Tensor] = []
  1236. extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
  1237. extractor.dump(obj)
  1238. return tensors
  1239. def _get_model_id(obj) -> str | None:
  1240. if isinstance(obj, torch.jit.ScriptModule):
  1241. return str(obj._c._type())
  1242. elif isinstance(obj, torch.jit.ScriptFunction):
  1243. return obj.qualified_name
  1244. else:
  1245. return None
  1246. # In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
  1247. # that were previously dropped. To preserve the behavior, explicitly drop them there
  1248. if sys.version_info >= (3, 11):
  1249. _drop(enum.Enum.__new__)
  1250. _drop(enum.Enum.__format__)
  1251. _drop(enum.Enum.__repr__)
  1252. _drop(enum.Enum.__str__)