attention_pool2d.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. """ Attention Pool 2D
  2. Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
  3. Based on idea in CLIP by OpenAI, licensed Apache 2.0
  4. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
  5. Hacked together by / Copyright 2021 Ross Wightman
  6. """
  7. from typing import Optional, Union, Tuple
  8. import torch
  9. import torch.nn as nn
  10. from .config import use_fused_attn
  11. from .helpers import to_2tuple
  12. from .pos_embed import resample_abs_pos_embed
  13. from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
  14. from .weight_init import trunc_normal_
  15. class RotAttentionPool2d(nn.Module):
  16. """ Attention based 2D feature pooling w/ rotary (relative) pos embedding.
  17. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
  18. Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
  19. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
  20. NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
  21. train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
  22. """
  23. fused_attn: torch.jit.Final[bool]
  24. def __init__(
  25. self,
  26. in_features: int,
  27. out_features: Optional[int] = None,
  28. ref_feat_size: Union[int, Tuple[int, int]] = 7,
  29. embed_dim: Optional[int] = None,
  30. head_dim: Optional[int] = 64,
  31. num_heads: Optional[int] = None,
  32. qkv_bias: bool = True,
  33. qkv_separate: bool = False,
  34. pool_type: str = 'token',
  35. class_token: bool = False,
  36. drop_rate: float = 0.,
  37. device=None,
  38. dtype=None,
  39. ):
  40. dd = {'device': device, 'dtype': dtype}
  41. super().__init__()
  42. assert pool_type in ('', 'token')
  43. self.embed_dim = embed_dim = embed_dim or in_features
  44. self.in_features = in_features
  45. self.out_features = out_features or in_features
  46. ref_feat_size = to_2tuple(ref_feat_size)
  47. if num_heads is not None:
  48. assert embed_dim % num_heads == 0
  49. head_dim = embed_dim // num_heads
  50. else:
  51. assert embed_dim % head_dim == 0
  52. num_heads = embed_dim // head_dim
  53. self.num_heads = num_heads
  54. self.head_dim = head_dim
  55. self.pool_type = pool_type.lower()
  56. self.scale = self.head_dim ** -0.5
  57. self.fused_attn = use_fused_attn()
  58. if class_token:
  59. self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd))
  60. else:
  61. self.cls_token = None
  62. if qkv_separate:
  63. self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  64. self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  65. self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  66. self.qkv = None
  67. else:
  68. self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
  69. self.drop = nn.Dropout(drop_rate)
  70. self.proj = nn.Linear(embed_dim, self.out_features, **dd)
  71. self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size, **dd)
  72. def init_weights(self, zero_init_last: bool = False):
  73. if self.qkv is None:
  74. in_features = self.q.in_features
  75. trunc_normal_(self.q.weight, std=in_features ** -0.5)
  76. nn.init.zeros_(self.q.bias)
  77. trunc_normal_(self.k.weight, std=in_features ** -0.5)
  78. nn.init.zeros_(self.k.bias)
  79. trunc_normal_(self.v.weight, std=in_features ** -0.5)
  80. nn.init.zeros_(self.v.bias)
  81. else:
  82. in_features = self.qkv.in_features
  83. trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
  84. nn.init.zeros_(self.qkv.bias)
  85. def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
  86. # NOTE: this module is being used as a head, so need compatible reset()
  87. if pool_type is not None:
  88. assert pool_type in ('', 'token')
  89. self.pool_type = pool_type
  90. if num_classes is not None:
  91. self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
  92. self.out_features = num_classes if num_classes > 0 else self.embed_dim
  93. def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
  94. if self.pool_type == 'token':
  95. x = x[:, 0]
  96. else:
  97. # if not pooled, return spatial output without token
  98. x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
  99. return x
  100. def forward(self, x, pre_logits: bool = False):
  101. B, _, H, W = x.shape
  102. N = H * W
  103. x = x.flatten(2).transpose(1, 2)
  104. if self.cls_token is None:
  105. x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
  106. else:
  107. x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
  108. if self.qkv is None:
  109. q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  110. k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  111. v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  112. else:
  113. x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  114. q, k, v = x.unbind(0)
  115. rse, rce = self.pos_embed.get_embed((H, W))
  116. q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
  117. k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
  118. if self.fused_attn:
  119. x = nn.functional.scaled_dot_product_attention(q, k, v)
  120. else:
  121. q = q * self.scale
  122. attn = q @ k.transpose(-2, -1)
  123. attn = attn.softmax(dim=-1)
  124. x = attn @ v
  125. x = x.transpose(1, 2).reshape(B, N + 1, -1)
  126. x = self.drop(x)
  127. if pre_logits:
  128. x = self._pool(x, H, W)
  129. return x
  130. x = self.proj(x)
  131. x = self._pool(x, H, W)
  132. return x
  133. class AttentionPool2d(nn.Module):
  134. """ Attention based 2D feature pooling w/ learned (absolute) pos embedding.
  135. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
  136. It was based on impl in CLIP by OpenAI
  137. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
  138. NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
  139. """
  140. fused_attn: torch.jit.Final[bool]
  141. def __init__(
  142. self,
  143. in_features: int,
  144. feat_size: Union[int, Tuple[int, int]] = 7,
  145. out_features: Optional[int] = None,
  146. embed_dim: Optional[int] = None,
  147. head_dim: Optional[int] = 64,
  148. num_heads: Optional[int] = None,
  149. qkv_bias: bool = True,
  150. qkv_separate: bool = False,
  151. pool_type: str = 'token',
  152. class_token: bool = False,
  153. drop_rate: float = 0.,
  154. device=None,
  155. dtype=None,
  156. ):
  157. dd = {'device': device, 'dtype': dtype}
  158. super().__init__()
  159. assert pool_type in ('', 'token')
  160. self.embed_dim = embed_dim = embed_dim or in_features
  161. self.in_features = in_features
  162. self.out_features = out_features or in_features
  163. if num_heads is not None:
  164. assert embed_dim % num_heads == 0
  165. head_dim = embed_dim // num_heads
  166. else:
  167. assert embed_dim % head_dim == 0
  168. num_heads = embed_dim // head_dim
  169. self.feat_size = to_2tuple(feat_size)
  170. self.seq_len = self.feat_size[0] * self.feat_size[1]
  171. self.num_heads = num_heads
  172. self.head_dim = head_dim
  173. self.pool_type = pool_type
  174. self.scale = self.head_dim ** -0.5
  175. self.fused_attn = use_fused_attn()
  176. if class_token:
  177. self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd))
  178. else:
  179. self.cls_token = None
  180. if qkv_separate:
  181. self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  182. self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  183. self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
  184. self.qkv = None
  185. else:
  186. self.q = self.k = self.v = None
  187. self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
  188. self.drop = nn.Dropout(drop_rate)
  189. self.proj = nn.Linear(embed_dim, self.out_features, **dd)
  190. self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features, **dd))
  191. self.init_weights()
  192. def init_weights(self, zero_init_last: bool = False):
  193. if self.qkv is None:
  194. in_features = self.q.in_features
  195. trunc_normal_(self.q.weight, std=in_features ** -0.5)
  196. nn.init.zeros_(self.q.bias)
  197. trunc_normal_(self.k.weight, std=in_features ** -0.5)
  198. nn.init.zeros_(self.k.bias)
  199. trunc_normal_(self.v.weight, std=in_features ** -0.5)
  200. nn.init.zeros_(self.v.bias)
  201. else:
  202. in_features = self.qkv.in_features
  203. trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
  204. nn.init.zeros_(self.qkv.bias)
  205. trunc_normal_(self.pos_embed, std=in_features ** -0.5)
  206. def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
  207. # NOTE: this module is being used as a head, so need compatible reset()
  208. if pool_type is not None:
  209. assert pool_type in ('', 'token')
  210. self.pool_type = pool_type
  211. if num_classes is not None:
  212. self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
  213. self.out_features = num_classes if num_classes > 0 else self.embed_dim
  214. def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
  215. if self.pool_type == 'token':
  216. x = x[:, 0]
  217. else:
  218. # if not pooled, return spatial output without token
  219. x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
  220. return x
  221. def forward(self, x, pre_logits: bool = False):
  222. B, _, H, W = x.shape
  223. N = H * W
  224. x = x.flatten(2).transpose(1, 2)
  225. if self.cls_token is None:
  226. x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
  227. else:
  228. x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
  229. pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
  230. x = x + pos_embed
  231. if self.qkv is None:
  232. q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  233. k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  234. v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
  235. else:
  236. x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  237. q, k, v = x.unbind(0)
  238. if self.fused_attn:
  239. x = nn.functional.scaled_dot_product_attention(q, k, v)
  240. else:
  241. q = q * self.scale
  242. attn = q @ k.transpose(-2, -1)
  243. attn = attn.softmax(dim=-1)
  244. x = attn @ v
  245. x = x.transpose(1, 2).reshape(B, N + 1, -1)
  246. x = self.drop(x)
  247. if pre_logits:
  248. x = self._pool(x, H, W)
  249. return x
  250. x = self.proj(x)
  251. x = self._pool(x, H, W)
  252. return x