modeling_sam.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368
  1. # coding=utf-8
  2. # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SAM model."""
  16. import collections
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import Tensor, nn
  23. from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
  24. from ...activations import ACT2FN
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import BaseModelOutput
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  28. from ...processing_utils import Unpack
  29. from ...utils import (
  30. ModelOutput,
  31. auto_docstring,
  32. logging,
  33. )
  34. from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring(
  38. custom_intro="""
  39. Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
  40. layer to the pooler_output.
  41. """
  42. )
  43. class SamVisionEncoderOutput(ModelOutput):
  44. r"""
  45. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  46. The image embeddings obtained by applying the projection layer to the pooler_output.
  47. """
  48. image_embeds: Optional[torch.FloatTensor] = None
  49. last_hidden_state: Optional[torch.FloatTensor] = None
  50. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  51. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  52. @dataclass
  53. @auto_docstring(
  54. custom_intro="""
  55. Base class for Segment-Anything model's output
  56. """
  57. )
  58. class SamImageSegmentationOutput(ModelOutput):
  59. r"""
  60. iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
  61. The iou scores of the predicted masks.
  62. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
  63. The predicted low resolutions masks. Needs to be post-processed by the processor
  64. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  65. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  66. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  67. Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
  68. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  69. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  70. sequence_length)`.
  71. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  72. heads.
  73. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  74. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  75. sequence_length)`.
  76. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  77. heads.
  78. """
  79. iou_scores: Optional[torch.FloatTensor] = None
  80. pred_masks: Optional[torch.FloatTensor] = None
  81. vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  82. vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  83. mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  84. class SamPatchEmbeddings(nn.Module):
  85. """
  86. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  87. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  88. Transformer.
  89. """
  90. def __init__(self, config):
  91. super().__init__()
  92. image_size, patch_size = config.image_size, config.patch_size
  93. num_channels, hidden_size = config.num_channels, config.hidden_size
  94. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  95. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  96. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  97. self.image_size = image_size
  98. self.patch_size = patch_size
  99. self.num_channels = num_channels
  100. self.num_patches = num_patches
  101. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  102. def forward(self, pixel_values):
  103. batch_size, num_channels, height, width = pixel_values.shape
  104. if num_channels != self.num_channels:
  105. raise ValueError(
  106. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  107. )
  108. if height != self.image_size[0] or width != self.image_size[1]:
  109. raise ValueError(
  110. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  111. )
  112. embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
  113. return embeddings
  114. class SamMLPBlock(nn.Module):
  115. def __init__(self, config):
  116. super().__init__()
  117. self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
  118. self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
  119. self.act = ACT2FN[config.hidden_act]
  120. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  121. hidden_states = self.lin1(hidden_states)
  122. hidden_states = self.act(hidden_states)
  123. hidden_states = self.lin2(hidden_states)
  124. return hidden_states
  125. # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
  126. class SamLayerNorm(nn.LayerNorm):
  127. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  128. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  129. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  130. """
  131. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  132. super().__init__(normalized_shape, eps=eps, **kwargs)
  133. if data_format not in ["channels_last", "channels_first"]:
  134. raise NotImplementedError(f"Unsupported data format: {data_format}")
  135. self.data_format = data_format
  136. def forward(self, features: torch.Tensor) -> torch.Tensor:
  137. """
  138. Args:
  139. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  140. """
  141. if self.data_format == "channels_first":
  142. features = features.permute(0, 2, 3, 1)
  143. features = super().forward(features)
  144. features = features.permute(0, 3, 1, 2)
  145. else:
  146. features = super().forward(features)
  147. return features
  148. def eager_attention_forward(
  149. module: nn.Module,
  150. query: torch.Tensor,
  151. key: torch.Tensor,
  152. value: torch.Tensor,
  153. attention_mask: Optional[torch.Tensor],
  154. scaling: float,
  155. dropout: float = 0.0,
  156. **kwargs,
  157. ):
  158. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  159. if attention_mask is not None:
  160. attn_weights = attn_weights + attention_mask
  161. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  162. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  163. attn_output = torch.matmul(attn_weights, value)
  164. attn_output = attn_output.transpose(1, 2).contiguous()
  165. return attn_output, attn_weights
  166. class SamAttention(nn.Module):
  167. """
  168. SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  169. values.
  170. """
  171. def __init__(self, config, downsample_rate=None):
  172. super().__init__()
  173. self.config = config
  174. self.hidden_size = config.hidden_size
  175. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  176. self.internal_dim = config.hidden_size // downsample_rate
  177. self.num_attention_heads = config.num_attention_heads
  178. if self.internal_dim % config.num_attention_heads != 0:
  179. raise ValueError("num_attention_heads must divide hidden_size.")
  180. self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
  181. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  182. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  183. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  184. self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
  185. self.is_causal = False
  186. def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
  187. batch, point_batch_size, n_tokens, channel = hidden_states.shape
  188. c_per_head = channel // num_attention_heads
  189. hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
  190. return hidden_states.transpose(1, 2)
  191. def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
  192. batch, n_tokens, n_heads, c_per_head = hidden_states.shape
  193. return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
  194. def forward(
  195. self,
  196. query: Tensor,
  197. key: Tensor,
  198. value: Tensor,
  199. attention_similarity: Optional[Tensor] = None,
  200. **kwargs: Unpack[TransformersKwargs],
  201. ) -> Tensor:
  202. # Input projections
  203. query = self.q_proj(query)
  204. key = self.k_proj(key)
  205. value = self.v_proj(value)
  206. point_batch_size = query.shape[1]
  207. # Separate into heads
  208. query = self._separate_heads(query, self.num_attention_heads)
  209. key = self._separate_heads(key, self.num_attention_heads)
  210. value = self._separate_heads(value, self.num_attention_heads)
  211. # SamAttention
  212. attention_interface: Callable = eager_attention_forward
  213. if self.config._attn_implementation != "eager":
  214. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  215. attn_output, attn_weights = attention_interface(
  216. self,
  217. query,
  218. key,
  219. value,
  220. attention_mask=attention_similarity,
  221. dropout=0.0,
  222. scaling=self.scaling,
  223. is_causal=self.is_causal,
  224. **kwargs,
  225. )
  226. attn_output = self._recombine_heads(attn_output, point_batch_size)
  227. attn_output = self.out_proj(attn_output)
  228. return attn_output, attn_weights
  229. class SamTwoWayAttentionBlock(nn.Module):
  230. def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
  231. """
  232. A transformer block with four layers:
  233. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  234. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  235. Arguments:
  236. config (`SamMaskDecoderConfig`):
  237. The configuration file used to instantiate the block
  238. attention_downsample_rate (*optionalk*, int, defaults to 2):
  239. The downsample ratio of the block used to reduce the inner dim of the attention.
  240. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  241. Whether or not to skip the addition of the query_point_embedding on the first layer.
  242. """
  243. super().__init__()
  244. self.hidden_size = config.hidden_size
  245. self.layer_norm_eps = config.layer_norm_eps
  246. self.self_attn = SamAttention(config, downsample_rate=1)
  247. self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  248. self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
  249. self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  250. self.mlp = SamMLPBlock(config)
  251. self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  252. self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  253. self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
  254. self.skip_first_layer_pe = skip_first_layer_pe
  255. def forward(
  256. self,
  257. queries: Tensor,
  258. keys: Tensor,
  259. query_point_embedding: Tensor,
  260. key_point_embedding: Tensor,
  261. attention_similarity: Tensor,
  262. **kwargs: Unpack[TransformersKwargs],
  263. ):
  264. # Self attention block
  265. if self.skip_first_layer_pe:
  266. queries, _ = self.self_attn(query=queries, key=queries, value=queries)
  267. else:
  268. query = queries + query_point_embedding
  269. attn_out, _ = self.self_attn(query=query, key=query, value=queries)
  270. queries = queries + attn_out
  271. queries = self.layer_norm1(queries)
  272. # Cross attention block, tokens attending to image embedding
  273. query = queries + query_point_embedding
  274. key = keys + key_point_embedding
  275. attn_out, _ = self.cross_attn_token_to_image(
  276. query=query, key=key, value=keys, attention_similarity=attention_similarity
  277. )
  278. queries = queries + attn_out
  279. queries = self.layer_norm2(queries)
  280. # MLP block
  281. mlp_out = self.mlp(queries)
  282. queries = queries + mlp_out
  283. queries = self.layer_norm3(queries)
  284. # Cross attention block, image embedding attending to tokens
  285. query = queries + query_point_embedding
  286. key = keys + key_point_embedding
  287. attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  288. keys = keys + attn_out
  289. keys = self.layer_norm4(keys)
  290. return queries, keys, attn_out
  291. class SamTwoWayTransformer(nn.Module):
  292. def __init__(self, config: SamMaskDecoderConfig):
  293. super().__init__()
  294. self.config = config
  295. self.num_hidden_layers = config.num_hidden_layers
  296. self.layers = nn.ModuleList()
  297. for i in range(self.num_hidden_layers):
  298. self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  299. self.final_attn_token_to_image = SamAttention(config)
  300. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  301. def forward(
  302. self,
  303. point_embeddings: Tensor,
  304. image_embeddings: Tensor,
  305. image_positional_embeddings: Tensor,
  306. attention_similarity: Tensor,
  307. target_embedding=None,
  308. **kwargs: Unpack[TransformersKwargs],
  309. ) -> Union[tuple, BaseModelOutput]:
  310. if image_embeddings is None:
  311. raise ValueError("You have to specify an image_embedding")
  312. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  313. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  314. # Prepare queries
  315. queries = point_embeddings
  316. keys = image_embeddings
  317. # Apply transformer blocks and final layernorm
  318. for layer in self.layers:
  319. if target_embedding is not None:
  320. queries += target_embedding
  321. queries, keys, _ = layer(
  322. queries=queries,
  323. keys=keys,
  324. query_point_embedding=point_embeddings,
  325. key_point_embedding=image_positional_embeddings,
  326. attention_similarity=attention_similarity,
  327. **kwargs,
  328. )
  329. # Apply the final attention layer from the points to the image
  330. query = queries + point_embeddings
  331. key = keys + image_positional_embeddings
  332. attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
  333. queries = queries + attn_out
  334. queries = self.layer_norm_final_attn(queries)
  335. return queries, keys
  336. class SamFeedForward(nn.Module):
  337. def __init__(
  338. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
  339. ):
  340. super().__init__()
  341. self.num_layers = num_layers
  342. self.activation = nn.ReLU()
  343. self.proj_in = nn.Linear(input_dim, hidden_dim)
  344. self.proj_out = nn.Linear(hidden_dim, output_dim)
  345. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  346. self.sigmoid_output = sigmoid_output
  347. def forward(self, hidden_states):
  348. hidden_states = self.proj_in(hidden_states)
  349. hidden_states = self.activation(hidden_states)
  350. for layer in self.layers:
  351. hidden_states = self.activation(layer(hidden_states))
  352. hidden_states = self.proj_out(hidden_states)
  353. if self.sigmoid_output:
  354. hidden_states = F.sigmoid(hidden_states)
  355. return hidden_states
  356. class SamMaskDecoder(nn.Module):
  357. def __init__(self, config: SamMaskDecoderConfig):
  358. super().__init__()
  359. self.config = config
  360. self.hidden_size = config.hidden_size
  361. self.num_multimask_outputs = config.num_multimask_outputs
  362. self.num_mask_tokens = config.num_multimask_outputs + 1
  363. self.iou_token = nn.Embedding(1, self.hidden_size)
  364. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  365. self.transformer = SamTwoWayTransformer(config)
  366. # should we create a new class for this?
  367. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  368. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  369. self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
  370. self.activation = nn.GELU()
  371. mlps_list = []
  372. for _ in range(self.num_mask_tokens):
  373. mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  374. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  375. self.iou_prediction_head = SamFeedForward(
  376. self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
  377. )
  378. def forward(
  379. self,
  380. image_embeddings: torch.Tensor,
  381. image_positional_embeddings: torch.Tensor,
  382. sparse_prompt_embeddings: torch.Tensor,
  383. dense_prompt_embeddings: torch.Tensor,
  384. multimask_output: bool,
  385. attention_similarity: Optional[torch.Tensor] = None,
  386. target_embedding: Optional[torch.Tensor] = None,
  387. ) -> tuple[torch.Tensor, torch.Tensor]:
  388. """
  389. Predict masks given image and prompt embeddings.
  390. Args:
  391. image_embeddings (`torch.Tensor`):
  392. the embeddings from the image encoder
  393. image_positional_embedding (`torch.Tensor`):
  394. positional encoding with the shape of image_embeddings
  395. sparse_prompt_embeddings (`torch.Tensor`):
  396. The embeddings of the points and boxes
  397. dense_prompt_embeddings (`torch.Tensor`):
  398. the embeddings of the mask inputs
  399. multimask_output (bool):
  400. Whether to return multiple masks or a single mask.
  401. """
  402. batch_size, num_channels, height, width = image_embeddings.shape
  403. point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
  404. # Concatenate output tokens
  405. output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
  406. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  407. if sparse_prompt_embeddings is not None:
  408. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  409. else:
  410. tokens = output_tokens
  411. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  412. # Expand per-image data in batch direction to be per-point
  413. image_embeddings = image_embeddings + dense_prompt_embeddings
  414. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
  415. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  416. # Run the transformer, image_positional_embedding are consumed
  417. point_embedding, image_embeddings = self.transformer(
  418. point_embeddings=point_embeddings,
  419. image_embeddings=image_embeddings,
  420. image_positional_embeddings=image_positional_embeddings,
  421. attention_similarity=attention_similarity,
  422. target_embedding=target_embedding,
  423. )
  424. iou_token_out = point_embedding[:, :, 0, :]
  425. mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
  426. # Upscale mask embeddings and predict masks using the mask tokens
  427. image_embeddings = image_embeddings.transpose(2, 3).reshape(
  428. batch_size * point_batch_size, num_channels, height, width
  429. )
  430. upscaled_embedding = self.upscale_conv1(image_embeddings)
  431. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  432. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
  433. hyper_in_list = []
  434. for i in range(self.num_mask_tokens):
  435. current_mlp = self.output_hypernetworks_mlps[i]
  436. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  437. hyper_in = torch.stack(hyper_in_list, dim=2)
  438. _, num_channels, height, width = upscaled_embedding.shape
  439. upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
  440. masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
  441. # Generate mask quality predictions
  442. iou_pred = self.iou_prediction_head(iou_token_out)
  443. # Select the correct mask or masks for output
  444. if multimask_output:
  445. mask_slice = slice(1, None)
  446. else:
  447. mask_slice = slice(0, 1)
  448. masks = masks[:, :, mask_slice, :, :]
  449. iou_pred = iou_pred[:, :, mask_slice]
  450. return masks, iou_pred
  451. class SamPositionalEmbedding(nn.Module):
  452. def __init__(self, config):
  453. super().__init__()
  454. self.scale = config.hidden_size // 2
  455. self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
  456. def forward(self, input_coords, input_shape=None):
  457. """Positionally encode points that are normalized to [0,1]."""
  458. coordinates = input_coords.clone()
  459. if input_shape is not None:
  460. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  461. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  462. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  463. coordinates = 2 * coordinates - 1
  464. coordinates = coordinates.to(self.positional_embedding.dtype)
  465. coordinates = coordinates @ self.positional_embedding
  466. coordinates = 2 * np.pi * coordinates
  467. # outputs d_1 x ... x d_n x channel shape
  468. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  469. class SamMaskEmbedding(nn.Module):
  470. def __init__(self, config: SamPromptEncoderConfig):
  471. super().__init__()
  472. self.mask_input_channels = config.mask_input_channels // 4
  473. self.activation = ACT2FN[config.hidden_act]
  474. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  475. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  476. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  477. self.layer_norm1 = SamLayerNorm(
  478. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  479. )
  480. self.layer_norm2 = SamLayerNorm(
  481. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  482. )
  483. def forward(self, masks):
  484. hidden_states = self.conv1(masks)
  485. hidden_states = self.layer_norm1(hidden_states)
  486. hidden_states = self.activation(hidden_states)
  487. hidden_states = self.conv2(hidden_states)
  488. hidden_states = self.layer_norm2(hidden_states)
  489. hidden_states = self.activation(hidden_states)
  490. dense_embeddings = self.conv3(hidden_states)
  491. return dense_embeddings
  492. class SamPromptEncoder(nn.Module):
  493. def __init__(self, config: SamConfig):
  494. super().__init__()
  495. self.shared_embedding = SamPositionalEmbedding(config.vision_config)
  496. config = config.prompt_encoder_config
  497. self.mask_embed = SamMaskEmbedding(config)
  498. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  499. self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
  500. self.input_image_size = config.image_size
  501. self.point_embed = nn.ModuleList(
  502. [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
  503. )
  504. self.hidden_size = config.hidden_size
  505. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  506. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  507. """Embeds point prompts."""
  508. points = points + 0.5 # Shift to center of pixel
  509. if pad:
  510. target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
  511. target_labels_shape = (points.shape[0], points.shape[1], 1)
  512. padding_point = torch.zeros(target_point_shape, device=points.device)
  513. padding_label = -torch.ones(target_labels_shape, device=labels.device)
  514. points = torch.cat([points, padding_point], dim=2)
  515. labels = torch.cat([labels, padding_label], dim=2)
  516. input_shape = (self.input_image_size, self.input_image_size)
  517. point_embedding = self.shared_embedding(points, input_shape)
  518. # torch.where and expanding the labels tensor is required by the ONNX export
  519. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  520. # This is required for the ONNX export. The dtype, device need to be explicitly
  521. # specified as otherwise torch.onnx.export interprets as double
  522. point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
  523. point_embedding = torch.where(
  524. (labels == 0)[:, :, :, None],
  525. point_embedding + self.point_embed[0].weight[None, None, :, :],
  526. point_embedding,
  527. )
  528. point_embedding = torch.where(
  529. (labels == 1)[:, :, :, None],
  530. point_embedding + self.point_embed[1].weight[None, None, :, :],
  531. point_embedding,
  532. )
  533. return point_embedding
  534. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  535. """Embeds box prompts."""
  536. boxes = boxes + 0.5 # Shift to center of pixel
  537. batch_size, nb_boxes = boxes.shape[:2]
  538. coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
  539. input_shape = (self.input_image_size, self.input_image_size)
  540. corner_embedding = self.shared_embedding(coords, input_shape)
  541. corner_embedding[:, :, 0, :] += self.point_embed[2].weight
  542. corner_embedding[:, :, 1, :] += self.point_embed[3].weight
  543. return corner_embedding
  544. def forward(
  545. self,
  546. input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
  547. input_labels: Optional[torch.Tensor],
  548. input_boxes: Optional[torch.Tensor],
  549. input_masks: Optional[torch.Tensor],
  550. ) -> tuple[torch.Tensor, torch.Tensor]:
  551. """
  552. Embeds different types of prompts, returning both sparse and dense embeddings.
  553. Args:
  554. points (`torch.Tensor`, *optional*):
  555. point coordinates and labels to embed.
  556. boxes (`torch.Tensor`, *optional*):
  557. boxes to embed
  558. masks (`torch.Tensor`, *optional*):
  559. masks to embed
  560. """
  561. sparse_embeddings = None
  562. batch_size = 1
  563. if input_points is not None:
  564. batch_size = input_points.shape[0]
  565. if input_labels is None:
  566. raise ValueError("If points are provided, labels must also be provided.")
  567. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  568. sparse_embeddings = point_embeddings
  569. if input_boxes is not None:
  570. batch_size = input_boxes.shape[0]
  571. box_embeddings = self._embed_boxes(input_boxes)
  572. if sparse_embeddings is None:
  573. sparse_embeddings = box_embeddings
  574. else:
  575. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  576. if input_masks is not None:
  577. dense_embeddings = self.mask_embed(input_masks)
  578. else:
  579. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  580. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  581. )
  582. return sparse_embeddings, dense_embeddings
  583. class SamVisionAttention(nn.Module):
  584. """Multi-head Attention block with relative position embeddings."""
  585. def __init__(self, config, window_size):
  586. super().__init__()
  587. input_size = (
  588. (config.image_size // config.patch_size, config.image_size // config.patch_size)
  589. if window_size == 0
  590. else (window_size, window_size)
  591. )
  592. self.num_attention_heads = config.num_attention_heads
  593. head_dim = config.hidden_size // config.num_attention_heads
  594. self.scale = head_dim**-0.5
  595. self.dropout = config.attention_dropout
  596. self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
  597. self.proj = nn.Linear(config.hidden_size, config.hidden_size)
  598. self.use_rel_pos = config.use_rel_pos
  599. if self.use_rel_pos:
  600. if input_size is None:
  601. raise ValueError("Input size must be provided if using relative positional encoding.")
  602. # initialize relative positional embeddings
  603. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  604. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  605. def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  606. """
  607. Get relative positional embeddings according to the relative positions of
  608. query and key sizes.
  609. Args:
  610. q_size (int):
  611. size of the query.
  612. k_size (int):
  613. size of key k.
  614. rel_pos (`torch.Tensor`):
  615. relative position embeddings (L, channel).
  616. Returns:
  617. Extracted positional embeddings according to relative positions.
  618. """
  619. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  620. # Interpolate rel pos.
  621. rel_pos_resized = F.interpolate(
  622. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  623. size=max_rel_dist,
  624. mode="linear",
  625. )
  626. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  627. # Scale the coords with short length if shapes for q and k are different.
  628. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  629. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  630. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  631. return rel_pos_resized[relative_coords.long()]
  632. def get_decomposed_rel_pos(
  633. self,
  634. query: torch.Tensor,
  635. rel_pos_h: torch.Tensor,
  636. rel_pos_w: torch.Tensor,
  637. q_size: tuple[int, int],
  638. k_size: tuple[int, int],
  639. ) -> torch.Tensor:
  640. """
  641. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  642. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
  643. Args:
  644. query (`torch.Tensor`):
  645. query q in the attention layer with shape (batch_size, query_height * query_width, channel).
  646. rel_pos_h (`torch.Tensor`):
  647. relative position embeddings (Lh, channel) for height axis.
  648. rel_pos_w (`torch.Tensor`):
  649. relative position embeddings (Lw, channel) for width axis.
  650. q_size (tuple):
  651. spatial sequence size of query q with (query_height, query_width).
  652. k_size (tuple):
  653. spatial sequence size of key k with (key_height, key_width).
  654. Returns:
  655. decomposed_rel_pos (`torch.Tensor`):
  656. decomposed relative position embeddings.
  657. """
  658. query_height, query_width = q_size
  659. key_height, key_width = k_size
  660. relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
  661. relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
  662. batch_size, _, dim = query.shape
  663. reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
  664. rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
  665. rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
  666. decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  667. return decomposed_rel_pos
  668. def forward(self, hidden_states: torch.Tensor, output_attentions=None) -> tuple[torch.Tensor, torch.Tensor]:
  669. batch_size, height, width, _ = hidden_states.shape
  670. # qkv with shape (3, batch_size, nHead, height * width, channel)
  671. qkv = (
  672. self.qkv(hidden_states)
  673. .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  674. .permute(2, 0, 3, 1, 4)
  675. )
  676. # q, k, v with shape (batch_size * nHead, height * width, channel)
  677. query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
  678. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  679. if self.use_rel_pos:
  680. decomposed_rel_pos = self.get_decomposed_rel_pos(
  681. query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  682. )
  683. decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
  684. attn_weights = attn_weights + decomposed_rel_pos
  685. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  686. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  687. attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
  688. attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
  689. attn_output = self.proj(attn_output)
  690. return attn_output, attn_weights
  691. class SamVisionSdpaAttention(SamVisionAttention):
  692. """
  693. Multi-head Attention block with relative position embeddings.
  694. Using SDPA instead of the default attention.
  695. """
  696. def __init__(self, config, window_size):
  697. super().__init__(config, window_size)
  698. def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
  699. if output_attentions:
  700. logger.warning_once(
  701. "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
  702. "`output_attentions=True`. Falling back to the manual attention implementation, but "
  703. "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
  704. 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  705. )
  706. return super().forward(
  707. hidden_states=hidden_states,
  708. output_attentions=output_attentions,
  709. )
  710. batch_size, height, width, _ = hidden_states.shape
  711. # qkv with shape (3, B, nHead, H * W, C)
  712. qkv = (
  713. self.qkv(hidden_states)
  714. .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  715. .permute(2, 0, 3, 1, 4)
  716. )
  717. # q, k, v with shape (B * nHead, H * W, C)
  718. query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
  719. attn_bias = None
  720. if self.use_rel_pos:
  721. decomposed_rel_pos = self.get_decomposed_rel_pos(
  722. query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  723. )
  724. decomposed_rel_pos = decomposed_rel_pos.reshape(
  725. batch_size, self.num_attention_heads, height * width, height * width
  726. )
  727. attn_bias = decomposed_rel_pos
  728. query = query.view(batch_size, self.num_attention_heads, height * width, -1)
  729. key = key.view(batch_size, self.num_attention_heads, height * width, -1)
  730. value = value.view(batch_size, self.num_attention_heads, height * width, -1)
  731. attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
  732. attn_output = (
  733. attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
  734. .permute(0, 2, 3, 1, 4)
  735. .reshape(batch_size, height, width, -1)
  736. )
  737. attn_output = self.proj(attn_output)
  738. return attn_output, None
  739. SAM_VISION_ATTENTION_CLASSES = {
  740. "eager": SamVisionAttention,
  741. "sdpa": SamVisionSdpaAttention,
  742. }
  743. class SamVisionLayer(GradientCheckpointingLayer):
  744. def __init__(self, config, window_size):
  745. super().__init__()
  746. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  747. self.attn = SAM_VISION_ATTENTION_CLASSES[config._attn_implementation](config, window_size)
  748. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  749. self.mlp = SamMLPBlock(config)
  750. self.window_size = window_size
  751. def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> tuple[torch.Tensor, tuple[int, int]]:
  752. """
  753. Args:
  754. Partition into non-overlapping windows with padding if needed.
  755. hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
  756. size.
  757. Returns:
  758. windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
  759. (pad_height, pad_width): padded height and width before partition
  760. """
  761. batch_size, height, width, channel = hidden_states.shape
  762. pad_h = (window_size - height % window_size) % window_size
  763. pad_w = (window_size - width % window_size) % window_size
  764. hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
  765. pad_height, pad_width = height + pad_h, width + pad_w
  766. hidden_states = hidden_states.reshape(
  767. batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
  768. )
  769. windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
  770. return windows, (pad_height, pad_width)
  771. def window_unpartition(
  772. self, windows: torch.Tensor, window_size: int, padding_shape: tuple[int, int], original_shape: tuple[int, int]
  773. ) -> torch.Tensor:
  774. """
  775. Args:
  776. Window unpartition into original sequences and removing padding.
  777. hidden_states (tensor):
  778. input tokens with [batch_size * num_windows, window_size, window_size, channel].
  779. window_size (int):
  780. window size.
  781. padding_shape (Tuple):
  782. padded height and width (pad_height, pad_width).
  783. original_shape (Tuple): original height and width (height, width) before padding.
  784. Returns:
  785. hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
  786. """
  787. pad_height, pad_width = padding_shape
  788. height, width = original_shape
  789. batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
  790. hidden_states = windows.reshape(
  791. batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
  792. )
  793. hidden_states = (
  794. hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
  795. )
  796. hidden_states = hidden_states[:, :height, :width, :].contiguous()
  797. return hidden_states
  798. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.FloatTensor]:
  799. residual = hidden_states
  800. hidden_states = self.layer_norm1(hidden_states)
  801. # Window partition
  802. if self.window_size > 0:
  803. height, width = hidden_states.shape[1], hidden_states.shape[2]
  804. hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
  805. hidden_states, attn_weights = self.attn(
  806. hidden_states=hidden_states,
  807. )
  808. # Reverse window partition
  809. if self.window_size > 0:
  810. hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
  811. hidden_states = residual + hidden_states
  812. layernorm_output = self.layer_norm2(hidden_states)
  813. hidden_states = hidden_states + self.mlp(layernorm_output)
  814. return hidden_states
  815. class SamVisionNeck(nn.Module):
  816. def __init__(self, config: SamVisionConfig):
  817. super().__init__()
  818. self.config = config
  819. self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
  820. self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
  821. self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
  822. self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
  823. def forward(self, hidden_states):
  824. hidden_states = hidden_states.permute(0, 3, 1, 2)
  825. hidden_states = self.conv1(hidden_states)
  826. hidden_states = self.layer_norm1(hidden_states)
  827. hidden_states = self.conv2(hidden_states)
  828. hidden_states = self.layer_norm2(hidden_states)
  829. return hidden_states
  830. @auto_docstring
  831. class SamPreTrainedModel(PreTrainedModel):
  832. config: SamConfig
  833. base_model_prefix = "sam"
  834. main_input_name = "pixel_values"
  835. _no_split_modules = ["SamVisionAttention"]
  836. supports_gradient_checkpointing = True
  837. _supports_sdpa = True
  838. def _init_weights(self, module: nn.Module):
  839. super()._init_weights(module)
  840. if isinstance(module, SamVisionAttention):
  841. if module.use_rel_pos:
  842. module.rel_pos_h.data.zero_()
  843. module.rel_pos_w.data.zero_()
  844. elif isinstance(module, SamVisionEncoder):
  845. if self.config.use_abs_pos:
  846. module.pos_embed.data.zero_()
  847. class SamVisionEncoder(SamPreTrainedModel):
  848. _can_record_outputs = {"hidden_states": SamVisionLayer, "attentions": SamVisionAttention}
  849. def __init__(self, config: SamVisionConfig):
  850. super().__init__(config)
  851. self.config = config
  852. self.image_size = config.image_size
  853. self.patch_embed = SamPatchEmbeddings(config)
  854. self.pos_embed = None
  855. if config.use_abs_pos:
  856. # Initialize absolute positional embedding with pretrain image size.
  857. self.pos_embed = nn.Parameter(
  858. torch.zeros(
  859. 1,
  860. config.image_size // config.patch_size,
  861. config.image_size // config.patch_size,
  862. config.hidden_size,
  863. )
  864. )
  865. self.layers = nn.ModuleList()
  866. for i in range(config.num_hidden_layers):
  867. layer = SamVisionLayer(
  868. config,
  869. window_size=config.window_size if i not in config.global_attn_indexes else 0,
  870. )
  871. self.layers.append(layer)
  872. self.neck = SamVisionNeck(config)
  873. self.gradient_checkpointing = False
  874. def get_input_embeddings(self):
  875. return self.patch_embed
  876. @check_model_inputs(tie_last_hidden_states=False)
  877. def forward(
  878. self, pixel_values: Optional[torch.FloatTensor] = None, **kwargs: Unpack[TransformersKwargs]
  879. ) -> SamVisionEncoderOutput:
  880. if pixel_values is None:
  881. raise ValueError("You have to specify pixel_values")
  882. hidden_states = self.patch_embed(pixel_values)
  883. if self.pos_embed is not None:
  884. hidden_states = hidden_states + self.pos_embed
  885. for layer_module in self.layers:
  886. hidden_states = layer_module(hidden_states)
  887. hidden_states = self.neck(hidden_states)
  888. return SamVisionEncoderOutput(
  889. last_hidden_state=hidden_states,
  890. )
  891. @auto_docstring(
  892. custom_intro="""
  893. The vision model from Sam without any head or projection on top.
  894. """
  895. )
  896. class SamVisionModel(SamPreTrainedModel):
  897. config: SamVisionConfig
  898. main_input_name = "pixel_values"
  899. def __init__(self, config: SamVisionConfig):
  900. super().__init__(config)
  901. self.vision_encoder = SamVisionEncoder(config)
  902. self.post_init()
  903. def get_input_embeddings(self) -> nn.Module:
  904. return self.vision_encoder.patch_embed
  905. @auto_docstring
  906. def forward(
  907. self,
  908. pixel_values: Optional[torch.FloatTensor] = None,
  909. **kwargs: Unpack[TransformersKwargs],
  910. ) -> Union[tuple, SamVisionEncoderOutput]:
  911. return self.vision_encoder(pixel_values, **kwargs)
  912. @auto_docstring(
  913. custom_intro="""
  914. Segment Anything Model (SAM) for generating segmentation masks, given an input image and
  915. input points and labels, boxes, or masks.
  916. """
  917. )
  918. class SamModel(SamPreTrainedModel):
  919. _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
  920. # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
  921. _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
  922. _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)}
  923. def __init__(self, config: SamConfig):
  924. super().__init__(config)
  925. self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
  926. self.vision_encoder = SamVisionEncoder(config.vision_config)
  927. self.prompt_encoder = SamPromptEncoder(config)
  928. # The module using it is not a PreTrainedModel subclass so we need this
  929. config.mask_decoder_config._attn_implementation = config._attn_implementation
  930. self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
  931. self.post_init()
  932. def _tie_weights(self):
  933. self.prompt_encoder.shared_embedding.positional_embedding.data = (
  934. self.shared_image_embedding.positional_embedding.data
  935. )
  936. def get_input_embeddings(self):
  937. return self.vision_encoder.get_input_embeddings()
  938. def get_image_wide_positional_embeddings(self):
  939. size = self.config.prompt_encoder_config.image_embedding_size
  940. target_device = self.shared_image_embedding.positional_embedding.device
  941. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  942. grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
  943. y_embed = grid.cumsum(dim=0) - 0.5
  944. x_embed = grid.cumsum(dim=1) - 0.5
  945. y_embed = y_embed / size
  946. x_embed = x_embed / size
  947. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  948. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  949. @torch.no_grad()
  950. def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs]):
  951. r"""
  952. Returns the image embeddings by passing the pixel values through the vision encoder.
  953. Args:
  954. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  955. Input pixel values
  956. """
  957. vision_output = self.vision_encoder(
  958. pixel_values,
  959. **kwargs,
  960. )
  961. image_embeddings = vision_output[0]
  962. return image_embeddings
  963. @torch.no_grad()
  964. def get_prompt_embeddings(
  965. self,
  966. input_points: Optional[torch.FloatTensor] = None,
  967. input_labels: Optional[torch.LongTensor] = None,
  968. input_boxes: Optional[torch.FloatTensor] = None,
  969. input_masks: Optional[torch.LongTensor] = None,
  970. ):
  971. r"""
  972. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  973. Args:
  974. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  975. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  976. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  977. point. The model will output `point_batch_size` times 3 masks in total.
  978. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  979. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  980. processor, or can be fed by the user.
  981. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  982. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  983. processor. users can also pass manually the input boxes.
  984. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  985. Optional input masks for the prompt encoder.
  986. """
  987. prompt_output = self.prompt_encoder(
  988. input_points=input_points,
  989. input_labels=input_labels,
  990. input_boxes=input_boxes,
  991. input_masks=input_masks,
  992. )
  993. return prompt_output
  994. @check_model_inputs()
  995. @auto_docstring
  996. def forward(
  997. self,
  998. pixel_values: Optional[torch.FloatTensor] = None,
  999. input_points: Optional[torch.FloatTensor] = None,
  1000. input_labels: Optional[torch.LongTensor] = None,
  1001. input_boxes: Optional[torch.FloatTensor] = None,
  1002. input_masks: Optional[torch.LongTensor] = None,
  1003. image_embeddings: Optional[torch.FloatTensor] = None,
  1004. multimask_output: bool = True,
  1005. attention_similarity: Optional[torch.FloatTensor] = None,
  1006. target_embedding: Optional[torch.FloatTensor] = None,
  1007. **kwargs: Unpack[TransformersKwargs],
  1008. ) -> SamImageSegmentationOutput:
  1009. r"""
  1010. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  1011. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1012. better results. The points can be obtained by passing a list of list of list to the processor that will
  1013. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  1014. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  1015. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1016. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1017. coordinates of the point. If a different number of points is passed either for each image, or for each
  1018. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1019. computation of the embedding will be skipped for these points using the labels.
  1020. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  1021. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1022. official implementation, there are 3 types of labels
  1023. - `1`: the point is a point that contains the object of interest
  1024. - `0`: the point is a point that does not contain the object of interest
  1025. - `-1`: the point corresponds to the background
  1026. We added the label:
  1027. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1028. The padding labels should be automatically done by the processor.
  1029. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  1030. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1031. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1032. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  1033. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  1034. In the order (`x1`, `y1`, `x2`, `y2`):
  1035. - `x1`: the x coordinate of the top left point of the input box
  1036. - `y1`: the y coordinate of the top left point of the input box
  1037. - `x2`: the x coordinate of the bottom right point of the input box
  1038. - `y2`: the y coordinate of the bottom right point of the input box
  1039. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  1040. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1041. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1042. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1043. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1044. Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
  1045. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1046. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  1047. multimask_output (`bool`, *optional*):
  1048. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1049. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1050. "best" mask, by specifying `multimask_output=False`.
  1051. attention_similarity (`torch.FloatTensor`, *optional*):
  1052. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  1053. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1054. target_embedding (`torch.FloatTensor`, *optional*):
  1055. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  1056. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1057. Example:
  1058. ```python
  1059. >>> from PIL import Image
  1060. >>> import requests
  1061. >>> from transformers import AutoModel, AutoProcessor
  1062. >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
  1063. >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
  1064. >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1065. >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
  1066. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1067. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1068. >>> # Get segmentation mask
  1069. >>> outputs = model(**inputs)
  1070. >>> # Postprocess masks
  1071. >>> masks = processor.post_process_masks(
  1072. ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
  1073. ... )
  1074. ```
  1075. """
  1076. if pixel_values is None and image_embeddings is None:
  1077. raise ValueError("Either pixel_values or image_embeddings must be provided.")
  1078. if pixel_values is not None and image_embeddings is not None:
  1079. raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
  1080. if input_points is not None and len(input_points.shape) != 4:
  1081. raise ValueError(
  1082. "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
  1083. f" got {input_points.shape}.",
  1084. )
  1085. if input_boxes is not None and len(input_boxes.shape) != 3:
  1086. raise ValueError(
  1087. "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
  1088. f" got {input_boxes.shape}.",
  1089. )
  1090. if input_points is not None and input_boxes is not None:
  1091. point_batch_size = input_points.shape[1]
  1092. box_batch_size = input_boxes.shape[1]
  1093. if point_batch_size != box_batch_size:
  1094. raise ValueError(
  1095. f"You should provide as many bounding boxes as input points per box. Got {point_batch_size} and {box_batch_size}."
  1096. )
  1097. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1098. # repeat with batch size
  1099. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
  1100. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1101. vision_attentions = None
  1102. vision_hidden_states = None
  1103. if pixel_values is not None:
  1104. vision_outputs: SamVisionEncoderOutput = self.vision_encoder(pixel_values, **kwargs)
  1105. image_embeddings = vision_outputs.last_hidden_state
  1106. vision_hidden_states = vision_outputs.hidden_states
  1107. vision_attentions = vision_outputs.attentions
  1108. if input_points is not None and input_labels is None:
  1109. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1110. if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
  1111. raise ValueError(
  1112. "The batch size of the image embeddings and the input points must be the same. ",
  1113. f"Got {image_embeddings.shape[0]} and {input_points.shape[0]} respectively.",
  1114. " if you want to pass multiple points for the same image, make sure that you passed ",
  1115. " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
  1116. " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
  1117. )
  1118. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1119. input_points=input_points,
  1120. input_labels=input_labels,
  1121. input_boxes=input_boxes,
  1122. input_masks=input_masks,
  1123. )
  1124. low_res_masks, iou_predictions = self.mask_decoder(
  1125. image_embeddings=image_embeddings,
  1126. image_positional_embeddings=image_positional_embeddings,
  1127. sparse_prompt_embeddings=sparse_embeddings,
  1128. dense_prompt_embeddings=dense_embeddings,
  1129. multimask_output=multimask_output,
  1130. attention_similarity=attention_similarity,
  1131. target_embedding=target_embedding,
  1132. )
  1133. return SamImageSegmentationOutput(
  1134. iou_scores=iou_predictions,
  1135. pred_masks=low_res_masks,
  1136. vision_hidden_states=vision_hidden_states,
  1137. vision_attentions=vision_attentions,
  1138. )
  1139. __all__ = ["SamVisionModel", "SamModel", "SamPreTrainedModel"]