clip.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. import open_clip
  4. import torch
  5. import torch.nn as nn
  6. import torchvision.transforms as T
  7. class FrozenOpenCLIPEmbedder(nn.Module):
  8. """
  9. Uses the OpenCLIP transformer encoder for text
  10. """
  11. LAYERS = ['last', 'penultimate']
  12. def __init__(self,
  13. arch='ViT-H-14',
  14. pretrained='laion2b_s32b_b79k',
  15. device='cuda',
  16. max_length=77,
  17. freeze=True,
  18. layer='last'):
  19. super().__init__()
  20. assert layer in self.LAYERS
  21. model, _, _ = open_clip.create_model_and_transforms(
  22. arch, device=torch.device('cpu'), pretrained=pretrained)
  23. del model.visual
  24. self.model = model
  25. self.device = device
  26. self.max_length = max_length
  27. if freeze:
  28. self.freeze()
  29. self.layer = layer
  30. if self.layer == 'last':
  31. self.layer_idx = 0
  32. elif self.layer == 'penultimate':
  33. self.layer_idx = 1
  34. else:
  35. raise NotImplementedError()
  36. def freeze(self):
  37. self.model = self.model.eval()
  38. for param in self.parameters():
  39. param.requires_grad = False
  40. def forward(self, text):
  41. tokens = open_clip.tokenize(text)
  42. z = self.encode_with_transformer(tokens.to(self.device))
  43. return z
  44. def encode_with_transformer(self, text):
  45. x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
  46. x = x + self.model.positional_embedding
  47. x = x.permute(1, 0, 2) # NLD -> LND
  48. x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
  49. x = x.permute(1, 0, 2) # LND -> NLD
  50. x = self.model.ln_final(x)
  51. return x
  52. def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
  53. for i, r in enumerate(self.model.transformer.resblocks):
  54. if i == len(self.model.transformer.resblocks) - self.layer_idx:
  55. break
  56. if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
  57. ):
  58. x = checkpoint(r, x, attn_mask)
  59. else:
  60. x = r(x, attn_mask=attn_mask)
  61. return x
  62. def encode(self, text):
  63. return self(text)
  64. class FrozenOpenCLIPVisualEmbedder(nn.Module):
  65. """
  66. Uses the OpenCLIP transformer encoder for text
  67. """
  68. LAYERS = ['last', 'penultimate']
  69. def __init__(self,
  70. arch='ViT-H-14',
  71. pretrained='laion2b_s32b_b79k',
  72. device='cuda',
  73. max_length=77,
  74. freeze=True,
  75. layer='last',
  76. input_shape=(224, 224, 3)):
  77. super().__init__()
  78. assert layer in self.LAYERS
  79. model, _, preprocess = open_clip.create_model_and_transforms(
  80. arch, device=torch.device('cpu'), pretrained=pretrained)
  81. del model.transformer
  82. self.model = model
  83. data_white = np.ones(input_shape, dtype=np.uint8) * 255
  84. self.black_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0)
  85. self.preprocess = preprocess
  86. self.device = device
  87. self.max_length = max_length # 77
  88. if freeze:
  89. self.freeze()
  90. self.layer = layer # 'penultimate'
  91. if self.layer == 'last':
  92. self.layer_idx = 0
  93. elif self.layer == 'penultimate':
  94. self.layer_idx = 1
  95. else:
  96. raise NotImplementedError()
  97. def freeze(self):
  98. self.model = self.model.eval()
  99. for param in self.parameters():
  100. param.requires_grad = False
  101. def forward(self, image):
  102. # tokens = open_clip.tokenize(text)
  103. z = self.model.encode_image(image.to(self.device))
  104. return z
  105. def encode_with_transformer(self, text):
  106. x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
  107. x = x + self.model.positional_embedding
  108. x = x.permute(1, 0, 2) # NLD -> LND
  109. x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
  110. x = x.permute(1, 0, 2) # LND -> NLD
  111. x = self.model.ln_final(x)
  112. return x
  113. def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
  114. for i, r in enumerate(self.model.transformer.resblocks):
  115. if i == len(self.model.transformer.resblocks) - self.layer_idx:
  116. break
  117. if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
  118. ):
  119. x = checkpoint(r, x, attn_mask)
  120. else:
  121. x = r(x, attn_mask=attn_mask)
  122. return x
  123. def encode(self, text):
  124. return self(text)