quantize.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import copy
  16. from paddle.nn import Layer
  17. from paddle.nn.quant.format import (
  18. ConvertibleQuantedLayer,
  19. LinearQuanterDequanter,
  20. )
  21. from .base_quanter import BaseQuanter
  22. from .config import QuantConfig
  23. class Quantization(metaclass=abc.ABCMeta):
  24. r"""
  25. Abstract class used to prepares a copy of the model for quantization calibration or quantization-aware training.
  26. Args:
  27. config(QuantConfig): Quantization configuration
  28. """
  29. def __init__(self, config: QuantConfig):
  30. self._config = copy.deepcopy(config)
  31. @abc.abstractmethod
  32. def quantize(self, model: Layer, inplace=False):
  33. r"""Create a model for quantization-aware training or post-training quantization."""
  34. pass
  35. def convert(self, model: Layer, inplace=False, remain_weight=False):
  36. r"""Convert the quantization model to ONNX style. And the converted
  37. model can be saved as inference model by calling paddle.jit.save.
  38. Args:
  39. model(Layer): The quantized model to be converted.
  40. inplace(bool, optional): Whether to modify the model in-place, default is False.
  41. remain_weight(bool, optional): Whether to remain weights in floats, default is False.
  42. Return: The converted model
  43. Examples:
  44. .. code-block:: python
  45. >>> import paddle
  46. >>> from paddle.quantization import QAT, QuantConfig
  47. >>> from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
  48. >>> from paddle.vision.models import LeNet
  49. >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
  50. >>> q_config = QuantConfig(activation=quanter, weight=quanter)
  51. >>> qat = QAT(q_config)
  52. >>> model = LeNet()
  53. >>> quantized_model = qat.quantize(model)
  54. >>> converted_model = qat.convert(quantized_model)
  55. >>> dummy_data = paddle.rand([1, 1, 32, 32], dtype="float32")
  56. >>> paddle.jit.save(converted_model, "./quant_deploy", [dummy_data])
  57. """
  58. _model = model if inplace else copy.deepcopy(model)
  59. replaced = {}
  60. for name, child in _model.named_children():
  61. quant_dequant = None
  62. if isinstance(child, ConvertibleQuantedLayer):
  63. if child.converted:
  64. continue
  65. if hasattr(child, 'weight_quanter') and (
  66. child.weight_quanter is None
  67. or child.weight_quanter.scales() is None
  68. ):
  69. continue
  70. child._convert(remain_weight=remain_weight)
  71. elif isinstance(child, BaseQuanter):
  72. quant_dequant = LinearQuanterDequanter.from_quanter(child)
  73. else:
  74. self.convert(child, inplace=True, remain_weight=remain_weight)
  75. if quant_dequant is not None:
  76. replaced[name] = quant_dequant
  77. for key, value in replaced.items():
  78. _model._sub_layers[key] = value
  79. return _model
  80. def _convert_to_quant_layers(self, model: Layer, config: QuantConfig):
  81. replaced = {}
  82. for name, child in model.named_children():
  83. if (
  84. config._is_quantifiable(child)
  85. and type(child) in config.qat_layer_mappings
  86. ):
  87. replaced[name] = config._get_qat_layer(child)
  88. else:
  89. self._convert_to_quant_layers(child, config)
  90. for key, value in replaced.items():
  91. model._sub_layers[key] = value
  92. def _insert_activation_observers(self, model: Layer, config: QuantConfig):
  93. replaced = {}
  94. for name, child in model.named_children():
  95. if config._need_observe(child):
  96. replaced[name] = config._get_observe_wrapper(child)
  97. else:
  98. if (
  99. type(child) not in config._qat_layer_mapping.values()
  100. and type(child)
  101. not in config._customized_qat_layer_mapping.values()
  102. ):
  103. self._insert_activation_observers(child, config)
  104. for key, value in replaced.items():
  105. model._sub_layers[key] = value
  106. def _details(self):
  107. return self._config.details()
  108. def __str__(self):
  109. return self._details()
  110. def __repr__(self):
  111. return self.__str__()