| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358 |
- # coding=utf-8
- # Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch CLIPSeg model."""
- import copy
- import math
- from dataclasses import dataclass
- from typing import Any, Callable, Optional, Union
- import torch
- from torch import nn
- from ...activations import ACT2FN
- from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int
- from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
- logger = logging.get_logger(__name__)
- # contrastive loss function, adapted from
- # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
- def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
- return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
- # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clipseg
- def clipseg_loss(similarity: torch.Tensor) -> torch.Tensor:
- caption_loss = contrastive_loss(similarity)
- image_loss = contrastive_loss(similarity.t())
- return (caption_loss + image_loss) / 2.0
- @dataclass
- @auto_docstring
- # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->CLIPSeg
- class CLIPSegOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
- Contrastive loss for image-text similarity.
- logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
- The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
- similarity scores.
- logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
- The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
- similarity scores.
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
- The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
- text_model_output (`BaseModelOutputWithPooling`):
- The output of the [`CLIPSegTextModel`].
- vision_model_output (`BaseModelOutputWithPooling`):
- The output of the [`CLIPSegVisionModel`].
- """
- loss: Optional[torch.FloatTensor] = None
- logits_per_image: Optional[torch.FloatTensor] = None
- logits_per_text: Optional[torch.FloatTensor] = None
- text_embeds: Optional[torch.FloatTensor] = None
- image_embeds: Optional[torch.FloatTensor] = None
- text_model_output: BaseModelOutputWithPooling = None
- vision_model_output: BaseModelOutputWithPooling = None
- def to_tuple(self) -> tuple[Any]:
- return tuple(
- self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
- for k in self.keys()
- )
- @dataclass
- @auto_docstring
- class CLIPSegDecoderOutput(ModelOutput):
- r"""
- logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
- Classification scores for each pixel.
- """
- logits: Optional[torch.FloatTensor] = None
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
- attentions: Optional[tuple[torch.FloatTensor]] = None
- @dataclass
- @auto_docstring
- class CLIPSegImageSegmentationOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Binary cross entropy loss for segmentation.
- logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
- Classification scores for each pixel.
- conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
- Conditional embeddings used for segmentation.
- pooled_output (`torch.FloatTensor` of shape `(batch_size, embed_dim)`):
- Pooled output of the [`CLIPSegVisionModel`].
- vision_model_output (`BaseModelOutputWithPooling`):
- The output of the [`CLIPSegVisionModel`].
- decoder_output (`CLIPSegDecoderOutput`):
- The output of the [`CLIPSegDecoder`].
- """
- loss: Optional[torch.FloatTensor] = None
- logits: Optional[torch.FloatTensor] = None
- conditional_embeddings: Optional[torch.FloatTensor] = None
- pooled_output: Optional[torch.FloatTensor] = None
- vision_model_output: BaseModelOutputWithPooling = None
- decoder_output: CLIPSegDecoderOutput = None
- def to_tuple(self) -> tuple[Any]:
- return tuple(
- self[k] if k not in ["vision_model_output", "decoder_output"] else getattr(self, k).to_tuple()
- for k in self.keys()
- )
- class CLIPSegVisionEmbeddings(nn.Module):
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
- def __init__(self, config: CLIPSegVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- bias=False,
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - 1
- position_embedding = self.position_embedding.weight.unsqueeze(0)
- num_positions = position_embedding.shape[1] - 1
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embedding(self.position_ids)
- class_pos_embed = position_embedding[:, :1]
- patch_pos_embed = position_embedding[:, 1:]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=True) -> torch.Tensor:
- batch_size, _, height, width = pixel_values.shape
- if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
- )
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- if interpolate_pos_encoding:
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- else:
- embeddings = embeddings + self.position_embedding(self.position_ids)
- return embeddings
- # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->CLIPSeg
- class CLIPSegTextEmbeddings(nn.Module):
- def __init__(self, config: CLIPSegTextConfig):
- super().__init__()
- embed_dim = config.hidden_size
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- ) -> torch.Tensor:
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
- max_position_embedding = self.position_embedding.weight.shape[0]
- if seq_length > max_position_embedding:
- raise ValueError(
- f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
- f"{seq_length} and max_position_embeddings: {max_position_embedding}"
- )
- if position_ids is None:
- position_ids = self.position_ids[:, :seq_length]
- if inputs_embeds is None:
- inputs_embeds = self.token_embedding(input_ids)
- position_embeddings = self.position_embedding(position_ids)
- embeddings = inputs_embeds + position_embeddings
- return embeddings
- # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor],
- scaling: float,
- dropout: float = 0.0,
- **kwargs,
- ):
- attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class CLIPSegAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: Union[CLIPSegVisionConfig, CLIPSegTextConfig]):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.is_causal = False
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- causal_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = False,
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
- """Input shape: Batch x Time x Channel"""
- batch_size, seq_length, embed_dim = hidden_states.shape
- queries = self.q_proj(hidden_states)
- keys = self.k_proj(hidden_states)
- values = self.v_proj(hidden_states)
- queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
- keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
- values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
- # CLIP text model uses both `causal_attention_mask` and `attention_mask`
- # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
- if self.config._attn_implementation != "flash_attention_2":
- if attention_mask is not None and causal_attention_mask is not None:
- attention_mask = attention_mask + causal_attention_mask
- elif causal_attention_mask is not None:
- attention_mask = causal_attention_mask
- else:
- self.is_causal = causal_attention_mask is not None
- attention_interface: Callable = eager_attention_forward
- if self.config._attn_implementation != "eager":
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output, attn_weights = attention_interface(
- self,
- queries,
- keys,
- values,
- attention_mask,
- is_causal=self.is_causal,
- scaling=self.scale,
- dropout=0.0 if not self.training else self.dropout,
- )
- attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
- attn_output = self.out_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg
- class CLIPSegMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
- class CLIPSegEncoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: CLIPSegConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = CLIPSegAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = CLIPSegMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- causal_attention_mask: torch.Tensor,
- output_attentions: Optional[bool] = False,
- ) -> tuple[torch.FloatTensor]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- `(config.encoder_attention_heads,)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- """
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- causal_attention_mask=causal_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- @auto_docstring
- class CLIPSegPreTrainedModel(PreTrainedModel):
- config: CLIPSegConfig
- base_model_prefix = "clip"
- supports_gradient_checkpointing = True
- def _init_weights(self, module):
- """Initialize the weights"""
- factor = self.config.initializer_factor
- if isinstance(module, CLIPSegTextEmbeddings):
- module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
- module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
- elif isinstance(module, CLIPSegVisionEmbeddings):
- factor = self.config.initializer_factor
- nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
- nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
- nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
- elif isinstance(module, CLIPSegAttention):
- factor = self.config.initializer_factor
- in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
- out_proj_std = (module.embed_dim**-0.5) * factor
- nn.init.normal_(module.q_proj.weight, std=in_proj_std)
- nn.init.normal_(module.k_proj.weight, std=in_proj_std)
- nn.init.normal_(module.v_proj.weight, std=in_proj_std)
- nn.init.normal_(module.out_proj.weight, std=out_proj_std)
- elif isinstance(module, CLIPSegMLP):
- factor = self.config.initializer_factor
- in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
- fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
- nn.init.normal_(module.fc1.weight, std=fc_std)
- nn.init.normal_(module.fc2.weight, std=in_proj_std)
- elif isinstance(module, CLIPSegModel):
- nn.init.normal_(
- module.text_projection.weight,
- std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
- )
- nn.init.normal_(
- module.visual_projection.weight,
- std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
- )
- if isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
- class CLIPSegEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`CLIPSegEncoderLayer`].
- Args:
- config: CLIPSegConfig
- """
- def __init__(self, config: CLIPSegConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- @can_return_tuple
- def forward(
- self,
- inputs_embeds,
- attention_mask: Optional[torch.Tensor] = None,
- causal_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple, BaseModelOutput]:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Causal mask for the text model. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- encoder_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- hidden_states = inputs_embeds
- for idx, encoder_layer in enumerate(self.layers):
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask,
- causal_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
- )
- class CLIPSegTextTransformer(nn.Module):
- def __init__(self, config: CLIPSegTextConfig):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = CLIPSegTextEmbeddings(config)
- self.encoder = CLIPSegEncoder(config)
- self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- # For `pooled_output` computation
- self.eos_token_id = config.eos_token_id
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple, BaseModelOutputWithPooling]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is None:
- raise ValueError("You have to specify input_ids")
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
- # CLIPSeg's text model uses causal mask, prepare it here.
- # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
- causal_attention_mask = _create_4d_causal_attention_mask(
- input_shape, hidden_states.dtype, device=hidden_states.device
- )
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
- encoder_outputs = self.encoder(
- inputs_embeds=hidden_states,
- attention_mask=attention_mask,
- causal_attention_mask=causal_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = encoder_outputs[0]
- last_hidden_state = self.final_layer_norm(last_hidden_state)
- if self.eos_token_id == 2:
- # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
- # A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
- # ------------------------------------------------------------
- # text_embeds.shape = [batch_size, sequence_length, transformer.width]
- # take features from the eot embedding (eot_token is the highest number in each sequence)
- # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
- pooled_output = last_hidden_state[
- torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
- input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
- ]
- else:
- # The config gets updated `eos_token_id` from PR #24773 (so the use of extra new tokens is possible)
- pooled_output = last_hidden_state[
- torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
- # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
- # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
- (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
- .int()
- .argmax(dim=-1),
- ]
- if not return_dict:
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- class CLIPSegTextModel(CLIPSegPreTrainedModel):
- config: CLIPSegTextConfig
- _no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]
- def __init__(self, config: CLIPSegTextConfig):
- super().__init__(config)
- self.text_model = CLIPSegTextTransformer(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> nn.Module:
- return self.text_model.embeddings.token_embedding
- def set_input_embeddings(self, value):
- self.text_model.embeddings.token_embedding = value
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple, BaseModelOutputWithPooling]:
- r"""
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, CLIPSegTextModel
- >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> model = CLIPSegTextModel.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_state = outputs.last_hidden_state
- >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
- ```"""
- return self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- class CLIPSegVisionTransformer(nn.Module):
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIP->CLIPSeg
- def __init__(self, config: CLIPSegVisionConfig):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = CLIPSegVisionEmbeddings(config)
- self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.encoder = CLIPSegEncoder(config)
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- @auto_docstring
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor],
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- interpolate_pos_encoding: Optional[bool] = True,
- ) -> Union[tuple, BaseModelOutputWithPooling]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- hidden_states = self.pre_layrnorm(hidden_states)
- encoder_outputs = self.encoder(
- inputs_embeds=hidden_states,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = encoder_outputs[0]
- pooled_output = last_hidden_state[:, 0, :]
- pooled_output = self.post_layernorm(pooled_output)
- if not return_dict:
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=last_hidden_state,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- class CLIPSegVisionModel(CLIPSegPreTrainedModel):
- config: CLIPSegVisionConfig
- main_input_name = "pixel_values"
- def __init__(self, config: CLIPSegVisionConfig):
- super().__init__(config)
- self.vision_model = CLIPSegVisionTransformer(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> nn.Module:
- return self.vision_model.embeddings.patch_embedding
- @auto_docstring
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: Optional[bool] = True,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple, BaseModelOutputWithPooling]:
- r"""
- Examples:
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, CLIPSegVisionModel
- >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> model = CLIPSegVisionModel.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_state = outputs.last_hidden_state
- >>> pooled_output = outputs.pooler_output # pooled CLS states
- ```"""
- return self.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- @auto_docstring
- class CLIPSegModel(CLIPSegPreTrainedModel):
- config: CLIPSegConfig
- def __init__(self, config: CLIPSegConfig):
- super().__init__(config)
- if not isinstance(config.text_config, CLIPSegTextConfig):
- raise TypeError(
- "config.text_config is expected to be of type CLIPSegTextConfig but is of type"
- f" {type(config.text_config)}."
- )
- if not isinstance(config.vision_config, CLIPSegVisionConfig):
- raise TypeError(
- "config.vision_config is expected to be of type CLIPSegVisionConfig but is of type"
- f" {type(config.vision_config)}."
- )
- text_config = config.text_config
- vision_config = config.vision_config
- # The module using it is not a PreTrainedModel subclass so we need this
- text_config._attn_implementation = config._attn_implementation
- # The module using it is not a PreTrainedModel subclass so we need this
- vision_config._attn_implementation = config._attn_implementation
- self.projection_dim = config.projection_dim
- self.text_embed_dim = text_config.hidden_size
- self.vision_embed_dim = vision_config.hidden_size
- self.text_model = CLIPSegTextTransformer(text_config)
- self.vision_model = CLIPSegVisionTransformer(vision_config)
- self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
- self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
- self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
- # Initialize weights and apply final processing
- self.post_init()
- @filter_out_non_signature_kwargs()
- @auto_docstring
- def get_text_features(
- self,
- input_ids: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- r"""
- Returns:
- text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
- applying the projection layer to the pooled output of [`CLIPSegTextModel`].
- Examples:
- ```python
- >>> import torch
- >>> from transformers import AutoTokenizer, CLIPSegModel
- >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
- >>> with torch.inference_mode():
- ... text_features = model.get_text_features(**inputs)
- ```"""
- text_outputs: BaseModelOutputWithPooling = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- )
- pooled_output = text_outputs.pooler_output
- text_features = self.text_projection(pooled_output)
- return text_features
- @filter_out_non_signature_kwargs()
- @auto_docstring
- def get_image_features(
- self,
- pixel_values: torch.FloatTensor,
- interpolate_pos_encoding: bool = True,
- ) -> torch.FloatTensor:
- r"""
- Returns:
- image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
- applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
- Examples:
- ```python
- >>> import torch
- >>> from transformers import AutoProcessor, CLIPSegModel
- >>> from transformers.image_utils import load_image
- >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = load_image(url)
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> with torch.inference_mode():
- ... image_features = model.get_image_features(**inputs)
- ```"""
- vision_outputs: BaseModelOutputWithPooling = self.vision_model(
- pixel_values=pixel_values,
- interpolate_pos_encoding=interpolate_pos_encoding,
- )
- pooled_output = vision_outputs.pooler_output
- image_features = self.visual_projection(pooled_output)
- return image_features
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- return_loss: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: bool = True,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple, CLIPSegOutput]:
- r"""
- return_loss (`bool`, *optional*):
- Whether or not to return the contrastive loss.
- Examples:
- ```python
- >>> import torch
- >>> from transformers import AutoProcessor, CLIPSegModel
- >>> from transformers.image_utils import load_image
- >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = load_image(url)
- >>> inputs = processor(
- ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
- ... )
- >>> with torch.inference_mode():
- ... outputs = model(**inputs)
- >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
- >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
- ```"""
- # Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- vision_outputs = self.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- text_outputs = self.text_model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- image_embeds = vision_outputs[1]
- image_embeds = self.visual_projection(image_embeds)
- text_embeds = text_outputs[1]
- text_embeds = self.text_projection(text_embeds)
- # normalized features
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
- # cosine similarity as logits
- logit_scale = self.logit_scale.exp()
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
- logits_per_image = logits_per_text.t()
- loss = None
- if return_loss:
- loss = clipseg_loss(logits_per_text)
- if not return_dict:
- output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
- return ((loss,) + output) if loss is not None else output
- return CLIPSegOutput(
- loss=loss,
- logits_per_image=logits_per_image,
- logits_per_text=logits_per_text,
- text_embeds=text_embeds,
- image_embeds=image_embeds,
- text_model_output=text_outputs,
- vision_model_output=vision_outputs,
- )
- class CLIPSegDecoderLayer(nn.Module):
- """
- CLIPSeg decoder layer, which is identical to `CLIPSegEncoderLayer`, except that normalization is applied after
- self-attention/MLP, rather than before.
- """
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer.__init__ with AltCLIP->CLIPSeg
- def __init__(self, config: CLIPSegConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = CLIPSegAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = CLIPSegMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- causal_attention_mask: torch.Tensor,
- output_attentions: Optional[bool] = False,
- ) -> tuple[torch.FloatTensor]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- `(config.encoder_attention_heads,)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- """
- residual = hidden_states
- hidden_states, attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- causal_attention_mask=causal_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = residual + hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- residual = hidden_states
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- class CLIPSegDecoder(CLIPSegPreTrainedModel):
- def __init__(self, config: CLIPSegConfig):
- super().__init__(config)
- self.conditional_layer = config.conditional_layer
- self.film_mul = nn.Linear(config.projection_dim, config.reduce_dim)
- self.film_add = nn.Linear(config.projection_dim, config.reduce_dim)
- if config.use_complex_transposed_convolution:
- transposed_kernels = (config.vision_config.patch_size // 4, config.vision_config.patch_size // 4)
- self.transposed_convolution = nn.Sequential(
- nn.Conv2d(config.reduce_dim, config.reduce_dim, kernel_size=3, padding=1),
- nn.ReLU(),
- nn.ConvTranspose2d(
- config.reduce_dim,
- config.reduce_dim // 2,
- kernel_size=transposed_kernels[0],
- stride=transposed_kernels[0],
- ),
- nn.ReLU(),
- nn.ConvTranspose2d(
- config.reduce_dim // 2, 1, kernel_size=transposed_kernels[1], stride=transposed_kernels[1]
- ),
- )
- else:
- self.transposed_convolution = nn.ConvTranspose2d(
- config.reduce_dim, 1, config.vision_config.patch_size, stride=config.vision_config.patch_size
- )
- depth = len(config.extract_layers)
- self.reduces = nn.ModuleList(
- [nn.Linear(config.vision_config.hidden_size, config.reduce_dim) for _ in range(depth)]
- )
- decoder_config = copy.deepcopy(config.vision_config)
- decoder_config.hidden_size = config.reduce_dim
- decoder_config.num_attention_heads = config.decoder_num_attention_heads
- decoder_config.intermediate_size = config.decoder_intermediate_size
- decoder_config.hidden_act = "relu"
- self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])
- def forward(
- self,
- hidden_states: tuple[torch.Tensor],
- conditional_embeddings: torch.Tensor,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- activations = hidden_states[::-1]
- output = None
- for i, (activation, layer, reduce) in enumerate(zip(activations, self.layers, self.reduces)):
- if output is not None:
- output = reduce(activation) + output
- else:
- output = reduce(activation)
- if i == self.conditional_layer:
- output = self.film_mul(conditional_embeddings) * output.permute(1, 0, 2) + self.film_add(
- conditional_embeddings
- )
- output = output.permute(1, 0, 2)
- layer_outputs = layer(
- output, attention_mask=None, causal_attention_mask=None, output_attentions=output_attentions
- )
- output = layer_outputs[0]
- if output_hidden_states:
- all_hidden_states += (output,)
- if output_attentions:
- all_attentions += (layer_outputs[1],)
- output = output[:, 1:, :].permute(0, 2, 1) # remove cls token and reshape to [batch_size, reduce_dim, seq_len]
- size = int(math.sqrt(output.shape[2]))
- batch_size = conditional_embeddings.shape[0]
- output = output.view(batch_size, output.shape[1], size, size)
- logits = self.transposed_convolution(output).squeeze(1)
- if not return_dict:
- return tuple(v for v in [logits, all_hidden_states, all_attentions] if v is not None)
- return CLIPSegDecoderOutput(
- logits=logits,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- )
- @auto_docstring(
- custom_intro="""
- CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation.
- """
- )
- class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
- config: CLIPSegConfig
- def __init__(self, config: CLIPSegConfig):
- super().__init__(config)
- self.config = config
- self.clip = CLIPSegModel(config)
- self.extract_layers = config.extract_layers
- self.decoder = CLIPSegDecoder(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_conditional_embeddings(
- self,
- batch_size: Optional[int] = None,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- conditional_pixel_values: Optional[torch.Tensor] = None,
- ):
- if input_ids is not None:
- # compute conditional embeddings from texts
- if len(input_ids) != batch_size:
- raise ValueError("Make sure to pass as many prompt texts as there are query images")
- with torch.no_grad():
- conditional_embeddings = self.clip.get_text_features(
- input_ids, attention_mask=attention_mask, position_ids=position_ids
- )
- elif conditional_pixel_values is not None:
- # compute conditional embeddings from images
- if len(conditional_pixel_values) != batch_size:
- raise ValueError("Make sure to pass as many prompt images as there are query images")
- with torch.no_grad():
- conditional_embeddings = self.clip.get_image_features(conditional_pixel_values)
- else:
- raise ValueError(
- "Invalid conditional, should be either provided as `input_ids` or `conditional_pixel_values`"
- )
- return conditional_embeddings
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.FloatTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- conditional_pixel_values: Optional[torch.FloatTensor] = None,
- conditional_embeddings: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: bool = True,
- return_dict: Optional[bool] = None,
- ) -> Union[tuple, CLIPSegOutput]:
- r"""
- conditional_pixel_values (`torch.FloatTensor`, *optional*):
- The pixel values of the conditional images.
- conditional_embeddings (`torch.FloatTensor` of shape `(batch_size, config.projection_dim)`, *optional*):
- The conditional embeddings for the query images. If provided, the model will use this instead of computing
- the embeddings from the conditional_pixel_values.
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- Examples:
- ```python
- >>> import torch
- >>> from transformers import AutoProcessor, CLIPSegForImageSegmentation
- >>> from transformers.image_utils import load_image
- >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = load_image(url)
- >>> texts = ["a cat", "a remote", "a blanket"]
- >>> inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")
- >>> with torch.inference_mode():
- ... outputs = model(**inputs)
- >>> logits = outputs.logits
- >>> print(logits.shape)
- torch.Size([3, 352, 352])
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # step 1: forward the query images through the frozen CLIP vision encoder
- with torch.no_grad():
- vision_outputs = self.clip.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=True, # we need the intermediate hidden states
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- pooled_output = self.clip.visual_projection(vision_outputs[1])
- hidden_states = vision_outputs.hidden_states if return_dict else vision_outputs[2]
- # we add +1 here as the hidden states also include the initial embeddings
- activations = [hidden_states[i + 1] for i in self.extract_layers]
- # update vision_outputs
- if return_dict:
- vision_outputs = BaseModelOutputWithPooling(
- last_hidden_state=vision_outputs.last_hidden_state,
- pooler_output=vision_outputs.pooler_output,
- hidden_states=vision_outputs.hidden_states if output_hidden_states else None,
- attentions=vision_outputs.attentions,
- )
- else:
- vision_outputs = (
- vision_outputs[:2] + vision_outputs[3:] if not output_hidden_states else vision_outputs
- )
- # step 2: compute conditional embeddings, either from text, images or an own provided embedding
- if conditional_embeddings is None:
- conditional_embeddings = self.get_conditional_embeddings(
- batch_size=pixel_values.shape[0],
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- conditional_pixel_values=conditional_pixel_values,
- )
- else:
- if conditional_embeddings.shape[0] != pixel_values.shape[0]:
- raise ValueError(
- "Make sure to pass as many conditional embeddings as there are query images in the batch"
- )
- if conditional_embeddings.shape[1] != self.config.projection_dim:
- raise ValueError(
- "Make sure that the feature dimension of the conditional embeddings matches"
- " `config.projection_dim`."
- )
- # step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
- decoder_outputs = self.decoder(
- activations,
- conditional_embeddings,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
- loss = None
- if labels is not None:
- # move labels to the correct device to enable PP
- labels = labels.to(logits.device)
- loss_fn = nn.BCEWithLogitsLoss()
- loss = loss_fn(logits, labels)
- if not return_dict:
- output = (logits, conditional_embeddings, pooled_output, vision_outputs, decoder_outputs)
- return ((loss,) + output) if loss is not None else output
- return CLIPSegImageSegmentationOutput(
- loss=loss,
- logits=logits,
- conditional_embeddings=conditional_embeddings,
- pooled_output=pooled_output,
- vision_model_output=vision_outputs,
- decoder_output=decoder_outputs,
- )
- __all__ = [
- "CLIPSegModel",
- "CLIPSegPreTrainedModel",
- "CLIPSegTextModel",
- "CLIPSegVisionModel",
- "CLIPSegForImageSegmentation",
- ]
|