registry.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. This module implements TorchDynamo's backend registry system for managing compiler backends.
  3. The registry provides a centralized way to register, discover and manage different compiler
  4. backends that can be used with torch.compile(). It handles:
  5. - Backend registration and discovery through decorators and entry points
  6. - Lazy loading of backend implementations
  7. - Lookup and validation of backend names
  8. - Categorization of backends using tags (debug, experimental, etc.)
  9. Key components:
  10. - CompilerFn: Type for backend compiler functions that transform FX graphs
  11. - _BACKENDS: Registry mapping backend names to entry points
  12. - _COMPILER_FNS: Registry mapping backend names to loaded compiler functions
  13. Example usage:
  14. @register_backend
  15. def my_compiler(fx_graph, example_inputs):
  16. # Transform FX graph into optimized implementation
  17. return compiled_fn
  18. # Use registered backend
  19. torch.compile(model, backend="my_compiler")
  20. The registry also supports discovering backends through setuptools entry points
  21. in the "torch_dynamo_backends" group. Example:
  22. ```
  23. setup.py
  24. ---
  25. from setuptools import setup
  26. setup(
  27. name='my_torch_backend',
  28. version='0.1',
  29. packages=['my_torch_backend'],
  30. entry_points={
  31. 'torch_dynamo_backends': [
  32. # name = path to entry point of backend implementation
  33. 'my_compiler = my_torch_backend.compiler:my_compiler_function',
  34. ],
  35. },
  36. )
  37. ```
  38. ```
  39. my_torch_backend/compiler.py
  40. ---
  41. def my_compiler_function(fx_graph, example_inputs):
  42. # Transform FX graph into optimized implementation
  43. return compiled_fn
  44. ```
  45. Using `my_compiler` backend:
  46. ```
  47. import torch
  48. model = ... # Your PyTorch model
  49. optimized_model = torch.compile(model, backend="my_compiler")
  50. ```
  51. """
  52. import functools
  53. import logging
  54. import sys
  55. from collections.abc import Sequence
  56. from importlib.metadata import EntryPoint
  57. from typing import Any, Callable, Optional, Protocol, Union
  58. import torch
  59. from torch import fx
  60. log = logging.getLogger(__name__)
  61. class CompiledFn(Protocol):
  62. def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ...
  63. CompilerFn = Callable[[fx.GraphModule, list[torch.Tensor]], CompiledFn]
  64. _BACKENDS: dict[str, Optional[EntryPoint]] = {}
  65. _COMPILER_FNS: dict[str, CompilerFn] = {}
  66. def register_backend(
  67. compiler_fn: Optional[CompilerFn] = None,
  68. name: Optional[str] = None,
  69. tags: Sequence[str] = (),
  70. ) -> Callable[..., Any]:
  71. """
  72. Decorator to add a given compiler to the registry to allow calling
  73. `torch.compile` with string shorthand. Note: for projects not
  74. imported by default, it might be easier to pass a function directly
  75. as a backend and not use a string.
  76. Args:
  77. compiler_fn: Callable taking a FX graph and fake tensor inputs
  78. name: Optional name, defaults to `compiler_fn.__name__`
  79. tags: Optional set of string tags to categorize backend with
  80. """
  81. if compiler_fn is None:
  82. # @register_backend(name="") syntax
  83. return functools.partial(register_backend, name=name, tags=tags) # type: ignore[return-value]
  84. assert callable(compiler_fn)
  85. name = name or compiler_fn.__name__
  86. assert name not in _COMPILER_FNS, f"duplicate name: {name}"
  87. if compiler_fn not in _BACKENDS:
  88. _BACKENDS[name] = None
  89. _COMPILER_FNS[name] = compiler_fn
  90. compiler_fn._tags = tuple(tags) # type: ignore[attr-defined]
  91. return compiler_fn
  92. register_debug_backend = functools.partial(register_backend, tags=("debug",))
  93. register_experimental_backend = functools.partial(
  94. register_backend, tags=("experimental",)
  95. )
  96. def lookup_backend(compiler_fn: Union[str, CompilerFn]) -> CompilerFn:
  97. """Expand backend strings to functions"""
  98. if isinstance(compiler_fn, str):
  99. if compiler_fn not in _BACKENDS:
  100. _lazy_import()
  101. if compiler_fn not in _BACKENDS:
  102. from ..exc import InvalidBackend
  103. raise InvalidBackend(name=compiler_fn)
  104. if compiler_fn not in _COMPILER_FNS:
  105. entry_point = _BACKENDS[compiler_fn]
  106. if entry_point is not None:
  107. register_backend(compiler_fn=entry_point.load(), name=compiler_fn)
  108. compiler_fn = _COMPILER_FNS[compiler_fn]
  109. return compiler_fn
  110. # NOTE: can't type this due to public api mismatch; follow up with dev team
  111. def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: ignore[no-untyped-def]
  112. """
  113. Return valid strings that can be passed to:
  114. torch.compile(..., backend="name")
  115. """
  116. _lazy_import()
  117. exclude_tags_set = set(exclude_tags or ())
  118. backends = [
  119. name
  120. for name in _BACKENDS.keys()
  121. if name not in _COMPILER_FNS
  122. or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined]
  123. ]
  124. return sorted(backends)
  125. @functools.cache
  126. def _lazy_import() -> None:
  127. from .. import backends
  128. from ..utils import import_submodule
  129. import_submodule(backends)
  130. from ..repro.after_dynamo import dynamo_minifier_backend
  131. assert dynamo_minifier_backend is not None
  132. _discover_entrypoint_backends()
  133. @functools.cache
  134. def _discover_entrypoint_backends() -> None:
  135. # importing here so it will pick up the mocked version in test_backends.py
  136. from importlib.metadata import entry_points
  137. group_name = "torch_dynamo_backends"
  138. if sys.version_info < (3, 10):
  139. eps = entry_points()
  140. eps = eps[group_name] if group_name in eps else []
  141. eps_dict = {ep.name: ep for ep in eps}
  142. else:
  143. eps = entry_points(group=group_name)
  144. eps_dict = {name: eps[name] for name in eps.names}
  145. for backend_name in eps_dict:
  146. _BACKENDS[backend_name] = eps_dict[backend_name]