utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright 2021 The OpenAI Team Authors.
  2. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  3. #
  4. # The implementation here is modified based on OpenAI CLIP,
  5. # originally MIT License, Copyright (c) 2021 OpenAI,
  6. # and publicly available at https://github.com/openai/CLIP/.
  7. from collections import OrderedDict
  8. from typing import Tuple, Union
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. import torch.utils.checkpoint as checkpoint
  13. from torch import nn
  14. from transformers import BertConfig, BertForMaskedLM
  15. from modelscope.utils.compatible_with_transformers import \
  16. compatible_position_ids
  17. class LayerNorm(nn.LayerNorm):
  18. """Subclass torch's LayerNorm to handle fp16."""
  19. def forward(self, x: torch.Tensor):
  20. orig_type = x.dtype
  21. ret = super().forward(x.type(torch.float32))
  22. return ret.type(orig_type)
  23. class QuickGELU(nn.Module):
  24. def forward(self, x: torch.Tensor):
  25. return x * torch.sigmoid(1.702 * x)
  26. class ResidualAttentionBlock(nn.Module):
  27. def __init__(self,
  28. d_model: int,
  29. n_head: int,
  30. attn_mask: torch.Tensor = None):
  31. super().__init__()
  32. self.attn = nn.MultiheadAttention(d_model, n_head)
  33. self.ln_1 = LayerNorm(d_model)
  34. self.mlp = nn.Sequential(
  35. OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
  36. ('gelu', QuickGELU()),
  37. ('c_proj', nn.Linear(d_model * 4, d_model))]))
  38. self.ln_2 = LayerNorm(d_model)
  39. self.attn_mask = attn_mask
  40. def attention(self, x: torch.Tensor):
  41. self.attn_mask = self.attn_mask.to(
  42. dtype=x.dtype,
  43. device=x.device) if self.attn_mask is not None else None
  44. return self.attn(
  45. x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
  46. def forward(self, x: torch.Tensor):
  47. x = x + self.attention(self.ln_1(x))
  48. x = x + self.mlp(self.ln_2(x))
  49. return x
  50. class Transformer(nn.Module):
  51. def __init__(self,
  52. width: int,
  53. layers: int,
  54. heads: int,
  55. attn_mask: torch.Tensor = None,
  56. use_gc=False):
  57. super().__init__()
  58. self.use_gc = use_gc
  59. self.width = width
  60. self.layers = layers
  61. self.resblocks = nn.Sequential(*[
  62. ResidualAttentionBlock(width, heads, attn_mask)
  63. for _ in range(layers)
  64. ])
  65. def forward(self, x: torch.Tensor):
  66. if self.use_gc:
  67. for each_block in self.resblocks:
  68. x = checkpoint.checkpoint(each_block, x)
  69. return x
  70. else:
  71. return self.resblocks(x)
  72. class VisionTransformer(nn.Module):
  73. def __init__(self,
  74. input_resolution: int,
  75. patch_size: int,
  76. width: int,
  77. layers: int,
  78. heads: int,
  79. output_dim: int,
  80. use_gc=False):
  81. super().__init__()
  82. self.input_resolution = input_resolution
  83. self.output_dim = output_dim
  84. self.conv1 = nn.Conv2d(
  85. in_channels=3,
  86. out_channels=width,
  87. kernel_size=patch_size,
  88. stride=patch_size,
  89. bias=False)
  90. scale = width**-0.5
  91. self.class_embedding = nn.Parameter(scale * torch.randn(width))
  92. self.positional_embedding = nn.Parameter(scale * torch.randn(
  93. (input_resolution // patch_size)**2 + 1, width))
  94. self.ln_pre = LayerNorm(width)
  95. self.transformer = Transformer(width, layers, heads, use_gc=use_gc)
  96. self.ln_post = LayerNorm(width)
  97. self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
  98. def forward(self, x: torch.Tensor):
  99. x = self.conv1(x) # shape = [*, width, grid, grid]
  100. x = x.reshape(x.shape[0], x.shape[1],
  101. -1) # shape = [*, width, grid ** 2]
  102. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  103. class_embedding = self.class_embedding.to(x.dtype) + \
  104. torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
  105. x = torch.cat([class_embedding, x],
  106. dim=1) # shape = [*, grid ** 2 + 1, width]
  107. x = x + self.positional_embedding.to(x.dtype)
  108. x = self.ln_pre(x)
  109. x = x.permute(1, 0, 2) # NLD -> LND
  110. x = self.transformer(x)
  111. x = x.permute(1, 0, 2) # LND -> NLD
  112. x = self.ln_post(x[:, 0, :])
  113. if self.proj is not None:
  114. x = x @ self.proj
  115. return x
  116. class CLIPVisionWrapper(nn.Module):
  117. def __init__(self, ):
  118. super().__init__()
  119. self.vision_transformer = VisionTransformer(
  120. input_resolution=224,
  121. patch_size=14,
  122. width=1024,
  123. layers=24,
  124. heads=16,
  125. output_dim=768)
  126. def forward(self, x):
  127. x = self.vision_transformer.conv1(x) # shape = [*, width, grid, grid]
  128. x = x.reshape(x.shape[0], x.shape[1],
  129. -1) # shape = [*, width, grid ** 2]
  130. x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
  131. class_embedding = self.vision_transformer.class_embedding.to(x.dtype) + \
  132. torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
  133. x = torch.cat([class_embedding, x],
  134. dim=1) # shape = [*, grid ** 2 + 1, width]
  135. x = x + self.vision_transformer.positional_embedding.to(x.dtype)
  136. x = self.vision_transformer.ln_pre(x)
  137. x = x.permute(1, 0, 2) # NLD -> LND
  138. x = self.vision_transformer.transformer(x)
  139. x = x.permute(1, 0, 2) # LND -> NLD
  140. x_tensor = x.clone()
  141. x = self.vision_transformer.ln_post(x[:, 0, :])
  142. if self.vision_transformer.proj is not None:
  143. x = x @ self.vision_transformer.proj
  144. return x, x_tensor
  145. class BertWrapper(nn.Module):
  146. def __init__(self, config_json, feat_dim, token_dim):
  147. super(BertWrapper, self).__init__()
  148. bert_config = BertConfig.from_json_file(config_json)
  149. self.bert = BertForMaskedLM(bert_config).bert
  150. self.projector = nn.Linear(768, feat_dim, bias=False)
  151. self.projector_token_embeds = nn.Linear(768, token_dim)
  152. def forward(self, input_ids, attention_mask):
  153. trans_features = {
  154. 'input_ids': input_ids,
  155. 'attention_mask': attention_mask
  156. }
  157. output_states = self.bert(**trans_features, return_dict=False)
  158. output_tokens = output_states[0]
  159. cls_tokens = output_tokens[:, 0, :] # CLS token is first token
  160. return self.projector(cls_tokens), self.projector_token_embeds(
  161. output_tokens)
  162. class Mlp(nn.Module):
  163. def __init__(self,
  164. in_features,
  165. hidden_features=None,
  166. out_features=None,
  167. act_layer=nn.GELU,
  168. drop=0.):
  169. super().__init__()
  170. out_features = out_features or in_features
  171. hidden_features = hidden_features or in_features
  172. self.fc1 = nn.Linear(in_features, hidden_features)
  173. self.act = act_layer()
  174. self.fc2 = nn.Linear(hidden_features, out_features)
  175. self.drop = nn.Dropout(drop)
  176. def forward(self, x):
  177. x = self.fc1(x)
  178. x = self.act(x)
  179. x = self.drop(x)
  180. x = self.fc2(x)
  181. x = self.drop(x)
  182. return x
  183. class CrossLayer(nn.Module):
  184. def __init__(self, feat_dim, mlp_ratio):
  185. super(CrossLayer, self).__init__()
  186. self.norm1 = nn.LayerNorm(feat_dim)
  187. self.norm2 = nn.LayerNorm(feat_dim)
  188. self.norm3 = nn.LayerNorm(feat_dim)
  189. self.self_attn = nn.MultiheadAttention(
  190. embed_dim=feat_dim, num_heads=16)
  191. self.cross_attn = nn.MultiheadAttention(
  192. embed_dim=feat_dim, num_heads=16)
  193. self.ffn = Mlp(
  194. in_features=feat_dim,
  195. hidden_features=feat_dim * mlp_ratio,
  196. drop=0.1)
  197. self.dropout1 = nn.Dropout(0.1)
  198. self.dropout2 = nn.Dropout(0.1)
  199. self.dropout3 = nn.Dropout(0.1)
  200. def forward(self, text_tensors, text_masks, image_tensors,
  201. retrieved_tensors):
  202. retrieved_tensors_res = self.norm1(retrieved_tensors)
  203. retrieved_tensors_res = self.self_attn(
  204. (text_tensors + retrieved_tensors_res).permute(1, 0, 2),
  205. (text_tensors + retrieved_tensors_res).permute(1, 0, 2),
  206. retrieved_tensors_res.permute(1, 0, 2),
  207. key_padding_mask=(text_masks == 0),
  208. )[0].permute(1, 0, 2)
  209. retrieved_tensors = retrieved_tensors + self.dropout1(
  210. retrieved_tensors_res)
  211. retrieved_tensors_res = self.norm2(retrieved_tensors)
  212. retrieved_tensors_res = self.cross_attn(
  213. (text_tensors + retrieved_tensors_res).permute(1, 0, 2),
  214. image_tensors.permute(1, 0, 2),
  215. image_tensors.permute(1, 0, 2))[0].permute(1, 0, 2)
  216. retrieved_tensors = retrieved_tensors + self.dropout2(
  217. retrieved_tensors_res)
  218. retrieved_tensors_res = self.norm3(retrieved_tensors)
  219. retrieved_tensors = retrieved_tensors + self.dropout3(
  220. self.ffn(retrieved_tensors_res))
  221. return retrieved_tensors
  222. class TEAM(nn.Module):
  223. def __init__(self, text_model, image_model, pretrained):
  224. super(TEAM, self).__init__()
  225. self.text_model = text_model
  226. self.image_model = image_model
  227. self.cross_model = nn.ModuleList(
  228. [CrossLayer(feat_dim=1024, mlp_ratio=2)])
  229. self.image_tensor_fc = nn.Linear(1024, 768)
  230. self.text_tensor_fc = nn.Linear(1024, 768)
  231. params = torch.load(pretrained, 'cpu')
  232. compatible_position_ids(params,
  233. 'text_model.bert.embeddings.position_ids')
  234. self.load_state_dict(params, strict=True)
  235. def get_feature(self, text_data=None, text_mask=None, img_tensor=None):
  236. if text_data is not None:
  237. text_feature, text_tensors = self.text_model(text_data, text_mask)
  238. text_feature = F.normalize(text_feature, p=2.0, dim=1)
  239. else:
  240. text_feature, text_tensors = None, None
  241. if img_tensor is not None:
  242. image_feature, image_tensors = self.image_model(img_tensor)
  243. image_feature = F.normalize(image_feature, p=2.0, dim=1)
  244. else:
  245. image_feature, image_tensors = None, None
  246. return text_feature, text_tensors, image_feature, image_tensors
  247. def get_cross_score(self, text_tensors, text_mask, image_tensors):
  248. retrieved_tensors = torch.zeros_like(text_tensors)
  249. pair_score_list = []
  250. text_tensors_proj = self.text_tensor_fc(text_tensors)
  251. text_mask_float = text_mask.type(text_tensors_proj.dtype)
  252. for each_cross_model in self.cross_model:
  253. retrieved_tensors = each_cross_model(text_tensors, text_mask,
  254. image_tensors,
  255. retrieved_tensors)
  256. retrieved_tensors_proj = self.image_tensor_fc(retrieved_tensors)
  257. pair_score = torch.sum(
  258. F.normalize(retrieved_tensors_proj, p=2.0, dim=2)
  259. * F.normalize(text_tensors_proj, p=2.0, dim=2),
  260. dim=2)
  261. pair_score_reduced = torch.sum(
  262. pair_score * text_mask_float, dim=1) / torch.clamp(
  263. torch.sum(text_mask_float, dim=1), min=1.0)
  264. pair_score_list.append(pair_score_reduced)
  265. return pair_score_list