utils.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. # mypy: allow-untyped-defs
  2. import operator
  3. import types
  4. from typing import Any, Callable, Optional, Union
  5. import torch
  6. import torch.ao.quantization.pt2e._affine_quantization # noqa: F401
  7. import torch.nn.functional as F
  8. # Makes sure that quantized_decomposed ops are registered
  9. from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
  10. from torch.ao.quantization.quantizer import QuantizationAnnotation
  11. from torch.export.unflatten import _assign_attr, _AttrKind
  12. from torch.fx import GraphModule, Node
  13. from torch.nn.utils.fusion import fuse_conv_bn_weights
  14. from torch.utils._pytree import LeafSpec
  15. __all__ = [
  16. "fold_bn_weights_into_conv_node",
  17. "remove_tensor_overload_for_qdq_ops",
  18. ]
  19. _QUANTIZE_OPS = [
  20. torch.ops.quantized_decomposed.quantize_per_tensor.default,
  21. torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
  22. torch.ops.quantized_decomposed.quantize_per_channel.default,
  23. ]
  24. _DEQUANTIZE_OPS = [
  25. torch.ops.quantized_decomposed.dequantize_per_tensor.default,
  26. torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
  27. torch.ops.quantized_decomposed.dequantize_per_channel.default,
  28. ]
  29. def _is_connected(source: torch.fx.Node, dest: torch.fx.Node) -> bool:
  30. """
  31. Assuming dest is one of the ops inserted by quant workflow, this function
  32. finds if source and dest are connected. Assumption is that only quant workflow
  33. inserted ops exist between source and dest
  34. """
  35. quant_workflow_ops = _QUANTIZE_OPS + _DEQUANTIZE_OPS
  36. quant_workflow_ops.append(torch.ops.quantized_decomposed.choose_qparams.tensor)
  37. while dest.target in quant_workflow_ops:
  38. if not isinstance(dest.args[0], torch.fx.Node):
  39. raise ValueError(
  40. f"expected arg[0] of quant workflow ops to be a node but found {dest.args[0]}"
  41. )
  42. dest = dest.args[0]
  43. return dest == source
  44. def _find_q_dq_node_for_user(
  45. produer: torch.fx.Node, user: torch.fx.Node
  46. ) -> tuple[Any, Any]:
  47. """
  48. Find q, dq pair corresponding to [producer -> q -> dq -> user]
  49. Utils works by finding dq arg of user and ensuring it is connected to
  50. producer
  51. """
  52. dq_node = None
  53. for n in user.args:
  54. if (
  55. isinstance(n, torch.fx.Node)
  56. and n.op == "call_function"
  57. and n.target in _DEQUANTIZE_OPS
  58. ):
  59. if _is_connected(produer, n):
  60. dq_node = n
  61. break
  62. if dq_node is None:
  63. for n in user.kwargs:
  64. if (
  65. isinstance(n, torch.fx.Node)
  66. and n.op == "call_function"
  67. and n.target in _DEQUANTIZE_OPS
  68. ):
  69. if _is_connected(produer, n):
  70. dq_node = n
  71. break
  72. if dq_node is None:
  73. return (None, None)
  74. q_node = None
  75. if (
  76. isinstance(arg := dq_node.args[0], torch.fx.Node)
  77. and arg.op == "call_function"
  78. and arg.target in _QUANTIZE_OPS
  79. ):
  80. q_node = arg
  81. return (q_node, dq_node)
  82. def _is_sym_size_node(node: Node):
  83. return (
  84. node.op == "call_function"
  85. and node.target == torch.ops.aten.sym_size.default
  86. or node.target == torch.ops.aten.sym_numel.default
  87. or node.target == torch.ops.aten.sym_numel
  88. or node.target == torch.ops.aten.sym_size
  89. )
  90. def _filter_sym_size_users(node: torch.fx.Node) -> list[torch.fx.Node]:
  91. node_users = list(filter((lambda x: (_is_sym_size_node(x) is False)), node.users))
  92. return node_users
  93. def _is_valid_annotation(annotation: QuantizationAnnotation) -> bool:
  94. if annotation is None:
  95. return False
  96. input_qspec_map = annotation.input_qspec_map
  97. output_qspec = annotation.output_qspec
  98. if len(input_qspec_map) == 0 and output_qspec is None:
  99. return False
  100. return True
  101. def _get_tensor_constant_from_node(node, m):
  102. if node is None:
  103. return None
  104. assert node.op == "get_attr"
  105. target_atoms = node.target.split(".")
  106. attr_itr = m
  107. for i, atom in enumerate(target_atoms):
  108. if not hasattr(attr_itr, atom):
  109. raise RuntimeError(
  110. f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
  111. )
  112. attr_itr = getattr(attr_itr, atom)
  113. return attr_itr
  114. def _get_all_arguments(orig_args, orig_kwargs, args_schema):
  115. all_args = []
  116. for i, schema in enumerate(args_schema):
  117. if schema.name in orig_kwargs:
  118. all_args.append(orig_kwargs[schema.name])
  119. elif not schema.kwarg_only and i < len(orig_args):
  120. all_args.append(orig_args[i])
  121. else:
  122. all_args.append(schema.default_value)
  123. return all_args
  124. def _is_supported_batch_norm_for_training(node: Node):
  125. """
  126. Return True if the given node refers to an aten batch norm op QAT supports.
  127. """
  128. supported_ops = [
  129. torch.ops.aten.batch_norm.default,
  130. torch.ops.aten._native_batch_norm_legit.default,
  131. # Note: we won't need this op anymore after batch norm consolidation
  132. # For now, we need to continue to support it because it gives better
  133. # training numerics than `_native_batch_norm_legit`
  134. torch.ops.aten.cudnn_batch_norm.default,
  135. torch.ops.aten.miopen_batch_norm.default,
  136. ]
  137. return node.target in supported_ops
  138. # TODO: move this to torch/ao/quantization/utils.py
  139. def _is_conv_node(n: Node):
  140. """
  141. Return whether the node refers to an aten conv op.
  142. """
  143. return n.op == "call_function" and n.target in [
  144. torch.ops.aten.conv1d.default,
  145. torch.ops.aten.conv1d.padding,
  146. torch.ops.aten.conv2d.default,
  147. torch.ops.aten.conv2d.padding,
  148. torch.ops.aten.conv3d.default,
  149. torch.ops.aten.conv3d.padding,
  150. ]
  151. def _is_conv_transpose_node(n: Node):
  152. """
  153. Return whether the node refers to an aten conv_transpose op.
  154. """
  155. return n.op == "call_function" and n.target in [
  156. torch.ops.aten.conv_transpose1d,
  157. torch.ops.aten.conv_transpose1d.default,
  158. torch.ops.aten.conv_transpose2d,
  159. torch.ops.aten.conv_transpose2d.input,
  160. ]
  161. def _is_conv_or_conv_transpose_node(n: Node):
  162. """
  163. Return whether the node refers to an aten conv or conv transpose op.
  164. """
  165. return _is_conv_node(n) or _is_conv_transpose_node(n)
  166. def _is_conv_transpose_fn(conv_fn: Callable):
  167. return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
  168. def _is_bn_node(n: Node):
  169. return (
  170. _is_supported_batch_norm_for_training(n)
  171. or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
  172. )
  173. def fold_bn_weights_into_conv_node(
  174. conv_node: Node,
  175. conv_weight_node: Node,
  176. conv_bias_node: Optional[Node],
  177. bn_node: Node,
  178. m: GraphModule,
  179. ) -> None:
  180. # conv args: input, weight, bias, stride, padding, dilation, ...
  181. conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
  182. conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
  183. transpose = _is_conv_transpose_node(conv_node)
  184. # eval bn args: input, weight, bias, running mean, running var, momentum, eps
  185. # train bn args: input, weight, bias, running mean, running var, training, momentum, eps
  186. bn_args_schema = bn_node.target._schema.arguments # type: ignore[union-attr]
  187. bn_args = _get_all_arguments(bn_node.args, bn_node.kwargs, bn_args_schema)
  188. bn_w = _get_tensor_constant_from_node(bn_args[1], m)
  189. bn_b = _get_tensor_constant_from_node(bn_args[2], m)
  190. bn_rm = _get_tensor_constant_from_node(bn_args[3], m)
  191. bn_rv = _get_tensor_constant_from_node(bn_args[4], m)
  192. if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default:
  193. eps_arg_index = 6
  194. elif _is_supported_batch_norm_for_training(bn_node):
  195. eps_arg_index = 7
  196. else:
  197. raise ValueError("BN node target is unexpected ", bn_node.target)
  198. bn_eps = bn_args[eps_arg_index]
  199. fused_weight, fused_bias = fuse_conv_bn_weights(
  200. conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
  201. )
  202. # update the weight and bias for conv
  203. conv_args = list(conv_node.args)
  204. # filling in the default bias argument
  205. if len(conv_args) == 2:
  206. conv_args.append(None)
  207. # calling data since the fused_weight and fused_bias are nn.Parameter
  208. weight_attr_name = conv_weight_node.target
  209. assert isinstance(weight_attr_name, str)
  210. _assign_attr(fused_weight, m, weight_attr_name, _AttrKind.PARAMETER)
  211. if conv_bias_node is not None:
  212. bias_attr_name = conv_bias_node.target
  213. _assign_attr(fused_bias, m, str(bias_attr_name), _AttrKind.PARAMETER)
  214. else:
  215. bias_attr_name = weight_attr_name + "_bias"
  216. _assign_attr(fused_bias, m, bias_attr_name, _AttrKind.PARAMETER)
  217. with m.graph.inserting_before(conv_node):
  218. get_bias_node = m.graph.get_attr(bias_attr_name)
  219. # NOTE: here we assume the bias of conv is not quantized!
  220. conv_args[2] = get_bias_node
  221. conv_node.args = tuple(conv_args)
  222. # native_batch_norm has 3 outputs, we expect getitem calls on the output
  223. # and we want to replace the uses of getitem 0 with the output of conv
  224. #
  225. if bn_node.target == torch.ops.aten.batch_norm.default:
  226. # With the new training ir, instead of batch_norm + getitem,
  227. # we only have the batch_norm node.
  228. #
  229. # Before:
  230. # conv -> bn -> users
  231. # After:
  232. # conv -> users
  233. # bn has no users now
  234. bn_node.replace_all_uses_with(conv_node)
  235. else:
  236. # Before:
  237. # conv -> bn - (first output) -> users1
  238. # \ - (second output) -> users2
  239. # \ - (third output) -> users3
  240. # After:
  241. # conv -> (first output) -> users1
  242. # bn -
  243. # \ - (second output) -> users2
  244. # \ - (third output) -> users3
  245. # if users2 and users3 are empty then bn will be removed through dead code elimination
  246. for user in bn_node.users:
  247. if (
  248. user.op != "call_function"
  249. or user.target != operator.getitem
  250. or user.args[1] != 0
  251. ):
  252. continue
  253. user.replace_all_uses_with(conv_node)
  254. # If the BN node does not have users, erase it from the graph
  255. # Note: we need to do this manually because the model can still be in train
  256. # mode at this point, in which case DCE won't erase the BN node automatically
  257. # since the node refers to a mutating op. Here we still need to call DCE first
  258. # to get rid of the unused getitem nodes that consume the BN node.
  259. m.graph.eliminate_dead_code()
  260. if len(bn_node.users) == 0:
  261. m.graph.erase_node(bn_node)
  262. # fuse conv bn weights, inplace modification of the graph_module and graph
  263. def _fuse_conv_bn_(m: GraphModule) -> None:
  264. has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
  265. if not has_bn:
  266. return
  267. for n in m.graph.nodes:
  268. if n.op != "call_function" or n.target not in (
  269. torch.ops.aten._native_batch_norm_legit_no_training.default,
  270. torch.ops.aten.batch_norm.default,
  271. ):
  272. continue
  273. bn_node = n
  274. n = bn_node.args[0]
  275. if not _is_conv_or_conv_transpose_node(n):
  276. continue
  277. conv_node = n
  278. conv_weight_node = conv_node.args[1]
  279. conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
  280. fold_bn_weights_into_conv_node(
  281. conv_node, conv_weight_node, conv_bias_node, bn_node, m
  282. )
  283. m.graph.eliminate_dead_code()
  284. m.recompile()
  285. def _get_node_name_to_scope(model: GraphModule) -> dict[str, tuple[str, type]]:
  286. # TODO: move this information to fx node itself
  287. node_name_to_scope: dict[str, tuple[str, type]] = {}
  288. for n in model.graph.nodes:
  289. nn_module_stack = n.meta.get("nn_module_stack", None)
  290. current_scope = ("", type(None))
  291. if nn_module_stack:
  292. bt = list(nn_module_stack.values())[-1]
  293. current_scope = (bt[0].split(".")[-1], bt[1])
  294. node_name_to_scope[n.name] = current_scope
  295. return node_name_to_scope
  296. def _get_aten_graph_module_for_pattern(
  297. pattern: Callable,
  298. example_inputs: tuple[Any, ...],
  299. is_cuda: bool = False,
  300. **kwargs,
  301. ) -> GraphModule:
  302. """
  303. Convert the pattern to an FX graph with decomposed aten ops.
  304. """
  305. if is_cuda:
  306. example_inputs = tuple(
  307. [x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
  308. )
  309. aten_pattern = torch.export.export_for_training(
  310. pattern, # type: ignore[arg-type]
  311. example_inputs,
  312. kwargs,
  313. strict=True,
  314. ).module(check_guards=False)
  315. aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
  316. aten_pattern.recompile() # type: ignore[operator]
  317. # ep.module() adds copy_ nodes for the mutated inputs.
  318. # For patterns, it doesn't matter
  319. for node in aten_pattern.graph.nodes: # type: ignore[union-attr]
  320. if (
  321. node.op == "call_function"
  322. and node.target == torch.ops.aten.copy_.default
  323. and len(node.users) == 0
  324. ):
  325. aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr]
  326. aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
  327. aten_pattern.recompile() # type: ignore[operator]
  328. return aten_pattern # type: ignore[return-value]
  329. def remove_tensor_overload_for_qdq_ops(match_pattern: GraphModule) -> None:
  330. """Remove .tensor overload for quantize/dequantize ops so that we can
  331. use the match_pattern that we get from torchdynamo export to match the output of convert_pt2e
  332. """
  333. _MAP = {
  334. torch.ops.quantized_decomposed.quantize_per_tensor.default: torch.ops.quantized_decomposed.quantize_per_tensor,
  335. torch.ops.quantized_decomposed.dequantize_per_tensor.default: torch.ops.quantized_decomposed.dequantize_per_tensor,
  336. torch.ops.quantized_decomposed.quantize_per_tensor.tensor: torch.ops.quantized_decomposed.quantize_per_tensor,
  337. torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor,
  338. torch.ops.quantized_decomposed.quantize_per_tensor.tensor2: torch.ops.quantized_decomposed.quantize_per_tensor,
  339. torch.ops.quantized_decomposed.dequantize_per_tensor.tensor2: torch.ops.quantized_decomposed.dequantize_per_tensor,
  340. torch.ops.quantized_decomposed.quantize_per_channel.default: torch.ops.quantized_decomposed.quantize_per_channel,
  341. torch.ops.quantized_decomposed.dequantize_per_channel.default: torch.ops.quantized_decomposed.dequantize_per_channel,
  342. torch.ops.aten.clamp.Tensor: torch.ops.aten.clamp,
  343. }
  344. for n in match_pattern.graph.nodes:
  345. if n.op != "call_function":
  346. continue
  347. if n.target in _MAP:
  348. n.target = _MAP[n.target]
  349. def _is_literal(arg):
  350. if isinstance(arg, (int, float)):
  351. return True
  352. if isinstance(arg, (tuple, list)):
  353. return all(map(_is_literal, arg))
  354. return False
  355. def _replace_literals_with_new_placeholders(
  356. gm: torch.fx.GraphModule,
  357. merge_dup: bool = False,
  358. exclude_literals: Optional[list[Any]] = None,
  359. ):
  360. """Replace the literals in the graph with placeholder nodes that's created on the fly while we
  361. traverse the graph, so that the literal arguments in the graph can be matched and replaced
  362. To use this, the pattern and replacement graph should have the exact same number of literal args
  363. and they should be used in the exact same order in the pattern and replacement graph.
  364. If the literal arguments are not used in the same order in pattern and replacement graph, please
  365. use `_replace_literals_with_existing_placeholders` instead
  366. Args:
  367. `gm`: input GraphModule that we'll transform
  368. `merge_dup`: boolean flag to indicate that if the same literal appears multiple times in
  369. the graph, whether they should correspond to the same placeholder or not
  370. `exclude_literals`: a list of literals that will not be replaced with placeholders
  371. Example:
  372. # 1. Original Graph
  373. def pattern(self, x):
  374. return x + 3
  375. def replacement(self, x):
  376. return x - 3
  377. example_inputs = (torch.randn(1, 3, 3, 3),)
  378. pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
  379. replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
  380. # 2. Before calling replace literals we'll see the following graph:
  381. def pattern(self, x):
  382. return x + 3
  383. def replacement(self, x):
  384. return x - 3
  385. pattern_gm = _replace_literals_with_new_placeholders(pattern_gm)
  386. replacement_gm = _replace_literals_with_new_placeholders(replacement_gm)
  387. # 3. After replacing literals with new placeholder nodes
  388. def pattern(self, x, new_ph):
  389. return x + new_ph
  390. def pattern(self, x, new_ph):
  391. return x - new_ph
  392. """
  393. last_ph = None
  394. cnt = 0
  395. literal_to_ph: dict[Union[float, bool, int, torch.dtype], Node] = {}
  396. if exclude_literals is None:
  397. exclude_literals = []
  398. in_spec = gm._in_spec
  399. args_spec = in_spec.children_specs[0]
  400. for node in gm.graph.nodes:
  401. if node.op == "placeholder":
  402. last_ph = node
  403. cnt += 1
  404. continue
  405. with gm.graph.inserting_after(last_ph):
  406. new_args = []
  407. for arg in node.args:
  408. if _is_literal(arg) and arg not in exclude_literals:
  409. if merge_dup and arg in literal_to_ph:
  410. new_args.append(literal_to_ph[arg])
  411. else:
  412. ph_node = gm.graph.placeholder("arg" + str(cnt))
  413. new_args.append(ph_node)
  414. args_spec.children_specs.append(LeafSpec())
  415. cnt += 1
  416. if merge_dup:
  417. literal_to_ph[arg] = ph_node
  418. else:
  419. new_args.append(arg)
  420. new_args = tuple(new_args)
  421. node.args = new_args
  422. # Update `num_nodes`, `num_leaves`, `num_children`.
  423. args_spec.__post_init__()
  424. in_spec.__post_init__()
  425. return gm
  426. def _replace_literals_with_existing_placeholders(
  427. gm: torch.fx.GraphModule,
  428. exclude_literals: Optional[list[Any]] = None,
  429. literal_to_ph_idx: Optional[dict[Union[float, int, bool, torch.dtype], int]] = None,
  430. ):
  431. """Replace the literals in the graph with **existing** placeholder nodes, so that the literal arguments
  432. in the graph can be matched and replaced
  433. To use this, all literal args in the graph should be unique and each of them should correspond
  434. to exactly one placeholder node
  435. # 1. Original Graph
  436. def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
  437. return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
  438. def replacement(x_i8, scale, zero_point, quant_min, quant_max):
  439. x_i8 = torch.clamp(x_i8, quant_min, quant_max)
  440. return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
  441. example_inputs = (
  442. torch.randn(1, 3, 3, 3),
  443. 1.0,
  444. 0,
  445. -128,
  446. 127,
  447. )
  448. pattern_gm = _get_aten_graph_module_for_pattern(pattern, example_inputs)
  449. replacement_gm = _get_aten_graph_module_for_pattern(pattern, example_inptus)
  450. # 2. Before calling replace literals we'll see the following graph:
  451. def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
  452. # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
  453. return torch.dequantize_per_tensor(x_i8, 1.0, 0, -128, 127)
  454. def replacement(x_i8, scale, zero_point, quant_min, quant_max):
  455. # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
  456. x_i8 = torch.clamp(x_i8, -128, 127)
  457. return ((x_i8.to(torch.float32) - 0) * 1.0).to(dtype=torch.float32)
  458. # Note that literal args appear in different order in pattern and replacement graph, so
  459. # we can't use _replace_literals_with_new_placeholders
  460. literal_to_ph_idx = {1.0: 1, 0: 2, -128: 3, 127: 4}
  461. pattern_gm = _replace_literals_with_existing_placeholders(pattern_gm, literal_to_ph_idx)
  462. replacement_gm = _replace_literals_with_existing_placeholders(replacement_gm, literal_to_ph_idx)
  463. # 3. After replacing literals with existing placeholder nodes
  464. def pattern(self, x_i8, scale, zero_point, quant_min, quant_max):
  465. # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
  466. return torch.dequantize_per_tensor(x_i8, scale, zero_point, quant_min, quant_max)
  467. def replacement(x_i8, scale, zero_point, quant_min, quant_max):
  468. # scale/zero_point/quant_min/quant_max are burnt in since they are scalar values
  469. x_i8 = torch.clamp(x_i8, quant_min, quant_max)
  470. return ((x_i8.to(torch.float32) - zero_point) * scale).to(dtype=torch.float32)
  471. """
  472. if exclude_literals is None:
  473. exclude_literals = []
  474. if literal_to_ph_idx is None:
  475. literal_to_ph_idx = {}
  476. phs = [node for node in gm.graph.nodes if node.op == "placeholder"]
  477. for node in gm.graph.nodes:
  478. if node.op != "call_function":
  479. continue
  480. new_args = []
  481. for arg in node.args:
  482. if (
  483. _is_literal(arg)
  484. and arg not in exclude_literals
  485. and arg in literal_to_ph_idx
  486. ):
  487. ph_idx = literal_to_ph_idx[arg]
  488. ph_node = phs[ph_idx]
  489. new_args.append(ph_node)
  490. else:
  491. new_args.append(arg)
  492. new_args = tuple(new_args)
  493. node.args = new_args
  494. return gm
  495. # TODO: Handle this in export itself and don't wrap the model in another GraphModule
  496. # in prepare and convert
  497. def _disallow_eval_train(model: GraphModule):
  498. """
  499. Disallow calling `model.train()` or `model.eval()` on the given GraphModule.
  500. This is useful for exported models, where these methods don't actually behave as expected.
  501. """
  502. error_message = """
  503. Calling train() or eval() is not supported for exported models.
  504. Please call `torch.ao.quantization.move_exported_model_to_train(model)` (or eval) instead.
  505. If you cannot replace the calls to `model.train()` and `model.eval()`, you may override
  506. the behavior for these methods by calling `torch.ao.quantization.allow_exported_model_train_eval(model)`,
  507. which does the above automatically for you. Note that this has limited effect on switching
  508. behavior between train and eval modes, and should be used only for special ops such as dropout
  509. and batchnorm.
  510. """
  511. def _train(self, mode: bool = True):
  512. raise NotImplementedError(error_message)
  513. def _eval(self, mode: bool = True):
  514. raise NotImplementedError(error_message)
  515. model.train = types.MethodType(_train, model) # type: ignore[method-assign]
  516. model.eval = types.MethodType(_eval, model) # type: ignore[method-assign]
  517. return model