| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007 |
- # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
- # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
- from collections import OrderedDict
- from functools import partial
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- from timm.models.layers import trunc_normal_
- try:
- from fairscale.nn.checkpoint import checkpoint_wrapper
- except ImportError:
- checkpoint_wrapper = None
- MViTv2_Base_config = {
- 'depth':
- 24,
- 'dim_mul': [[2, 2.0], [5, 2.0], [21, 2.0]],
- 'head_mul': [[2, 2.0], [5, 2.0], [21, 2.0]],
- 'pool_q_stride':
- [[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1],
- [5, 1, 2, 2], [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1],
- [10, 1, 1, 1], [11, 1, 1, 1], [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1],
- [15, 1, 1, 1], [16, 1, 1, 1], [17, 1, 1, 1], [18, 1, 1, 1], [19, 1, 1, 1],
- [20, 1, 1, 1], [21, 1, 2, 2], [22, 1, 1, 1], [23, 1, 1, 1]],
- 'pool_kvq_kernel': [3, 3, 3],
- 'pool_kv_stride_adaptive': [1, 4, 4],
- }
- def interpolate_rel_pos_embed(state_dict_origin,
- state_dict_model,
- temporal=True,
- verbose=False):
- rel_pos_embed_types = ['rel_pos_h', 'rel_pos_w']
- if temporal:
- rel_pos_embed_types += ['rel_pos_t']
- state_dict_inflated = state_dict_origin.copy()
- for k, v2d in state_dict_origin.items():
- if any([x in k for x in rel_pos_embed_types]):
- v3d = state_dict_model[k]
- if v2d.shape[0] != v3d.shape[0]:
- rel_pos_resized = F.interpolate(
- v2d.reshape(1, v2d.shape[0], -1).permute(0, 2, 1),
- size=v3d.shape[0],
- mode='linear',
- )
- v3d = rel_pos_resized.reshape(-1, v3d.shape[0]).permute(1, 0)
- if verbose:
- print('Inflate {}: {} -> {}: {}'.format(
- k, v2d.shape, k, v3d.shape))
- else:
- v3d = v2d
- state_dict_inflated[k] = v3d.clone()
- return state_dict_inflated
- def _prepare_mvit_configs(cfg):
- depth = cfg['depth']
- dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1)
- for i in range(len(cfg['dim_mul'])):
- dim_mul[cfg['dim_mul'][i][0]] = cfg['dim_mul'][i][1]
- for i in range(len(cfg['head_mul'])):
- head_mul[cfg['head_mul'][i][0]] = cfg['head_mul'][i][1]
- pool_q = [[] for i in range(depth)]
- pool_kv = [[] for i in range(depth)]
- stride_q = [[] for i in range(depth)]
- stride_kv = [[] for i in range(depth)]
- for i in range(len(cfg['pool_q_stride'])):
- stride_q[cfg['pool_q_stride'][i][0]] = cfg['pool_q_stride'][i][1:]
- pool_q[cfg['pool_q_stride'][i][0]] = cfg['pool_kvq_kernel']
- if cfg['pool_kv_stride_adaptive'] is not None:
- _stride_kv = cfg['pool_kv_stride_adaptive']
- cfg['pool_kv_stride'] = []
- for i in range(cfg['depth']):
- if len(stride_q[i]) > 0:
- _stride_kv = [
- max(_stride_kv[d] // stride_q[i][d], 1)
- for d in range(len(_stride_kv))
- ]
- cfg['pool_kv_stride'].append([i] + _stride_kv)
- for i in range(len(cfg['pool_kv_stride'])):
- stride_kv[cfg['pool_kv_stride'][i][0]] = cfg['pool_kv_stride'][i][1:]
- pool_kv[cfg['pool_kv_stride'][i][0]] = cfg['pool_kvq_kernel']
- return dim_mul, head_mul, pool_q, pool_kv, stride_q, stride_kv
- class Mlp(nn.Module):
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop_rate=0.0,
- ):
- super().__init__()
- self.drop_rate = drop_rate
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- if self.drop_rate > 0.0:
- self.drop = nn.Dropout(drop_rate)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- if self.drop_rate > 0.0:
- x = self.drop(x)
- x = self.fc2(x)
- if self.drop_rate > 0.0:
- x = self.drop(x)
- return x
- class Permute(nn.Module):
- def __init__(self, dims):
- super().__init__()
- self.dims = dims
- def forward(self, x):
- return x.permute(*self.dims)
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
- """
- Stochastic Depth per sample.
- """
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0], ) + (1, ) * (
- x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
- mask.floor_() # binarize
- output = x.div(keep_prob) * mask
- return output
- class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
- def round_width(width, multiplier, min_width=1, divisor=1, verbose=False):
- if not multiplier:
- return width
- width *= multiplier
- min_width = min_width or divisor
- if verbose:
- print(f'min width {min_width}')
- print(f'width {width} divisor {divisor}')
- print(f'other {int(width + divisor / 2) // divisor * divisor}')
- width_out = max(min_width, int(width + divisor / 2) // divisor * divisor)
- if width_out < 0.9 * width:
- width_out += divisor
- return int(width_out)
- class PatchEmbed(nn.Module):
- """
- PatchEmbed.
- """
- def __init__(
- self,
- dim_in=3,
- dim_out=768,
- kernel=(7, 7),
- stride=(4, 4),
- padding=(3, 3),
- conv2d=False,
- ):
- super().__init__()
- if conv2d:
- conv_function = nn.Conv2d
- else:
- conv_function = nn.Conv3d
- self.proj = conv_function(
- dim_in,
- dim_out,
- kernel_size=kernel,
- stride=stride,
- padding=padding,
- )
- def forward(self, x):
- x = self.proj(x)
- # B C H W -> B HW C
- return x.flatten(2).transpose(1, 2), x.shape
- def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None):
- if pool is None:
- return tensor, thw_shape
- tensor_dim = tensor.ndim
- if tensor_dim == 4:
- pass
- elif tensor_dim == 3:
- tensor = tensor.unsqueeze(1)
- else:
- raise NotImplementedError(
- f'Unsupported input dimension {tensor.shape}')
- if has_cls_embed:
- cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :]
- B, N, L, C = tensor.shape
- T, H, W = thw_shape
- tensor = (
- tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous())
- tensor = pool(tensor)
- thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]]
- L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4]
- tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3)
- if has_cls_embed:
- tensor = torch.cat((cls_tok, tensor), dim=2)
- if norm is not None:
- tensor = norm(tensor)
- # Assert tensor_dim in [3, 4]
- if tensor_dim == 4:
- pass
- else: # tensor_dim == 3:
- tensor = tensor.squeeze(1)
- return tensor, thw_shape
- def get_rel_pos(rel_pos, d):
- if isinstance(d, int):
- ori_d = rel_pos.shape[0]
- if ori_d == d:
- return rel_pos
- else:
- # Interpolate rel pos.
- new_pos_embed = F.interpolate(
- rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1),
- size=d,
- mode='linear',
- )
- return new_pos_embed.reshape(-1, d).permute(1, 0)
- def cal_rel_pos_spatial(attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h,
- rel_pos_w):
- """
- Decomposed Spatial Relative Positional Embeddings.
- """
- sp_idx = 1 if has_cls_embed else 0
- q_t, q_h, q_w = q_shape
- k_t, k_h, k_w = k_shape
- dh = int(2 * max(q_h, k_h) - 1)
- dw = int(2 * max(q_w, k_w) - 1)
- # Scale up rel pos if shapes for q and k are different.
- q_h_ratio = max(k_h / q_h, 1.0)
- k_h_ratio = max(q_h / k_h, 1.0)
- dist_h = (
- torch.arange(q_h)[:, None] * q_h_ratio
- - torch.arange(k_h)[None, :] * k_h_ratio)
- dist_h += (k_h - 1) * k_h_ratio
- q_w_ratio = max(k_w / q_w, 1.0)
- k_w_ratio = max(q_w / k_w, 1.0)
- dist_w = (
- torch.arange(q_w)[:, None] * q_w_ratio
- - torch.arange(k_w)[None, :] * k_w_ratio)
- dist_w += (k_w - 1) * k_w_ratio
- # Interpolate rel pos if needed.
- rel_pos_h = get_rel_pos(rel_pos_h, dh)
- rel_pos_w = get_rel_pos(rel_pos_w, dw)
- Rh = rel_pos_h[dist_h.long()]
- Rw = rel_pos_w[dist_w.long()]
- B, n_head, q_N, dim = q.shape
- r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
- rel_h_q = torch.einsum('bythwc,hkc->bythwk', r_q,
- Rh) # [B, H, q_t, qh, qw, k_h]
- rel_w_q = torch.einsum('bythwc,wkc->bythwk', r_q,
- Rw) # [B, H, q_t, qh, qw, k_w]
- attn[:, :, sp_idx:, sp_idx:] = (
- attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
- + rel_h_q[:, :, :, :, :, None, :, None]
- + rel_w_q[:, :, :, :, :, None, None, :]).view(B, -1, q_t * q_h * q_w,
- k_t * k_h * k_w)
- return attn
- def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t):
- """
- Temporal Relative Positional Embeddings.
- """
- sp_idx = 1 if has_cls_embed else 0
- q_t, q_h, q_w = q_shape
- k_t, k_h, k_w = k_shape
- dt = int(2 * max(q_t, k_t) - 1)
- # Interpolate rel pos if needed.
- rel_pos_t = get_rel_pos(rel_pos_t, dt)
- # Scale up rel pos if shapes for q and k are different.
- q_t_ratio = max(k_t / q_t, 1.0)
- k_t_ratio = max(q_t / k_t, 1.0)
- dist_t = (
- torch.arange(q_t)[:, None] * q_t_ratio
- - torch.arange(k_t)[None, :] * k_t_ratio)
- dist_t += (k_t - 1) * k_t_ratio
- Rt = rel_pos_t[dist_t.long()]
- B, n_head, q_N, dim = q.shape
- r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim)
- # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim]
- r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w,
- dim)
- # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t]
- rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1)
- # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t]
- rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5)
- attn[:, :, sp_idx:, sp_idx:] = (
- attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w)
- + rel[:, :, :, :, :, :, None, None]).view(B, -1, q_t * q_h * q_w,
- k_t * k_h * k_w)
- return attn
- class MultiScaleAttention(nn.Module):
- def __init__(
- self,
- dim,
- dim_out,
- input_size,
- num_heads=8,
- qkv_bias=False,
- drop_rate=0.0,
- kernel_q=(1, 1, 1),
- kernel_kv=(1, 1, 1),
- stride_q=(1, 1, 1),
- stride_kv=(1, 1, 1),
- norm_layer=nn.LayerNorm,
- has_cls_embed=True,
- # Options include `conv`, `avg`, and `max`.
- mode='conv',
- # If True, perform pool before projection.
- pool_first=False,
- rel_pos_spatial=False,
- rel_pos_temporal=False,
- rel_pos_zero_init=False,
- residual_pooling=True,
- separate_qkv=False,
- ):
- super().__init__()
- self.pool_first = pool_first
- self.separate_qkv = separate_qkv
- self.drop_rate = drop_rate
- self.num_heads = num_heads
- self.dim_out = dim_out
- head_dim = dim_out // num_heads
- self.scale = head_dim**-0.5
- self.has_cls_embed = has_cls_embed
- padding_q = [int(q // 2) for q in kernel_q]
- padding_kv = [int(kv // 2) for kv in kernel_kv]
- if pool_first or separate_qkv:
- self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
- self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
- self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
- else:
- self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
- self.proj = nn.Linear(dim_out, dim_out)
- if drop_rate > 0.0:
- self.proj_drop = nn.Dropout(drop_rate)
- # Skip pooling with kernel and stride size of (1, 1, 1).
- if np.prod(kernel_q) == 1 and np.prod(stride_q) == 1:
- kernel_q = ()
- if np.prod(kernel_kv) == 1 and np.prod(stride_kv) == 1:
- kernel_kv = ()
- self.mode = mode
- if mode in ('avg', 'max'):
- pool_op = nn.MaxPool3d if mode == 'max' else nn.AvgPool3d
- self.pool_q = (
- pool_op(kernel_q, stride_q, padding_q, ceil_mode=False)
- if len(kernel_q) > 0 else None)
- self.pool_k = (
- pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
- if len(kernel_kv) > 0 else None)
- self.pool_v = (
- pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False)
- if len(kernel_kv) > 0 else None)
- elif mode == 'conv' or mode == 'conv_unshared':
- if pool_first:
- dim_conv = dim // num_heads if mode == 'conv' else dim
- else:
- dim_conv = dim_out // num_heads if mode == 'conv' else dim_out
- self.pool_q = (
- nn.Conv3d(
- dim_conv,
- dim_conv,
- kernel_q,
- stride=stride_q,
- padding=padding_q,
- groups=dim_conv,
- bias=False,
- ) if len(kernel_q) > 0 else None)
- self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None
- self.pool_k = (
- nn.Conv3d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- ) if len(kernel_kv) > 0 else None)
- self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
- self.pool_v = (
- nn.Conv3d(
- dim_conv,
- dim_conv,
- kernel_kv,
- stride=stride_kv,
- padding=padding_kv,
- groups=dim_conv,
- bias=False,
- ) if len(kernel_kv) > 0 else None)
- self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None
- else:
- raise NotImplementedError(f'Unsupported model {mode}')
- self.rel_pos_spatial = rel_pos_spatial
- self.rel_pos_temporal = rel_pos_temporal
- if self.rel_pos_spatial:
- assert input_size[1] == input_size[2]
- size = input_size[1]
- q_size = size // stride_q[1] if len(stride_q) > 0 else size
- kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
- rel_sp_dim = 2 * max(q_size, kv_size) - 1
- self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
- self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim))
- if not rel_pos_zero_init:
- trunc_normal_(self.rel_pos_h, std=0.02)
- trunc_normal_(self.rel_pos_w, std=0.02)
- if self.rel_pos_temporal:
- self.rel_pos_t = nn.Parameter(
- torch.zeros(2 * input_size[0] - 1, head_dim))
- # if not rel_pos_zero_init:
- # trunc_normal_(self.rel_pos_t, std=0.02)
- self.residual_pooling = residual_pooling
- def forward(self, x, thw_shape):
- B, N, _ = x.shape
- if self.pool_first:
- if self.mode == 'conv_unshared':
- fold_dim = 1
- else:
- fold_dim = self.num_heads
- x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
- q = k = v = x
- else:
- assert self.mode != 'conv_unshared'
- if not self.separate_qkv:
- qkv = (
- self.qkv(x).reshape(B, N, 3, self.num_heads,
- -1).permute(2, 0, 3, 1, 4))
- q, k, v = qkv[0], qkv[1], qkv[2]
- else:
- q = k = v = x
- q = (
- self.q(q).reshape(B, N, self.num_heads,
- -1).permute(0, 2, 1, 3))
- k = (
- self.k(k).reshape(B, N, self.num_heads,
- -1).permute(0, 2, 1, 3))
- v = (
- self.v(v).reshape(B, N, self.num_heads,
- -1).permute(0, 2, 1, 3))
- q, q_shape = attention_pool(
- q,
- self.pool_q,
- thw_shape,
- has_cls_embed=self.has_cls_embed,
- norm=self.norm_q if hasattr(self, 'norm_q') else None,
- )
- k, k_shape = attention_pool(
- k,
- self.pool_k,
- thw_shape,
- has_cls_embed=self.has_cls_embed,
- norm=self.norm_k if hasattr(self, 'norm_k') else None,
- )
- v, v_shape = attention_pool(
- v,
- self.pool_v,
- thw_shape,
- has_cls_embed=self.has_cls_embed,
- norm=self.norm_v if hasattr(self, 'norm_v') else None,
- )
- if self.pool_first:
- q_N = (
- np.prod(q_shape)
- + 1 if self.has_cls_embed else np.prod(q_shape))
- k_N = (
- np.prod(k_shape)
- + 1 if self.has_cls_embed else np.prod(k_shape))
- v_N = (
- np.prod(v_shape)
- + 1 if self.has_cls_embed else np.prod(v_shape))
- q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
- q = (
- self.q(q).reshape(B, q_N, self.num_heads,
- -1).permute(0, 2, 1, 3))
- v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
- v = (
- self.v(v).reshape(B, v_N, self.num_heads,
- -1).permute(0, 2, 1, 3))
- k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
- k = (
- self.k(k).reshape(B, k_N, self.num_heads,
- -1).permute(0, 2, 1, 3))
- N = q.shape[2]
- attn = (q * self.scale) @ k.transpose(-2, -1)
- if self.rel_pos_spatial:
- attn = cal_rel_pos_spatial(
- attn,
- q,
- k,
- self.has_cls_embed,
- q_shape,
- k_shape,
- self.rel_pos_h,
- self.rel_pos_w,
- )
- if self.rel_pos_temporal:
- attn = cal_rel_pos_temporal(
- attn,
- q,
- self.has_cls_embed,
- q_shape,
- k_shape,
- self.rel_pos_t,
- )
- attn = attn.softmax(dim=-1)
- x = attn @ v
- if self.residual_pooling:
- # Minor Difference
- if self.has_cls_embed:
- x[:, :, 1:, :] += q[:, :, 1:, :]
- else:
- x = x + q
- x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
- x = self.proj(x)
- if self.drop_rate > 0.0:
- x = self.proj_drop(x)
- return x, q_shape
- class MultiScaleBlock(nn.Module):
- def __init__(
- self,
- dim,
- dim_out,
- num_heads,
- input_size,
- mlp_ratio=4.0,
- qkv_bias=False,
- qk_scale=None,
- drop_rate=0.0,
- drop_path=0.0,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm,
- up_rate=None,
- kernel_q=(1, 1, 1),
- kernel_kv=(1, 1, 1),
- stride_q=(1, 1, 1),
- stride_kv=(1, 1, 1),
- mode='conv',
- has_cls_embed=True,
- pool_first=False,
- rel_pos_spatial=False,
- rel_pos_temporal=False,
- rel_pos_zero_init=False,
- residual_pooling=True,
- dim_mul_in_att=False,
- separate_qkv=False,
- use_grad_checkpoint=False,
- ):
- super().__init__()
- self.dim = dim
- self.dim_out = dim_out
- self.norm1 = norm_layer(dim)
- self.dim_mul_in_att = dim_mul_in_att
- kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
- stride_skip = stride_q
- padding_skip = [int(skip // 2) for skip in kernel_skip]
- att_dim = dim_out if dim_mul_in_att else dim
- self.use_grad_checkpoint = use_grad_checkpoint
- self.attn = MultiScaleAttention(
- dim,
- att_dim,
- num_heads=num_heads,
- input_size=input_size,
- qkv_bias=qkv_bias,
- drop_rate=drop_rate,
- kernel_q=kernel_q,
- kernel_kv=kernel_kv,
- stride_q=stride_q,
- stride_kv=stride_kv,
- norm_layer=norm_layer,
- has_cls_embed=has_cls_embed,
- mode=mode,
- pool_first=pool_first,
- rel_pos_spatial=rel_pos_spatial,
- rel_pos_temporal=rel_pos_temporal,
- rel_pos_zero_init=rel_pos_zero_init,
- residual_pooling=residual_pooling,
- separate_qkv=separate_qkv,
- )
- self.drop_path = (
- DropPath(drop_path) if drop_path > 0.0 else nn.Identity())
- self.norm2 = norm_layer(att_dim)
- mlp_hidden_dim = int(att_dim * mlp_ratio)
- self.has_cls_embed = has_cls_embed
- # TODO: check the use case for up_rate, and merge the following lines
- if up_rate is not None and up_rate > 1:
- mlp_dim_out = dim * up_rate
- else:
- mlp_dim_out = dim_out
- self.mlp = Mlp(
- in_features=att_dim,
- hidden_features=mlp_hidden_dim,
- out_features=mlp_dim_out,
- act_layer=act_layer,
- drop_rate=drop_rate,
- )
- if dim != dim_out:
- self.proj = nn.Linear(dim, dim_out)
- self.pool_skip = (
- nn.MaxPool3d(
- kernel_skip, stride_skip, padding_skip, ceil_mode=False)
- if len(kernel_skip) > 0 else None)
- def forward(self, x, thw_shape):
- x_norm = self.norm1(x)
- if self.use_grad_checkpoint:
- x_block, thw_shape_new = checkpoint.checkpoint(
- self.attn, x_norm, thw_shape)
- else:
- x_block, thw_shape_new = self.attn(x_norm, thw_shape)
- if self.dim_mul_in_att and self.dim != self.dim_out:
- x = self.proj(x_norm)
- x_res, _ = attention_pool(
- x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed)
- x = x_res + self.drop_path(x_block)
- x_norm = self.norm2(x)
- if self.use_grad_checkpoint:
- x_mlp = checkpoint.checkpoint(self.mlp, x_norm)
- else:
- x_mlp = self.mlp(x_norm)
- if not self.dim_mul_in_att and self.dim != self.dim_out:
- x = self.proj(x_norm)
- x = x + self.drop_path(x_mlp)
- return x, thw_shape_new
- class MViTv2(nn.Module):
- """
- Improved Multiscale Vision Transformers for Classification and Detection
- Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
- Christoph Feichtenhofer*
- https://arxiv.org/abs/2112.01526
- Multiscale Vision Transformers
- Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
- Christoph Feichtenhofer*
- https://arxiv.org/abs/2104.11227
- """
- def __init__(
- self,
- img_size=224,
- embed_dim=96,
- num_classes=1000,
- num_frames=4,
- num_heads=1,
- depth=24,
- patch_kernel=[3, 7, 7],
- patch_stride=[2, 4, 4],
- patch_padding=[1, 3, 3],
- config=None,
- dropout_rate=0.,
- drop_path_rate=0.,
- mlp_ratio=4.,
- qkv_bias=True,
- mode='conv',
- cls_embed_on=True,
- use_abs_pos=False,
- rel_pos_spatial=True,
- rel_pos_temporal=True,
- rel_pos_zero_init=False,
- residual_pooling=True,
- dim_mul_in_att=True,
- pool_first=False,
- zero_decay_pos_cls=False,
- separate_qkv=False,
- norm_stem=False,
- sep_pos_embed=False,
- use_grad_checkpoint=True,
- ):
- super().__init__()
- # Prepare input.
- in_chans = 3
- self.img_size = img_size
- # Prepare output.
- self.num_classes = num_classes
- self.embed_dim = embed_dim
- # MViT params.
- self.num_heads = num_heads
- self.depth = depth
- self.cls_embed_on = cls_embed_on
- self.use_abs_pos = use_abs_pos
- self.zero_decay_pos_cls = zero_decay_pos_cls
- self.use_grad_checkpoint = use_grad_checkpoint
- self.sep_pos_embed = sep_pos_embed
- self.drop_rate = dropout_rate
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
- if use_grad_checkpoint:
- self.patch_embed = checkpoint_wrapper(
- PatchEmbed(
- dim_in=in_chans,
- dim_out=embed_dim,
- kernel=patch_kernel,
- stride=patch_stride,
- padding=patch_padding,
- ))
- else:
- self.patch_embed = PatchEmbed(
- dim_in=in_chans,
- dim_out=embed_dim,
- kernel=patch_kernel,
- stride=patch_stride,
- padding=patch_padding,
- )
- patch_dims = [
- num_frames // patch_stride[0],
- img_size // patch_stride[1],
- img_size // patch_stride[2],
- ]
- num_patches = np.prod(patch_dims)
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
- ] # stochastic depth decay rule
- if self.cls_embed_on:
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- pos_embed_dim = num_patches + 1
- else:
- pos_embed_dim = num_patches
- if self.use_abs_pos:
- self.pos_embed = nn.Parameter(
- torch.zeros(1, pos_embed_dim, embed_dim))
- if self.use_abs_pos:
- if self.sep_pos_embed:
- self.pos_embed_spatial = nn.Parameter(
- torch.zeros(1, self.patch_dims[1] * self.patch_dims[2],
- embed_dim))
- self.pos_embed_temporal = nn.Parameter(
- torch.zeros(1, self.patch_dims[0], embed_dim))
- if self.cls_embed_on:
- self.pos_embed_class = nn.Parameter(
- torch.zeros(1, 1, embed_dim))
- else:
- self.pos_embed = nn.Parameter(
- torch.zeros(1, pos_embed_dim, embed_dim))
- assert config is not None
- # MViT backbone configs
- dim_mul, head_mul, pool_q, pool_kv, stride_q, stride_kv = _prepare_mvit_configs(
- config)
- input_size = patch_dims
- self.norm_stem = norm_layer(embed_dim) if norm_stem else None
- self.blocks = nn.ModuleList()
- for i in range(depth):
- num_heads = round_width(num_heads, head_mul[i])
- if dim_mul_in_att:
- dim_out = round_width(
- embed_dim,
- dim_mul[i],
- divisor=round_width(num_heads, head_mul[i]),
- )
- else:
- dim_out = round_width(
- embed_dim,
- dim_mul[i + 1],
- divisor=round_width(num_heads, head_mul[i + 1]),
- )
- attention_block = MultiScaleBlock(
- dim=embed_dim,
- dim_out=dim_out,
- num_heads=num_heads,
- input_size=input_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop_rate=self.drop_rate,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- kernel_q=pool_q[i] if len(pool_q) > i else [],
- kernel_kv=pool_kv[i] if len(pool_kv) > i else [],
- stride_q=stride_q[i] if len(stride_q) > i else [],
- stride_kv=stride_kv[i] if len(stride_kv) > i else [],
- mode=mode,
- has_cls_embed=self.cls_embed_on,
- pool_first=pool_first,
- rel_pos_spatial=rel_pos_spatial,
- rel_pos_temporal=rel_pos_temporal,
- rel_pos_zero_init=rel_pos_zero_init,
- residual_pooling=residual_pooling,
- dim_mul_in_att=dim_mul_in_att,
- separate_qkv=separate_qkv,
- use_grad_checkpoint=False)
- if use_grad_checkpoint:
- attention_block = checkpoint_wrapper(
- attention_block, offload_to_cpu=False)
- self.blocks.append(attention_block)
- if len(stride_q[i]) > 0:
- input_size = [
- size // stride
- for size, stride in zip(input_size, stride_q[i])
- ]
- embed_dim = dim_out
- self.norm = norm_layer(embed_dim)
- self.head = nn.Identity()
- if self.use_abs_pos:
- if self.sep_pos_embed:
- trunc_normal_(self.pos_embed_spatial, std=0.02)
- trunc_normal_(self.pos_embed_temporal, std=0.02)
- if self.cls_embed_on:
- trunc_normal_(self.pos_embed_class, std=0.02)
- else:
- trunc_normal_(self.pos_embed, std=0.02)
- if self.cls_embed_on:
- trunc_normal_(self.cls_token, std=0.02)
- self.apply(self._init_weights)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- nn.init.trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- @torch.jit.ignore
- def no_weight_decay(self):
- names = []
- if self.zero_decay_pos_cls:
- if self.use_abs_pos:
- if self.sep_pos_embed:
- names.extend([
- 'pos_embed_spatial',
- 'pos_embed_temporal',
- 'pos_embed_class',
- ])
- else:
- names.append(['pos_embed'])
- if self.rel_pos_spatial:
- names.extend(['rel_pos_h', 'rel_pos_w', 'rel_pos_hw'])
- if self.rel_pos_temporal:
- names.extend(['rel_pos_t'])
- if self.cls_embed_on:
- names.append('cls_token')
- return names
- def _get_pos_embed(self, pos_embed, bcthw):
- t, h, w = bcthw[-3], bcthw[-2], bcthw[-1]
- if self.cls_embed_on:
- cls_pos_embed = pos_embed[:, 0:1, :]
- pos_embed = pos_embed[:, 1:]
- txy_num = pos_embed.shape[1]
- p_t, p_h, p_w = self.patch_dims
- assert p_t * p_h * p_w == txy_num
- if (p_t, p_h, p_w) != (t, h, w):
- new_pos_embed = F.interpolate(
- pos_embed[:, :, :].reshape(1, p_t, p_h, p_w,
- -1).permute(0, 4, 1, 2, 3),
- size=(t, h, w),
- mode='trilinear',
- )
- pos_embed = new_pos_embed.reshape(1, -1,
- t * h * w).permute(0, 2, 1)
- if self.cls_embed_on:
- pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1)
- return pos_embed
- def forward_features(self, x):
- x = x.permute(0, 2, 1, 3, 4)
- x, bcthw = self.patch_embed(x)
- T, H, W = bcthw[-3], bcthw[-2], bcthw[-1]
- B, N, C = x.shape
- if self.cls_embed_on:
- cls_tokens = self.cls_token.expand(
- B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
- x = torch.cat((cls_tokens, x), dim=1)
- if self.use_abs_pos:
- if self.sep_pos_embed:
- pos_embed = self.pos_embed_spatial.repeat(
- 1, self.patch_dims[0], 1) + torch.repeat_interleave(
- self.pos_embed_temporal,
- self.patch_dims[1] * self.patch_dims[2],
- dim=1)
- if self.cls_embed_on:
- pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1)
- pos_embed = self._get_pos_embed(pos_embed, bcthw)
- x = x + pos_embed
- else:
- pos_embed = self._get_pos_embed(self.pos_embed, bcthw)
- x = x + pos_embed
- if self.drop_rate:
- x = self.pos_drop(x)
- if self.norm_stem:
- x = self.norm_stem(x)
- thw = [T, H, W]
- for blk in self.blocks:
- x, thw = blk(x, thw)
- x = self.norm(x)
- return x
- def forward(self, x):
- x = self.forward_features(x)
- x = self.head(x)
- return x
|