| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- # Copyright (c) 2021 OpenAI
- #
- # This source code is licensed under the MIT license which can be found at
- # https://github.com/openai/CLIP/blob/main/LICENSE
- from collections import OrderedDict
- import torch
- import torch.nn.functional as F
- from fairseq.modules import LayerNorm
- from torch import nn
- from .utils.utils import DropPath
- __all__ = [
- 'vit_base',
- 'vit_large',
- 'vit_large_336',
- 'vit_huge',
- ]
- class QuickGELU(nn.Module):
- r"""
- An activation function module.
- """
- def forward(self, x: torch.Tensor):
- return x * torch.sigmoid(1.702 * x)
- class ResidualAttentionBlock(nn.Module):
- r"""
- A residual attention block module.
- step 1. Calculate the self attention in input with layer normalization.
- step 2. Add input to the result of self attention's result as I.
- step 3. Calculate the mlp of input I with layer normalization.
- step 4. Add I to the result of mlp.
- """
- def __init__(self,
- d_model: int,
- n_head: int,
- attn_mask: torch.Tensor = None,
- drop_path_rate=0.0):
- r"""
- Args:
- d_model (`int`): The embedding dimensions.
- n_head (`int`): The number of heads in self attention block.
- attn_mask (`Tensor`, **optional**, default to None):
- Attention mask using in self attention.
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- super().__init__()
- self.attn = nn.MultiheadAttention(d_model, n_head)
- self.ln_1 = LayerNorm(d_model)
- self.mlp = nn.Sequential(
- OrderedDict([
- ('c_fc', nn.Linear(d_model, d_model * 4)),
- ('gelu', QuickGELU()),
- ('c_proj', nn.Linear(d_model * 4, d_model)),
- ]))
- self.ln_2 = LayerNorm(d_model)
- self.attn_mask = attn_mask
- self.drop_path = DropPath(drop_path_rate)
- def attention(self, x: torch.Tensor):
- r"""
- A wrapper of self attention .
- """
- self.attn_mask = (
- self.attn_mask.to(dtype=x.dtype, device=x.device)
- if self.attn_mask is not None else None)
- return self.attn(
- x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
- def forward(self, x: torch.Tensor):
- x = x + self.drop_path(self.attention(self.ln_1(x)))
- x = x + self.drop_path(self.mlp(self.ln_2(x)))
- return x
- class Transformer(nn.Module):
- r"""
- A transformer module using in `VisionTransformer`.
- Execute a sequential of `ResidualAttentionBlock`.
- """
- def __init__(
- self,
- width: int,
- layers: int,
- heads: int,
- attn_mask: torch.Tensor = None,
- drop_path_rate: float = 0.0,
- ):
- r"""
- Args:
- width (`int`): The width of input image.
- layers (`int`): The number of `ResidualAttentionBlock` layers.
- heads (int): The number of self attention heads.
- attn_mask (`Tensor`, **optional**, default to None):
- Attention mask using in self attention.
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- super().__init__()
- self.width = width
- self.layers = layers
- self.resblocks = nn.Sequential(*[
- ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate)
- for _ in range(layers)
- ])
- def forward(self, x: torch.Tensor):
- return self.resblocks(x)
- class VisionTransformer(nn.Module):
- r"""
- Vision transformer module.
- step 1. Using conv2d to get the image embedding.
- step 2. If the resolution of input image doesn't equal to the initialized one
- do `bilinear` interpolate to get new patch position embedding.
- step 3. Add position embedding to image embedding to generate final image representation.
- step 4. Do `Transformer` to the image representation.
- """
- def __init__(
- self,
- input_resolution: int,
- patch_size: int,
- width: int,
- layers: int,
- heads: int,
- drop_path_rate: float = 0.0,
- ):
- r"""
- Args:
- input_resolution (`int`): The resolution of input image.
- patch_size (`int`): The resolution of each patch image.
- width (`int`): The dimension of each patch image.
- layers (`int`): The number of `ResidualAttentionBlock` in `Transformer`.
- heads (`int`): The number of heads in self attention block.
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- super().__init__()
- self.input_resolution = input_resolution
- self.patch_size = patch_size
- self.conv1 = nn.Conv2d(
- in_channels=3,
- out_channels=width,
- kernel_size=patch_size,
- stride=patch_size,
- bias=False,
- )
- scale = width**-0.5
- self.width = width
- self.positional_embedding = nn.Parameter(scale * torch.randn(
- (input_resolution // patch_size)**2 + 1, width))
- self.ln_pre = LayerNorm(width)
- self.transformer = Transformer(
- width, layers, heads, drop_path_rate=drop_path_rate)
- def forward(self, x: torch.Tensor):
- resolution = x.shape[-2]
- height, width = x.shape[-2] // self.patch_size, x.shape[
- -1] // self.patch_size
- x = self.conv1(x) # shape = [*, width, grid, grid]
- x = x.reshape(x.shape[0], x.shape[1],
- -1) # shape = [*, width, grid ** 2]
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
- if resolution != self.input_resolution:
- old_pe = self.positional_embedding[1:]
- patch_num = self.input_resolution // self.patch_size
- old_pe = old_pe.reshape(1, patch_num, patch_num,
- -1).permute(0, 3, 1, 2)
- new_pe = F.interpolate(
- old_pe, size=(height, width), mode='bilinear')
- new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1)
- x = x + new_pe.to(x.dtype)
- else:
- x = x + self.positional_embedding[1:].to(x.dtype)
- x = self.ln_pre(x)
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.transformer(x)
- x = x.permute(1, 0, 2) # LND -> NLD
- bz, seq, hidden = x.shape
- x = x.transpose(1, 2).reshape(bz, hidden, height, width)
- return x
- def vit_base(drop_path_rate: float = 0.0):
- r"""
- An instance of base vision transformer model.
- Args:
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate)
- def vit_large(drop_path_rate: float = 0.0):
- r"""
- An instance of large vision transformer model.
- Args:
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate)
- def vit_large_336(drop_path_rate: float = 0.0):
- r"""
- An instance of large vision transformer model with 336 as input image width .
- Args:
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate)
- def vit_huge(drop_path_rate: float = 0.0):
- r"""
- An instance of huge vision transformer model.
- Args:
- drop_path_rate (`float`, **optional**, default to 0.0):
- Drop path rate. See more details about drop path from
- https://arxiv.org/pdf/1605.07648v4.pdf.
- """
- return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate)
|