| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155 |
- """ Multi-Scale Vision Transformer v2
- @inproceedings{li2021improved,
- title={MViTv2: Improved multiscale vision transformers for classification and detection},
- author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph},
- booktitle={CVPR},
- year={2022}
- }
- Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit
- Original copyright below.
- Modifications and timm support by / Copyright 2022, Ross Wightman
- """
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
- import operator
- from collections import OrderedDict
- from dataclasses import dataclass
- from functools import partial, reduce
- from typing import Union, List, Tuple, Optional, Any, Type
- import torch
- from torch import nn
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
- from timm.layers import Mlp, DropPath, calculate_drop_path_rates, trunc_normal_tf_, get_norm_layer, to_2tuple
- from ._builder import build_model_with_cfg
- from ._features import feature_take_indices
- from ._features_fx import register_notrace_function
- from ._manipulate import checkpoint
- from ._registry import register_model, generate_default_cfgs
- __all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
- @dataclass
- class MultiScaleVitCfg:
- depths: Tuple[int, ...] = (2, 3, 16, 3)
- embed_dim: Union[int, Tuple[int, ...]] = 96
- num_heads: Union[int, Tuple[int, ...]] = 1
- mlp_ratio: float = 4.
- pool_first: bool = False
- expand_attn: bool = True
- qkv_bias: bool = True
- use_cls_token: bool = False
- use_abs_pos: bool = False
- residual_pooling: bool = True
- mode: str = 'conv'
- kernel_qkv: Tuple[int, int] = (3, 3)
- stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2))
- stride_kv: Optional[Tuple[Tuple[int, int]]] = None
- stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4)
- patch_kernel: Tuple[int, int] = (7, 7)
- patch_stride: Tuple[int, int] = (4, 4)
- patch_padding: Tuple[int, int] = (3, 3)
- pool_type: str = 'max'
- rel_pos_type: str = 'spatial'
- act_layer: Union[str, Tuple[str, str]] = 'gelu'
- norm_layer: Union[str, Tuple[str, str]] = 'layernorm'
- norm_eps: float = 1e-6
- def __post_init__(self):
- num_stages = len(self.depths)
- if not isinstance(self.embed_dim, (tuple, list)):
- self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages))
- assert len(self.embed_dim) == num_stages
- if not isinstance(self.num_heads, (tuple, list)):
- self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages))
- assert len(self.num_heads) == num_stages
- if self.stride_kv_adaptive is not None and self.stride_kv is None:
- _stride_kv = self.stride_kv_adaptive
- pool_kv_stride = []
- for i in range(num_stages):
- if min(self.stride_q[i]) > 1:
- _stride_kv = [
- max(_stride_kv[d] // self.stride_q[i][d], 1)
- for d in range(len(_stride_kv))
- ]
- pool_kv_stride.append(tuple(_stride_kv))
- self.stride_kv = tuple(pool_kv_stride)
- def prod(iterable):
- return reduce(operator.mul, iterable, 1)
- class PatchEmbed(nn.Module):
- """
- PatchEmbed.
- """
- def __init__(
- self,
- dim_in: int = 3,
- dim_out: int = 768,
- kernel: Tuple[int, int] = (7, 7),
- stride: Tuple[int, int] = (4, 4),
- padding: Tuple[int, int] = (3, 3),
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- self.proj = nn.Conv2d(
- dim_in,
- dim_out,
- kernel_size=kernel,
- stride=stride,
- padding=padding,
- **dd,
- )
- def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
- x = self.proj(x)
- # B C H W -> B HW C
- return x.flatten(2).transpose(1, 2), x.shape[-2:]
- @register_notrace_function
- def reshape_pre_pool(
- x,
- feat_size: List[int],
- has_cls_token: bool = True
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- H, W = feat_size
- if has_cls_token:
- cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
- else:
- cls_tok = None
- x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
- return x, cls_tok
- @register_notrace_function
- def reshape_post_pool(
- x,
- num_heads: int,
- cls_tok: Optional[torch.Tensor] = None
- ) -> Tuple[torch.Tensor, List[int]]:
- feat_size = [x.shape[2], x.shape[3]]
- L_pooled = x.shape[2] * x.shape[3]
- x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3)
- if cls_tok is not None:
- x = torch.cat((cls_tok, x), dim=2)
- return x, feat_size
- @register_notrace_function
- def cal_rel_pos_type(
- attn: torch.Tensor,
- q: torch.Tensor,
- has_cls_token: bool,
- q_size: List[int],
- k_size: List[int],
- rel_pos_h: torch.Tensor,
- rel_pos_w: torch.Tensor,
- ):
- """
- Spatial Relative Positional Embeddings.
- """
- sp_idx = 1 if has_cls_token else 0
- q_h, q_w = q_size
- k_h, k_w = k_size
- # Scale up rel pos if shapes for q and k are different.
- q_h_ratio = max(k_h / q_h, 1.0)
- k_h_ratio = max(q_h / k_h, 1.0)
- dist_h = (
- torch.arange(q_h, device=q.device, dtype=torch.long).unsqueeze(-1) * q_h_ratio -
- torch.arange(k_h, device=q.device, dtype=torch.long).unsqueeze(0) * k_h_ratio
- )
- dist_h += (k_h - 1) * k_h_ratio
- q_w_ratio = max(k_w / q_w, 1.0)
- k_w_ratio = max(q_w / k_w, 1.0)
- dist_w = (
- torch.arange(q_w, device=q.device, dtype=torch.long).unsqueeze(-1) * q_w_ratio -
- torch.arange(k_w, device=q.device, dtype=torch.long).unsqueeze(0) * k_w_ratio
- )
- dist_w += (k_w - 1) * k_w_ratio
- rel_h = rel_pos_h[dist_h.long()]
- rel_w = rel_pos_w[dist_w.long()]
- B, n_head, q_N, dim = q.shape
- r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
- rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
- rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w)
- attn[:, :, sp_idx:, sp_idx:] = (
- attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
- + rel_h.unsqueeze(-1)
- + rel_w.unsqueeze(-2)
- ).view(B, -1, q_h * q_w, k_h * k_w)
- return attn
- class MultiScaleAttentionPoolFirst(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- feat_size: Tuple[int, int],
- num_heads: int = 8,
- qkv_bias: bool = True,
- mode: str = "conv",
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- has_cls_token: bool = True,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.dim_out = dim_out
- self.head_dim = dim_out // num_heads
- self.scale = self.head_dim ** -0.5
- self.has_cls_token = has_cls_token
- padding_q = tuple([int(q // 2) for q in kernel_q])
- padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
- self.q = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
- self.k = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
- self.v = nn.Linear(dim, dim_out, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim_out, dim_out, **dd)
- # Skip pooling with kernel and stride size of (1, 1, 1).
- if prod(kernel_q) == 1 and prod(stride_q) == 1:
- kernel_q = None
- if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
- kernel_kv = None
- self.mode = mode
- self.unshared = mode == 'conv_unshared'
- self.pool_q, self.pool_k, self.pool_v = None, None, None
- self.norm_q, self.norm_k, self.norm_v = None, None, None
- if mode in ("avg", "max"):
- pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
- if kernel_q:
- self.pool_q = pool_op(kernel_q, stride_q, padding_q)
- if kernel_kv:
- self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
- self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
- elif mode == "conv" or mode == "conv_unshared":
- dim_conv = dim // num_heads if mode == "conv" else dim
- if kernel_q:
- self.pool_q = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_q,
- stride=stride_q,
- padding=padding_q,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_q = norm_layer(dim_conv, **dd)
- if kernel_kv:
- self.pool_k = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_k = norm_layer(dim_conv, **dd)
- self.pool_v = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_v = norm_layer(dim_conv, **dd)
- else:
- raise NotImplementedError(f"Unsupported model {mode}")
- # relative pos embedding
- self.rel_pos_type = rel_pos_type
- if self.rel_pos_type == 'spatial':
- assert feat_size[0] == feat_size[1]
- size = feat_size[0]
- q_size = size // stride_q[1] if len(stride_q) > 0 else size
- kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
- rel_sp_dim = 2 * max(q_size, kv_size) - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- trunc_normal_tf_(self.rel_pos_h, std=0.02)
- trunc_normal_tf_(self.rel_pos_w, std=0.02)
- self.residual_pooling = residual_pooling
- def forward(self, x, feat_size: List[int]):
- B, N, _ = x.shape
- fold_dim = 1 if self.unshared else self.num_heads
- x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
- q = k = v = x
- if self.pool_q is not None:
- q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
- q = self.pool_q(q)
- q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
- else:
- q_size = feat_size
- if self.norm_q is not None:
- q = self.norm_q(q)
- if self.pool_k is not None:
- k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
- k = self.pool_k(k)
- k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
- else:
- k_size = feat_size
- if self.norm_k is not None:
- k = self.norm_k(k)
- if self.pool_v is not None:
- v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
- v = self.pool_v(v)
- v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
- else:
- v_size = feat_size
- if self.norm_v is not None:
- v = self.norm_v(v)
- q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
- q = q.transpose(1, 2).reshape(B, q_N, -1)
- q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
- k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
- k = k.transpose(1, 2).reshape(B, k_N, -1)
- k = self.k(k).reshape(B, k_N, self.num_heads, -1)
- v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
- v = v.transpose(1, 2).reshape(B, v_N, -1)
- v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
- attn = (q * self.scale) @ k
- if self.rel_pos_type == 'spatial':
- attn = cal_rel_pos_type(
- attn,
- q,
- self.has_cls_token,
- q_size,
- k_size,
- self.rel_pos_h,
- self.rel_pos_w,
- )
- attn = attn.softmax(dim=-1)
- x = attn @ v
- if self.residual_pooling:
- x = x + q
- x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
- x = self.proj(x)
- return x, q_size
- class MultiScaleAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- feat_size: Tuple[int, int],
- num_heads: int = 8,
- qkv_bias: bool = True,
- mode: str = "conv",
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- has_cls_token: bool = True,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.num_heads = num_heads
- self.dim_out = dim_out
- self.head_dim = dim_out // num_heads
- self.scale = self.head_dim ** -0.5
- self.has_cls_token = has_cls_token
- padding_q = tuple([int(q // 2) for q in kernel_q])
- padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
- self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias, **dd)
- self.proj = nn.Linear(dim_out, dim_out, **dd)
- # Skip pooling with kernel and stride size of (1, 1, 1).
- if prod(kernel_q) == 1 and prod(stride_q) == 1:
- kernel_q = None
- if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
- kernel_kv = None
- self.mode = mode
- self.unshared = mode == 'conv_unshared'
- self.norm_q, self.norm_k, self.norm_v = None, None, None
- self.pool_q, self.pool_k, self.pool_v = None, None, None
- if mode in ("avg", "max"):
- pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
- if kernel_q:
- self.pool_q = pool_op(kernel_q, stride_q, padding_q)
- if kernel_kv:
- self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
- self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
- elif mode == "conv" or mode == "conv_unshared":
- dim_conv = dim_out // num_heads if mode == "conv" else dim_out
- if kernel_q:
- self.pool_q = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_q,
- stride=stride_q,
- padding=padding_q,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_q = norm_layer(dim_conv, **dd)
- if kernel_kv:
- self.pool_k = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_k = norm_layer(dim_conv, **dd)
- self.pool_v = nn.Conv2d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- **dd,
- )
- self.norm_v = norm_layer(dim_conv, **dd)
- else:
- raise NotImplementedError(f"Unsupported model {mode}")
- # relative pos embedding
- self.rel_pos_type = rel_pos_type
- if self.rel_pos_type == 'spatial':
- assert feat_size[0] == feat_size[1]
- size = feat_size[0]
- q_size = size // stride_q[1] if len(stride_q) > 0 else size
- kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
- rel_sp_dim = 2 * max(q_size, kv_size) - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim, **dd))
- trunc_normal_tf_(self.rel_pos_h, std=0.02)
- trunc_normal_tf_(self.rel_pos_w, std=0.02)
- self.residual_pooling = residual_pooling
- def forward(self, x, feat_size: List[int]):
- B, N, _ = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv.unbind(dim=0)
- if self.pool_q is not None:
- q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
- q = self.pool_q(q)
- q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
- else:
- q_size = feat_size
- if self.norm_q is not None:
- q = self.norm_q(q)
- if self.pool_k is not None:
- k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
- k = self.pool_k(k)
- k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
- else:
- k_size = feat_size
- if self.norm_k is not None:
- k = self.norm_k(k)
- if self.pool_v is not None:
- v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
- v = self.pool_v(v)
- v, _ = reshape_post_pool(v, self.num_heads, v_tok)
- if self.norm_v is not None:
- v = self.norm_v(v)
- attn = (q * self.scale) @ k.transpose(-2, -1)
- if self.rel_pos_type == 'spatial':
- attn = cal_rel_pos_type(
- attn,
- q,
- self.has_cls_token,
- q_size,
- k_size,
- self.rel_pos_h,
- self.rel_pos_w,
- )
- attn = attn.softmax(dim=-1)
- x = attn @ v
- if self.residual_pooling:
- x = x + q
- x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
- x = self.proj(x)
- return x, q_size
- class MultiScaleBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- num_heads: int,
- feat_size: Tuple[int, int],
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- drop_path: float = 0.0,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- mode: str = "conv",
- has_cls_token: bool = True,
- expand_attn: bool = False,
- pool_first: bool = False,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- proj_needed = dim != dim_out
- self.dim = dim
- self.dim_out = dim_out
- self.has_cls_token = has_cls_token
- self.norm1 = norm_layer(dim, **dd)
- self.shortcut_proj_attn = nn.Linear(dim, dim_out, **dd) if proj_needed and expand_attn else None
- if stride_q and prod(stride_q) > 1:
- kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
- stride_skip = stride_q
- padding_skip = [int(skip // 2) for skip in kernel_skip]
- self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip)
- else:
- self.shortcut_pool_attn = None
- att_dim = dim_out if expand_attn else dim
- attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention
- self.attn = attn_layer(
- dim,
- att_dim,
- num_heads=num_heads,
- feat_size=feat_size,
- qkv_bias=qkv_bias,
- kernel_q=kernel_q,
- kernel_kv=kernel_kv,
- stride_q=stride_q,
- stride_kv=stride_kv,
- norm_layer=norm_layer,
- has_cls_token=has_cls_token,
- mode=mode,
- rel_pos_type=rel_pos_type,
- residual_pooling=residual_pooling,
- **dd,
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(att_dim, **dd)
- mlp_dim_out = dim_out
- self.shortcut_proj_mlp = nn.Linear(dim, dim_out, **dd) if proj_needed and not expand_attn else None
- self.mlp = Mlp(
- in_features=att_dim,
- hidden_features=int(att_dim * mlp_ratio),
- out_features=mlp_dim_out,
- **dd,
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- def _shortcut_pool(self, x, feat_size: List[int]):
- if self.shortcut_pool_attn is None:
- return x
- if self.has_cls_token:
- cls_tok, x = x[:, :1, :], x[:, 1:, :]
- else:
- cls_tok = None
- B, L, C = x.shape
- H, W = feat_size
- x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
- x = self.shortcut_pool_attn(x)
- x = x.reshape(B, C, -1).transpose(1, 2)
- if cls_tok is not None:
- x = torch.cat((cls_tok, x), dim=1)
- return x
- def forward(self, x, feat_size: List[int]):
- x_norm = self.norm1(x)
- # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj
- x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm)
- x_shortcut = self._shortcut_pool(x_shortcut, feat_size)
- x, feat_size_new = self.attn(x_norm, feat_size)
- x = x_shortcut + self.drop_path1(x)
- x_norm = self.norm2(x)
- x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm)
- x = x_shortcut + self.drop_path2(self.mlp(x_norm))
- return x, feat_size_new
- class MultiScaleVitStage(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- depth: int,
- num_heads: int,
- feat_size: Tuple[int, int],
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- kernel_q: Tuple[int, int] = (1, 1),
- kernel_kv: Tuple[int, int] = (1, 1),
- stride_q: Tuple[int, int] = (1, 1),
- stride_kv: Tuple[int, int] = (1, 1),
- mode: str = "conv",
- has_cls_token: bool = True,
- expand_attn: bool = False,
- pool_first: bool = False,
- rel_pos_type: str = 'spatial',
- residual_pooling: bool = True,
- norm_layer: Type[nn.Module] = nn.LayerNorm,
- drop_path: Union[float, List[float]] = 0.0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.grad_checkpointing = False
- self.blocks = nn.ModuleList()
- if expand_attn:
- out_dims = (dim_out,) * depth
- else:
- out_dims = (dim,) * (depth - 1) + (dim_out,)
- for i in range(depth):
- attention_block = MultiScaleBlock(
- dim=dim,
- dim_out=out_dims[i],
- num_heads=num_heads,
- feat_size=feat_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- kernel_q=kernel_q,
- kernel_kv=kernel_kv,
- stride_q=stride_q if i == 0 else (1, 1),
- stride_kv=stride_kv,
- mode=mode,
- has_cls_token=has_cls_token,
- pool_first=pool_first,
- rel_pos_type=rel_pos_type,
- residual_pooling=residual_pooling,
- expand_attn=expand_attn,
- norm_layer=norm_layer,
- drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
- **dd,
- )
- dim = out_dims[i]
- self.blocks.append(attention_block)
- if i == 0:
- feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)])
- self.feat_size = feat_size
- def forward(self, x, feat_size: List[int]):
- for blk in self.blocks:
- if self.grad_checkpointing and not torch.jit.is_scripting():
- x, feat_size = checkpoint(blk, x, feat_size)
- else:
- x, feat_size = blk(x, feat_size)
- return x, feat_size
- class MultiScaleVit(nn.Module):
- """
- Improved Multiscale Vision Transformers for Classification and Detection
- Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
- Christoph Feichtenhofer*
- https://arxiv.org/abs/2112.01526
- Multiscale Vision Transformers
- Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
- Christoph Feichtenhofer*
- https://arxiv.org/abs/2104.11227
- """
- def __init__(
- self,
- cfg: MultiScaleVitCfg,
- img_size: Tuple[int, int] = (224, 224),
- in_chans: int = 3,
- global_pool: Optional[str] = None,
- num_classes: int = 1000,
- drop_path_rate: float = 0.,
- drop_rate: float = 0.,
- device=None,
- dtype=None,
- ):
- super().__init__()
- dd = {'device': device, 'dtype': dtype}
- img_size = to_2tuple(img_size)
- norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
- self.num_classes = num_classes
- self.drop_rate = drop_rate
- if global_pool is None:
- global_pool = 'token' if cfg.use_cls_token else 'avg'
- self.global_pool = global_pool
- self.depths = tuple(cfg.depths)
- self.expand_attn = cfg.expand_attn
- embed_dim = cfg.embed_dim[0]
- self.patch_embed = PatchEmbed(
- dim_in=in_chans,
- dim_out=embed_dim,
- kernel=cfg.patch_kernel,
- stride=cfg.patch_stride,
- padding=cfg.patch_padding,
- **dd,
- )
- patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1])
- num_patches = prod(patch_dims)
- if cfg.use_cls_token:
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
- self.num_prefix_tokens = 1
- pos_embed_dim = num_patches + 1
- else:
- self.num_prefix_tokens = 0
- self.cls_token = None
- pos_embed_dim = num_patches
- if cfg.use_abs_pos:
- self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim, **dd))
- else:
- self.pos_embed = None
- num_stages = len(cfg.embed_dim)
- feat_size = patch_dims
- curr_stride = max(cfg.patch_stride)
- dpr = calculate_drop_path_rates(drop_path_rate, cfg.depths, stagewise=True)
- self.stages = nn.ModuleList()
- self.feature_info = []
- for i in range(num_stages):
- if cfg.expand_attn:
- dim_out = cfg.embed_dim[i]
- else:
- dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
- stage = MultiScaleVitStage(
- dim=embed_dim,
- dim_out=dim_out,
- depth=cfg.depths[i],
- num_heads=cfg.num_heads[i],
- feat_size=feat_size,
- mlp_ratio=cfg.mlp_ratio,
- qkv_bias=cfg.qkv_bias,
- mode=cfg.mode,
- pool_first=cfg.pool_first,
- expand_attn=cfg.expand_attn,
- kernel_q=cfg.kernel_qkv,
- kernel_kv=cfg.kernel_qkv,
- stride_q=cfg.stride_q[i],
- stride_kv=cfg.stride_kv[i],
- has_cls_token=cfg.use_cls_token,
- rel_pos_type=cfg.rel_pos_type,
- residual_pooling=cfg.residual_pooling,
- norm_layer=norm_layer,
- drop_path=dpr[i],
- **dd,
- )
- curr_stride *= max(cfg.stride_q[i])
- self.feature_info += [dict(module=f'block.{i}', num_chs=dim_out, reduction=curr_stride)]
- embed_dim = dim_out
- feat_size = stage.feat_size
- self.stages.append(stage)
- self.num_features = self.head_hidden_size = embed_dim
- self.norm = norm_layer(embed_dim, **dd)
- self.head = nn.Sequential(OrderedDict([
- ('drop', nn.Dropout(self.drop_rate)),
- ('fc', nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity())
- ]))
- if self.pos_embed is not None:
- trunc_normal_tf_(self.pos_embed, std=0.02)
- if self.cls_token is not None:
- trunc_normal_tf_(self.cls_token, std=0.02)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_tf_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0.0)
- @torch.jit.ignore
- def no_weight_decay(self):
- return {k for k, _ in self.named_parameters()
- if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])}
- @torch.jit.ignore
- def group_matcher(self, coarse=False):
- matcher = dict(
- stem=r'^patch_embed', # stem and embed
- blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
- )
- return matcher
- @torch.jit.ignore
- def set_grad_checkpointing(self, enable=True):
- for s in self.stages:
- s.grad_checkpointing = enable
- @torch.jit.ignore
- def get_classifier(self) -> nn.Module:
- return self.head.fc
- def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
- self.num_classes = num_classes
- if global_pool is not None:
- self.global_pool = global_pool
- device = self.head.fc.weight.device if hasattr(self.head.fc, 'weight') else None
- dtype = self.head.fc.weight.dtype if hasattr(self.head.fc, 'weight') else None
- self.head = nn.Sequential(OrderedDict([
- ('drop', nn.Dropout(self.drop_rate)),
- ('fc', nn.Linear(self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity())
- ]))
- def forward_intermediates(
- self,
- x: torch.Tensor,
- indices: Optional[Union[int, List[int]]] = None,
- norm: bool = False,
- stop_early: bool = False,
- output_fmt: str = 'NCHW',
- intermediates_only: bool = False,
- ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
- """ Forward features that returns intermediates.
- Args:
- x: Input image tensor
- indices: Take last n blocks if int, all if None, select matching indices if sequence
- norm: Apply norm layer to all intermediates
- stop_early: Stop iterating over blocks when last desired intermediate hit
- output_fmt: Shape of intermediate feature outputs
- intermediates_only: Only return intermediate features
- Returns:
- """
- assert output_fmt in ('NCHW', 'NLC'), 'Output shape must be NCHW or NLC.'
- reshape = output_fmt == 'NCHW'
- intermediates = []
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # FIXME slice block/pos_block if < max
- # forward pass
- x, feat_size = self.patch_embed(x)
- B = x.shape[0]
- if self.cls_token is not None:
- cls_tokens = self.cls_token.expand(B, -1, -1)
- x = torch.cat((cls_tokens, x), dim=1)
- if self.pos_embed is not None:
- x = x + self.pos_embed
- last_idx = len(self.stages) - 1
- for feat_idx, stage in enumerate(self.stages):
- x, feat_size = stage(x, feat_size)
- if feat_idx in take_indices:
- if norm and feat_idx == last_idx:
- x_inter = self.norm(x) # applying final norm last intermediate
- else:
- x_inter = x
- if reshape:
- if self.cls_token is not None:
- # possible to allow return of class tokens, TBD
- x_inter = x_inter[:, 1:]
- x_inter = x_inter.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2)
- intermediates.append(x_inter)
- if intermediates_only:
- return intermediates
- if feat_idx == last_idx:
- x = self.norm(x)
- return x, intermediates
- def prune_intermediate_layers(
- self,
- indices: Union[int, List[int]] = 1,
- prune_norm: bool = False,
- prune_head: bool = True,
- ):
- """ Prune layers not required for specified intermediates.
- """
- take_indices, max_index = feature_take_indices(len(self.stages), indices)
- # FIXME add stage pruning
- # self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
- if prune_norm:
- self.norm = nn.Identity()
- if prune_head:
- self.reset_classifier(0, '')
- return take_indices
- def forward_features(self, x):
- x, feat_size = self.patch_embed(x)
- B, N, C = x.shape
- if self.cls_token is not None:
- cls_tokens = self.cls_token.expand(B, -1, -1)
- x = torch.cat((cls_tokens, x), dim=1)
- if self.pos_embed is not None:
- x = x + self.pos_embed
- for stage in self.stages:
- x, feat_size = stage(x, feat_size)
- x = self.norm(x)
- return x
- def forward_head(self, x, pre_logits: bool = False):
- if self.global_pool:
- if self.global_pool == 'avg':
- x = x[:, self.num_prefix_tokens:].mean(1)
- else:
- x = x[:, 0]
- return x if pre_logits else self.head(x)
- def forward(self, x):
- x = self.forward_features(x)
- x = self.forward_head(x)
- return x
- def checkpoint_filter_fn(state_dict, model):
- if 'stages.0.blocks.0.norm1.weight' in state_dict:
- # native checkpoint, look for rel_pos interpolations
- for k in state_dict.keys():
- if 'rel_pos' in k:
- rel_pos = state_dict[k]
- dest_rel_pos_shape = model.state_dict()[k].shape
- if rel_pos.shape[0] != dest_rel_pos_shape[0]:
- rel_pos_resized = torch.nn.functional.interpolate(
- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
- size=dest_rel_pos_shape[0],
- mode="linear",
- )
- state_dict[k] = rel_pos_resized.reshape(-1, dest_rel_pos_shape[0]).permute(1, 0)
- return state_dict
- import re
- if 'model_state' in state_dict:
- state_dict = state_dict['model_state']
- depths = getattr(model, 'depths', None)
- expand_attn = getattr(model, 'expand_attn', True)
- assert depths is not None, 'model requires depth attribute to remap checkpoints'
- depth_map = {}
- block_idx = 0
- for stage_idx, d in enumerate(depths):
- depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
- block_idx += d
- out_dict = {}
- for k, v in state_dict.items():
- k = re.sub(
- r'blocks\.(\d+)',
- lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
- k)
- if expand_attn:
- k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
- else:
- k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
- if 'head' in k:
- k = k.replace('head.projection', 'head.fc')
- out_dict[k] = v
- return out_dict
- model_cfgs = dict(
- mvitv2_tiny=MultiScaleVitCfg(
- depths=(1, 2, 5, 2),
- ),
- mvitv2_small=MultiScaleVitCfg(
- depths=(1, 2, 11, 2),
- ),
- mvitv2_base=MultiScaleVitCfg(
- depths=(2, 3, 16, 3),
- ),
- mvitv2_large=MultiScaleVitCfg(
- depths=(2, 6, 36, 4),
- embed_dim=144,
- num_heads=2,
- expand_attn=False,
- ),
- mvitv2_small_cls=MultiScaleVitCfg(
- depths=(1, 2, 11, 2),
- use_cls_token=True,
- ),
- mvitv2_base_cls=MultiScaleVitCfg(
- depths=(2, 3, 16, 3),
- use_cls_token=True,
- ),
- mvitv2_large_cls=MultiScaleVitCfg(
- depths=(2, 6, 36, 4),
- embed_dim=144,
- num_heads=2,
- use_cls_token=True,
- expand_attn=True,
- ),
- mvitv2_huge_cls=MultiScaleVitCfg(
- depths=(4, 8, 60, 8),
- embed_dim=192,
- num_heads=3,
- use_cls_token=True,
- expand_attn=True,
- ),
- )
- def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
- out_indices = kwargs.pop('out_indices', 4)
- return build_model_with_cfg(
- MultiScaleVit,
- variant,
- pretrained,
- model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
- pretrained_filter_fn=checkpoint_filter_fn,
- feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
- **kwargs,
- )
- def _cfg(url='', **kwargs):
- return {
- 'url': url,
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
- 'crop_pct': .9, 'interpolation': 'bicubic',
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
- 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
- 'fixed_input_size': True,
- 'license': 'apache-2.0',
- **kwargs
- }
- default_cfgs = generate_default_cfgs({
- 'mvitv2_tiny.fb_in1k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth',
- hf_hub_id='timm/'),
- 'mvitv2_small_cls': _cfg(url=''),
- 'mvitv2_base_cls.fb_inw21k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
- hf_hub_id='timm/',
- num_classes=19168),
- 'mvitv2_large_cls.fb_inw21k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
- hf_hub_id='timm/',
- num_classes=19168),
- 'mvitv2_huge_cls.fb_inw21k': _cfg(
- url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
- hf_hub_id='timm/',
- num_classes=19168),
- })
- @register_model
- def mvitv2_tiny(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_small(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_base(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_large(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_small_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_base_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_large_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs)
- @register_model
- def mvitv2_huge_cls(pretrained=False, **kwargs) -> MultiScaleVit:
- return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs)
|