_trace.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406
  1. # mypy: allow-untyped-defs
  2. """Tracing.
  3. This module contains functionality to support the JIT's tracing frontend, notably:
  4. * torch.jit.trace
  5. * torch.jit.trace_module
  6. This is not intended to be imported directly; please use the exposed
  7. functionalities in `torch.jit`.
  8. """
  9. import contextlib
  10. import copy
  11. import functools
  12. import inspect
  13. import os
  14. import re
  15. import warnings
  16. from enum import Enum
  17. from typing import Any, Callable, Optional, TypeVar
  18. from typing_extensions import ParamSpec
  19. import torch
  20. from torch._jit_internal import (
  21. _get_model_id,
  22. _qualified_name,
  23. get_callable_argument_names,
  24. is_scripting,
  25. )
  26. from torch.autograd import function
  27. from torch.jit._script import _CachedForward, script, ScriptModule
  28. from torch.jit._state import _enabled, _python_cu
  29. from torch.nn import Module
  30. from torch.testing._comparison import default_tolerances
  31. _flatten = torch._C._jit_flatten
  32. _unflatten = torch._C._jit_unflatten
  33. R = TypeVar("R", covariant=True) # return type (always covariant)
  34. P = ParamSpec("P")
  35. def _create_interpreter_name_lookup_fn(frames_up=1):
  36. def _get_interpreter_name_for_var(var):
  37. frame = inspect.currentframe()
  38. if not frame:
  39. raise RuntimeError("failed to inspect frame")
  40. i = 0
  41. while i < frames_up + 1:
  42. frame = frame.f_back
  43. if not frame:
  44. raise RuntimeError("failed to get frame")
  45. i += 1
  46. f_locals = frame.f_locals
  47. for k, v in f_locals.items():
  48. if isinstance(v, torch.Tensor) and var is v:
  49. return k if k != "self" else ""
  50. return ""
  51. return _get_interpreter_name_for_var
  52. def _unique_state_dict(module, keep_vars=False):
  53. # since Parameter.detach() always creates a new torch.Tensor instance,
  54. # id(v) doesn't work with it. So we always get the Parameter or Buffer
  55. # as values, and deduplicate the params using Parameters and Buffers
  56. state_dict = module.state_dict(keep_vars=True)
  57. filtered_dict = type(state_dict)()
  58. seen_ids: set[int] = set()
  59. for k, v in state_dict.items():
  60. if id(v) in seen_ids:
  61. continue
  62. seen_ids.add(id(v))
  63. if keep_vars:
  64. filtered_dict[k] = v
  65. else:
  66. filtered_dict[k] = v.detach()
  67. return filtered_dict
  68. class ONNXTracedModule(torch.nn.Module):
  69. def __init__(
  70. self,
  71. inner,
  72. strict=True,
  73. force_outplace=False,
  74. return_inputs=False,
  75. return_inputs_states=False,
  76. ):
  77. super().__init__()
  78. # inner may be a Module, or it may be an arbitrary callable
  79. # If it's a Module, we get its parameters automatically, which lets
  80. # us avoid a special casing functions versus modules.
  81. self.inner = inner
  82. self.strict = strict
  83. self._force_outplace = force_outplace
  84. self._return_inputs = return_inputs
  85. self._return_inputs_states = return_inputs_states
  86. def forward(self, *args: torch.Tensor):
  87. in_vars, in_desc = _flatten(args)
  88. # NOTE: use full state, because we need it for BatchNorm export
  89. # This differs from the compiler path, which doesn't support it at the moment.
  90. module_state = list(_unique_state_dict(self, keep_vars=True).values())
  91. ret_inputs = []
  92. inputs_states = []
  93. outs = []
  94. def wrapper(*args):
  95. in_args: list[torch.Tensor] = []
  96. for i in range(len(in_vars)):
  97. if not isinstance(args[i], torch.Tensor):
  98. raise RuntimeError("Expected Tensor argument")
  99. in_args.append(args[i])
  100. trace_inputs = _unflatten(in_args, in_desc)
  101. if self._return_inputs:
  102. ret_inputs.append(
  103. tuple(x.clone(memory_format=torch.preserve_format) for x in args)
  104. )
  105. if self._return_inputs_states:
  106. inputs_states.append(_unflatten(in_args, in_desc))
  107. outs.append(self.inner(*trace_inputs))
  108. if self._return_inputs_states:
  109. inputs_states[0] = (inputs_states[0], trace_inputs)
  110. out_vars, _ = _flatten(outs)
  111. if len(out_vars) == 1:
  112. return out_vars[0]
  113. else:
  114. return tuple(out_vars)
  115. graph, _out = torch._C._create_graph_by_tracing(
  116. wrapper,
  117. in_vars + module_state,
  118. _create_interpreter_name_lookup_fn(),
  119. self.strict,
  120. self._force_outplace,
  121. )
  122. if self._return_inputs:
  123. return graph, outs[0], ret_inputs[0]
  124. if self._return_inputs_states:
  125. return graph, outs[0], inputs_states[0]
  126. else:
  127. return graph, outs[0]
  128. def _clone_inputs(args):
  129. def clone_input(a):
  130. if a is None:
  131. return None
  132. elif isinstance(a, torch.Tensor):
  133. # TODO: figure out one liner to .clone() and set requires_grad
  134. v = (
  135. a.detach()
  136. .clone(memory_format=None if a.is_mkldnn else torch.preserve_format)
  137. .requires_grad_(a.requires_grad)
  138. )
  139. if a.grad is not None:
  140. v.grad = clone_input(v.grad)
  141. return v
  142. else:
  143. return a.clone(memory_format=torch.preserve_format)
  144. return function._nested_map(
  145. lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
  146. )(args)
  147. # This is purely for developer debugging. We are not going to advertise it.
  148. _JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing
  149. _JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
  150. _JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)
  151. @contextlib.contextmanager
  152. def _time(trace_name, name, time=True):
  153. if (not _JIT_TIME and not time) or not torch.cuda.is_available():
  154. yield
  155. return
  156. stream = torch.cuda.current_stream()
  157. start = torch.cuda.Event(enable_timing=True)
  158. end = torch.cuda.Event(enable_timing=True)
  159. stream.record_event(start)
  160. try:
  161. yield
  162. finally:
  163. stream.record_event(end)
  164. end.synchronize()
  165. print(f"{trace_name} {name} time: {start.elapsed_time(end)} ms")
  166. def verify(model, args, loss_fn=torch.sum, devices=None):
  167. """
  168. Verify that a JIT compiled model has the same behavior as its uncompiled version along with its backwards pass.
  169. If your model returns multiple outputs,
  170. you must also specify a `loss_fn` to produce a loss for which
  171. the backwards will be computed.
  172. This function has side-effects (e.g., it executes your model / saves and loads
  173. parameters), so don't expect the model to come out exactly the same as what
  174. you passed in.
  175. Args:
  176. model (compiled torch.nn.Module or function): the module/function to be
  177. verified. The module/function definition MUST have been decorated with
  178. `@torch.jit.compile`.
  179. args (tuple or Tensor): the positional arguments to pass to the
  180. compiled function/module to be verified. A non-tuple is assumed to
  181. be a single positional argument to be passed to the model.
  182. loss_fn (function, optional): the loss function to be applied to
  183. the output of the model, before backwards is invoked. By default,
  184. we assume that a model returns a single result, and we :func:`torch.sum`
  185. before calling backwards; if this is inappropriate, you can pass your
  186. own loss function. Note that if a model returns a tuple of results,
  187. these are passed as separate positional arguments to `loss_fn`.
  188. devices (iterable of device IDs, optional): the GPU devices which the
  189. compiled module will be run on. This determines the RNG state we
  190. must save when running both compiled and uncompiled versions of the model.
  191. """
  192. # TODO: In principle, we track device information in our trace, so it
  193. # should be possible to check if our execution actually obeyed the 'devices'
  194. # the user provided.
  195. # TODO: Consider adding a utility function to torch.jit to test
  196. # for this case
  197. if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined]
  198. raise TypeError(
  199. "Cannot verify an uncompiled module. Add @torch.jit.compile to compile it"
  200. )
  201. is_module = isinstance(model, Module)
  202. if not isinstance(args, tuple):
  203. args = (args,)
  204. if is_module:
  205. saved_state = copy.deepcopy(model.state_dict())
  206. def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
  207. params = list(model.parameters()) if is_module else []
  208. in_vars, _ = _flatten((args, params))
  209. # We use a special API to reset the trace and compile it from scratch.
  210. compiled_fn = model
  211. if force_trace:
  212. compiled_fn.clear_cache()
  213. if assert_compiled:
  214. hits = compiled_fn.hits
  215. out = model(*args)
  216. if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined]
  217. raise RuntimeError("failed to use the compiled function")
  218. if not isinstance(out, tuple):
  219. out = (out,)
  220. if loss_fn == torch.sum and len(out) != 1:
  221. raise ValueError(
  222. f"Model returns {len(out)} outputs, but default loss function "
  223. "(torch.sum) can only handle a single output"
  224. )
  225. out_vars, _ = _flatten(out)
  226. saved_outs = [
  227. v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
  228. ]
  229. loss = loss_fn(*out)
  230. grads = torch.autograd.grad([loss], in_vars)
  231. # TODO: I'm not sure if the clone here is necessary but it is safer
  232. saved_grads = [
  233. v.detach().clone(memory_format=torch.preserve_format) for v in grads
  234. ]
  235. return (saved_outs, saved_grads)
  236. with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
  237. uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
  238. assert model.has_trace_for(*args)
  239. if is_module:
  240. model.load_state_dict(saved_state) # type: ignore[possibly-undefined]
  241. compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
  242. _verify_equal(uncompiled_outs, compiled_outs)
  243. _verify_equal(uncompiled_grads, compiled_grads)
  244. def _verify_equal(xs, ys):
  245. for x, y in zip(xs, ys):
  246. if x.sub(y).abs().max() > 1e-6:
  247. raise RuntimeError("JIT and real computation mismatch")
  248. def indent(s):
  249. return "\n".join(["\t" + line for line in s.splitlines()])
  250. class TracingCheckError(Exception):
  251. def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
  252. self.message = "Tracing failed sanity checks!\n"
  253. if extra_msg is not None:
  254. self.message += extra_msg + "\n"
  255. if graph_diff_error is not None:
  256. self.message += "ERROR: Graphs differed across invocations!\n"
  257. self.message += indent(graph_diff_error) + "\n"
  258. if tensor_compare_error is not None:
  259. self.message += (
  260. "ERROR: Tensor-valued Constant nodes differed in value "
  261. "across invocations. This often indicates that the tracer has"
  262. " encountered untraceable code.\n"
  263. )
  264. self.message += indent(tensor_compare_error) + "\n"
  265. super().__init__(self.message)
  266. # Check the traced module against a set of user-provided validation inputs
  267. @torch.no_grad()
  268. def _check_trace(
  269. check_inputs,
  270. func,
  271. traced_func,
  272. check_tolerance,
  273. strict,
  274. force_outplace,
  275. is_trace_module,
  276. _module_class,
  277. example_inputs_is_kwarg=False,
  278. ):
  279. # Note: tracing is independent of optimizations, which consume the trace
  280. for inputs in check_inputs:
  281. if isinstance(inputs, torch.Tensor):
  282. inputs = (inputs,)
  283. if is_trace_module:
  284. copied_dict = {}
  285. for name, data in inputs.items():
  286. copied_dict[name] = _clone_inputs(data)
  287. check_mod = torch.jit.trace_module(
  288. getattr(func, "__self__", func),
  289. copied_dict,
  290. check_trace=False,
  291. strict=strict,
  292. _force_outplace=force_outplace,
  293. _module_class=_module_class,
  294. _compilation_unit=torch._C.CompilationUnit(),
  295. example_inputs_is_kwarg=example_inputs_is_kwarg,
  296. _store_inputs=False,
  297. )
  298. check_mod_func = check_mod._c._get_method(traced_func.name)
  299. inputs = inputs[traced_func.name]
  300. if (
  301. isinstance(inputs, (torch.Tensor))
  302. or isinstance(inputs, dict)
  303. and not example_inputs_is_kwarg
  304. ):
  305. inputs = (inputs,)
  306. else:
  307. if example_inputs_is_kwarg:
  308. check_mod = torch.jit.trace(
  309. func,
  310. check_trace=False,
  311. strict=strict,
  312. _force_outplace=force_outplace,
  313. _module_class=_module_class,
  314. example_kwarg_inputs=_clone_inputs(inputs),
  315. _store_inputs=False,
  316. )
  317. else:
  318. check_mod = torch.jit.trace(
  319. func,
  320. _clone_inputs(inputs),
  321. check_trace=False,
  322. strict=strict,
  323. _force_outplace=force_outplace,
  324. _module_class=_module_class,
  325. _store_inputs=False,
  326. )
  327. check_mod_func = check_mod
  328. def graph_diagnostic_info():
  329. mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph)
  330. torch._C._jit_pass_inline(mod_canonicalized)
  331. torch._C._jit_pass_erase_shape_information(mod_canonicalized)
  332. mod_str = str(mod_canonicalized)
  333. mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str)
  334. check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph)
  335. torch._C._jit_pass_inline(check_canonicalized)
  336. torch._C._jit_pass_erase_shape_information(check_canonicalized)
  337. check_str = str(check_canonicalized)
  338. check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str)
  339. graph_diff_errors = None
  340. if mod_str != check_str:
  341. import difflib
  342. graph_diff = difflib.ndiff(
  343. mod_str.splitlines(True), check_str.splitlines(True)
  344. )
  345. graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n"
  346. for n_mod, n_check in zip(
  347. mod_canonicalized.nodes(), check_canonicalized.nodes()
  348. ):
  349. if str(n_mod) != str(n_check):
  350. graph_diff_errors += "First diverging operator:\n"
  351. node_diff = difflib.ndiff(
  352. str(n_mod).splitlines(True), str(n_check).splitlines(True)
  353. )
  354. source_printout = (
  355. "Node diff:\n" + indent("".join(node_diff)) + "\n"
  356. )
  357. mod_stack = n_mod.sourceRange()
  358. if mod_stack:
  359. source_printout += (
  360. "Trace source location:\n" + indent(mod_stack) + "\n"
  361. )
  362. check_stack = n_check.sourceRange()
  363. if check_stack:
  364. source_printout += (
  365. "Check source location:\n" + indent(check_stack) + "\n"
  366. )
  367. graph_diff_errors += source_printout
  368. break # For now, only print out the first pair of nodes that diverges
  369. tensor_compare_errors = None
  370. # Check Tensor-valued constant nodes
  371. for n_mod, n_check in zip(
  372. mod_canonicalized.nodes(), check_canonicalized.nodes()
  373. ):
  374. if n_mod.kind() != n_check.kind():
  375. break # Graphs have already diverged
  376. if n_mod.kind() == "prim::Constant" and not (
  377. n_mod.mustBeNone() or n_check.mustBeNone()
  378. ):
  379. if not n_mod.hasAttribute("value"):
  380. continue
  381. if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t":
  382. continue
  383. mod_tensor_val = n_mod.t("value")
  384. check_tensor_val = n_check.t("value")
  385. try:
  386. torch.testing.assert_close(
  387. mod_tensor_val, check_tensor_val, equal_nan=True
  388. )
  389. except (RuntimeError, AssertionError) as e:
  390. if tensor_compare_errors is None:
  391. tensor_compare_errors = ""
  392. tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n"
  393. compare_stack = n_mod.sourceRange()
  394. if compare_stack:
  395. tensor_compare_errors += (
  396. "Source Location:\n" + indent(compare_stack) + "\n"
  397. )
  398. tensor_compare_errors += "Comparison exception: " + indent(
  399. str(e)
  400. )
  401. break # For now, only print the first diverging pair
  402. return graph_diff_errors, tensor_compare_errors
  403. def wrap_retval(x):
  404. return x if isinstance(x, tuple) else (x,)
  405. def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
  406. try:
  407. if isinstance(inputs, dict) and example_inputs_is_kwarg:
  408. outs = wrap_retval(mod(**inputs))
  409. else:
  410. outs = wrap_retval(mod(*_clone_inputs(inputs)))
  411. outs = [out for out in outs if isinstance(out, torch.Tensor)]
  412. return outs
  413. except Exception as e:
  414. graph_diff_errors, tensor_compare_errors = graph_diagnostic_info()
  415. msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}"
  416. raise TracingCheckError(
  417. graph_diff_errors,
  418. tensor_compare_errors,
  419. extra_msg=msg,
  420. ) from e
  421. has_warned = [False]
  422. def maybe_warn_nondeterministic():
  423. if has_warned[0]:
  424. return
  425. has_warned[0] = True
  426. nondeterm_ops = [
  427. op for op in traced_func.graph.nodes() if op.isNondeterministic()
  428. ]
  429. if len(nondeterm_ops) > 0:
  430. nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
  431. nondeterministic_ops_warning += (
  432. "Did you forget call .eval() on your model? Nodes:\n"
  433. )
  434. nondeterministic_ops_warning += "\n".join(
  435. [indent(str(op)) for op in nondeterm_ops][:20]
  436. )
  437. nondeterministic_ops_warning += (
  438. "\nThis may cause errors in trace checking. To disable trace checking,"
  439. " pass check_trace=False to torch.jit.trace()"
  440. )
  441. warnings.warn(
  442. nondeterministic_ops_warning, category=TracerWarning, stacklevel=5
  443. )
  444. def compare_outputs(original, reference, match_what):
  445. all_ok = True
  446. for i, (orig, ref) in enumerate(zip(original, reference)):
  447. try:
  448. if orig.is_quantized:
  449. orig = orig.dequantize()
  450. if ref.is_quantized:
  451. ref = ref.dequantize()
  452. if orig.is_mkldnn:
  453. orig = orig.to_dense()
  454. if ref.is_mkldnn:
  455. ref = ref.to_dense()
  456. if ref.is_complex() or orig.is_complex():
  457. torch.testing.assert_close(
  458. orig.to(torch.cdouble),
  459. ref.to(torch.cdouble),
  460. rtol=check_tolerance,
  461. atol=default_tolerances(orig, ref)[1],
  462. equal_nan=True,
  463. )
  464. else:
  465. if orig.is_mps or ref.is_mps:
  466. torch.testing.assert_close(
  467. orig.float(),
  468. ref.float(),
  469. rtol=check_tolerance,
  470. atol=default_tolerances(orig, ref)[1],
  471. equal_nan=True,
  472. )
  473. elif getattr(orig, "is_nested", None) or getattr(
  474. ref, "is_nested", None
  475. ):
  476. assert getattr(orig, "is_nested", None) == getattr(
  477. ref, "is_nested", None
  478. )
  479. for t_orig, t_ref in zip(orig.unbind(), ref.unbind()):
  480. torch.testing.assert_close(
  481. t_orig.double(),
  482. t_ref.double(),
  483. rtol=check_tolerance,
  484. atol=default_tolerances(t_orig, t_ref)[1],
  485. equal_nan=True,
  486. )
  487. else:
  488. torch.testing.assert_close(
  489. orig.double(),
  490. ref.double(),
  491. rtol=check_tolerance,
  492. atol=default_tolerances(orig, ref)[1],
  493. equal_nan=True,
  494. )
  495. except AssertionError as e:
  496. maybe_warn_nondeterministic()
  497. warnings.warn(
  498. "Output nr "
  499. + str(i + 1)
  500. + ". of the traced function does not match "
  501. "the corresponding output of the "
  502. + match_what
  503. + ". Detailed error:\n"
  504. + str(e),
  505. category=TracerWarning,
  506. stacklevel=4,
  507. )
  508. all_ok = False
  509. return all_ok
  510. traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
  511. fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
  512. if compare_outputs(traced_outs, fn_outs, "Python function"):
  513. check_outs = run_mod_and_filter_tensor_outputs(
  514. check_mod_func, inputs, "repeated trace"
  515. )
  516. compare_outputs(traced_outs, check_outs, "repeated trace")
  517. diag_info = graph_diagnostic_info()
  518. if any(info is not None for info in diag_info):
  519. raise TracingCheckError(*diag_info)
  520. class TracerWarning(Warning):
  521. @staticmethod
  522. def ignore_lib_warnings():
  523. # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
  524. warnings.filterwarnings(
  525. "ignore", category=TracerWarning, module="torch.(?!jit)"
  526. )
  527. warnings.filterwarnings("ignore", "torch::jit::fuser::cuda")
  528. # We ignore the tracer warnings coming form inside the library, because all our shape
  529. # checks in nn will trigger them.
  530. TracerWarning.ignore_lib_warnings()
  531. torch._C._tracer_warn_use_python()
  532. def make_tuple(example_inputs):
  533. if isinstance(example_inputs, (torch.Tensor, dict)):
  534. return (example_inputs,)
  535. # done primarily so that weird iterables fail here and not pybind11 code
  536. if not isinstance(example_inputs, tuple):
  537. return tuple(example_inputs)
  538. return example_inputs
  539. def make_module(mod, _module_class, _compilation_unit):
  540. if isinstance(mod, ScriptModule):
  541. return mod
  542. elif torch._jit_internal.module_has_exports(mod):
  543. infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
  544. return torch.jit._recursive.create_script_module(
  545. mod, infer_methods_stubs_fn, share_types=False, is_tracing=True
  546. )
  547. else:
  548. if _module_class is None:
  549. _module_class = TopLevelTracedModule
  550. return _module_class(mod, _compilation_unit=_compilation_unit)
  551. def wrap_check_inputs(check_inputs):
  552. if check_inputs is None:
  553. return None
  554. return [{"forward": c} for c in check_inputs]
  555. def analyze_ts_result_with_export_result(export, trace):
  556. import torch.utils._pytree as pytree
  557. flat_export = pytree.tree_leaves(export)
  558. flat_trace = pytree.tree_leaves(trace)
  559. for orig, loaded in zip(flat_export, flat_trace):
  560. if orig.layout != loaded.layout:
  561. return False
  562. # mkldnn is not supported for torch.allclose
  563. if orig.layout == torch._mkldnn: # type: ignore[attr-defined]
  564. return True
  565. if type(orig) != type(loaded):
  566. return False
  567. if isinstance(orig, torch._subclasses.FakeTensor):
  568. # Skip for FakeTensor.
  569. return True
  570. elif isinstance(orig, torch.Tensor):
  571. if orig.dtype != loaded.dtype:
  572. return False
  573. if not torch.allclose(orig, loaded):
  574. return False
  575. else:
  576. if orig != loaded:
  577. return False
  578. return True
  579. def _trace_impl(
  580. func,
  581. example_inputs=None,
  582. optimize=None,
  583. check_trace=True,
  584. check_inputs=None,
  585. check_tolerance=1e-5,
  586. strict=True,
  587. _force_outplace=False,
  588. _module_class=None,
  589. _compilation_unit=_python_cu,
  590. example_kwarg_inputs=None,
  591. _store_inputs=True,
  592. ):
  593. if isinstance(func, torch.jit.ScriptModule):
  594. # it is hard to trace it because the forward method on ScriptModule is already defined, so it
  595. # would result in an error.
  596. warnings.warn(
  597. "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is."
  598. )
  599. return func
  600. if isinstance(func, torch.nn.Module):
  601. if example_inputs is None:
  602. if isinstance(example_kwarg_inputs, dict):
  603. example_inputs = example_kwarg_inputs
  604. else:
  605. raise RuntimeError("example_kwarg_inputs should be a dict")
  606. return trace_module(
  607. func,
  608. {"forward": example_inputs},
  609. None,
  610. check_trace,
  611. wrap_check_inputs(check_inputs),
  612. check_tolerance,
  613. strict,
  614. _force_outplace,
  615. _module_class,
  616. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  617. _store_inputs=_store_inputs,
  618. )
  619. if (
  620. hasattr(func, "__self__")
  621. and isinstance(func.__self__, torch.nn.Module)
  622. and func.__name__ == "forward"
  623. ):
  624. if example_inputs is None:
  625. if isinstance(example_kwarg_inputs, dict):
  626. example_inputs = example_kwarg_inputs
  627. else:
  628. raise RuntimeError("example_kwarg_inputs should be a dict")
  629. return trace_module(
  630. func.__self__,
  631. {"forward": example_inputs},
  632. None,
  633. check_trace,
  634. wrap_check_inputs(check_inputs),
  635. check_tolerance,
  636. strict,
  637. _force_outplace,
  638. _module_class,
  639. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  640. _store_inputs=_store_inputs,
  641. )
  642. # Special case for common case of passing a single Tensor
  643. if (
  644. isinstance(example_inputs, (torch.Tensor, dict))
  645. and example_kwarg_inputs is None
  646. ):
  647. example_inputs = (example_inputs,)
  648. # done primarily so that weird iterables fail here and not pybind11 code
  649. elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple):
  650. example_inputs = tuple(example_inputs)
  651. var_lookup_fn = _create_interpreter_name_lookup_fn(0)
  652. if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module):
  653. raise AttributeError(
  654. "trace doesn't support compiling individual module's functions.\n"
  655. "Please use trace_module"
  656. )
  657. name = _qualified_name(func)
  658. if isinstance(example_kwarg_inputs, dict):
  659. example_inputs = example_kwarg_inputs
  660. traced = torch._C._create_function_from_trace_with_dict(
  661. name,
  662. func,
  663. example_kwarg_inputs,
  664. var_lookup_fn,
  665. strict,
  666. _force_outplace,
  667. get_callable_argument_names(func),
  668. )
  669. else:
  670. traced = torch._C._create_function_from_trace(
  671. name,
  672. func,
  673. example_inputs,
  674. var_lookup_fn,
  675. strict,
  676. _force_outplace,
  677. get_callable_argument_names(func),
  678. )
  679. # Check the trace against new traces created from user-specified inputs
  680. if check_trace:
  681. if check_inputs is not None:
  682. _check_trace(
  683. check_inputs,
  684. func,
  685. traced,
  686. check_tolerance,
  687. strict,
  688. _force_outplace,
  689. False,
  690. _module_class,
  691. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  692. )
  693. else:
  694. _check_trace(
  695. [example_inputs],
  696. func,
  697. traced,
  698. check_tolerance,
  699. strict,
  700. _force_outplace,
  701. False,
  702. _module_class,
  703. example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
  704. )
  705. # Allow torch.compile() to inline
  706. traced._torchdynamo_inline = func # type: ignore[attr-defined]
  707. return traced
  708. class _ExportType(str, Enum):
  709. DIRECT_EXPORT = "DIRECT_EXPORT"
  710. TRACE_AND_EXPORT = "TRACE_AND_EXPORT"
  711. SOURCE_TO_SOURCE = "SOURCE_TO_SOURCE"
  712. def __str__(self) -> str:
  713. return self.value
  714. class _ExportOutcome(str, Enum):
  715. SUCCESS = "SUCCESS"
  716. FAILED_TO_EXPORT = "FAILED_TO_EXPORT"
  717. FAILED_TO_RUN = "FAILED_TO_RUN"
  718. ACCURACY_ERROR = "ACCURACY_ERROR"
  719. def __str__(self) -> str:
  720. return self.value
  721. def trace(
  722. func,
  723. example_inputs=None,
  724. optimize=None,
  725. check_trace=True,
  726. check_inputs=None,
  727. check_tolerance=1e-5,
  728. strict=True,
  729. _force_outplace=False,
  730. _module_class=None,
  731. _compilation_unit=_python_cu,
  732. example_kwarg_inputs=None,
  733. _store_inputs=True,
  734. ):
  735. r"""
  736. Trace a function and return an executable or :class:`ScriptFunction` that will be optimized using just-in-time compilation.
  737. Tracing is ideal for code that operates only on
  738. ``Tensor``\\s and lists, dictionaries, and
  739. tuples of ``Tensor``\\s.
  740. Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an
  741. existing module or Python function into a TorchScript
  742. :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example
  743. inputs, and we run the function, recording the operations performed on all
  744. the tensors.
  745. * The resulting recording of a standalone function produces `ScriptFunction`.
  746. * The resulting recording of `nn.Module.forward` or `nn.Module` produces
  747. `ScriptModule`.
  748. This module also contains any parameters that the original
  749. module had as well.
  750. Warning:
  751. Tracing only correctly records functions and modules which are not data
  752. dependent (e.g., do not have conditionals on data in tensors) and do not have
  753. any untracked external dependencies (e.g., perform input/output or
  754. access global variables). Tracing only records operations done when the given
  755. function is run on the given tensors. Therefore, the returned
  756. `ScriptModule` will always run the same traced graph on any input. This
  757. has some important implications when your module is expected to run
  758. different sets of operations, depending on the input and/or the module
  759. state. For example,
  760. * Tracing will not record any control-flow like if-statements or loops.
  761. When this control-flow is constant across your module, this is fine
  762. and it often inlines the control-flow decisions. But sometimes the
  763. control-flow is actually part of the model itself. For instance, a
  764. recurrent network is a loop over the (possibly dynamic) length of an
  765. input sequence.
  766. * In the returned :class:`ScriptModule`, operations that have different
  767. behaviors in ``training`` and ``eval`` modes will always behave as if
  768. it is in the mode it was in during tracing, no matter which mode the
  769. `ScriptModule` is in.
  770. In cases like these, tracing would not be appropriate and
  771. :func:`scripting <torch.jit.script>` is a better choice. If you trace
  772. such models, you may silently get incorrect results on subsequent
  773. invocations of the model. The tracer will try to emit warnings when
  774. doing something that may cause an incorrect trace to be produced.
  775. Args:
  776. func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
  777. that will be run with `example_inputs`. `func` arguments and return
  778. values must be tensors or (possibly nested) tuples that contain
  779. tensors. When a module is passed `torch.jit.trace`, only the
  780. ``forward`` method is run and traced (see :func:`torch.jit.trace
  781. <torch.jit.trace_module>` for details).
  782. Keyword arguments:
  783. example_inputs (tuple or torch.Tensor or None, optional): A tuple of example
  784. inputs that will be passed to the function while tracing.
  785. Default: ``None``. Either this argument or ``example_kwarg_inputs``
  786. should be specified. The resulting trace can be run with inputs of
  787. different types and shapes assuming the traced operations support those
  788. types and shapes. `example_inputs` may also be a single Tensor in which
  789. case it is automatically wrapped in a tuple. When the value is None,
  790. ``example_kwarg_inputs`` should be specified.
  791. check_trace (``bool``, optional): Check if the same inputs run through
  792. traced code produce the same outputs. Default: ``True``. You might want
  793. to disable this if, for example, your network contains non-
  794. deterministic ops or if you are sure that the network is correct despite
  795. a checker failure.
  796. check_inputs (list of tuples, optional): A list of tuples of input
  797. arguments that should be used to check the trace against what is
  798. expected. Each tuple is equivalent to a set of input arguments that
  799. would be specified in ``example_inputs``. For best results, pass in
  800. a set of checking inputs representative of the space of shapes and
  801. types of inputs you expect the network to see. If not specified,
  802. the original ``example_inputs`` are used for checking
  803. check_tolerance (float, optional): Floating-point comparison tolerance
  804. to use in the checker procedure. This can be used to relax the
  805. checker strictness in the event that results diverge numerically
  806. for a known reason, such as operator fusion.
  807. strict (``bool``, optional): run the tracer in a strict mode or not
  808. (default: ``True``). Only turn this off when you want the tracer to
  809. record your mutable container types (currently ``list``/``dict``)
  810. and you are sure that the container you are using in your
  811. problem is a ``constant`` structure and does not get used as
  812. control flow (if, for) conditions.
  813. example_kwarg_inputs (dict, optional): This parameter is a pack of keyword
  814. arguments of example inputs that will be passed to the function while
  815. tracing. Default: ``None``. Either this argument or ``example_inputs``
  816. should be specified. The dict will be unpacking by the arguments name
  817. of the traced function. If the keys of the dict don't not match with
  818. the traced function's arguments name, a runtime exception will be raised.
  819. Returns:
  820. If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns
  821. a :class:`ScriptModule` object with a single ``forward`` method
  822. containing the traced code. The returned `ScriptModule` will
  823. have the same set of sub-modules and parameters as the original
  824. ``nn.Module``. If ``func`` is a standalone function, ``trace``
  825. returns `ScriptFunction`.
  826. Example (tracing a function):
  827. .. testcode::
  828. import torch
  829. def foo(x, y):
  830. return 2 * x + y
  831. # Run `foo` with the provided inputs and record the tensor operations
  832. traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
  833. # `traced_foo` can now be run with the TorchScript interpreter or saved
  834. # and loaded in a Python-free environment
  835. Example (tracing an existing module)::
  836. import torch
  837. import torch.nn as nn
  838. class Net(nn.Module):
  839. def __init__(self) -> None:
  840. super().__init__()
  841. self.conv = nn.Conv2d(1, 1, 3)
  842. def forward(self, x):
  843. return self.conv(x)
  844. n = Net()
  845. example_weight = torch.rand(1, 1, 3, 3)
  846. example_forward_input = torch.rand(1, 1, 3, 3)
  847. # Trace a specific method and construct `ScriptModule` with
  848. # a single `forward` method
  849. module = torch.jit.trace(n.forward, example_forward_input)
  850. # Trace a module (implicitly traces `forward`) and construct a
  851. # `ScriptModule` with a single `forward` method
  852. module = torch.jit.trace(n, example_forward_input)
  853. """
  854. if not _enabled:
  855. return func
  856. if optimize is not None:
  857. warnings.warn(
  858. "`optimize` is deprecated and has no effect. "
  859. "Use `with torch.jit.optimized_execution()` instead",
  860. FutureWarning,
  861. stacklevel=2,
  862. )
  863. from torch._utils_internal import log_torchscript_usage
  864. traced_func = _trace_impl(
  865. func,
  866. example_inputs,
  867. optimize,
  868. check_trace,
  869. check_inputs,
  870. check_tolerance,
  871. strict,
  872. _force_outplace,
  873. _module_class,
  874. _compilation_unit,
  875. example_kwarg_inputs,
  876. _store_inputs,
  877. )
  878. log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
  879. return traced_func
  880. _trace_module_map: Optional[dict[Any, Any]] = None
  881. def trace_module(
  882. mod,
  883. inputs,
  884. optimize=None,
  885. check_trace=True,
  886. check_inputs=None,
  887. check_tolerance=1e-5,
  888. strict=True,
  889. _force_outplace=False,
  890. _module_class=None,
  891. _compilation_unit=_python_cu,
  892. example_inputs_is_kwarg=False,
  893. _store_inputs=True,
  894. ):
  895. """
  896. Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation.
  897. When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only
  898. the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of
  899. method names to example inputs to trace (see the ``inputs``) argument below.
  900. See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing.
  901. Args:
  902. mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are
  903. specified in ``inputs``. The given methods will be compiled
  904. as a part of a single `ScriptModule`.
  905. inputs (dict): A dict containing sample inputs indexed by method names in ``mod``.
  906. The inputs will be passed to methods whose names correspond to inputs'
  907. keys while tracing.
  908. ``{ 'forward' : example_forward_input, 'method2': example_method2_input}``
  909. Keyword arguments:
  910. check_trace (``bool``, optional): Check if the same inputs run through
  911. traced code produce the same outputs. Default: ``True``. You might want
  912. to disable this if, for example, your network contains non-
  913. deterministic ops or if you are sure that the network is correct despite
  914. a checker failure.
  915. check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used
  916. to check the trace against what is expected. Each tuple
  917. is equivalent to a set of input arguments that would
  918. be specified in ``inputs``. For best results, pass in a
  919. set of checking inputs representative of the space of
  920. shapes and types of inputs you expect the network to see.
  921. If not specified, the original ``inputs`` are used for checking
  922. check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
  923. This can be used to relax the checker strictness in the event that
  924. results diverge numerically for a known reason, such as operator fusion.
  925. example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack
  926. pack of keyword arguments. Default: ``False``.
  927. Returns:
  928. A :class:`ScriptModule` object with a single ``forward`` method containing the traced code.
  929. When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of
  930. sub-modules and parameters as ``func``.
  931. Example (tracing a module with multiple methods)::
  932. import torch
  933. import torch.nn as nn
  934. class Net(nn.Module):
  935. def __init__(self) -> None:
  936. super().__init__()
  937. self.conv = nn.Conv2d(1, 1, 3)
  938. def forward(self, x):
  939. return self.conv(x)
  940. def weighted_kernel_sum(self, weight):
  941. return weight * self.conv.weight
  942. n = Net()
  943. example_weight = torch.rand(1, 1, 3, 3)
  944. example_forward_input = torch.rand(1, 1, 3, 3)
  945. # Trace a specific method and construct `ScriptModule` with
  946. # a single `forward` method
  947. module = torch.jit.trace(n.forward, example_forward_input)
  948. # Trace a module (implicitly traces `forward`) and construct a
  949. # `ScriptModule` with a single `forward` method
  950. module = torch.jit.trace(n, example_forward_input)
  951. # Trace specific methods on a module (specified in `inputs`), constructs
  952. # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
  953. inputs = {
  954. "forward": example_forward_input,
  955. "weighted_kernel_sum": example_weight,
  956. }
  957. module = torch.jit.trace_module(n, inputs)
  958. """
  959. if not _enabled:
  960. return mod
  961. if optimize is not None:
  962. warnings.warn(
  963. "`optimize` is deprecated and has no effect. "
  964. "Use `with torch.jit.optimized_execution()` instead",
  965. FutureWarning,
  966. stacklevel=2,
  967. )
  968. var_lookup_fn = _create_interpreter_name_lookup_fn(0)
  969. if not isinstance(mod, torch.nn.Module):
  970. raise AttributeError("expected torch.nn.Module as the first argument")
  971. if not isinstance(inputs, dict):
  972. raise AttributeError("expected a dictionary of (method_name, input) pairs")
  973. old_module_map = torch.jit._trace._trace_module_map
  974. try:
  975. trace_module_map: dict[Any, Any] = {}
  976. def register_submods(mod, prefix):
  977. for name, child in mod.named_children():
  978. submod_qualname = prefix + "." + name
  979. trace_module_map[child] = submod_qualname
  980. register_submods(child, submod_qualname)
  981. trace_module_map["__module"] = mod
  982. torch.jit._trace._trace_module_map = trace_module_map
  983. register_submods(mod, "__module")
  984. module = make_module(mod, _module_class, _compilation_unit)
  985. for method_name, example_inputs in inputs.items():
  986. if method_name == "forward":
  987. # "forward" is a special case because we need to trace
  988. # `Module.__call__`, which sets up some extra tracing, but uses
  989. # argument names of the real `Module.forward` method.
  990. func = mod
  991. forward_method = getattr(mod, method_name)
  992. argument_names = get_callable_argument_names(forward_method)
  993. else:
  994. func = getattr(mod, method_name)
  995. argument_names = get_callable_argument_names(func)
  996. if isinstance(example_inputs, dict) and example_inputs_is_kwarg:
  997. # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/
  998. for key in example_inputs:
  999. if key not in argument_names:
  1000. valid_arguments = "[" + ",".join(argument_names) + "]"
  1001. raise NameError(
  1002. f"""'{key}' is not in forward() method's arguments,
  1003. valid arguments name are {valid_arguments}"""
  1004. )
  1005. module._c._create_method_from_trace_with_dict(
  1006. method_name,
  1007. func,
  1008. example_inputs,
  1009. var_lookup_fn,
  1010. strict,
  1011. _force_outplace,
  1012. argument_names,
  1013. _store_inputs,
  1014. )
  1015. else:
  1016. example_inputs = make_tuple(example_inputs)
  1017. module._c._create_method_from_trace(
  1018. method_name,
  1019. func,
  1020. example_inputs,
  1021. var_lookup_fn,
  1022. strict,
  1023. _force_outplace,
  1024. argument_names,
  1025. _store_inputs,
  1026. )
  1027. check_trace_method = module._c._get_method(method_name)
  1028. # Check the trace against new traces created from user-specified inputs
  1029. if check_trace:
  1030. if check_inputs is not None:
  1031. _check_trace(
  1032. check_inputs,
  1033. func,
  1034. check_trace_method,
  1035. check_tolerance,
  1036. strict,
  1037. _force_outplace,
  1038. True,
  1039. _module_class,
  1040. example_inputs_is_kwarg=example_inputs_is_kwarg,
  1041. )
  1042. else:
  1043. _check_trace(
  1044. [inputs],
  1045. func,
  1046. check_trace_method,
  1047. check_tolerance,
  1048. strict,
  1049. _force_outplace,
  1050. True,
  1051. _module_class,
  1052. example_inputs_is_kwarg=example_inputs_is_kwarg,
  1053. )
  1054. finally:
  1055. torch.jit._trace._trace_module_map = old_module_map
  1056. return module
  1057. def is_tracing():
  1058. """Return a boolean value.
  1059. Returns ``True`` in tracing (if a function is called during the
  1060. tracing of code with ``torch.jit.trace``) and ``False`` otherwise.
  1061. """
  1062. if is_scripting():
  1063. return False
  1064. return torch._C._is_tracing()
  1065. class TracedModule(ScriptModule):
  1066. _disable_script_meta = True
  1067. def __init__(self, orig, id_set=None, _compilation_unit=None):
  1068. # XXX: orig can be a nn.Module or a function!
  1069. super().__init__()
  1070. assert isinstance(orig, torch.nn.Module)
  1071. # Copy a subset of `orig` to a temporary nn.Module.
  1072. # This is a way to customize what will actually get compiled by create_script_module
  1073. id_set = set()
  1074. # This allows us to preserve the original module's qualified name by defining a new
  1075. # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name
  1076. # we have a special case that will look up this attribute to override whatever qualname
  1077. # we would get from the python type system
  1078. class QualnameWrapper(torch.nn.Module):
  1079. pass
  1080. QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined]
  1081. type(orig)
  1082. )
  1083. tmp_module = QualnameWrapper()
  1084. def check_unique(param):
  1085. if param in id_set:
  1086. raise ValueError(
  1087. "TracedModules don't support parameter sharing between modules"
  1088. )
  1089. id_set.add(param)
  1090. tmp_module.training = orig.training
  1091. for name, param in orig._parameters.items():
  1092. if param is not None:
  1093. tmp_module._parameters[name] = param
  1094. check_unique(param)
  1095. for name, buf in orig._buffers.items():
  1096. if buf is not None:
  1097. tmp_module._buffers[name] = buf
  1098. check_unique(buf)
  1099. for name, val in orig.__dict__.items():
  1100. if (
  1101. torch._C._jit_is_script_object(val)
  1102. and name not in orig._parameters
  1103. and name not in orig._buffers
  1104. ):
  1105. setattr(tmp_module, name, val)
  1106. if orig._backward_hooks:
  1107. raise ValueError(
  1108. "Modules that have backward hooks assigned can't be compiled: "
  1109. + str(orig)
  1110. )
  1111. for name, submodule in orig._modules.items():
  1112. if submodule is None:
  1113. continue
  1114. tmp_module._modules[name] = make_module(
  1115. submodule, TracedModule, _compilation_unit=None
  1116. )
  1117. script_module = torch.jit._recursive.create_script_module(
  1118. tmp_module, lambda module: (), share_types=False, is_tracing=True
  1119. )
  1120. self.__dict__["_name"] = type(orig).__name__
  1121. self.__dict__["_actual_script_module"] = script_module
  1122. for name in ("_parameters", "_buffers", "_modules", "training"):
  1123. delattr(self, name)
  1124. def forward(self, *args, **kwargs):
  1125. raise RuntimeError("Trace submodules cannot be called.")
  1126. def __getattr__(self, attr):
  1127. if "_actual_script_module" not in self.__dict__:
  1128. return super().__getattr__(attr)
  1129. return getattr(self._actual_script_module, attr)
  1130. def __setattr__(self, attr, value):
  1131. if "_actual_script_module" not in self.__dict__:
  1132. return super().__setattr__(attr, value)
  1133. setattr(self._actual_script_module, attr, value)
  1134. def _get_name(self):
  1135. return self._name
  1136. def extra_repr(self):
  1137. return f"original_name={self._name}"
  1138. class TopLevelTracedModule(TracedModule):
  1139. forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment]
  1140. def _reconstruct(self, cpp_module):
  1141. """
  1142. Re-construct an instance of TopLevelTracedModule using an instance of a C++ module.
  1143. Args:
  1144. cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
  1145. """
  1146. self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
  1147. def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]:
  1148. @functools.wraps(fn)
  1149. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  1150. if not is_tracing():
  1151. # Not tracing, don't do anything
  1152. return fn(*args, **kwargs)
  1153. compiled_fn: Callable[P, R] = script(wrapper.__original_fn) # type: ignore[attr-defined]
  1154. return compiled_fn(*args, **kwargs)
  1155. wrapper.__original_fn = fn # type: ignore[attr-defined]
  1156. wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined]
  1157. return wrapper
  1158. def _get_trace_graph(
  1159. f,
  1160. args=(),
  1161. kwargs=None,
  1162. strict=True,
  1163. _force_outplace=False,
  1164. return_inputs=False,
  1165. _return_inputs_states=False,
  1166. ):
  1167. """Return a tuple on tracing a function or model.
  1168. .. warning::
  1169. This function is internal-only and should only be used by the ONNX
  1170. exporter. If you are trying to get a graph through tracing, please go
  1171. through the public API instead::
  1172. trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
  1173. trace_graph = trace.graph
  1174. Trace a function or model, returning a tuple consisting of the both the
  1175. *trace* of an execution, as well as the original return value. If return_inputs,
  1176. also returns the trace inputs as part of the tuple
  1177. Tracing is guaranteed not to change the semantics of the function/module
  1178. that is traced.
  1179. Args:
  1180. f (torch.nn.Module or function): the function or module
  1181. to be traced.
  1182. args (tuple or Tensor): the positional arguments to pass to the
  1183. function/module to be traced. A non-tuple is assumed to
  1184. be a single positional argument to be passed to the model.
  1185. kwargs (dict): the keyword arguments to pass to the function/module
  1186. to be traced.
  1187. Example (trace a cell):
  1188. .. testcode::
  1189. trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
  1190. """
  1191. if kwargs is None:
  1192. kwargs = {}
  1193. if not isinstance(args, tuple):
  1194. args = (args,)
  1195. outs = ONNXTracedModule(
  1196. f, strict, _force_outplace, return_inputs, _return_inputs_states
  1197. )(*args, **kwargs)
  1198. return outs