| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- from __future__ import annotations
- import typing
- from typing import Any, Optional, TYPE_CHECKING, Union
- import sympy
- import torch
- from . import config
- from .codecache import write_text
- from .kernel_inputs import KernelInputs # noqa: TC001
- from .metrics import get_metric_table, is_metric_table_enabled
- from .runtime.hints import DeviceProperties, ReductionHint
- from .scheduler import BaseSchedulerNode, Scheduler, WhyNoFuse
- from .template_heuristics import get_template_heuristic
- from .template_heuristics.triton import (
- BaseConfigHeuristic,
- CPUConfigHeuristic,
- CUDAConfigHeuristic,
- MTIAConfigHeuristic,
- ROCmConfigHeuristic,
- XPUConfigHeuristic,
- )
- from .virtualized import V
- if TYPE_CHECKING:
- from collections.abc import Generator
- from functools import partial
- from triton import Config as TritonConfig
- from torch.utils._ordered_set import OrderedSet
- from .codegen.common import KernelTemplate
- from .codegen.simd_kernel_features import SIMDKernelFeatures
- from .codegen.triton import TritonKernel
- from .ir import ChoiceCaller
- from .select_algorithm import ExternKernelChoice
- class Sortable(typing.Protocol):
- """Anything that can be used as a list.sort() key (int/tuple/etc)"""
- def __lt__(self, other: typing.Self) -> bool: ...
- class InductorChoices:
- """
- This class contains a collection of default heuristics that effect performance of our generated
- code. We try to not put correctness requirements in this file.
- You can override the choices made here by doing:
- class MyHeuristics(InductorChoices):
- ...
- torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
- """
- def get_config_heuristics(
- self, device_type: Optional[str] = "cuda"
- ) -> BaseConfigHeuristic:
- if device_type == "cuda":
- if torch.version.hip is None:
- return CUDAConfigHeuristic()
- else:
- return ROCmConfigHeuristic()
- elif device_type == "xpu":
- return XPUConfigHeuristic()
- elif device_type == "cpu":
- return CPUConfigHeuristic()
- elif device_type == "mtia":
- return MTIAConfigHeuristic()
- else:
- return BaseConfigHeuristic()
- # Conv configs
- def get_conv_configs(
- self, device_type: Optional[str] = "cuda"
- ) -> partial[Generator[TritonConfig, None, None]]:
- conv_heuristics = self.get_config_heuristics(device_type)
- return conv_heuristics.get_conv_configs()
- # Flex attention configs
- # TODO(coconutruben): break out flexattention/decode configs into the new retrieval mechanism
- def get_flex_attention_fwd_configs(
- self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
- ) -> list[Any]:
- flex_heuristics = self.get_config_heuristics(device_type)
- return flex_heuristics.get_flex_attn_fwd_configs(head_dim, dtype)
- def get_flex_attention_bwd_configs(
- self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
- ) -> list[Any]:
- flex_heuristics = self.get_config_heuristics(device_type)
- return flex_heuristics.get_flex_attn_bwd_configs(head_dim, dtype)
- def get_flex_decode_configs(
- self, head_dim: int, dtype: torch.dtype, device_type: Optional[str] = "cuda"
- ) -> list[Any]:
- flex_heuristics = self.get_config_heuristics(device_type)
- return flex_heuristics.get_flex_decode_configs(head_dim, dtype)
- def get_mm_configs(
- self,
- kernel_inputs: KernelInputs,
- layout: Any,
- templates: list[Union[KernelTemplate, ExternKernelChoice]],
- op_name: str,
- kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None,
- ) -> Generator[ChoiceCaller, None, None]:
- """
- Get generator of ChoiceCallers for MM templates using template-specific heuristics.
- Args:
- kernel_inputs: MMKernelInputs containing input tensor nodes and matrix indices
- layout: Output layout
- templates: List of template objects (KernelTemplate or ExternKernelChoice)
- op_name: Operation name (e.g., "bmm", "baddbmm", "addmm", "mm_plus_mm")
- kwarg_overrides: Optional dict of kwargs to override for each template heuristic,
- indexed by template.uid. These only override the per config kwargs, not the extra kwargs
- Yields:
- ChoiceCaller objects from the templates
- """
- if kwarg_overrides is None:
- kwarg_overrides = {}
- input_tensors = kernel_inputs.nodes()
- if len(input_tensors) < 2:
- raise ValueError(f"Need at least 2 input tensors, got {len(input_tensors)}")
- # Extract device_type from kernel_inputs
- device_type = kernel_inputs.device_type
- assert device_type is not None, "get_mm_configs requires a valid device type"
- for template in templates:
- # Extract template_name from the template object
- template_name = template.uid
- # Get the appropriate template-specific heuristic
- heuristic = get_template_heuristic(template_name, device_type, op_name)
- cs = heuristic.get_template_configs(
- kernel_inputs,
- layout,
- op_name,
- )
- extra_kwargs = heuristic.get_extra_kwargs(kernel_inputs, layout, op_name)
- # Extract layout and input_nodes from extra_kwargs to pass them explicitly
- layout_val = layout
- # adjust the kernel inputs to the template-specific heuristic, if needed
- # default here is to just return the kernel_inputs as is
- input_nodes_val = heuristic.adjust_kernel_inputs(
- kernel_inputs, op_name
- ).nodes()
- # Get overrides for this specific template
- overrides = kwarg_overrides.get(template.uid, {})
- extra_kwargs["layout"] = layout_val
- extra_kwargs["input_nodes"] = input_nodes_val
- for c in cs:
- choice = template.choice_or_none(**{**c, **overrides}, **extra_kwargs)
- if choice is not None:
- yield choice
- def triton_kernel_kwargs(
- self,
- kernel_cls: type[TritonKernel],
- features: SIMDKernelFeatures,
- groups: list[sympy.Expr],
- kernel_kwargs: dict[str, Any],
- ) -> dict[str, Any]:
- """Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
- return kernel_kwargs
- @staticmethod
- def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool:
- """Heuristic to decide if a cooperative reduction should be used."""
- if config.triton.force_cooperative_reductions:
- return True
- if (
- not config.triton.cooperative_reductions
- or V.graph.get_current_device_or_throw().type == "cpu"
- ):
- return False
- xhint = V.graph.sizevars.size_hint(features.numel, fallback=2)
- if xhint <= 8:
- threshold = 32768 * xhint
- elif xhint <= 16:
- threshold = 2097152
- else:
- return False
- # TODO(jansel): should this default on for dynamic shapes?
- return V.graph.sizevars.statically_known_geq(
- features.reduction_numel, threshold
- )
- @staticmethod
- def should_use_persistent_reduction(
- features: SIMDKernelFeatures, cooperative_reduction: bool
- ) -> bool:
- """
- Heuristic to decide if a persistent reduction should be used.
- """
- if not config.triton.persistent_reductions:
- return False
- threshold = {
- ReductionHint.INNER: 1024,
- }.get(features.get_reduction_hint(), 64)
- if cooperative_reduction:
- # The RSPLIT of cooperative reductions means each thread block is operating on fewer elements
- try:
- threshold *= 32 // min(
- V.graph.sizevars.size_hint_or_throw(features.numel), 32
- )
- except ValueError:
- pass # unbacked symint
- # If multi_kernel is enabled, we do more aggressive persistent reduction.
- # This may result in some persistent reductions slower than the
- # corresponding non-persistent reductions. MultiKernel will do benchmarking
- # to pick the faster one.
- if config.triton.multi_kernel:
- threshold *= 16
- return V.graph.sizevars.statically_known_leq(
- features.reduction_numel, threshold
- ) # type: ignore[arg-types]
- @staticmethod
- def reduction_split_factor(
- device: torch.device,
- reduction_numel_hint: int,
- numel_hint: int,
- inner_reduction: bool,
- ) -> int:
- """Heuristic to decide the RSPLIT used for split reductions.
- When a reduction has a small number of outputs there is not enough parallelism,
- so we will do the reduction in two phases."""
- props = DeviceProperties.create(device)
- num_sm = props.multi_processor_count
- min_elements_per_thread = 32
- max_elements_per_thread = 512
- threads_per_sm = 2048
- min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
- max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
- num_warps = 8
- num_threads = 32 * num_warps
- if inner_reduction:
- # do heuristics that's close to eager mode for split inner reduction
- # we leak reduction autotune configs here, and will need to refactor to avoid this later
- if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
- return 1
- if reduction_numel_hint <= 8192:
- return 1
- if reduction_numel_hint * numel_hint <= min_elements_per_device:
- split_size = min_elements_per_thread
- elif reduction_numel_hint * numel_hint < max_elements_per_device:
- target_blocks = num_sm * threads_per_sm // (2 * num_threads)
- blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
- tmp_split_size = (
- reduction_numel_hint + num_threads * blocks_per_output - 1
- ) // (num_threads * blocks_per_output)
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
- if abs(closest - tmp_split_size) < 30:
- # prefer even splits, but never smalle than min_elements_per_thread
- split_size = max(closest, min_elements_per_thread)
- else:
- split_size = tmp_split_size
- else:
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
- if abs(closest - max_elements_per_thread) < 50:
- # prefer even splits
- split_size = closest
- else:
- split_size = max_elements_per_thread
- return (reduction_numel_hint + split_size * num_threads - 1) // (
- split_size * num_threads
- )
- else:
- # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
- # extend to even smaller number of outputs
- rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
- xvals_per_block = 128
- xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
- if reduction_numel_hint * numel_hint < min_elements_per_device:
- split_size = min_elements_per_thread
- elif reduction_numel_hint * numel_hint < max_elements_per_device:
- target_blocks = num_sm * threads_per_sm // (num_threads)
- target_blocks = (target_blocks + xblocks - 1) // xblocks
- tmp_split_size = (
- reduction_numel_hint + rvals_per_thread * target_blocks - 1
- ) // (rvals_per_thread * target_blocks)
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
- if abs(tmp_split_size - closest) < 20:
- split_size = max(closest, min_elements_per_thread)
- else:
- split_size = tmp_split_size
- else:
- divisors = sympy.divisors(reduction_numel_hint)
- closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
- if abs(closest - max_elements_per_thread) < 50:
- # prefer even splits
- split_size = closest
- else:
- split_size = max_elements_per_thread
- return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
- rvals_per_thread * split_size
- )
- @staticmethod
- def can_fuse(
- scheduler: Scheduler,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- shared_data_score: int,
- ) -> bool:
- """
- Heuristics to prevent fusion applied to both horizontal and vertical fusions. Heuristics here should not
- be needed for correctness and tweaking them may yield additional performance.
- See also some related heuristics that can be changed via config:
- - config.triton.tiling_prevents_pointwise_fusion
- - config.triton.tiling_prevents_reduction_fusion
- - config.aggressive_fusion (will cause this function to be called more times)
- """
- if shared_data_score == 0 and (
- not config.aggressive_fusion or node1.is_reduction() or node2.is_reduction()
- ):
- if is_metric_table_enabled("fusion_failure_due_to_indexing_mismatch"):
- common_buf_names: OrderedSet[str] = (
- node1.read_writes.buffer_names() & node2.read_writes.buffer_names()
- )
- if len(common_buf_names) > 0:
- get_metric_table("fusion_failure_due_to_indexing_mismatch").add_row(
- lambda: {
- "pre_grad_graph_id": V.graph.graph_id,
- "post_grad_graph_id": V.graph.post_grad_graph_id,
- "node1_name": node1.get_name(),
- "node2_name": node2.get_name(),
- "node1_debug_str": write_text(node1.debug_str()),
- "node2_debug_str": write_text(node2.debug_str()),
- "common_buffer_names": list(common_buf_names), # type: ignore[dict-item]
- "failure_reason": scheduler.decide_fusion_fail_reason(
- node1, node2, common_buf_names
- ),
- }
- )
- WhyNoFuse(node1, node2)("no shared data due to indexing mismatch")
- return False
- WhyNoFuse(node1, node2)("no shared data")
- return False # heuristic not needed for correctness
- if (
- not node1.is_foreach()
- and not node2.is_foreach()
- and len(node1.get_nodes()) + len(node2.get_nodes()) > config.max_fusion_size
- ):
- WhyNoFuse(node1, node2)("exceeds max fusion")
- return False # heuristic not needed for correctness
- if scheduler.can_fusion_increase_peak_memory(node1, node2):
- WhyNoFuse(node1, node2)("Fusion will increase peak memory")
- return False
- if (
- config.realize_acc_reads_size_threshold is not None
- and scheduler.fusion_accumulate_large_reads(
- node1,
- node2,
- config.realize_acc_reads_size_threshold,
- )
- ):
- WhyNoFuse(node1, node2)("Fusion accumulate large amount of reads")
- return False
- return True
- @staticmethod
- def can_fuse_vertical(
- scheduler: Scheduler,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- shared_data_score: int,
- ) -> bool:
- """Hook for heuristics to prevent vertical (producer/consumer) fusions"""
- return True
- @staticmethod
- def can_fuse_horizontal(
- scheduler: Scheduler,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- shared_data_score: int,
- ) -> bool:
- """Hook for heuristics to prevent horizontal (consumer/consumer) fusions"""
- if shared_data_score < config.score_fusion_memory_threshold:
- WhyNoFuse(node1, node2)("score_fusion_memory_threshold")
- return False
- if scheduler.are_long_distant_nodes(node1, node2):
- WhyNoFuse(node1, node2)(
- "Nodes are too far away. Fusing them may increase peak memory."
- )
- return False
- return True
- @staticmethod
- def score_fusion(
- scheduler: Scheduler,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- ) -> Sortable:
- """
- Assign a score (higher comes first) to the fusion of node1 and node2.
- When different fusions conflict with each other, this is the way we
- decide what order to run them in.
- Our current score is based on:
- - The type of fusion (template/reduction/etc)
- - Estimate of the saved memory operations
- - Fusions closer together in original graph order
- """
- memory_score = scheduler.score_fusion_memory(node1, node2)
- proximity_score = -max(
- abs(node1.min_order - node2.max_order),
- abs(node2.min_order - node1.max_order),
- )
- # prologue fusion always last
- if node2.is_template():
- template_score = 0
- else:
- template_score = 1 + (
- (node1.is_template() == config.epilogue_fusion_first)
- and memory_score > 0
- )
- return (
- template_score,
- node1.is_reduction() == node2.is_reduction() and memory_score > 0,
- memory_score,
- proximity_score,
- )
|