backbone.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # The implementation here is modified based on HuggingFace, originally Apache 2.0 License
  2. # and publicly available at https://github.com/huggingface/transformers
  3. # Copyright 2018 The HuggingFace Inc. team.
  4. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  5. import hashlib
  6. import os
  7. import urllib
  8. import warnings
  9. from collections import OrderedDict
  10. from typing import Tuple, Union
  11. import numpy as np
  12. import torch
  13. import torch.nn.functional as F
  14. from torch import nn
  15. from tqdm import tqdm
  16. from modelscope.models.base.base_torch_model import TorchModel
  17. class LayerNorm(nn.LayerNorm):
  18. def forward(self, x: torch.Tensor):
  19. orig_type = x.dtype
  20. ret = super().forward(x.type(torch.float32))
  21. return ret.type(orig_type)
  22. class QuickGELU(TorchModel):
  23. def forward(self, x: torch.Tensor):
  24. return x * torch.sigmoid(1.702 * x)
  25. class ResidualAttentionBlock(TorchModel):
  26. def __init__(self,
  27. d_model: int,
  28. n_head: int,
  29. attn_mask: torch.Tensor = None):
  30. super().__init__()
  31. self.attn = nn.MultiheadAttention(d_model, n_head)
  32. self.ln_1 = LayerNorm(d_model)
  33. self.mlp = nn.Sequential(
  34. OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
  35. ('gelu', QuickGELU()),
  36. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  37. self.ln_2 = LayerNorm(d_model)
  38. self.attn_mask = attn_mask
  39. def attention(self, x: torch.Tensor):
  40. self.attn_mask = self.attn_mask.to(
  41. dtype=x.dtype,
  42. device=x.device) if self.attn_mask is not None else None
  43. return self.attn(
  44. x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  45. def forward(self, x: torch.Tensor):
  46. x = x + self.attention(self.ln_1(x))
  47. x = x + self.mlp(self.ln_2(x))
  48. return x
  49. class Transformer(TorchModel):
  50. def __init__(self,
  51. width: int,
  52. layers: int,
  53. heads: int,
  54. attn_mask: torch.Tensor = None):
  55. super().__init__()
  56. self.width = width
  57. self.layers = layers
  58. self.resblocks = nn.Sequential(*[
  59. ResidualAttentionBlock(width, heads, attn_mask)
  60. for _ in range(layers)
  61. ])
  62. def forward(self, x: torch.Tensor):
  63. return self.resblocks(x)
  64. class VisualTransformer(TorchModel):
  65. def __init__(self, input_resolution: int, patch_size: int, width: int,
  66. layers: int, heads: int, output_dim: int):
  67. super().__init__()
  68. self.input_resolution = input_resolution
  69. self.output_dim = output_dim
  70. self.conv1 = nn.Conv2d(
  71. in_channels=3,
  72. out_channels=width,
  73. kernel_size=patch_size,
  74. stride=patch_size,
  75. bias=False)
  76. scale = width**-0.5
  77. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  78. self.positional_embedding = nn.Parameter(scale * torch.randn(
  79. (input_resolution // patch_size)**2 + 1, width))
  80. self.ln_pre = LayerNorm(width)
  81. self.transformer = Transformer(width, layers, heads)
  82. self.ln_post = LayerNorm(width)
  83. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  84. def forward(self, x: torch.Tensor):
  85. x = self.conv1(x)
  86. x = x.reshape(x.shape[0], x.shape[1], -1)
  87. x = x.permute(0, 2, 1)
  88. x_1 = self.class_embedding.to(x.dtype)
  89. x_2 = torch.zeros(
  90. x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
  91. x_1 = x_1 + x_2
  92. x = torch.cat([x_1, x], dim=1)
  93. x = x + self.positional_embedding.to(x.dtype)
  94. x = self.ln_pre(x)
  95. x = x.permute(1, 0, 2)
  96. x = self.transformer(x)
  97. x = x.permute(1, 0, 2)
  98. x = self.ln_post(x[:, 0, :])
  99. if self.proj is not None:
  100. x = x @ self.proj
  101. return x
  102. class CLIP(TorchModel):
  103. def __init__(self, embed_dim: int, image_resolution: int,
  104. vision_layers: Union[Tuple[int, int, int, int], int],
  105. vision_width: int, vision_patch_size: int,
  106. context_length: int, vocab_size: int, transformer_width: int,
  107. transformer_heads: int, transformer_layers: int):
  108. super().__init__()
  109. self.context_length = context_length
  110. vision_heads = vision_width // 64
  111. self.visual = VisualTransformer(
  112. input_resolution=image_resolution,
  113. patch_size=vision_patch_size,
  114. width=vision_width,
  115. layers=vision_layers,
  116. heads=vision_heads,
  117. output_dim=embed_dim)
  118. self.transformer = Transformer(
  119. width=transformer_width,
  120. layers=transformer_layers,
  121. heads=transformer_heads,
  122. attn_mask=self.build_attention_mask())
  123. self.vocab_size = vocab_size
  124. self.token_embedding = nn.Embedding(vocab_size, transformer_width)
  125. self.positional_embedding = nn.Parameter(
  126. torch.empty(self.context_length, transformer_width))
  127. self.ln_final = LayerNorm(transformer_width)
  128. self.text_projection = nn.Parameter(
  129. torch.empty(transformer_width, embed_dim))
  130. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
  131. self.initialize_parameters()
  132. def initialize_parameters(self):
  133. nn.init.normal_(self.token_embedding.weight, std=0.02)
  134. nn.init.normal_(self.positional_embedding, std=0.01)
  135. proj_std = (self.transformer.width**-0.5) * (
  136. (2 * self.transformer.layers)**-0.5)
  137. attn_std = self.transformer.width**-0.5
  138. fc_std = (2 * self.transformer.width)**-0.5
  139. for block in self.transformer.resblocks:
  140. nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
  141. nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
  142. nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
  143. nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
  144. if self.text_projection is not None:
  145. nn.init.normal_(
  146. self.text_projection, std=self.transformer.width**-0.5)
  147. def build_attention_mask(self):
  148. mask = torch.empty(self.context_length, self.context_length)
  149. mask.fill_(float('-inf'))
  150. mask.triu_(1)
  151. return mask
  152. @property
  153. def dtype(self):
  154. return self.visual.conv1.weight.dtype
  155. def encode_image(self, image):
  156. return self.visual(image.type(self.dtype))
  157. def encode_text(self, text, return_all_tokens=False):
  158. x = self.token_embedding(text).type(self.dtype)
  159. x = x + self.positional_embedding.type(self.dtype)
  160. x = x.permute(1, 0, 2)
  161. x = self.transformer(x)
  162. x = x.permute(1, 0, 2)
  163. x = self.ln_final(x).type(self.dtype)
  164. if return_all_tokens:
  165. return x @ self.text_projection
  166. x = x[torch.arange(x.shape[0]),
  167. text.argmax(dim=-1)] @ self.text_projection
  168. return x
  169. def forward(self, image, text):
  170. image_features = self.encode_image(image)
  171. text_features = self.encode_text(text)
  172. image_features = image_features / image_features.norm(
  173. dim=-1, keepdim=True)
  174. text_features = text_features / text_features.norm(
  175. dim=-1, keepdim=True)
  176. logit_scale = self.logit_scale.exp()
  177. logits_per_image = logit_scale * image_features @ text_features.t()
  178. logits_per_text = logit_scale * text_features @ image_features.t()
  179. return logits_per_image, logits_per_text
  180. def build_model(state_dict: dict):
  181. vit = 'visual.proj' in state_dict
  182. if vit:
  183. vision_width = state_dict['visual.conv1.weight'].shape[0]
  184. vision_layers = len([
  185. k for k in state_dict.keys()
  186. if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
  187. ])
  188. vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
  189. grid_size = round(
  190. (state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
  191. image_resolution = vision_patch_size * grid_size
  192. else:
  193. counts: list = [
  194. len(
  195. set(
  196. k.split('.')[2] for k in state_dict
  197. if k.startswith(f'visual.layer{b}')))
  198. for b in [1, 2, 3, 4]
  199. ]
  200. vision_layers = tuple(counts)
  201. vision_width = state_dict['visual.layer1.0.conv1.weight'].shape[0]
  202. output_width = round(
  203. (state_dict['visual.attnpool.positional_embedding'].shape[0]
  204. - 1)**0.5)
  205. vision_patch_size = None
  206. assert output_width**2 + 1 == state_dict[
  207. 'visual.attnpool.positional_embedding'].shape[0]
  208. image_resolution = output_width * 32
  209. embed_dim = state_dict['text_projection'].shape[1]
  210. context_length = state_dict['positional_embedding'].shape[0]
  211. vocab_size = state_dict['token_embedding.weight'].shape[0]
  212. transformer_width = state_dict['ln_final.weight'].shape[0]
  213. transformer_heads = transformer_width // 64
  214. transformer_layers = len(
  215. set(
  216. k.split('.')[2] for k in state_dict
  217. if k.startswith('transformer.resblocks')))
  218. model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
  219. vision_patch_size, context_length, vocab_size,
  220. transformer_width, transformer_heads, transformer_layers)
  221. for key in ['input_resolution', 'context_length', 'vocab_size']:
  222. if key in state_dict:
  223. del state_dict[key]
  224. model.load_state_dict(state_dict)
  225. return model.eval()
  226. def load_clip(name: str,
  227. device: Union[str, torch.device] = 'cuda'
  228. if torch.cuda.is_available() else 'cpu',
  229. jit=True):
  230. jit = False
  231. model_path = name
  232. try:
  233. model = torch.jit.load(
  234. model_path, map_location=device if jit else 'cpu').eval()
  235. state_dict = None
  236. except RuntimeError:
  237. if jit:
  238. warnings.warn(
  239. f'File {model_path} is not a JIT archive. Loading as a state dict instead'
  240. )
  241. jit = False
  242. state_dict = torch.load(
  243. model_path, map_location='cpu', weights_only=True)
  244. if not jit:
  245. model = build_model(state_dict or model.state_dict()).to(device)
  246. if str(device) == 'cpu':
  247. model.float()
  248. return model
  249. device_holder = torch.jit.trace(
  250. lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
  251. device_node = [
  252. n for n in device_holder.graph.findAllNodes('prim::Constant')
  253. if 'Device' in repr(n)
  254. ][-1]
  255. def patch_device(module):
  256. graphs = [module.graph] if hasattr(module, 'graph') else []
  257. if hasattr(module, 'forward1'):
  258. graphs.append(module.forward1.graph)
  259. for graph in graphs:
  260. for node in graph.findAllNodes('prim::Constant'):
  261. if 'value' in node.attributeNames() and str(
  262. node['value']).startswith('cuda'):
  263. node.copyAttributes(device_node)
  264. model.apply(patch_device)
  265. patch_device(model.encode_image)
  266. patch_device(model.encode_text)
  267. if str(device) == 'cpu':
  268. float_holder = torch.jit.trace(
  269. lambda: torch.ones([]).float(), example_inputs=[])
  270. float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
  271. float_node = float_input.node()
  272. def patch_float(module):
  273. graphs = [module.graph] if hasattr(module, 'graph') else []
  274. if hasattr(module, 'forward1'):
  275. graphs.append(module.forward1.graph)
  276. for graph in graphs:
  277. for node in graph.findAllNodes('aten::to'):
  278. inputs = list(node.inputs())
  279. for i in [1, 2]:
  280. if inputs[i].node()['value'] == 5:
  281. inputs[i].node().copyAttributes(float_node)
  282. model.apply(patch_float)
  283. patch_float(model.encode_image)
  284. patch_float(model.encode_text)
  285. model.float()
  286. return model