_impl.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. # flake8: noqa: B950
  2. import math
  3. from typing import Callable, Optional, TypeVar
  4. from typing_extensions import ParamSpec
  5. import torch
  6. from torch.onnx.ops import _dtype_mappings
  7. # Use ParamSpec for better type preservation instead of bound Callable TypeVar
  8. _P = ParamSpec("_P")
  9. _R = TypeVar("_R")
  10. # ONNX to ATen decomp table
  11. ONNX_ATEN_DECOMP_TABLE: dict[torch._ops.OpOverload, Callable] = {}
  12. _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS = frozenset(
  13. {
  14. 1, # FLOAT
  15. 10, # FLOAT16
  16. 11, # DOUBLE
  17. 16, # BFLOAT16
  18. }
  19. )
  20. def _onnx_op(
  21. op_type: str, opset_version: int
  22. ) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
  23. """Decorator to register an ONNX operator with a custom implementation."""
  24. def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
  25. overload = f"opset{opset_version}"
  26. torch_op = torch.library.custom_op(
  27. f"onnx::{op_type}.{overload}", mutates_args=()
  28. )(func)
  29. ONNX_ATEN_DECOMP_TABLE[getattr(getattr(torch.ops.onnx, op_type), overload)] = (
  30. func # type: ignore[assignment]
  31. )
  32. # Use the same implementation for the fake implementation
  33. # This is possible because we use pure aten ops to implement ONNX ops
  34. torch_op.register_fake(func)
  35. return torch_op # type: ignore[return-value]
  36. return decorator
  37. @_onnx_op("RotaryEmbedding", 23)
  38. def rotary_embedding_23(
  39. x: torch.Tensor,
  40. cos_cache: torch.Tensor,
  41. sin_cache: torch.Tensor,
  42. position_ids: Optional[torch.Tensor] = None,
  43. *,
  44. interleaved: bool = False,
  45. num_heads: int = 0,
  46. rotary_embedding_dim: int = 0,
  47. ) -> torch.Tensor:
  48. """RotaryEmbedding-23 https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html#rotaryembedding-23"""
  49. # x has shape (batch_size, num_heads, sequence_length, head_size)
  50. # or (batch_size, sequence_length, hidden_size)
  51. input_shape = x.shape
  52. input_rank = len(input_shape)
  53. batch_size = input_shape[0]
  54. sequence_length = input_shape[-2]
  55. # Validate position_ids and caches match x
  56. if position_ids is not None:
  57. torch._check(
  58. position_ids.dim() == 2,
  59. lambda: f"position_ids must be 2D when provided. Received shape {position_ids.shape}",
  60. )
  61. torch._check(
  62. position_ids.shape[0] == batch_size,
  63. lambda: f"position_ids first dim (batch) must match x.shape[0] ({batch_size}). Received {position_ids.shape[0]}",
  64. )
  65. torch._check(
  66. position_ids.shape[1] == sequence_length,
  67. lambda: f"position_ids second dim (sequence) must match x.shape[-2] ({sequence_length}). Received {position_ids.shape[1]}",
  68. )
  69. torch._check(
  70. cos_cache.dim() == 2 and sin_cache.dim() == 2,
  71. lambda: "cos_cache/sin_cache must be 2D when position_ids is provided. "
  72. f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
  73. )
  74. else:
  75. torch._check(
  76. cos_cache.dim() == 3 and sin_cache.dim() == 3,
  77. lambda: "cos_cache/sin_cache must be 3D when position_ids is not provided. "
  78. f"Received cos_cache shape {cos_cache.shape}, sin_cache shape {sin_cache.shape}",
  79. )
  80. # First ensure x has shape [batch_size, num_heads, seq_len, head_size]
  81. # So that the rotation logic can be shared with reshaped 3D inputs
  82. if input_rank == 4:
  83. # Reshape from (batch_size, num_heads, seq_len, head_size)
  84. # to [batch_size, seq_len, num_heads, head_size]
  85. x = torch.permute(x, (0, 2, 1, 3))
  86. elif input_rank == 3:
  87. torch._check(
  88. num_heads != 0,
  89. lambda: f"num_heads must be provided for 3D inputs. Received input tensor with shape {input_shape}",
  90. )
  91. hidden_size = input_shape[2]
  92. head_size = hidden_size // num_heads
  93. new_shape = [batch_size, sequence_length, num_heads, head_size]
  94. x = torch.reshape(x, new_shape)
  95. torch._check(len(x.shape) == 4, lambda: "x should be a 4D tensor by now")
  96. head_size = x.shape[3]
  97. # Fully or partially perform rotation on x based on rotary_embedding_dim attribute
  98. if rotary_embedding_dim == 0:
  99. # If rotary_embedding_dim not provided, perform full rotation by using head_size
  100. rotary_embedding_dim = head_size
  101. x_rotate = x[:, :, :, :rotary_embedding_dim]
  102. x_not_rotate = x[:, :, :, rotary_embedding_dim:]
  103. rotary_embedding_dim_half = rotary_embedding_dim // 2
  104. # Retrieve sin and cos caches using position ids
  105. if position_ids is not None:
  106. cos = cos_cache[
  107. position_ids
  108. ] # Shape: [batch_size, sequence_length, head_size/2]
  109. sin = sin_cache[
  110. position_ids
  111. ] # Shape: [batch_size, sequence_length, head_size/2]
  112. else:
  113. cos = cos_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
  114. sin = sin_cache # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
  115. torch._check(
  116. cos.shape[0] == batch_size and cos.shape[1] == sequence_length,
  117. lambda: f"cos has shape {cos.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
  118. )
  119. torch._check(
  120. sin.shape[0] == batch_size and sin.shape[1] == sequence_length,
  121. lambda: f"sin has shape {sin.shape} but expected (batch={batch_size}, seq={sequence_length}, ...)",
  122. )
  123. torch._check(
  124. cos.shape[-1] == rotary_embedding_dim_half,
  125. lambda: f"Last dimension of cos cache ({cos.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
  126. )
  127. torch._check(
  128. sin.shape[-1] == rotary_embedding_dim_half,
  129. lambda: f"Last dimension of sin cache ({sin.shape[-1]}) should match rotary_embedding_dim/2 ({rotary_embedding_dim_half}).",
  130. )
  131. cos = torch.unsqueeze(
  132. cos, 2
  133. ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
  134. sin = torch.unsqueeze(
  135. sin, 2
  136. ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
  137. # Either divide the x in halves or interleave (based on interleaved attribute)
  138. if interleaved:
  139. x1 = x_rotate[:, :, :, 0::2]
  140. x2 = x_rotate[:, :, :, 1::2]
  141. else:
  142. x1, x2 = torch.chunk(x_rotate, 2, dim=-1)
  143. # Calculate real and imaginary values
  144. real = cos * x1 - sin * x2
  145. imag = sin * x1 + cos * x2
  146. # Inserted rotated embeddings back to the original x
  147. if interleaved:
  148. # x_rotate[:, :, :, 0::2] = real
  149. # x_rotate[:, :, :, 1::2] = imag
  150. real = torch.unsqueeze(real, -1)
  151. imag = torch.unsqueeze(imag, -1)
  152. x_rotate_concat = torch.cat((real, imag), dim=-1)
  153. x_rotate = torch.reshape(x_rotate_concat, x_rotate.shape)
  154. else:
  155. x_rotate = torch.cat((real, imag), dim=-1)
  156. output = torch.cat((x_rotate, x_not_rotate), dim=-1)
  157. if input_rank == 3:
  158. return torch.reshape(output, input_shape)
  159. # Return the dimensions to the original order
  160. return torch.permute(output, (0, 2, 1, 3))
  161. def _get_scale_factor(scale: Optional[float], head_size: int) -> float:
  162. """Get the scale factor for attention computation."""
  163. return scale if scale is not None else (1.0 / math.sqrt(head_size))
  164. def _reshape_3d_to_4d(
  165. tensor: torch.Tensor, batch_size: int, num_heads: int
  166. ) -> torch.Tensor:
  167. """Reshape 3D tensor to 4D for multi-head attention."""
  168. sequence_length, hidden_size = tensor.shape[1], tensor.shape[2]
  169. head_size = hidden_size // num_heads
  170. return (
  171. tensor.view(batch_size, sequence_length, num_heads, head_size)
  172. .transpose(1, 2)
  173. .contiguous()
  174. )
  175. def _get_qk_output_for_aten_spda(
  176. Q: torch.Tensor,
  177. K: torch.Tensor,
  178. current_q_num_heads: int,
  179. current_kv_num_heads: int,
  180. scale: Optional[float],
  181. qk_matmul_output_mode: int,
  182. ) -> torch.Tensor:
  183. """Get QK output tensor based on the specified mode."""
  184. if qk_matmul_output_mode == 0:
  185. return _compute_qk_output_for_mode_0(
  186. Q, K, current_q_num_heads, current_kv_num_heads, scale
  187. )
  188. else:
  189. # For other modes, return a zero tensor with correct shape
  190. return torch.zeros_like(torch.matmul(Q, K.transpose(-2, -1)))
  191. def _validate_gqa_configuration(
  192. current_q_num_heads: int, current_kv_num_heads: int
  193. ) -> None:
  194. """Validate Group Query Attention configuration."""
  195. torch._check(
  196. current_q_num_heads % current_kv_num_heads == 0,
  197. lambda: f"q_num_heads ({current_q_num_heads}) must be divisible by kv_num_heads ({current_kv_num_heads}) for GQA",
  198. )
  199. def _compute_qk_output_for_mode_0(
  200. Q: torch.Tensor,
  201. K: torch.Tensor,
  202. current_q_num_heads: int,
  203. current_kv_num_heads: int,
  204. scale: Optional[float],
  205. ) -> torch.Tensor:
  206. """Helper function to compute QK output for qk_matmul_output_mode == 0."""
  207. # Handle GQA manually for QK output
  208. K_for_qk = K
  209. if current_q_num_heads != current_kv_num_heads:
  210. repeat_factor = current_q_num_heads // current_kv_num_heads
  211. K_for_qk = K.repeat_interleave(repeat_factor, dim=1)
  212. scale_factor = _get_scale_factor(scale, Q.shape[3])
  213. # Scale both Q and K by sqrt(scale_factor) for numerical stability
  214. sqrt_scale = math.sqrt(scale_factor)
  215. Q_scaled = Q * sqrt_scale
  216. K_scaled = K_for_qk * sqrt_scale
  217. return torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
  218. @_onnx_op("Attention", 23)
  219. def attention_23(
  220. Q: torch.Tensor,
  221. K: torch.Tensor,
  222. V: torch.Tensor,
  223. attn_mask: Optional[torch.Tensor] = None,
  224. past_key: Optional[torch.Tensor] = None,
  225. past_value: Optional[torch.Tensor] = None,
  226. *,
  227. is_causal: bool = False,
  228. kv_num_heads: int = 0,
  229. q_num_heads: int = 0,
  230. qk_matmul_output_mode: int = 0,
  231. scale: Optional[float] = None,
  232. softcap: float = 0.0,
  233. softmax_precision: Optional[int] = None,
  234. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  235. """Attention-23 https://onnx.ai/onnx/operators/onnx__Attention.html#attention-23"""
  236. num_head_dim, sequence_dim, head_dim = 1, 2, 3
  237. # Store original input shape to determine output shape
  238. input_shape_len = len(Q.shape)
  239. batch_size = Q.shape[0]
  240. # Reshape 3D inputs to 4D format
  241. if len(Q.shape) == 3:
  242. torch._check(
  243. q_num_heads != 0 and kv_num_heads != 0,
  244. lambda: "q_num_heads and kv_num_heads must be provided for 3D inputs",
  245. )
  246. q_sequence_length = Q.shape[1]
  247. Q = _reshape_3d_to_4d(Q, batch_size, q_num_heads)
  248. K = _reshape_3d_to_4d(K, batch_size, kv_num_heads)
  249. V = _reshape_3d_to_4d(V, batch_size, kv_num_heads)
  250. torch._check(
  251. len(Q.shape) == 4 and len(K.shape) == 4 and len(V.shape) == 4,
  252. lambda: "Q, K, and V should be 4D tensors by now",
  253. )
  254. # Calculate scale factor if not provided
  255. q_head_size = Q.shape[head_dim]
  256. scale = _get_scale_factor(scale, q_head_size)
  257. # Handle past key/value caches
  258. present_key = (
  259. torch.cat([past_key, K], dim=sequence_dim)
  260. if past_key is not None
  261. else K.clone()
  262. )
  263. present_value = (
  264. torch.cat([past_value, V], dim=sequence_dim)
  265. if past_value is not None
  266. else V.clone()
  267. )
  268. # Update K and V to include past states
  269. K, V = present_key, present_value
  270. # Get current dimensions
  271. current_q_num_heads = Q.shape[num_head_dim]
  272. current_kv_num_heads = K.shape[num_head_dim]
  273. q_sequence_length = Q.shape[sequence_dim]
  274. kv_sequence_length = K.shape[sequence_dim]
  275. # Check if we can use the optimized scaled_dot_product_attention (most optimized)
  276. can_use_sdpa = (
  277. softcap == 0.0 # No softcap
  278. and qk_matmul_output_mode == 0 # Default QK output mode
  279. and softmax_precision is None # No custom softmax precision
  280. and (attn_mask is None or attn_mask.dtype == torch.bool)
  281. )
  282. _validate_gqa_configuration(current_q_num_heads, current_kv_num_heads)
  283. if can_use_sdpa:
  284. # Use PyTorch's optimized scaled_dot_product_attention
  285. # Prepare attention mask for SDPA
  286. sdpa_attn_mask = None
  287. if attn_mask is not None:
  288. # Convert boolean mask: True means participate, SDPA expects True to mask out
  289. sdpa_attn_mask = ~attn_mask if attn_mask.dtype == torch.bool else attn_mask
  290. output = torch.nn.functional.scaled_dot_product_attention(
  291. Q,
  292. K,
  293. V,
  294. attn_mask=sdpa_attn_mask,
  295. dropout_p=0.0,
  296. is_causal=is_causal,
  297. scale=scale,
  298. enable_gqa=bool(
  299. current_q_num_heads != current_kv_num_heads
  300. ), # Ensure enable_gqa is not SymBool
  301. )
  302. qk_output = _get_qk_output_for_aten_spda(
  303. Q,
  304. K,
  305. current_q_num_heads,
  306. current_kv_num_heads,
  307. scale,
  308. qk_matmul_output_mode,
  309. )
  310. else:
  311. # Fallback to manual implementation for complex cases
  312. # Handle Group Query Attention (GQA) and Multi-Query Attention (MQA)
  313. if current_q_num_heads != current_kv_num_heads:
  314. repeat_factor = current_q_num_heads // current_kv_num_heads
  315. K = K.repeat_interleave(repeat_factor, dim=num_head_dim)
  316. V = V.repeat_interleave(repeat_factor, dim=num_head_dim)
  317. # Create attention bias
  318. attn_bias = torch.zeros(
  319. q_sequence_length, kv_sequence_length, dtype=Q.dtype, device=Q.device
  320. )
  321. # Apply causal masking
  322. if is_causal:
  323. torch._check(
  324. attn_mask is None, lambda: "Cannot use both is_causal and attn_mask"
  325. )
  326. causal_mask = torch.tril(
  327. torch.ones(
  328. q_sequence_length,
  329. kv_sequence_length,
  330. dtype=torch.bool,
  331. device=Q.device,
  332. )
  333. )
  334. attn_bias = attn_bias.masked_fill(~causal_mask, float("-inf"))
  335. # Apply attention mask
  336. if attn_mask is not None:
  337. if attn_mask.dtype == torch.bool:
  338. # Boolean mask: True means participate in attention
  339. attn_bias = attn_bias.masked_fill(~attn_mask, float("-inf"))
  340. else:
  341. # Float mask: added to attention scores
  342. attn_bias = attn_bias + attn_mask
  343. # Apply scaling factor
  344. scale_factor = _get_scale_factor(scale, Q.shape[3])
  345. # Scale both Q and K by sqrt(scale_factor) for numerical stability
  346. sqrt_scale = math.sqrt(scale_factor)
  347. Q_scaled = Q * sqrt_scale
  348. K_scaled = K * sqrt_scale
  349. # Compute Q @ K^T
  350. qk_matmul_output = torch.matmul(Q_scaled, K_scaled.transpose(-2, -1))
  351. # Initialize QK output based on mode
  352. qk_output = qk_matmul_output # Default case for mode 0
  353. # Add attention bias
  354. qk_with_bias = qk_matmul_output + attn_bias
  355. if qk_matmul_output_mode == 1:
  356. qk_output = qk_with_bias
  357. # Apply softcap if provided
  358. if softcap > 0.0:
  359. qk_with_bias = softcap * torch.tanh(qk_with_bias / softcap)
  360. if qk_matmul_output_mode == 2:
  361. qk_output = qk_with_bias
  362. # Apply softmax with optional precision casting
  363. if softmax_precision is not None:
  364. # Map ONNX data type to torch dtype
  365. if softmax_precision in _ATTENTION_23_ALLOWED_INTERMEDIATE_PRECISIONS:
  366. original_dtype = qk_with_bias.dtype
  367. qk_with_bias = qk_with_bias.to(
  368. _dtype_mappings.ONNX_DTYPE_TO_TORCH_DTYPE[softmax_precision]
  369. )
  370. qk_softmax = torch.softmax(qk_with_bias, dim=-1)
  371. qk_softmax = qk_softmax.to(original_dtype)
  372. else:
  373. qk_softmax = torch.softmax(qk_with_bias, dim=-1)
  374. else:
  375. qk_softmax = torch.softmax(qk_with_bias, dim=-1)
  376. if qk_matmul_output_mode == 3:
  377. qk_output = qk_softmax
  378. # Compute attention output
  379. output = torch.matmul(qk_softmax, V)
  380. # Reshape output back to 3D if input was 3D
  381. if input_shape_len == 3:
  382. # output: (batch_size, q_num_heads, q_sequence_length, v_head_size) -> (batch_size, q_sequence_length, hidden_size)
  383. output = (
  384. output.transpose(1, 2).contiguous().view(batch_size, q_sequence_length, -1)
  385. )
  386. return output, present_key, present_value, qk_output