clip.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # The implementation is adopted from CLIP, made publicly available
  2. # under MIT License at https://github.com/openai/CLIP
  3. import warnings
  4. from collections import OrderedDict
  5. from typing import Tuple, Union
  6. import numpy as np
  7. import torch
  8. from torch import nn
  9. class CLIP(nn.Module):
  10. def __init__(
  11. self,
  12. embed_dim: int,
  13. # vision
  14. image_resolution: int,
  15. vision_layers: Union[Tuple[int, int, int, int], int],
  16. vision_width: int,
  17. vision_patch_size: int,
  18. # text
  19. context_length: int,
  20. vocab_size: int,
  21. transformer_width: int,
  22. transformer_heads: int,
  23. transformer_layers: int):
  24. super().__init__()
  25. self.context_length = context_length
  26. vision_heads = vision_width // 64
  27. self.visual = VisionTransformer(
  28. input_resolution=image_resolution,
  29. patch_size=vision_patch_size,
  30. width=vision_width,
  31. layers=vision_layers,
  32. heads=vision_heads,
  33. output_dim=embed_dim)
  34. self.transformer = Transformer(
  35. width=transformer_width,
  36. layers=transformer_layers,
  37. heads=transformer_heads,
  38. attn_mask=self.build_attention_mask())
  39. self.vocab_size = vocab_size
  40. self.token_embedding = nn.Embedding(vocab_size, transformer_width)
  41. self.positional_embedding = nn.Parameter(
  42. torch.empty(self.context_length, transformer_width))
  43. self.ln_final = LayerNorm(transformer_width)
  44. self.text_projection = nn.Parameter(
  45. torch.empty(transformer_width, embed_dim))
  46. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
  47. self.initialize_parameters()
  48. def initialize_parameters(self):
  49. nn.init.normal_(self.token_embedding.weight, std=0.02)
  50. nn.init.normal_(self.positional_embedding, std=0.01)
  51. proj_std = (self.transformer.width**-0.5) * (
  52. (2 * self.transformer.layers)**-0.5)
  53. attn_std = self.transformer.width**-0.5
  54. fc_std = (2 * self.transformer.width)**-0.5
  55. for block in self.transformer.resblocks:
  56. nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
  57. nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
  58. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
  59. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
  60. if self.text_projection is not None:
  61. nn.init.normal_(
  62. self.text_projection, std=self.transformer.width**-0.5)
  63. def build_attention_mask(self):
  64. # lazily create causal attention mask, with full attention between the vision tokens
  65. # pytorch uses additive attention mask; fill with -inf
  66. mask = torch.empty(self.context_length, self.context_length)
  67. mask.fill_(float('-inf'))
  68. mask.triu_(1) # zero out the lower diagonal
  69. return mask
  70. @property
  71. def dtype(self):
  72. return self.visual.conv1.weight.dtype
  73. def encode_image(self, image):
  74. return self.visual(image.type(self.dtype))
  75. def encode_text(self, text):
  76. x = self.token_embedding(text).type(
  77. self.dtype) # [batch_size, n_ctx, d_model]
  78. x = x + self.positional_embedding.type(self.dtype)
  79. x = x.permute(1, 0, 2) # NLD -> LND
  80. x = self.transformer(x)
  81. x = x.permute(1, 0, 2) # LND -> NLD
  82. x = self.ln_final(x).type(self.dtype)
  83. # x.shape = [batch_size, n_ctx, transformer.width]
  84. # take features from the eot embedding (eot_token is the highest number in each sequence)
  85. # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
  86. x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
  87. return x
  88. def forward(self, image, text):
  89. image_features = self.encode_image(image)
  90. text_features = self.encode_text(text)
  91. # normalized features
  92. image_features = image_features / image_features.norm(
  93. dim=1, keepdim=True)
  94. text_features = text_features / text_features.norm(dim=1, keepdim=True)
  95. # cosine similarity as logits
  96. logit_scale = self.logit_scale.exp()
  97. logits_per_image = logit_scale * image_features @ text_features.t()
  98. logits_per_text = logits_per_image.t()
  99. # shape = [global_batch_size, global_batch_size]
  100. return logits_per_image, logits_per_text
  101. class LayerNorm(nn.LayerNorm):
  102. """Subclass torch's LayerNorm to handle fp16."""
  103. def forward(self, x: torch.Tensor):
  104. orig_type = x.dtype
  105. ret = super().forward(x.type(torch.float32))
  106. return ret.type(orig_type)
  107. class QuickGELU(nn.Module):
  108. def forward(self, x: torch.Tensor):
  109. return x * torch.sigmoid(1.702 * x)
  110. class ResidualAttentionBlock(nn.Module):
  111. def __init__(self,
  112. d_model: int,
  113. n_head: int,
  114. attn_mask: torch.Tensor = None):
  115. super().__init__()
  116. self.attn = nn.MultiheadAttention(d_model, n_head)
  117. self.ln_1 = LayerNorm(d_model)
  118. self.mlp = nn.Sequential(
  119. OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
  120. ('gelu', QuickGELU()),
  121. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  122. self.ln_2 = LayerNorm(d_model)
  123. self.attn_mask = attn_mask
  124. def attention(self, x: torch.Tensor):
  125. self.attn_mask = self.attn_mask.to(
  126. dtype=x.dtype,
  127. device=x.device) if self.attn_mask is not None else None
  128. return self.attn(
  129. x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  130. def forward(self, x: torch.Tensor):
  131. x = x + self.attention(self.ln_1(x))
  132. x = x + self.mlp(self.ln_2(x))
  133. return x
  134. class Transformer(nn.Module):
  135. def __init__(self,
  136. width: int,
  137. layers: int,
  138. heads: int,
  139. attn_mask: torch.Tensor = None):
  140. super().__init__()
  141. self.width = width
  142. self.layers = layers
  143. self.resblocks = nn.Sequential(*[
  144. ResidualAttentionBlock(width, heads, attn_mask)
  145. for _ in range(layers)
  146. ])
  147. def forward(self, x: torch.Tensor):
  148. return self.resblocks(x)
  149. class VisionTransformer(nn.Module):
  150. def __init__(self, input_resolution: int, patch_size: int, width: int,
  151. layers: int, heads: int, output_dim: int):
  152. super().__init__()
  153. self.input_resolution = input_resolution
  154. self.output_dim = output_dim
  155. self.conv1 = nn.Conv2d(
  156. in_channels=3,
  157. out_channels=width,
  158. kernel_size=patch_size,
  159. stride=patch_size,
  160. bias=False)
  161. scale = width**-0.5
  162. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  163. self.positional_embedding = nn.Parameter(scale * torch.randn(
  164. (input_resolution // patch_size)**2 + 1, width))
  165. self.ln_pre = LayerNorm(width)
  166. self.transformer = Transformer(width, layers, heads)
  167. self.ln_post = LayerNorm(width)
  168. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  169. def forward(self, x: torch.Tensor):
  170. x = self.conv1(x) # shape = [*, width, grid, grid]
  171. x = x.reshape(x.shape[0], x.shape[1],
  172. -1) # shape = [*, width, grid ** 2]
  173. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  174. class_token = self.class_embedding.to(x.dtype) + torch.zeros(
  175. x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
  176. x = torch.cat([class_token, x], dim=1)
  177. x = x + self.positional_embedding.to(x.dtype)
  178. x = self.ln_pre(x)
  179. x = x.permute(1, 0, 2) # NLD -> LND
  180. x = self.transformer(x)
  181. x = x.permute(1, 0, 2) # LND -> NLD
  182. x = self.ln_post(x[:, 0, :])
  183. if self.proj is not None:
  184. x = x @ self.proj
  185. return x
  186. def build_model(state_dict: dict):
  187. vision_width = state_dict['visual.conv1.weight'].shape[0]
  188. vision_layers = len([
  189. k for k in state_dict.keys()
  190. if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
  191. ])
  192. vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
  193. grid_size = round(
  194. (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
  195. image_resolution = vision_patch_size * grid_size
  196. embed_dim = state_dict['text_projection'].shape[1]
  197. context_length = state_dict['positional_embedding'].shape[0]
  198. vocab_size = state_dict['token_embedding.weight'].shape[0]
  199. transformer_width = state_dict['ln_final.weight'].shape[0]
  200. transformer_heads = transformer_width // 64
  201. transformer_layers = len(
  202. set(
  203. k.split('.')[2] for k in state_dict
  204. if k.startswith('transformer.resblocks')))
  205. model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
  206. vision_patch_size, context_length, vocab_size,
  207. transformer_width, transformer_heads, transformer_layers)
  208. for key in ['input_resolution', 'context_length', 'vocab_size']:
  209. if key in state_dict:
  210. del state_dict[key]
  211. model.load_state_dict(state_dict)
  212. return model.eval()
  213. def load_clip(name: str,
  214. device: Union[str, torch.device] = 'cuda'
  215. if torch.cuda.is_available() else 'cpu',
  216. jit=True):
  217. jit = False
  218. model_path = name
  219. try:
  220. model = torch.jit.load(
  221. model_path, map_location=device if jit else 'cpu').eval()
  222. state_dict = None
  223. except RuntimeError:
  224. if jit:
  225. warnings.warn(
  226. f'File {model_path} is not a JIT archive. Loading as a state dict instead'
  227. )
  228. jit = False
  229. state_dict = torch.load(model_path, map_location='cpu')
  230. if not jit:
  231. model = build_model(state_dict or model.state_dict()).to(device)
  232. if str(device) == 'cpu':
  233. model.float()
  234. return model
  235. device_holder = torch.jit.trace(
  236. lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
  237. device_node = [
  238. n for n in device_holder.graph.findAllNodes('prim::Constant')
  239. if 'Device' in repr(n)
  240. ][-1]
  241. def patch_device(module):
  242. graphs = [module.graph] if hasattr(module, 'graph') else []
  243. if hasattr(module, 'forward1'):
  244. graphs.append(module.forward1.graph)
  245. for graph in graphs:
  246. for node in graph.findAllNodes('prim::Constant'):
  247. if 'value' in node.attributeNames() and str(
  248. node['value']).startswith('cuda'):
  249. node.copyAttributes(device_node)
  250. model.apply(patch_device)
  251. patch_device(model.encode_image)
  252. patch_device(model.encode_text)
  253. if str(device) == 'cpu':
  254. float_holder = torch.jit.trace(
  255. lambda: torch.ones([]).float(), example_inputs=[])
  256. float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
  257. float_node = float_input.node()
  258. def patch_float(module):
  259. graphs = [module.graph] if hasattr(module, 'graph') else []
  260. if hasattr(module, 'forward1'):
  261. graphs.append(module.forward1.graph)
  262. for graph in graphs:
  263. for node in graph.findAllNodes('aten::to'):
  264. inputs = list(node.inputs())
  265. for i in [1, 2]:
  266. if inputs[i].node()['value'] == 5:
  267. inputs[i].node().copyAttributes(float_node)
  268. model.apply(patch_float)
  269. patch_float(model.encode_image)
  270. patch_float(model.encode_text)
  271. model.float()
  272. return model