pos_embed_rel.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. """ Relative position embedding modules and functions
  2. Hacked together by / Copyright 2022 Ross Wightman
  3. """
  4. import math
  5. import os
  6. from typing import Optional, Tuple
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from .grid import ndgrid
  11. from .interpolate import RegularGridInterpolator
  12. from .mlp import Mlp
  13. from .weight_init import trunc_normal_
  14. _USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0
  15. def gen_relative_position_index(
  16. q_size: Tuple[int, int],
  17. k_size: Optional[Tuple[int, int]] = None,
  18. class_token: bool = False,
  19. device=None,
  20. ) -> torch.Tensor:
  21. # Adapted with significant modifications from Swin / BeiT codebases
  22. # get pair-wise relative position index for each token inside the window
  23. assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
  24. coords = torch.stack(ndgrid(
  25. torch.arange(q_size[0], device=device),
  26. torch.arange(q_size[1], device=device),
  27. )).flatten(1) # 2, Wh, Ww
  28. relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
  29. relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
  30. relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
  31. relative_coords[:, :, 1] += q_size[1] - 1
  32. relative_coords[:, :, 0] *= 2 * q_size[1] - 1
  33. num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)
  34. # else:
  35. # # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
  36. # q_coords = torch.stack(
  37. # ndgrid(
  38. # torch.arange(q_size[0]),
  39. # torch.arange(q_size[1])
  40. # )
  41. # ).flatten(1) # 2, Wh, Ww
  42. # k_coords = torch.stack(
  43. # ndgrid(
  44. # torch.arange(k_size[0]),
  45. # torch.arange(k_size[1])
  46. # )
  47. # ).flatten(1)
  48. # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
  49. # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
  50. # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
  51. # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
  52. # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
  53. # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
  54. # num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3
  55. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  56. if class_token:
  57. # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
  58. # NOTE not intended or tested with MLP log-coords
  59. relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
  60. relative_position_index[0, 0:] = num_relative_distance
  61. relative_position_index[0:, 0] = num_relative_distance + 1
  62. relative_position_index[0, 0] = num_relative_distance + 2
  63. return relative_position_index.contiguous()
  64. def resize_rel_pos_bias_table_simple(
  65. rel_pos_bias,
  66. new_window_size: Tuple[int, int],
  67. new_bias_shape: Tuple[int, ...],
  68. ):
  69. dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
  70. if rel_pos_bias.ndim == 3:
  71. # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
  72. _, dst_h, dst_w = new_bias_shape
  73. num_attn_heads, src_h, src_w = rel_pos_bias.shape
  74. assert dst_h == dst_size[0] and dst_w == dst_size[1]
  75. if src_h != dst_h or src_w != dst_w:
  76. rel_pos_bias = torch.nn.functional.interpolate(
  77. rel_pos_bias.unsqueeze(0),
  78. size=dst_size,
  79. mode="bicubic",
  80. align_corners=False,
  81. ).squeeze(0)
  82. else:
  83. assert rel_pos_bias.ndim == 2
  84. # (num_pos, num_heads) (aka flat) bias shape
  85. dst_num_pos, _ = new_bias_shape
  86. src_num_pos, num_attn_heads = rel_pos_bias.shape
  87. num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
  88. src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
  89. src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed
  90. if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
  91. if num_extra_tokens:
  92. extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
  93. rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
  94. else:
  95. extra_tokens = None
  96. rel_pos_bias = torch.nn.functional.interpolate(
  97. rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
  98. size=dst_size,
  99. mode="bicubic",
  100. align_corners=False,
  101. ).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)
  102. if extra_tokens is not None:
  103. rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
  104. return rel_pos_bias
  105. def resize_rel_pos_bias_table_levit(
  106. position_bias_table,
  107. new_size,
  108. interpolation: str = 'bicubic',
  109. antialias: bool = True,
  110. ):
  111. """
  112. Resample relative position bias table suggested in LeVit
  113. Adapted from: https://github.com/microsoft/Cream/blob/main/TinyViT/utils.py
  114. """
  115. L1, nH1 = position_bias_table.size()
  116. L2, nH2 = new_size
  117. assert nH1 == nH2
  118. if L1 != L2:
  119. orig_dtype = position_bias_table.dtype
  120. position_bias_table = position_bias_table.float()
  121. # bicubic interpolate relative_position_bias_table if not match
  122. S1 = int(L1 ** 0.5)
  123. S2 = int(L2 ** 0.5)
  124. relative_position_bias_table_resized = F.interpolate(
  125. position_bias_table.permute(1, 0).view(1, nH1, S1, S1),
  126. size=(S2, S2),
  127. mode=interpolation,
  128. antialias=antialias,
  129. )
  130. relative_position_bias_table_resized = relative_position_bias_table_resized.view(nH2, L2).permute(1, 0)
  131. relative_position_bias_table_resized.to(orig_dtype)
  132. return relative_position_bias_table_resized
  133. else:
  134. return position_bias_table
  135. def resize_rel_pos_bias_table(
  136. rel_pos_bias,
  137. new_window_size: Tuple[int, int],
  138. new_bias_shape: Tuple[int, ...],
  139. ):
  140. """ Resize relative position bias table using more advanced interpolation.
  141. Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
  142. https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
  143. Args:
  144. rel_pos_bias:
  145. new_window_size:
  146. new_bias_shape:
  147. Returns:
  148. """
  149. if _USE_SCIPY:
  150. from scipy import interpolate
  151. dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
  152. if rel_pos_bias.ndim == 3:
  153. # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
  154. num_extra_tokens = 0
  155. _, dst_h, dst_w = new_bias_shape
  156. assert dst_h == dst_size[0] and dst_w == dst_size[1]
  157. num_attn_heads, src_h, src_w = rel_pos_bias.shape
  158. src_size = (src_h, src_w)
  159. has_flat_shape = False
  160. else:
  161. assert rel_pos_bias.ndim == 2
  162. # (num_pos, num_heads) (aka flat) bias shape
  163. dst_num_pos, _ = new_bias_shape
  164. src_num_pos, num_attn_heads = rel_pos_bias.shape
  165. num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
  166. src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
  167. src_size = (src_size, src_size)
  168. has_flat_shape = True
  169. if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
  170. # print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
  171. if num_extra_tokens:
  172. extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
  173. rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
  174. else:
  175. extra_tokens = None
  176. def geometric_progression(a, r, n):
  177. return a * (1.0 - r ** n) / (1.0 - r)
  178. def _calc(src, dst):
  179. left, right = 1.01, 1.5
  180. while right - left > 1e-6:
  181. q = (left + right) / 2.0
  182. gp = geometric_progression(1, q, src // 2)
  183. if gp > dst // 2:
  184. right = q
  185. else:
  186. left = q
  187. dis = []
  188. cur = 1
  189. for i in range(src // 2):
  190. dis.append(cur)
  191. cur += q ** (i + 1)
  192. r_ids = [-_ for _ in reversed(dis)]
  193. return r_ids + [0] + dis
  194. y = _calc(src_size[0], dst_size[0])
  195. x = _calc(src_size[1], dst_size[1])
  196. yx = [torch.tensor(y), torch.tensor(x)]
  197. # print("Original positions = %s" % str(x))
  198. ty = dst_size[0] // 2.0
  199. tx = dst_size[1] // 2.0
  200. dy = torch.arange(-ty, ty + 0.1, 1.0)
  201. dx = torch.arange(-tx, tx + 0.1, 1.0)
  202. dyx = ndgrid(dy, dx)
  203. # print("Target positions = %s" % str(dx))
  204. all_rel_pos_bias = []
  205. for i in range(num_attn_heads):
  206. if has_flat_shape:
  207. z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
  208. else:
  209. z = rel_pos_bias[i, :, :].float()
  210. if _USE_SCIPY:
  211. # Original beit code uses scipy w/ cubic interpolation
  212. f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
  213. r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
  214. else:
  215. # Without scipy dependency, I've found a reasonably simple impl
  216. # that supports uneven spaced interpolation pts with 'linear' interp.
  217. # Results are comparable to scipy for model accuracy in most cases.
  218. f = RegularGridInterpolator(yx, z)
  219. r = f(dyx).contiguous().to(rel_pos_bias.device)
  220. if has_flat_shape:
  221. r = r.view(-1, 1)
  222. all_rel_pos_bias.append(r)
  223. if has_flat_shape:
  224. rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
  225. else:
  226. rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)
  227. if extra_tokens is not None:
  228. assert has_flat_shape
  229. rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
  230. return rel_pos_bias
  231. class RelPosBias(nn.Module):
  232. """ Relative Position Bias
  233. Adapted from Swin-V1 relative position bias impl, modularized.
  234. """
  235. def __init__(
  236. self,
  237. window_size: Tuple[int, int],
  238. num_heads: int,
  239. prefix_tokens: int = 0,
  240. device=None,
  241. dtype=None,
  242. ):
  243. dd = {'device': device, 'dtype': dtype}
  244. super().__init__()
  245. assert prefix_tokens <= 1
  246. self.window_size = window_size
  247. self.window_area = window_size[0] * window_size[1]
  248. self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
  249. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
  250. self.relative_position_bias_table = nn.Parameter(torch.empty(num_relative_distance, num_heads, **dd))
  251. self.register_buffer(
  252. "relative_position_index",
  253. gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0, device=device).view(-1),
  254. persistent=False,
  255. )
  256. self.init_weights()
  257. def init_weights(self):
  258. trunc_normal_(self.relative_position_bias_table, std=.02)
  259. def get_bias(self) -> torch.Tensor:
  260. relative_position_bias = self.relative_position_bias_table[self.relative_position_index]
  261. # win_h * win_w, win_h * win_w, num_heads
  262. relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
  263. return relative_position_bias.unsqueeze(0).contiguous()
  264. def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
  265. return attn + self.get_bias()
  266. def gen_relative_log_coords(
  267. win_size: Tuple[int, int],
  268. pretrained_win_size: Tuple[int, int] = (0, 0),
  269. mode='swin',
  270. device=None,
  271. dtype=None,
  272. ):
  273. assert mode in ('swin', 'cr')
  274. # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
  275. relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], device=device).to(torch.float32)
  276. relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], device=device).to(torch.float32)
  277. relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
  278. relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
  279. if mode == 'swin':
  280. if pretrained_win_size[0] > 0:
  281. relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
  282. relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
  283. else:
  284. relative_coords_table[:, :, 0] /= (win_size[0] - 1)
  285. relative_coords_table[:, :, 1] /= (win_size[1] - 1)
  286. relative_coords_table *= 8 # normalize to -8, 8
  287. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
  288. 1.0 + relative_coords_table.abs()) / math.log2(8)
  289. else:
  290. # mode == 'cr'
  291. relative_coords_table = torch.sign(relative_coords_table) * torch.log(
  292. 1.0 + relative_coords_table.abs())
  293. return relative_coords_table.to(dtype)
  294. class RelPosMlp(nn.Module):
  295. """ Log-Coordinate Relative Position MLP
  296. Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
  297. This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
  298. """
  299. def __init__(
  300. self,
  301. window_size: Tuple[int, int],
  302. num_heads: int = 8,
  303. hidden_dim: int = 128,
  304. prefix_tokens: int = 0,
  305. mode: str = 'cr',
  306. pretrained_window_size: Tuple[int, int] = (0, 0),
  307. device=None,
  308. dtype=None,
  309. ):
  310. dd = {'device': device, 'dtype': dtype}
  311. super().__init__()
  312. self.window_size = window_size
  313. self.window_area = self.window_size[0] * self.window_size[1]
  314. self.prefix_tokens = prefix_tokens
  315. self.num_heads = num_heads
  316. self.bias_shape = (self.window_area,) * 2 + (num_heads,)
  317. if mode == 'swin':
  318. self.bias_act = nn.Sigmoid()
  319. self.bias_gain = 16
  320. mlp_bias = (True, False)
  321. else:
  322. self.bias_act = nn.Identity()
  323. self.bias_gain = None
  324. mlp_bias = True
  325. self.mlp = Mlp(
  326. 2, # x, y
  327. hidden_features=hidden_dim,
  328. out_features=num_heads,
  329. act_layer=nn.ReLU,
  330. bias=mlp_bias,
  331. drop=(0.125, 0.),
  332. **dd,
  333. )
  334. self.register_buffer(
  335. "relative_position_index",
  336. gen_relative_position_index(window_size, device=device).view(-1),
  337. persistent=False,
  338. )
  339. # get relative_coords_table
  340. self.register_buffer(
  341. "rel_coords_log",
  342. gen_relative_log_coords(window_size, pretrained_window_size, mode=mode, **dd),
  343. persistent=False,
  344. )
  345. def get_bias(self) -> torch.Tensor:
  346. relative_position_bias = self.mlp(self.rel_coords_log)
  347. if self.relative_position_index is not None:
  348. relative_position_bias = relative_position_bias.view(-1, self.num_heads)[self.relative_position_index]
  349. relative_position_bias = relative_position_bias.view(self.bias_shape)
  350. relative_position_bias = relative_position_bias.permute(2, 0, 1)
  351. relative_position_bias = self.bias_act(relative_position_bias)
  352. if self.bias_gain is not None:
  353. relative_position_bias = self.bias_gain * relative_position_bias
  354. if self.prefix_tokens:
  355. relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
  356. return relative_position_bias.unsqueeze(0).contiguous()
  357. def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
  358. return attn + self.get_bias()
  359. def generate_lookup_tensor(
  360. length: int,
  361. max_relative_position: Optional[int] = None,
  362. device=None,
  363. dtype=None,
  364. ):
  365. """Generate a one_hot lookup tensor to reindex embeddings along one dimension.
  366. Args:
  367. length: the length to reindex to.
  368. max_relative_position: the maximum relative position to consider.
  369. Relative position embeddings for distances above this threshold
  370. are zeroed out.
  371. Returns:
  372. a lookup Tensor of size [length, length, vocab_size] that satisfies
  373. ret[n,m,v] = 1{m - n + max_relative_position = v}.
  374. """
  375. if max_relative_position is None:
  376. max_relative_position = length - 1
  377. # Return the cached lookup tensor, otherwise compute it and cache it.
  378. vocab_size = 2 * max_relative_position + 1
  379. ret = torch.zeros(length, length, vocab_size, device=device, dtype=dtype)
  380. for i in range(length):
  381. for x in range(length):
  382. v = x - i + max_relative_position
  383. if abs(x - i) > max_relative_position:
  384. continue
  385. ret[i, x, v] = 1
  386. return ret
  387. def reindex_2d_einsum_lookup(
  388. relative_position_tensor,
  389. height: int,
  390. width: int,
  391. height_lookup: torch.Tensor,
  392. width_lookup: torch.Tensor,
  393. ) -> torch.Tensor:
  394. """Reindex 2d relative position bias with 2 independent einsum lookups.
  395. Adapted from:
  396. https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
  397. Args:
  398. relative_position_tensor: tensor of shape
  399. [..., vocab_height, vocab_width, ...].
  400. height: height to reindex to.
  401. width: width to reindex to.
  402. height_lookup: one-hot height lookup
  403. width_lookup: one-hot width lookup
  404. Returns:
  405. reindexed_tensor: a Tensor of shape
  406. [..., height * width, height * width, ...]
  407. """
  408. reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
  409. reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
  410. area = height * width
  411. return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
  412. class RelPosBiasTf(nn.Module):
  413. """ Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
  414. Adapted from:
  415. https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
  416. """
  417. def __init__(
  418. self,
  419. window_size: Tuple[int, int],
  420. num_heads: int,
  421. prefix_tokens: int = 0,
  422. device=None,
  423. dtype=None,
  424. ):
  425. dd = {'device': device, 'dtype': dtype}
  426. super().__init__()
  427. assert prefix_tokens <= 1
  428. self.window_size = window_size
  429. self.window_area = window_size[0] * window_size[1]
  430. self.num_heads = num_heads
  431. vocab_height = 2 * window_size[0] - 1
  432. vocab_width = 2 * window_size[1] - 1
  433. self.bias_shape = (self.num_heads, vocab_height, vocab_width)
  434. self.relative_position_bias_table = nn.Parameter(torch.empty(self.bias_shape, **dd))
  435. self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0], **dd), persistent=False)
  436. self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1], **dd), persistent=False)
  437. self.init_weights()
  438. def init_weights(self):
  439. nn.init.normal_(self.relative_position_bias_table, std=.02)
  440. def get_bias(self) -> torch.Tensor:
  441. # FIXME change to not use one-hot/einsum?
  442. return reindex_2d_einsum_lookup(
  443. self.relative_position_bias_table,
  444. self.window_size[0],
  445. self.window_size[1],
  446. self.height_lookup,
  447. self.width_lookup
  448. )
  449. def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
  450. return attn + self.get_bias()