wrapper.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddle.nn import Layer
  15. from .base_quanter import BaseQuanter
  16. class ObserveWrapper(Layer):
  17. r"""
  18. Put an observer layer and an observed layer into a wrapping layer.
  19. It is used to insert layers into the model for QAT or PTQ.
  20. Args:
  21. observer(BaseQuanter): Observer layer
  22. observed(Layer): Observed layer
  23. observe_input(bool): If it is true the observer layer will be called before observed layer.
  24. If it is false the observed layer will be called before observer layer. Default: True.
  25. """
  26. def __init__(
  27. self,
  28. observer: BaseQuanter,
  29. observed: Layer,
  30. observe_input=True,
  31. ):
  32. super().__init__()
  33. self._observer = observer
  34. self._observed = observed
  35. self._observe_input = observe_input
  36. def forward(self, *inputs, **kwargs):
  37. if self._observe_input:
  38. out = self._observer(*inputs, **kwargs)
  39. return self._observed(out, **kwargs)
  40. else:
  41. out = self._observed(*inputs, **kwargs)
  42. return self._observer(out, **kwargs)