activations.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import math
  16. from collections import OrderedDict
  17. import torch
  18. from torch import Tensor, nn
  19. from .integrations.hub_kernels import use_kernel_forward_from_hub
  20. from .utils import logging
  21. from .utils.import_utils import is_torchdynamo_compiling
  22. logger = logging.get_logger(__name__)
  23. @use_kernel_forward_from_hub("GeluTanh")
  24. class GELUTanh(nn.Module):
  25. """
  26. A fast C implementation of the tanh approximation of the GeLU activation function. See
  27. https://huggingface.co/papers/1606.08415.
  28. This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
  29. match due to rounding errors.
  30. """
  31. def __init__(self, use_gelu_tanh_python: bool = False):
  32. super().__init__()
  33. if use_gelu_tanh_python:
  34. self.act = self._gelu_tanh_python
  35. else:
  36. self.act = functools.partial(nn.functional.gelu, approximate="tanh")
  37. def _gelu_tanh_python(self, input: Tensor) -> Tensor:
  38. return input * 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
  39. def forward(self, input: Tensor) -> Tensor:
  40. return self.act(input)
  41. @use_kernel_forward_from_hub("NewGELU")
  42. class NewGELUActivation(nn.Module):
  43. """
  44. Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
  45. the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
  46. """
  47. def forward(self, input: Tensor) -> Tensor:
  48. return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
  49. @use_kernel_forward_from_hub("GeLU")
  50. class GELUActivation(nn.Module):
  51. """
  52. Original Implementation of the GELU activation function in Google BERT repo when initially created. For
  53. information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
  54. torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
  55. Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
  56. """
  57. def __init__(self, use_gelu_python: bool = False):
  58. super().__init__()
  59. if use_gelu_python:
  60. self.act = self._gelu_python
  61. else:
  62. self.act = nn.functional.gelu
  63. def _gelu_python(self, input: Tensor) -> Tensor:
  64. return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
  65. def forward(self, input: Tensor) -> Tensor:
  66. return self.act(input)
  67. @use_kernel_forward_from_hub("SiLU")
  68. class SiLUActivation(nn.Module):
  69. """
  70. See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
  71. Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
  72. Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
  73. Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
  74. later.
  75. """
  76. def forward(self, input: Tensor) -> Tensor:
  77. return nn.functional.silu(input)
  78. @use_kernel_forward_from_hub("FastGELU")
  79. class FastGELUActivation(nn.Module):
  80. """
  81. Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
  82. """
  83. def forward(self, input: Tensor) -> Tensor:
  84. return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
  85. @use_kernel_forward_from_hub("QuickGELU")
  86. class QuickGELUActivation(nn.Module):
  87. """
  88. Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
  89. """
  90. def forward(self, input: Tensor) -> Tensor:
  91. return input * torch.sigmoid(1.702 * input)
  92. class ClippedGELUActivation(nn.Module):
  93. """
  94. Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
  95. it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
  96. https://huggingface.co/papers/2004.09602.
  97. Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
  98. initially created.
  99. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
  100. torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
  101. """
  102. def __init__(self, min: float, max: float):
  103. if min > max:
  104. raise ValueError(f"min should be < max (got min: {min}, max: {max})")
  105. super().__init__()
  106. self.min = min
  107. self.max = max
  108. def forward(self, x: Tensor) -> Tensor:
  109. return torch.clip(gelu(x), self.min, self.max)
  110. class AccurateGELUActivation(nn.Module):
  111. """
  112. Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
  113. https://github.com/hendrycks/GELUs
  114. Implemented along with MEGA (Moving Average Equipped Gated Attention)
  115. """
  116. def __init__(self):
  117. super().__init__()
  118. self.precomputed_constant = math.sqrt(2 / math.pi)
  119. def forward(self, input: Tensor) -> Tensor:
  120. return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
  121. class MishActivation(nn.Module):
  122. """
  123. See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
  124. visit the official repository for the paper: https://github.com/digantamisra98/Mish
  125. """
  126. def __init__(self):
  127. super().__init__()
  128. self.act = nn.functional.mish
  129. def _mish_python(self, input: Tensor) -> Tensor:
  130. return input * torch.tanh(nn.functional.softplus(input))
  131. def forward(self, input: Tensor) -> Tensor:
  132. return self.act(input)
  133. class LinearActivation(nn.Module):
  134. """
  135. Applies the linear activation function, i.e. forwarding input directly to output.
  136. """
  137. def forward(self, input: Tensor) -> Tensor:
  138. return input
  139. class LaplaceActivation(nn.Module):
  140. """
  141. Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
  142. https://huggingface.co/papers/2209.10655
  143. Inspired by squared relu, but with bounded range and gradient for better stability
  144. """
  145. def forward(self, input, mu=0.707107, sigma=0.282095):
  146. input = (input - mu).div(sigma * math.sqrt(2.0))
  147. return 0.5 * (1.0 + torch.erf(input))
  148. class ReLUSquaredActivation(nn.Module):
  149. """
  150. Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2
  151. """
  152. def forward(self, input):
  153. relu_applied = nn.functional.relu(input)
  154. squared = torch.square(relu_applied)
  155. return squared
  156. class ClassInstantier(OrderedDict):
  157. def __getitem__(self, key):
  158. content = super().__getitem__(key)
  159. cls, kwargs = content if isinstance(content, tuple) else (content, {})
  160. return cls(**kwargs)
  161. class XIELUActivation(nn.Module):
  162. """
  163. Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
  164. If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
  165. Otherwise, we emit a single warning and use xIELU Python
  166. """
  167. def __init__(
  168. self,
  169. alpha_p_init=0.8,
  170. alpha_n_init=0.8,
  171. beta=0.5,
  172. eps=-1e-6,
  173. dtype=torch.bfloat16,
  174. with_vector_loads=False,
  175. ):
  176. super().__init__()
  177. self.alpha_p = nn.Parameter(torch.log(torch.expm1(torch.tensor(alpha_p_init, dtype=dtype))).unsqueeze(0))
  178. self.alpha_n = nn.Parameter(
  179. torch.log(torch.expm1(torch.tensor(alpha_n_init - beta, dtype=dtype))).unsqueeze(0)
  180. )
  181. self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
  182. self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
  183. self.with_vector_loads = with_vector_loads
  184. # Temporary until xIELU CUDA fully implemented
  185. self._beta_scalar = float(self.beta.detach().cpu().float().item())
  186. self._eps_scalar = float(self.eps.detach().cpu().float().item())
  187. self._xielu_cuda_obj = None
  188. try:
  189. import xielu.ops # noqa: F401
  190. self._xielu_cuda_obj = torch.classes.xielu.XIELU()
  191. msg = "Using experimental xIELU CUDA."
  192. try:
  193. from torch._dynamo import allow_in_graph
  194. self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
  195. msg += " Enabled torch._dynamo for xIELU CUDA."
  196. except Exception as err:
  197. msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
  198. self._xielu_cuda_fn = self._xielu_cuda
  199. logger.warning_once(msg)
  200. except Exception as err:
  201. logger.warning_once(
  202. "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n"
  203. "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
  204. str(err),
  205. )
  206. def _xielu_python(self, x: Tensor) -> Tensor:
  207. alpha_p = nn.functional.softplus(self.alpha_p)
  208. alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
  209. return torch.where(
  210. x > 0,
  211. alpha_p * x * x + self.beta * x,
  212. (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
  213. )
  214. def _xielu_cuda(self, x: Tensor) -> Tensor:
  215. """Firewall function to prevent torch.compile from seeing .item() calls"""
  216. original_shape = x.shape
  217. # CUDA kernel expects 3D tensors, reshape if needed
  218. while x.dim() < 3:
  219. x = x.unsqueeze(0)
  220. if x.dim() > 3:
  221. x = x.view(-1, 1, x.size(-1))
  222. if original_shape != x.shape:
  223. logger.warning_once(
  224. "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
  225. original_shape,
  226. x.shape,
  227. )
  228. result = self._xielu_cuda_obj.forward(
  229. x,
  230. self.alpha_p.to(x.dtype),
  231. self.alpha_n.to(x.dtype),
  232. # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
  233. self._beta_scalar,
  234. self._eps_scalar,
  235. self.with_vector_loads,
  236. )
  237. return result.view(original_shape)
  238. def forward(self, input: Tensor) -> Tensor:
  239. if self._xielu_cuda_obj is not None and input.is_cuda:
  240. if not is_torchdynamo_compiling():
  241. return self._xielu_cuda_fn(input)
  242. else:
  243. logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
  244. return self._xielu_python(input)
  245. ACT2CLS = {
  246. "gelu": GELUActivation,
  247. "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
  248. "gelu_fast": FastGELUActivation,
  249. "gelu_new": NewGELUActivation,
  250. "gelu_python": (GELUActivation, {"use_gelu_python": True}),
  251. "gelu_pytorch_tanh": GELUTanh,
  252. "gelu_python_tanh": (GELUTanh, {"use_gelu_tanh_python": True}),
  253. "gelu_accurate": AccurateGELUActivation,
  254. "laplace": LaplaceActivation,
  255. "leaky_relu": nn.LeakyReLU,
  256. "linear": LinearActivation,
  257. "mish": MishActivation,
  258. "quick_gelu": QuickGELUActivation,
  259. "relu": nn.ReLU,
  260. "relu2": ReLUSquaredActivation,
  261. "relu6": nn.ReLU6,
  262. "sigmoid": nn.Sigmoid,
  263. "silu": SiLUActivation,
  264. "swish": nn.SiLU,
  265. "tanh": nn.Tanh,
  266. "prelu": nn.PReLU,
  267. "xielu": XIELUActivation,
  268. }
  269. ACT2FN = ClassInstantier(ACT2CLS)
  270. def get_activation(activation_string):
  271. if activation_string in ACT2FN:
  272. return ACT2FN[activation_string]
  273. else:
  274. raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
  275. # For backwards compatibility with: from activations import gelu_python
  276. gelu_python = get_activation("gelu_python")
  277. gelu_new = get_activation("gelu_new")
  278. gelu = get_activation("gelu")
  279. gelu_fast = get_activation("gelu_fast")
  280. quick_gelu = get_activation("quick_gelu")
  281. silu = get_activation("silu")
  282. mish = get_activation("mish")
  283. linear_act = get_activation("linear")