model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright 2021 The OpenAI Team Authors.
  2. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  3. #
  4. # The implementation here is modified based on OpenAI CLIP,
  5. # originally MIT License, Copyright (c) 2021 OpenAI,
  6. # and publicly available at https://github.com/openai/CLIP/.
  7. """ Generative Multimodal Model Architecture."""
  8. import os
  9. import json
  10. import torch
  11. import torch.nn.functional as F
  12. from torch import nn
  13. from modelscope.models.multi_modal.gemm import gemm_base, tokenizer
  14. class ImageEncoder(nn.Module):
  15. """Image Feature Encoder
  16. ViT Style Transformer
  17. """
  18. def __init__(self, configs):
  19. super().__init__()
  20. (embed_dim, image_resolution, vision_layers, vision_width,
  21. vision_patch_size) = configs[:5]
  22. self.visual = gemm_base.VisualTransformer(
  23. input_resolution=image_resolution,
  24. patch_size=vision_patch_size,
  25. width=vision_width,
  26. layers=vision_layers,
  27. heads=vision_width // 64,
  28. output_dim=embed_dim,
  29. use_gc=False)
  30. def forward(self, image, return_tokens=False):
  31. features = self.visual(image)
  32. tokens = features[:, 1:, :]
  33. embedding = features[:, 0, :]
  34. return (embedding, tokens) if return_tokens else embedding
  35. class TextEncoder(nn.Module):
  36. """Text Feature Encoder
  37. BERT style transformer
  38. """
  39. def __init__(self, configs):
  40. super().__init__()
  41. (context_length, vocab_size, model_width, model_heads,
  42. model_layers) = configs[-5:]
  43. # text model
  44. self.transformer = gemm_base.Transformer(
  45. width=model_width,
  46. layers=model_layers,
  47. heads=model_heads,
  48. attn_mask=self.build_attention_mask(context_length),
  49. )
  50. # others
  51. self.token_embedding = nn.Embedding(vocab_size, model_width)
  52. self.positional_embedding = nn.Parameter(
  53. torch.empty(context_length, model_width))
  54. self.ln_final = nn.LayerNorm(model_width)
  55. self.text_projection = nn.Parameter(
  56. torch.empty(model_width, configs[0]))
  57. def build_attention_mask(self, seq_length=None):
  58. mask = torch.ones(seq_length, seq_length) * -1e4
  59. mask.triu_(1) # zero out the lower diagonal
  60. return mask
  61. def forward(self, text, return_tokens=False):
  62. x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
  63. x = x + self.positional_embedding
  64. x = x.permute(1, 0, 2) # NLD -> LND
  65. x = self.transformer(x)
  66. x = x.permute(1, 0, 2) # LND -> NLD
  67. x = self.ln_final(x)
  68. # take features from the eot embedding (eot_token is the highest number in each sequence)
  69. embedding = x[torch.arange(x.shape[0]),
  70. text.argmax(dim=-1), ...] @ self.text_projection
  71. return (embedding, x) if return_tokens else embedding
  72. class RLEGModel(nn.Module):
  73. """ Generative multi-modal model, trained with RLEG method.
  74. It takes image or text or both of them as input, and produce
  75. the corresponding features of inputs.
  76. """
  77. def __init__(self, model_dir):
  78. super().__init__()
  79. with open(
  80. '{}/encoder_config.json'.format(model_dir), 'r',
  81. encoding='utf-8') as f:
  82. model_config = json.loads(f.read())
  83. model_name = list(model_config.keys())[0]
  84. config_args = model_config[model_name]
  85. bpe_path = os.path.join(model_dir, 'bpe_vocab_16e6.txt.gz')
  86. self.tokenizer = tokenizer.SimpleTokenizer(bpe_path)
  87. # build model architecture
  88. self.image_encoder = ImageEncoder(config_args)
  89. self.text_encoder = TextEncoder(config_args)
  90. self.logit_scale = nn.Parameter(torch.ones([]))
  91. def tokenize(self, text_str):
  92. text_tensor = tokenizer.clip_tokenize(self.tokenizer, [text_str])[0]
  93. return text_tensor
  94. def encode_text(self, text):
  95. feature = self.text_encoder(text)
  96. feature = F.normalize(feature, p=2, dim=-1)
  97. return feature
  98. def encode_image(self, image):
  99. feature = self.image_encoder(image)
  100. feature = F.normalize(feature, p=2, dim=-1)
  101. return feature
  102. def parse_feat(self, feat):
  103. out = feat.cpu().numpy()
  104. return out
  105. @torch.no_grad()
  106. def forward(self, image=None, text=None):
  107. """ It takes image or text as input,
  108. and extracts the features as output.
  109. """
  110. img_feature, text_feature = None, None
  111. if image is not None:
  112. img_feature = self.parse_feat(self.encode_image(image))
  113. if text is not None:
  114. text_feature = self.parse_feat(self.encode_text(text))
  115. out = {
  116. 'image_feature': img_feature,
  117. 'text_feature': text_feature,
  118. }
  119. return out