| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import abc
- import inspect
- from functools import partial
- from paddle.nn import Layer
- from .base_quanter import BaseQuanter
- class ClassWithArguments(metaclass=abc.ABCMeta):
- def __init__(self, *args, **kwargs):
- self._args = args
- self._kwargs = kwargs
- @property
- def args(self):
- return self._args
- @property
- def kwargs(self):
- return self._kwargs
- @abc.abstractmethod
- def _get_class(self):
- pass
- def __str__(self):
- args_str = ",".join(
- list(self.args) + [f"{k}={v}" for k, v in self.kwargs.items()]
- )
- return f"{self.__class__.__name__}({args_str})"
- def __repr__(self):
- return self.__str__()
- class QuanterFactory(ClassWithArguments):
- r"""
- The factory holds the quanter's class information and
- the arguments used to create quanter instance.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.partial_class = None
- def _instance(self, layer: Layer) -> BaseQuanter:
- r"""
- Create an instance of quanter for target layer.
- """
- if self.partial_class is None:
- self.partial_class = partial(
- self._get_class(), *self.args, **self.kwargs
- )
- return self.partial_class(layer)
- ObserverFactory = QuanterFactory
- def quanter(class_name):
- r"""
- Annotation to declare a factory class for quanter.
- Args:
- class_name (str): The name of factory class to be declared.
- Examples:
- .. code-block:: python
- >>> # doctest: +SKIP('need 2 file to run example')
- >>> # Given codes in ./customized_quanter.py
- >>> from paddle.quantization import quanter
- >>> from paddle.quantization import BaseQuanter
- >>> @quanter("CustomizedQuanter")
- >>> class CustomizedQuanterLayer(BaseQuanter):
- ... def __init__(self, arg1, kwarg1=None):
- ... pass
- >>> # Used in ./test.py
- >>> # from .customized_quanter import CustomizedQuanter
- >>> from paddle.quantization import QuantConfig
- >>> arg1_value = "test"
- >>> kwarg1_value = 20
- >>> quanter = CustomizedQuanter(arg1_value, kwarg1=kwarg1_value)
- >>> q_config = QuantConfig(activation=quanter, weight=quanter)
- """
- def wrapper(target_class):
- init_function_str = f"""
- def init_function(self, *args, **kwargs):
- super(type(self), self).__init__(*args, **kwargs)
- import importlib
- module = importlib.import_module("{target_class.__module__}")
- my_class = getattr(module, "{target_class.__name__}")
- globals()["{target_class.__name__}"] = my_class
- def get_class_function(self):
- return {target_class.__name__}
- locals()["init_function"]=init_function
- locals()["get_class_function"]=get_class_function
- """
- exec(init_function_str)
- frm = inspect.stack()[1]
- mod = inspect.getmodule(frm[0])
- new_class = type(
- class_name,
- (QuanterFactory,),
- {
- "__init__": locals()["init_function"],
- "_get_class": locals()["get_class_function"],
- },
- )
- setattr(mod, class_name, new_class)
- if "__all__" in mod.__dict__:
- mod.__all__.append(class_name)
- return target_class
- return wrapper
|