backbone.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # The implementation is adopted from CLIP,
  2. # made publicly available under the MIT License at https://github.com/openai/CLIP
  3. import math
  4. import os
  5. from collections import OrderedDict
  6. from typing import Tuple, Union
  7. import numpy as np
  8. import torch
  9. import torch.nn.functional as F
  10. from torch import nn
  11. from .vim import ViM
  12. class LayerNorm(nn.LayerNorm):
  13. def forward(self, x: torch.Tensor):
  14. orig_type = x.dtype
  15. ret = super().forward(x.type(torch.float32))
  16. return ret.type(orig_type)
  17. class QuickGELU(nn.Module):
  18. def forward(self, x: torch.Tensor):
  19. return x * torch.sigmoid(1.702 * x)
  20. class ResidualAttentionBlock(nn.Module):
  21. def __init__(self,
  22. d_model: int,
  23. n_head: int,
  24. attn_mask: torch.Tensor = None):
  25. super().__init__()
  26. self.attn = nn.MultiheadAttention(d_model, n_head)
  27. self.ln_1 = LayerNorm(d_model)
  28. self.mlp = nn.Sequential(
  29. OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
  30. ('gelu', QuickGELU()),
  31. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  32. self.ln_2 = LayerNorm(d_model)
  33. self.attn_mask = attn_mask
  34. self.vim_att = ViM()
  35. self.vim_mlp = ViM()
  36. def attention(self, x: torch.Tensor):
  37. self.attn_mask = self.attn_mask.to(
  38. dtype=x.dtype,
  39. device=x.device) if self.attn_mask is not None else None
  40. return self.attn(
  41. x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  42. def forward(self, x: torch.Tensor, task_name: str):
  43. x_normed_1 = self.ln_1(x)
  44. x = x + self.attention(x_normed_1)
  45. x = x + self.vim_att(x_normed_1, task_name)
  46. x_normed_2 = self.ln_2(x)
  47. x = x + self.mlp(x_normed_2)
  48. x = x + self.vim_mlp(x_normed_2, task_name)
  49. return x
  50. class Transformer(nn.Module):
  51. def __init__(self,
  52. width: int,
  53. layers: int,
  54. heads: int,
  55. attn_mask: torch.Tensor = None):
  56. super().__init__()
  57. self.width = width
  58. self.layers = layers
  59. self.resblocks = nn.ModuleList([
  60. ResidualAttentionBlock(width, heads, attn_mask)
  61. for _ in range(layers)
  62. ])
  63. def forward(self, x: torch.Tensor, **kwargs):
  64. L, B, D = x.size()
  65. features = []
  66. for i, blk in enumerate(self.resblocks):
  67. x = blk(x, **kwargs)
  68. features.append(x)
  69. return features
  70. class VisionTransformer(nn.Module):
  71. """
  72. The Vision Transformer (ViT) model
  73. Args:
  74. - input_resolution (int): shape of input image
  75. - patch_width (int): size of patch tokens
  76. - width (int): feature channels
  77. - layers (int): number of transformer layers
  78. - heads (int): number of multi-head attention
  79. - output_dim (int): output feature channels
  80. """
  81. def __init__(self,
  82. input_resolution: int,
  83. patch_size: int,
  84. width: int,
  85. layers: int,
  86. heads: int,
  87. output_dim: int = 512):
  88. super().__init__()
  89. self.input_resolution = input_resolution
  90. self.conv1 = nn.Conv2d(
  91. in_channels=3,
  92. out_channels=width,
  93. kernel_size=patch_size,
  94. stride=patch_size,
  95. bias=False)
  96. scale = width**-0.5
  97. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  98. self.positional_embedding = nn.Parameter(scale * torch.randn(
  99. (input_resolution // patch_size)**2 + 1, width))
  100. self.ln_pre = LayerNorm(width)
  101. self.patch_per_side = input_resolution // patch_size
  102. self.transformer = Transformer(width, layers, heads)
  103. self.ln_post = LayerNorm(width)
  104. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  105. self.output_dim = output_dim
  106. def forward(self, x: torch.Tensor, **kwargs):
  107. x = self.conv1(x) # shape = [*, width, grid, grid]
  108. B = x.size(0)
  109. P = x.size(2)
  110. x = x.reshape(x.shape[0], x.shape[1], -1) # [*, width, grid ** 2]
  111. x = x.permute(0, 2, 1) # [*, grid ** 2, width]
  112. cls_token = self.class_embedding.to(x.dtype).reshape(1, 1, -1).repeat(
  113. B, 1, 1)
  114. x = torch.cat([cls_token, x],
  115. dim=1) # shape = [*, grid ** 2 + 1, width]
  116. x = x + self.positional_embedding.to(x.dtype)
  117. x = self.ln_pre(x)
  118. x = x.permute(1, 0, 2) # NLD -> LND
  119. x_per_layer = self.transformer(x, **kwargs)
  120. x = x_per_layer[-1]
  121. x = x.permute(1, 0, 2) # LND -> NLD
  122. x = self.ln_post(x[:, 0, :])
  123. if self.proj is not None:
  124. x = x @ self.proj
  125. # outputs: [x_1, ..., x_N, last_cls_token], x_i in 2D
  126. outputs = []
  127. for output in x_per_layer:
  128. outputs.append(output[1:, :, :].permute(1, 2,
  129. 0).reshape(B, -1, P, P))
  130. outputs.append(x)
  131. return outputs
  132. model_dict = {
  133. 'vit_b16_224':
  134. dict(input_resolution=224, patch_size=16, width=768, layers=12, heads=12),
  135. 'vit_b32_224':
  136. dict(input_resolution=224, patch_size=32, width=768, layers=12, heads=12),
  137. }
  138. def build_backbone(arch='vit_b16_224', pretrained=None):
  139. """ build a ViT + ViM model
  140. Args:
  141. arch: name of backbone
  142. pretrained: weights of pretrained model
  143. """
  144. model_args = model_dict[arch]
  145. model = VisionTransformer(**model_args)
  146. model.load_state_dict(pretrained)
  147. return model
  148. if __name__ == '__main__':
  149. model = build_backbone()