| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- import ast
- import contextlib
- import inspect
- import threading
- from collections.abc import Generator, Iterable
- from typing import Any, Callable, Optional, Union
- from torch.utils._exposed_in import exposed_in
- from .custom_ops import custom_op, CustomOpDef
- from .infer_schema import infer_schema
- triton_ops_to_kernels: dict[str, list[object]] = {}
- def get_triton_kernels_for_op(name: str) -> list[object]:
- return triton_ops_to_kernels.get(name, [])
- def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]:
- """
- Inspect the source of an arbitrary callable passed to torch._library.triton_op,
- and grab all of the triton kernels that are wrapped inside of it.
- TODO: This check is best effort. It does *not* handle the case where the triton
- kernel is hidden behind recursive function calls.
- """
- def find_triton_kernels(fn: Callable[..., Any]) -> list[object]:
- try:
- source = inspect.getsource(fn)
- except (OSError, TypeError):
- return [] # Source code not available
- from torch._inductor.utils import IndentedBuffer
- buffer = IndentedBuffer()
- buffer.splice(source, strip=True)
- tree = ast.parse(buffer.getrawvalue())
- # Visitor to collect function calls and triton kernels
- class Visitor(ast.NodeVisitor):
- def __init__(self) -> None:
- self.triton_kernels: list[Any] = []
- def visit_Call(self, node: ast.Call) -> None:
- triton_func_names = ("capture_triton", "wrap_triton")
- if isinstance(node.func, ast.Attribute):
- attr = node.func
- if (
- isinstance(attr.value, ast.Attribute)
- and isinstance(attr.value.value, ast.Name)
- and attr.value.value.id == "torch"
- and attr.value.attr == "_library"
- and attr.attr in triton_func_names
- ):
- if node.args and isinstance(node.args[0], ast.Name):
- self.triton_kernels.append(node.args[0].id)
- # Catch capture_triton, wrap_triton that's been
- # imported directly
- elif isinstance(node.func, ast.Name):
- if node.func.id in triton_func_names:
- if node.args and isinstance(node.args[0], ast.Name):
- self.triton_kernels.append(node.args[0].id)
- self.generic_visit(node)
- collector = Visitor()
- collector.visit(tree)
- closure_vars = inspect.getclosurevars(fn)
- resolved = []
- # First, resolve triton kernel names
- for name in collector.triton_kernels:
- if name in closure_vars.nonlocals:
- resolved.append(closure_vars.nonlocals[name])
- elif name in closure_vars.globals:
- resolved.append(closure_vars.globals[name])
- elif name in closure_vars.builtins:
- resolved.append(closure_vars.builtins[name])
- return resolved
- return find_triton_kernels(fn)
- @exposed_in("torch.library")
- def triton_op(
- name: str,
- fn: Optional[Callable] = None,
- /,
- *,
- mutates_args: Union[str, Iterable[str]],
- schema: Optional[str] = None,
- ) -> Callable:
- """Create a custom operator whose implementation is backed by 1+ triton kernels.
- This is a more structured way of using triton kernels with PyTorch.
- Prefer using triton kernels with no ``torch.library`` custom operator wrappers
- (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
- that is simpler;
- only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
- want to create an operator that behaves like PyTorch built-in operators.
- For example, you may use a ``torch.library`` wrapper API to define the
- behavior of the triton kernel when passed a tensor subclass or under
- a TorchDispatchMode.
- Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
- when the implementation
- consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
- custom operators as opaque (:func:`torch.compile` and
- :func:`torch.export.export` will never trace into them), but ``triton_op``
- makes the implementation visible to these subsystems, allowing them
- to optimize the triton kernel(s).
- Note that ``fn`` must only consist of calls to PyTorch-understood
- operators and triton kernels. Any triton kernels called inside ``fn``
- must be wrapped in a call to :func:`torch.library.wrap_triton`.
- Args:
- name (str): A name for the custom op that looks like "{namespace}::{name}",
- e.g. "mylib::my_linear". The name is used as the op's stable identifier
- in PyTorch subsystems (e.g. torch.export, FX graphs).
- To avoid name collisions, please use your project name as the namespace;
- e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
- mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
- This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
- it pessimistically assumes that all inputs to the operator are being mutated.
- schema (None | str): A schema string for the operator. If None
- (recommended) we'll infer a schema for the operator from its type
- annotations. We recommend letting us infer a schema unless you
- have a specific reason not to.
- Example: "(Tensor x, int y) -> (Tensor, Tensor)".
- Example::
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
- >>> import torch
- >>> from torch.library import triton_op, wrap_triton
- >>>
- >>> import triton
- >>> from triton import language as tl
- >>>
- >>> @triton.jit
- >>> def add_kernel(
- >>> in_ptr0,
- >>> in_ptr1,
- >>> out_ptr,
- >>> n_elements,
- >>> BLOCK_SIZE: "tl.constexpr",
- >>> ):
- >>> pid = tl.program_id(axis=0)
- >>> block_start = pid * BLOCK_SIZE
- >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
- >>> mask = offsets < n_elements
- >>> x = tl.load(in_ptr0 + offsets, mask=mask)
- >>> y = tl.load(in_ptr1 + offsets, mask=mask)
- >>> output = x + y
- >>> tl.store(out_ptr + offsets, output, mask=mask)
- >>>
- >>> @triton_op("mylib::add", mutates_args={})
- >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
- >>> output = torch.empty_like(x)
- >>> n_elements = output.numel()
- >>>
- >>> def grid(meta):
- >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- >>>
- >>> # NB: we need to wrap the triton kernel in a call to wrap_triton
- >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
- >>> return output
- >>>
- >>> @torch.compile
- >>> def f(x, y):
- >>> return add(x, y)
- >>>
- >>> x = torch.randn(3, device="cuda")
- >>> y = torch.randn(3, device="cuda")
- >>>
- >>> z = f(x, y)
- >>> assert torch.allclose(z, x + y)
- """
- def dec(fn: Callable[..., object]) -> CustomOpDef:
- def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
- # Optimization: we're passing regular Tensors into the triton kernel, so
- # no need to go through HOP dispatch
- with set_wrap_triton_enabled(False):
- return fn(*args, **kwargs)
- result = custom_op(
- name,
- backend_fn,
- mutates_args=mutates_args,
- schema=infer_schema(fn, mutates_args=mutates_args),
- )
- from .._subclasses.functional_tensor import FunctionalTensorMode
- # We require that the user pass us a function that is make_fx traceable,
- # so we can just register it as the Fake/meta kernel.
- result.register_fake(fn)
- # We decompose the operator when FunctionalTensorMode is active.
- # The goal is to decompose the operator in AOTDispatcher.
- # - With torch.compile, this means that the backend (usually Inductor)
- # can see a call to the triton kernel(s) and so it can directly optimize
- # them by inlining them into the lowering process.
- def functional_decomp( # type: ignore[no-untyped-def]
- mode, op, types, args, kwargs
- ):
- # NOTE [Export custom triton op]
- # For torch.export (strict and non-strict), we don't do functional decomposition.
- # Instead, we preserve the custom triton ops as custom ops. This is because we want
- # the exported program to be high-level and serializable. If we decompose
- # the custom op to a functional hop and make it a node in exported program,
- # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
- # functions and triton dtypes. This is undesireble because:
- # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
- # - exported program will contain the implementation detail (e.g. triton source code) for a specific
- # backend (GPU), which is probably at a wrong level of abstraction.
- # - changes to triton or the serialization logic for triton arguments can be BC breaking
- #
- # In the short term, we expect users to have a separate aot_compile stage that compiles the exported program
- # into a Cubin file on the same machine that users call export, which does autotuning and removes triton
- # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
- # In the long term, we may export multiple cubins for the triton op directly
- from torch.export._trace import custom_triton_ops_decomposition_disabled
- if custom_triton_ops_decomposition_disabled():
- return mode.__torch_dispatch__(op, types, args, kwargs)
- else:
- # TODO: https://github.com/pytorch/pytorch/issues/160333
- # We should deduplicate the unrecognized_types logic.
- import torch._subclasses
- unrecognized_types = [
- t
- for t in types
- if not issubclass(t, torch._subclasses.FakeTensor)
- and t
- not in [
- torch.Tensor,
- torch._subclasses.functional_tensor.FunctionalTensor,
- ]
- ]
- if unrecognized_types:
- return NotImplemented
- with mode:
- return fn(*args, **kwargs)
- triton_kernels = get_inner_triton_kernels(fn)
- triton_ops_to_kernels[name] = triton_kernels
- result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
- return result
- if fn is None:
- return dec
- else:
- return dec(fn)
- wrap_triton_enabled = threading.local()
- wrap_triton_enabled_default = True
- @contextlib.contextmanager
- def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
- """If triton kernels annotated with @wrap_triton should dispatch via HOP
- or go straight to the triton kernel execution.
- We have this switch because eager-mode performance of HOP dispatch is slow
- enough to matter (~1ms) and we know that wrap_triton isn't necessary in
- some situations (eager-mode with regular Tensors)
- """
- try:
- prev = is_wrap_triton_enabled()
- wrap_triton_enabled.value = enabled
- yield
- finally:
- wrap_triton_enabled.value = prev
- def is_wrap_triton_enabled() -> bool:
- return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
- def capture_triton(triton_kernel: Callable, /) -> Any:
- """This API has been renamed to wrap_triton"""
- return wrap_triton(triton_kernel)
- @exposed_in("torch.library")
- def wrap_triton(triton_kernel: Callable, /) -> Any:
- """Allows capture of a triton kernel into a graph via make_fx or
- non-strict ``torch.export``.
- These technologies perform Dispatcher-based tracing (via
- ``__torch_dispatch__``) and cannot see calls to raw triton kernels.
- The ``wrap_triton`` API wraps a triton kernel into a callable that
- can actually be traced into a graph.
- Please use this API together with :func:`torch.library.triton_op`.
- Examples:
- >>> # xdoctest: +SKIP
- >>> import torch
- >>> import triton
- >>> from triton import language as tl
- >>> from torch.fx.experimental.proxy_tensor import make_fx
- >>> from torch.library import wrap_triton
- >>>
- >>> @triton.jit
- >>> def add_kernel(
- >>> in_ptr0,
- >>> in_ptr1,
- >>> out_ptr,
- >>> n_elements,
- >>> BLOCK_SIZE: "tl.constexpr",
- >>> ):
- >>> pid = tl.program_id(axis=0)
- >>> block_start = pid * BLOCK_SIZE
- >>> offsets = block_start + tl.arange(0, BLOCK_SIZE)
- >>> mask = offsets < n_elements
- >>> x = tl.load(in_ptr0 + offsets, mask=mask)
- >>> y = tl.load(in_ptr1 + offsets, mask=mask)
- >>> output = x + y
- >>> tl.store(out_ptr + offsets, output, mask=mask)
- >>>
- >>> def add(x, y):
- >>> output = torch.empty_like(x)
- >>> n_elements = output.numel()
- >>>
- >>> def grid_fn(meta):
- >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- >>>
- >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
- >>> return output
- >>>
- >>> x = torch.randn(3, device="cuda")
- >>> y = torch.randn(3, device="cuda")
- >>> gm = make_fx(add)(x, y)
- >>> print(gm.code)
- >>> # def forward(self, x_1, y_1):
- >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False)
- >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation(
- >>> # kernel_idx = 0, constant_args_idx = 0,
- >>> # grid = [(1, 1, 1)], kwargs = {
- >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like,
- >>> # 'n_elements': 3, 'BLOCK_SIZE': 16
- >>> # })
- >>> # return empty_like
- """
- from triton.runtime.autotuner import Autotuner
- from triton.runtime.jit import JITFunction
- from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper
- if not isinstance(triton_kernel, (JITFunction, Autotuner)):
- raise RuntimeError(
- "wrap_triton only works on functions annotated with triton.jit or triton.autotune"
- )
- if not is_wrap_triton_enabled():
- return triton_kernel
- return TraceableTritonKernelWrapper(triton_kernel, None, None)
|