modeling_groupvit.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431
  1. # coding=utf-8
  2. # Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch GroupViT model."""
  16. import collections.abc
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import ModelOutput, auto_docstring, filter_out_non_signature_kwargs, logging, torch_int
  28. from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
  29. logger = logging.get_logger(__name__)
  30. # contrastive loss function, adapted from
  31. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  32. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  33. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  34. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->groupvit
  35. def groupvit_loss(similarity: torch.Tensor) -> torch.Tensor:
  36. caption_loss = contrastive_loss(similarity)
  37. image_loss = contrastive_loss(similarity.t())
  38. return (caption_loss + image_loss) / 2.0
  39. def hard_softmax(logits: torch.Tensor, dim: int):
  40. y_soft = logits.softmax(dim)
  41. # Straight through.
  42. index = y_soft.max(dim, keepdim=True)[1]
  43. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  44. ret = y_hard - y_soft.detach() + y_soft
  45. return ret
  46. def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
  47. # more stable https://github.com/pytorch/pytorch/issues/41663
  48. gumbel_dist = torch.distributions.gumbel.Gumbel(
  49. torch.tensor(0.0, device=logits.device, dtype=logits.dtype),
  50. torch.tensor(1.0, device=logits.device, dtype=logits.dtype),
  51. )
  52. gumbels = gumbel_dist.sample(logits.shape)
  53. gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
  54. y_soft = gumbels.softmax(dim)
  55. if hard:
  56. # Straight through.
  57. index = y_soft.max(dim, keepdim=True)[1]
  58. y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
  59. ret = y_hard - y_soft.detach() + y_soft
  60. else:
  61. # Reparameterization trick.
  62. ret = y_soft
  63. return ret
  64. def resize_attention_map(attentions, height, width, align_corners=False):
  65. """
  66. Args:
  67. attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
  68. height (`int`): height of the output attention map
  69. width (`int`): width of the output attention map
  70. align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
  71. Returns:
  72. `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]
  73. """
  74. scale = (height * width // attentions.shape[2]) ** 0.5
  75. if height > width:
  76. feat_width = int(np.round(width / scale))
  77. feat_height = attentions.shape[2] // feat_width
  78. else:
  79. feat_height = int(np.round(height / scale))
  80. feat_width = attentions.shape[2] // feat_height
  81. batch_size = attentions.shape[0]
  82. groups = attentions.shape[1] # number of group token
  83. # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]
  84. attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)
  85. attentions = nn.functional.interpolate(
  86. attentions, size=(height, width), mode="bilinear", align_corners=align_corners
  87. )
  88. return attentions
  89. def get_grouping_from_attentions(attentions, hw_shape):
  90. """
  91. Args:
  92. attentions (`tuple(torch.FloatTensor)`: tuple of attention maps returned by `GroupViTVisionTransformer`
  93. hw_shape (`tuple(int)`): height and width of the output attention map
  94. Returns:
  95. `torch.Tensor`: the attention map of shape [batch_size, groups, height, width]
  96. """
  97. attn_maps = []
  98. with torch.no_grad():
  99. prev_attn_masks = None
  100. for attn_masks in attentions:
  101. # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
  102. attn_masks = attn_masks.permute(0, 2, 1).contiguous()
  103. if prev_attn_masks is None:
  104. prev_attn_masks = attn_masks
  105. else:
  106. prev_attn_masks = prev_attn_masks @ attn_masks
  107. # [batch_size, heightxwidth, num_groups] -> [batch_size, num_groups, heightxwidth] -> [batch_size, num_groups, height, width]
  108. cur_attn_map = resize_attention_map(prev_attn_masks.permute(0, 2, 1).contiguous(), *hw_shape)
  109. attn_maps.append(cur_attn_map)
  110. # [batch_size, num_groups, height, width]
  111. final_grouping = attn_maps[-1]
  112. return final_grouping
  113. class GroupViTCrossAttentionLayer(nn.Module):
  114. def __init__(self, config: GroupViTVisionConfig):
  115. super().__init__()
  116. self.attn = GroupViTAttention(config)
  117. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  118. self.mlp = GroupViTMLP(config)
  119. self.norm_post = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  120. def forward(self, query, key):
  121. x = query
  122. x = x + self.attn(query, encoder_hidden_states=key)[0]
  123. x = x + self.mlp(self.norm2(x))
  124. x = self.norm_post(x)
  125. return x
  126. class GroupViTAssignAttention(nn.Module):
  127. def __init__(self, config: GroupViTVisionConfig):
  128. super().__init__()
  129. self.scale = config.hidden_size**-0.5
  130. self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
  131. self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
  132. self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
  133. self.proj = nn.Linear(config.hidden_size, config.hidden_size)
  134. self.assign_eps = config.assign_eps
  135. def get_attn(self, attn, gumbel=True, hard=True):
  136. if gumbel and self.training:
  137. attn = gumbel_softmax(attn, dim=-2, hard=hard)
  138. else:
  139. if hard:
  140. attn = hard_softmax(attn, dim=-2)
  141. else:
  142. attn = nn.functional.softmax(attn, dim=-2)
  143. return attn
  144. def forward(self, query, key):
  145. value = key
  146. # [batch_size, query_length, channels]
  147. query = self.q_proj(query)
  148. # [batch_size, key_length, channels]
  149. key = self.k_proj(key)
  150. # [batch_size, key_length, channels]
  151. value = self.v_proj(value)
  152. # [batch_size, query_length, key_length]
  153. raw_attn = (query @ key.transpose(-2, -1)) * self.scale
  154. attn = self.get_attn(raw_attn)
  155. soft_attn = self.get_attn(raw_attn, gumbel=False, hard=False)
  156. attn = attn / (attn.sum(dim=-1, keepdim=True) + self.assign_eps)
  157. out = attn @ value
  158. out = self.proj(out)
  159. return out, soft_attn
  160. class GroupViTTokenAssign(nn.Module):
  161. def __init__(self, config: GroupViTVisionConfig, num_group_token, num_output_group):
  162. super().__init__()
  163. self.num_output_group = num_output_group
  164. # norm on group_tokens
  165. self.norm_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  166. assign_mlp_ratio = (
  167. config.assign_mlp_ratio
  168. if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
  169. else (config.assign_mlp_ratio, config.assign_mlp_ratio)
  170. )
  171. tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
  172. self.mlp_inter = GroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group)
  173. self.norm_post_tokens = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  174. # norm on x
  175. self.norm_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  176. self.pre_assign_attn = GroupViTCrossAttentionLayer(config)
  177. self.assign = GroupViTAssignAttention(config)
  178. self.norm_new_x = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  179. self.mlp_channels = GroupViTMLP(config, config.hidden_size, channels_dim, config.hidden_size)
  180. def project_group_token(self, group_tokens):
  181. """
  182. Args:
  183. group_tokens (torch.Tensor): group tokens, [batch_size, num_group_tokens, channels]
  184. Returns:
  185. projected_group_tokens (torch.Tensor): [batch_size, num_output_groups, channels]
  186. """
  187. # [B, num_output_groups, C] <- [B, num_group_tokens, C]
  188. projected_group_tokens = self.mlp_inter(group_tokens)
  189. projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
  190. return projected_group_tokens
  191. def forward(self, image_tokens, group_tokens):
  192. """
  193. Args:
  194. image_tokens (`torch.Tensor`): image tokens, of shape [batch_size, input_length, channels]
  195. group_tokens (`torch.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
  196. """
  197. group_tokens = self.norm_tokens(group_tokens)
  198. image_tokens = self.norm_x(image_tokens)
  199. # [batch_size, num_output_groups, channels]
  200. projected_group_tokens = self.project_group_token(group_tokens)
  201. projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
  202. new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
  203. new_image_tokens += projected_group_tokens
  204. new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
  205. return new_image_tokens, attention
  206. @dataclass
  207. @auto_docstring
  208. class GroupViTModelOutput(ModelOutput):
  209. r"""
  210. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  211. Contrastive loss for image-text similarity.
  212. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  213. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  214. similarity scores.
  215. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  216. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  217. similarity scores.
  218. segmentation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
  219. Classification scores for each pixel.
  220. <Tip warning={true}>
  221. The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
  222. to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
  223. original image size as post-processing. You should always check your logits shape and resize as needed.
  224. </Tip>
  225. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  226. The text embeddings obtained by applying the projection layer to the pooled output of
  227. [`GroupViTTextModel`].
  228. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  229. The image embeddings obtained by applying the projection layer to the pooled output of
  230. [`GroupViTVisionModel`].
  231. text_model_output (`BaseModelOutputWithPooling`):
  232. The output of the [`GroupViTTextModel`].
  233. vision_model_output (`BaseModelOutputWithPooling`):
  234. The output of the [`GroupViTVisionModel`].
  235. """
  236. loss: Optional[torch.FloatTensor] = None
  237. logits_per_image: Optional[torch.FloatTensor] = None
  238. logits_per_text: Optional[torch.FloatTensor] = None
  239. segmentation_logits: Optional[torch.FloatTensor] = None
  240. text_embeds: Optional[torch.FloatTensor] = None
  241. image_embeds: Optional[torch.FloatTensor] = None
  242. text_model_output: BaseModelOutputWithPooling = None
  243. vision_model_output: BaseModelOutputWithPooling = None
  244. def to_tuple(self) -> tuple[Any]:
  245. return tuple(
  246. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  247. for k in self.keys()
  248. )
  249. class GroupViTPatchEmbeddings(nn.Module):
  250. """
  251. Image to Patch Embedding.
  252. """
  253. def __init__(
  254. self,
  255. image_size: int = 224,
  256. patch_size: Union[int, tuple[int, int]] = 16,
  257. num_channels: int = 3,
  258. embed_dim: int = 768,
  259. ):
  260. super().__init__()
  261. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  262. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  263. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  264. self.image_size = image_size
  265. self.patch_size = patch_size
  266. self.num_patches = num_patches
  267. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  268. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  269. batch_size, num_channels, height, width = pixel_values.shape
  270. if not interpolate_pos_encoding:
  271. if height != self.image_size[0] or width != self.image_size[1]:
  272. raise ValueError(
  273. f"Input image size ({height}*{width}) doesn't match model"
  274. f" ({self.image_size[0]}*{self.image_size[1]})."
  275. )
  276. x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  277. return x
  278. class GroupViTVisionEmbeddings(nn.Module):
  279. def __init__(self, config: GroupViTVisionConfig):
  280. super().__init__()
  281. self.patch_embeddings = GroupViTPatchEmbeddings(
  282. image_size=config.image_size,
  283. patch_size=config.patch_size,
  284. num_channels=config.num_channels,
  285. embed_dim=config.hidden_size,
  286. )
  287. num_patches = self.patch_embeddings.num_patches
  288. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, config.hidden_size))
  289. self.dropout = nn.Dropout(config.dropout)
  290. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  291. self.patch_size = config.patch_size
  292. self.config = config
  293. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  294. """
  295. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  296. images. This method is also adapted to support torch.jit tracing and no class embeddings.
  297. Adapted from:
  298. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  299. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  300. """
  301. num_patches = embeddings.shape[1]
  302. num_positions = self.position_embeddings.shape[1]
  303. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  304. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  305. return self.position_embeddings
  306. patch_pos_embed = self.position_embeddings
  307. dim = embeddings.shape[-1]
  308. new_height = height // self.patch_size
  309. new_width = width // self.patch_size
  310. sqrt_num_positions = torch_int(num_positions**0.5)
  311. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  312. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  313. patch_pos_embed = nn.functional.interpolate(
  314. patch_pos_embed,
  315. size=(new_height, new_width),
  316. mode="bicubic",
  317. align_corners=False,
  318. )
  319. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  320. return patch_pos_embed
  321. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  322. batch_size, num_channels, height, width = pixel_values.shape
  323. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  324. embeddings = self.layernorm(embeddings)
  325. batch_size, seq_len, _ = embeddings.size()
  326. # add positional encoding to each token
  327. if interpolate_pos_encoding:
  328. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  329. else:
  330. embeddings = embeddings + self.position_embeddings
  331. embeddings = self.dropout(embeddings)
  332. return embeddings
  333. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->GroupViT
  334. class GroupViTTextEmbeddings(nn.Module):
  335. def __init__(self, config: GroupViTTextConfig):
  336. super().__init__()
  337. embed_dim = config.hidden_size
  338. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  339. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  340. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  341. self.register_buffer(
  342. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  343. )
  344. def forward(
  345. self,
  346. input_ids: Optional[torch.LongTensor] = None,
  347. position_ids: Optional[torch.LongTensor] = None,
  348. inputs_embeds: Optional[torch.FloatTensor] = None,
  349. ) -> torch.Tensor:
  350. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  351. max_position_embedding = self.position_embedding.weight.shape[0]
  352. if seq_length > max_position_embedding:
  353. raise ValueError(
  354. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  355. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  356. )
  357. if position_ids is None:
  358. position_ids = self.position_ids[:, :seq_length]
  359. if inputs_embeds is None:
  360. inputs_embeds = self.token_embedding(input_ids)
  361. position_embeddings = self.position_embedding(position_ids)
  362. embeddings = inputs_embeds + position_embeddings
  363. return embeddings
  364. class GroupViTStage(nn.Module):
  365. """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
  366. def __init__(
  367. self,
  368. config: GroupViTVisionConfig,
  369. depth: int,
  370. num_prev_group_token: int,
  371. num_group_token: int,
  372. num_output_group: int,
  373. ):
  374. super().__init__()
  375. self.depth = depth
  376. self.num_group_token = num_group_token
  377. if num_group_token > 0:
  378. self.group_token = nn.Parameter(torch.zeros(1, num_group_token, config.hidden_size))
  379. else:
  380. self.group_token = None
  381. self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(depth)])
  382. if num_group_token > 0:
  383. self.downsample = GroupViTTokenAssign(
  384. config=config,
  385. num_group_token=num_group_token,
  386. num_output_group=num_output_group,
  387. )
  388. else:
  389. self.downsample = None
  390. if num_prev_group_token > 0 and num_group_token > 0:
  391. self.group_projector = nn.Sequential(
  392. nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
  393. GroupViTMixerMLP(config, num_prev_group_token, config.hidden_size // 2, num_group_token),
  394. )
  395. else:
  396. self.group_projector = None
  397. @property
  398. def with_group_token(self):
  399. return self.group_token is not None
  400. def split_x(self, x):
  401. if self.with_group_token:
  402. return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
  403. else:
  404. return x, None
  405. def concat_x(self, x: torch.Tensor, group_token: Optional[torch.Tensor] = None) -> torch.Tensor:
  406. if group_token is None:
  407. return x
  408. return torch.cat([x, group_token], dim=1)
  409. def forward(
  410. self,
  411. hidden_states: torch.Tensor,
  412. prev_group_token: Optional[torch.Tensor] = None,
  413. output_attentions: Optional[bool] = False,
  414. ) -> tuple[torch.FloatTensor]:
  415. """
  416. Args:
  417. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  418. attention_mask (`torch.FloatTensor`): attention mask of size
  419. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  420. `(config.encoder_attention_heads,)`.
  421. output_attentions (`bool`, *optional*):
  422. Whether or not to return the grouping tensors of Grouping block.
  423. """
  424. if self.with_group_token:
  425. group_token = self.group_token.expand(hidden_states.size(0), -1, -1)
  426. if self.group_projector is not None:
  427. group_token = group_token + self.group_projector(prev_group_token)
  428. else:
  429. group_token = None
  430. x = hidden_states
  431. cat_x = self.concat_x(x, group_token)
  432. for layer in self.layers:
  433. layer_out = layer(cat_x, attention_mask=None, causal_attention_mask=None)
  434. cat_x = layer_out[0]
  435. x, group_token = self.split_x(cat_x)
  436. attention = None
  437. if self.downsample is not None:
  438. x, attention = self.downsample(x, group_token)
  439. outputs = (x, group_token)
  440. if output_attentions:
  441. outputs = outputs + (attention,)
  442. return outputs
  443. class GroupViTMLP(nn.Module):
  444. def __init__(
  445. self,
  446. config: GroupViTVisionConfig,
  447. hidden_size: Optional[int] = None,
  448. intermediate_size: Optional[int] = None,
  449. output_size: Optional[int] = None,
  450. ):
  451. super().__init__()
  452. self.config = config
  453. self.activation_fn = ACT2FN[config.hidden_act]
  454. hidden_size = hidden_size if hidden_size is not None else config.hidden_size
  455. intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
  456. output_size = output_size if output_size is not None else hidden_size
  457. self.fc1 = nn.Linear(hidden_size, intermediate_size)
  458. self.fc2 = nn.Linear(intermediate_size, output_size)
  459. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  460. hidden_states = self.fc1(hidden_states)
  461. hidden_states = self.activation_fn(hidden_states)
  462. hidden_states = self.fc2(hidden_states)
  463. return hidden_states
  464. class GroupViTMixerMLP(GroupViTMLP):
  465. def forward(self, x):
  466. x = super().forward(x.transpose(1, 2))
  467. return x.transpose(1, 2)
  468. class GroupViTAttention(nn.Module):
  469. """Multi-headed attention from 'Attention Is All You Need' paper"""
  470. def __init__(self, config):
  471. super().__init__()
  472. self.config = config
  473. self.embed_dim = config.hidden_size
  474. self.num_heads = config.num_attention_heads
  475. self.head_dim = self.embed_dim // self.num_heads
  476. if self.head_dim * self.num_heads != self.embed_dim:
  477. raise ValueError(
  478. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  479. f" {self.num_heads})."
  480. )
  481. self.scale = self.head_dim**-0.5
  482. self.dropout = config.attention_dropout
  483. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  484. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  485. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  486. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  487. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  488. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  489. def forward(
  490. self,
  491. hidden_states: torch.Tensor,
  492. attention_mask: Optional[torch.Tensor] = None,
  493. causal_attention_mask: Optional[torch.Tensor] = None,
  494. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  495. output_attentions: Optional[bool] = False,
  496. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  497. """Input shape: Batch x Time x Channel"""
  498. bsz, tgt_len, embed_dim = hidden_states.size()
  499. is_cross_attention = encoder_hidden_states is not None
  500. # get query proj
  501. query_states = self.q_proj(hidden_states) * self.scale
  502. if is_cross_attention:
  503. key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
  504. value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
  505. else:
  506. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  507. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  508. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  509. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  510. key_states = key_states.view(*proj_shape)
  511. value_states = value_states.view(*proj_shape)
  512. src_len = key_states.size(1)
  513. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  514. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  515. raise ValueError(
  516. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  517. f" {attn_weights.size()}"
  518. )
  519. # apply the causal_attention_mask first
  520. if causal_attention_mask is not None:
  521. if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
  522. raise ValueError(
  523. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  524. f" {causal_attention_mask.size()}"
  525. )
  526. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
  527. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  528. if attention_mask is not None:
  529. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  530. raise ValueError(
  531. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  532. )
  533. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  534. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  535. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  536. if output_attentions:
  537. # this operation is a bit awkward, but it's required to
  538. # make sure that attn_weights keeps its gradient.
  539. # In order to do so, attn_weights have to reshaped
  540. # twice and have to be reused in the following
  541. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  542. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  543. else:
  544. attn_weights_reshaped = None
  545. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  546. attn_output = torch.bmm(attn_probs, value_states)
  547. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  548. raise ValueError(
  549. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  550. f" {attn_output.size()}"
  551. )
  552. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  553. attn_output = attn_output.transpose(1, 2)
  554. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  555. attn_output = self.out_proj(attn_output)
  556. return attn_output, attn_weights_reshaped
  557. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT
  558. class GroupViTEncoderLayer(GradientCheckpointingLayer):
  559. def __init__(self, config: GroupViTConfig):
  560. super().__init__()
  561. self.embed_dim = config.hidden_size
  562. self.self_attn = GroupViTAttention(config)
  563. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  564. self.mlp = GroupViTMLP(config)
  565. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  566. def forward(
  567. self,
  568. hidden_states: torch.Tensor,
  569. attention_mask: torch.Tensor,
  570. causal_attention_mask: torch.Tensor,
  571. output_attentions: Optional[bool] = False,
  572. ) -> tuple[torch.FloatTensor]:
  573. """
  574. Args:
  575. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  576. attention_mask (`torch.FloatTensor`): attention mask of size
  577. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  578. `(config.encoder_attention_heads,)`.
  579. output_attentions (`bool`, *optional*):
  580. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  581. returned tensors for more detail.
  582. """
  583. residual = hidden_states
  584. hidden_states = self.layer_norm1(hidden_states)
  585. hidden_states, attn_weights = self.self_attn(
  586. hidden_states=hidden_states,
  587. attention_mask=attention_mask,
  588. causal_attention_mask=causal_attention_mask,
  589. output_attentions=output_attentions,
  590. )
  591. hidden_states = residual + hidden_states
  592. residual = hidden_states
  593. hidden_states = self.layer_norm2(hidden_states)
  594. hidden_states = self.mlp(hidden_states)
  595. hidden_states = residual + hidden_states
  596. outputs = (hidden_states,)
  597. if output_attentions:
  598. outputs += (attn_weights,)
  599. return outputs
  600. @auto_docstring
  601. class GroupViTPreTrainedModel(PreTrainedModel):
  602. config: GroupViTConfig
  603. base_model_prefix = "groupvit"
  604. supports_gradient_checkpointing = True
  605. def _init_weights(self, module):
  606. """Initialize the weights"""
  607. init_range = self.config.initializer_range
  608. if isinstance(module, (nn.Linear, nn.Conv2d)):
  609. # Slightly different from the TF version which uses truncated_normal for initialization
  610. # cf https://github.com/pytorch/pytorch/pull/5617
  611. module.weight.data.normal_(mean=0.0, std=init_range)
  612. if module.bias is not None:
  613. module.bias.data.zero_()
  614. elif isinstance(module, nn.LayerNorm):
  615. module.bias.data.zero_()
  616. module.weight.data.fill_(1.0)
  617. factor = self.config.initializer_factor
  618. if isinstance(module, GroupViTTextEmbeddings):
  619. module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  620. module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  621. elif isinstance(module, GroupViTAttention):
  622. factor = self.config.initializer_factor
  623. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  624. out_proj_std = (module.embed_dim**-0.5) * factor
  625. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  626. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  627. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  628. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  629. elif isinstance(module, GroupViTMLP):
  630. factor = self.config.initializer_factor
  631. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  632. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  633. nn.init.normal_(module.fc1.weight, std=fc_std)
  634. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  635. class GroupViTVisionEncoder(nn.Module):
  636. def __init__(self, config: GroupViTVisionConfig) -> None:
  637. super().__init__()
  638. self.config = config
  639. self.stages = nn.ModuleList(
  640. [
  641. GroupViTStage(
  642. config=config,
  643. depth=config.depths[i],
  644. num_group_token=config.num_group_tokens[i],
  645. num_output_group=config.num_output_groups[i],
  646. num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
  647. )
  648. for i in range(len(config.depths))
  649. ]
  650. )
  651. self.gradient_checkpointing = False
  652. def forward(
  653. self,
  654. hidden_states: torch.Tensor,
  655. output_hidden_states: Optional[bool] = None,
  656. output_attentions: Optional[bool] = None,
  657. return_dict: Optional[bool] = None,
  658. ) -> Union[tuple, BaseModelOutput]:
  659. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  660. output_hidden_states = (
  661. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  662. )
  663. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  664. all_hidden_states = () if output_hidden_states else None
  665. all_groupings = () if output_attentions else None
  666. group_tokens = None
  667. for i, stage in enumerate(self.stages):
  668. if output_hidden_states:
  669. all_hidden_states = all_hidden_states + (hidden_states,)
  670. layer_outputs = stage(hidden_states, group_tokens, output_attentions)
  671. hidden_states = layer_outputs[0]
  672. group_tokens = layer_outputs[1]
  673. if output_attentions and layer_outputs[2] is not None:
  674. all_groupings = all_groupings + (layer_outputs[2],)
  675. if output_hidden_states:
  676. all_hidden_states = all_hidden_states + (hidden_states,)
  677. if not return_dict:
  678. return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
  679. return BaseModelOutput(
  680. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
  681. )
  682. class GroupViTTextEncoder(nn.Module):
  683. """
  684. Transformer encoder consisting of `config.num_hidden_layers` self-attention layers. Each layer is a
  685. [`GroupViTEncoderLayer`].
  686. Args:
  687. config: GroupViTTextConfig
  688. """
  689. def __init__(self, config: GroupViTTextConfig):
  690. super().__init__()
  691. self.config = config
  692. self.layers = nn.ModuleList([GroupViTEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  693. self.gradient_checkpointing = False
  694. def forward(
  695. self,
  696. inputs_embeds,
  697. attention_mask: Optional[torch.Tensor] = None,
  698. causal_attention_mask: Optional[torch.Tensor] = None,
  699. output_attentions: Optional[bool] = None,
  700. output_hidden_states: Optional[bool] = None,
  701. return_dict: Optional[bool] = None,
  702. ) -> Union[tuple, BaseModelOutput]:
  703. r"""
  704. Args:
  705. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  706. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  707. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  708. than the model's internal embedding lookup matrix.
  709. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  710. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  711. - 1 for tokens that are **not masked**,
  712. - 0 for tokens that are **masked**.
  713. [What are attention masks?](../glossary#attention-mask)
  714. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  715. Causal mask for the text model. Mask values selected in `[0, 1]`:
  716. - 1 for tokens that are **not masked**,
  717. - 0 for tokens that are **masked**.
  718. [What are attention masks?](../glossary#attention-mask)
  719. output_attentions (`bool`, *optional*):
  720. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  721. returned tensors for more detail.
  722. output_hidden_states (`bool`, *optional*):
  723. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  724. for more detail.
  725. return_dict (`bool`, *optional*):
  726. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  727. """
  728. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  729. output_hidden_states = (
  730. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  731. )
  732. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  733. encoder_states = () if output_hidden_states else None
  734. all_attentions = () if output_attentions else None
  735. hidden_states = inputs_embeds
  736. for idx, encoder_layer in enumerate(self.layers):
  737. if output_hidden_states:
  738. encoder_states = encoder_states + (hidden_states,)
  739. layer_outputs = encoder_layer(
  740. hidden_states,
  741. attention_mask,
  742. causal_attention_mask,
  743. output_attentions=output_attentions,
  744. )
  745. hidden_states = layer_outputs[0]
  746. if output_attentions:
  747. all_attentions = all_attentions + (layer_outputs[1],)
  748. if output_hidden_states:
  749. encoder_states = encoder_states + (hidden_states,)
  750. if not return_dict:
  751. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  752. return BaseModelOutput(
  753. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  754. )
  755. class GroupViTTextTransformer(nn.Module):
  756. def __init__(self, config: GroupViTTextConfig):
  757. super().__init__()
  758. self.config = config
  759. embed_dim = config.hidden_size
  760. self.embeddings = GroupViTTextEmbeddings(config)
  761. self.encoder = GroupViTTextEncoder(config)
  762. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  763. # For `pooled_output` computation
  764. self.eos_token_id = config.eos_token_id
  765. @auto_docstring
  766. def forward(
  767. self,
  768. input_ids: Optional[torch.Tensor] = None,
  769. attention_mask: Optional[torch.Tensor] = None,
  770. position_ids: Optional[torch.Tensor] = None,
  771. output_attentions: Optional[bool] = None,
  772. output_hidden_states: Optional[bool] = None,
  773. return_dict: Optional[bool] = None,
  774. ) -> Union[tuple, BaseModelOutputWithPooling]:
  775. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  776. output_hidden_states = (
  777. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  778. )
  779. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  780. if input_ids is None:
  781. raise ValueError("You have to specify input_ids")
  782. input_shape = input_ids.size()
  783. input_ids = input_ids.view(-1, input_shape[-1])
  784. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  785. # CLIP's text model uses causal mask, prepare it here.
  786. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
  787. causal_attention_mask = _create_4d_causal_attention_mask(
  788. input_shape, hidden_states.dtype, device=hidden_states.device
  789. )
  790. # expand attention_mask
  791. if attention_mask is not None:
  792. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  793. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  794. encoder_outputs = self.encoder(
  795. inputs_embeds=hidden_states,
  796. attention_mask=attention_mask,
  797. causal_attention_mask=causal_attention_mask,
  798. output_attentions=output_attentions,
  799. output_hidden_states=output_hidden_states,
  800. return_dict=return_dict,
  801. )
  802. last_hidden_state = encoder_outputs[0]
  803. last_hidden_state = self.final_layer_norm(last_hidden_state)
  804. if self.eos_token_id == 2:
  805. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  806. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  807. # ------------------------------------------------------------
  808. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  809. # take features from the eot embedding (eot_token is the highest number in each sequence)
  810. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  811. pooled_output = last_hidden_state[
  812. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  813. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  814. ]
  815. else:
  816. # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
  817. pooled_output = last_hidden_state[
  818. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  819. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  820. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  821. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  822. .int()
  823. .argmax(dim=-1),
  824. ]
  825. if not return_dict:
  826. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  827. return BaseModelOutputWithPooling(
  828. last_hidden_state=last_hidden_state,
  829. pooler_output=pooled_output,
  830. hidden_states=encoder_outputs.hidden_states,
  831. attentions=encoder_outputs.attentions,
  832. )
  833. class GroupViTTextModel(GroupViTPreTrainedModel):
  834. config: GroupViTTextConfig
  835. def __init__(self, config: GroupViTTextConfig):
  836. super().__init__(config)
  837. self.text_model = GroupViTTextTransformer(config)
  838. # Initialize weights and apply final processing
  839. self.post_init()
  840. def get_input_embeddings(self) -> nn.Module:
  841. return self.text_model.embeddings.token_embedding
  842. def set_input_embeddings(self, value):
  843. self.text_model.embeddings.token_embedding = value
  844. @auto_docstring
  845. def forward(
  846. self,
  847. input_ids: Optional[torch.Tensor] = None,
  848. attention_mask: Optional[torch.Tensor] = None,
  849. position_ids: Optional[torch.Tensor] = None,
  850. output_attentions: Optional[bool] = None,
  851. output_hidden_states: Optional[bool] = None,
  852. return_dict: Optional[bool] = None,
  853. ) -> Union[tuple, BaseModelOutputWithPooling]:
  854. r"""
  855. Examples:
  856. ```python
  857. >>> from transformers import CLIPTokenizer, GroupViTTextModel
  858. >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
  859. >>> model = GroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  860. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  861. >>> outputs = model(**inputs)
  862. >>> last_hidden_state = outputs.last_hidden_state
  863. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  864. ```"""
  865. return self.text_model(
  866. input_ids=input_ids,
  867. attention_mask=attention_mask,
  868. position_ids=position_ids,
  869. output_attentions=output_attentions,
  870. output_hidden_states=output_hidden_states,
  871. return_dict=return_dict,
  872. )
  873. class GroupViTVisionTransformer(nn.Module):
  874. def __init__(self, config: GroupViTVisionConfig):
  875. super().__init__()
  876. self.config = config
  877. embed_dim = config.hidden_size
  878. self.embeddings = GroupViTVisionEmbeddings(config)
  879. self.encoder = GroupViTVisionEncoder(config)
  880. self.layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  881. @auto_docstring
  882. def forward(
  883. self,
  884. pixel_values: Optional[torch.FloatTensor] = None,
  885. output_hidden_states: Optional[bool] = None,
  886. output_attentions: Optional[bool] = None,
  887. return_dict: Optional[bool] = None,
  888. ) -> Union[tuple, BaseModelOutputWithPooling]:
  889. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  890. output_hidden_states = (
  891. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  892. )
  893. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  894. if pixel_values is None:
  895. raise ValueError("You have to specify pixel_values")
  896. hidden_states = self.embeddings(pixel_values)
  897. encoder_outputs = self.encoder(
  898. hidden_states=hidden_states,
  899. output_hidden_states=output_hidden_states,
  900. output_attentions=output_attentions,
  901. return_dict=return_dict,
  902. )
  903. last_hidden_state = encoder_outputs[0]
  904. # normalize the last hidden state
  905. last_hidden_state = self.layernorm(last_hidden_state)
  906. pooled_output = last_hidden_state.mean(dim=1)
  907. if not return_dict:
  908. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  909. return BaseModelOutputWithPooling(
  910. last_hidden_state=last_hidden_state,
  911. pooler_output=pooled_output,
  912. hidden_states=encoder_outputs.hidden_states,
  913. attentions=encoder_outputs.attentions,
  914. )
  915. class GroupViTVisionModel(GroupViTPreTrainedModel):
  916. config: GroupViTVisionConfig
  917. main_input_name = "pixel_values"
  918. def __init__(self, config: GroupViTVisionConfig):
  919. super().__init__(config)
  920. self.vision_model = GroupViTVisionTransformer(config)
  921. # Initialize weights and apply final processing
  922. self.post_init()
  923. def get_input_embeddings(self) -> GroupViTPatchEmbeddings:
  924. return self.vision_model.embeddings.patch_embeddings
  925. @auto_docstring
  926. def forward(
  927. self,
  928. pixel_values: Optional[torch.FloatTensor] = None,
  929. output_attentions: Optional[bool] = None,
  930. output_hidden_states: Optional[bool] = None,
  931. return_dict: Optional[bool] = None,
  932. ) -> Union[tuple, BaseModelOutputWithPooling]:
  933. r"""
  934. Examples:
  935. ```python
  936. >>> from PIL import Image
  937. >>> import requests
  938. >>> from transformers import AutoProcessor, GroupViTVisionModel
  939. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  940. >>> model = GroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  941. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  942. >>> image = Image.open(requests.get(url, stream=True).raw)
  943. >>> inputs = processor(images=image, return_tensors="pt")
  944. >>> outputs = model(**inputs)
  945. >>> last_hidden_state = outputs.last_hidden_state
  946. >>> pooled_output = outputs.pooler_output # pooled CLS states
  947. ```"""
  948. return self.vision_model(
  949. pixel_values=pixel_values,
  950. output_attentions=output_attentions,
  951. output_hidden_states=output_hidden_states,
  952. return_dict=return_dict,
  953. )
  954. @auto_docstring
  955. class GroupViTModel(GroupViTPreTrainedModel):
  956. config: GroupViTConfig
  957. def __init__(self, config: GroupViTConfig):
  958. super().__init__(config)
  959. if not isinstance(config.text_config, GroupViTTextConfig):
  960. raise TypeError(
  961. "config.text_config is expected to be of type GroupViTTextConfig but is of type"
  962. f" {type(config.text_config)}."
  963. )
  964. if not isinstance(config.vision_config, GroupViTVisionConfig):
  965. raise TypeError(
  966. "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
  967. f" {type(config.vision_config)}."
  968. )
  969. text_config = config.text_config
  970. vision_config = config.vision_config
  971. self.projection_dim = config.projection_dim
  972. self.projection_intermediate_dim = config.projection_intermediate_dim
  973. self.text_embed_dim = text_config.hidden_size
  974. self.vision_embed_dim = vision_config.hidden_size
  975. self.text_model = GroupViTTextTransformer(text_config)
  976. self.vision_model = GroupViTVisionTransformer(vision_config)
  977. self.visual_projection = nn.Sequential(
  978. nn.Linear(self.vision_embed_dim, self.projection_intermediate_dim, bias=True),
  979. nn.BatchNorm1d(self.projection_intermediate_dim),
  980. nn.ReLU(inplace=True),
  981. nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
  982. )
  983. self.text_projection = nn.Sequential(
  984. nn.Linear(self.text_embed_dim, self.projection_intermediate_dim, bias=True),
  985. nn.BatchNorm1d(self.projection_intermediate_dim),
  986. nn.ReLU(inplace=True),
  987. nn.Linear(self.projection_intermediate_dim, self.projection_dim, bias=True),
  988. )
  989. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  990. # Initialize weights and apply final processing
  991. self.post_init()
  992. @filter_out_non_signature_kwargs()
  993. @auto_docstring
  994. def get_text_features(
  995. self,
  996. input_ids: torch.Tensor,
  997. attention_mask: Optional[torch.Tensor] = None,
  998. position_ids: Optional[torch.Tensor] = None,
  999. ) -> torch.FloatTensor:
  1000. r"""
  1001. Returns:
  1002. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  1003. applying the projection layer to the pooled output of [`GroupViTTextModel`].
  1004. Examples:
  1005. ```python
  1006. >>> import torch
  1007. >>> from transformers import CLIPTokenizer, GroupViTModel
  1008. >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1009. >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1010. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  1011. >>> with torch.inference_mode():
  1012. ... text_features = model.get_text_features(**inputs)
  1013. ```"""
  1014. text_outputs: BaseModelOutputWithPooling = self.text_model(
  1015. input_ids=input_ids,
  1016. attention_mask=attention_mask,
  1017. position_ids=position_ids,
  1018. )
  1019. text_features = self.text_projection(text_outputs.pooler_output)
  1020. return text_features
  1021. @filter_out_non_signature_kwargs()
  1022. @auto_docstring
  1023. def get_image_features(self, pixel_values: torch.Tensor) -> torch.FloatTensor:
  1024. r"""
  1025. Returns:
  1026. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  1027. applying the projection layer to the pooled output of [`GroupViTVisionModel`].
  1028. Examples:
  1029. ```python
  1030. >>> import torch
  1031. >>> from transformers import AutoProcessor, GroupViTModel
  1032. >>> from transformers.image_utils import load_image
  1033. >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1034. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1035. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1036. >>> image = load_image(url)
  1037. >>> inputs = processor(images=image, return_tensors="pt")
  1038. >>> with torch.inference_mode():
  1039. ... image_features = model.get_image_features(**inputs)
  1040. ```"""
  1041. vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values)
  1042. image_features = self.visual_projection(vision_outputs.pooler_output)
  1043. return image_features
  1044. @auto_docstring
  1045. def forward(
  1046. self,
  1047. input_ids: Optional[torch.LongTensor] = None,
  1048. pixel_values: Optional[torch.FloatTensor] = None,
  1049. attention_mask: Optional[torch.Tensor] = None,
  1050. position_ids: Optional[torch.LongTensor] = None,
  1051. return_loss: Optional[bool] = None,
  1052. output_attentions: Optional[bool] = None,
  1053. output_hidden_states: Optional[bool] = None,
  1054. output_segmentation: Optional[bool] = None,
  1055. return_dict: Optional[bool] = None,
  1056. ) -> Union[tuple, GroupViTModelOutput]:
  1057. r"""
  1058. return_loss (`bool`, *optional*):
  1059. Whether or not to return the contrastive loss.
  1060. output_segmentation (`bool`, *optional*):
  1061. Whether or not to return the segmentation logits.
  1062. Examples:
  1063. ```python
  1064. >>> from PIL import Image
  1065. >>> import requests
  1066. >>> from transformers import AutoProcessor, GroupViTModel
  1067. >>> model = GroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1068. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1069. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1070. >>> image = Image.open(requests.get(url, stream=True).raw)
  1071. >>> inputs = processor(
  1072. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  1073. ... )
  1074. >>> outputs = model(**inputs)
  1075. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  1076. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  1077. ```"""
  1078. # Use GROUPVIT model's config for some fields (if specified) instead of those of vision & text components.
  1079. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1080. output_segmentation = (
  1081. output_segmentation if output_segmentation is not None else self.config.output_segmentation
  1082. )
  1083. if output_segmentation:
  1084. output_attentions = True
  1085. output_hidden_states = (
  1086. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1087. )
  1088. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1089. vision_outputs = self.vision_model(
  1090. pixel_values=pixel_values,
  1091. output_attentions=output_attentions,
  1092. output_hidden_states=output_hidden_states,
  1093. return_dict=return_dict,
  1094. )
  1095. text_outputs = self.text_model(
  1096. input_ids=input_ids,
  1097. attention_mask=attention_mask,
  1098. position_ids=position_ids,
  1099. output_attentions=output_attentions,
  1100. output_hidden_states=output_hidden_states,
  1101. return_dict=return_dict,
  1102. )
  1103. image_embeds = vision_outputs[1]
  1104. image_embeds = self.visual_projection(image_embeds)
  1105. text_embeds = text_outputs[1]
  1106. text_embeds = self.text_projection(text_embeds)
  1107. # normalized features
  1108. image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
  1109. text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
  1110. # cosine similarity as logits
  1111. logit_scale = self.logit_scale.exp()
  1112. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  1113. logits_per_image = logits_per_text.t()
  1114. seg_logits = None
  1115. if output_segmentation:
  1116. # grouped features
  1117. # [batch_size_image, num_group, hidden_size]
  1118. image_group_embeds = vision_outputs[0]
  1119. # [batch_size_image*num_group, hidden_size]
  1120. image_group_embeds = self.visual_projection(image_group_embeds.reshape(-1, image_group_embeds.shape[-1]))
  1121. if output_hidden_states:
  1122. attentions = vision_outputs[3]
  1123. else:
  1124. attentions = vision_outputs[2]
  1125. # [batch_size_image, num_group, height, width]
  1126. grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
  1127. # normalized features
  1128. image_group_embeds = image_group_embeds / image_group_embeds.norm(dim=-1, keepdim=True)
  1129. # [batch_size_image x num_group, batch_size_text]
  1130. logits_per_image_group = torch.matmul(image_group_embeds, text_embeds.t()) * logit_scale
  1131. # [batch_size_image, batch_size_text, num_group]
  1132. logits_per_image_group = logits_per_image_group.reshape(
  1133. image_embeds.shape[0], -1, text_embeds.shape[0]
  1134. ).permute(0, 2, 1)
  1135. # [batch_size_image, batch_size_text, height x width]
  1136. flatten_grouping = grouping.reshape(grouping.shape[0], grouping.shape[1], -1)
  1137. # [batch_size_image, batch_size_text, height, width]
  1138. seg_logits = torch.matmul(logits_per_image_group, flatten_grouping) * logit_scale
  1139. seg_logits = seg_logits.reshape(
  1140. seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3]
  1141. )
  1142. loss = None
  1143. if return_loss:
  1144. loss = groupvit_loss(logits_per_text)
  1145. if not return_dict:
  1146. if seg_logits is not None:
  1147. output = (
  1148. logits_per_image,
  1149. logits_per_text,
  1150. seg_logits,
  1151. text_embeds,
  1152. image_embeds,
  1153. text_outputs,
  1154. vision_outputs,
  1155. )
  1156. else:
  1157. output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
  1158. return ((loss,) + output) if loss is not None else output
  1159. return GroupViTModelOutput(
  1160. loss=loss,
  1161. logits_per_image=logits_per_image,
  1162. logits_per_text=logits_per_text,
  1163. segmentation_logits=seg_logits,
  1164. text_embeds=text_embeds,
  1165. image_embeds=image_embeds,
  1166. text_model_output=text_outputs,
  1167. vision_model_output=vision_outputs,
  1168. )
  1169. __all__ = ["GroupViTModel", "GroupViTPreTrainedModel", "GroupViTTextModel", "GroupViTVisionModel"]