register.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. class Registry:
  16. """A general registry object."""
  17. __slots__ = ['name', 'rules']
  18. def __init__(self, name):
  19. self.name = name
  20. self.rules = {}
  21. def register(self, op_type, rule):
  22. assert isinstance(op_type, str)
  23. assert inspect.isfunction(rule)
  24. assert (
  25. op_type not in self.rules
  26. ), f'name "{op_type}" should not be registered before.'
  27. self.rules[op_type] = rule
  28. def lookup(self, op_type):
  29. return self.rules.get(op_type)
  30. _decomposition_ops = Registry('decomposition')
  31. def register_decomp(op_type):
  32. """
  33. Decorator for registering the lower function for an original op into sequence of primitive ops.
  34. Args:
  35. op_type(str): The op name
  36. Returns:
  37. wrapper: Inner wrapper function
  38. Examples:
  39. .. code-block:: python
  40. >>> from paddle.decomposition import register
  41. >>> @register.register_decomp('softmax')
  42. >>> def softmax(x, axis):
  43. ... molecular = exp(x)
  44. ... denominator = broadcast_to(sum(molecular, axis=axis, keepdim=True), x.shape)
  45. ... res = divide(molecular, denominator)
  46. ... return res
  47. """
  48. if not isinstance(op_type, str):
  49. raise TypeError(f'op_type must be str, but got {type(op_type)}.')
  50. def wrapper(f):
  51. _decomposition_ops.register(op_type, f)
  52. return f
  53. return wrapper
  54. def get_decomp_rule(op_type):
  55. _lowerrule = _decomposition_ops.lookup(op_type)
  56. return _lowerrule