triton.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import ast
  2. import contextlib
  3. import inspect
  4. import threading
  5. from collections.abc import Generator, Iterable
  6. from typing import Any, Callable, Optional, Union
  7. from torch.utils._exposed_in import exposed_in
  8. from .custom_ops import custom_op, CustomOpDef
  9. from .infer_schema import infer_schema
  10. triton_ops_to_kernels: dict[str, list[object]] = {}
  11. def get_triton_kernels_for_op(name: str) -> list[object]:
  12. return triton_ops_to_kernels.get(name, [])
  13. def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]:
  14. """
  15. Inspect the source of an arbitrary callable passed to torch._library.triton_op,
  16. and grab all of the triton kernels that are wrapped inside of it.
  17. TODO: This check is best effort. It does *not* handle the case where the triton
  18. kernel is hidden behind recursive function calls.
  19. """
  20. def find_triton_kernels(fn: Callable[..., Any]) -> list[object]:
  21. try:
  22. source = inspect.getsource(fn)
  23. except (OSError, TypeError):
  24. return [] # Source code not available
  25. from torch._inductor.utils import IndentedBuffer
  26. buffer = IndentedBuffer()
  27. buffer.splice(source, strip=True)
  28. tree = ast.parse(buffer.getrawvalue())
  29. # Visitor to collect function calls and triton kernels
  30. class Visitor(ast.NodeVisitor):
  31. def __init__(self) -> None:
  32. self.triton_kernels: list[Any] = []
  33. def visit_Call(self, node: ast.Call) -> None:
  34. triton_func_names = ("capture_triton", "wrap_triton")
  35. if isinstance(node.func, ast.Attribute):
  36. attr = node.func
  37. if (
  38. isinstance(attr.value, ast.Attribute)
  39. and isinstance(attr.value.value, ast.Name)
  40. and attr.value.value.id == "torch"
  41. and attr.value.attr == "_library"
  42. and attr.attr in triton_func_names
  43. ):
  44. if node.args and isinstance(node.args[0], ast.Name):
  45. self.triton_kernels.append(node.args[0].id)
  46. # Catch capture_triton, wrap_triton that's been
  47. # imported directly
  48. elif isinstance(node.func, ast.Name):
  49. if node.func.id in triton_func_names:
  50. if node.args and isinstance(node.args[0], ast.Name):
  51. self.triton_kernels.append(node.args[0].id)
  52. self.generic_visit(node)
  53. collector = Visitor()
  54. collector.visit(tree)
  55. closure_vars = inspect.getclosurevars(fn)
  56. resolved = []
  57. # First, resolve triton kernel names
  58. for name in collector.triton_kernels:
  59. if name in closure_vars.nonlocals:
  60. resolved.append(closure_vars.nonlocals[name])
  61. elif name in closure_vars.globals:
  62. resolved.append(closure_vars.globals[name])
  63. elif name in closure_vars.builtins:
  64. resolved.append(closure_vars.builtins[name])
  65. return resolved
  66. return find_triton_kernels(fn)
  67. @exposed_in("torch.library")
  68. def triton_op(
  69. name: str,
  70. fn: Optional[Callable] = None,
  71. /,
  72. *,
  73. mutates_args: Union[str, Iterable[str]],
  74. schema: Optional[str] = None,
  75. ) -> Callable:
  76. """Create a custom operator whose implementation is backed by 1+ triton kernels.
  77. This is a more structured way of using triton kernels with PyTorch.
  78. Prefer using triton kernels with no ``torch.library`` custom operator wrappers
  79. (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
  80. that is simpler;
  81. only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
  82. want to create an operator that behaves like PyTorch built-in operators.
  83. For example, you may use a ``torch.library`` wrapper API to define the
  84. behavior of the triton kernel when passed a tensor subclass or under
  85. a TorchDispatchMode.
  86. Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
  87. when the implementation
  88. consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
  89. custom operators as opaque (:func:`torch.compile` and
  90. :func:`torch.export.export` will never trace into them), but ``triton_op``
  91. makes the implementation visible to these subsystems, allowing them
  92. to optimize the triton kernel(s).
  93. Note that ``fn`` must only consist of calls to PyTorch-understood
  94. operators and triton kernels. Any triton kernels called inside ``fn``
  95. must be wrapped in a call to :func:`torch.library.wrap_triton`.
  96. Args:
  97. name (str): A name for the custom op that looks like "{namespace}::{name}",
  98. e.g. "mylib::my_linear". The name is used as the op's stable identifier
  99. in PyTorch subsystems (e.g. torch.export, FX graphs).
  100. To avoid name collisions, please use your project name as the namespace;
  101. e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
  102. mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
  103. This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
  104. it pessimistically assumes that all inputs to the operator are being mutated.
  105. schema (None | str): A schema string for the operator. If None
  106. (recommended) we'll infer a schema for the operator from its type
  107. annotations. We recommend letting us infer a schema unless you
  108. have a specific reason not to.
  109. Example: "(Tensor x, int y) -> (Tensor, Tensor)".
  110. Example::
  111. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  112. >>> import torch
  113. >>> from torch.library import triton_op, wrap_triton
  114. >>>
  115. >>> import triton
  116. >>> from triton import language as tl
  117. >>>
  118. >>> @triton.jit
  119. >>> def add_kernel(
  120. >>> in_ptr0,
  121. >>> in_ptr1,
  122. >>> out_ptr,
  123. >>> n_elements,
  124. >>> BLOCK_SIZE: "tl.constexpr",
  125. >>> ):
  126. >>> pid = tl.program_id(axis=0)
  127. >>> block_start = pid * BLOCK_SIZE
  128. >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
  129. >>> mask = offsets < n_elements
  130. >>> x = tl.load(in_ptr0 + offsets, mask=mask)
  131. >>> y = tl.load(in_ptr1 + offsets, mask=mask)
  132. >>> output = x + y
  133. >>> tl.store(out_ptr + offsets, output, mask=mask)
  134. >>>
  135. >>> @triton_op("mylib::add", mutates_args={})
  136. >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  137. >>> output = torch.empty_like(x)
  138. >>> n_elements = output.numel()
  139. >>>
  140. >>> def grid(meta):
  141. >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  142. >>>
  143. >>> # NB: we need to wrap the triton kernel in a call to wrap_triton
  144. >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
  145. >>> return output
  146. >>>
  147. >>> @torch.compile
  148. >>> def f(x, y):
  149. >>> return add(x, y)
  150. >>>
  151. >>> x = torch.randn(3, device="cuda")
  152. >>> y = torch.randn(3, device="cuda")
  153. >>>
  154. >>> z = f(x, y)
  155. >>> assert torch.allclose(z, x + y)
  156. """
  157. def dec(fn: Callable[..., object]) -> CustomOpDef:
  158. def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
  159. # Optimization: we're passing regular Tensors into the triton kernel, so
  160. # no need to go through HOP dispatch
  161. with set_wrap_triton_enabled(False):
  162. return fn(*args, **kwargs)
  163. result = custom_op(
  164. name,
  165. backend_fn,
  166. mutates_args=mutates_args,
  167. schema=infer_schema(fn, mutates_args=mutates_args),
  168. )
  169. from .._subclasses.functional_tensor import FunctionalTensorMode
  170. # We require that the user pass us a function that is make_fx traceable,
  171. # so we can just register it as the Fake/meta kernel.
  172. result.register_fake(fn)
  173. # We decompose the operator when FunctionalTensorMode is active.
  174. # The goal is to decompose the operator in AOTDispatcher.
  175. # - With torch.compile, this means that the backend (usually Inductor)
  176. # can see a call to the triton kernel(s) and so it can directly optimize
  177. # them by inlining them into the lowering process.
  178. def functional_decomp( # type: ignore[no-untyped-def]
  179. mode, op, types, args, kwargs
  180. ):
  181. # NOTE [Export custom triton op]
  182. # For torch.export (strict and non-strict), we don't do functional decomposition.
  183. # Instead, we preserve the custom triton ops as custom ops. This is because we want
  184. # the exported program to be high-level and serializable. If we decompose
  185. # the custom op to a functional hop and make it a node in exported program,
  186. # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
  187. # functions and triton dtypes. This is undesireble because:
  188. # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
  189. # - exported program will contain the implementation detail (e.g. triton source code) for a specific
  190. # backend (GPU), which is probably at a wrong level of abstraction.
  191. # - changes to triton or the serialization logic for triton arguments can be BC breaking
  192. #
  193. # In the short term, we expect users to have a separate aot_compile stage that compiles the exported program
  194. # into a Cubin file on the same machine that users call export, which does autotuning and removes triton
  195. # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
  196. # In the long term, we may export multiple cubins for the triton op directly
  197. from torch.export._trace import custom_triton_ops_decomposition_disabled
  198. if custom_triton_ops_decomposition_disabled():
  199. return mode.__torch_dispatch__(op, types, args, kwargs)
  200. else:
  201. # TODO: https://github.com/pytorch/pytorch/issues/160333
  202. # We should deduplicate the unrecognized_types logic.
  203. import torch._subclasses
  204. unrecognized_types = [
  205. t
  206. for t in types
  207. if not issubclass(t, torch._subclasses.FakeTensor)
  208. and t
  209. not in [
  210. torch.Tensor,
  211. torch._subclasses.functional_tensor.FunctionalTensor,
  212. ]
  213. ]
  214. if unrecognized_types:
  215. return NotImplemented
  216. with mode:
  217. return fn(*args, **kwargs)
  218. triton_kernels = get_inner_triton_kernels(fn)
  219. triton_ops_to_kernels[name] = triton_kernels
  220. result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
  221. return result
  222. if fn is None:
  223. return dec
  224. else:
  225. return dec(fn)
  226. wrap_triton_enabled = threading.local()
  227. wrap_triton_enabled_default = True
  228. @contextlib.contextmanager
  229. def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
  230. """If triton kernels annotated with @wrap_triton should dispatch via HOP
  231. or go straight to the triton kernel execution.
  232. We have this switch because eager-mode performance of HOP dispatch is slow
  233. enough to matter (~1ms) and we know that wrap_triton isn't necessary in
  234. some situations (eager-mode with regular Tensors)
  235. """
  236. try:
  237. prev = is_wrap_triton_enabled()
  238. wrap_triton_enabled.value = enabled
  239. yield
  240. finally:
  241. wrap_triton_enabled.value = prev
  242. def is_wrap_triton_enabled() -> bool:
  243. return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
  244. def capture_triton(triton_kernel: Callable, /) -> Any:
  245. """This API has been renamed to wrap_triton"""
  246. return wrap_triton(triton_kernel)
  247. @exposed_in("torch.library")
  248. def wrap_triton(triton_kernel: Callable, /) -> Any:
  249. """Allows capture of a triton kernel into a graph via make_fx or
  250. non-strict ``torch.export``.
  251. These technologies perform Dispatcher-based tracing (via
  252. ``__torch_dispatch__``) and cannot see calls to raw triton kernels.
  253. The ``wrap_triton`` API wraps a triton kernel into a callable that
  254. can actually be traced into a graph.
  255. Please use this API together with :func:`torch.library.triton_op`.
  256. Examples:
  257. >>> # xdoctest: +SKIP
  258. >>> import torch
  259. >>> import triton
  260. >>> from triton import language as tl
  261. >>> from torch.fx.experimental.proxy_tensor import make_fx
  262. >>> from torch.library import wrap_triton
  263. >>>
  264. >>> @triton.jit
  265. >>> def add_kernel(
  266. >>> in_ptr0,
  267. >>> in_ptr1,
  268. >>> out_ptr,
  269. >>> n_elements,
  270. >>> BLOCK_SIZE: "tl.constexpr",
  271. >>> ):
  272. >>> pid = tl.program_id(axis=0)
  273. >>> block_start = pid * BLOCK_SIZE
  274. >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
  275. >>> mask = offsets < n_elements
  276. >>> x = tl.load(in_ptr0 + offsets, mask=mask)
  277. >>> y = tl.load(in_ptr1 + offsets, mask=mask)
  278. >>> output = x + y
  279. >>> tl.store(out_ptr + offsets, output, mask=mask)
  280. >>>
  281. >>> def add(x, y):
  282. >>> output = torch.empty_like(x)
  283. >>> n_elements = output.numel()
  284. >>>
  285. >>> def grid_fn(meta):
  286. >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
  287. >>>
  288. >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
  289. >>> return output
  290. >>>
  291. >>> x = torch.randn(3, device="cuda")
  292. >>> y = torch.randn(3, device="cuda")
  293. >>> gm = make_fx(add)(x, y)
  294. >>> print(gm.code)
  295. >>> # def forward(self, x_1, y_1):
  296. >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
  297. >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
  298. >>> # kernel_idx = 0, constant_args_idx = 0,
  299. >>> # grid = [(1, 1, 1)], kwargs = {
  300. >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
  301. >>> # 'n_elements': 3, 'BLOCK_SIZE': 16
  302. >>> # })
  303. >>> # return empty_like
  304. """
  305. from triton.runtime.autotuner import Autotuner
  306. from triton.runtime.jit import JITFunction
  307. from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
  308. if not isinstance(triton_kernel, (JITFunction, Autotuner)):
  309. raise RuntimeError(
  310. "wrap_triton only works on functions annotated with triton.jit or triton.autotune"
  311. )
  312. if not is_wrap_triton_enabled():
  313. return triton_kernel
  314. return TraceableTritonKernelWrapper(triton_kernel, None, None)