modeling_clip.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240
  1. # coding=utf-8
  2. # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CLIP model."""
  16. from dataclasses import dataclass
  17. from typing import Any, Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
  26. from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
  27. logger = logging.get_logger(__name__)
  28. # contrastive loss function, adapted from
  29. # https://sachinruk.github.io/blog/2021-03-07-clip.html
  30. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  31. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  32. def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
  33. caption_loss = contrastive_loss(similarity)
  34. image_loss = contrastive_loss(similarity.t())
  35. return (caption_loss + image_loss) / 2.0
  36. def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
  37. """
  38. This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
  39. model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
  40. """
  41. square_tensor = torch.pow(tensor, 2)
  42. sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
  43. normed_tensor = torch.pow(sum_tensor, 0.5)
  44. return normed_tensor
  45. @dataclass
  46. @auto_docstring(
  47. custom_intro="""
  48. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  49. """
  50. )
  51. class CLIPVisionModelOutput(ModelOutput):
  52. r"""
  53. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  54. The image embeddings obtained by applying the projection layer to the pooler_output.
  55. """
  56. image_embeds: Optional[torch.FloatTensor] = None
  57. last_hidden_state: Optional[torch.FloatTensor] = None
  58. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  59. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  60. @dataclass
  61. @auto_docstring(
  62. custom_intro="""
  63. Base class for text model's outputs that also contains a pooling of the last hidden states.
  64. """
  65. )
  66. class CLIPTextModelOutput(ModelOutput):
  67. r"""
  68. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  69. The text embeddings obtained by applying the projection layer to the pooler_output.
  70. """
  71. text_embeds: Optional[torch.FloatTensor] = None
  72. last_hidden_state: Optional[torch.FloatTensor] = None
  73. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  74. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  75. @dataclass
  76. @auto_docstring
  77. class CLIPOutput(ModelOutput):
  78. r"""
  79. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  80. Contrastive loss for image-text similarity.
  81. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  82. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  83. similarity scores.
  84. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  85. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  86. similarity scores.
  87. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  88. The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
  89. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  90. The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
  91. text_model_output (`BaseModelOutputWithPooling`):
  92. The output of the [`CLIPTextModel`].
  93. vision_model_output (`BaseModelOutputWithPooling`):
  94. The output of the [`CLIPVisionModel`].
  95. """
  96. loss: Optional[torch.FloatTensor] = None
  97. logits_per_image: Optional[torch.FloatTensor] = None
  98. logits_per_text: Optional[torch.FloatTensor] = None
  99. text_embeds: Optional[torch.FloatTensor] = None
  100. image_embeds: Optional[torch.FloatTensor] = None
  101. text_model_output: BaseModelOutputWithPooling = None
  102. vision_model_output: BaseModelOutputWithPooling = None
  103. def to_tuple(self) -> tuple[Any]:
  104. return tuple(
  105. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  106. for k in self.keys()
  107. )
  108. class CLIPVisionEmbeddings(nn.Module):
  109. def __init__(self, config: CLIPVisionConfig):
  110. super().__init__()
  111. self.config = config
  112. self.embed_dim = config.hidden_size
  113. self.image_size = config.image_size
  114. self.patch_size = config.patch_size
  115. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  116. self.patch_embedding = nn.Conv2d(
  117. in_channels=config.num_channels,
  118. out_channels=self.embed_dim,
  119. kernel_size=self.patch_size,
  120. stride=self.patch_size,
  121. bias=False,
  122. )
  123. self.num_patches = (self.image_size // self.patch_size) ** 2
  124. self.num_positions = self.num_patches + 1
  125. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  126. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  127. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  128. """
  129. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  130. images. This method is also adapted to support torch.jit tracing.
  131. Adapted from:
  132. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  133. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  134. """
  135. num_patches = embeddings.shape[1] - 1
  136. position_embedding = self.position_embedding.weight.unsqueeze(0)
  137. num_positions = position_embedding.shape[1] - 1
  138. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  139. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  140. return self.position_embedding(self.position_ids)
  141. class_pos_embed = position_embedding[:, :1]
  142. patch_pos_embed = position_embedding[:, 1:]
  143. dim = embeddings.shape[-1]
  144. new_height = height // self.patch_size
  145. new_width = width // self.patch_size
  146. sqrt_num_positions = torch_int(num_positions**0.5)
  147. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  148. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  149. patch_pos_embed = nn.functional.interpolate(
  150. patch_pos_embed,
  151. size=(new_height, new_width),
  152. mode="bicubic",
  153. align_corners=False,
  154. )
  155. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  156. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  157. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  158. batch_size, _, height, width = pixel_values.shape
  159. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  160. raise ValueError(
  161. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  162. )
  163. target_dtype = self.patch_embedding.weight.dtype
  164. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  165. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  166. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  167. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  168. if interpolate_pos_encoding:
  169. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  170. else:
  171. embeddings = embeddings + self.position_embedding(self.position_ids)
  172. return embeddings
  173. class CLIPTextEmbeddings(nn.Module):
  174. def __init__(self, config: CLIPTextConfig):
  175. super().__init__()
  176. embed_dim = config.hidden_size
  177. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  178. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  179. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  180. self.register_buffer(
  181. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  182. )
  183. def forward(
  184. self,
  185. input_ids: Optional[torch.LongTensor] = None,
  186. position_ids: Optional[torch.LongTensor] = None,
  187. inputs_embeds: Optional[torch.FloatTensor] = None,
  188. ) -> torch.Tensor:
  189. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  190. max_position_embedding = self.position_embedding.weight.shape[0]
  191. if seq_length > max_position_embedding:
  192. raise ValueError(
  193. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  194. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  195. )
  196. if position_ids is None:
  197. position_ids = self.position_ids[:, :seq_length]
  198. if inputs_embeds is None:
  199. inputs_embeds = self.token_embedding(input_ids)
  200. position_embeddings = self.position_embedding(position_ids)
  201. embeddings = inputs_embeds + position_embeddings
  202. return embeddings
  203. def eager_attention_forward(
  204. module: nn.Module,
  205. query: torch.Tensor,
  206. key: torch.Tensor,
  207. value: torch.Tensor,
  208. attention_mask: Optional[torch.Tensor],
  209. scaling: float,
  210. dropout: float = 0.0,
  211. output_attentions: bool = True,
  212. **kwargs,
  213. ):
  214. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  215. if attention_mask is not None:
  216. attn_weights = attn_weights + attention_mask
  217. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  218. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  219. attn_output = torch.matmul(attn_weights, value)
  220. attn_output = attn_output.transpose(1, 2).contiguous()
  221. if not output_attentions:
  222. attn_weights = None
  223. return attn_output, attn_weights
  224. class CLIPAttention(nn.Module):
  225. """Multi-headed attention from 'Attention Is All You Need' paper"""
  226. def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
  227. super().__init__()
  228. self.config = config
  229. self.embed_dim = config.hidden_size
  230. self.num_heads = config.num_attention_heads
  231. self.head_dim = self.embed_dim // self.num_heads
  232. if self.head_dim * self.num_heads != self.embed_dim:
  233. raise ValueError(
  234. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  235. f" {self.num_heads})."
  236. )
  237. self.scale = self.head_dim**-0.5
  238. self.dropout = config.attention_dropout
  239. self.is_causal = False
  240. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  241. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  242. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  243. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  244. def forward(
  245. self,
  246. hidden_states: torch.Tensor,
  247. attention_mask: Optional[torch.Tensor] = None,
  248. causal_attention_mask: Optional[torch.Tensor] = None,
  249. output_attentions: Optional[bool] = False,
  250. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  251. """Input shape: Batch x Time x Channel"""
  252. batch_size, seq_length, embed_dim = hidden_states.shape
  253. queries = self.q_proj(hidden_states)
  254. keys = self.k_proj(hidden_states)
  255. values = self.v_proj(hidden_states)
  256. queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
  257. keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
  258. values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
  259. # CLIP text model uses both `causal_attention_mask` and `attention_mask`
  260. # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
  261. if self.config._attn_implementation == "flash_attention_2":
  262. self.is_causal = causal_attention_mask is not None
  263. else:
  264. if attention_mask is not None and causal_attention_mask is not None:
  265. attention_mask = attention_mask + causal_attention_mask
  266. elif causal_attention_mask is not None:
  267. attention_mask = causal_attention_mask
  268. attention_interface: Callable = eager_attention_forward
  269. if self.config._attn_implementation != "eager":
  270. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  271. attn_output, attn_weights = attention_interface(
  272. self,
  273. queries,
  274. keys,
  275. values,
  276. attention_mask,
  277. is_causal=self.is_causal,
  278. scaling=self.scale,
  279. dropout=0.0 if not self.training else self.dropout,
  280. output_attentions=output_attentions,
  281. )
  282. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  283. attn_output = self.out_proj(attn_output)
  284. if not output_attentions:
  285. attn_weights = None
  286. return attn_output, attn_weights
  287. class CLIPMLP(nn.Module):
  288. def __init__(self, config):
  289. super().__init__()
  290. self.config = config
  291. self.activation_fn = ACT2FN[config.hidden_act]
  292. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  293. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  294. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  295. hidden_states = self.fc1(hidden_states)
  296. hidden_states = self.activation_fn(hidden_states)
  297. hidden_states = self.fc2(hidden_states)
  298. return hidden_states
  299. class CLIPEncoderLayer(GradientCheckpointingLayer):
  300. def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]):
  301. super().__init__()
  302. self.embed_dim = config.hidden_size
  303. self.self_attn = CLIPAttention(config)
  304. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  305. self.mlp = CLIPMLP(config)
  306. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  307. def forward(
  308. self,
  309. hidden_states: torch.Tensor,
  310. attention_mask: torch.Tensor,
  311. causal_attention_mask: torch.Tensor,
  312. output_attentions: Optional[bool] = False,
  313. ) -> tuple[torch.FloatTensor]:
  314. """
  315. Args:
  316. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  317. attention_mask (`torch.FloatTensor`): attention mask of size
  318. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  319. `(config.encoder_attention_heads,)`.
  320. output_attentions (`bool`, *optional*):
  321. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  322. returned tensors for more detail.
  323. """
  324. residual = hidden_states
  325. hidden_states = self.layer_norm1(hidden_states)
  326. hidden_states, attn_weights = self.self_attn(
  327. hidden_states=hidden_states,
  328. attention_mask=attention_mask,
  329. causal_attention_mask=causal_attention_mask,
  330. output_attentions=output_attentions,
  331. )
  332. hidden_states = residual + hidden_states
  333. residual = hidden_states
  334. hidden_states = self.layer_norm2(hidden_states)
  335. hidden_states = self.mlp(hidden_states)
  336. hidden_states = residual + hidden_states
  337. outputs = (hidden_states,)
  338. if output_attentions:
  339. outputs += (attn_weights,)
  340. return outputs
  341. @auto_docstring
  342. class CLIPPreTrainedModel(PreTrainedModel):
  343. config: CLIPConfig
  344. base_model_prefix = "clip"
  345. supports_gradient_checkpointing = True
  346. _supports_sdpa = True
  347. _supports_flash_attn = True
  348. _supports_flex_attn = True
  349. _supports_attention_backend = True
  350. def _init_weights(self, module):
  351. """Initialize the weights"""
  352. factor = self.config.initializer_factor
  353. if isinstance(module, CLIPTextEmbeddings):
  354. module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  355. module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  356. elif isinstance(module, CLIPVisionEmbeddings):
  357. factor = self.config.initializer_factor
  358. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  359. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  360. nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  361. elif isinstance(module, CLIPAttention):
  362. factor = self.config.initializer_factor
  363. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  364. out_proj_std = (module.embed_dim**-0.5) * factor
  365. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  366. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  367. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  368. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  369. elif isinstance(module, CLIPMLP):
  370. factor = self.config.initializer_factor
  371. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  372. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  373. nn.init.normal_(module.fc1.weight, std=fc_std)
  374. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  375. elif isinstance(module, CLIPModel):
  376. nn.init.normal_(
  377. module.text_projection.weight,
  378. std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
  379. )
  380. nn.init.normal_(
  381. module.visual_projection.weight,
  382. std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
  383. )
  384. elif isinstance(module, CLIPVisionModelWithProjection):
  385. nn.init.normal_(
  386. module.visual_projection.weight,
  387. std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
  388. )
  389. elif isinstance(module, CLIPTextModelWithProjection):
  390. nn.init.normal_(
  391. module.text_projection.weight,
  392. std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
  393. )
  394. elif isinstance(module, CLIPForImageClassification):
  395. nn.init.normal_(
  396. module.classifier.weight,
  397. std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
  398. )
  399. if isinstance(module, nn.LayerNorm):
  400. module.bias.data.zero_()
  401. module.weight.data.fill_(1.0)
  402. if isinstance(module, nn.Linear) and module.bias is not None:
  403. module.bias.data.zero_()
  404. class CLIPEncoder(nn.Module):
  405. """
  406. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  407. [`CLIPEncoderLayer`].
  408. Args:
  409. config: CLIPConfig
  410. """
  411. def __init__(self, config: CLIPConfig):
  412. super().__init__()
  413. self.config = config
  414. self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  415. self.gradient_checkpointing = False
  416. def forward(
  417. self,
  418. inputs_embeds,
  419. attention_mask: Optional[torch.Tensor] = None,
  420. causal_attention_mask: Optional[torch.Tensor] = None,
  421. output_attentions: Optional[bool] = None,
  422. output_hidden_states: Optional[bool] = None,
  423. ) -> BaseModelOutput:
  424. r"""
  425. Args:
  426. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  427. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  428. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  429. than the model's internal embedding lookup matrix.
  430. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  431. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  432. - 1 for tokens that are **not masked**,
  433. - 0 for tokens that are **masked**.
  434. [What are attention masks?](../glossary#attention-mask)
  435. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  436. Causal mask for the text model. Mask values selected in `[0, 1]`:
  437. - 1 for tokens that are **not masked**,
  438. - 0 for tokens that are **masked**.
  439. [What are attention masks?](../glossary#attention-mask)
  440. output_attentions (`bool`, *optional*):
  441. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  442. returned tensors for more detail.
  443. output_hidden_states (`bool`, *optional*):
  444. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  445. for more detail.
  446. return_dict (`bool`, *optional*):
  447. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  448. """
  449. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  450. output_hidden_states = (
  451. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  452. )
  453. encoder_states = () if output_hidden_states else None
  454. all_attentions = () if output_attentions else None
  455. hidden_states = inputs_embeds
  456. for idx, encoder_layer in enumerate(self.layers):
  457. if output_hidden_states:
  458. encoder_states = encoder_states + (hidden_states,)
  459. layer_outputs = encoder_layer(
  460. hidden_states,
  461. attention_mask,
  462. causal_attention_mask,
  463. output_attentions=output_attentions,
  464. )
  465. hidden_states = layer_outputs[0]
  466. if output_attentions:
  467. all_attentions = all_attentions + (layer_outputs[1],)
  468. if output_hidden_states:
  469. encoder_states = encoder_states + (hidden_states,)
  470. return BaseModelOutput(
  471. last_hidden_state=hidden_states,
  472. hidden_states=encoder_states,
  473. attentions=all_attentions,
  474. )
  475. class CLIPTextTransformer(nn.Module):
  476. def __init__(self, config: CLIPTextConfig):
  477. super().__init__()
  478. self.config = config
  479. embed_dim = config.hidden_size
  480. self.embeddings = CLIPTextEmbeddings(config)
  481. self.encoder = CLIPEncoder(config)
  482. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  483. # For `pooled_output` computation
  484. self.eos_token_id = config.eos_token_id
  485. @auto_docstring
  486. def forward(
  487. self,
  488. input_ids: Optional[torch.Tensor] = None,
  489. attention_mask: Optional[torch.Tensor] = None,
  490. position_ids: Optional[torch.Tensor] = None,
  491. output_attentions: Optional[bool] = None,
  492. output_hidden_states: Optional[bool] = None,
  493. ) -> BaseModelOutputWithPooling:
  494. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  495. output_hidden_states = (
  496. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  497. )
  498. if input_ids is None:
  499. raise ValueError("You have to specify input_ids")
  500. input_shape = input_ids.size()
  501. input_ids = input_ids.view(-1, input_shape[-1])
  502. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  503. # CLIP's text model uses causal mask, prepare it here.
  504. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
  505. causal_attention_mask = _create_4d_causal_attention_mask(
  506. input_shape, hidden_states.dtype, device=hidden_states.device
  507. )
  508. # expand attention_mask
  509. if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
  510. # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
  511. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  512. encoder_outputs: BaseModelOutput = self.encoder(
  513. inputs_embeds=hidden_states,
  514. attention_mask=attention_mask,
  515. causal_attention_mask=causal_attention_mask,
  516. output_attentions=output_attentions,
  517. output_hidden_states=output_hidden_states,
  518. )
  519. last_hidden_state = encoder_outputs.last_hidden_state
  520. last_hidden_state = self.final_layer_norm(last_hidden_state)
  521. if self.eos_token_id == 2:
  522. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  523. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  524. # ------------------------------------------------------------
  525. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  526. # take features from the eot embedding (eot_token is the highest number in each sequence)
  527. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  528. pooled_output = last_hidden_state[
  529. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  530. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  531. ]
  532. else:
  533. # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
  534. pooled_output = last_hidden_state[
  535. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  536. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  537. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  538. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  539. .int()
  540. .argmax(dim=-1),
  541. ]
  542. return BaseModelOutputWithPooling(
  543. last_hidden_state=last_hidden_state,
  544. pooler_output=pooled_output,
  545. hidden_states=encoder_outputs.hidden_states,
  546. attentions=encoder_outputs.attentions,
  547. )
  548. @auto_docstring(
  549. custom_intro="""
  550. The text model from CLIP without any head or projection on top.
  551. """
  552. )
  553. class CLIPTextModel(CLIPPreTrainedModel):
  554. config: CLIPTextConfig
  555. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
  556. _supports_flash_attn = False # mask creation only accounts for sdpa/eager
  557. def __init__(self, config: CLIPTextConfig):
  558. super().__init__(config)
  559. self.text_model = CLIPTextTransformer(config)
  560. # Initialize weights and apply final processing
  561. self.post_init()
  562. def get_input_embeddings(self) -> nn.Module:
  563. return self.text_model.embeddings.token_embedding
  564. def set_input_embeddings(self, value):
  565. self.text_model.embeddings.token_embedding = value
  566. @can_return_tuple
  567. @auto_docstring
  568. def forward(
  569. self,
  570. input_ids: Optional[torch.Tensor] = None,
  571. attention_mask: Optional[torch.Tensor] = None,
  572. position_ids: Optional[torch.Tensor] = None,
  573. output_attentions: Optional[bool] = None,
  574. output_hidden_states: Optional[bool] = None,
  575. ) -> BaseModelOutputWithPooling:
  576. r"""
  577. Examples:
  578. ```python
  579. >>> from transformers import AutoTokenizer, CLIPTextModel
  580. >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
  581. >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
  582. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  583. >>> outputs = model(**inputs)
  584. >>> last_hidden_state = outputs.last_hidden_state
  585. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  586. ```"""
  587. return self.text_model(
  588. input_ids=input_ids,
  589. attention_mask=attention_mask,
  590. position_ids=position_ids,
  591. output_attentions=output_attentions,
  592. output_hidden_states=output_hidden_states,
  593. )
  594. class CLIPVisionTransformer(nn.Module):
  595. def __init__(self, config: CLIPVisionConfig):
  596. super().__init__()
  597. self.config = config
  598. embed_dim = config.hidden_size
  599. self.embeddings = CLIPVisionEmbeddings(config)
  600. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  601. self.encoder = CLIPEncoder(config)
  602. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  603. @auto_docstring
  604. def forward(
  605. self,
  606. pixel_values: Optional[torch.FloatTensor] = None,
  607. output_attentions: Optional[bool] = None,
  608. output_hidden_states: Optional[bool] = None,
  609. interpolate_pos_encoding: Optional[bool] = False,
  610. ) -> BaseModelOutputWithPooling:
  611. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  612. output_hidden_states = (
  613. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  614. )
  615. if pixel_values is None:
  616. raise ValueError("You have to specify pixel_values")
  617. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  618. hidden_states = self.pre_layrnorm(hidden_states)
  619. encoder_outputs: BaseModelOutput = self.encoder(
  620. inputs_embeds=hidden_states,
  621. output_attentions=output_attentions,
  622. output_hidden_states=output_hidden_states,
  623. )
  624. last_hidden_state = encoder_outputs.last_hidden_state
  625. pooled_output = last_hidden_state[:, 0, :]
  626. pooled_output = self.post_layernorm(pooled_output)
  627. return BaseModelOutputWithPooling(
  628. last_hidden_state=last_hidden_state,
  629. pooler_output=pooled_output,
  630. hidden_states=encoder_outputs.hidden_states,
  631. attentions=encoder_outputs.attentions,
  632. )
  633. @auto_docstring(
  634. custom_intro="""
  635. The vision model from CLIP without any head or projection on top.
  636. """
  637. )
  638. class CLIPVisionModel(CLIPPreTrainedModel):
  639. config: CLIPVisionConfig
  640. main_input_name = "pixel_values"
  641. _no_split_modules = ["CLIPEncoderLayer"]
  642. def __init__(self, config: CLIPVisionConfig):
  643. super().__init__(config)
  644. self.vision_model = CLIPVisionTransformer(config)
  645. # Initialize weights and apply final processing
  646. self.post_init()
  647. def get_input_embeddings(self) -> nn.Module:
  648. return self.vision_model.embeddings.patch_embedding
  649. @can_return_tuple
  650. @auto_docstring
  651. def forward(
  652. self,
  653. pixel_values: Optional[torch.FloatTensor] = None,
  654. output_attentions: Optional[bool] = None,
  655. output_hidden_states: Optional[bool] = None,
  656. interpolate_pos_encoding: bool = False,
  657. ) -> BaseModelOutputWithPooling:
  658. r"""
  659. Example:
  660. ```python
  661. >>> from PIL import Image
  662. >>> import requests
  663. >>> from transformers import AutoProcessor, CLIPVisionModel
  664. >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
  665. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  666. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  667. >>> image = Image.open(requests.get(url, stream=True).raw)
  668. >>> inputs = processor(images=image, return_tensors="pt")
  669. >>> outputs = model(**inputs)
  670. >>> last_hidden_state = outputs.last_hidden_state
  671. >>> pooled_output = outputs.pooler_output # pooled CLS states
  672. ```"""
  673. return self.vision_model(
  674. pixel_values=pixel_values,
  675. output_attentions=output_attentions,
  676. output_hidden_states=output_hidden_states,
  677. interpolate_pos_encoding=interpolate_pos_encoding,
  678. )
  679. @auto_docstring
  680. class CLIPModel(CLIPPreTrainedModel):
  681. config: CLIPConfig
  682. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
  683. _supports_flash_attn = False # mask creation only accounts for sdpa/eager
  684. def __init__(self, config: CLIPConfig):
  685. super().__init__(config)
  686. if not isinstance(config.text_config, CLIPTextConfig):
  687. raise TypeError(
  688. "config.text_config is expected to be of type CLIPTextConfig but is of type"
  689. f" {type(config.text_config)}."
  690. )
  691. if not isinstance(config.vision_config, CLIPVisionConfig):
  692. raise TypeError(
  693. "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
  694. f" {type(config.vision_config)}."
  695. )
  696. text_config = config.text_config
  697. vision_config = config.vision_config
  698. self.projection_dim = config.projection_dim
  699. self.text_embed_dim = text_config.hidden_size
  700. self.vision_embed_dim = vision_config.hidden_size
  701. text_model = CLIPTextModel._from_config(text_config)
  702. self.text_model = text_model.text_model
  703. vision_model = CLIPVisionModel._from_config(vision_config)
  704. self.vision_model = vision_model.vision_model
  705. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  706. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  707. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  708. # Initialize weights and apply final processing
  709. self.post_init()
  710. @filter_out_non_signature_kwargs()
  711. @auto_docstring
  712. def get_text_features(
  713. self,
  714. input_ids: torch.Tensor,
  715. attention_mask: Optional[torch.Tensor] = None,
  716. position_ids: Optional[torch.Tensor] = None,
  717. ) -> torch.FloatTensor:
  718. r"""
  719. Returns:
  720. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  721. applying the projection layer to the pooled output of [`CLIPTextModel`].
  722. Examples:
  723. ```python
  724. >>> import torch
  725. >>> from transformers import AutoTokenizer, CLIPModel
  726. >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  727. >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
  728. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  729. >>> with torch.inference_mode():
  730. ... text_features = model.get_text_features(**inputs)
  731. ```"""
  732. text_outputs: BaseModelOutputWithPooling = self.text_model(
  733. input_ids=input_ids,
  734. attention_mask=attention_mask,
  735. position_ids=position_ids,
  736. )
  737. pooled_output = text_outputs.pooler_output
  738. text_features = self.text_projection(pooled_output)
  739. return text_features
  740. @filter_out_non_signature_kwargs()
  741. @auto_docstring
  742. def get_image_features(
  743. self,
  744. pixel_values: torch.FloatTensor,
  745. interpolate_pos_encoding: bool = False,
  746. ) -> torch.FloatTensor:
  747. r"""
  748. Returns:
  749. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  750. applying the projection layer to the pooled output of [`CLIPVisionModel`].
  751. Examples:
  752. ```python
  753. >>> import torch
  754. >>> from transformers import AutoProcessor, CLIPModel
  755. >>> from transformers.image_utils import load_image
  756. >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  757. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  758. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  759. >>> image = load_image(url)
  760. >>> inputs = processor(images=image, return_tensors="pt")
  761. >>> with torch.inference_mode():
  762. ... image_features = model.get_image_features(**inputs)
  763. ```"""
  764. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  765. pixel_values=pixel_values,
  766. interpolate_pos_encoding=interpolate_pos_encoding,
  767. )
  768. pooled_output = vision_outputs.pooler_output
  769. image_features = self.visual_projection(pooled_output)
  770. return image_features
  771. @can_return_tuple
  772. @auto_docstring
  773. def forward(
  774. self,
  775. input_ids: Optional[torch.LongTensor] = None,
  776. pixel_values: Optional[torch.FloatTensor] = None,
  777. attention_mask: Optional[torch.Tensor] = None,
  778. position_ids: Optional[torch.LongTensor] = None,
  779. return_loss: Optional[bool] = None,
  780. output_attentions: Optional[bool] = None,
  781. output_hidden_states: Optional[bool] = None,
  782. interpolate_pos_encoding: bool = False,
  783. ) -> CLIPOutput:
  784. r"""
  785. return_loss (`bool`, *optional*):
  786. Whether or not to return the contrastive loss.
  787. Examples:
  788. ```python
  789. >>> import torch
  790. >>> from transformers import AutoProcessor, CLIPModel
  791. >>> from transformers.image_utils import load_image
  792. >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  793. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  794. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  795. >>> image = load_image(url)
  796. >>> inputs = processor(
  797. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  798. ... )
  799. >>> with torch.inference_mode():
  800. ... outputs = model(**inputs)
  801. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  802. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  803. ```"""
  804. # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
  805. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  806. output_hidden_states = (
  807. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  808. )
  809. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  810. pixel_values=pixel_values,
  811. output_attentions=output_attentions,
  812. output_hidden_states=output_hidden_states,
  813. interpolate_pos_encoding=interpolate_pos_encoding,
  814. )
  815. text_outputs: BaseModelOutputWithPooling = self.text_model(
  816. input_ids=input_ids,
  817. attention_mask=attention_mask,
  818. position_ids=position_ids,
  819. output_attentions=output_attentions,
  820. output_hidden_states=output_hidden_states,
  821. )
  822. image_embeds = vision_outputs.pooler_output
  823. image_embeds = self.visual_projection(image_embeds)
  824. text_embeds = text_outputs.pooler_output
  825. text_embeds = self.text_projection(text_embeds)
  826. # normalized features
  827. image_embeds = image_embeds / _get_vector_norm(image_embeds)
  828. text_embeds = text_embeds / _get_vector_norm(text_embeds)
  829. # cosine similarity as logits
  830. logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
  831. logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device)
  832. logits_per_image = logits_per_text.t()
  833. loss = None
  834. if return_loss:
  835. loss = clip_loss(logits_per_text)
  836. return CLIPOutput(
  837. loss=loss,
  838. logits_per_image=logits_per_image,
  839. logits_per_text=logits_per_text,
  840. text_embeds=text_embeds,
  841. image_embeds=image_embeds,
  842. text_model_output=text_outputs,
  843. vision_model_output=vision_outputs,
  844. )
  845. @auto_docstring
  846. class CLIPTextModelWithProjection(CLIPPreTrainedModel):
  847. config: CLIPTextConfig
  848. _supports_flash_attn = False
  849. _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
  850. def __init__(self, config: CLIPTextConfig):
  851. super().__init__(config)
  852. text_model = CLIPTextModel._from_config(config)
  853. self.text_model = text_model.text_model
  854. self.text_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
  855. # Initialize weights and apply final processing
  856. self.post_init()
  857. def get_input_embeddings(self) -> nn.Module:
  858. return self.text_model.embeddings.token_embedding
  859. def set_input_embeddings(self, value):
  860. self.text_model.embeddings.token_embedding = value
  861. @can_return_tuple
  862. @auto_docstring
  863. def forward(
  864. self,
  865. input_ids: Optional[torch.Tensor] = None,
  866. attention_mask: Optional[torch.Tensor] = None,
  867. position_ids: Optional[torch.Tensor] = None,
  868. output_attentions: Optional[bool] = None,
  869. output_hidden_states: Optional[bool] = None,
  870. ) -> CLIPTextModelOutput:
  871. r"""
  872. Examples:
  873. ```python
  874. >>> import torch
  875. >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
  876. >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
  877. >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
  878. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  879. >>> with torch.inference_mode():
  880. ... outputs = model(**inputs)
  881. >>> text_embeds = outputs.text_embeds
  882. ```"""
  883. text_outputs: BaseModelOutputWithPooling = self.text_model(
  884. input_ids=input_ids,
  885. attention_mask=attention_mask,
  886. position_ids=position_ids,
  887. output_attentions=output_attentions,
  888. output_hidden_states=output_hidden_states,
  889. )
  890. pooled_output = text_outputs.pooler_output
  891. text_embeds = self.text_projection(pooled_output)
  892. return CLIPTextModelOutput(
  893. text_embeds=text_embeds,
  894. last_hidden_state=text_outputs.last_hidden_state,
  895. hidden_states=text_outputs.hidden_states,
  896. attentions=text_outputs.attentions,
  897. )
  898. @auto_docstring
  899. class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
  900. config: CLIPVisionConfig
  901. main_input_name = "pixel_values"
  902. def __init__(self, config: CLIPVisionConfig):
  903. super().__init__(config)
  904. vision_model = CLIPVisionModel._from_config(config)
  905. self.vision_model = vision_model.vision_model
  906. self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
  907. # Initialize weights and apply final processing
  908. self.post_init()
  909. def get_input_embeddings(self) -> nn.Module:
  910. return self.vision_model.embeddings.patch_embedding
  911. @can_return_tuple
  912. @auto_docstring
  913. def forward(
  914. self,
  915. pixel_values: Optional[torch.FloatTensor] = None,
  916. output_attentions: Optional[bool] = None,
  917. output_hidden_states: Optional[bool] = None,
  918. interpolate_pos_encoding: bool = False,
  919. ) -> CLIPVisionModelOutput:
  920. r"""
  921. Examples:
  922. ```python
  923. >>> import torch
  924. >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
  925. >>> from transformers.image_utils import load_image
  926. >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
  927. >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
  928. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  929. >>> image = load_image(url)
  930. >>> inputs = processor(images=image, return_tensors="pt")
  931. >>> with torch.inference_mode():
  932. ... outputs = model(**inputs)
  933. >>> image_embeds = outputs.image_embeds
  934. ```"""
  935. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  936. pixel_values=pixel_values,
  937. output_attentions=output_attentions,
  938. output_hidden_states=output_hidden_states,
  939. interpolate_pos_encoding=interpolate_pos_encoding,
  940. )
  941. pooled_output = vision_outputs.pooler_output
  942. image_embeds = self.visual_projection(pooled_output)
  943. return CLIPVisionModelOutput(
  944. image_embeds=image_embeds,
  945. last_hidden_state=vision_outputs.last_hidden_state,
  946. hidden_states=vision_outputs.hidden_states,
  947. attentions=vision_outputs.attentions,
  948. )
  949. @auto_docstring(
  950. custom_intro="""
  951. CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
  952. the patch tokens) e.g. for ImageNet.
  953. """
  954. )
  955. class CLIPForImageClassification(CLIPPreTrainedModel):
  956. main_input_name = "pixel_values"
  957. def __init__(self, config: CLIPConfig) -> None:
  958. super().__init__(config)
  959. self.num_labels = config.num_labels
  960. vision_model = CLIPVisionModel._from_config(config.vision_config)
  961. self.vision_model = vision_model.vision_model
  962. # Classifier head
  963. self.classifier = (
  964. nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  965. )
  966. # Initialize weights and apply final processing
  967. self.post_init()
  968. @can_return_tuple
  969. @auto_docstring
  970. def forward(
  971. self,
  972. pixel_values: Optional[torch.Tensor] = None,
  973. labels: Optional[torch.Tensor] = None,
  974. output_attentions: Optional[bool] = None,
  975. output_hidden_states: Optional[bool] = None,
  976. ) -> ImageClassifierOutput:
  977. r"""
  978. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  979. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  980. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  981. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  982. """
  983. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  984. output_hidden_states = (
  985. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  986. )
  987. outputs: BaseModelOutputWithPooling = self.vision_model(
  988. pixel_values,
  989. output_attentions=output_attentions,
  990. output_hidden_states=output_hidden_states,
  991. )
  992. sequence_output = outputs.last_hidden_state
  993. # average pool the patch tokens
  994. sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
  995. # apply classifier
  996. logits = self.classifier(sequence_output)
  997. loss = None
  998. if labels is not None:
  999. loss = self.loss_function(labels, logits, self.config)
  1000. return ImageClassifierOutput(
  1001. loss=loss,
  1002. logits=logits,
  1003. hidden_states=outputs.hidden_states,
  1004. attentions=outputs.attentions,
  1005. )
  1006. __all__ = [
  1007. "CLIPModel",
  1008. "CLIPPreTrainedModel",
  1009. "CLIPTextModel",
  1010. "CLIPTextModelWithProjection",
  1011. "CLIPVisionModel",
  1012. "CLIPVisionModelWithProjection",
  1013. "CLIPForImageClassification",
  1014. ]