_backends.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import functools
  2. from importlib.metadata import entry_points
  3. from functools import cache
  4. import os
  5. import warnings
  6. def dispatching_disabled():
  7. """Determine if dispatching has been disabled by the user."""
  8. no_dispatching = os.environ.get("SKIMAGE_NO_DISPATCHING", False)
  9. if no_dispatching == "1":
  10. return True
  11. else:
  12. return False
  13. def public_api_module(func):
  14. """Get the name of the public module for a scikit-image function.
  15. This computes the name of the public submodule in which the function can
  16. be found.
  17. """
  18. full_name = func.__module__
  19. # This relies on the fact that scikit-image does not use
  20. # sub-submodules in its public API, except in one case.
  21. # This means that public name can be atmost `skimage.foobar`
  22. # for everything else
  23. sub_submodules = ["skimage.filters.rank"]
  24. candidates = [name for name in sub_submodules if full_name.startswith(name)]
  25. if len(candidates) == 0:
  26. # Assume first two parts of the name are where the function is in our public API
  27. parts = full_name.split(".")
  28. if len(parts) <= 2:
  29. msg = f"expected {func.__module__=} with more than 2 dot-delimited parts"
  30. raise ValueError(msg)
  31. public_name = ".".join(parts[:2])
  32. elif len(candidates) == 1:
  33. public_name = candidates[0]
  34. else:
  35. msg = f"{func!r} matches more than one sub-submodule: {candidates!r}"
  36. raise ValueError(msg)
  37. # It would be nice to sanity check things by doing something like the
  38. # following. However we can't because this code is executed while the
  39. # module is being imported, which means this would create a circular
  40. # import
  41. # mod = importlib.import_module(public_name)
  42. # assert getattr(mod, func.__name__) is func
  43. return public_name
  44. @cache
  45. def all_backends():
  46. """List all installed backends and information about them."""
  47. backends = {}
  48. backends_ = entry_points(group="skimage_backends")
  49. backend_infos = entry_points(group="skimage_backend_infos")
  50. for backend in backends_:
  51. backends[backend.name] = {"implementation": backend}
  52. try:
  53. info = backend_infos[backend.name]
  54. # Double () to load and then call the backend information function
  55. backends[backend.name]["info"] = info.load()()
  56. except KeyError:
  57. pass
  58. return backends
  59. def dispatchable(func):
  60. """Mark a function as dispatchable.
  61. When a decorated function is called, the installed backends are
  62. searched for an implementation. If no backend implements the function
  63. then the scikit-image implementation is used.
  64. """
  65. func_name = func.__name__
  66. func_module = public_api_module(func)
  67. # If no backends are installed at all or dispatching is disabled,
  68. # return the original function. This way people who don't care about it
  69. # don't see anything related to dispatching
  70. if dispatching_disabled() or not all_backends():
  71. return func
  72. @functools.wraps(func)
  73. def wrapper(*args, **kwargs):
  74. # Backends are tried in alphabetical order, this makes things
  75. # predictable and stable across runs. Might need a better solution
  76. # when it becomes common that users have more than one backend
  77. # that would accept a call.
  78. for name in sorted(all_backends()):
  79. backend = all_backends()[name]
  80. # Check if the function we are looking for is implemented in
  81. # the backend
  82. if f"{func_module}:{func_name}" not in backend["info"].supported_functions:
  83. continue
  84. backend_impl = backend["implementation"].load()
  85. # Allow the backend to accept/reject a call based on the function
  86. # name and the arguments
  87. wants_it = backend_impl.can_has(
  88. f"{func_module}:{func_name}", *args, **kwargs
  89. )
  90. if not wants_it:
  91. continue
  92. func_impl = backend_impl.get_implementation(f"{func_module}:{func_name}")
  93. warnings.warn(
  94. f"Call to '{func_module}:{func_name}' was dispatched to"
  95. f" the '{name}' backend. Set SKIMAGE_NO_DISPATCHING=1 to"
  96. " disable this.",
  97. DispatchNotification,
  98. # XXX from where should this warning originate?
  99. # XXX from where the function that was dispatched was called?
  100. # XXX or from where the user called a function that called
  101. # XXX a function that was dispatched?
  102. stacklevel=2,
  103. )
  104. return func_impl(*args, **kwargs)
  105. else:
  106. return func(*args, **kwargs)
  107. return wrapper
  108. class BackendInformation:
  109. """Information about a backend
  110. A backend that wants to provide additional information about itself
  111. should return an instance of this from its information entry point.
  112. """
  113. def __init__(self, supported_functions):
  114. self.supported_functions = supported_functions
  115. class DispatchNotification(RuntimeWarning):
  116. """Notification issued when a function is dispatched to a backend."""
  117. pass