attention.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from typing import Final, Optional, Type
  2. import torch
  3. from torch import nn as nn
  4. from torch.nn import functional as F
  5. from ._fx import register_notrace_function
  6. from .config import use_fused_attn
  7. from .pos_embed_sincos import apply_rot_embed_cat
  8. @torch.fx.wrap
  9. @register_notrace_function
  10. def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
  11. return scores if attn_mask is None else scores + attn_mask
  12. class Attention(nn.Module):
  13. """Standard Multi-head Self Attention module with QKV projection.
  14. This module implements the standard multi-head attention mechanism used in transformers.
  15. It supports both the fused attention implementation (scaled_dot_product_attention) for
  16. efficiency when available, and a manual implementation otherwise. The module includes
  17. options for QK normalization, attention dropout, and projection dropout.
  18. """
  19. fused_attn: Final[bool]
  20. def __init__(
  21. self,
  22. dim: int,
  23. num_heads: int = 8,
  24. qkv_bias: bool = False,
  25. qk_norm: bool = False,
  26. scale_norm: bool = False,
  27. proj_bias: bool = True,
  28. attn_drop: float = 0.,
  29. proj_drop: float = 0.,
  30. norm_layer: Optional[Type[nn.Module]] = None,
  31. device=None,
  32. dtype=None
  33. ) -> None:
  34. """Initialize the Attention module.
  35. Args:
  36. dim: Input dimension of the token embeddings
  37. num_heads: Number of attention heads
  38. qkv_bias: Whether to use bias in the query, key, value projections
  39. qk_norm: Whether to apply normalization to query and key vectors
  40. proj_bias: Whether to use bias in the output projection
  41. attn_drop: Dropout rate applied to the attention weights
  42. proj_drop: Dropout rate applied after the output projection
  43. norm_layer: Normalization layer constructor for QK normalization if enabled
  44. """
  45. super().__init__()
  46. dd = {'device': device, 'dtype': dtype}
  47. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  48. if qk_norm or scale_norm:
  49. assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
  50. self.num_heads = num_heads
  51. self.head_dim = dim // num_heads
  52. self.scale = self.head_dim ** -0.5
  53. self.fused_attn = use_fused_attn()
  54. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  55. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  56. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  57. self.attn_drop = nn.Dropout(attn_drop)
  58. self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
  59. self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
  60. self.proj_drop = nn.Dropout(proj_drop)
  61. def forward(
  62. self,
  63. x: torch.Tensor,
  64. attn_mask: Optional[torch.Tensor] = None,
  65. ) -> torch.Tensor:
  66. B, N, C = x.shape
  67. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  68. q, k, v = qkv.unbind(0)
  69. q, k = self.q_norm(q), self.k_norm(k)
  70. if self.fused_attn:
  71. x = F.scaled_dot_product_attention(
  72. q, k, v,
  73. attn_mask=attn_mask,
  74. dropout_p=self.attn_drop.p if self.training else 0.,
  75. )
  76. else:
  77. q = q * self.scale
  78. attn = q @ k.transpose(-2, -1)
  79. attn = maybe_add_mask(attn, attn_mask)
  80. attn = attn.softmax(dim=-1)
  81. attn = self.attn_drop(attn)
  82. x = attn @ v
  83. x = x.transpose(1, 2).reshape(B, N, C)
  84. x = self.norm(x)
  85. x = self.proj(x)
  86. x = self.proj_drop(x)
  87. return x
  88. class AttentionRope(nn.Module):
  89. """ A Self Attention module with ROPE support.
  90. Includes options for:
  91. * QK normalization option
  92. * Attention output (scale) normalization
  93. * Fused or unfused QKV projection support
  94. """
  95. fused_attn: torch.jit.Final[bool]
  96. def __init__(
  97. self,
  98. dim: int,
  99. num_heads: int = 8,
  100. qkv_bias: bool = True,
  101. qkv_fused: bool = True,
  102. num_prefix_tokens: int = 1,
  103. attn_drop: float = 0.,
  104. proj_drop: float = 0.,
  105. attn_head_dim: Optional[int] = None,
  106. norm_layer: Type[nn.Module] = None,
  107. qk_norm: bool = False,
  108. scale_norm: bool = False,
  109. proj_bias: bool = True,
  110. rotate_half: bool = False,
  111. device=None,
  112. dtype=None,
  113. ):
  114. """Initialize the Attention module.
  115. Args:
  116. dim: Input dimension of the token embeddings
  117. num_heads: Number of attention heads
  118. qkv_bias: Whether to add a bias term to the query, key, and value projections
  119. num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
  120. should not have position embeddings applied
  121. attn_drop: Dropout rate for attention weights
  122. proj_drop: Dropout rate for the output projection
  123. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  124. norm_layer: Normalization layer constructor to use for QK and scale normalization
  125. qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
  126. scale_norm: Enable normalization (scaling) of attention output with norm_layer
  127. rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
  128. """
  129. super().__init__()
  130. dd = {'device': device, 'dtype': dtype}
  131. if scale_norm or qk_norm:
  132. assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
  133. self.num_heads = num_heads
  134. head_dim = dim // num_heads
  135. if attn_head_dim is not None:
  136. head_dim = attn_head_dim
  137. attn_dim = head_dim * self.num_heads
  138. self.scale = head_dim ** -0.5
  139. self.num_prefix_tokens = num_prefix_tokens
  140. self.fused_attn = use_fused_attn()
  141. self.rotate_half = rotate_half
  142. if qkv_fused:
  143. self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
  144. self.q_proj = self.k_proj = self.v_proj = None
  145. else:
  146. self.qkv = None
  147. self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  148. self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  149. self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  150. self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
  151. self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
  152. self.attn_drop = nn.Dropout(attn_drop)
  153. self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
  154. self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd)
  155. self.proj_drop = nn.Dropout(proj_drop)
  156. def forward(
  157. self,
  158. x,
  159. rope: Optional[torch.Tensor] = None,
  160. attn_mask: Optional[torch.Tensor] = None,
  161. ):
  162. """Forward pass for the attention module.
  163. Args:
  164. x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
  165. rope: Rotary position embeddings tensor for position-aware attention
  166. attn_mask: Optional attention mask to apply during attention computation
  167. Returns:
  168. Tensor of shape (batch_size, sequence_length, embedding_dim)
  169. """
  170. B, N, C = x.shape
  171. if self.qkv is not None:
  172. qkv = self.qkv(x)
  173. qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  174. q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
  175. else:
  176. q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
  177. k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
  178. v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
  179. q, k = self.q_norm(q), self.k_norm(k)
  180. if rope is not None:
  181. npt = self.num_prefix_tokens
  182. half = getattr(self, 'rotate_half', False)
  183. q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  184. k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  185. if self.fused_attn:
  186. x = F.scaled_dot_product_attention(
  187. q, k, v,
  188. attn_mask=attn_mask,
  189. dropout_p=self.attn_drop.p if self.training else 0.,
  190. )
  191. else:
  192. q = q * self.scale
  193. attn = (q @ k.transpose(-2, -1))
  194. attn = maybe_add_mask(attn, attn_mask)
  195. attn = attn.softmax(dim=-1)
  196. attn = self.attn_drop(attn)
  197. x = attn @ v
  198. x = x.transpose(1, 2).reshape(B, N, C)
  199. x = self.norm(x)
  200. x = self.proj(x)
  201. x = self.proj_drop(x)
  202. return x