modeling_clipseg.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358
  1. # coding=utf-8
  2. # Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CLIPSeg model."""
  16. import copy
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Any, Callable, Optional, Union
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  27. from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
  28. from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
  29. logger = logging.get_logger(__name__)
  30. # contrastive loss function, adapted from
  31. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  32. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  33. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  34. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clipseg
  35. def clipseg_loss(similarity: torch.Tensor) -> torch.Tensor:
  36. caption_loss = contrastive_loss(similarity)
  37. image_loss = contrastive_loss(similarity.t())
  38. return (caption_loss + image_loss) / 2.0
  39. @dataclass
  40. @auto_docstring
  41. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->CLIPSeg
  42. class CLIPSegOutput(ModelOutput):
  43. r"""
  44. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  45. Contrastive loss for image-text similarity.
  46. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  47. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  48. similarity scores.
  49. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  50. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  51. similarity scores.
  52. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  53. The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
  54. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  55. The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
  56. text_model_output (`BaseModelOutputWithPooling`):
  57. The output of the [`CLIPSegTextModel`].
  58. vision_model_output (`BaseModelOutputWithPooling`):
  59. The output of the [`CLIPSegVisionModel`].
  60. """
  61. loss: Optional[torch.FloatTensor] = None
  62. logits_per_image: Optional[torch.FloatTensor] = None
  63. logits_per_text: Optional[torch.FloatTensor] = None
  64. text_embeds: Optional[torch.FloatTensor] = None
  65. image_embeds: Optional[torch.FloatTensor] = None
  66. text_model_output: BaseModelOutputWithPooling = None
  67. vision_model_output: BaseModelOutputWithPooling = None
  68. def to_tuple(self) -> tuple[Any]:
  69. return tuple(
  70. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  71. for k in self.keys()
  72. )
  73. @dataclass
  74. @auto_docstring
  75. class CLIPSegDecoderOutput(ModelOutput):
  76. r"""
  77. logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
  78. Classification scores for each pixel.
  79. """
  80. logits: Optional[torch.FloatTensor] = None
  81. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  82. attentions: Optional[tuple[torch.FloatTensor]] = None
  83. @dataclass
  84. @auto_docstring
  85. class CLIPSegImageSegmentationOutput(ModelOutput):
  86. r"""
  87. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  88. Binary cross entropy loss for segmentation.
  89. logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
  90. Classification scores for each pixel.
  91. conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
  92. Conditional embeddings used for segmentation.
  93. pooled_output (`torch.FloatTensor` of shape `(batch_size, embed_dim)`):
  94. Pooled output of the [`CLIPSegVisionModel`].
  95. vision_model_output (`BaseModelOutputWithPooling`):
  96. The output of the [`CLIPSegVisionModel`].
  97. decoder_output (`CLIPSegDecoderOutput`):
  98. The output of the [`CLIPSegDecoder`].
  99. """
  100. loss: Optional[torch.FloatTensor] = None
  101. logits: Optional[torch.FloatTensor] = None
  102. conditional_embeddings: Optional[torch.FloatTensor] = None
  103. pooled_output: Optional[torch.FloatTensor] = None
  104. vision_model_output: BaseModelOutputWithPooling = None
  105. decoder_output: CLIPSegDecoderOutput = None
  106. def to_tuple(self) -> tuple[Any]:
  107. return tuple(
  108. self[k] if k not in ["vision_model_output", "decoder_output"] else getattr(self, k).to_tuple()
  109. for k in self.keys()
  110. )
  111. class CLIPSegVisionEmbeddings(nn.Module):
  112. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
  113. def __init__(self, config: CLIPSegVisionConfig):
  114. super().__init__()
  115. self.config = config
  116. self.embed_dim = config.hidden_size
  117. self.image_size = config.image_size
  118. self.patch_size = config.patch_size
  119. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  120. self.patch_embedding = nn.Conv2d(
  121. in_channels=config.num_channels,
  122. out_channels=self.embed_dim,
  123. kernel_size=self.patch_size,
  124. stride=self.patch_size,
  125. bias=False,
  126. )
  127. self.num_patches = (self.image_size // self.patch_size) ** 2
  128. self.num_positions = self.num_patches + 1
  129. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  130. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  131. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  132. """
  133. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  134. images. This method is also adapted to support torch.jit tracing.
  135. Adapted from:
  136. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  137. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  138. """
  139. num_patches = embeddings.shape[1] - 1
  140. position_embedding = self.position_embedding.weight.unsqueeze(0)
  141. num_positions = position_embedding.shape[1] - 1
  142. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  143. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  144. return self.position_embedding(self.position_ids)
  145. class_pos_embed = position_embedding[:, :1]
  146. patch_pos_embed = position_embedding[:, 1:]
  147. dim = embeddings.shape[-1]
  148. new_height = height // self.patch_size
  149. new_width = width // self.patch_size
  150. sqrt_num_positions = torch_int(num_positions**0.5)
  151. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  152. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  153. patch_pos_embed = nn.functional.interpolate(
  154. patch_pos_embed,
  155. size=(new_height, new_width),
  156. mode="bicubic",
  157. align_corners=False,
  158. )
  159. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  160. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  161. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=True) -> torch.Tensor:
  162. batch_size, _, height, width = pixel_values.shape
  163. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  164. raise ValueError(
  165. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  166. )
  167. patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
  168. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  169. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  170. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  171. if interpolate_pos_encoding:
  172. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  173. else:
  174. embeddings = embeddings + self.position_embedding(self.position_ids)
  175. return embeddings
  176. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->CLIPSeg
  177. class CLIPSegTextEmbeddings(nn.Module):
  178. def __init__(self, config: CLIPSegTextConfig):
  179. super().__init__()
  180. embed_dim = config.hidden_size
  181. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  182. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  183. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  184. self.register_buffer(
  185. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  186. )
  187. def forward(
  188. self,
  189. input_ids: Optional[torch.LongTensor] = None,
  190. position_ids: Optional[torch.LongTensor] = None,
  191. inputs_embeds: Optional[torch.FloatTensor] = None,
  192. ) -> torch.Tensor:
  193. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  194. max_position_embedding = self.position_embedding.weight.shape[0]
  195. if seq_length > max_position_embedding:
  196. raise ValueError(
  197. f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
  198. f"{seq_length} and max_position_embeddings: {max_position_embedding}"
  199. )
  200. if position_ids is None:
  201. position_ids = self.position_ids[:, :seq_length]
  202. if inputs_embeds is None:
  203. inputs_embeds = self.token_embedding(input_ids)
  204. position_embeddings = self.position_embedding(position_ids)
  205. embeddings = inputs_embeds + position_embeddings
  206. return embeddings
  207. # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
  208. def eager_attention_forward(
  209. module: nn.Module,
  210. query: torch.Tensor,
  211. key: torch.Tensor,
  212. value: torch.Tensor,
  213. attention_mask: Optional[torch.Tensor],
  214. scaling: float,
  215. dropout: float = 0.0,
  216. **kwargs,
  217. ):
  218. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  219. if attention_mask is not None:
  220. attn_weights = attn_weights + attention_mask
  221. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  222. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  223. attn_output = torch.matmul(attn_weights, value)
  224. attn_output = attn_output.transpose(1, 2).contiguous()
  225. return attn_output, attn_weights
  226. class CLIPSegAttention(nn.Module):
  227. """Multi-headed attention from 'Attention Is All You Need' paper"""
  228. def __init__(self, config: Union[CLIPSegVisionConfig, CLIPSegTextConfig]):
  229. super().__init__()
  230. self.config = config
  231. self.embed_dim = config.hidden_size
  232. self.num_heads = config.num_attention_heads
  233. self.head_dim = self.embed_dim // self.num_heads
  234. if self.head_dim * self.num_heads != self.embed_dim:
  235. raise ValueError(
  236. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  237. f" {self.num_heads})."
  238. )
  239. self.scale = self.head_dim**-0.5
  240. self.dropout = config.attention_dropout
  241. self.is_causal = False
  242. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  243. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  244. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  245. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: Optional[torch.Tensor] = None,
  250. causal_attention_mask: Optional[torch.Tensor] = None,
  251. output_attentions: Optional[bool] = False,
  252. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  253. """Input shape: Batch x Time x Channel"""
  254. batch_size, seq_length, embed_dim = hidden_states.shape
  255. queries = self.q_proj(hidden_states)
  256. keys = self.k_proj(hidden_states)
  257. values = self.v_proj(hidden_states)
  258. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  259. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  260. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  261. # CLIP text model uses both `causal_attention_mask` and `attention_mask`
  262. # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
  263. if self.config._attn_implementation != "flash_attention_2":
  264. if attention_mask is not None and causal_attention_mask is not None:
  265. attention_mask = attention_mask + causal_attention_mask
  266. elif causal_attention_mask is not None:
  267. attention_mask = causal_attention_mask
  268. else:
  269. self.is_causal = causal_attention_mask is not None
  270. attention_interface: Callable = eager_attention_forward
  271. if self.config._attn_implementation != "eager":
  272. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  273. attn_output, attn_weights = attention_interface(
  274. self,
  275. queries,
  276. keys,
  277. values,
  278. attention_mask,
  279. is_causal=self.is_causal,
  280. scaling=self.scale,
  281. dropout=0.0 if not self.training else self.dropout,
  282. )
  283. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  284. attn_output = self.out_proj(attn_output)
  285. if not output_attentions:
  286. attn_weights = None
  287. return attn_output, attn_weights
  288. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg
  289. class CLIPSegMLP(nn.Module):
  290. def __init__(self, config):
  291. super().__init__()
  292. self.config = config
  293. self.activation_fn = ACT2FN[config.hidden_act]
  294. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  295. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  296. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  297. hidden_states = self.fc1(hidden_states)
  298. hidden_states = self.activation_fn(hidden_states)
  299. hidden_states = self.fc2(hidden_states)
  300. return hidden_states
  301. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
  302. class CLIPSegEncoderLayer(GradientCheckpointingLayer):
  303. def __init__(self, config: CLIPSegConfig):
  304. super().__init__()
  305. self.embed_dim = config.hidden_size
  306. self.self_attn = CLIPSegAttention(config)
  307. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  308. self.mlp = CLIPSegMLP(config)
  309. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  310. def forward(
  311. self,
  312. hidden_states: torch.Tensor,
  313. attention_mask: torch.Tensor,
  314. causal_attention_mask: torch.Tensor,
  315. output_attentions: Optional[bool] = False,
  316. ) -> tuple[torch.FloatTensor]:
  317. """
  318. Args:
  319. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  320. attention_mask (`torch.FloatTensor`): attention mask of size
  321. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  322. `(config.encoder_attention_heads,)`.
  323. output_attentions (`bool`, *optional*):
  324. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  325. returned tensors for more detail.
  326. """
  327. residual = hidden_states
  328. hidden_states = self.layer_norm1(hidden_states)
  329. hidden_states, attn_weights = self.self_attn(
  330. hidden_states=hidden_states,
  331. attention_mask=attention_mask,
  332. causal_attention_mask=causal_attention_mask,
  333. output_attentions=output_attentions,
  334. )
  335. hidden_states = residual + hidden_states
  336. residual = hidden_states
  337. hidden_states = self.layer_norm2(hidden_states)
  338. hidden_states = self.mlp(hidden_states)
  339. hidden_states = residual + hidden_states
  340. outputs = (hidden_states,)
  341. if output_attentions:
  342. outputs += (attn_weights,)
  343. return outputs
  344. @auto_docstring
  345. class CLIPSegPreTrainedModel(PreTrainedModel):
  346. config: CLIPSegConfig
  347. base_model_prefix = "clip"
  348. supports_gradient_checkpointing = True
  349. def _init_weights(self, module):
  350. """Initialize the weights"""
  351. factor = self.config.initializer_factor
  352. if isinstance(module, CLIPSegTextEmbeddings):
  353. module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  354. module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  355. elif isinstance(module, CLIPSegVisionEmbeddings):
  356. factor = self.config.initializer_factor
  357. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  358. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  359. nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  360. elif isinstance(module, CLIPSegAttention):
  361. factor = self.config.initializer_factor
  362. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  363. out_proj_std = (module.embed_dim**-0.5) * factor
  364. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  365. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  366. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  367. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  368. elif isinstance(module, CLIPSegMLP):
  369. factor = self.config.initializer_factor
  370. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  371. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  372. nn.init.normal_(module.fc1.weight, std=fc_std)
  373. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  374. elif isinstance(module, CLIPSegModel):
  375. nn.init.normal_(
  376. module.text_projection.weight,
  377. std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
  378. )
  379. nn.init.normal_(
  380. module.visual_projection.weight,
  381. std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
  382. )
  383. if isinstance(module, nn.LayerNorm):
  384. module.bias.data.zero_()
  385. module.weight.data.fill_(1.0)
  386. if isinstance(module, nn.Linear) and module.bias is not None:
  387. module.bias.data.zero_()
  388. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
  389. class CLIPSegEncoder(nn.Module):
  390. """
  391. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  392. [`CLIPSegEncoderLayer`].
  393. Args:
  394. config: CLIPSegConfig
  395. """
  396. def __init__(self, config: CLIPSegConfig):
  397. super().__init__()
  398. self.config = config
  399. self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  400. self.gradient_checkpointing = False
  401. @can_return_tuple
  402. def forward(
  403. self,
  404. inputs_embeds,
  405. attention_mask: Optional[torch.Tensor] = None,
  406. causal_attention_mask: Optional[torch.Tensor] = None,
  407. output_attentions: Optional[bool] = None,
  408. output_hidden_states: Optional[bool] = None,
  409. return_dict: Optional[bool] = None,
  410. ) -> Union[tuple, BaseModelOutput]:
  411. r"""
  412. Args:
  413. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  414. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  415. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  416. than the model's internal embedding lookup matrix.
  417. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  418. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  419. - 1 for tokens that are **not masked**,
  420. - 0 for tokens that are **masked**.
  421. [What are attention masks?](../glossary#attention-mask)
  422. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  423. Causal mask for the text model. Mask values selected in `[0, 1]`:
  424. - 1 for tokens that are **not masked**,
  425. - 0 for tokens that are **masked**.
  426. [What are attention masks?](../glossary#attention-mask)
  427. output_attentions (`bool`, *optional*):
  428. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  429. returned tensors for more detail.
  430. output_hidden_states (`bool`, *optional*):
  431. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  432. for more detail.
  433. return_dict (`bool`, *optional*):
  434. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  435. """
  436. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  437. output_hidden_states = (
  438. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  439. )
  440. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  441. encoder_states = () if output_hidden_states else None
  442. all_attentions = () if output_attentions else None
  443. hidden_states = inputs_embeds
  444. for idx, encoder_layer in enumerate(self.layers):
  445. if output_hidden_states:
  446. encoder_states = encoder_states + (hidden_states,)
  447. layer_outputs = encoder_layer(
  448. hidden_states,
  449. attention_mask,
  450. causal_attention_mask,
  451. output_attentions=output_attentions,
  452. )
  453. hidden_states = layer_outputs[0]
  454. if output_attentions:
  455. all_attentions = all_attentions + (layer_outputs[1],)
  456. if output_hidden_states:
  457. encoder_states = encoder_states + (hidden_states,)
  458. return BaseModelOutput(
  459. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  460. )
  461. class CLIPSegTextTransformer(nn.Module):
  462. def __init__(self, config: CLIPSegTextConfig):
  463. super().__init__()
  464. self.config = config
  465. embed_dim = config.hidden_size
  466. self.embeddings = CLIPSegTextEmbeddings(config)
  467. self.encoder = CLIPSegEncoder(config)
  468. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  469. # For `pooled_output` computation
  470. self.eos_token_id = config.eos_token_id
  471. @auto_docstring
  472. def forward(
  473. self,
  474. input_ids: Optional[torch.Tensor] = None,
  475. attention_mask: Optional[torch.Tensor] = None,
  476. position_ids: Optional[torch.Tensor] = None,
  477. output_attentions: Optional[bool] = None,
  478. output_hidden_states: Optional[bool] = None,
  479. return_dict: Optional[bool] = None,
  480. ) -> Union[tuple, BaseModelOutputWithPooling]:
  481. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  482. output_hidden_states = (
  483. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  484. )
  485. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  486. if input_ids is None:
  487. raise ValueError("You have to specify input_ids")
  488. input_shape = input_ids.size()
  489. input_ids = input_ids.view(-1, input_shape[-1])
  490. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  491. # CLIPSeg's text model uses causal mask, prepare it here.
  492. # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
  493. causal_attention_mask = _create_4d_causal_attention_mask(
  494. input_shape, hidden_states.dtype, device=hidden_states.device
  495. )
  496. # expand attention_mask
  497. if attention_mask is not None:
  498. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  499. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  500. encoder_outputs = self.encoder(
  501. inputs_embeds=hidden_states,
  502. attention_mask=attention_mask,
  503. causal_attention_mask=causal_attention_mask,
  504. output_attentions=output_attentions,
  505. output_hidden_states=output_hidden_states,
  506. return_dict=return_dict,
  507. )
  508. last_hidden_state = encoder_outputs[0]
  509. last_hidden_state = self.final_layer_norm(last_hidden_state)
  510. if self.eos_token_id == 2:
  511. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  512. # A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  513. # ------------------------------------------------------------
  514. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  515. # take features from the eot embedding (eot_token is the highest number in each sequence)
  516. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  517. pooled_output = last_hidden_state[
  518. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  519. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  520. ]
  521. else:
  522. # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
  523. pooled_output = last_hidden_state[
  524. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  525. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  526. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  527. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  528. .int()
  529. .argmax(dim=-1),
  530. ]
  531. if not return_dict:
  532. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  533. return BaseModelOutputWithPooling(
  534. last_hidden_state=last_hidden_state,
  535. pooler_output=pooled_output,
  536. hidden_states=encoder_outputs.hidden_states,
  537. attentions=encoder_outputs.attentions,
  538. )
  539. class CLIPSegTextModel(CLIPSegPreTrainedModel):
  540. config: CLIPSegTextConfig
  541. _no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]
  542. def __init__(self, config: CLIPSegTextConfig):
  543. super().__init__(config)
  544. self.text_model = CLIPSegTextTransformer(config)
  545. # Initialize weights and apply final processing
  546. self.post_init()
  547. def get_input_embeddings(self) -> nn.Module:
  548. return self.text_model.embeddings.token_embedding
  549. def set_input_embeddings(self, value):
  550. self.text_model.embeddings.token_embedding = value
  551. @auto_docstring
  552. def forward(
  553. self,
  554. input_ids: Optional[torch.Tensor] = None,
  555. attention_mask: Optional[torch.Tensor] = None,
  556. position_ids: Optional[torch.Tensor] = None,
  557. output_attentions: Optional[bool] = None,
  558. output_hidden_states: Optional[bool] = None,
  559. return_dict: Optional[bool] = None,
  560. ) -> Union[tuple, BaseModelOutputWithPooling]:
  561. r"""
  562. Examples:
  563. ```python
  564. >>> from transformers import AutoTokenizer, CLIPSegTextModel
  565. >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
  566. >>> model = CLIPSegTextModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  567. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  568. >>> outputs = model(**inputs)
  569. >>> last_hidden_state = outputs.last_hidden_state
  570. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  571. ```"""
  572. return self.text_model(
  573. input_ids=input_ids,
  574. attention_mask=attention_mask,
  575. position_ids=position_ids,
  576. output_attentions=output_attentions,
  577. output_hidden_states=output_hidden_states,
  578. return_dict=return_dict,
  579. )
  580. class CLIPSegVisionTransformer(nn.Module):
  581. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIP->CLIPSeg
  582. def __init__(self, config: CLIPSegVisionConfig):
  583. super().__init__()
  584. self.config = config
  585. embed_dim = config.hidden_size
  586. self.embeddings = CLIPSegVisionEmbeddings(config)
  587. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  588. self.encoder = CLIPSegEncoder(config)
  589. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  590. @auto_docstring
  591. def forward(
  592. self,
  593. pixel_values: Optional[torch.FloatTensor],
  594. output_attentions: Optional[bool] = None,
  595. output_hidden_states: Optional[bool] = None,
  596. return_dict: Optional[bool] = None,
  597. interpolate_pos_encoding: Optional[bool] = True,
  598. ) -> Union[tuple, BaseModelOutputWithPooling]:
  599. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  600. output_hidden_states = (
  601. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  602. )
  603. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  604. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  605. hidden_states = self.pre_layrnorm(hidden_states)
  606. encoder_outputs = self.encoder(
  607. inputs_embeds=hidden_states,
  608. output_attentions=output_attentions,
  609. output_hidden_states=output_hidden_states,
  610. return_dict=return_dict,
  611. )
  612. last_hidden_state = encoder_outputs[0]
  613. pooled_output = last_hidden_state[:, 0, :]
  614. pooled_output = self.post_layernorm(pooled_output)
  615. if not return_dict:
  616. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  617. return BaseModelOutputWithPooling(
  618. last_hidden_state=last_hidden_state,
  619. pooler_output=pooled_output,
  620. hidden_states=encoder_outputs.hidden_states,
  621. attentions=encoder_outputs.attentions,
  622. )
  623. class CLIPSegVisionModel(CLIPSegPreTrainedModel):
  624. config: CLIPSegVisionConfig
  625. main_input_name = "pixel_values"
  626. def __init__(self, config: CLIPSegVisionConfig):
  627. super().__init__(config)
  628. self.vision_model = CLIPSegVisionTransformer(config)
  629. # Initialize weights and apply final processing
  630. self.post_init()
  631. def get_input_embeddings(self) -> nn.Module:
  632. return self.vision_model.embeddings.patch_embedding
  633. @auto_docstring
  634. def forward(
  635. self,
  636. pixel_values: Optional[torch.FloatTensor] = None,
  637. output_attentions: Optional[bool] = None,
  638. output_hidden_states: Optional[bool] = None,
  639. interpolate_pos_encoding: Optional[bool] = True,
  640. return_dict: Optional[bool] = None,
  641. ) -> Union[tuple, BaseModelOutputWithPooling]:
  642. r"""
  643. Examples:
  644. ```python
  645. >>> from PIL import Image
  646. >>> import requests
  647. >>> from transformers import AutoProcessor, CLIPSegVisionModel
  648. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  649. >>> model = CLIPSegVisionModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  650. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  651. >>> image = Image.open(requests.get(url, stream=True).raw)
  652. >>> inputs = processor(images=image, return_tensors="pt")
  653. >>> outputs = model(**inputs)
  654. >>> last_hidden_state = outputs.last_hidden_state
  655. >>> pooled_output = outputs.pooler_output # pooled CLS states
  656. ```"""
  657. return self.vision_model(
  658. pixel_values=pixel_values,
  659. output_attentions=output_attentions,
  660. output_hidden_states=output_hidden_states,
  661. interpolate_pos_encoding=interpolate_pos_encoding,
  662. return_dict=return_dict,
  663. )
  664. @auto_docstring
  665. class CLIPSegModel(CLIPSegPreTrainedModel):
  666. config: CLIPSegConfig
  667. def __init__(self, config: CLIPSegConfig):
  668. super().__init__(config)
  669. if not isinstance(config.text_config, CLIPSegTextConfig):
  670. raise TypeError(
  671. "config.text_config is expected to be of type CLIPSegTextConfig but is of type"
  672. f" {type(config.text_config)}."
  673. )
  674. if not isinstance(config.vision_config, CLIPSegVisionConfig):
  675. raise TypeError(
  676. "config.vision_config is expected to be of type CLIPSegVisionConfig but is of type"
  677. f" {type(config.vision_config)}."
  678. )
  679. text_config = config.text_config
  680. vision_config = config.vision_config
  681. # The module using it is not a PreTrainedModel subclass so we need this
  682. text_config._attn_implementation = config._attn_implementation
  683. # The module using it is not a PreTrainedModel subclass so we need this
  684. vision_config._attn_implementation = config._attn_implementation
  685. self.projection_dim = config.projection_dim
  686. self.text_embed_dim = text_config.hidden_size
  687. self.vision_embed_dim = vision_config.hidden_size
  688. self.text_model = CLIPSegTextTransformer(text_config)
  689. self.vision_model = CLIPSegVisionTransformer(vision_config)
  690. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  691. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  692. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  693. # Initialize weights and apply final processing
  694. self.post_init()
  695. @filter_out_non_signature_kwargs()
  696. @auto_docstring
  697. def get_text_features(
  698. self,
  699. input_ids: torch.Tensor,
  700. attention_mask: Optional[torch.Tensor] = None,
  701. position_ids: Optional[torch.Tensor] = None,
  702. ) -> torch.FloatTensor:
  703. r"""
  704. Returns:
  705. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  706. applying the projection layer to the pooled output of [`CLIPSegTextModel`].
  707. Examples:
  708. ```python
  709. >>> import torch
  710. >>> from transformers import AutoTokenizer, CLIPSegModel
  711. >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
  712. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  713. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  714. >>> with torch.inference_mode():
  715. ... text_features = model.get_text_features(**inputs)
  716. ```"""
  717. text_outputs: BaseModelOutputWithPooling = self.text_model(
  718. input_ids=input_ids,
  719. attention_mask=attention_mask,
  720. position_ids=position_ids,
  721. )
  722. pooled_output = text_outputs.pooler_output
  723. text_features = self.text_projection(pooled_output)
  724. return text_features
  725. @filter_out_non_signature_kwargs()
  726. @auto_docstring
  727. def get_image_features(
  728. self,
  729. pixel_values: torch.FloatTensor,
  730. interpolate_pos_encoding: bool = True,
  731. ) -> torch.FloatTensor:
  732. r"""
  733. Returns:
  734. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  735. applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
  736. Examples:
  737. ```python
  738. >>> import torch
  739. >>> from transformers import AutoProcessor, CLIPSegModel
  740. >>> from transformers.image_utils import load_image
  741. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  742. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  743. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  744. >>> image = load_image(url)
  745. >>> inputs = processor(images=image, return_tensors="pt")
  746. >>> with torch.inference_mode():
  747. ... image_features = model.get_image_features(**inputs)
  748. ```"""
  749. vision_outputs: BaseModelOutputWithPooling = self.vision_model(
  750. pixel_values=pixel_values,
  751. interpolate_pos_encoding=interpolate_pos_encoding,
  752. )
  753. pooled_output = vision_outputs.pooler_output
  754. image_features = self.visual_projection(pooled_output)
  755. return image_features
  756. @auto_docstring
  757. def forward(
  758. self,
  759. input_ids: Optional[torch.LongTensor] = None,
  760. pixel_values: Optional[torch.FloatTensor] = None,
  761. attention_mask: Optional[torch.Tensor] = None,
  762. position_ids: Optional[torch.LongTensor] = None,
  763. return_loss: Optional[bool] = None,
  764. output_attentions: Optional[bool] = None,
  765. output_hidden_states: Optional[bool] = None,
  766. interpolate_pos_encoding: bool = True,
  767. return_dict: Optional[bool] = None,
  768. ) -> Union[tuple, CLIPSegOutput]:
  769. r"""
  770. return_loss (`bool`, *optional*):
  771. Whether or not to return the contrastive loss.
  772. Examples:
  773. ```python
  774. >>> import torch
  775. >>> from transformers import AutoProcessor, CLIPSegModel
  776. >>> from transformers.image_utils import load_image
  777. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  778. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  779. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  780. >>> image = load_image(url)
  781. >>> inputs = processor(
  782. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  783. ... )
  784. >>> with torch.inference_mode():
  785. ... outputs = model(**inputs)
  786. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  787. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  788. ```"""
  789. # Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
  790. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  791. output_hidden_states = (
  792. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  793. )
  794. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  795. vision_outputs = self.vision_model(
  796. pixel_values=pixel_values,
  797. output_attentions=output_attentions,
  798. output_hidden_states=output_hidden_states,
  799. interpolate_pos_encoding=interpolate_pos_encoding,
  800. return_dict=return_dict,
  801. )
  802. text_outputs = self.text_model(
  803. input_ids=input_ids,
  804. attention_mask=attention_mask,
  805. position_ids=position_ids,
  806. output_attentions=output_attentions,
  807. output_hidden_states=output_hidden_states,
  808. return_dict=return_dict,
  809. )
  810. image_embeds = vision_outputs[1]
  811. image_embeds = self.visual_projection(image_embeds)
  812. text_embeds = text_outputs[1]
  813. text_embeds = self.text_projection(text_embeds)
  814. # normalized features
  815. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  816. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  817. # cosine similarity as logits
  818. logit_scale = self.logit_scale.exp()
  819. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  820. logits_per_image = logits_per_text.t()
  821. loss = None
  822. if return_loss:
  823. loss = clipseg_loss(logits_per_text)
  824. if not return_dict:
  825. output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
  826. return ((loss,) + output) if loss is not None else output
  827. return CLIPSegOutput(
  828. loss=loss,
  829. logits_per_image=logits_per_image,
  830. logits_per_text=logits_per_text,
  831. text_embeds=text_embeds,
  832. image_embeds=image_embeds,
  833. text_model_output=text_outputs,
  834. vision_model_output=vision_outputs,
  835. )
  836. class CLIPSegDecoderLayer(nn.Module):
  837. """
  838. CLIPSeg decoder layer, which is identical to `CLIPSegEncoderLayer`, except that normalization is applied after
  839. self-attention/MLP, rather than before.
  840. """
  841. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer.__init__ with AltCLIP->CLIPSeg
  842. def __init__(self, config: CLIPSegConfig):
  843. super().__init__()
  844. self.embed_dim = config.hidden_size
  845. self.self_attn = CLIPSegAttention(config)
  846. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  847. self.mlp = CLIPSegMLP(config)
  848. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  849. def forward(
  850. self,
  851. hidden_states: torch.Tensor,
  852. attention_mask: torch.Tensor,
  853. causal_attention_mask: torch.Tensor,
  854. output_attentions: Optional[bool] = False,
  855. ) -> tuple[torch.FloatTensor]:
  856. """
  857. Args:
  858. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  859. attention_mask (`torch.FloatTensor`): attention mask of size
  860. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  861. `(config.encoder_attention_heads,)`.
  862. output_attentions (`bool`, *optional*):
  863. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  864. returned tensors for more detail.
  865. """
  866. residual = hidden_states
  867. hidden_states, attn_weights = self.self_attn(
  868. hidden_states=hidden_states,
  869. attention_mask=attention_mask,
  870. causal_attention_mask=causal_attention_mask,
  871. output_attentions=output_attentions,
  872. )
  873. hidden_states = residual + hidden_states
  874. hidden_states = self.layer_norm1(hidden_states)
  875. residual = hidden_states
  876. hidden_states = self.mlp(hidden_states)
  877. hidden_states = residual + hidden_states
  878. hidden_states = self.layer_norm2(hidden_states)
  879. outputs = (hidden_states,)
  880. if output_attentions:
  881. outputs += (attn_weights,)
  882. return outputs
  883. class CLIPSegDecoder(CLIPSegPreTrainedModel):
  884. def __init__(self, config: CLIPSegConfig):
  885. super().__init__(config)
  886. self.conditional_layer = config.conditional_layer
  887. self.film_mul = nn.Linear(config.projection_dim, config.reduce_dim)
  888. self.film_add = nn.Linear(config.projection_dim, config.reduce_dim)
  889. if config.use_complex_transposed_convolution:
  890. transposed_kernels = (config.vision_config.patch_size // 4, config.vision_config.patch_size // 4)
  891. self.transposed_convolution = nn.Sequential(
  892. nn.Conv2d(config.reduce_dim, config.reduce_dim, kernel_size=3, padding=1),
  893. nn.ReLU(),
  894. nn.ConvTranspose2d(
  895. config.reduce_dim,
  896. config.reduce_dim // 2,
  897. kernel_size=transposed_kernels[0],
  898. stride=transposed_kernels[0],
  899. ),
  900. nn.ReLU(),
  901. nn.ConvTranspose2d(
  902. config.reduce_dim // 2, 1, kernel_size=transposed_kernels[1], stride=transposed_kernels[1]
  903. ),
  904. )
  905. else:
  906. self.transposed_convolution = nn.ConvTranspose2d(
  907. config.reduce_dim, 1, config.vision_config.patch_size, stride=config.vision_config.patch_size
  908. )
  909. depth = len(config.extract_layers)
  910. self.reduces = nn.ModuleList(
  911. [nn.Linear(config.vision_config.hidden_size, config.reduce_dim) for _ in range(depth)]
  912. )
  913. decoder_config = copy.deepcopy(config.vision_config)
  914. decoder_config.hidden_size = config.reduce_dim
  915. decoder_config.num_attention_heads = config.decoder_num_attention_heads
  916. decoder_config.intermediate_size = config.decoder_intermediate_size
  917. decoder_config.hidden_act = "relu"
  918. self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])
  919. def forward(
  920. self,
  921. hidden_states: tuple[torch.Tensor],
  922. conditional_embeddings: torch.Tensor,
  923. output_attentions: Optional[bool] = None,
  924. output_hidden_states: Optional[bool] = None,
  925. return_dict: Optional[bool] = True,
  926. ):
  927. all_hidden_states = () if output_hidden_states else None
  928. all_attentions = () if output_attentions else None
  929. activations = hidden_states[::-1]
  930. output = None
  931. for i, (activation, layer, reduce) in enumerate(zip(activations, self.layers, self.reduces)):
  932. if output is not None:
  933. output = reduce(activation) + output
  934. else:
  935. output = reduce(activation)
  936. if i == self.conditional_layer:
  937. output = self.film_mul(conditional_embeddings) * output.permute(1, 0, 2) + self.film_add(
  938. conditional_embeddings
  939. )
  940. output = output.permute(1, 0, 2)
  941. layer_outputs = layer(
  942. output, attention_mask=None, causal_attention_mask=None, output_attentions=output_attentions
  943. )
  944. output = layer_outputs[0]
  945. if output_hidden_states:
  946. all_hidden_states += (output,)
  947. if output_attentions:
  948. all_attentions += (layer_outputs[1],)
  949. output = output[:, 1:, :].permute(0, 2, 1) # remove cls token and reshape to [batch_size, reduce_dim, seq_len]
  950. size = int(math.sqrt(output.shape[2]))
  951. batch_size = conditional_embeddings.shape[0]
  952. output = output.view(batch_size, output.shape[1], size, size)
  953. logits = self.transposed_convolution(output).squeeze(1)
  954. if not return_dict:
  955. return tuple(v for v in [logits, all_hidden_states, all_attentions] if v is not None)
  956. return CLIPSegDecoderOutput(
  957. logits=logits,
  958. hidden_states=all_hidden_states,
  959. attentions=all_attentions,
  960. )
  961. @auto_docstring(
  962. custom_intro="""
  963. CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation.
  964. """
  965. )
  966. class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
  967. config: CLIPSegConfig
  968. def __init__(self, config: CLIPSegConfig):
  969. super().__init__(config)
  970. self.config = config
  971. self.clip = CLIPSegModel(config)
  972. self.extract_layers = config.extract_layers
  973. self.decoder = CLIPSegDecoder(config)
  974. # Initialize weights and apply final processing
  975. self.post_init()
  976. def get_conditional_embeddings(
  977. self,
  978. batch_size: Optional[int] = None,
  979. input_ids: Optional[torch.Tensor] = None,
  980. attention_mask: Optional[torch.Tensor] = None,
  981. position_ids: Optional[torch.Tensor] = None,
  982. conditional_pixel_values: Optional[torch.Tensor] = None,
  983. ):
  984. if input_ids is not None:
  985. # compute conditional embeddings from texts
  986. if len(input_ids) != batch_size:
  987. raise ValueError("Make sure to pass as many prompt texts as there are query images")
  988. with torch.no_grad():
  989. conditional_embeddings = self.clip.get_text_features(
  990. input_ids, attention_mask=attention_mask, position_ids=position_ids
  991. )
  992. elif conditional_pixel_values is not None:
  993. # compute conditional embeddings from images
  994. if len(conditional_pixel_values) != batch_size:
  995. raise ValueError("Make sure to pass as many prompt images as there are query images")
  996. with torch.no_grad():
  997. conditional_embeddings = self.clip.get_image_features(conditional_pixel_values)
  998. else:
  999. raise ValueError(
  1000. "Invalid conditional, should be either provided as `input_ids` or `conditional_pixel_values`"
  1001. )
  1002. return conditional_embeddings
  1003. @auto_docstring
  1004. def forward(
  1005. self,
  1006. input_ids: Optional[torch.FloatTensor] = None,
  1007. pixel_values: Optional[torch.FloatTensor] = None,
  1008. conditional_pixel_values: Optional[torch.FloatTensor] = None,
  1009. conditional_embeddings: Optional[torch.FloatTensor] = None,
  1010. attention_mask: Optional[torch.Tensor] = None,
  1011. position_ids: Optional[torch.LongTensor] = None,
  1012. labels: Optional[torch.LongTensor] = None,
  1013. output_attentions: Optional[bool] = None,
  1014. output_hidden_states: Optional[bool] = None,
  1015. interpolate_pos_encoding: bool = True,
  1016. return_dict: Optional[bool] = None,
  1017. ) -> Union[tuple, CLIPSegOutput]:
  1018. r"""
  1019. conditional_pixel_values (`torch.FloatTensor`, *optional*):
  1020. The pixel values of the conditional images.
  1021. conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, config.projection_dim)`, *optional*):
  1022. The conditional embeddings for the query images. If provided, the model will use this instead of computing
  1023. the embeddings from the conditional_pixel_values.
  1024. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1025. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1026. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1027. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1028. Examples:
  1029. ```python
  1030. >>> import torch
  1031. >>> from transformers import AutoProcessor, CLIPSegForImageSegmentation
  1032. >>> from transformers.image_utils import load_image
  1033. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  1034. >>> model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
  1035. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1036. >>> image = load_image(url)
  1037. >>> texts = ["a cat", "a remote", "a blanket"]
  1038. >>> inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")
  1039. >>> with torch.inference_mode():
  1040. ... outputs = model(**inputs)
  1041. >>> logits = outputs.logits
  1042. >>> print(logits.shape)
  1043. torch.Size([3, 352, 352])
  1044. ```"""
  1045. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1046. # step 1: forward the query images through the frozen CLIP vision encoder
  1047. with torch.no_grad():
  1048. vision_outputs = self.clip.vision_model(
  1049. pixel_values=pixel_values,
  1050. output_attentions=output_attentions,
  1051. output_hidden_states=True, # we need the intermediate hidden states
  1052. interpolate_pos_encoding=interpolate_pos_encoding,
  1053. return_dict=return_dict,
  1054. )
  1055. pooled_output = self.clip.visual_projection(vision_outputs[1])
  1056. hidden_states = vision_outputs.hidden_states if return_dict else vision_outputs[2]
  1057. # we add +1 here as the hidden states also include the initial embeddings
  1058. activations = [hidden_states[i + 1] for i in self.extract_layers]
  1059. # update vision_outputs
  1060. if return_dict:
  1061. vision_outputs = BaseModelOutputWithPooling(
  1062. last_hidden_state=vision_outputs.last_hidden_state,
  1063. pooler_output=vision_outputs.pooler_output,
  1064. hidden_states=vision_outputs.hidden_states if output_hidden_states else None,
  1065. attentions=vision_outputs.attentions,
  1066. )
  1067. else:
  1068. vision_outputs = (
  1069. vision_outputs[:2] + vision_outputs[3:] if not output_hidden_states else vision_outputs
  1070. )
  1071. # step 2: compute conditional embeddings, either from text, images or an own provided embedding
  1072. if conditional_embeddings is None:
  1073. conditional_embeddings = self.get_conditional_embeddings(
  1074. batch_size=pixel_values.shape[0],
  1075. input_ids=input_ids,
  1076. attention_mask=attention_mask,
  1077. position_ids=position_ids,
  1078. conditional_pixel_values=conditional_pixel_values,
  1079. )
  1080. else:
  1081. if conditional_embeddings.shape[0] != pixel_values.shape[0]:
  1082. raise ValueError(
  1083. "Make sure to pass as many conditional embeddings as there are query images in the batch"
  1084. )
  1085. if conditional_embeddings.shape[1] != self.config.projection_dim:
  1086. raise ValueError(
  1087. "Make sure that the feature dimension of the conditional embeddings matches"
  1088. " `config.projection_dim`."
  1089. )
  1090. # step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
  1091. decoder_outputs = self.decoder(
  1092. activations,
  1093. conditional_embeddings,
  1094. output_attentions=output_attentions,
  1095. output_hidden_states=output_hidden_states,
  1096. return_dict=return_dict,
  1097. )
  1098. logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
  1099. loss = None
  1100. if labels is not None:
  1101. # move labels to the correct device to enable PP
  1102. labels = labels.to(logits.device)
  1103. loss_fn = nn.BCEWithLogitsLoss()
  1104. loss = loss_fn(logits, labels)
  1105. if not return_dict:
  1106. output = (logits, conditional_embeddings, pooled_output, vision_outputs, decoder_outputs)
  1107. return ((loss,) + output) if loss is not None else output
  1108. return CLIPSegImageSegmentationOutput(
  1109. loss=loss,
  1110. logits=logits,
  1111. conditional_embeddings=conditional_embeddings,
  1112. pooled_output=pooled_output,
  1113. vision_model_output=vision_outputs,
  1114. decoder_output=decoder_outputs,
  1115. )
  1116. __all__ = [
  1117. "CLIPSegModel",
  1118. "CLIPSegPreTrainedModel",
  1119. "CLIPSegTextModel",
  1120. "CLIPSegVisionModel",
  1121. "CLIPSegForImageSegmentation",
  1122. ]