ptq.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from paddle.distributed import fleet
  16. from paddle.nn import Layer
  17. from .config import QuantConfig
  18. from .quantize import Quantization
  19. class PTQ(Quantization):
  20. """
  21. Applying post training quantization to the model.
  22. """
  23. def __init__(self, config: QuantConfig):
  24. super().__init__(config)
  25. def _is_parallel_training(self):
  26. try:
  27. if fleet.worker_num() > 2:
  28. return True
  29. else:
  30. return False
  31. except Exception: # fleet is not initialized
  32. return False
  33. def quantize(self, model: Layer, inplace=False):
  34. r"""
  35. Create a model for post-training quantization.
  36. The quantization configuration will be propagated in the model.
  37. And it will insert observers into the model to collect and compute
  38. quantization parameters.
  39. Args:
  40. model(Layer): The model to be quantized.
  41. inplace(bool): Whether to modify the model in-place.
  42. Return: The prepared model for post-training quantization.
  43. Examples:
  44. .. code-block:: python
  45. >>> from paddle.quantization import PTQ, QuantConfig
  46. >>> from paddle.quantization.observers import AbsmaxObserver
  47. >>> from paddle.vision.models import LeNet
  48. >>> observer = AbsmaxObserver()
  49. >>> q_config = QuantConfig(activation=observer, weight=observer)
  50. >>> ptq = PTQ(q_config)
  51. >>> model = LeNet()
  52. >>> model.eval()
  53. >>> quant_model = ptq.quantize(model)
  54. >>> print(quant_model)
  55. LeNet(
  56. (features): Sequential(
  57. (0): QuantedConv2D(
  58. (weight_quanter): AbsmaxObserverLayer()
  59. (activation_quanter): AbsmaxObserverLayer()
  60. )
  61. (1): ObserveWrapper(
  62. (_observer): AbsmaxObserverLayer()
  63. (_observed): ReLU()
  64. )
  65. (2): ObserveWrapper(
  66. (_observer): AbsmaxObserverLayer()
  67. (_observed): MaxPool2D(kernel_size=2, stride=2, padding=0)
  68. )
  69. (3): QuantedConv2D(
  70. (weight_quanter): AbsmaxObserverLayer()
  71. (activation_quanter): AbsmaxObserverLayer()
  72. )
  73. (4): ObserveWrapper(
  74. (_observer): AbsmaxObserverLayer()
  75. (_observed): ReLU()
  76. )
  77. (5): ObserveWrapper(
  78. (_observer): AbsmaxObserverLayer()
  79. (_observed): MaxPool2D(kernel_size=2, stride=2, padding=0)
  80. )
  81. )
  82. (fc): Sequential(
  83. (0): QuantedLinear(
  84. (weight_quanter): AbsmaxObserverLayer()
  85. (activation_quanter): AbsmaxObserverLayer()
  86. )
  87. (1): QuantedLinear(
  88. (weight_quanter): AbsmaxObserverLayer()
  89. (activation_quanter): AbsmaxObserverLayer()
  90. )
  91. (2): QuantedLinear(
  92. (weight_quanter): AbsmaxObserverLayer()
  93. (activation_quanter): AbsmaxObserverLayer()
  94. )
  95. )
  96. )
  97. """
  98. _model = model
  99. if not inplace:
  100. assert (
  101. not self._is_parallel_training()
  102. ), "'inplace' is not compatible with parallel training."
  103. _model = copy.deepcopy(model)
  104. _model.eval()
  105. assert (
  106. not model.training
  107. ), "Post-Training Quantization should not work on training models. Please set evaluation mode by model.eval()."
  108. self._config._specify(_model)
  109. self._convert_to_quant_layers(_model, self._config)
  110. self._insert_activation_observers(_model, self._config)
  111. return _model