quantize_pt2e.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import typing_extensions
  2. import torch
  3. from torch._export.passes.constant_folding import constant_fold
  4. from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass
  5. from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ
  6. from torch.ao.quantization.quantizer import ( # noqa: F401
  7. DerivedQuantizationSpec,
  8. FixedQParamsQuantizationSpec,
  9. QuantizationAnnotation,
  10. QuantizationSpec,
  11. QuantizationSpecBase,
  12. Quantizer,
  13. SharedQuantizationSpec,
  14. )
  15. from torch.fx import GraphModule, Node
  16. from torch.fx.passes.infra.pass_manager import PassManager
  17. from .pt2e.prepare import prepare
  18. from .pt2e.qat_utils import _fold_conv_bn_qat, _fuse_conv_bn_qat
  19. from .pt2e.representation import reference_representation_rewrite
  20. from .pt2e.utils import _disallow_eval_train, _fuse_conv_bn_, _get_node_name_to_scope
  21. from .quantize_fx import _convert_to_reference_decomposed_fx
  22. from .utils import DEPRECATION_WARNING
  23. __all__ = [
  24. "prepare_pt2e",
  25. "prepare_qat_pt2e",
  26. "convert_pt2e",
  27. ]
  28. @typing_extensions.deprecated(DEPRECATION_WARNING)
  29. def prepare_pt2e(
  30. model: GraphModule,
  31. quantizer: Quantizer,
  32. ) -> GraphModule:
  33. """Prepare a model for post training quantization
  34. Args:
  35. * `model` (torch.fx.GraphModule): a model captured by `torch.export.export_for_training` API.
  36. * `quantizer`: A backend specific quantizer that conveys how user want the
  37. model to be quantized. Tutorial for how to write a quantizer can be found here:
  38. https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
  39. Return:
  40. A GraphModule with observer (based on quantizer annotation), ready for calibration
  41. Example::
  42. import torch
  43. from torch.ao.quantization.quantize_pt2e import prepare_pt2e
  44. from torch.ao.quantization.quantizer import (
  45. XNNPACKQuantizer,
  46. get_symmetric_quantization_config,
  47. )
  48. class M(torch.nn.Module):
  49. def __init__(self) -> None:
  50. super().__init__()
  51. self.linear = torch.nn.Linear(5, 10)
  52. def forward(self, x):
  53. return self.linear(x)
  54. # initialize a floating point model
  55. float_model = M().eval()
  56. # define calibration function
  57. def calibrate(model, data_loader):
  58. model.eval()
  59. with torch.no_grad():
  60. for image, target in data_loader:
  61. model(image)
  62. # Step 1. program capture
  63. # NOTE: this API will be updated to torch.export API in the future, but the captured
  64. # result should mostly stay the same
  65. m = torch.export.export_for_training(m, *example_inputs).module()
  66. # we get a model with aten ops
  67. # Step 2. quantization
  68. # backend developer will write their own Quantizer and expose methods to allow
  69. # users to express how they
  70. # want the model to be quantized
  71. quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
  72. m = prepare_pt2e(m, quantizer)
  73. # run calibration
  74. # calibrate(m, sample_inference_data)
  75. """
  76. torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e")
  77. original_graph_meta = model.meta
  78. node_name_to_scope = _get_node_name_to_scope(model)
  79. # TODO: check qconfig_mapping to make sure conv and bn are both configured
  80. # to be quantized before fusion
  81. # TODO: (maybe) rewrite this with subgraph_rewriter
  82. _fuse_conv_bn_(model)
  83. model = quantizer.transform_for_annotation(model)
  84. quantizer.annotate(model)
  85. quantizer.validate(model)
  86. model = prepare(
  87. model,
  88. node_name_to_scope,
  89. is_qat=False,
  90. obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
  91. )
  92. model.meta.update(original_graph_meta)
  93. model = _disallow_eval_train(model)
  94. return model
  95. @typing_extensions.deprecated(DEPRECATION_WARNING)
  96. def prepare_qat_pt2e(
  97. model: GraphModule,
  98. quantizer: Quantizer,
  99. ) -> GraphModule:
  100. """Prepare a model for quantization aware training
  101. Args:
  102. * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
  103. * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e`
  104. Return:
  105. A GraphModule with fake quant modules (based on quantizer annotation), ready for
  106. quantization aware training
  107. Example::
  108. import torch
  109. from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
  110. from torch.ao.quantization.quantizer import (
  111. XNNPACKQuantizer,
  112. get_symmetric_quantization_config,
  113. )
  114. class M(torch.nn.Module):
  115. def __init__(self) -> None:
  116. super().__init__()
  117. self.linear = torch.nn.Linear(5, 10)
  118. def forward(self, x):
  119. return self.linear(x)
  120. # initialize a floating point model
  121. float_model = M().eval()
  122. # define the training loop for quantization aware training
  123. def train_loop(model, train_data):
  124. model.train()
  125. for image, target in data_loader:
  126. ...
  127. # Step 1. program capture
  128. # NOTE: this API will be updated to torch.export API in the future, but the captured
  129. # result should mostly stay the same
  130. m = torch.export.export_for_training(m, *example_inputs).module()
  131. # we get a model with aten ops
  132. # Step 2. quantization
  133. # backend developer will write their own Quantizer and expose methods to allow
  134. # users to express how they
  135. # want the model to be quantized
  136. quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
  137. m = prepare_qat_pt2e(m, quantizer)
  138. # run quantization aware training
  139. train_loop(prepared_model, train_loop)
  140. """
  141. torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e")
  142. original_graph_meta = model.meta
  143. node_name_to_scope = _get_node_name_to_scope(model)
  144. model = quantizer.transform_for_annotation(model)
  145. quantizer.annotate(model)
  146. quantizer.validate(model)
  147. # Perform fusion after annotate to avoid quantizing ops in the new
  148. # subgraph that don't need to be quantized
  149. # TODO: only fuse if conv and bn are both configured to be quantized
  150. _fuse_conv_bn_qat(model)
  151. model = prepare(
  152. model,
  153. node_name_to_scope,
  154. is_qat=True,
  155. obs_or_fq_callback=quantizer.prepare_obs_or_fq_callback,
  156. )
  157. model.meta.update(original_graph_meta)
  158. model = _disallow_eval_train(model)
  159. return model
  160. _QUANT_OPS = [
  161. torch.ops.quantized_decomposed.quantize_per_tensor.default,
  162. torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
  163. torch.ops.quantized_decomposed.quantize_per_channel.default,
  164. torch.ops.pt2e_quant.quantize_affine,
  165. ]
  166. def _quant_node_constraint(n: Node) -> bool:
  167. """If there is any pure ops between get_attr and quantize op they will be const propagated
  168. e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
  169. (Note: dequantize op is not going to be constant propagated)
  170. This filter is added because we don't want to constant fold the things that are not
  171. related to quantization
  172. """
  173. return n.op == "call_function" and n.target in _QUANT_OPS
  174. @typing_extensions.deprecated(DEPRECATION_WARNING)
  175. def convert_pt2e(
  176. model: GraphModule,
  177. use_reference_representation: bool = False,
  178. fold_quantize: bool = True,
  179. ) -> GraphModule:
  180. """Convert a calibrated/trained model to a quantized model
  181. Args:
  182. * `model` (torch.fx.GraphModule): calibrated/trained model
  183. * `use_reference_representation` (bool): boolean flag to indicate whether to produce reference representation or not
  184. * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not
  185. Returns:
  186. quantized model, either in q/dq representation or reference representation
  187. Example::
  188. # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training
  189. # `convert_pt2e` produces a quantized model that represents quantized computation with
  190. # quantize dequantize ops and fp32 ops by default.
  191. # Please refer to
  192. # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model
  193. # for detailed explanation of output quantized model
  194. quantized_model = convert_pt2e(prepared_model)
  195. """
  196. torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e")
  197. if not isinstance(use_reference_representation, bool):
  198. raise ValueError(
  199. "Unexpected argument type for `use_reference_representation`, "
  200. f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e"
  201. )
  202. original_graph_meta = model.meta
  203. model = _convert_to_reference_decomposed_fx(model)
  204. model = _fold_conv_bn_qat(model)
  205. pm = PassManager([DuplicateDQPass()])
  206. model = pm(model).graph_module
  207. pm = PassManager([PortNodeMetaForQDQ()])
  208. model = pm(model).graph_module
  209. if fold_quantize:
  210. constant_fold(model, _quant_node_constraint)
  211. if use_reference_representation:
  212. model = reference_representation_rewrite(model)
  213. model.meta.update(original_graph_meta)
  214. model = _disallow_eval_train(model)
  215. return model