simple_registry.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # mypy: allow-untyped-defs
  2. from typing import Callable, Optional
  3. from .fake_impl import FakeImplHolder
  4. from .utils import RegistrationHandle
  5. __all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"]
  6. class SimpleLibraryRegistry:
  7. """Registry for the "simple" torch.library APIs
  8. The "simple" torch.library APIs are a higher-level API on top of the
  9. raw PyTorch DispatchKey registration APIs that includes:
  10. - fake impl
  11. Registrations for these APIs do not go into the PyTorch dispatcher's
  12. table because they may not directly involve a DispatchKey. For example,
  13. the fake impl is a Python function that gets invoked by FakeTensor.
  14. Instead, we manage them here.
  15. SimpleLibraryRegistry is a mapping from a fully qualified operator name
  16. (including the overload) to SimpleOperatorEntry.
  17. """
  18. def __init__(self):
  19. self._data = {}
  20. def find(self, qualname: str) -> "SimpleOperatorEntry":
  21. res = self._data.get(qualname, None)
  22. if res is None:
  23. self._data[qualname] = res = SimpleOperatorEntry(qualname)
  24. return res
  25. singleton: SimpleLibraryRegistry = SimpleLibraryRegistry()
  26. class SimpleOperatorEntry:
  27. """This is 1:1 to an operator overload.
  28. The fields of SimpleOperatorEntry are Holders where kernels can be
  29. registered to.
  30. """
  31. def __init__(self, qualname: str):
  32. self.qualname: str = qualname
  33. self.fake_impl: FakeImplHolder = FakeImplHolder(qualname)
  34. self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = (
  35. GenericTorchDispatchRuleHolder(qualname)
  36. )
  37. # For compatibility reasons. We can delete this soon.
  38. @property
  39. def abstract_impl(self):
  40. return self.fake_impl
  41. class GenericTorchDispatchRuleHolder:
  42. def __init__(self, qualname):
  43. self._data = {}
  44. self.qualname = qualname
  45. def register(
  46. self, torch_dispatch_class: type, func: Callable
  47. ) -> RegistrationHandle:
  48. if self.find(torch_dispatch_class):
  49. raise RuntimeError(
  50. f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}"
  51. )
  52. self._data[torch_dispatch_class] = func
  53. def deregister():
  54. del self._data[torch_dispatch_class]
  55. return RegistrationHandle(deregister)
  56. def find(self, torch_dispatch_class):
  57. return self._data.get(torch_dispatch_class, None)
  58. def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]:
  59. return singleton.find(op.__qualname__).torch_dispatch_rules.find(
  60. torch_dispatch_class
  61. )