vit.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright (c) 2021 OpenAI
  2. #
  3. # This source code is licensed under the MIT license which can be found at
  4. # https://github.com/openai/CLIP/blob/main/LICENSE
  5. from collections import OrderedDict
  6. import torch
  7. import torch.nn.functional as F
  8. from fairseq.modules import LayerNorm
  9. from torch import nn
  10. from .utils.utils import DropPath
  11. __all__ = [
  12. 'vit_base',
  13. 'vit_large',
  14. 'vit_large_336',
  15. 'vit_huge',
  16. ]
  17. class QuickGELU(nn.Module):
  18. r"""
  19. An activation function module.
  20. """
  21. def forward(self, x: torch.Tensor):
  22. return x * torch.sigmoid(1.702 * x)
  23. class ResidualAttentionBlock(nn.Module):
  24. r"""
  25. A residual attention block module.
  26. step 1. Calculate the self attention in input with layer normalization.
  27. step 2. Add input to the result of self attention's result as I.
  28. step 3. Calculate the mlp of input I with layer normalization.
  29. step 4. Add I to the result of mlp.
  30. """
  31. def __init__(self,
  32. d_model: int,
  33. n_head: int,
  34. attn_mask: torch.Tensor = None,
  35. drop_path_rate=0.0):
  36. r"""
  37. Args:
  38. d_model (`int`): The embedding dimensions.
  39. n_head (`int`): The number of heads in self attention block.
  40. attn_mask (`Tensor`, **optional**, default to None):
  41. Attention mask using in self attention.
  42. drop_path_rate (`float`, **optional**, default to 0.0):
  43. Drop path rate. See more details about drop path from
  44. https://arxiv.org/pdf/1605.07648v4.pdf.
  45. """
  46. super().__init__()
  47. self.attn = nn.MultiheadAttention(d_model, n_head)
  48. self.ln_1 = LayerNorm(d_model)
  49. self.mlp = nn.Sequential(
  50. OrderedDict([
  51. ('c_fc', nn.Linear(d_model, d_model * 4)),
  52. ('gelu', QuickGELU()),
  53. ('c_proj', nn.Linear(d_model * 4, d_model)),
  54. ]))
  55. self.ln_2 = LayerNorm(d_model)
  56. self.attn_mask = attn_mask
  57. self.drop_path = DropPath(drop_path_rate)
  58. def attention(self, x: torch.Tensor):
  59. r"""
  60. A wrapper of self attention .
  61. """
  62. self.attn_mask = (
  63. self.attn_mask.to(dtype=x.dtype, device=x.device)
  64. if self.attn_mask is not None else None)
  65. return self.attn(
  66. x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  67. def forward(self, x: torch.Tensor):
  68. x = x + self.drop_path(self.attention(self.ln_1(x)))
  69. x = x + self.drop_path(self.mlp(self.ln_2(x)))
  70. return x
  71. class Transformer(nn.Module):
  72. r"""
  73. A transformer module using in `VisionTransformer`.
  74. Execute a sequential of `ResidualAttentionBlock`.
  75. """
  76. def __init__(
  77. self,
  78. width: int,
  79. layers: int,
  80. heads: int,
  81. attn_mask: torch.Tensor = None,
  82. drop_path_rate: float = 0.0,
  83. ):
  84. r"""
  85. Args:
  86. width (`int`): The width of input image.
  87. layers (`int`): The number of `ResidualAttentionBlock` layers.
  88. heads (int): The number of self attention heads.
  89. attn_mask (`Tensor`, **optional**, default to None):
  90. Attention mask using in self attention.
  91. drop_path_rate (`float`, **optional**, default to 0.0):
  92. Drop path rate. See more details about drop path from
  93. https://arxiv.org/pdf/1605.07648v4.pdf.
  94. """
  95. super().__init__()
  96. self.width = width
  97. self.layers = layers
  98. self.resblocks = nn.Sequential(*[
  99. ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate)
  100. for _ in range(layers)
  101. ])
  102. def forward(self, x: torch.Tensor):
  103. return self.resblocks(x)
  104. class VisionTransformer(nn.Module):
  105. r"""
  106. Vision transformer module.
  107. step 1. Using conv2d to get the image embedding.
  108. step 2. If the resolution of input image doesn't equal to the initialized one
  109. do `bilinear` interpolate to get new patch position embedding.
  110. step 3. Add position embedding to image embedding to generate final image representation.
  111. step 4. Do `Transformer` to the image representation.
  112. """
  113. def __init__(
  114. self,
  115. input_resolution: int,
  116. patch_size: int,
  117. width: int,
  118. layers: int,
  119. heads: int,
  120. drop_path_rate: float = 0.0,
  121. ):
  122. r"""
  123. Args:
  124. input_resolution (`int`): The resolution of input image.
  125. patch_size (`int`): The resolution of each patch image.
  126. width (`int`): The dimension of each patch image.
  127. layers (`int`): The number of `ResidualAttentionBlock` in `Transformer`.
  128. heads (`int`): The number of heads in self attention block.
  129. drop_path_rate (`float`, **optional**, default to 0.0):
  130. Drop path rate. See more details about drop path from
  131. https://arxiv.org/pdf/1605.07648v4.pdf.
  132. """
  133. super().__init__()
  134. self.input_resolution = input_resolution
  135. self.patch_size = patch_size
  136. self.conv1 = nn.Conv2d(
  137. in_channels=3,
  138. out_channels=width,
  139. kernel_size=patch_size,
  140. stride=patch_size,
  141. bias=False,
  142. )
  143. scale = width**-0.5
  144. self.width = width
  145. self.positional_embedding = nn.Parameter(scale * torch.randn(
  146. (input_resolution // patch_size)**2 + 1, width))
  147. self.ln_pre = LayerNorm(width)
  148. self.transformer = Transformer(
  149. width, layers, heads, drop_path_rate=drop_path_rate)
  150. def forward(self, x: torch.Tensor):
  151. resolution = x.shape[-2]
  152. height, width = x.shape[-2] // self.patch_size, x.shape[
  153. -1] // self.patch_size
  154. x = self.conv1(x) # shape = [*, width, grid, grid]
  155. x = x.reshape(x.shape[0], x.shape[1],
  156. -1) # shape = [*, width, grid ** 2]
  157. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  158. if resolution != self.input_resolution:
  159. old_pe = self.positional_embedding[1:]
  160. patch_num = self.input_resolution // self.patch_size
  161. old_pe = old_pe.reshape(1, patch_num, patch_num,
  162. -1).permute(0, 3, 1, 2)
  163. new_pe = F.interpolate(
  164. old_pe, size=(height, width), mode='bilinear')
  165. new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1)
  166. x = x + new_pe.to(x.dtype)
  167. else:
  168. x = x + self.positional_embedding[1:].to(x.dtype)
  169. x = self.ln_pre(x)
  170. x = x.permute(1, 0, 2) # NLD -> LND
  171. x = self.transformer(x)
  172. x = x.permute(1, 0, 2) # LND -> NLD
  173. bz, seq, hidden = x.shape
  174. x = x.transpose(1, 2).reshape(bz, hidden, height, width)
  175. return x
  176. def vit_base(drop_path_rate: float = 0.0):
  177. r"""
  178. An instance of base vision transformer model.
  179. Args:
  180. drop_path_rate (`float`, **optional**, default to 0.0):
  181. Drop path rate. See more details about drop path from
  182. https://arxiv.org/pdf/1605.07648v4.pdf.
  183. """
  184. return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate)
  185. def vit_large(drop_path_rate: float = 0.0):
  186. r"""
  187. An instance of large vision transformer model.
  188. Args:
  189. drop_path_rate (`float`, **optional**, default to 0.0):
  190. Drop path rate. See more details about drop path from
  191. https://arxiv.org/pdf/1605.07648v4.pdf.
  192. """
  193. return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate)
  194. def vit_large_336(drop_path_rate: float = 0.0):
  195. r"""
  196. An instance of large vision transformer model with 336 as input image width .
  197. Args:
  198. drop_path_rate (`float`, **optional**, default to 0.0):
  199. Drop path rate. See more details about drop path from
  200. https://arxiv.org/pdf/1605.07648v4.pdf.
  201. """
  202. return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate)
  203. def vit_huge(drop_path_rate: float = 0.0):
  204. r"""
  205. An instance of huge vision transformer model.
  206. Args:
  207. drop_path_rate (`float`, **optional**, default to 0.0):
  208. Drop path rate. See more details about drop path from
  209. https://arxiv.org/pdf/1605.07648v4.pdf.
  210. """
  211. return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate)