model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. from collections import OrderedDict
  17. from typing import Any, Dict, Tuple, Union
  18. import json
  19. import numpy as np
  20. import torch
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from modelscope.metainfo import Models
  24. from modelscope.models import TorchModel
  25. from modelscope.models.builder import MODELS
  26. from modelscope.models.multi_modal.clip.bert_tokenizer import FullTokenizer
  27. from modelscope.models.multi_modal.clip.configuration_bert import BertConfig
  28. from modelscope.models.multi_modal.clip.modeling_bert import BertModel
  29. from modelscope.utils.constant import ModeKeys, ModelFile, Tasks
  30. from modelscope.utils.logger import get_logger
  31. logger = get_logger()
  32. __all__ = ['CLIPForMultiModalEmbedding']
  33. class Bottleneck(nn.Module):
  34. expansion = 4
  35. def __init__(self, inplanes, planes, stride=1):
  36. super().__init__()
  37. # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
  38. self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
  39. self.bn1 = nn.BatchNorm2d(planes)
  40. self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
  41. self.bn2 = nn.BatchNorm2d(planes)
  42. self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
  43. self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
  44. self.bn3 = nn.BatchNorm2d(planes * self.expansion)
  45. self.relu = nn.ReLU(inplace=True)
  46. self.downsample = None
  47. self.stride = stride
  48. if stride > 1 or inplanes != planes * Bottleneck.expansion:
  49. # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
  50. self.downsample = nn.Sequential(
  51. OrderedDict([('-1', nn.AvgPool2d(stride)),
  52. ('0',
  53. nn.Conv2d(
  54. inplanes,
  55. planes * self.expansion,
  56. 1,
  57. stride=1,
  58. bias=False)),
  59. ('1', nn.BatchNorm2d(planes * self.expansion))]))
  60. def forward(self, x: torch.Tensor):
  61. identity = x
  62. out = self.relu(self.bn1(self.conv1(x)))
  63. out = self.relu(self.bn2(self.conv2(out)))
  64. out = self.avgpool(out)
  65. out = self.bn3(self.conv3(out))
  66. if self.downsample is not None:
  67. identity = self.downsample(x)
  68. out += identity
  69. out = self.relu(out)
  70. return out
  71. class AttentionPool2d(nn.Module):
  72. def __init__(self,
  73. spacial_dim: int,
  74. embed_dim: int,
  75. num_heads: int,
  76. output_dim: int = None):
  77. super().__init__()
  78. self.positional_embedding = nn.Parameter(
  79. torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
  80. self.k_proj = nn.Linear(embed_dim, embed_dim)
  81. self.q_proj = nn.Linear(embed_dim, embed_dim)
  82. self.v_proj = nn.Linear(embed_dim, embed_dim)
  83. self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
  84. self.num_heads = num_heads
  85. def forward(self, x):
  86. x = x.reshape(x.shape[0], x.shape[1],
  87. x.shape[2] * x.shape[3]).permute(2, 0,
  88. 1) # NCHW -> (HW)NC
  89. x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
  90. x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
  91. x, _ = F.multi_head_attention_forward(
  92. query=x,
  93. key=x,
  94. value=x,
  95. embed_dim_to_check=x.shape[-1],
  96. num_heads=self.num_heads,
  97. q_proj_weight=self.q_proj.weight,
  98. k_proj_weight=self.k_proj.weight,
  99. v_proj_weight=self.v_proj.weight,
  100. in_proj_weight=None,
  101. in_proj_bias=torch.cat(
  102. [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
  103. bias_k=None,
  104. bias_v=None,
  105. add_zero_attn=False,
  106. dropout_p=0,
  107. out_proj_weight=self.c_proj.weight,
  108. out_proj_bias=self.c_proj.bias,
  109. use_separate_proj_weight=True,
  110. training=self.training,
  111. need_weights=False)
  112. return x[0]
  113. class ModifiedResNet(nn.Module):
  114. """
  115. A ResNet class that is similar to torchvision's but contains the following changes:
  116. - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
  117. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
  118. - The final pooling layer is a QKV attention instead of an average pool
  119. """
  120. def __init__(self,
  121. layers,
  122. output_dim,
  123. heads,
  124. input_resolution=224,
  125. width=64):
  126. super().__init__()
  127. self.output_dim = output_dim
  128. self.input_resolution = input_resolution
  129. # the 3-layer stem
  130. self.conv1 = nn.Conv2d(
  131. 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
  132. self.bn1 = nn.BatchNorm2d(width // 2)
  133. self.conv2 = nn.Conv2d(
  134. width // 2, width // 2, kernel_size=3, padding=1, bias=False)
  135. self.bn2 = nn.BatchNorm2d(width // 2)
  136. self.conv3 = nn.Conv2d(
  137. width // 2, width, kernel_size=3, padding=1, bias=False)
  138. self.bn3 = nn.BatchNorm2d(width)
  139. self.avgpool = nn.AvgPool2d(2)
  140. self.relu = nn.ReLU(inplace=True)
  141. # residual layers
  142. self._inplanes = width # this is a *mutable* variable used during construction
  143. self.layer1 = self._make_layer(width, layers[0])
  144. self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
  145. self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
  146. self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
  147. embed_dim = width * 32 # the ResNet feature dimension
  148. self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim,
  149. heads, output_dim)
  150. def _make_layer(self, planes, blocks, stride=1):
  151. layers = [Bottleneck(self._inplanes, planes, stride)]
  152. self._inplanes = planes * Bottleneck.expansion
  153. for _ in range(1, blocks):
  154. layers.append(Bottleneck(self._inplanes, planes))
  155. return nn.Sequential(*layers)
  156. def forward(self, x):
  157. def stem(x):
  158. for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
  159. (self.conv3, self.bn3)]:
  160. x = self.relu(bn(conv(x)))
  161. x = self.avgpool(x)
  162. return x
  163. x = x.type(self.conv1.weight.dtype)
  164. x = stem(x)
  165. x = self.layer1(x)
  166. x = self.layer2(x)
  167. x = self.layer3(x)
  168. x = self.layer4(x)
  169. x = self.attnpool(x)
  170. return x
  171. class LayerNorm(nn.LayerNorm):
  172. """Subclass torch's LayerNorm to handle fp16."""
  173. def forward(self, x: torch.Tensor):
  174. orig_type = x.dtype
  175. ret = super().forward(x.type(torch.float32))
  176. return ret.type(orig_type)
  177. class QuickGELU(nn.Module):
  178. def forward(self, x: torch.Tensor):
  179. return x * torch.sigmoid(1.702 * x)
  180. class ResidualAttentionBlock(nn.Module):
  181. def __init__(self,
  182. d_model: int,
  183. n_head: int,
  184. attn_mask: torch.Tensor = None):
  185. super().__init__()
  186. self.attn = nn.MultiheadAttention(d_model, n_head)
  187. self.ln_1 = LayerNorm(d_model)
  188. self.mlp = nn.Sequential(
  189. OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
  190. ('gelu', QuickGELU()),
  191. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  192. self.ln_2 = LayerNorm(d_model)
  193. self.attn_mask = attn_mask
  194. def attention(self, x: torch.Tensor):
  195. self.attn_mask = self.attn_mask.to(
  196. dtype=x.dtype,
  197. device=x.device) if self.attn_mask is not None else None
  198. return self.attn(
  199. x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  200. def forward(self, x: torch.Tensor):
  201. x = x + self.attention(self.ln_1(x))
  202. x = x + self.mlp(self.ln_2(x))
  203. return x
  204. class Transformer(nn.Module):
  205. def __init__(self,
  206. width: int,
  207. layers: int,
  208. heads: int,
  209. attn_mask: torch.Tensor = None):
  210. super().__init__()
  211. self.width = width
  212. self.layers = layers
  213. self.resblocks = nn.Sequential(*[
  214. ResidualAttentionBlock(width, heads, attn_mask)
  215. for _ in range(layers)
  216. ])
  217. def forward(self, x: torch.Tensor):
  218. return self.resblocks(x)
  219. class VisualTransformer(nn.Module):
  220. def __init__(self, input_resolution: int, patch_size: int, width: int,
  221. layers: int, heads: int, output_dim: int):
  222. super().__init__()
  223. self.input_resolution = input_resolution
  224. self.output_dim = output_dim
  225. self.conv1 = nn.Conv2d(
  226. in_channels=3,
  227. out_channels=width,
  228. kernel_size=patch_size,
  229. stride=patch_size,
  230. bias=False)
  231. scale = width**-0.5
  232. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  233. self.positional_embedding = nn.Parameter(scale * torch.randn(
  234. (input_resolution // patch_size)**2 + 1, width))
  235. self.ln_pre = LayerNorm(width)
  236. self.transformer = Transformer(width, layers, heads)
  237. self.ln_post = LayerNorm(width)
  238. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  239. def forward(self, x: torch.Tensor):
  240. x = self.conv1(x) # shape = [*, width, grid, grid]
  241. x = x.reshape(x.shape[0], x.shape[1],
  242. -1) # shape = [*, width, grid ** 2]
  243. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  244. x = torch.cat(
  245. [ # noqa
  246. self.class_embedding.to(x.dtype) + torch.zeros( # noqa
  247. x.shape[0],
  248. 1,
  249. x.shape[-1],
  250. dtype=x.dtype,
  251. device=x.device),
  252. x # noqa
  253. ],
  254. dim=1) # noqa shape = [*, grid ** 2 + 1, width]
  255. x = x + self.positional_embedding.to(x.dtype)
  256. x = self.ln_pre(x)
  257. x = x.permute(1, 0, 2) # NLD -> LND
  258. x = self.transformer(x)
  259. x = x.permute(1, 0, 2) # LND -> NLD
  260. x = self.ln_post(x[:, 0, :])
  261. if self.proj is not None:
  262. x = x @ self.proj
  263. return x
  264. class CLIP(nn.Module):
  265. def __init__(
  266. self,
  267. embed_dim: int,
  268. # vision
  269. image_resolution: int,
  270. vision_layers: Union[Tuple[int, int, int, int], int],
  271. vision_width: int,
  272. vision_patch_size: int,
  273. # text
  274. vocab_size: int,
  275. text_attention_probs_dropout_prob: float,
  276. text_hidden_act: str,
  277. text_hidden_dropout_prob: float,
  278. text_hidden_size: int,
  279. text_initializer_range: float,
  280. text_intermediate_size: int,
  281. text_max_position_embeddings: int,
  282. text_num_attention_heads: int,
  283. text_num_hidden_layers: int,
  284. text_type_vocab_size: int,
  285. tokenizer: FullTokenizer,
  286. # vision_head_width, added this param for ViT-H
  287. vision_head_width: int = 64,
  288. ):
  289. super().__init__()
  290. if isinstance(vision_layers, (tuple, list)):
  291. vision_heads = vision_width * 32 // vision_head_width
  292. self.visual = ModifiedResNet(
  293. layers=vision_layers,
  294. output_dim=embed_dim,
  295. heads=vision_heads,
  296. input_resolution=image_resolution,
  297. width=vision_width)
  298. else:
  299. vision_heads = vision_width // vision_head_width
  300. self.visual = VisualTransformer(
  301. input_resolution=image_resolution,
  302. patch_size=vision_patch_size,
  303. width=vision_width,
  304. layers=vision_layers,
  305. heads=vision_heads,
  306. output_dim=embed_dim)
  307. self.bert_config = BertConfig(
  308. vocab_size_or_config_json_file=vocab_size,
  309. hidden_size=text_hidden_size,
  310. num_hidden_layers=text_num_hidden_layers,
  311. num_attention_heads=text_num_attention_heads,
  312. intermediate_size=text_intermediate_size,
  313. hidden_act=text_hidden_act,
  314. hidden_dropout_prob=text_hidden_dropout_prob,
  315. attention_probs_dropout_prob=text_attention_probs_dropout_prob,
  316. max_position_embeddings=text_max_position_embeddings,
  317. type_vocab_size=text_type_vocab_size,
  318. initializer_range=text_initializer_range,
  319. layer_norm_eps=1e-12,
  320. )
  321. self.bert = BertModel(self.bert_config)
  322. self.text_projection = nn.Parameter(
  323. torch.empty(text_hidden_size, embed_dim))
  324. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
  325. self.tokenizer = tokenizer
  326. self.initialize_parameters()
  327. def initialize_parameters(self):
  328. self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
  329. if isinstance(self.visual, ModifiedResNet):
  330. if self.visual.attnpool is not None:
  331. std = self.visual.attnpool.c_proj.in_features**-0.5
  332. nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
  333. nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
  334. nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
  335. nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
  336. for resnet_block in [
  337. self.visual.layer1, self.visual.layer2, self.visual.layer3,
  338. self.visual.layer4
  339. ]:
  340. for name, param in resnet_block.named_parameters():
  341. if name.endswith('bn3.weight'):
  342. nn.init.zeros_(param)
  343. if self.text_projection is not None:
  344. nn.init.normal_(
  345. self.text_projection, std=self.bert_config.hidden_size**-0.5)
  346. @property
  347. def dtype(self):
  348. return self.visual.conv1.weight.dtype
  349. def encode_image(self, image):
  350. return self.visual(image.type(self.dtype))
  351. def encode_text(self, text):
  352. pad_index = self.tokenizer.vocab['[PAD]']
  353. attn_mask = text.ne(pad_index).type(self.dtype)
  354. x = self.bert(
  355. text, attention_mask=attn_mask)[0].type(
  356. self.dtype) # [batch_size, seq_length, hidden_size]
  357. return x[:, 0, :] @ self.text_projection
  358. def forward(self, image, text):
  359. assert image is not None or text is not None, 'text and image cannot both be None!'
  360. if image is None:
  361. return self.encode_text(text)
  362. elif text is None:
  363. return self.encode_image(image)
  364. image_features = self.encode_image(image)
  365. text_features = self.encode_text(text)
  366. image_features = image_features / image_features.norm(
  367. dim=-1, keepdim=True)
  368. text_features = text_features / text_features.norm(
  369. dim=-1, keepdim=True)
  370. return image_features, text_features, self.logit_scale.exp()
  371. def get_similarity(self, image, text):
  372. image_features = self.encode_image(image)
  373. text_features = self.encode_text(text)
  374. # normalized features
  375. image_features = image_features / image_features.norm(
  376. dim=1, keepdim=True)
  377. text_features = text_features / text_features.norm(dim=1, keepdim=True)
  378. # cosine similarity as logits
  379. logit_scale = self.logit_scale.exp()
  380. logits_per_image = logit_scale * image_features @ text_features.t()
  381. logits_per_text = logits_per_image.t()
  382. # shape = [global_batch_size, global_batch_size]
  383. return logits_per_image, logits_per_text
  384. def convert_models_to_fp32(model):
  385. for p in model.parameters():
  386. p.data = p.data.float()
  387. if p.grad:
  388. p.grad.data = p.grad.data.float()
  389. def convert_weights(model: nn.Module):
  390. """Convert applicable model parameters to fp16"""
  391. def _convert_weights_to_fp16(module):
  392. if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Linear)):
  393. module.weight.data = module.weight.data.half()
  394. if module.bias is not None:
  395. module.bias.data = module.bias.data.half()
  396. if isinstance(module, nn.MultiheadAttention):
  397. for attr in [
  398. *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
  399. 'in_proj_bias', 'bias_k', 'bias_v'
  400. ]:
  401. tensor = getattr(module, attr)
  402. if tensor is not None:
  403. tensor.data = tensor.data.half()
  404. if isinstance(module, BertModel):
  405. module.to(torch.half)
  406. for name in ['text_projection', 'proj']:
  407. if hasattr(module, name):
  408. attr = getattr(module, name)
  409. if attr is not None:
  410. attr.data = attr.data.half()
  411. model.apply(_convert_weights_to_fp16)
  412. @MODELS.register_module(Tasks.multi_modal_embedding, module_name=Models.clip)
  413. class CLIPForMultiModalEmbedding(TorchModel):
  414. def __init__(self, model_dir, *args, **kwargs):
  415. super().__init__(model_dir=model_dir, *args, **kwargs)
  416. # Initialize the model.
  417. vision_model_config_file = '{}/vision_model_config.json'.format(
  418. model_dir)
  419. logger.info(
  420. f'Loading vision model config from {vision_model_config_file}')
  421. assert os.path.exists(vision_model_config_file)
  422. text_model_config_file = '{}/text_model_config.json'.format(model_dir)
  423. logger.info(f'Loading text model config from {text_model_config_file}')
  424. assert os.path.exists(text_model_config_file)
  425. with open(
  426. vision_model_config_file, 'r',
  427. encoding='utf-8') as fv,\
  428. open(text_model_config_file, 'r', encoding='utf-8') as ft:
  429. self.model_info = json.load(fv)
  430. for k, v in json.load(ft).items():
  431. self.model_info[k] = v
  432. vocab_file = f'{model_dir}/{ModelFile.VOCAB_FILE}'
  433. self.tokenizer = FullTokenizer(vocab_file=vocab_file)
  434. # initialize the model
  435. self.clip_model = CLIP(**self.model_info, tokenizer=self.tokenizer)
  436. convert_weights(self.clip_model)
  437. # restore the pretrained weight
  438. checkpoint = torch.load(
  439. f'{model_dir}/{ModelFile.TORCH_MODEL_BIN_FILE}', 'cpu')
  440. sd = checkpoint[
  441. 'state_dict'] if 'state_dict' in checkpoint else checkpoint
  442. if next(iter(sd.items()))[0].startswith('module'):
  443. sd = {k[len('module.'):]: v for k, v in sd.items()}
  444. # support the finetuned model
  445. if next(iter(sd.items()))[0].startswith('clip_model'):
  446. sd = {k[len('clip_model.'):]: v for k, v in sd.items()}
  447. self.clip_model.load_state_dict(sd)
  448. self.clip_model.eval()
  449. # place the model
  450. self.device = 'cuda:{}'.format(int(os.environ.get(
  451. 'LOCAL_RANK', 0))) if torch.cuda.is_available() else 'cpu'
  452. if torch.cuda.is_available():
  453. self.clip_model.to(self.device)
  454. logger.info('Use GPU {} for finetuning & inference'.format(
  455. int(os.environ.get('LOCAL_RANK', 0))))
  456. else:
  457. self.clip_model.float()
  458. logger.info('Use CPU for finetuning & inference')
  459. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  460. from modelscope.outputs import OutputKeys
  461. output = {
  462. OutputKeys.IMG_EMBEDDING: None,
  463. OutputKeys.TEXT_EMBEDDING: None
  464. }
  465. mode = input.get('mode', ModeKeys.INFERENCE)
  466. # encode the image
  467. if 'img' in input and isinstance(input['img'], torch.Tensor):
  468. image_tensor = input['img'].to(self.device)
  469. if image_tensor.dim() == 5 and image_tensor.shape[1] == 1:
  470. image_tensor = image_tensor.squeeze(1)
  471. with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
  472. image_features = self.clip_model.encode_image(image_tensor)
  473. image_features = image_features / image_features.norm(
  474. dim=-1, keepdim=True) # l2-normalize
  475. output[OutputKeys.IMG_EMBEDDING] = image_features
  476. if 'text' in input and isinstance(input['text'], torch.Tensor):
  477. text_tensor = input['text'].to(self.device)
  478. if text_tensor.dim() == 3 and text_tensor.shape[1] == 1:
  479. text_tensor = text_tensor.squeeze(1)
  480. with torch.autograd.set_grad_enabled(mode == ModeKeys.TRAIN):
  481. text_features = self.clip_model.encode_text(text_tensor)
  482. text_features = text_features / text_features.norm(
  483. dim=-1, keepdim=True) # l2-normalize
  484. output[OutputKeys.TEXT_EMBEDDING] = text_features
  485. if mode == ModeKeys.TRAIN:
  486. output['logit_scale'] = (self.clip_model.logit_scale
  487. * 1.0).exp().mean()
  488. return output
  489. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  490. return inputs
  491. @property
  492. def temperature(self):
  493. return 1.0 / self.clip_model.logit_scale.exp()