factory.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import inspect
  16. from functools import partial
  17. from paddle.nn import Layer
  18. from .base_quanter import BaseQuanter
  19. class ClassWithArguments(metaclass=abc.ABCMeta):
  20. def __init__(self, *args, **kwargs):
  21. self._args = args
  22. self._kwargs = kwargs
  23. @property
  24. def args(self):
  25. return self._args
  26. @property
  27. def kwargs(self):
  28. return self._kwargs
  29. @abc.abstractmethod
  30. def _get_class(self):
  31. pass
  32. def __str__(self):
  33. args_str = ",".join(
  34. list(self.args) + [f"{k}={v}" for k, v in self.kwargs.items()]
  35. )
  36. return f"{self.__class__.__name__}({args_str})"
  37. def __repr__(self):
  38. return self.__str__()
  39. class QuanterFactory(ClassWithArguments):
  40. r"""
  41. The factory holds the quanter's class information and
  42. the arguments used to create quanter instance.
  43. """
  44. def __init__(self, *args, **kwargs):
  45. super().__init__(*args, **kwargs)
  46. self.partial_class = None
  47. def _instance(self, layer: Layer) -> BaseQuanter:
  48. r"""
  49. Create an instance of quanter for target layer.
  50. """
  51. if self.partial_class is None:
  52. self.partial_class = partial(
  53. self._get_class(), *self.args, **self.kwargs
  54. )
  55. return self.partial_class(layer)
  56. ObserverFactory = QuanterFactory
  57. def quanter(class_name):
  58. r"""
  59. Annotation to declare a factory class for quanter.
  60. Args:
  61. class_name (str): The name of factory class to be declared.
  62. Examples:
  63. .. code-block:: python
  64. >>> # doctest: +SKIP('need 2 file to run example')
  65. >>> # Given codes in ./customized_quanter.py
  66. >>> from paddle.quantization import quanter
  67. >>> from paddle.quantization import BaseQuanter
  68. >>> @quanter("CustomizedQuanter")
  69. >>> class CustomizedQuanterLayer(BaseQuanter):
  70. ... def __init__(self, arg1, kwarg1=None):
  71. ... pass
  72. >>> # Used in ./test.py
  73. >>> # from .customized_quanter import CustomizedQuanter
  74. >>> from paddle.quantization import QuantConfig
  75. >>> arg1_value = "test"
  76. >>> kwarg1_value = 20
  77. >>> quanter = CustomizedQuanter(arg1_value, kwarg1=kwarg1_value)
  78. >>> q_config = QuantConfig(activation=quanter, weight=quanter)
  79. """
  80. def wrapper(target_class):
  81. init_function_str = f"""
  82. def init_function(self, *args, **kwargs):
  83. super(type(self), self).__init__(*args, **kwargs)
  84. import importlib
  85. module = importlib.import_module("{target_class.__module__}")
  86. my_class = getattr(module, "{target_class.__name__}")
  87. globals()["{target_class.__name__}"] = my_class
  88. def get_class_function(self):
  89. return {target_class.__name__}
  90. locals()["init_function"]=init_function
  91. locals()["get_class_function"]=get_class_function
  92. """
  93. exec(init_function_str)
  94. frm = inspect.stack()[1]
  95. mod = inspect.getmodule(frm[0])
  96. new_class = type(
  97. class_name,
  98. (QuanterFactory,),
  99. {
  100. "__init__": locals()["init_function"],
  101. "_get_class": locals()["get_class_function"],
  102. },
  103. )
  104. setattr(mod, class_name, new_class)
  105. if "__all__" in mod.__dict__:
  106. mod.__all__.append(class_name)
  107. return target_class
  108. return wrapper