qat.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. from paddle.nn import Layer
  16. from .config import QuantConfig
  17. from .quantize import Quantization
  18. class QAT(Quantization):
  19. r"""
  20. Tools used to prepare model for quantization-aware training.
  21. Args:
  22. config(QuantConfig): Quantization configuration
  23. Examples:
  24. .. code-block:: python
  25. >>> from paddle.quantization import QAT, QuantConfig
  26. >>> from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
  27. >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
  28. >>> q_config = QuantConfig(activation=quanter, weight=quanter)
  29. >>> qat = QAT(q_config)
  30. """
  31. def __init__(self, config: QuantConfig):
  32. super().__init__(config)
  33. def quantize(self, model: Layer, inplace=False):
  34. r"""
  35. Create a model for quantization-aware training.
  36. The quantization configuration will be propagated in the model.
  37. And it will insert fake quanters into the model to simulate the quantization.
  38. Args:
  39. model(Layer): The model to be quantized.
  40. inplace(bool): Whether to modify the model in-place.
  41. Return: The prepared model for quantization-aware training.
  42. Examples:
  43. .. code-block:: python
  44. >>> from paddle.quantization import QAT, QuantConfig
  45. >>> from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
  46. >>> from paddle.vision.models import LeNet
  47. >>> quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
  48. >>> q_config = QuantConfig(activation=quanter, weight=quanter)
  49. >>> qat = QAT(q_config)
  50. >>> model = LeNet()
  51. >>> quant_model = qat.quantize(model)
  52. >>> print(quant_model)
  53. LeNet(
  54. (features): Sequential(
  55. (0): QuantedConv2D(
  56. (weight_quanter): FakeQuanterWithAbsMaxObserverLayer()
  57. (activation_quanter): FakeQuanterWithAbsMaxObserverLayer()
  58. )
  59. (1): ObserveWrapper(
  60. (_observer): FakeQuanterWithAbsMaxObserverLayer()
  61. (_observed): ReLU()
  62. )
  63. (2): ObserveWrapper(
  64. (_observer): FakeQuanterWithAbsMaxObserverLayer()
  65. (_observed): MaxPool2D(kernel_size=2, stride=2, padding=0)
  66. )
  67. (3): QuantedConv2D(
  68. (weight_quanter): FakeQuanterWithAbsMaxObserverLayer()
  69. (activation_quanter): FakeQuanterWithAbsMaxObserverLayer()
  70. )
  71. (4): ObserveWrapper(
  72. (_observer): FakeQuanterWithAbsMaxObserverLayer()
  73. (_observed): ReLU()
  74. )
  75. (5): ObserveWrapper(
  76. (_observer): FakeQuanterWithAbsMaxObserverLayer()
  77. (_observed): MaxPool2D(kernel_size=2, stride=2, padding=0)
  78. )
  79. )
  80. (fc): Sequential(
  81. (0): QuantedLinear(
  82. (weight_quanter): FakeQuanterWithAbsMaxObserverLayer()
  83. (activation_quanter): FakeQuanterWithAbsMaxObserverLayer()
  84. )
  85. (1): QuantedLinear(
  86. (weight_quanter): FakeQuanterWithAbsMaxObserverLayer()
  87. (activation_quanter): FakeQuanterWithAbsMaxObserverLayer()
  88. )
  89. (2): QuantedLinear(
  90. (weight_quanter): FakeQuanterWithAbsMaxObserverLayer()
  91. (activation_quanter): FakeQuanterWithAbsMaxObserverLayer()
  92. )
  93. )
  94. )
  95. """
  96. assert (
  97. model.training
  98. ), "Quantization-Aware Training should work on training models. Please set training mode by model.train()."
  99. _model = model if inplace else copy.deepcopy(model)
  100. self._config._specify(_model)
  101. self._convert_to_quant_layers(_model, self._config)
  102. self._insert_activation_observers(_model, self._config)
  103. return _model