hub_kernels.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from functools import partial
  16. from typing import Optional, Union
  17. from ..modeling_flash_attention_utils import lazy_import_flash_attention
  18. from .flash_attention import flash_attention_forward
  19. try:
  20. from kernels import (
  21. Device,
  22. LayerRepository,
  23. Mode,
  24. get_kernel,
  25. register_kernel_mapping,
  26. replace_kernel_forward_from_hub,
  27. use_kernel_forward_from_hub,
  28. )
  29. _kernels_available = True
  30. _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
  31. "MultiScaleDeformableAttention": {
  32. "cuda": LayerRepository(
  33. repo_id="kernels-community/deformable-detr",
  34. layer_name="MultiScaleDeformableAttention",
  35. )
  36. },
  37. "Llama4TextMoe": {
  38. "cuda": LayerRepository(
  39. # Move to kernels-community/moe once we release.
  40. repo_id="kernels-community/moe",
  41. layer_name="Llama4TextMoe",
  42. )
  43. },
  44. "RMSNorm": {
  45. "cuda": LayerRepository(
  46. repo_id="kernels-community/liger_kernels",
  47. layer_name="LigerRMSNorm",
  48. # revision="pure-layer-test",
  49. ),
  50. "rocm": {
  51. Mode.INFERENCE: LayerRepository(
  52. repo_id="kernels-community/liger_kernels",
  53. layer_name="LigerRMSNorm",
  54. # revision="pure-layer-test",
  55. )
  56. },
  57. },
  58. "MLP": {
  59. "cuda": LayerRepository(
  60. repo_id="medmekk/triton-llama-mlp",
  61. layer_name="TritonLlamaMLP",
  62. )
  63. },
  64. "MegaBlocksMoeMLP": {
  65. "cuda": {
  66. Mode.TRAINING: LayerRepository(
  67. repo_id="kernels-community/megablocks",
  68. layer_name="MegaBlocksMoeMLP",
  69. ),
  70. Mode.INFERENCE: LayerRepository(
  71. repo_id="kernels-community/megablocks",
  72. layer_name="MegaBlocksMoeMLP",
  73. ),
  74. },
  75. "rocm": {
  76. Mode.INFERENCE: LayerRepository(
  77. repo_id="ahadnagy/megablocks",
  78. layer_name="MegaBlocksMoeMLP",
  79. )
  80. },
  81. },
  82. "FastGELU": {
  83. "cuda": {
  84. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  85. repo_id="kernels-community/activation",
  86. layer_name="FastGELU",
  87. version=">=0.0.4,<0.1.0",
  88. )
  89. }
  90. },
  91. "QuickGELU": {
  92. "cuda": {
  93. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  94. repo_id="kernels-community/activation",
  95. layer_name="QuickGELU",
  96. version=">=0.0.4,<0.1.0",
  97. )
  98. }
  99. },
  100. "NewGELU": {
  101. "cuda": {
  102. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  103. repo_id="kernels-community/activation",
  104. layer_name="NewGELU",
  105. version=">=0.0.4,<0.1.0",
  106. )
  107. }
  108. },
  109. "SiLU": {
  110. "cuda": {
  111. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  112. repo_id="kernels-community/activation", layer_name="Silu", version=">=0.1.0"
  113. )
  114. }
  115. },
  116. "GeLU": {
  117. "cuda": {
  118. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  119. repo_id="kernels-community/activation", layer_name="Gelu", version=">=0.1.0"
  120. )
  121. }
  122. },
  123. "GeluTanh": {
  124. "cuda": {
  125. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  126. repo_id="kernels-community/activation", layer_name="GeluTanh", version=">=0.1.0"
  127. )
  128. }
  129. },
  130. }
  131. register_kernel_mapping(_KERNEL_MAPPING)
  132. except ImportError:
  133. _kernels_available = False
  134. # Stub to make decorators int transformers work when `kernels`
  135. # is not installed.
  136. def use_kernel_forward_from_hub(*args, **kwargs):
  137. def decorator(cls):
  138. return cls
  139. return decorator
  140. class LayerRepository:
  141. def __init__(self, *args, **kwargs):
  142. raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
  143. def replace_kernel_forward_from_hub(*args, **kwargs):
  144. raise RuntimeError(
  145. "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
  146. )
  147. def register_kernel_mapping(*args, **kwargs):
  148. raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
  149. def is_kernel(attn_implementation: Optional[str]) -> bool:
  150. """Check whether `attn_implementation` matches a kernel pattern from the hub."""
  151. return (
  152. attn_implementation is not None
  153. and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None
  154. )
  155. def load_and_register_kernel(attn_implementation: str) -> None:
  156. """Load and register the kernel associated to `attn_implementation`."""
  157. if not is_kernel(attn_implementation):
  158. return
  159. if not _kernels_available:
  160. raise ImportError(
  161. "`kernels` is either not installed or uses an incompatible version. "
  162. "Please install the latest version with `pip install -U kernels`."
  163. )
  164. # Need to be imported here as otherwise we have a circular import in `modeling_utils`
  165. from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
  166. from ..modeling_utils import ALL_ATTENTION_FUNCTIONS
  167. attention_wrapper = None
  168. # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
  169. actual_attn_name = attn_implementation
  170. if "|" in attn_implementation:
  171. attention_wrapper, actual_attn_name = attn_implementation.split("|")
  172. # `transformers` has wrapper for sdpa, paged, flash, flex etc.
  173. attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
  174. # Extract repo_id and kernel_name from the string
  175. if ":" in actual_attn_name:
  176. repo_id, kernel_name = actual_attn_name.split(":")
  177. kernel_name = kernel_name.strip()
  178. else:
  179. repo_id = actual_attn_name
  180. kernel_name = None
  181. repo_id = repo_id.strip()
  182. # extract the rev after the @ if it exists
  183. repo_id, _, rev = repo_id.partition("@")
  184. repo_id = repo_id.strip()
  185. rev = rev.strip() if rev else None
  186. # Load the kernel from hub
  187. try:
  188. kernel = get_kernel(repo_id, revision=rev)
  189. except Exception as e:
  190. raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.")
  191. # correctly wrap the kernel
  192. if hasattr(kernel, "flash_attn_varlen_func"):
  193. if attention_wrapper is None:
  194. attention_wrapper = flash_attention_forward
  195. kernel_function = partial(attention_wrapper, implementation=kernel)
  196. lazy_import_flash_attention(kernel, force_import=True)
  197. elif kernel_name is not None:
  198. kernel_function = getattr(kernel, kernel_name)
  199. # Register the kernel as a valid attention
  200. ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
  201. ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
  202. __all__ = [
  203. "LayerRepository",
  204. "use_kernel_forward_from_hub",
  205. "register_kernel_mapping",
  206. "replace_kernel_forward_from_hub",
  207. ]