| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- # Copyright 2021 The OpenAI Team Authors.
- # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
- #
- # The implementation here is modified based on OpenAI CLIP,
- # originally MIT License, Copyright (c) 2021 OpenAI,
- # and publicly available at https://github.com/openai/CLIP/.
- from collections import OrderedDict
- from typing import Tuple, Union
- import numpy as np
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- from torch import nn
- from transformers import BertConfig, BertForMaskedLM
- from modelscope.utils.compatible_with_transformers import \
- compatible_position_ids
- class LayerNorm(nn.LayerNorm):
- """Subclass torch's LayerNorm to handle fp16."""
- def forward(self, x: torch.Tensor):
- orig_type = x.dtype
- ret = super().forward(x.type(torch.float32))
- return ret.type(orig_type)
- class QuickGELU(nn.Module):
- def forward(self, x: torch.Tensor):
- return x * torch.sigmoid(1.702 * x)
- class ResidualAttentionBlock(nn.Module):
- def __init__(self,
- d_model: int,
- n_head: int,
- attn_mask: torch.Tensor = None):
- super().__init__()
- self.attn = nn.MultiheadAttention(d_model, n_head)
- self.ln_1 = LayerNorm(d_model)
- self.mlp = nn.Sequential(
- OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
- ('gelu', QuickGELU()),
- ('c_proj', nn.Linear(d_model * 4, d_model))]))
- self.ln_2 = LayerNorm(d_model)
- self.attn_mask = attn_mask
- def attention(self, x: torch.Tensor):
- self.attn_mask = self.attn_mask.to(
- dtype=x.dtype,
- device=x.device) if self.attn_mask is not None else None
- return self.attn(
- x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
- def forward(self, x: torch.Tensor):
- x = x + self.attention(self.ln_1(x))
- x = x + self.mlp(self.ln_2(x))
- return x
- class Transformer(nn.Module):
- def __init__(self,
- width: int,
- layers: int,
- heads: int,
- attn_mask: torch.Tensor = None,
- use_gc=False):
- super().__init__()
- self.use_gc = use_gc
- self.width = width
- self.layers = layers
- self.resblocks = nn.Sequential(*[
- ResidualAttentionBlock(width, heads, attn_mask)
- for _ in range(layers)
- ])
- def forward(self, x: torch.Tensor):
- if self.use_gc:
- for each_block in self.resblocks:
- x = checkpoint.checkpoint(each_block, x)
- return x
- else:
- return self.resblocks(x)
- class VisionTransformer(nn.Module):
- def __init__(self,
- input_resolution: int,
- patch_size: int,
- width: int,
- layers: int,
- heads: int,
- output_dim: int,
- use_gc=False):
- super().__init__()
- self.input_resolution = input_resolution
- self.output_dim = output_dim
- self.conv1 = nn.Conv2d(
- in_channels=3,
- out_channels=width,
- kernel_size=patch_size,
- stride=patch_size,
- bias=False)
- scale = width**-0.5
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
- self.positional_embedding = nn.Parameter(scale * torch.randn(
- (input_resolution // patch_size)**2 + 1, width))
- self.ln_pre = LayerNorm(width)
- self.transformer = Transformer(width, layers, heads, use_gc=use_gc)
- self.ln_post = LayerNorm(width)
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
- def forward(self, x: torch.Tensor):
- x = self.conv1(x) # shape = [*, width, grid, grid]
- x = x.reshape(x.shape[0], x.shape[1],
- -1) # shape = [*, width, grid ** 2]
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
- class_embedding = self.class_embedding.to(x.dtype) + \
- torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
- x = torch.cat([class_embedding, x],
- dim=1) # shape = [*, grid ** 2 + 1, width]
- x = x + self.positional_embedding.to(x.dtype)
- x = self.ln_pre(x)
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.transformer(x)
- x = x.permute(1, 0, 2) # LND -> NLD
- x = self.ln_post(x[:, 0, :])
- if self.proj is not None:
- x = x @ self.proj
- return x
- class CLIPVisionWrapper(nn.Module):
- def __init__(self, ):
- super().__init__()
- self.vision_transformer = VisionTransformer(
- input_resolution=224,
- patch_size=14,
- width=1024,
- layers=24,
- heads=16,
- output_dim=768)
- def forward(self, x):
- x = self.vision_transformer.conv1(x) # shape = [*, width, grid, grid]
- x = x.reshape(x.shape[0], x.shape[1],
- -1) # shape = [*, width, grid ** 2]
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
- class_embedding = self.vision_transformer.class_embedding.to(x.dtype) + \
- torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
- x = torch.cat([class_embedding, x],
- dim=1) # shape = [*, grid ** 2 + 1, width]
- x = x + self.vision_transformer.positional_embedding.to(x.dtype)
- x = self.vision_transformer.ln_pre(x)
- x = x.permute(1, 0, 2) # NLD -> LND
- x = self.vision_transformer.transformer(x)
- x = x.permute(1, 0, 2) # LND -> NLD
- x_tensor = x.clone()
- x = self.vision_transformer.ln_post(x[:, 0, :])
- if self.vision_transformer.proj is not None:
- x = x @ self.vision_transformer.proj
- return x, x_tensor
- class BertWrapper(nn.Module):
- def __init__(self, config_json, feat_dim, token_dim):
- super(BertWrapper, self).__init__()
- bert_config = BertConfig.from_json_file(config_json)
- self.bert = BertForMaskedLM(bert_config).bert
- self.projector = nn.Linear(768, feat_dim, bias=False)
- self.projector_token_embeds = nn.Linear(768, token_dim)
- def forward(self, input_ids, attention_mask):
- trans_features = {
- 'input_ids': input_ids,
- 'attention_mask': attention_mask
- }
- output_states = self.bert(**trans_features, return_dict=False)
- output_tokens = output_states[0]
- cls_tokens = output_tokens[:, 0, :] # CLS token is first token
- return self.projector(cls_tokens), self.projector_token_embeds(
- output_tokens)
- class Mlp(nn.Module):
- def __init__(self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
- class CrossLayer(nn.Module):
- def __init__(self, feat_dim, mlp_ratio):
- super(CrossLayer, self).__init__()
- self.norm1 = nn.LayerNorm(feat_dim)
- self.norm2 = nn.LayerNorm(feat_dim)
- self.norm3 = nn.LayerNorm(feat_dim)
- self.self_attn = nn.MultiheadAttention(
- embed_dim=feat_dim, num_heads=16)
- self.cross_attn = nn.MultiheadAttention(
- embed_dim=feat_dim, num_heads=16)
- self.ffn = Mlp(
- in_features=feat_dim,
- hidden_features=feat_dim * mlp_ratio,
- drop=0.1)
- self.dropout1 = nn.Dropout(0.1)
- self.dropout2 = nn.Dropout(0.1)
- self.dropout3 = nn.Dropout(0.1)
- def forward(self, text_tensors, text_masks, image_tensors,
- retrieved_tensors):
- retrieved_tensors_res = self.norm1(retrieved_tensors)
- retrieved_tensors_res = self.self_attn(
- (text_tensors + retrieved_tensors_res).permute(1, 0, 2),
- (text_tensors + retrieved_tensors_res).permute(1, 0, 2),
- retrieved_tensors_res.permute(1, 0, 2),
- key_padding_mask=(text_masks == 0),
- )[0].permute(1, 0, 2)
- retrieved_tensors = retrieved_tensors + self.dropout1(
- retrieved_tensors_res)
- retrieved_tensors_res = self.norm2(retrieved_tensors)
- retrieved_tensors_res = self.cross_attn(
- (text_tensors + retrieved_tensors_res).permute(1, 0, 2),
- image_tensors.permute(1, 0, 2),
- image_tensors.permute(1, 0, 2))[0].permute(1, 0, 2)
- retrieved_tensors = retrieved_tensors + self.dropout2(
- retrieved_tensors_res)
- retrieved_tensors_res = self.norm3(retrieved_tensors)
- retrieved_tensors = retrieved_tensors + self.dropout3(
- self.ffn(retrieved_tensors_res))
- return retrieved_tensors
- class TEAM(nn.Module):
- def __init__(self, text_model, image_model, pretrained):
- super(TEAM, self).__init__()
- self.text_model = text_model
- self.image_model = image_model
- self.cross_model = nn.ModuleList(
- [CrossLayer(feat_dim=1024, mlp_ratio=2)])
- self.image_tensor_fc = nn.Linear(1024, 768)
- self.text_tensor_fc = nn.Linear(1024, 768)
- params = torch.load(pretrained, 'cpu')
- compatible_position_ids(params,
- 'text_model.bert.embeddings.position_ids')
- self.load_state_dict(params, strict=True)
- def get_feature(self, text_data=None, text_mask=None, img_tensor=None):
- if text_data is not None:
- text_feature, text_tensors = self.text_model(text_data, text_mask)
- text_feature = F.normalize(text_feature, p=2.0, dim=1)
- else:
- text_feature, text_tensors = None, None
- if img_tensor is not None:
- image_feature, image_tensors = self.image_model(img_tensor)
- image_feature = F.normalize(image_feature, p=2.0, dim=1)
- else:
- image_feature, image_tensors = None, None
- return text_feature, text_tensors, image_feature, image_tensors
- def get_cross_score(self, text_tensors, text_mask, image_tensors):
- retrieved_tensors = torch.zeros_like(text_tensors)
- pair_score_list = []
- text_tensors_proj = self.text_tensor_fc(text_tensors)
- text_mask_float = text_mask.type(text_tensors_proj.dtype)
- for each_cross_model in self.cross_model:
- retrieved_tensors = each_cross_model(text_tensors, text_mask,
- image_tensors,
- retrieved_tensors)
- retrieved_tensors_proj = self.image_tensor_fc(retrieved_tensors)
- pair_score = torch.sum(
- F.normalize(retrieved_tensors_proj, p=2.0, dim=2)
- * F.normalize(text_tensors_proj, p=2.0, dim=2),
- dim=2)
- pair_score_reduced = torch.sum(
- pair_score * text_mask_float, dim=1) / torch.clamp(
- torch.sum(text_mask_float, dim=1), min=1.0)
- pair_score_list.append(pair_score_reduced)
- return pair_score_list
|