utils.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import operator
  4. import warnings
  5. from collections import namedtuple
  6. from dataclasses import dataclass
  7. from typing import Any, Callable, Optional, Union
  8. import torch
  9. import torch.nn as nn
  10. from torch.ao.quantization import QConfigAny, QuantType
  11. from torch.ao.quantization.backend_config import DTypeWithConstraints
  12. from torch.ao.quantization.fake_quantize import (
  13. FakeQuantizeBase,
  14. FixedQParamsFakeQuantize,
  15. )
  16. from torch.ao.quantization.observer import (
  17. _is_activation_post_process,
  18. FixedQParamsObserver,
  19. ObserverBase,
  20. )
  21. from torch.ao.quantization.qconfig import (
  22. float16_dynamic_qconfig,
  23. float16_static_qconfig,
  24. qconfig_equals,
  25. )
  26. from torch.ao.quantization.qconfig_mapping import QConfigMapping
  27. from torch.ao.quantization.stubs import DeQuantStub
  28. from torch.ao.quantization.utils import (
  29. _assert_and_get_unique_device,
  30. activation_is_statically_quantized,
  31. )
  32. from torch.fx import GraphModule, map_arg
  33. from torch.fx.graph import Graph, Node
  34. # importing the lib so that the quantized_decomposed ops are registered
  35. from ._decomposed import quantized_decomposed_lib # noqa: F401
  36. from .custom_config import PrepareCustomConfig
  37. # TODO: revisit this list. Many helper methods shouldn't be public
  38. __all__ = [
  39. "all_node_args_except_first",
  40. "all_node_args_have_no_tensors",
  41. "assert_and_get_unique_device",
  42. "collect_producer_nodes",
  43. "create_getattr_from_value",
  44. "create_node_from_old_node_preserve_meta",
  45. "EMPTY_ARG_DICT",
  46. "get_custom_module_class_keys",
  47. "get_linear_prepack_op_for_dtype",
  48. "get_new_attr_name_with_prefix",
  49. "get_non_observable_arg_indexes_and_types",
  50. "get_qconv_prepack_op",
  51. "get_skipped_module_name_and_classes",
  52. "graph_module_from_producer_nodes",
  53. "maybe_get_next_module",
  54. "NodeInfo",
  55. "node_arg_is_bias",
  56. "node_arg_is_weight",
  57. "NON_OBSERVABLE_ARG_DICT",
  58. "NON_QUANTIZABLE_WEIGHT_OPS",
  59. "return_arg_list",
  60. "ObservedGraphModuleAttrs",
  61. ]
  62. NON_QUANTIZABLE_WEIGHT_OPS = {
  63. torch.nn.functional.layer_norm,
  64. torch.nn.functional.group_norm,
  65. torch.nn.functional.instance_norm,
  66. }
  67. @dataclass
  68. class ObservedGraphModuleAttrs:
  69. node_name_to_qconfig: dict[str, QConfigAny]
  70. node_name_to_scope: dict[str, tuple[str, type]]
  71. prepare_custom_config: PrepareCustomConfig
  72. equalization_node_name_to_qconfig: dict[str, Any]
  73. qconfig_mapping: QConfigMapping
  74. is_qat: bool
  75. observed_node_names: set[str]
  76. is_observed_standalone_module: bool = False
  77. standalone_module_input_quantized_idxs: Optional[list[int]] = None
  78. standalone_module_output_quantized_idxs: Optional[list[int]] = None
  79. def node_arg_is_weight(node: Node, arg: Any) -> bool:
  80. """Returns if node arg is weight"""
  81. weight_index = None
  82. if "target_dtype_info" in node.meta:
  83. weight_index = node.meta["target_dtype_info"].get("weight_index", None)
  84. if (
  85. weight_index is not None
  86. and weight_index < len(node.args)
  87. and node.args[weight_index] is arg
  88. ):
  89. return True
  90. return node.kwargs.get("weight") is arg
  91. def node_arg_is_bias(node: Node, arg: Any) -> bool:
  92. """Returns if node arg is bias"""
  93. bias_index = None
  94. if "target_dtype_info" in node.meta:
  95. bias_index = node.meta["target_dtype_info"].get("bias_index", None)
  96. if (
  97. bias_index is not None
  98. and bias_index < len(node.args)
  99. and node.args[bias_index] is arg
  100. ):
  101. return True
  102. return node.kwargs.get("bias") is arg
  103. def get_custom_module_class_keys(
  104. custom_module_mapping: dict[QuantType, dict[type, type]],
  105. ) -> list[Any]:
  106. r"""Get all the unique custom module keys in the custom config dict
  107. e.g.
  108. Input:
  109. {
  110. QuantType.STATIC: {
  111. CustomModule1: ObservedCustomModule
  112. },
  113. QuantType.DYNAMIC: {
  114. CustomModule2: DynamicObservedCustomModule
  115. },
  116. QuantType.WEIGHT_ONLY: {
  117. CustomModule3: WeightOnlyObservedCustomModule
  118. },
  119. }
  120. Output:
  121. # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts
  122. [CustomModule1, CustomModule2, CustomModule3]
  123. """
  124. # using set to dedup
  125. float_custom_module_classes: set[Any] = set()
  126. for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]:
  127. quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {})
  128. quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys())
  129. float_custom_module_classes |= quant_mode_custom_module_classes
  130. return list(float_custom_module_classes)
  131. def get_linear_prepack_op_for_dtype(dtype):
  132. if dtype == torch.float16:
  133. return torch.ops.quantized.linear_prepack_fp16
  134. elif dtype == torch.qint8:
  135. return torch.ops.quantized.linear_prepack
  136. else:
  137. raise Exception("can't get linear prepack op for dtype:", dtype) # noqa: TRY002
  138. def get_qconv_prepack_op(conv_op: Callable) -> Callable:
  139. prepack_ops = {
  140. torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack,
  141. torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack,
  142. torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack,
  143. torch.nn.functional.conv_transpose1d: torch.ops.quantized.conv_transpose1d_prepack,
  144. torch.nn.functional.conv_transpose2d: torch.ops.quantized.conv_transpose2d_prepack,
  145. torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
  146. }
  147. prepack_op = prepack_ops.get(conv_op, None)
  148. assert prepack_op, f"Didn't find prepack op for {conv_op}"
  149. return prepack_op
  150. # Returns a function that can get a new attribute name for module with given
  151. # prefix, for example,
  152. # >> get_new_observer_name = get_new_attr_name_with_prefix('_observer')
  153. # >> new_name = get_new_observer_name(module)
  154. # new_name will be an unused attribute name on module, e.g. `_observer_1`
  155. def get_new_attr_name_with_prefix(prefix: str) -> Callable:
  156. prefix = prefix.replace(".", "_")
  157. def get_new_attr_name(module: torch.nn.Module):
  158. def get_attr_name(i: int):
  159. return prefix + str(i)
  160. i = 0
  161. attr_name = get_attr_name(i)
  162. while hasattr(module, attr_name):
  163. i += 1
  164. attr_name = get_attr_name(i)
  165. return attr_name
  166. return get_new_attr_name
  167. def collect_producer_nodes(node: Node) -> Optional[list[Node]]:
  168. r"""Starting from a target node, trace back until we hit input or
  169. getattr node. This is used to extract the chain of operators
  170. starting from getattr to the target node, for example
  171. def forward(self, x):
  172. observed = self.observer(self.weight)
  173. return F.linear(x, observed)
  174. collect_producer_nodes(observed) will either return a list of nodes that
  175. produces the observed node or None if we can't extract a self contained
  176. graph without free variables(inputs of the forward function).
  177. """
  178. nodes = [node]
  179. frontier = [node]
  180. while frontier:
  181. node = frontier.pop()
  182. all_args = list(node.args) + list(node.kwargs.values())
  183. for arg in all_args:
  184. if not isinstance(arg, Node):
  185. continue
  186. if arg.op == "placeholder":
  187. # hit input, can't fold in this case
  188. return None
  189. nodes.append(arg)
  190. if not (arg.op == "call_function" and arg.target == getattr):
  191. frontier.append(arg)
  192. return nodes
  193. def graph_module_from_producer_nodes(
  194. root: GraphModule, producer_nodes: list[Node]
  195. ) -> GraphModule:
  196. r"""Construct a graph module from extracted producer nodes
  197. from `collect_producer_nodes` function
  198. Args:
  199. root: the root module for the original graph
  200. producer_nodes: a list of nodes we use to construct the graph
  201. Return:
  202. A graph module constructed from the producer nodes
  203. """
  204. assert len(producer_nodes) > 0, "list of producer nodes can not be empty"
  205. # since we traced back from node to getattr
  206. producer_nodes.reverse()
  207. graph = Graph()
  208. env: dict[Any, Any] = {}
  209. def load_arg(a):
  210. return map_arg(a, lambda node: env[node])
  211. for producer_node in producer_nodes:
  212. env[producer_node] = graph.node_copy(producer_node, load_arg)
  213. graph.output(load_arg(producer_nodes[-1]))
  214. graph_module = GraphModule(root, graph)
  215. return graph_module
  216. # TODO: delete
  217. def assert_and_get_unique_device(module: torch.nn.Module) -> Any:
  218. """
  219. Returns the unique device for a module, or None if no device is found.
  220. Throws an error if multiple devices are detected.
  221. """
  222. return _assert_and_get_unique_device(module)
  223. def create_getattr_from_value(
  224. module: torch.nn.Module,
  225. graph: Graph,
  226. prefix: str,
  227. value: Any,
  228. device: Optional[torch.device] = None,
  229. ) -> Node:
  230. """
  231. Given a value of any type, creates a getattr node corresponding to the value and
  232. registers the value as a buffer to the module.
  233. """
  234. get_new_attr_name = get_new_attr_name_with_prefix(prefix)
  235. attr_name = get_new_attr_name(module)
  236. if device is None:
  237. device = assert_and_get_unique_device(module)
  238. new_value = (
  239. value.detach().clone()
  240. if isinstance(value, torch.Tensor)
  241. else torch.tensor(value, device=device)
  242. )
  243. module.register_buffer(attr_name, new_value)
  244. # Create get_attr with value
  245. attr_node = graph.create_node("get_attr", attr_name)
  246. return attr_node
  247. def all_node_args_have_no_tensors(
  248. node: Node, modules: dict[str, torch.nn.Module], cache: dict[Node, bool]
  249. ) -> bool:
  250. """
  251. If we know for sure that all of this node's args have no
  252. tensors (are primitives), return True. If we either
  253. find a tensor or are not sure, return False. Note: this
  254. function is not exact.
  255. """
  256. if cache and node in cache:
  257. return cache[node]
  258. result = False # will be overwritten
  259. if not isinstance(node, Node):
  260. result = True
  261. elif node.op == "placeholder":
  262. result = False
  263. elif node.op == "call_module":
  264. assert isinstance(node.target, str)
  265. if _is_activation_post_process(modules[node.target]):
  266. result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
  267. elif node.op == "call_module":
  268. result = False
  269. elif node.op == "call_function" and node.target is operator.getitem:
  270. result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
  271. elif node.op == "get_attr":
  272. result = False
  273. elif node.target is getattr and node.args[1] in ["ndim", "shape"]:
  274. # x1 = x0.ndim
  275. result = True
  276. elif node.op == "call_method" and node.target == "size":
  277. # x1 = x0.size(0)
  278. result = True
  279. else:
  280. found_one_tensor = False
  281. for arg in node.args:
  282. if isinstance(arg, list):
  283. for list_el in arg:
  284. if isinstance(list_el, Node):
  285. this_list_el_args_have_no_tensors = (
  286. all_node_args_have_no_tensors(list_el, modules, cache)
  287. )
  288. found_one_tensor = found_one_tensor or (
  289. not this_list_el_args_have_no_tensors
  290. )
  291. # If found_one_tensor is True, there is no point in
  292. # recursing further as the end result will always
  293. # be True.
  294. # TODO(future PR): remove this entire function and
  295. # change to dtype inference without recursion.
  296. if found_one_tensor:
  297. result = not found_one_tensor
  298. if cache:
  299. cache[node] = result
  300. return result
  301. elif isinstance(arg, int):
  302. pass
  303. else:
  304. if isinstance(arg, Node):
  305. this_arg_args_have_no_tensors = all_node_args_have_no_tensors(
  306. arg, modules, cache
  307. )
  308. found_one_tensor = found_one_tensor or (
  309. not this_arg_args_have_no_tensors
  310. )
  311. # If found_one_tensor is True, there is no point in
  312. # recursing further as the end result will always
  313. # be True.
  314. # TODO(future PR): remove this entire function and
  315. # change to dtype inference without recursion.
  316. if found_one_tensor:
  317. result = not found_one_tensor
  318. if cache:
  319. cache[node] = result
  320. return result
  321. else:
  322. found_one_tensor = True
  323. result = not found_one_tensor
  324. if cache:
  325. cache[node] = result
  326. return result
  327. def all_node_args_except_first(node: Node) -> list[int]:
  328. """
  329. Returns all node arg indices after first
  330. """
  331. return list(range(1, len(node.args)))
  332. def return_arg_list(arg_indices: list[int]) -> Callable[[Node], list[int]]:
  333. """
  334. Constructs a function that takes a node as arg and returns the arg_indices
  335. that are valid for node.args
  336. """
  337. def arg_indices_func(node: Node) -> list[int]:
  338. return [i for i in arg_indices if i < len(node.args)]
  339. return arg_indices_func
  340. NodeInfo = namedtuple("NodeInfo", "op target")
  341. # this dict identifies which indices of a node are non tensors
  342. # so that they can be propagated correctly since inserting observers
  343. # for them would cause errors
  344. NON_OBSERVABLE_ARG_DICT: dict[
  345. NodeInfo, dict[Union[type, torch.dtype], Callable[[Node], list[int]]]
  346. ] = {
  347. NodeInfo("call_method", "masked_fill"): {
  348. torch.bool: return_arg_list([1]),
  349. float: return_arg_list([2]),
  350. },
  351. NodeInfo("call_method", "permute"): {int: all_node_args_except_first},
  352. NodeInfo("call_method", "repeat"): {int: all_node_args_except_first},
  353. NodeInfo("call_method", "reshape"): {int: all_node_args_except_first},
  354. NodeInfo("call_method", "size"): {int: return_arg_list([1])},
  355. NodeInfo("call_method", "transpose"): {int: all_node_args_except_first},
  356. NodeInfo("call_method", torch.transpose): {int: all_node_args_except_first},
  357. NodeInfo("call_method", "unsqueeze"): {int: return_arg_list([1])},
  358. NodeInfo("call_method", "unsqueeze_"): {int: return_arg_list([1])},
  359. NodeInfo("call_method", torch.unsqueeze): {int: return_arg_list([1])},
  360. NodeInfo("call_method", "view"): {int: all_node_args_except_first},
  361. }
  362. EMPTY_ARG_DICT: dict[Union[type, torch.dtype], Callable[[Node], list[int]]] = {}
  363. def get_non_observable_arg_indexes_and_types(
  364. node: Node,
  365. ) -> dict[Union[type, torch.dtype], Callable[[Node], list[int]]]:
  366. """
  367. Returns a dict with of non float tensor types as keys and values which correspond to a
  368. function to retrieve the list (which takes the node as an argument)
  369. """
  370. info = NodeInfo(node.op, node.target)
  371. return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT)
  372. def maybe_get_next_module(
  373. node: Node,
  374. modules: dict[str, nn.Module],
  375. target_module_type: Optional[type[nn.Module]] = None,
  376. target_functional_type: Any = None,
  377. ) -> Optional[Node]:
  378. """Gets the next module that matches what is needed in
  379. is_target_module_type if it exists
  380. Args:
  381. node: The node whose users we want to look at
  382. target_module_type: Module type that we want to check
  383. target_functional_type: Functional type that we want to check
  384. """
  385. for user in node.users.keys():
  386. if (
  387. user.op == "call_module"
  388. and target_module_type is not None
  389. and isinstance(modules[str(user.target)], target_module_type)
  390. ):
  391. return user
  392. elif (
  393. user.op == "call_function"
  394. and target_functional_type is not None
  395. and user.target == target_functional_type
  396. ):
  397. return user
  398. return None
  399. def create_node_from_old_node_preserve_meta(
  400. quantized_graph: Graph,
  401. create_node_args: tuple[Any, ...],
  402. old_node: Node,
  403. ) -> Node:
  404. """
  405. Creates `new_node` and copies the necessary metadata to it from `old_node`.
  406. """
  407. new_node = quantized_graph.create_node(*create_node_args)
  408. new_node.stack_trace = old_node.stack_trace
  409. return new_node
  410. def get_skipped_module_name_and_classes(
  411. prepare_custom_config: PrepareCustomConfig, is_standalone_module: bool
  412. ) -> tuple[list[str], list[type[Any]]]:
  413. skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names)
  414. skipped_module_classes = copy.copy(
  415. prepare_custom_config.non_traceable_module_classes
  416. )
  417. if not is_standalone_module:
  418. # standalone module and custom module config are applied in top level module
  419. skipped_module_names += list(
  420. prepare_custom_config.standalone_module_names.keys()
  421. )
  422. skipped_module_classes += list(
  423. prepare_custom_config.standalone_module_classes.keys()
  424. )
  425. skipped_module_classes += get_custom_module_class_keys(
  426. prepare_custom_config.float_to_observed_mapping
  427. )
  428. return skipped_module_names, skipped_module_classes
  429. def _is_custom_module_lstm(
  430. node: Node,
  431. named_modules: dict[str, torch.nn.Module],
  432. qconfig: QConfigAny = None,
  433. # QuantizeHandler, but we cannot include the type here due to circular imports
  434. qhandler: Optional[Any] = None,
  435. ) -> bool:
  436. """
  437. Return whether this refers to the custom module LSTM flow.
  438. """
  439. mod = _get_module(node, named_modules)
  440. if qconfig is not None and qhandler is not None:
  441. assert isinstance(
  442. qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
  443. ) # type: ignore[attr-defined]
  444. return (
  445. isinstance(mod, torch.nn.LSTM)
  446. and activation_is_statically_quantized(qconfig)
  447. and qhandler.is_custom_module()
  448. )
  449. else:
  450. return isinstance(mod, torch.ao.nn.quantizable.LSTM)
  451. def _is_custom_module_mha(
  452. node: Node,
  453. named_modules: dict[str, torch.nn.Module],
  454. qconfig: QConfigAny = None,
  455. # QuantizeHandler, but we cannot include the type here due to circular imports
  456. qhandler: Optional[Any] = None,
  457. ) -> bool:
  458. """
  459. Return whether this refers to the custom module MultiheadAttention flow.
  460. """
  461. mod = _get_module(node, named_modules)
  462. if qconfig is not None and qhandler is not None:
  463. assert isinstance(
  464. qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
  465. ) # type: ignore[attr-defined]
  466. return (
  467. isinstance(mod, torch.nn.MultiheadAttention)
  468. and activation_is_statically_quantized(qconfig)
  469. and qhandler.is_custom_module()
  470. )
  471. else:
  472. return isinstance(mod, torch.ao.nn.quantizable.MultiheadAttention)
  473. def _get_module(
  474. node: Node, named_modules: dict[str, torch.nn.Module]
  475. ) -> Optional[torch.nn.Module]:
  476. """
  477. If `node` refers to a call_module node, return the module, else None.
  478. """
  479. if node.op == "call_module" and str(node.target) in named_modules:
  480. return named_modules[str(node.target)]
  481. else:
  482. return None
  483. def _insert_dequant_stub(
  484. node: Node,
  485. model: torch.nn.Module,
  486. named_modules: dict[str, torch.nn.Module],
  487. graph: Graph,
  488. ) -> Node:
  489. """
  490. Attach a `DeQuantStub` to the model and create a node that calls this
  491. `DeQuantStub` on the output of `node`, similar to how observers are inserted.
  492. """
  493. prefix = "dequant_stub_"
  494. get_new_dequant_stub_name = get_new_attr_name_with_prefix(prefix)
  495. dequant_stub_name = get_new_dequant_stub_name(model)
  496. dequant_stub = DeQuantStub()
  497. setattr(model, dequant_stub_name, dequant_stub)
  498. named_modules[dequant_stub_name] = dequant_stub
  499. with graph.inserting_after(node):
  500. return graph.call_module(dequant_stub_name, (node,))
  501. def _insert_dequant_stubs_for_custom_module_lstm_output(
  502. node: Node,
  503. model: torch.nn.Module,
  504. named_modules: dict[str, torch.nn.Module],
  505. graph: Graph,
  506. ) -> Node:
  507. """
  508. Insert DeQuantStubs after each internal output node of custom module LSTM.
  509. Custom module LSTM outputs are nested tuples of the structure (output, (hidden0, hidden1)),
  510. Since we cannot dequantize a tuple as a whole, we must first break down the tuple into its
  511. components through `getitem`. This function transforms the graph as follows:
  512. (1) Split the LSTM node into (output, (hidden0, hidden1))
  513. (2) Insert a DeQuantStub after each internal node
  514. (3) Recombine the DeQuantStubs into the same structure as before
  515. (4) Reroute all consumers of the original LSTM node and its sub-nodes
  516. (e.g. lstm[0])
  517. Before:
  518. lstm_output
  519. |
  520. v
  521. original_user(s)
  522. After:
  523. lstm_output
  524. / \\
  525. / (getitem) \\
  526. / \\
  527. v v
  528. output hidden
  529. | / \\
  530. (DeQuantStub) (getitem)
  531. | / \\
  532. v v v
  533. output_dq hidden0 hidden1
  534. | | |
  535. | (DeQuantStub) (DeQuantStub)
  536. | | |
  537. | v v
  538. | hidden0_dq hidden1_dq
  539. | \\ /
  540. | (tuple)
  541. | \\ /
  542. | v v
  543. | hidden_dq
  544. \\ /
  545. \\ (tuple) /
  546. v v
  547. lstm_output_dq
  548. |
  549. v
  550. original_user(s)
  551. For step (4), reroute all users of the original LSTM node(s) as follows:
  552. lstm_output -> lstm_output_dq
  553. lstm_output[0] -> output_dq
  554. lstm_output[1] -> hidden_dq
  555. lstm_output[1][0] -> hidden0_dq
  556. lstm_output[1][1] -> hidden1_dq
  557. Return the node `lstm_output_dq`.
  558. """
  559. # (1) Split the LSTM node into (output, (hidden0, hidden1))
  560. # (2) Insert a DeQuantStub after each internal node
  561. with graph.inserting_after(node):
  562. output = graph.call_function(operator.getitem, (node, 0))
  563. output_dq = _insert_dequant_stub(output, model, named_modules, graph)
  564. with graph.inserting_after(output_dq):
  565. hidden = graph.call_function(operator.getitem, (node, 1))
  566. with graph.inserting_after(hidden):
  567. hidden0 = graph.call_function(operator.getitem, (hidden, 0))
  568. hidden0_dq = _insert_dequant_stub(hidden0, model, named_modules, graph)
  569. with graph.inserting_after(hidden0_dq):
  570. hidden1 = graph.call_function(operator.getitem, (hidden, 1))
  571. hidden1_dq = _insert_dequant_stub(hidden1, model, named_modules, graph)
  572. # (3) Recombine the DeQuantStubs into the same structure as before
  573. with graph.inserting_after(hidden1_dq):
  574. hidden_dq = graph.call_function(tuple, ([hidden0_dq, hidden1_dq],))
  575. with graph.inserting_after(hidden_dq):
  576. lstm_output_dq = graph.call_function(tuple, ([output_dq, hidden_dq],))
  577. # (4) Reroute all consumers of the original LSTM node and its sub-nodes
  578. for user in list(node.users.keys()):
  579. if user != output and user != hidden:
  580. user.replace_input_with(node, lstm_output_dq)
  581. # The getitem and tuple nodes we added here may interfere with reference quantized
  582. # pattern matching, so we need to redirect the consumers of internal nodes to the
  583. # corresponding nodes with DeQuantStubs (e.g. lstm_output_dq[0] -> output_dq) attached,
  584. # in order to preserve reference patterns like "dequantize - consumer - quantize".
  585. _reroute_tuple_getitem_pattern(graph)
  586. return lstm_output_dq
  587. def _maybe_get_custom_module_lstm_from_node_arg(
  588. arg: Node,
  589. named_modules: dict[str, torch.nn.Module],
  590. ) -> Optional[Node]:
  591. """
  592. Given an argument of a node, if the argument refers to the path through which the node
  593. is a consumer of custom module LSTM, return the custom module LSTM node, or None otherwise.
  594. This is used to determine whether a node is a consumer of custom module LSTM, and, if so,
  595. skip inserting input observers for this node. This is because custom module LSTM produces
  596. quantized outputs, so inserting an input observer for the consumer of custom module LSTM
  597. would unnecessarily quantize the outputs again.
  598. lstm -> consumer
  599. In practice, however, custom module LSTM outputs a tuple (output, (hidden0, hidden1)) with
  600. DeQuantStubs attached to each internal node (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
  601. This tuple can be consumed in one of four ways:
  602. lstm -> getitem -> DeQuantStub -> consumer # consume lstm[0]
  603. lstm -> getitem -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm[1]
  604. lstm -> getitem -> getitem -> DeQuantStub -> consumer # consume lstm[1][0] or lstm[1][1]
  605. lstm -> getitem -> DeQuantStub -> tuple -> consumer # consume lstm
  606. Thus, we must match against the above patterns instead of simply checking the parent node
  607. to determine whether this node is a consumer of a custom module LSTM.
  608. """
  609. def match_dq(a):
  610. return isinstance(_get_module(a, named_modules), DeQuantStub)
  611. def match_lstm(a):
  612. return _is_custom_module_lstm(a, named_modules)
  613. def match_getitem(a):
  614. return a.op == "call_function" and a.target == operator.getitem
  615. def match_tuple(a):
  616. return a.op == "call_function" and a.target == tuple
  617. def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]:
  618. """
  619. Traverse up the graph and match the args one by one.
  620. If there is a match, return the last matched node, or None otherwise.
  621. """
  622. a = arg
  623. for i, match in enumerate(match_pattern):
  624. if not match(a):
  625. return None
  626. # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],)
  627. if i < len(match_pattern) - 1:
  628. if match == match_tuple:
  629. a = a.args[0][0] # type: ignore[assignment,index]
  630. else:
  631. a = a.args[0] # type: ignore[assignment]
  632. return a
  633. all_match_patterns = [
  634. [match_dq, match_getitem, match_lstm],
  635. [match_tuple, match_dq, match_getitem, match_getitem, match_lstm],
  636. [match_dq, match_getitem, match_getitem, match_lstm],
  637. [match_tuple, match_dq, match_getitem, match_lstm],
  638. ]
  639. for p in all_match_patterns:
  640. matched_node = _match_pattern(p)
  641. if matched_node is not None:
  642. return matched_node
  643. return None
  644. def _reroute_tuple_getitem_pattern(graph: Graph):
  645. """
  646. Search for patterns where N consecutive `tuple` call_function nodes are followed by
  647. N consecutive `getitem` call_function nodes that are "reverses" of the `tuple` nodes.
  648. If we find this pattern, reroute the consumers of the last `getitem` to skip these
  649. N `tuple` and `getitem` nodes.
  650. Before:
  651. a b c
  652. | \\ /
  653. \\ tuple
  654. \\ /
  655. tuple
  656. |
  657. getitem(1)
  658. |
  659. getitem(0)
  660. |
  661. d
  662. After:
  663. b
  664. |
  665. d
  666. """
  667. def find_patterns(
  668. node: Node,
  669. index_stack: list[int],
  670. current_pattern: list[Node],
  671. matched_patterns: list[list[Node]],
  672. seen: set[tuple[Node, tuple[int, ...]]],
  673. ):
  674. """
  675. Traverse the graph recursively to match for the N-tuple - N-getitem patterns,
  676. starting at the given node.
  677. We use a stack to keep track of the expected `getitem` indices, since these are
  678. reversed from the `tuple` indices. In the above example, the stack after
  679. (b -> tuple -> tuple) will be [0, 1], which will be popped by getitem(1) first
  680. and then by getitem(0).
  681. TODO: traverse upwards from the output and handle the case when tuple is not a
  682. separate node, e.g. graph.call_function(operator.getitem, args=(a, (b, c)))
  683. """
  684. if len(index_stack) == 0 and len(current_pattern) > 0:
  685. matched_patterns.append(copy.copy(current_pattern))
  686. current_pattern.clear()
  687. # Avoid duplicating work
  688. state = (node, tuple(index_stack))
  689. if state in seen:
  690. return
  691. seen.add(state)
  692. # Iterate through users of this node to find tuple/getitem nodes to match
  693. for user in node.users:
  694. if user.op == "call_function" and user.target == tuple:
  695. for i, user_arg in enumerate(user.args[0]): # type: ignore[arg-type]
  696. if user_arg == node:
  697. index_stack.append(i)
  698. current_pattern.append(user)
  699. find_patterns(
  700. user, index_stack, current_pattern, matched_patterns, seen
  701. )
  702. elif user.op == "call_function" and user.target == operator.getitem:
  703. if len(index_stack) > 0:
  704. if user.args[1] == index_stack[-1]:
  705. index_stack.pop()
  706. current_pattern.append(user)
  707. find_patterns(
  708. user, index_stack, current_pattern, matched_patterns, seen
  709. )
  710. return matched_patterns
  711. # Collect all matched patterns
  712. matched_patterns: list[list[Node]] = []
  713. seen: set[tuple[Node, tuple[int, ...]]] = set() # (node, index_stack)
  714. for node in graph.nodes:
  715. find_patterns(node, [], [], matched_patterns, seen)
  716. # For each pattern, redirect all consumers of the last getitem node to the correct input
  717. # of the first tuple node
  718. for pattern in matched_patterns:
  719. first_tuple = pattern[0]
  720. last_getitem = pattern[-1]
  721. assert first_tuple.op == "call_function" and first_tuple.target == tuple
  722. assert (
  723. last_getitem.op == "call_function"
  724. and last_getitem.target == operator.getitem
  725. )
  726. last_getitem_index = last_getitem.args[1]
  727. new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
  728. for user in list(last_getitem.users.keys()):
  729. user.replace_input_with(last_getitem, new_input) # type: ignore[arg-type]
  730. def _get_observer_from_activation_post_process(
  731. activation_post_process: Union[ObserverBase, FakeQuantizeBase],
  732. ) -> ObserverBase:
  733. """
  734. If `activation_post_process` is an observer, return the observer.
  735. If `activation_post_process` is a fake quantize, return the internal observer.
  736. """
  737. if isinstance(activation_post_process, ObserverBase):
  738. return activation_post_process
  739. else:
  740. assert isinstance(activation_post_process, FakeQuantizeBase)
  741. return activation_post_process.activation_post_process # type: ignore[return-value]
  742. def _qconfig_satisfies_dtype_config_constraints(
  743. qconfig: QConfigAny,
  744. dtype_with_constraints: DTypeWithConstraints,
  745. is_activation: bool = True,
  746. ) -> bool:
  747. """
  748. Return whether `qconfig` satisfies the following constraints from the backend,
  749. specified through the activation and weight DTypeWithConstraints.
  750. 1. QConfig specified a quantization range that falls within the backend's, if any
  751. 2. QConfig specified a min scale value that is >= the backend's, if any
  752. 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has
  753. scale and zero point that match the backend's, if any
  754. If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`.
  755. If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True.
  756. """
  757. # TODO: log warnings only when the user enabled a debug flag
  758. def _activation_post_process_satisfies_dtype_config_constraints(
  759. activation_post_process: Union[ObserverBase, FakeQuantizeBase],
  760. dtype_with_constraints: DTypeWithConstraints,
  761. debug_string: str,
  762. ) -> bool:
  763. observer = _get_observer_from_activation_post_process(activation_post_process)
  764. app_quant_min = getattr(observer, "quant_min", None)
  765. app_quant_max = getattr(observer, "quant_max", None)
  766. # TODO: for now, just use the existing eps value as scale_min. In the future, we should
  767. # resolve the differences between the two, either by renaming eps or some other way
  768. app_scale_min = getattr(observer, "eps", None)
  769. backend_quant_min = dtype_with_constraints.quant_min_lower_bound
  770. backend_quant_max = dtype_with_constraints.quant_max_upper_bound
  771. backend_scale_min = dtype_with_constraints.scale_min_lower_bound
  772. backend_scale_exact_match = dtype_with_constraints.scale_exact_match
  773. backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match
  774. # check quantization ranges
  775. if backend_quant_min is not None and backend_quant_max is not None:
  776. if app_quant_min is None or app_quant_max is None:
  777. warnings.warn(
  778. f"QConfig {debug_string} must specify 'quant_min' and 'quant_max', ignoring {qconfig}"
  779. )
  780. return False
  781. elif app_quant_min < backend_quant_min or app_quant_max > backend_quant_max:
  782. warnings.warn(
  783. f"QConfig {debug_string} quantization range must fall within the backend's:\n"
  784. f"QConfig range = ({app_quant_min}, {app_quant_max}), "
  785. f"BackendConfig range = ({backend_quant_min}, {backend_quant_max}), "
  786. f"ignoring {qconfig}"
  787. )
  788. return False
  789. # check scale min
  790. if backend_scale_min is not None:
  791. if app_scale_min is None:
  792. warnings.warn(
  793. f"QConfig {debug_string} must specify 'eps', ignoring {qconfig}"
  794. )
  795. return False
  796. if app_scale_min < backend_scale_min:
  797. warnings.warn(
  798. f"QConfig {debug_string} eps ({app_scale_min}) must be greater than or equal to "
  799. f"the backend's min scale value ({backend_scale_min}), ignoring {qconfig}"
  800. )
  801. return False
  802. # check fixed scale and zero point
  803. if (
  804. backend_scale_exact_match is not None
  805. and backend_zero_point_exact_match is not None
  806. ):
  807. # For tests only, accept the following qconfigs for now
  808. # TODO: handle fp16 qconfigs properly
  809. for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]:
  810. if qconfig_equals(qconfig, accepted_qconfig):
  811. return True
  812. suggestion_str = (
  813. "Please use torch.ao.quantization.get_default_qconfig_mapping or "
  814. "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n"
  815. ' qconfig_mapping = get_default_qconfig_mapping("fbgemm")\n'
  816. " model = prepare_fx(model, qconfig_mapping, example_inputs)"
  817. )
  818. if not isinstance(
  819. activation_post_process, FixedQParamsObserver
  820. ) and not isinstance(activation_post_process, FixedQParamsFakeQuantize):
  821. warnings.warn(
  822. f"QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize "
  823. f"for fixed qparams ops, ignoring {qconfig}.\n{suggestion_str}"
  824. )
  825. return False
  826. if (
  827. observer.scale != backend_scale_exact_match
  828. or observer.zero_point != backend_zero_point_exact_match
  829. ):
  830. warnings.warn(
  831. f"QConfig fixed scale ({observer.scale}) and zero point ({observer.zero_point}) "
  832. f"do not match the backend's ({backend_scale_exact_match} and {backend_zero_point_exact_match}), "
  833. f"ignoring {qconfig}.\n{suggestion_str}"
  834. )
  835. return False
  836. return True
  837. if qconfig is None or dtype_with_constraints.dtype is None:
  838. return True
  839. activation_post_process_ctr = (
  840. qconfig.activation if is_activation else qconfig.weight
  841. )
  842. debug_string = "activation" if is_activation else "weight"
  843. satisfies_constraints = True
  844. if activation_post_process_ctr is not None:
  845. activation_post_process = activation_post_process_ctr()
  846. assert _is_activation_post_process(activation_post_process)
  847. # If dtypes don't match, don't check the activation_post_process and return True early
  848. if activation_post_process.dtype != dtype_with_constraints.dtype:
  849. return True
  850. satisfies_constraints = (
  851. _activation_post_process_satisfies_dtype_config_constraints(
  852. activation_post_process, dtype_with_constraints, debug_string
  853. )
  854. )
  855. return satisfies_constraints