prepare.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. # mypy: allow-untyped-defs
  2. from typing import Any, Optional, Union
  3. import torch
  4. from torch._subclasses import FakeTensor
  5. from torch.ao.quantization import (
  6. CUSTOM_KEY,
  7. NUMERIC_DEBUG_HANDLE_KEY,
  8. ObserverOrFakeQuantize,
  9. QConfigMapping,
  10. )
  11. from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
  12. from torch.ao.quantization.fx.prepare import (
  13. _create_obs_or_fq_from_qspec,
  14. _insert_obs_or_fq,
  15. _is_activation_post_process_node,
  16. _save_state,
  17. )
  18. from torch.ao.quantization.qconfig import QConfigAny
  19. from torch.ao.quantization.quantizer import (
  20. EdgeOrNode,
  21. QuantizationSpecBase,
  22. SharedQuantizationSpec,
  23. )
  24. from torch.ao.quantization.utils import _assert_and_get_unique_device
  25. from torch.fx import Graph, GraphModule, Node
  26. from torch.fx.node import Argument
  27. # TODO: make pt2e folder private?
  28. __all__ = [
  29. "prepare",
  30. ]
  31. def _find_root_edge_or_node(
  32. edge_or_node: EdgeOrNode, shared_with_map: dict[EdgeOrNode, EdgeOrNode]
  33. ) -> EdgeOrNode:
  34. """Find the root node for the sharing tree
  35. Args:
  36. edge_or_node: edge/node that we want to find the root
  37. shared_with_map: each edge/node points to the parent, the root node will points to itself
  38. Returns:
  39. root edge/node
  40. """
  41. parent = shared_with_map[edge_or_node]
  42. if parent == edge_or_node:
  43. return edge_or_node
  44. root = _find_root_edge_or_node(parent, shared_with_map)
  45. # path compression
  46. shared_with_map[edge_or_node] = root
  47. return root
  48. def _union(
  49. parent: EdgeOrNode,
  50. child: EdgeOrNode,
  51. shared_with_map: dict[EdgeOrNode, EdgeOrNode],
  52. ) -> None:
  53. """Merge the subtree for `child` with `parent`, the order is important here"""
  54. root_parent = _find_root_edge_or_node(parent, shared_with_map)
  55. root_child = _find_root_edge_or_node(child, shared_with_map)
  56. # union the two trees by pointing the root of child to root of parent
  57. shared_with_map[root_child] = root_parent
  58. def _update_shared_with(
  59. child: EdgeOrNode,
  60. qspec: QuantizationSpecBase,
  61. shared_with_map: dict[EdgeOrNode, EdgeOrNode],
  62. ):
  63. """Update the `shared_with_map` based on the qspec, this applies the `SharedQuantizationSpec`
  64. configuration and established the relationship between `edge_or_node` with the edge/node that it
  65. is pointing to, we'll use this information in the end to get the group id
  66. """
  67. if isinstance(qspec, SharedQuantizationSpec):
  68. parent = qspec.edge_or_node
  69. # we point from edge_or_node to the node that it is sharing_with, e.g.
  70. # qspec for a = SharedQuantizationSpec(b) means `a` points to `b`
  71. _union(parent, child, shared_with_map)
  72. def _unwrap_shared_qspec(
  73. qspec: QuantizationSpecBase,
  74. edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
  75. shared_with_map: dict[EdgeOrNode, EdgeOrNode],
  76. ) -> QuantizationSpecBase:
  77. """Unwraps qspec to get the final root qspec (non SharedQuantizationSpec)
  78. if qspec is SharedQuantizationSpec
  79. (1). tries to find the root edge or node for the node that the qspec points to
  80. (2). recursively find the root qspec based on the qspec for the root node
  81. """
  82. if isinstance(qspec, SharedQuantizationSpec):
  83. sharing_with = qspec.edge_or_node
  84. root = _find_root_edge_or_node(sharing_with, shared_with_map)
  85. qspec = edge_or_node_to_qspec[root]
  86. return _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
  87. return qspec
  88. def _has_same_attr(
  89. qspec_a: QuantizationSpecBase, qspec_b: QuantizationSpecBase, attr_name: str
  90. ):
  91. return (
  92. hasattr(qspec_a, attr_name)
  93. and hasattr(qspec_b, attr_name)
  94. and getattr(qspec_a, attr_name) == getattr(qspec_b, attr_name)
  95. ) or (not hasattr(qspec_a, attr_name) and not hasattr(qspec_b, attr_name))
  96. def _get_edge_or_node_to_qspec(
  97. model: torch.fx.GraphModule,
  98. ) -> dict[EdgeOrNode, QuantizationSpecBase]:
  99. """Get a map from EdgeOrNode to quantization spec based on annotations on the nodes"""
  100. edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {}
  101. for n in model.graph.nodes:
  102. if hasattr(n, "meta") and "quantization_annotation" in n.meta:
  103. qa = n.meta["quantization_annotation"]
  104. for input_to_n, qspec in qa.input_qspec_map.items():
  105. input_edge = (input_to_n, n)
  106. edge_or_node_to_qspec[input_edge] = qspec
  107. if qa.output_qspec is not None:
  108. output_node = n
  109. qspec = qa.output_qspec
  110. edge_or_node_to_qspec[output_node] = qspec
  111. return edge_or_node_to_qspec
  112. def _union_input_edge_with(
  113. input_edge,
  114. input_edge_root_qspec,
  115. edge_or_node,
  116. edge_or_node_to_qspec,
  117. shared_with_map,
  118. ):
  119. """Union input edge with another edge or node, used in implicit sharing to point the current input
  120. edge to other user edges of the producer node, or the output of producer node since these are
  121. referring to the same Tensor
  122. """
  123. root_qspec = None
  124. if edge_or_node in edge_or_node_to_qspec:
  125. qspec = edge_or_node_to_qspec[edge_or_node]
  126. root_qspec = _unwrap_shared_qspec(qspec, edge_or_node_to_qspec, shared_with_map)
  127. # TODO: add assertions for types of root qspecs
  128. if root_qspec is not None and all(
  129. _has_same_attr(root_qspec, input_edge_root_qspec, attr)
  130. for attr in [
  131. "dtype",
  132. "is_dynamic",
  133. "quant_min",
  134. "quant_max",
  135. "qscheme",
  136. "ch_axis",
  137. "scale",
  138. "zero_point",
  139. ]
  140. ):
  141. # the input arg to the node should reuse the existing output observer for arg
  142. # since dtype is the same (we may want to extend this to be a more strict check
  143. # in the future)
  144. # so we point from `input_edge` to `arg` (output of the argument)
  145. _union(edge_or_node, input_edge, shared_with_map)
  146. def _get_edge_or_node_to_group_id(
  147. edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
  148. ) -> dict[EdgeOrNode, int]:
  149. """Map from edge/node to the group ID, generated from quantization annotations,
  150. edge/node with the same group ID should use the same observer/fake_quant instance
  151. This is applying SharedQuantizationSpec configuration and map each edge/node to a group
  152. There is another implicit sharing that's built in the quantization, when we have the following:
  153. * op1 -> op2
  154. * output of op1: int8_qspec
  155. * (op1 -> op2) input edge: int8_qspec
  156. we'll assume sharing between the output of op1 and input of (op1 -> op2) since these are the same Tensor.
  157. Figuring out the correct group ID for all edge/node is a standard union find problem:
  158. https://www.geeksforgeeks.org/introduction-to-disjoint-set-data-structure-or-union-find-algorithm/
  159. Args:
  160. edge_or_node_to_qspec: Dictionary from edge_or_node to the qspec, derived from annotations
  161. Returns:
  162. edge_or_node_to_group_id: Dictionary from edge_or_node to group_id (int), all edge or node that
  163. belongs to the same group should have the same id
  164. Example:
  165. op2 -> cat1 -> cat2
  166. op1 / /
  167. op3
  168. edge_or_node_to_qspec: {
  169. op1: int8_qspec,
  170. op2: int8_qspec,
  171. (op1, cat1): int8_qspc,
  172. (op2, cat1): SharedQuantizationSpec((op1, cat1)),
  173. cat1: SharedQuantizationSpec((op1, cat1)),
  174. (op3, cat2): int8_qspec,
  175. (cat1, cat2): SharedQuantizationSpec((op3, cat2)),
  176. cat2: SharedQuantizationSpec((op3, cat2)),
  177. }
  178. edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
  179. edge_or_node_to_group_id: {
  180. op1: 1,
  181. op2: 1,
  182. (op1, cat1): 1,
  183. (op2, cat1): 1,
  184. cat1: 1,
  185. (op3, cat2): 1,
  186. (cat1, cat2): 1,
  187. cat2: 1,
  188. }
  189. # everything are in the same group because (cat1) and (cat1, cat2) are implicitly shared, which
  190. # connects the two sharing group around cat1 and cat2 op due to transitive sharing
  191. """
  192. # means the observer of key should be shared with observer with value, by default it will
  193. # be shared with itself
  194. shared_with_map: dict[EdgeOrNode, EdgeOrNode] = {
  195. k: k for k in edge_or_node_to_qspec.keys()
  196. }
  197. for edge_or_node, qspec in edge_or_node_to_qspec.items():
  198. if isinstance(edge_or_node, torch.fx.Node):
  199. output_node = edge_or_node
  200. _update_shared_with(output_node, qspec, shared_with_map)
  201. else:
  202. input_edge = edge_or_node
  203. input_edge_root_qspec = _unwrap_shared_qspec(
  204. qspec, edge_or_node_to_qspec, shared_with_map
  205. )
  206. assert isinstance(input_edge, tuple)
  207. arg, n = input_edge
  208. if n.meta["quantization_annotation"].allow_implicit_sharing:
  209. # NOTE: the order is important here, we first share with other users and then share with previous
  210. # output because the reverse order could cause circular dependency
  211. # e.g node1 -> node2
  212. # \ -> node3
  213. # when processing (node1, node2), if we first point (node1, node2) to node1
  214. # Step 1. shared_map = {(node1, node2): node1}
  215. # Step 2. after that, we point the (node1, node2) to its other user (node1, node3) ,
  216. # which means shared_map = {(node1, node2): node1, node1: (node1, node3)}
  217. # because we will point the root of (node1, node2) (in this case node1) to the root of (node1, node3)
  218. # Step 3. and when we process (node1, node3), it can try to point to node1 as well, then we'll
  219. # have a circular dependency
  220. # the following order works around this issue, but this does not allow arbitrary configuration
  221. # of sharing so it might break in a different case in the future, when it breaks
  222. # quantizer writer can check the notes here to debug the issue
  223. # sharing with other users of the producer node
  224. # (arg, user)
  225. if not isinstance(arg, Node) or not isinstance(n, Node):
  226. raise Exception( # noqa: TRY002
  227. f"Expected input_edge to have type Tuple[Node, Node], but got: {arg, n}"
  228. )
  229. for user in arg.users:
  230. if user is n:
  231. continue
  232. arg_to_user_edge = (arg, user)
  233. _union_input_edge_with(
  234. input_edge,
  235. input_edge_root_qspec,
  236. arg_to_user_edge,
  237. edge_or_node_to_qspec,
  238. shared_with_map,
  239. )
  240. # sharing with output of producer node
  241. _union_input_edge_with(
  242. input_edge,
  243. input_edge_root_qspec,
  244. arg,
  245. edge_or_node_to_qspec,
  246. shared_with_map,
  247. )
  248. _update_shared_with(input_edge, qspec, shared_with_map)
  249. # now that we get the sharing relations between all edges and nodes, we can assign group ids
  250. cur_group_id = 0
  251. edge_or_node_to_group_id: dict[EdgeOrNode, int] = {}
  252. for edge_or_node in shared_with_map.keys():
  253. root = _find_root_edge_or_node(edge_or_node, shared_with_map)
  254. if root not in edge_or_node_to_group_id:
  255. edge_or_node_to_group_id[root] = cur_group_id
  256. cur_group_id += 1
  257. edge_or_node_to_group_id[edge_or_node] = edge_or_node_to_group_id[root]
  258. return edge_or_node_to_group_id
  259. def _get_obs_or_fq_map(
  260. edge_or_node_to_group_id: dict[EdgeOrNode, int],
  261. edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase],
  262. is_qat: bool,
  263. ) -> dict[EdgeOrNode, ObserverOrFakeQuantize]:
  264. """Generates the EdgeOrNode to observer/fake_quant instances
  265. Makes sure that for EdgeOrNode that has the same group_id should have the same observer or fake quant
  266. instances
  267. """
  268. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize] = {}
  269. group_id_to_obs_or_fq: dict[int, ObserverOrFakeQuantize] = {}
  270. for edge_or_node, qspec in edge_or_node_to_qspec.items():
  271. group_id = edge_or_node_to_group_id[edge_or_node]
  272. if group_id not in group_id_to_obs_or_fq:
  273. # TODO: maybe edge_or_node_to_qspec should be edge_or_node_to_root_qspec, this will simplify
  274. # the implementation for _create_obs_or_fq_from_qspec
  275. group_id_to_obs_or_fq[group_id] = _create_obs_or_fq_from_qspec(
  276. qspec, obs_or_fq_map, is_qat
  277. )
  278. obs_or_fq_map[edge_or_node] = group_id_to_obs_or_fq[group_id]
  279. return obs_or_fq_map
  280. def _maybe_insert_input_observer_for_arg_or_kwarg(
  281. node: Union[Node, Any],
  282. arg: Argument,
  283. qconfig: QConfigAny,
  284. model: torch.nn.Module,
  285. named_modules: dict[str, torch.nn.Module],
  286. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  287. is_qat: bool,
  288. model_device: Optional[torch.device] = None,
  289. ) -> Argument:
  290. """
  291. Given a `node` and an `arg`, inserts an input observer between
  292. `node` and `arg` if necessary.
  293. """
  294. # for ops such as torch.cat([x0, x1]),
  295. # traverse through the list
  296. if isinstance(arg, (list, tuple)):
  297. new_arg_to_return = []
  298. for inner_arg in arg:
  299. new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
  300. node,
  301. inner_arg,
  302. qconfig,
  303. model,
  304. named_modules,
  305. obs_or_fq_map,
  306. is_qat,
  307. model_device,
  308. )
  309. new_arg_to_return.append(new_inner_arg)
  310. return type(arg)(new_arg_to_return)
  311. if not isinstance(arg, Node):
  312. return arg
  313. assert isinstance(arg, Node)
  314. # default (no observer)
  315. new_arg = arg
  316. # find the original `arg` node to the current node, skipping inserted observer/fake_quant nodes
  317. original_arg = arg
  318. while _is_activation_post_process_node(original_arg, named_modules):
  319. original_arg = original_arg.args[0] # type: ignore[assignment]
  320. assert isinstance(original_arg, Node), (
  321. f"expect original argument to be a Node, but got: {type(original_arg)}"
  322. )
  323. input_edge = (original_arg, node)
  324. if input_edge not in obs_or_fq_map:
  325. return new_arg
  326. # input_edge needs to be observed
  327. input_edge_obs_or_fq = obs_or_fq_map[input_edge]
  328. if input_edge_obs_or_fq is None:
  329. return new_arg
  330. arg_as_output_obs_or_fq = obs_or_fq_map.get(original_arg, None)
  331. # the arg is observed as the output and is using the same instance as the input_edge
  332. # we'll reuse the inserted observer/fake_quant
  333. if arg_as_output_obs_or_fq is not None and id(arg_as_output_obs_or_fq) == id(
  334. input_edge_obs_or_fq
  335. ):
  336. return new_arg
  337. # otherwise, we'll insert a new observer/fake_quant node
  338. # skip inserting new observers if the same observer instance is inserted before for another user
  339. # Example:
  340. # conv1 -> obs1 -> existing_obs -> conv2
  341. # \ -> conv3
  342. #
  343. # instead of inserting new observers we will have:
  344. # conv1 -> obs1 -> existing_obs -> conv2
  345. # \ -> conv3
  346. for maybe_obs_node in arg.users.keys():
  347. if not _is_activation_post_process_node(maybe_obs_node, named_modules):
  348. continue
  349. maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
  350. if id(maybe_obs_mod) == id(input_edge_obs_or_fq):
  351. return maybe_obs_node
  352. assert isinstance(model.graph, Graph)
  353. new_arg = _insert_obs_or_fq(
  354. arg,
  355. input_edge_obs_or_fq,
  356. model,
  357. named_modules,
  358. model.graph,
  359. model_device,
  360. )
  361. return new_arg
  362. def _maybe_insert_input_observers_for_node(
  363. node: Node,
  364. qconfig: QConfigAny,
  365. model: torch.nn.Module,
  366. named_modules: dict[str, torch.nn.Module],
  367. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  368. is_qat: bool,
  369. model_device: Optional[torch.device] = None,
  370. ) -> None:
  371. """
  372. If needed, inserts observers to the input args and kwargs of `node`.
  373. Note: modifies `node` inplace.
  374. For example, if cur_node needs an observer after prev_node, we change from
  375. prev_node -> cur_node
  376. To
  377. prev_node -> obs -> cur_node
  378. """
  379. # Look through every input arg. If that arg's target dtype does not
  380. # match the current node's target dtype, insert an observer.
  381. new_args = []
  382. for arg in node.args:
  383. new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
  384. node,
  385. arg,
  386. qconfig,
  387. model,
  388. named_modules,
  389. obs_or_fq_map,
  390. is_qat,
  391. model_device,
  392. )
  393. new_args.append(new_arg)
  394. # Clone has a memory_format kwarg, zeros_like has a pin_memory kwarg, and
  395. # gelu has a has an approximate kwarg that persist in exported graph.
  396. # This is just a work around for these.
  397. assert (
  398. node.target == torch.ops.aten.clone.default
  399. or node.target == torch.ops.aten.zeros_like.default
  400. or node.target == torch.ops.aten.gelu.default
  401. or len(node.kwargs) == 0
  402. ), " expecting kwargs for aten op IR to be empty"
  403. # assign the new args to the node, inplace
  404. node.args = tuple(new_args)
  405. def _maybe_insert_output_observer_for_node(
  406. node: Node,
  407. model: torch.nn.Module,
  408. named_modules: dict[str, torch.nn.Module],
  409. graph: Graph,
  410. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  411. is_qat: bool,
  412. model_device: Optional[torch.device] = None,
  413. ) -> Optional[Node]:
  414. if node in obs_or_fq_map:
  415. output_act_obs_or_fq = obs_or_fq_map[node]
  416. new_output = _insert_obs_or_fq(
  417. node,
  418. output_act_obs_or_fq,
  419. model,
  420. named_modules,
  421. graph,
  422. model_device,
  423. )
  424. # propagate numeric debug handle from original node to observer/fake_quant node
  425. if (
  426. isinstance(node, Node)
  427. and isinstance(new_output, Node)
  428. and CUSTOM_KEY in node.meta
  429. and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
  430. ):
  431. if CUSTOM_KEY not in new_output.meta:
  432. new_output.meta[CUSTOM_KEY] = {}
  433. new_output.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
  434. CUSTOM_KEY
  435. ][NUMERIC_DEBUG_HANDLE_KEY]
  436. return new_output
  437. return None
  438. def _maybe_insert_input_and_output_observers_for_node(
  439. node: Node,
  440. model: torch.fx.GraphModule,
  441. obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
  442. is_qat: bool,
  443. model_device: Optional[torch.device] = None,
  444. ):
  445. this_node_quantization_annotation = (
  446. node.meta["quantization_annotation"]
  447. if "quantization_annotation" in node.meta
  448. else None
  449. )
  450. if this_node_quantization_annotation is None:
  451. return
  452. named_modules = dict(model.named_modules(remove_duplicate=False))
  453. _maybe_insert_input_observers_for_node(
  454. node,
  455. None, # qconfig
  456. model,
  457. named_modules,
  458. obs_or_fq_map,
  459. is_qat,
  460. model_device,
  461. )
  462. output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
  463. if not output_is_a_tensor:
  464. return
  465. # this returns the new observer node if it was needed
  466. maybe_output_obs_node = _maybe_insert_output_observer_for_node(
  467. node,
  468. model,
  469. named_modules,
  470. model.graph,
  471. obs_or_fq_map,
  472. is_qat,
  473. model_device,
  474. )
  475. if maybe_output_obs_node is None:
  476. return
  477. # Update users of original node to use the output observer
  478. # instead. For example, change
  479. #
  480. # next_node
  481. # /
  482. # cur_node -> obs
  483. #
  484. # to
  485. #
  486. # next_node
  487. # /
  488. # cur_node -> obs
  489. #
  490. # We need to save orig users before updating uses because
  491. # the list of users will change as we update uses
  492. orig_users = list(node.users.keys())
  493. for user_node in orig_users:
  494. if user_node is maybe_output_obs_node:
  495. continue
  496. user_node.replace_input_with(node, maybe_output_obs_node)
  497. def prepare(
  498. model: GraphModule,
  499. node_name_to_scope: dict[str, tuple[str, type]],
  500. is_qat: bool,
  501. obs_or_fq_callback=None,
  502. ) -> GraphModule:
  503. # Since we are mutating the graph as we go, we iterate over the original
  504. # nodes before observer insertion, instead of model.graph.nodes.
  505. nodes_before_observation = list(model.graph.nodes)
  506. # At the high level we construct a map from EdgeOrNode to a observer_or_fake_quant instance
  507. # all edge/nodes that belongs to the same group will use the same instance
  508. # and when we insert observers we'll just query this map to get the correct observer_or_fake_quant
  509. # instance
  510. edge_or_node_to_qspec = _get_edge_or_node_to_qspec(model)
  511. edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
  512. obs_or_fq_map = _get_obs_or_fq_map(
  513. edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat
  514. )
  515. if obs_or_fq_callback:
  516. obs_or_fq_callback(model, obs_or_fq_map)
  517. model_device = _assert_and_get_unique_device(model)
  518. for node in nodes_before_observation:
  519. # TODO: simplify logic for inserting observers
  520. _maybe_insert_input_and_output_observers_for_node(
  521. node,
  522. model,
  523. obs_or_fq_map,
  524. is_qat,
  525. model_device,
  526. )
  527. model = GraphModule(model, model.graph)
  528. _save_state(
  529. model,
  530. {}, # node_name_to_qconfig
  531. node_name_to_scope,
  532. PrepareCustomConfig(),
  533. {}, # equalization_node_name_to_qconfig
  534. QConfigMapping(),
  535. is_qat,
  536. set(), # observed_node_names
  537. )
  538. return model