convert.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282
  1. # mypy: ignore-errors
  2. import copy
  3. import operator
  4. import warnings
  5. from typing import Any, Callable, Optional, Union
  6. import torch
  7. from torch.ao.quantization import CUSTOM_KEY, NUMERIC_DEBUG_HANDLE_KEY
  8. from torch.ao.quantization.backend_config import (
  9. BackendConfig,
  10. get_native_backend_config,
  11. )
  12. from torch.ao.quantization.backend_config.utils import (
  13. get_fused_module_classes,
  14. get_pattern_to_dtype_configs,
  15. get_qat_module_classes,
  16. get_root_module_to_quantized_reference_module,
  17. )
  18. from torch.ao.quantization.observer import _is_activation_post_process
  19. from torch.ao.quantization.qconfig import qconfig_equals, QConfigAny
  20. from torch.ao.quantization.qconfig_mapping import QConfigMapping
  21. from torch.ao.quantization.quant_type import QuantType
  22. from torch.ao.quantization.quantize import _remove_qconfig
  23. from torch.ao.quantization.stubs import DeQuantStub
  24. from torch.ao.quantization.utils import (
  25. _parent_name,
  26. activation_is_statically_quantized,
  27. get_qparam_dict,
  28. get_swapped_custom_module_class,
  29. is_per_channel,
  30. to_underlying_dtype,
  31. weight_is_quantized,
  32. )
  33. from torch.fx import GraphModule
  34. from torch.fx.graph import Argument, Graph, Node
  35. from torch.nn.utils.parametrize import type_before_parametrizations
  36. # importing the lib so that the quantized_decomposed ops are registered
  37. from ._decomposed import quantized_decomposed_lib # noqa: F401
  38. from ._equalize import convert_eq_obs, update_obs_for_equalization
  39. from .custom_config import ConvertCustomConfig, PrepareCustomConfig
  40. from .graph_module import _is_observed_module, _is_observed_standalone_module
  41. from .lower_to_fbgemm import lower_to_fbgemm
  42. from .qconfig_mapping_utils import (
  43. _compare_prepare_convert_qconfig_mappings,
  44. _generate_node_name_to_qconfig,
  45. _is_qconfig_supported_by_dtype_configs,
  46. _update_qconfig_for_fusion,
  47. _update_qconfig_for_qat,
  48. )
  49. from .utils import (
  50. _get_module,
  51. _is_custom_module_lstm,
  52. _is_custom_module_mha,
  53. assert_and_get_unique_device,
  54. collect_producer_nodes,
  55. create_getattr_from_value,
  56. get_custom_module_class_keys,
  57. graph_module_from_producer_nodes,
  58. node_arg_is_weight,
  59. )
  60. __all__ = [
  61. "convert",
  62. "convert_custom_module",
  63. "convert_standalone_module",
  64. "convert_weighted_module",
  65. ]
  66. SUPPORTED_QDTYPES = [
  67. torch.quint8,
  68. torch.qint8,
  69. torch.qint32,
  70. torch.uint8,
  71. torch.int8,
  72. torch.uint16,
  73. torch.int16,
  74. torch.int32,
  75. torch.float8_e5m2,
  76. torch.float8_e4m3fn,
  77. ]
  78. _QSCHEME_TO_CHOOSE_QPARAMS_OP = {
  79. torch.per_tensor_affine: torch.ops.quantized_decomposed.choose_qparams.tensor,
  80. torch.per_tensor_symmetric: torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
  81. }
  82. def _replace_observer_with_quantize_dequantize_node_decomposed(
  83. model: torch.fx.GraphModule,
  84. node: Node,
  85. modules: dict[str, torch.nn.Module],
  86. node_name_to_scope: dict[str, tuple[str, type]],
  87. node_name_to_qconfig: dict[str, QConfigAny],
  88. model_device: Optional[torch.device] = None,
  89. ) -> None:
  90. """Replace activation_post_process module call node with quantize and
  91. dequantize node working with decomposed Tensor
  92. Before:
  93. ... -> observer_0(x) -> ...
  94. After:
  95. ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) ->
  96. torch.ops.quantized_decomposed.dequantize_per_tensor() -> ...
  97. or quantize_per_channel and dequantize_per_channel
  98. """
  99. graph = model.graph
  100. assert modules is not None
  101. assert isinstance(node.target, str)
  102. module_path, prefix = _get_module_path_and_prefix(
  103. node, node_name_to_scope, node_name_to_qconfig
  104. )
  105. activation_post_process = modules[node.target]
  106. if hasattr(activation_post_process, "convert"):
  107. activation_post_process.convert(model, node)
  108. return
  109. # skip replacing observers to quant/dequant nodes if the qconfigs of all
  110. # consumers and producers of this observer are None
  111. skip_replacement = all(
  112. _has_none_qconfig(n, node_name_to_qconfig)
  113. for n in list(node.args) + list(node.users.keys())
  114. )
  115. if skip_replacement or not _is_conversion_supported(activation_post_process):
  116. # didn't find corresponding quantize op and info for the activation_post_process
  117. # so we just remove the observer
  118. with graph.inserting_before(node):
  119. node.replace_all_uses_with(node.args[0])
  120. graph.erase_node(node)
  121. return
  122. # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
  123. # 1. extract the information from activation_post_process module for generating
  124. # the quantize and dequantize operator
  125. dtype = activation_post_process.dtype # type: ignore[attr-defined]
  126. is_dynamic = False
  127. if hasattr(activation_post_process, "is_dynamic"):
  128. is_dynamic = activation_post_process.is_dynamic # type: ignore[assignment]
  129. def add_dequantize_op_kwargs(dequantize_op, input_node):
  130. dequantize_op_kwargs = {}
  131. if "val" in input_node.meta:
  132. dq_out_dtype = input_node.meta["val"].dtype
  133. if dq_out_dtype != torch.float32:
  134. dequantize_op_kwargs = {"out_dtype": dq_out_dtype}
  135. return dequantize_op_kwargs
  136. if dtype in SUPPORTED_QDTYPES and (not is_dynamic):
  137. # TODO: probably should cleanup this condition check, it's hard
  138. # to reason about this if and the following elif
  139. # uint8/int8/int32 static quantization branch
  140. # 1. extract information for inserting q/dq node from activation_post_process
  141. node_type = "call_function"
  142. quantize_op: Optional[Callable] = None
  143. scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
  144. if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
  145. ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
  146. quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
  147. dequantize_op = (
  148. torch.ops.quantized_decomposed.dequantize_per_channel.default
  149. )
  150. quant_min = activation_post_process.quant_min
  151. quant_max = activation_post_process.quant_max
  152. dtype_ = to_underlying_dtype(dtype)
  153. qparams = {
  154. "_scale_": scale,
  155. "_zero_point_": zero_point,
  156. "_axis_": ch_axis,
  157. "_quant_min_": quant_min,
  158. "_quant_max_": quant_max,
  159. "_dtype_": dtype_,
  160. }
  161. else:
  162. quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
  163. dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
  164. scale = float(scale)
  165. zero_point = int(zero_point)
  166. quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
  167. quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
  168. dtype_ = to_underlying_dtype(dtype)
  169. qparams = {
  170. "_scale_": scale,
  171. "_zero_point_": zero_point,
  172. "_quant_min_": quant_min,
  173. "_quant_max_": quant_max,
  174. "_dtype_": dtype_,
  175. }
  176. # 2. replace activation_post_process node with quantize and dequantize
  177. with graph.inserting_before(node):
  178. input_node = node.args[0]
  179. quantize_op_inputs = [input_node]
  180. for key, value_or_node in qparams.items():
  181. # TODO: we can add the information of whether a value needs to
  182. # be registered as an attribute in qparams dict itself
  183. if key in ["_scale_", "_zero_point_"] and (
  184. not isinstance(value_or_node, (float, int))
  185. ):
  186. # For scale and zero_point values we register them as buffers in the root module.
  187. # However, note that when the values are not tensors, as in the case of
  188. # per_tensor quantization, they will be treated as literals.
  189. # However, registering them as a node seems to cause issue with dynamo
  190. # tracing where it may consider tensor overload as opposed to default.
  191. # With extra check of scale and zero_point being scalar, it makes
  192. # sure that the default overload can be used.
  193. # TODO: maybe need more complex attr name here
  194. qparam_node = create_getattr_from_value(
  195. model,
  196. graph,
  197. module_path + prefix + key,
  198. value_or_node,
  199. model_device,
  200. )
  201. quantize_op_inputs.append(qparam_node)
  202. else:
  203. # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
  204. quantize_op_inputs.append(value_or_node)
  205. quantized_node = graph.create_node(
  206. node_type, quantize_op, tuple(quantize_op_inputs), {}
  207. )
  208. # use the same qparams from quantize op
  209. dq_inputs = [quantized_node] + quantize_op_inputs[1:]
  210. dequantized_node = graph.call_function(
  211. dequantize_op,
  212. tuple(dq_inputs),
  213. add_dequantize_op_kwargs(dequantize_op, input_node),
  214. )
  215. node.replace_all_uses_with(dequantized_node)
  216. # propagate numeric debug handle from observer/fake_quant node to dequantize node
  217. if (
  218. CUSTOM_KEY in node.meta
  219. and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
  220. ):
  221. if CUSTOM_KEY not in dequantized_node.meta:
  222. dequantized_node.meta[CUSTOM_KEY] = {}
  223. dequantized_node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
  224. CUSTOM_KEY
  225. ][NUMERIC_DEBUG_HANDLE_KEY]
  226. graph.erase_node(node)
  227. elif is_dynamic:
  228. # uint8/int8/fp16 dynamic quantization
  229. # 1. extract information for inserting q/dq node from activation_post_process
  230. node_type = "call_function"
  231. quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor
  232. # we only use choose_qparams for is_decomposed now,
  233. # but we should probably align the non-decomposed path with this as well,
  234. # and that can be done after we remove reduce_range flag
  235. # 1. extract qparams from activation_post_process module
  236. dtype_ = to_underlying_dtype(dtype)
  237. assert dtype_ in [torch.uint8, torch.int8], (
  238. "only uint8 and int8 are supported in reference flow for "
  239. "dynamic quantization right now"
  240. )
  241. quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
  242. quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
  243. qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined]
  244. eps = getattr(activation_post_process, "eps", torch.finfo(torch.float32).eps) # type: ignore[attr-defined]
  245. # note: scale and zero_point are missing for quantize_per_tensor op
  246. # we'll need to get this from choose_qparams op, which we'll add after
  247. # this step
  248. qparams = {
  249. "_quant_min_": quant_min,
  250. "_quant_max_": quant_max,
  251. "_eps_": eps,
  252. "_dtype_": dtype_,
  253. }
  254. choose_qparams_op = _QSCHEME_TO_CHOOSE_QPARAMS_OP[qscheme]
  255. # 2. insert choose_qparams op and update the qparams list
  256. with graph.inserting_before(node):
  257. input_node = node.args[0]
  258. choose_qparams_op_inputs = [node.args[0]]
  259. for key, value in qparams.items():
  260. # we have quant_min, quant_max and dtype, all should be stored
  261. # as literals
  262. choose_qparams_op_inputs.append(value)
  263. choose_qparams_node = graph.create_node(
  264. "call_function", choose_qparams_op, tuple(choose_qparams_op_inputs), {}
  265. )
  266. # choose_qparms returns (scale, zero_point)
  267. scale_node = graph.create_node(
  268. "call_function", operator.getitem, (choose_qparams_node, 0), {}
  269. )
  270. zero_point_node = graph.create_node(
  271. "call_function", operator.getitem, (choose_qparams_node, 1), {}
  272. )
  273. quant_min = qparams["_quant_min_"]
  274. quant_max = qparams["_quant_max_"]
  275. dtype = qparams["_dtype_"]
  276. qparams = {
  277. "_scale_": scale_node,
  278. "_zero_point_": zero_point_node,
  279. "_quant_min_": quant_min,
  280. "_quant_max_": quant_max,
  281. "_dtype_": dtype,
  282. }
  283. # 3. replace activation_post_process node to quantize and dequantize node
  284. with graph.inserting_before(node):
  285. input_node = node.args[0]
  286. quantize_op_inputs = [input_node]
  287. for key, value_or_node in qparams.items():
  288. # TODO: we can add the information of whether a value needs to
  289. # be registered as an attribute in qparams dict itself
  290. if key in ["_scale_", "_zero_point_"]:
  291. # in this case we have a node in the graph since it's dynamically
  292. # computed from the input, with choose_qparams op
  293. qparam_node = value_or_node
  294. quantize_op_inputs.append(qparam_node)
  295. else:
  296. # for qparams that are not scale/zero_point (like axis, dtype) we
  297. # store them as literals in the graph.
  298. quantize_op_inputs.append(value_or_node)
  299. quantized_node = graph.create_node(
  300. node_type, quantize_op, tuple(quantize_op_inputs), {}
  301. )
  302. # use the same qparams from quantize op
  303. dq_inputs = [quantized_node] + quantize_op_inputs[1:]
  304. # need to use the tensor variant of this op, since scale and zero_point
  305. # from choose_qparam are Tensors, instead of float/int, this is to
  306. # prevent these nodes being traced away by downstream systems
  307. dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor
  308. dequantized_node = graph.call_function(
  309. dequantize_op,
  310. tuple(dq_inputs),
  311. add_dequantize_op_kwargs(dequantize_op, input_node),
  312. )
  313. node.replace_all_uses_with(dequantized_node)
  314. # propagate numeric debug handle from observer/fake_quant node to dequantize node
  315. if NUMERIC_DEBUG_HANDLE_KEY in node.meta:
  316. dequantized_node.meta[NUMERIC_DEBUG_HANDLE_KEY] = node.meta[
  317. NUMERIC_DEBUG_HANDLE_KEY
  318. ]
  319. graph.erase_node(node)
  320. elif dtype == torch.float16:
  321. # Insert to_fp16 -> to_fp32 node
  322. dtype_convert_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
  323. with graph.inserting_before(node):
  324. input_node = node.args[0]
  325. convert_fp16_node = graph.create_node(
  326. "call_function", dtype_convert_op, (input_node, torch.float16), {}
  327. )
  328. convert_fp32_node = graph.create_node(
  329. "call_function", dtype_convert_op, (convert_fp16_node, torch.float), {}
  330. )
  331. node.replace_all_uses_with(convert_fp32_node)
  332. graph.erase_node(node)
  333. # should not reach since we have checks in the beginning to make sure the
  334. # activation_post_process is supported
  335. def _replace_observer_with_quantize_dequantize_node(
  336. model: torch.fx.GraphModule,
  337. node: Node,
  338. modules: dict[str, torch.nn.Module],
  339. node_name_to_scope: dict[str, tuple[str, type]],
  340. node_name_to_qconfig: dict[str, QConfigAny],
  341. model_device: Optional[torch.device] = None,
  342. ) -> None:
  343. """Replace activation_post_process module call node with quantize and
  344. dequantize node
  345. Before:
  346. ... -> observer_0(x) -> ...
  347. After:
  348. ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
  349. """
  350. assert modules is not None
  351. assert isinstance(node.target, str)
  352. graph = model.graph
  353. module_path, prefix = _get_module_path_and_prefix(
  354. node, node_name_to_scope, node_name_to_qconfig
  355. )
  356. activation_post_process = modules[node.target]
  357. # skip replacing observers to quant/dequant nodes if the qconfigs of all
  358. # consumers and producers of this observer are None
  359. skip_replacement = all(
  360. _has_none_qconfig(n, node_name_to_qconfig)
  361. for n in list(node.args) + list(node.users.keys())
  362. )
  363. if skip_replacement or not _is_conversion_supported(activation_post_process):
  364. # didn't find corresponding quantize op and info for the activation_post_process
  365. # so we just remove the observer
  366. with graph.inserting_before(node):
  367. node.replace_all_uses_with(node.args[0])
  368. graph.erase_node(node)
  369. return
  370. # otherwise, we can convert the activation_post_process module call to quantize/dequantize node
  371. dtype = activation_post_process.dtype # type: ignore[attr-defined]
  372. is_dynamic = False
  373. if hasattr(activation_post_process, "is_dynamic"):
  374. is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
  375. if dtype in [
  376. torch.quint8,
  377. torch.qint8,
  378. torch.qint32,
  379. torch.float8_e5m2,
  380. torch.float8_e4m3fn,
  381. ] and (not is_dynamic):
  382. # TODO: probably should cleanup this condition check, it's hard
  383. # to reason about this if and the following elif
  384. # uint8/int8/int32 static quantization branch
  385. # 1. extract the information from activation_post_process module for generating
  386. # the quantize and dequantize operator
  387. node_type = "call_function"
  388. quantize_op: Optional[Callable] = None
  389. scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
  390. if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined]
  391. ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined, arg-type]
  392. qparams = {
  393. "_scale_": scale,
  394. "_zero_point_": zero_point,
  395. "_axis_": ch_axis,
  396. "_dtype_": dtype,
  397. }
  398. quantize_op = torch.quantize_per_channel
  399. else:
  400. scale = float(scale)
  401. zero_point = int(zero_point)
  402. qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype}
  403. quantize_op = torch.quantize_per_tensor
  404. # 2. replace activation_post_process node with quantize and dequantize
  405. with graph.inserting_before(node):
  406. input_node = node.args[0]
  407. quantize_op_inputs = [input_node]
  408. for key, value_or_node in qparams.items():
  409. # TODO: we can add the information of whether a value needs to
  410. # be registered as an attribute in qparams dict itself
  411. if key in ["_scale_", "_zero_point_"]:
  412. # For scale and zero_point values we register them as buffers in the root module.
  413. # TODO: maybe need more complex attr name here
  414. qparam_node = create_getattr_from_value(
  415. model,
  416. graph,
  417. module_path + prefix + key,
  418. value_or_node,
  419. model_device,
  420. )
  421. quantize_op_inputs.append(qparam_node)
  422. else:
  423. # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph.
  424. quantize_op_inputs.append(value_or_node)
  425. quantized_node = graph.create_node(
  426. node_type, quantize_op, tuple(quantize_op_inputs), {}
  427. )
  428. dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
  429. node.replace_all_uses_with(dequantized_node)
  430. graph.erase_node(node)
  431. elif is_dynamic:
  432. # uint8/int8/fp16 dynamic quantization branch
  433. node_type = "call_function"
  434. quantize_op = torch.quantize_per_tensor_dynamic
  435. # TODO: get reduce range from observer
  436. # reduce_range = activation_post_process.reduce_range
  437. reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86")
  438. qparams = {"_dtype_": dtype, "_reduce_range_": reduce_range}
  439. with graph.inserting_before(node):
  440. input_node = node.args[0]
  441. quantize_op_inputs = [input_node]
  442. for key, value in qparams.items():
  443. quantize_op_inputs.append(value)
  444. quantized_node = graph.create_node(
  445. node_type, quantize_op, tuple(quantize_op_inputs), {}
  446. )
  447. dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
  448. node.replace_all_uses_with(dequantized_node)
  449. graph.erase_node(node)
  450. elif dtype == torch.float16:
  451. node_type = "call_method"
  452. quantize_op = "to" # type: ignore[assignment]
  453. qparams = {"_dtype_": dtype}
  454. with graph.inserting_before(node):
  455. input_node = node.args[0]
  456. quantize_op_inputs = [input_node]
  457. for key, value in qparams.items():
  458. # TODO: we can add the information of whether a value needs to
  459. # be registered as an attribute in qparams dict itself
  460. quantize_op_inputs.append(value)
  461. quantized_node = graph.create_node(
  462. node_type, quantize_op, tuple(quantize_op_inputs), {}
  463. )
  464. dequantized_node = graph.call_method("dequantize", args=(quantized_node,))
  465. node.replace_all_uses_with(dequantized_node)
  466. graph.erase_node(node)
  467. # should not reach since we have checks in the beginning to make sure the
  468. # activation_post_process is supported
  469. # this is a temporary hack for custom module, we may want to implement
  470. # this properly after the custom module class design is finalized
  471. # TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted
  472. # after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs
  473. # after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively.
  474. def _replace_observer_or_dequant_stub_with_dequantize_node(
  475. node: Node, graph: Graph
  476. ) -> None:
  477. call_custom_module_node = node.args[0]
  478. assert isinstance(call_custom_module_node, Node), (
  479. f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
  480. )
  481. node.replace_all_uses_with(call_custom_module_node)
  482. graph.erase_node(node)
  483. _insert_dequantize_node(call_custom_module_node, graph)
  484. def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool:
  485. dtype = activation_post_process.dtype # type: ignore[attr-defined]
  486. is_dynamic = False
  487. if hasattr(activation_post_process, "is_dynamic"):
  488. is_dynamic = activation_post_process.is_dynamic # type: ignore[attr-defined, assignment]
  489. return (
  490. (dtype in SUPPORTED_QDTYPES and (not is_dynamic))
  491. or is_dynamic # type: ignore[return-value]
  492. or dtype == torch.float16
  493. )
  494. def _has_none_qconfig(
  495. node: Argument, node_name_to_qconfig: dict[str, QConfigAny]
  496. ) -> bool:
  497. """Check if a node has a qconfig of None, i.e. user requested to not quantize
  498. the node
  499. """
  500. return (
  501. isinstance(node, Node)
  502. and node.name in node_name_to_qconfig
  503. and node_name_to_qconfig[node.name] is None
  504. )
  505. def _run_weight_observers(observed: GraphModule, backend_config: BackendConfig) -> None:
  506. """Extract the subgraph that produces the weight for dynamic quant
  507. or weight only quant node and run the subgraph to observe the weight.
  508. Note that the observers of dynamic quant or weight only quant ops are
  509. run during the convert step.
  510. """
  511. for node in observed.graph.nodes:
  512. if node.op != "call_function":
  513. continue
  514. for node_arg in node.args:
  515. # node_arg is weight
  516. if node_arg and node_arg_is_weight(node, node_arg):
  517. weight_observer_nodes = collect_producer_nodes(node_arg)
  518. if weight_observer_nodes is None:
  519. continue
  520. weight_observer_module = graph_module_from_producer_nodes(
  521. observed, weight_observer_nodes
  522. )
  523. # run the weight observer
  524. weight_observer_module()
  525. def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> None:
  526. """If the arg is a dequantize Node, or a list/tuple/dict of dequantize Node,
  527. we'll recursively remove the dequantize Node
  528. """
  529. if isinstance(arg, Node) and arg.op == "call_method" and arg.target == "dequantize":
  530. quantize_node = arg.args[0]
  531. # we only replace the specific use since dequantize could be used by other nodes
  532. # as well
  533. node.replace_input_with(arg, quantize_node)
  534. elif isinstance(arg, (list, tuple)):
  535. for arg_element in arg:
  536. _maybe_recursive_remove_dequantize(arg_element, node, graph)
  537. elif isinstance(arg, dict):
  538. for arg_element in arg.values():
  539. _maybe_recursive_remove_dequantize(arg_element, node, graph)
  540. else:
  541. warnings.warn(
  542. f"Unsupported node type in recursive remove dequantize: {type(arg)}"
  543. )
  544. def _get_module_path_and_prefix(
  545. obs_node: Node,
  546. node_name_to_scope: dict[str, tuple[str, type]],
  547. node_name_to_qconfig: dict[str, QConfigAny],
  548. ) -> tuple[str, str]:
  549. """Given and observer node, get the `Scope` or the fully qualified name for
  550. the submodule containing the observed node, also return a prefix of "_input"
  551. when the observed node is an input of a F.linear op, and not the output of another
  552. quantized op.
  553. TODO: this logic is hacky, we should think about how to remove it or make it more
  554. general
  555. """
  556. observed_node = obs_node.args[0]
  557. # an observer can be inserted for both input of the next operator or output of the previous
  558. # operator (they can be the same)
  559. # this flag identifies if the observer is inserted only because the observed node is
  560. # the input of the next operator
  561. assert isinstance(observed_node, Node), (
  562. f"Expecting observed node to be a Node, but got {observed_node}"
  563. )
  564. is_input_observer_only = (
  565. node_name_to_qconfig[observed_node.name] is None
  566. if observed_node.name in node_name_to_qconfig
  567. else None
  568. )
  569. if is_input_observer_only:
  570. # if the quantize function is at the input of op, then we find the first user of the observer_node
  571. # to get the path. If a linear call_function is in the user list, we return the first instance
  572. # of linear node to get the FQN.
  573. users = list(obs_node.users)
  574. first_linear_use_or_first_use = users[0] if users else None
  575. linear_node = None
  576. for n in users:
  577. if n.op == "call_function" and n.target == torch.nn.functional.linear:
  578. linear_node = n
  579. break
  580. if linear_node:
  581. first_linear_use_or_first_use = linear_node
  582. prefix = "_input"
  583. else:
  584. # if the quantize function is at the output of the op, we use the observer input node to get the path
  585. first_linear_use_or_first_use = observed_node
  586. prefix = ""
  587. if (
  588. first_linear_use_or_first_use
  589. and first_linear_use_or_first_use.name in node_name_to_scope
  590. ):
  591. module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name]
  592. else:
  593. # TODO: it's not used, so actually we can skip quantization
  594. # but this requires changing return type of quantize_node
  595. # we can fix it later if needed
  596. module_path = ""
  597. return module_path, prefix
  598. def _insert_dequantize_node(node: Node, graph: Graph) -> None:
  599. """Inserts dequantize node for `node` in `graph`"""
  600. with graph.inserting_after(node):
  601. dequantize_node = graph.call_method("dequantize", (node,))
  602. for user_node in dict(node.users):
  603. if user_node is not dequantize_node:
  604. user_node.replace_input_with(node, dequantize_node)
  605. def _maybe_get_observer_for_node(
  606. node: Node, modules: dict[str, torch.nn.Module]
  607. ) -> Optional[torch.nn.Module]:
  608. """
  609. If the node is observed, return the observer
  610. instance. Otherwise, return None.
  611. """
  612. for maybe_obs_node in node.users.keys():
  613. if maybe_obs_node.op == "call_module":
  614. maybe_obs = modules[str(maybe_obs_node.target)]
  615. if _is_activation_post_process(maybe_obs):
  616. return maybe_obs
  617. return None
  618. def convert_standalone_module(
  619. node: Node,
  620. modules: dict[str, torch.nn.Module],
  621. model: torch.fx.GraphModule,
  622. is_reference: bool,
  623. backend_config: Optional[BackendConfig],
  624. ) -> None:
  625. """Converts a observed standalone module to a quantized standalone module by calling
  626. the fx convert api, currently using the same `is_reference` flag as parent, but we may
  627. changing this behavior in the future (e.g. separating quantization and lowering for
  628. standalone module as well)
  629. Args:
  630. - node: The call_module node of the observed standalone module
  631. - modules: named_module of original model
  632. - model: original model
  633. - is_reference: a flag from parent provided by user to decide if we want to
  634. produce a reference model or a fbgemm/qnnpack model
  635. - backend_config: backend configuration of the target backend of quantization
  636. """
  637. # TODO: remove is_reference flag
  638. if is_reference:
  639. convert_fn = torch.ao.quantization.quantize_fx.convert_to_reference_fx
  640. else:
  641. convert_fn = torch.ao.quantization.quantize_fx.convert_fx # type: ignore[attr-defined]
  642. # We know that observed standalone module is a GraphModule since
  643. # it's produced by us
  644. observed_standalone_module: GraphModule = modules[str(node.target)] # type: ignore[assignment]
  645. sm_input_quantized_idxs = observed_standalone_module.meta[
  646. "_observed_graph_module_attrs"
  647. ].standalone_module_input_quantized_idxs
  648. # remove the dequantize nodes for inputs
  649. args = list(node.args)
  650. for idx in range(len(args)):
  651. if idx in sm_input_quantized_idxs:
  652. arg = args[idx]
  653. if arg.op == "call_method" and arg.target == "dequantize": # type: ignore[union-attr]
  654. quantize_node = arg.args[0] # type: ignore[union-attr]
  655. node.replace_input_with(arg, quantize_node)
  656. if len(arg.users) == 0: # type: ignore[union-attr]
  657. model.graph.erase_node(arg)
  658. # add dequantize node for output
  659. sm_output_quantized_idxs = observed_standalone_module.meta[
  660. "_observed_graph_module_attrs"
  661. ].standalone_module_output_quantized_idxs
  662. if len(sm_output_quantized_idxs) > 0:
  663. assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
  664. "output idxs = [0] is supported"
  665. # if it's non-empty, then it means the output is kept in quantized form
  666. # we'll just add a dequantize node after this node
  667. _insert_dequantize_node(node, model.graph)
  668. # TODO: allow convert_custom_config to override backend_config
  669. # for standalone module
  670. quantized_standalone_module = convert_fn(
  671. observed_standalone_module, backend_config=backend_config
  672. )
  673. parent_name, name = _parent_name(node.target)
  674. # update the modules dict
  675. setattr(modules[parent_name], name, quantized_standalone_module)
  676. modules[str(node.target)] = quantized_standalone_module
  677. def convert_weighted_module(
  678. node: Node,
  679. modules: dict[str, torch.nn.Module],
  680. observed_node_names: set[str],
  681. node_name_to_qconfig: dict[str, QConfigAny],
  682. backend_config: BackendConfig,
  683. is_decomposed: bool = False,
  684. is_reference: bool = False,
  685. model_device: Optional[torch.device] = None,
  686. ) -> None:
  687. """Convert a weighted module to reference quantized module in the model
  688. If the QConfig of a QAT module is not set, the module will still be converted to
  689. a float module.
  690. Args:
  691. - node: The call_module node of the observed standalone module
  692. - modules: named_module of original model
  693. - observed_node_names: names for the set of observed fx node, we can skip
  694. this conversion if the node is not observed
  695. """
  696. original_module = modules[str(node.target)]
  697. qconfig: QConfigAny = original_module.qconfig # type: ignore[assignment]
  698. weight_post_process = None
  699. qat_module_classes = get_qat_module_classes(backend_config)
  700. if isinstance(original_module, qat_module_classes):
  701. # Converting qat module to a float module, we need to attach
  702. # weight fake_quant to the module, weight fake_quant is assumed to be run during
  703. # QAT so we don't need to run it again here
  704. weight_post_process = original_module.weight_fake_quant
  705. original_module = original_module.to_float() # type: ignore[operator]
  706. # change qat module to float module
  707. parent_name, name = _parent_name(node.target)
  708. setattr(modules[parent_name], name, original_module)
  709. is_observed = node.name in observed_node_names
  710. # If a qconfig is not defined for this node, then skip converting to a reference module
  711. if (
  712. qconfig is None
  713. or _has_none_qconfig(node, node_name_to_qconfig)
  714. or not is_observed
  715. ):
  716. return
  717. # skip converting to reference quantized module if the qconfig is not supported
  718. pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
  719. dtype_configs = pattern_to_dtype_configs.get(type(original_module), [])
  720. if not _is_qconfig_supported_by_dtype_configs(qconfig, dtype_configs):
  721. return
  722. # TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
  723. is_weight_quantized = weight_is_quantized(qconfig)
  724. # the condition for swapping the module to reference quantized module is:
  725. # weights need to be quantized
  726. if not is_weight_quantized:
  727. return
  728. fused_module = None
  729. float_module = original_module
  730. # extract the individual float_module and fused module
  731. if isinstance(original_module, torch.ao.nn.intrinsic._FusedModule):
  732. fused_module = float_module
  733. float_module = fused_module[0] # type: ignore[index]
  734. # TODO: move this to the reference quantized module
  735. # weight_qparams or weight_qparams dict
  736. wq_or_wq_dict = {"is_decomposed": is_decomposed}
  737. if isinstance(float_module, torch.nn.RNNCellBase):
  738. weight_post_process_ih = qconfig.weight() # type: ignore[union-attr, operator]
  739. weight_post_process_hh = qconfig.weight() # type: ignore[union-attr, operator]
  740. weight_post_process_ih(float_module.weight_ih)
  741. weight_post_process_hh(float_module.weight_hh)
  742. weight_qparams_ih = get_qparam_dict(weight_post_process_ih)
  743. weight_qparams_hh = get_qparam_dict(weight_post_process_hh)
  744. wq_or_wq_dict.update(
  745. {
  746. "weight_ih": weight_qparams_ih,
  747. "weight_hh": weight_qparams_hh,
  748. }
  749. )
  750. elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)):
  751. # format for wq_or_wq_dict (flattened attributes):
  752. # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...}
  753. for wn in float_module._flat_weights_names:
  754. if hasattr(float_module, wn) and wn.startswith("weight"):
  755. weight = getattr(float_module, wn)
  756. weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
  757. if weight_post_process.dtype == torch.qint8: # type: ignore[union-attr]
  758. weight_post_process(weight) # type: ignore[operator, misc]
  759. wq_or_wq_dict[wn] = get_qparam_dict(weight_post_process)
  760. else:
  761. # weight_post_process is None means the original module is not a QAT module
  762. # we need to get weight_post_process from qconfig in this case
  763. is_ptq = weight_post_process is None
  764. if is_ptq:
  765. weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
  766. if model_device is not None:
  767. device = model_device
  768. else:
  769. device = assert_and_get_unique_device(float_module)
  770. if device:
  771. weight_post_process.to(device)
  772. # Call weight observer/fake_quant at least once to ensure the scales and zero points
  773. # have the right shapes. Note: there are two cases where we don't have to do this:
  774. #
  775. # (1) QAT: The model's forward method already calls the weight observer/fake_quant,
  776. # and this typically happens during training, so we don't need to do it here.
  777. #
  778. # (2) Non-reference (lowered) case: The quantized module's from_float method already
  779. # calls the weight observer/fake_quant, so we don't have to do it here.
  780. #
  781. # Currently we ignore both cases and call the weight observer/fake_quant here
  782. # regardless, which is technically incorrect. For (1), this is mainly to preserve BC
  783. # in test code, which may not always train before convert. In the future, we should
  784. # break BC for these two cases. See https://github.com/pytorch/pytorch/issues/73941.
  785. #
  786. # For PT2, however, we don't need to preserve BC here, so we can skip this hack
  787. # for QAT. We identify this case as (is_decomposed + is_reference + is_qat).
  788. # Note that we still need it for PTQ in the PT2 flow since the model's forward
  789. # method doesn't call the weight observer.
  790. is_qat = not is_ptq
  791. if not (is_decomposed and is_reference and is_qat):
  792. weight_post_process(float_module.weight) # type: ignore[operator]
  793. wq_or_wq_dict.update(get_qparam_dict(weight_post_process))
  794. # We use the same reference module for all modes of quantization: static, dynamic, weight_only
  795. # root_module_to_quantized_reference_module: module mapping from root (floating point) module class
  796. # to quantized reference module class, e.g. nn.Conv2d to nn.quantized._reference.Conv2d
  797. root_module_to_quantized_reference_module = (
  798. get_root_module_to_quantized_reference_module(backend_config)
  799. )
  800. ref_qmodule_cls = root_module_to_quantized_reference_module.get(
  801. type_before_parametrizations(float_module), None
  802. )
  803. assert ref_qmodule_cls is not None, (
  804. f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
  805. )
  806. ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
  807. if fused_module is not None:
  808. fused_module[0] = ref_qmodule # type: ignore[operator]
  809. else:
  810. parent_name, name = _parent_name(node.target)
  811. setattr(modules[parent_name], name, ref_qmodule)
  812. def _remove_previous_dequantize_in_custom_module(
  813. node: Node, prev_node: Node, graph: Graph
  814. ) -> None:
  815. """
  816. Given a custom module `node`, if the previous node is a dequantize, reroute the custom as follows:
  817. Before: quantize - dequantize - custom_module
  818. After: quantize - custom_module
  819. \\ - dequantize
  820. """
  821. # expecting the input node for a custom module node to be a Node
  822. assert isinstance(prev_node, Node), (
  823. f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
  824. )
  825. if prev_node.op == "call_method" and prev_node.target == "dequantize":
  826. node.replace_input_with(prev_node, prev_node.args[0])
  827. # Remove the dequantize node if it doesn't have other users
  828. if len(prev_node.users) == 0:
  829. graph.erase_node(prev_node)
  830. def convert_custom_module(
  831. node: Node,
  832. graph: Graph,
  833. modules: dict[str, torch.nn.Module],
  834. custom_module_class_mapping: dict[QuantType, dict[type, type]],
  835. statically_quantized_custom_module_nodes: set[Node],
  836. ) -> None:
  837. """Converts an observed custom module to a quantized custom module based on
  838. `custom_module_class_mapping`
  839. For static quantization, we'll also remove the previous `dequantize` node and
  840. attach the observer node for output to the module, the observer for the node
  841. will be converted to a dequantize node instead of quantize-dequantize pairs
  842. later in the graph. In the end we would have a quantized custom module that
  843. has the same interface as a default quantized module in nn.quantized namespace,
  844. i.e. quantized input and quantized output.
  845. Args:
  846. - node: The call_module node of the observed standalone module
  847. - graph: The graph containing the node
  848. - modules: named_module of original model
  849. - custom_module_class_mapping: mapping from observed custom module class to
  850. quantized custom module class, used to swap custom modules
  851. - statically_quantized_custom_module_nodes: we'll add the custom module node
  852. if we find it is statically quantized, this will be used later when converting
  853. observers to quant/dequant node pairs, if the observed node is a statically
  854. quantized custom module nodes, we'll convert the observer to a dequantize node,
  855. this is to keep the interface the same as the default quantized module.
  856. TODO: maybe we want to redesign this part to align with reference model design
  857. as well, but there has been some discussions around the interface, so we can do
  858. it later.
  859. """
  860. observed_custom_module = modules[str(node.target)]
  861. qconfig = observed_custom_module.qconfig
  862. if activation_is_statically_quantized(qconfig):
  863. statically_quantized_custom_module_nodes.add(node)
  864. if _is_custom_module_lstm(node, modules):
  865. # The inputs are tuples in the form (input, (hidden0, hidden1))
  866. # Ensure all three input nodes are quantized
  867. assert (
  868. len(node.args) == 2
  869. and isinstance(node.args[1], tuple)
  870. and len(node.args[1]) == 2
  871. )
  872. (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
  873. assert isinstance(inputs, Node)
  874. assert isinstance(hidden0, Node)
  875. assert isinstance(hidden1, Node)
  876. _remove_previous_dequantize_in_custom_module(node, inputs, graph)
  877. _remove_previous_dequantize_in_custom_module(node, hidden0, graph)
  878. _remove_previous_dequantize_in_custom_module(node, hidden1, graph)
  879. elif _is_custom_module_mha(node, modules):
  880. # Inputs are in the form (query, key, value)
  881. # TODO: This is the first step in enabling the full fx custom module
  882. # quantization path for MultiheadAttention, and only covers the inputs
  883. # to the module.
  884. # Additional handling is yet to be implemented for the outputs, similar
  885. # to LSTM custom module
  886. assert len(node.args) == 3
  887. query, key, value = node.args
  888. assert isinstance(query, Node)
  889. assert isinstance(key, Node)
  890. assert isinstance(value, Node)
  891. _remove_previous_dequantize_in_custom_module(node, query, graph)
  892. _remove_previous_dequantize_in_custom_module(node, key, graph)
  893. _remove_previous_dequantize_in_custom_module(node, value, graph)
  894. else:
  895. # remove the previous dequant node to ensure the inputs are quantized
  896. arg = node.args[0]
  897. assert isinstance(arg, Node)
  898. _remove_previous_dequantize_in_custom_module(node, arg, graph)
  899. # absorb the following observer into the module conversion
  900. activation_post_process = _maybe_get_observer_for_node(node, modules)
  901. assert activation_post_process is not None
  902. observed_custom_module.activation_post_process = activation_post_process
  903. # swap the observed custom module to quantized custom module
  904. quantized_custom_module_class = get_swapped_custom_module_class(
  905. observed_custom_module, custom_module_class_mapping, qconfig
  906. )
  907. quantized_custom_module = quantized_custom_module_class.from_observed(
  908. observed_custom_module
  909. )
  910. parent_name, name = _parent_name(node.target)
  911. setattr(modules[parent_name], name, quantized_custom_module)
  912. def convert(
  913. model: GraphModule,
  914. is_reference: bool = False,
  915. convert_custom_config: Union[ConvertCustomConfig, dict[str, Any], None] = None,
  916. is_standalone_module: bool = False,
  917. _remove_qconfig_flag: bool = True,
  918. qconfig_mapping: Union[QConfigMapping, dict[str, Any], None] = None,
  919. backend_config: Union[BackendConfig, dict[str, Any], None] = None,
  920. is_decomposed: bool = False,
  921. keep_original_weights: bool = False,
  922. ) -> GraphModule:
  923. """
  924. We will convert an observed model (a module with observer calls) to a reference
  925. quantized model, the rule is simple:
  926. 1. for each observer module call in the graph, we'll convert it to calls to
  927. quantize and dequantize functions based on the observer instance
  928. 2. for weighted operations like linear/conv, we need to convert them to reference
  929. quantized module, this requires us to know whether the dtype configured for the
  930. weight is supported in the backend, this is done in prepare step and the result
  931. is stored in observed_node_names, we can decide whether we need to swap the
  932. module based on this set
  933. Args:
  934. * `is_standalone_module`: when this flag is True, it means we are quantizing
  935. a submodule that is not inlined in parent module, and will be quantized
  936. separately as one unit.
  937. * `is_decomposed`: a boolean flag to indicate whether we want to use the
  938. quantize operator for decomposed quantized tensor
  939. (torch.ops.quantized_decomposed.quantize_per_tensor) or default/standalone
  940. quantized tensor (torch.quantize_per_tensor)
  941. Returns:
  942. a quantized standalone module, whether input/output is quantized is
  943. specified by prepare_custom_config, with
  944. input_quantized_idxs, output_quantized_idxs, please
  945. see docs for :func:`~torch.ao.quantization.prepare_fx` for details
  946. """
  947. if convert_custom_config is None:
  948. convert_custom_config = ConvertCustomConfig()
  949. if isinstance(convert_custom_config, dict):
  950. warnings.warn(
  951. "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
  952. "in a future version. Please pass in a ConvertCustomConfig instead.",
  953. FutureWarning,
  954. stacklevel=2,
  955. )
  956. convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
  957. if isinstance(qconfig_mapping, dict):
  958. warnings.warn(
  959. "Passing a QConfig dictionary to convert is deprecated and will not be supported "
  960. "in a future version. Please pass in a QConfigMapping instead.",
  961. FutureWarning,
  962. stacklevel=2,
  963. )
  964. qconfig_mapping = (
  965. QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
  966. )
  967. qconfig_mapping = copy.deepcopy(qconfig_mapping)
  968. assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
  969. if isinstance(backend_config, dict):
  970. warnings.warn(
  971. "Passing a backend_config_dict to prepare is deprecated and will not be supported "
  972. "in a future version. Please pass in a BackendConfig instead.",
  973. FutureWarning,
  974. stacklevel=2,
  975. )
  976. backend_config = BackendConfig.from_dict(backend_config)
  977. if backend_config is None:
  978. backend_config = get_native_backend_config()
  979. assert _is_observed_module(model), "incoming model must be produced by prepare_fx"
  980. observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
  981. node_name_to_scope: dict[str, tuple[str, type]] = (
  982. observed_graph_module_attrs.node_name_to_scope
  983. )
  984. prepare_custom_config: PrepareCustomConfig = (
  985. observed_graph_module_attrs.prepare_custom_config
  986. )
  987. observed_node_names: set[str] = observed_graph_module_attrs.observed_node_names
  988. node_name_to_qconfig: dict[str, QConfigAny] = (
  989. observed_graph_module_attrs.node_name_to_qconfig
  990. ) # type: ignore[assignment]
  991. # mapping from fully qualified module name to module instance
  992. # for example,
  993. # {
  994. # '': Model(...),
  995. # 'linear': Linear(...),
  996. # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
  997. # }
  998. # We use remove_duplicate=False here because torch.cat uses
  999. # the same activation_post_process module instance but different names
  1000. modules = dict(model.named_modules(remove_duplicate=False))
  1001. # TODO refactor this code once we update the prepare logic to have additional information on
  1002. # which graph nodes have been observed and share that with convert to decide which observers to ignore.
  1003. if qconfig_mapping:
  1004. prepare_qconfig_mapping: QConfigMapping = (
  1005. observed_graph_module_attrs.qconfig_mapping
  1006. ) # type: ignore[assignment]
  1007. modules_copy = copy.deepcopy(modules)
  1008. if observed_graph_module_attrs.is_qat:
  1009. _update_qconfig_for_qat(qconfig_mapping, backend_config)
  1010. _update_qconfig_for_fusion(model, qconfig_mapping)
  1011. _compare_prepare_convert_qconfig_mappings(
  1012. prepare_qconfig_mapping, qconfig_mapping
  1013. ) # type: ignore[arg-type]
  1014. convert_node_name_to_qconfig = _generate_node_name_to_qconfig(
  1015. model, modules_copy, model.graph, qconfig_mapping, node_name_to_scope
  1016. )
  1017. # check the convert_node_name_to_qconfig generated and ensure that
  1018. # all the values either match what was set in prepare node_name_to_qconfig
  1019. # or are set to None in the convert_node_name_to_qconfig.
  1020. for k, v in node_name_to_qconfig.items():
  1021. assert k in convert_node_name_to_qconfig, (
  1022. f"Expected key {k} in convert node_name_to_qconfig"
  1023. )
  1024. if convert_node_name_to_qconfig[k] is not None:
  1025. assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
  1026. f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
  1027. f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
  1028. )
  1029. node_name_to_qconfig = convert_node_name_to_qconfig
  1030. custom_module_classes = get_custom_module_class_keys(
  1031. convert_custom_config.observed_to_quantized_mapping
  1032. )
  1033. custom_module_class_mapping = convert_custom_config.observed_to_quantized_mapping
  1034. if observed_graph_module_attrs.equalization_node_name_to_qconfig is not None:
  1035. # If we want to do equalization then do the following:
  1036. # Calculate the equalization scale, update the observers with the scaled
  1037. # inputs, and scale the weight
  1038. weight_eq_obs_dict = update_obs_for_equalization(model, modules)
  1039. convert_eq_obs(model, modules, weight_eq_obs_dict)
  1040. # always run weight observers in the top level forward method
  1041. # for dynamic quant ops or weight only quant ops
  1042. _run_weight_observers(model, backend_config)
  1043. # additional state to override inputs to be quantized, if specified
  1044. # by the user
  1045. placeholder_node_seen_cnt = 0
  1046. input_quantized_idxs: list[int] = prepare_custom_config.input_quantized_indexes
  1047. output_quantized_idxs: list[int] = prepare_custom_config.output_quantized_indexes
  1048. root_module_to_quantized_reference_module = (
  1049. get_root_module_to_quantized_reference_module(backend_config)
  1050. )
  1051. # convert tuples so that it can work with isinstance(module, tuple_of_classes)
  1052. root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
  1053. qat_module_classes = get_qat_module_classes(backend_config)
  1054. fused_module_classes = get_fused_module_classes(backend_config)
  1055. statically_quantized_custom_module_nodes: set[Node] = set()
  1056. model_device = assert_and_get_unique_device(model)
  1057. for node in list(model.graph.nodes):
  1058. if node.op == "placeholder":
  1059. cur_placeholder_node_idx = placeholder_node_seen_cnt
  1060. placeholder_node_seen_cnt += 1
  1061. if cur_placeholder_node_idx in input_quantized_idxs:
  1062. # Inputs are assumed to be quantized if the user specified the
  1063. # input_quantized_idxs override.
  1064. # we need to dequantize the inputs since all operators took
  1065. # floating point inputs in reference quantized models
  1066. _insert_dequantize_node(node, model.graph)
  1067. elif node.op == "output":
  1068. # If the argument is empty we don't need to do anything
  1069. if len(output_quantized_idxs) == 0:
  1070. continue
  1071. # Result are kept quantized if the user specified the
  1072. # output_quantized_idxs override.
  1073. # Remove the dequantize operator for the node in the end if any
  1074. return_node = node
  1075. output = node.args[0]
  1076. # outputs can be Node, list, tuple, dict, other cases are not supported yet
  1077. if isinstance(output, (list, tuple)):
  1078. for idx in output_quantized_idxs:
  1079. _maybe_recursive_remove_dequantize(
  1080. output[idx], return_node, model.graph
  1081. )
  1082. elif isinstance(output, (Node, dict)):
  1083. # we treat dict as a single argument currently, but it can be extended
  1084. # to support {"key": dtype} after we change output_quantized_idxs to
  1085. # dict
  1086. if 0 in output_quantized_idxs:
  1087. _maybe_recursive_remove_dequantize(output, return_node, model.graph)
  1088. else:
  1089. warnings.warn(
  1090. f"Unsupported node type for output_quantized_idxs: {type(output)}"
  1091. )
  1092. elif node.op == "call_module":
  1093. mod = _get_module(node, modules)
  1094. assert mod is not None
  1095. if _is_activation_post_process(mod):
  1096. observed_node = node.args[0]
  1097. if observed_node in statically_quantized_custom_module_nodes:
  1098. _replace_observer_or_dequant_stub_with_dequantize_node(
  1099. node, model.graph
  1100. )
  1101. else:
  1102. if is_decomposed:
  1103. _replace_observer_with_quantize_dequantize_node_decomposed(
  1104. model,
  1105. node,
  1106. modules,
  1107. node_name_to_scope,
  1108. node_name_to_qconfig,
  1109. model_device,
  1110. )
  1111. else:
  1112. _replace_observer_with_quantize_dequantize_node(
  1113. model,
  1114. node,
  1115. modules,
  1116. node_name_to_scope,
  1117. node_name_to_qconfig,
  1118. model_device,
  1119. )
  1120. elif isinstance(mod, DeQuantStub):
  1121. _replace_observer_or_dequant_stub_with_dequantize_node(
  1122. node, model.graph
  1123. )
  1124. elif _is_observed_standalone_module(mod):
  1125. convert_standalone_module(
  1126. node, modules, model, is_reference, backend_config
  1127. )
  1128. # below this point `type_before_parametrizations` is used
  1129. # instead of `type` to handle situations with fx quant + sparsity
  1130. elif type_before_parametrizations(mod) in set(root_module_classes).union(
  1131. qat_module_classes
  1132. ).union(fused_module_classes):
  1133. # extra check for fused module classes to make sure they are fused module classes
  1134. # of target modules
  1135. if (
  1136. type_before_parametrizations(mod) in fused_module_classes
  1137. and type_before_parametrizations(mod[0]) not in root_module_classes
  1138. ): # type: ignore[index]
  1139. continue
  1140. convert_weighted_module(
  1141. node,
  1142. modules,
  1143. observed_node_names,
  1144. node_name_to_qconfig,
  1145. backend_config,
  1146. is_decomposed,
  1147. is_reference,
  1148. model_device,
  1149. )
  1150. elif type_before_parametrizations(mod) in custom_module_classes:
  1151. convert_custom_module(
  1152. node,
  1153. model.graph,
  1154. modules,
  1155. custom_module_class_mapping,
  1156. statically_quantized_custom_module_nodes,
  1157. )
  1158. # remove deadcode after converting observers to quant/dequant ops
  1159. model.graph.eliminate_dead_code()
  1160. model = GraphModule(model, model.graph)
  1161. # TODO: maybe move this to quantize_fx.py
  1162. if not is_reference:
  1163. model = lower_to_fbgemm(
  1164. model, node_name_to_qconfig, node_name_to_scope, keep_original_weights
  1165. )
  1166. # TODO: this looks hacky, we want to check why we need this and see if we can
  1167. # remove this
  1168. # removes qconfig and activation_post_process modules
  1169. if _remove_qconfig_flag:
  1170. _remove_qconfig(model)
  1171. model.delete_all_unused_submodules()
  1172. model.meta.pop("_observed_graph_module_attrs", None)
  1173. return model