| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import inspect
- class Registry:
- """A general registry object."""
- __slots__ = ['name', 'rules']
- def __init__(self, name):
- self.name = name
- self.rules = {}
- def register(self, op_type, rule):
- assert isinstance(op_type, str)
- assert inspect.isfunction(rule)
- assert (
- op_type not in self.rules
- ), f'name "{op_type}" should not be registered before.'
- self.rules[op_type] = rule
- def lookup(self, op_type):
- return self.rules.get(op_type)
- _decomposition_ops = Registry('decomposition')
- def register_decomp(op_type):
- """
- Decorator for registering the lower function for an original op into sequence of primitive ops.
- Args:
- op_type(str): The op name
- Returns:
- wrapper: Inner wrapper function
- Examples:
- .. code-block:: python
- >>> from paddle.decomposition import register
- >>> @register.register_decomp('softmax')
- >>> def softmax(x, axis):
- ... molecular = exp(x)
- ... denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape)
- ... res = divide(molecular, denominator)
- ... return res
- """
- if not isinstance(op_type, str):
- raise TypeError(f'op_type must be str, but got {type(op_type)}.')
- def wrapper(f):
- _decomposition_ops.register(op_type, f)
- return f
- return wrapper
- def get_decomp_rule(op_type):
- _lowerrule = _decomposition_ops.lookup(op_type)
- return _lowerrule
|