modeling_sam2.py 71 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/sam2/modular_sam2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_sam2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Callable, Optional, Union
  24. import numpy as np
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from torch import Tensor
  29. from transformers.utils.generic import OutputRecorder
  30. from ...activations import ACT2FN
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutput
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...pytorch_utils import compile_compatible_method_lru_cache
  36. from ...utils import ModelOutput, auto_docstring
  37. from ...utils.generic import TransformersKwargs, check_model_inputs
  38. from ..auto import AutoModel
  39. from .configuration_sam2 import (
  40. Sam2Config,
  41. Sam2HieraDetConfig,
  42. Sam2MaskDecoderConfig,
  43. Sam2PromptEncoderConfig,
  44. Sam2VisionConfig,
  45. )
  46. @dataclass
  47. @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
  48. class Sam2VisionEncoderOutput(ModelOutput):
  49. r"""
  50. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  51. Sequence of hidden-states at the output of the last layer of the model.
  52. fpn_hidden_states (`tuple(torch.FloatTensor)`):
  53. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  54. `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
  55. fpn_position_encoding (`tuple(torch.FloatTensor)`):
  56. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  57. `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
  58. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  59. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  60. one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
  61. model at the output of each stage.
  62. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  63. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  64. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  65. the self-attention heads.
  66. """
  67. last_hidden_state: Optional[torch.FloatTensor] = None
  68. fpn_hidden_states: Optional[torch.FloatTensor] = None
  69. fpn_position_encoding: Optional[torch.FloatTensor] = None
  70. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  71. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  72. @dataclass
  73. @auto_docstring(custom_intro="Base class for the Sam2 model's output.")
  74. class Sam2ImageSegmentationOutput(ModelOutput):
  75. r"""
  76. iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
  77. The Intersection over Union (IoU) scores of the predicted masks.
  78. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
  79. The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
  80. by the processor to be brought to the original image size.
  81. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
  82. Logits for the object score, indicating if an object is present.
  83. image_embeddings (`tuple(torch.FloatTensor)`):
  84. The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
  85. tensor has shape `(batch_size, channels, height, width)`.
  86. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  87. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
  88. Hidden-states of the vision model at the output of each stage.
  89. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  90. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  91. Attentions weights of the vision model.
  92. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  93. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  94. Attentions weights of the mask decoder.
  95. """
  96. iou_scores: Optional[torch.FloatTensor] = None
  97. pred_masks: Optional[torch.FloatTensor] = None
  98. object_score_logits: Optional[torch.FloatTensor] = None
  99. image_embeddings: tuple[torch.FloatTensor, ...] = None
  100. vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  101. vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  102. mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  103. class Sam2PatchEmbeddings(nn.Module):
  104. r"""
  105. Turns pixel values into patch embeddings for transformer consumption.
  106. Args:
  107. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  108. Pixel values. Pixel values can be obtained using
  109. [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details.
  110. Returns:
  111. embeddings (`torch.FloatTensor`):
  112. Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
  113. """
  114. def __init__(self, config: Sam2HieraDetConfig):
  115. super().__init__()
  116. num_channels = config.num_channels
  117. hidden_size = config.hidden_size
  118. self.projection = nn.Conv2d(
  119. num_channels,
  120. hidden_size,
  121. kernel_size=config.patch_kernel_size,
  122. stride=config.patch_stride,
  123. padding=config.patch_padding,
  124. )
  125. def forward(self, pixel_values):
  126. _, num_channels, height, width = pixel_values.shape
  127. embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
  128. return embeddings
  129. # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
  130. class Sam2SinePositionEmbedding(nn.Module):
  131. """
  132. This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
  133. need paper, generalized to work on images.
  134. """
  135. def __init__(
  136. self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
  137. ):
  138. super().__init__()
  139. if scale is not None and normalize is False:
  140. raise ValueError("normalize should be True if scale is passed")
  141. self.num_pos_feats = num_pos_feats
  142. self.temperature = temperature
  143. self.normalize = normalize
  144. self.scale = 2 * math.pi if scale is None else scale
  145. @compile_compatible_method_lru_cache(maxsize=1)
  146. def forward(
  147. self,
  148. shape: torch.Size,
  149. device: Union[torch.device, str],
  150. dtype: torch.dtype,
  151. mask: Optional[Tensor] = None,
  152. ) -> Tensor:
  153. if mask is None:
  154. mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
  155. not_mask = (~mask).to(dtype)
  156. y_embed = not_mask.cumsum(1)
  157. x_embed = not_mask.cumsum(2)
  158. if self.normalize:
  159. eps = 1e-6
  160. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  161. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  162. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
  163. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
  164. pos_x = x_embed[:, :, :, None] / dim_t
  165. pos_y = y_embed[:, :, :, None] / dim_t
  166. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  167. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  168. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  169. return pos
  170. class Sam2VisionNeck(nn.Module):
  171. def __init__(self, config: Sam2VisionConfig):
  172. super().__init__()
  173. self.config = config
  174. self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
  175. self.convs = nn.ModuleList()
  176. for in_channels in config.backbone_channel_list:
  177. self.convs.append(
  178. nn.Conv2d(
  179. in_channels=in_channels,
  180. out_channels=config.fpn_hidden_size,
  181. kernel_size=config.fpn_kernel_size,
  182. stride=config.fpn_stride,
  183. padding=config.fpn_padding,
  184. ),
  185. )
  186. self.fpn_top_down_levels = config.fpn_top_down_levels
  187. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  188. fpn_hidden_states = ()
  189. fpn_position_encoding = ()
  190. # forward in top-down order (from low to high resolution)
  191. n = len(self.convs) - 1
  192. for i in range(n, -1, -1):
  193. lateral_features = hidden_states[i].permute(0, 3, 1, 2)
  194. lateral_features = self.convs[n - i](lateral_features)
  195. if i not in self.fpn_top_down_levels or i == n:
  196. prev_features = lateral_features
  197. else:
  198. top_down_features = F.interpolate(
  199. prev_features.to(dtype=torch.float32),
  200. scale_factor=2.0,
  201. mode="nearest",
  202. align_corners=None,
  203. antialias=False,
  204. ).to(lateral_features.dtype)
  205. prev_features = lateral_features + top_down_features
  206. prev_position_encoding = self.position_encoding(
  207. prev_features.shape, prev_features.device, prev_features.dtype
  208. ).to(prev_features.dtype)
  209. fpn_hidden_states += (prev_features,)
  210. fpn_position_encoding += (prev_position_encoding,)
  211. return fpn_hidden_states, fpn_position_encoding
  212. def eager_attention_forward(
  213. module: nn.Module,
  214. query: torch.Tensor,
  215. key: torch.Tensor,
  216. value: torch.Tensor,
  217. attention_mask: Optional[torch.Tensor],
  218. scaling: float,
  219. dropout: float = 0.0,
  220. **kwargs,
  221. ):
  222. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  223. if attention_mask is not None:
  224. attn_weights = attn_weights + attention_mask
  225. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  226. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  227. attn_output = torch.matmul(attn_weights, value)
  228. attn_output = attn_output.transpose(1, 2).contiguous()
  229. return attn_output, attn_weights
  230. def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor:
  231. if query_stride is None:
  232. return x
  233. # (B, H, W, C) -> (B, C, H, W)
  234. x = x.permute(0, 3, 1, 2)
  235. x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
  236. # (B, C, H', W') -> (B, H', W', C)
  237. x = x.permute(0, 2, 3, 1)
  238. return x
  239. class Sam2MultiScaleAttention(nn.Module):
  240. def __init__(
  241. self,
  242. config: Sam2HieraDetConfig,
  243. dim: int,
  244. dim_out: int,
  245. num_attention_heads: int,
  246. query_stride: Optional[tuple[int, int]] = None,
  247. ):
  248. super().__init__()
  249. self.config = config
  250. self.dim = dim
  251. self.dim_out = dim_out
  252. self.query_stride = query_stride
  253. self.num_attention_heads = num_attention_heads
  254. head_dim = dim_out // num_attention_heads
  255. self.scale = head_dim**-0.5
  256. self.qkv = nn.Linear(dim, dim_out * 3)
  257. self.proj = nn.Linear(dim_out, dim_out)
  258. self.is_causal = False
  259. def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
  260. batch_size, height, width, _ = hidden_states.shape
  261. # qkv with shape (B, H * W, 3, nHead, C)
  262. qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  263. # q, k, v with shape (B, H * W, nheads, C)
  264. query, key, value = torch.unbind(qkv, 2)
  265. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  266. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  267. # Q pooling (for downsample at stage changes)
  268. if self.query_stride:
  269. query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
  270. height, width = query.shape[1:3] # downsampled shape
  271. query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
  272. # transpose query, key, value to (B, nHead, H * W, C)
  273. query = query.transpose(1, 2)
  274. key = key.transpose(1, 2)
  275. value = value.transpose(1, 2)
  276. attention_interface: Callable = eager_attention_forward
  277. if self.config._attn_implementation != "eager":
  278. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  279. attn_output, _ = attention_interface(
  280. self,
  281. query,
  282. key,
  283. value,
  284. attention_mask=None,
  285. is_causal=self.is_causal,
  286. scaling=self.scale,
  287. **kwargs,
  288. )
  289. attn_output = attn_output.reshape(batch_size, height, width, -1)
  290. attn_output = self.proj(attn_output)
  291. return attn_output
  292. class Sam2FeedForward(nn.Module):
  293. def __init__(
  294. self,
  295. input_dim: int,
  296. hidden_dim: int,
  297. output_dim: int,
  298. num_layers: int,
  299. activation: str = "relu",
  300. sigmoid_output: bool = False,
  301. ):
  302. super().__init__()
  303. self.num_layers = num_layers
  304. self.activation = ACT2FN[activation]
  305. self.proj_in = nn.Linear(input_dim, hidden_dim)
  306. self.proj_out = nn.Linear(hidden_dim, output_dim)
  307. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  308. self.sigmoid_output = sigmoid_output
  309. def forward(self, hidden_states):
  310. hidden_states = self.proj_in(hidden_states)
  311. hidden_states = self.activation(hidden_states)
  312. for layer in self.layers:
  313. hidden_states = self.activation(layer(hidden_states))
  314. hidden_states = self.proj_out(hidden_states)
  315. if self.sigmoid_output:
  316. hidden_states = F.sigmoid(hidden_states)
  317. return hidden_states
  318. def window_partition(hidden_state, window_size):
  319. """
  320. Partition into non-overlapping windows with padding if needed.
  321. Args:
  322. hidden_state (`torch.Tensor`):
  323. Input tokens with [batch_size, height, width, num_channels].
  324. window_size (`int`):
  325. Window size.
  326. Returns:
  327. `tuple(torch.FloatTensor)` comprising various elements:
  328. - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
  329. - (padded_height, padded_width): padded height and width before partition
  330. """
  331. batch_size, height, width, num_channels = hidden_state.shape
  332. pad_height = (window_size - height % window_size) % window_size
  333. pad_width = (window_size - width % window_size) % window_size
  334. # Noop in case pad_width == 0 and pad_height == 0.
  335. hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
  336. padded_height, padded_width = height + pad_height, width + pad_width
  337. hidden_state = hidden_state.view(
  338. batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
  339. )
  340. windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  341. return windows, (padded_height, padded_width)
  342. def window_unpartition(windows, window_size, pad_height_width, height_width):
  343. """
  344. Window unpartition into original sequences and removing padding.
  345. Args:
  346. windows (`torch.Tensor`):
  347. Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
  348. window_size (`int`):
  349. Window size.
  350. pad_height_width (`tuple[int]`):
  351. Padded height and width (padded_height, padded_width).
  352. height_width (`tuple[int]`):
  353. Original height and width before padding.
  354. Returns:
  355. hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
  356. """
  357. padded_height, padded_width = pad_height_width
  358. height, width = height_width
  359. batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
  360. hidden_state = windows.view(
  361. batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
  362. )
  363. hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
  364. hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
  365. # We always have height <= padded_height and width <= padded_width
  366. hidden_state = hidden_state[:, :height, :width, :].contiguous()
  367. return hidden_state
  368. class Sam2MultiScaleBlock(GradientCheckpointingLayer):
  369. def __init__(
  370. self,
  371. config: Sam2HieraDetConfig,
  372. stage_idx: int,
  373. block_idx: int,
  374. total_block_idx: int,
  375. ):
  376. super().__init__()
  377. # take embed dim from previous stage if first block of stage
  378. self.dim = (
  379. config.embed_dim_per_stage[stage_idx - 1]
  380. if stage_idx > 0 and block_idx == 0
  381. else config.embed_dim_per_stage[stage_idx]
  382. )
  383. self.dim_out = config.embed_dim_per_stage[stage_idx]
  384. self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  385. # take window size from previous stage if first block of stage
  386. self.window_size = (
  387. config.window_size_per_stage[stage_idx - 1]
  388. if stage_idx > 0 and block_idx == 0
  389. else config.window_size_per_stage[stage_idx]
  390. )
  391. self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
  392. # use query stride for first block of stage if stage is a query pool stage
  393. self.query_stride = (
  394. config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
  395. )
  396. self.attn = Sam2MultiScaleAttention(
  397. config,
  398. self.dim,
  399. self.dim_out,
  400. num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
  401. query_stride=self.query_stride,
  402. )
  403. self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
  404. self.mlp = Sam2FeedForward(
  405. self.dim_out,
  406. int(self.dim_out * config.mlp_ratio),
  407. self.dim_out,
  408. num_layers=2,
  409. activation=config.hidden_act,
  410. )
  411. if self.dim != self.dim_out:
  412. self.proj = nn.Linear(self.dim, self.dim_out)
  413. def forward(
  414. self,
  415. hidden_states: torch.Tensor,
  416. **kwargs: Unpack[TransformersKwargs],
  417. ) -> torch.FloatTensor:
  418. residual = hidden_states # batch_size, height, width, channel
  419. hidden_states = self.layer_norm1(hidden_states)
  420. # Skip connection
  421. if self.dim != self.dim_out:
  422. residual = do_pool(self.proj(hidden_states), self.query_stride)
  423. # Window partition
  424. window_size = self.window_size
  425. if self.window_size > 0:
  426. H, W = hidden_states.shape[1], hidden_states.shape[2]
  427. hidden_states, pad_hw = window_partition(hidden_states, window_size)
  428. # Window Attention + Q Pooling (if stage change)
  429. attn_output = self.attn(
  430. hidden_states=hidden_states,
  431. **kwargs,
  432. )
  433. hidden_states = attn_output
  434. if self.query_stride:
  435. # Shapes have changed due to Q pooling
  436. window_size = self.window_size // self.query_stride[0]
  437. H, W = residual.shape[1:3]
  438. pad_h = (window_size - H % window_size) % window_size
  439. pad_w = (window_size - W % window_size) % window_size
  440. pad_hw = (H + pad_h, W + pad_w)
  441. # Reverse window partition
  442. if self.window_size > 0:
  443. hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
  444. hidden_states = residual + hidden_states
  445. layernorm_output = self.layer_norm2(hidden_states)
  446. hidden_states = hidden_states + self.mlp(layernorm_output)
  447. return hidden_states
  448. @dataclass
  449. @auto_docstring(
  450. custom_intro="""
  451. Hiera model's outputs that also contains a pooling of the last hidden states.
  452. """
  453. )
  454. class Sam2HieraDetModelOutput(ModelOutput):
  455. r"""
  456. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  457. hidden-states at the output of the last layer of the model.
  458. intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
  459. Sequence of hidden-states at the output of the intermediate layers of the model.
  460. """
  461. last_hidden_state: Optional[torch.FloatTensor] = None
  462. intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  463. @auto_docstring
  464. class Sam2PreTrainedModel(PreTrainedModel):
  465. config_class = Sam2Config
  466. base_model_prefix = "sam2"
  467. main_input_name = "pixel_values"
  468. _supports_sdpa = True
  469. _supports_flash_attn_2 = True
  470. _supports_attention_backend = True
  471. def _init_weights(self, module):
  472. std = self.config.initializer_range
  473. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  474. module.weight.data.normal_(mean=0.0, std=std)
  475. if module.bias is not None:
  476. module.bias.data.zero_()
  477. elif isinstance(module, nn.Embedding):
  478. module.weight.data.normal_(mean=0.0, std=std)
  479. if module.padding_idx is not None:
  480. module.weight.data[module.padding_idx].zero_()
  481. elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)):
  482. module.weight.data.fill_(1.0)
  483. module.bias.data.zero_()
  484. if isinstance(module, Sam2HieraDetModel):
  485. if module.pos_embed is not None:
  486. module.pos_embed.data.zero_()
  487. if module.pos_embed_window is not None:
  488. module.pos_embed_window.data.zero_()
  489. if isinstance(module, Sam2Model):
  490. if module.no_memory_embedding is not None:
  491. module.no_memory_embedding.data.zero_()
  492. class Sam2HieraDetModel(Sam2PreTrainedModel):
  493. config_class = Sam2HieraDetConfig
  494. main_input_name = "pixel_values"
  495. _can_record_outputs = {
  496. "hidden_states": Sam2MultiScaleBlock,
  497. "attentions": Sam2MultiScaleAttention,
  498. }
  499. def __init__(self, config: Sam2HieraDetConfig):
  500. super().__init__(config)
  501. self.patch_embed = Sam2PatchEmbeddings(config)
  502. # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
  503. self.pos_embed = nn.Parameter(
  504. torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
  505. )
  506. self.pos_embed_window = nn.Parameter(
  507. torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
  508. )
  509. self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
  510. self.blocks = nn.ModuleList()
  511. total_block_idx = 0
  512. for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
  513. for block_idx in range(blocks_per_stage):
  514. block = Sam2MultiScaleBlock(
  515. config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
  516. )
  517. self.blocks.append(block)
  518. total_block_idx += 1
  519. def get_input_embeddings(self):
  520. return self.patch_embed
  521. def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
  522. h, w = hw
  523. window_embed = self.pos_embed_window
  524. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  525. pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
  526. pos_embed = pos_embed.permute(0, 2, 3, 1)
  527. return pos_embed
  528. @check_model_inputs()
  529. def forward(
  530. self,
  531. pixel_values: Optional[torch.FloatTensor] = None,
  532. **kwargs: Unpack[TransformersKwargs],
  533. ) -> Union[tuple, Sam2HieraDetModelOutput]:
  534. if pixel_values is None:
  535. raise ValueError("You have to specify pixel_values")
  536. hidden_states = self.patch_embed(pixel_values)
  537. hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
  538. intermediate_hidden_states = ()
  539. for i, block_module in enumerate(self.blocks):
  540. hidden_states = block_module(hidden_states, **kwargs)
  541. if i in self.stage_ends:
  542. intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
  543. return Sam2HieraDetModelOutput(
  544. last_hidden_state=hidden_states,
  545. intermediate_hidden_states=intermediate_hidden_states,
  546. )
  547. @auto_docstring(
  548. custom_intro="""
  549. The vision model from Sam without any head or projection on top.
  550. """
  551. )
  552. class Sam2VisionModel(Sam2PreTrainedModel):
  553. config_class = Sam2VisionConfig
  554. main_input_name = "pixel_values"
  555. _can_record_outputs = {
  556. "hidden_states": Sam2MultiScaleBlock,
  557. "attentions": Sam2MultiScaleAttention,
  558. }
  559. def __init__(self, config: Sam2VisionConfig):
  560. super().__init__(config)
  561. self.config = config
  562. self.backbone = AutoModel.from_config(config.backbone_config)
  563. self.neck = Sam2VisionNeck(config)
  564. self.num_feature_levels = config.num_feature_levels
  565. self.post_init()
  566. def get_input_embeddings(self):
  567. return self.backbone.get_input_embeddings()
  568. @check_model_inputs()
  569. def forward(
  570. self,
  571. pixel_values: Optional[torch.FloatTensor] = None,
  572. **kwargs: Unpack[TransformersKwargs],
  573. ) -> Union[tuple, Sam2VisionEncoderOutput]:
  574. if pixel_values is None:
  575. raise ValueError("You have to specify pixel_values")
  576. # Forward through backbone
  577. backbone_output = self.backbone(pixel_values, **kwargs)
  578. hidden_states = backbone_output.last_hidden_state
  579. intermediate_hidden_states = backbone_output.intermediate_hidden_states
  580. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  581. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  582. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  583. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  584. return Sam2VisionEncoderOutput(
  585. last_hidden_state=hidden_states,
  586. fpn_hidden_states=fpn_hidden_states,
  587. fpn_position_encoding=fpn_position_encoding,
  588. )
  589. class Sam2PositionalEmbedding(nn.Module):
  590. def __init__(self, config: Sam2PromptEncoderConfig):
  591. super().__init__()
  592. self.scale = config.scale
  593. positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
  594. self.register_buffer("positional_embedding", positional_embedding)
  595. def forward(self, input_coords, input_shape=None):
  596. """Positionally encode points that are normalized to [0,1]."""
  597. coordinates = input_coords.clone()
  598. if input_shape is not None:
  599. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  600. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  601. coordinates.to(torch.float32)
  602. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  603. coordinates = 2 * coordinates - 1
  604. coordinates = coordinates.to(self.positional_embedding.dtype)
  605. coordinates = coordinates @ self.positional_embedding
  606. coordinates = 2 * np.pi * coordinates
  607. # outputs d_1 x ... x d_n x channel shape
  608. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  609. class Sam2MaskEmbedding(nn.Module):
  610. def __init__(self, config: Sam2PromptEncoderConfig):
  611. super().__init__()
  612. self.mask_input_channels = config.mask_input_channels // 4
  613. self.activation = ACT2FN[config.hidden_act]
  614. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  615. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  616. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  617. self.layer_norm1 = Sam2LayerNorm(
  618. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  619. )
  620. self.layer_norm2 = Sam2LayerNorm(
  621. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  622. )
  623. def forward(self, masks):
  624. hidden_states = self.conv1(masks)
  625. hidden_states = self.layer_norm1(hidden_states)
  626. hidden_states = self.activation(hidden_states)
  627. hidden_states = self.conv2(hidden_states)
  628. hidden_states = self.layer_norm2(hidden_states)
  629. hidden_states = self.activation(hidden_states)
  630. dense_embeddings = self.conv3(hidden_states)
  631. return dense_embeddings
  632. class Sam2PromptEncoder(nn.Module):
  633. def __init__(self, config: Sam2PromptEncoderConfig):
  634. super().__init__()
  635. self.shared_embedding = Sam2PositionalEmbedding(config)
  636. self.mask_embed = Sam2MaskEmbedding(config)
  637. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  638. self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  639. self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
  640. self.input_image_size = config.image_size
  641. self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
  642. self.hidden_size = config.hidden_size
  643. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  644. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  645. """Embeds point prompts."""
  646. points = points + 0.5 # Shift to center of pixel
  647. if pad:
  648. points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
  649. labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
  650. input_shape = (self.input_image_size, self.input_image_size)
  651. point_embedding = self.shared_embedding(points, input_shape)
  652. # torch.where and expanding the labels tensor is required by the ONNX export
  653. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  654. # This is required for the ONNX export. The dtype, device need to be explicitly
  655. # specified as otherwise torch.onnx.export interprets as double
  656. point_embedding = torch.where(
  657. labels[..., None] != -10,
  658. point_embedding,
  659. torch.zeros_like(point_embedding),
  660. )
  661. # Add point embeddings for labels >= 0
  662. point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
  663. return point_embedding
  664. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  665. """Embeds box prompts."""
  666. boxes += 0.5 # Shift to center of pixel
  667. coords = boxes.view(*boxes.shape[:2], 2, 2)
  668. # add padding point for consistency with the original implementation
  669. coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
  670. corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
  671. corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
  672. corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
  673. corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
  674. return corner_embedding
  675. def forward(
  676. self,
  677. input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
  678. input_labels: Optional[torch.Tensor],
  679. input_boxes: Optional[torch.Tensor],
  680. input_masks: Optional[torch.Tensor],
  681. ) -> tuple[torch.Tensor, torch.Tensor]:
  682. """
  683. Embeds different types of prompts, returning both sparse and dense embeddings.
  684. Args:
  685. points (`torch.Tensor`, *optional*):
  686. point coordinates and labels to embed.
  687. boxes (`torch.Tensor`, *optional*):
  688. boxes to embed
  689. masks (`torch.Tensor`, *optional*):
  690. masks to embed
  691. """
  692. sparse_embeddings = None
  693. batch_size = 1
  694. if input_points is not None:
  695. batch_size = input_points.shape[0]
  696. if input_labels is None:
  697. raise ValueError("If points are provided, labels must also be provided.")
  698. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  699. sparse_embeddings = point_embeddings
  700. if input_boxes is not None:
  701. batch_size = input_boxes.shape[0]
  702. box_embeddings = self._embed_boxes(input_boxes)
  703. if sparse_embeddings is None:
  704. sparse_embeddings = box_embeddings
  705. else:
  706. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  707. if input_masks is not None:
  708. dense_embeddings = self.mask_embed(input_masks)
  709. else:
  710. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  711. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  712. )
  713. return sparse_embeddings, dense_embeddings
  714. class Sam2Attention(nn.Module):
  715. """
  716. SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  717. values.
  718. """
  719. def __init__(self, config, downsample_rate=None):
  720. super().__init__()
  721. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  722. self.config = config
  723. self.hidden_size = config.hidden_size
  724. self.internal_dim = config.hidden_size // downsample_rate
  725. self.num_attention_heads = config.num_attention_heads
  726. self.head_dim = self.internal_dim // config.num_attention_heads
  727. self.scaling = self.head_dim**-0.5
  728. self.is_causal = False
  729. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  730. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  731. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  732. self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
  733. def forward(
  734. self,
  735. query: torch.Tensor,
  736. key: torch.Tensor,
  737. value: torch.Tensor,
  738. attention_similarity: Optional[torch.Tensor] = None,
  739. **kwargs: Unpack[TransformersKwargs],
  740. ) -> tuple[torch.Tensor, torch.Tensor]:
  741. # Input projections
  742. batch_size, point_batch_size = query.shape[:2]
  743. new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
  744. query = self.q_proj(query).view(*new_shape).transpose(1, 2)
  745. key = self.k_proj(key).view(*new_shape).transpose(1, 2)
  746. value = self.v_proj(value).view(*new_shape).transpose(1, 2)
  747. attention_interface: Callable = eager_attention_forward
  748. if self.config._attn_implementation != "eager":
  749. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  750. attn_output, attn_weights = attention_interface(
  751. self,
  752. query,
  753. key,
  754. value,
  755. attention_mask=attention_similarity,
  756. dropout=0.0,
  757. scaling=self.scaling,
  758. is_causal=self.is_causal,
  759. **kwargs,
  760. )
  761. attn_output = attn_output.reshape(
  762. batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
  763. ).contiguous()
  764. attn_output = self.o_proj(attn_output)
  765. return attn_output, attn_weights
  766. class Sam2TwoWayAttentionBlock(nn.Module):
  767. def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
  768. """
  769. A transformer block with four layers:
  770. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  771. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  772. Arguments:
  773. config (`Sam2MaskDecoderConfig`):
  774. The configuration file used to instantiate the block
  775. attention_downsample_rate (*optionalk*, int, defaults to 2):
  776. The downsample ratio of the block used to reduce the inner dim of the attention.
  777. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  778. Whether or not to skip the addition of the query_point_embedding on the first layer.
  779. """
  780. super().__init__()
  781. self.self_attn = Sam2Attention(config, downsample_rate=1)
  782. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  783. self.cross_attn_token_to_image = Sam2Attention(config)
  784. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  785. self.mlp = Sam2FeedForward(
  786. config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
  787. )
  788. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  789. self.layer_norm4 = nn.LayerNorm(config.hidden_size)
  790. self.cross_attn_image_to_token = Sam2Attention(config)
  791. self.skip_first_layer_pe = skip_first_layer_pe
  792. def forward(
  793. self,
  794. queries: Tensor,
  795. keys: Tensor,
  796. query_point_embedding: Tensor,
  797. key_point_embedding: Tensor,
  798. attention_similarity: Tensor,
  799. **kwargs: Unpack[TransformersKwargs],
  800. ):
  801. # Self attention block
  802. if self.skip_first_layer_pe:
  803. queries, _ = self.self_attn(query=queries, key=queries, value=queries)
  804. else:
  805. query = queries + query_point_embedding
  806. attn_out, _ = self.self_attn(query=query, key=query, value=queries)
  807. queries = queries + attn_out
  808. queries = self.layer_norm1(queries)
  809. # Cross attention block, tokens attending to image embedding
  810. query = queries + query_point_embedding
  811. key = keys + key_point_embedding
  812. attn_out, _ = self.cross_attn_token_to_image(
  813. query=query, key=key, value=keys, attention_similarity=attention_similarity
  814. )
  815. queries = queries + attn_out
  816. queries = self.layer_norm2(queries)
  817. # MLP block
  818. mlp_out = self.mlp(queries)
  819. queries = queries + mlp_out
  820. queries = self.layer_norm3(queries)
  821. # Cross attention block, image embedding attending to tokens
  822. query = queries + query_point_embedding
  823. key = keys + key_point_embedding
  824. attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  825. keys = keys + attn_out
  826. keys = self.layer_norm4(keys)
  827. return queries, keys, attn_out
  828. class Sam2TwoWayTransformer(nn.Module):
  829. def __init__(self, config: Sam2MaskDecoderConfig):
  830. super().__init__()
  831. self.config = config
  832. self.num_hidden_layers = config.num_hidden_layers
  833. self.layers = nn.ModuleList()
  834. for i in range(self.num_hidden_layers):
  835. self.layers.append(Sam2TwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  836. self.final_attn_token_to_image = Sam2Attention(config)
  837. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  838. def forward(
  839. self,
  840. point_embeddings: Tensor,
  841. image_embeddings: Tensor,
  842. image_positional_embeddings: Tensor,
  843. attention_similarity: Tensor,
  844. target_embedding=None,
  845. **kwargs: Unpack[TransformersKwargs],
  846. ) -> Union[tuple, BaseModelOutput]:
  847. if image_embeddings is None:
  848. raise ValueError("You have to specify an image_embedding")
  849. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  850. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  851. # Prepare queries
  852. queries = point_embeddings
  853. keys = image_embeddings
  854. # Apply transformer blocks and final layernorm
  855. for layer in self.layers:
  856. if target_embedding is not None:
  857. queries += target_embedding
  858. queries, keys, _ = layer(
  859. queries=queries,
  860. keys=keys,
  861. query_point_embedding=point_embeddings,
  862. key_point_embedding=image_positional_embeddings,
  863. attention_similarity=attention_similarity,
  864. **kwargs,
  865. )
  866. # Apply the final attention layer from the points to the image
  867. query = queries + point_embeddings
  868. key = keys + image_positional_embeddings
  869. attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
  870. queries = queries + attn_out
  871. queries = self.layer_norm_final_attn(queries)
  872. return queries, keys
  873. class Sam2LayerNorm(nn.LayerNorm):
  874. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  875. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  876. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  877. """
  878. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  879. super().__init__(normalized_shape, eps=eps, **kwargs)
  880. if data_format not in ["channels_last", "channels_first"]:
  881. raise NotImplementedError(f"Unsupported data format: {data_format}")
  882. self.data_format = data_format
  883. def forward(self, features: torch.Tensor) -> torch.Tensor:
  884. """
  885. Args:
  886. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  887. """
  888. if self.data_format == "channels_first":
  889. features = features.permute(0, 2, 3, 1)
  890. features = super().forward(features)
  891. features = features.permute(0, 3, 1, 2)
  892. else:
  893. features = super().forward(features)
  894. return features
  895. class Sam2MaskDecoder(nn.Module):
  896. def __init__(self, config: Sam2MaskDecoderConfig):
  897. super().__init__()
  898. self.config = config
  899. self.hidden_size = config.hidden_size
  900. self.num_multimask_outputs = config.num_multimask_outputs
  901. self.num_mask_tokens = config.num_multimask_outputs + 1
  902. self.iou_token = nn.Embedding(1, self.hidden_size)
  903. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  904. self.transformer = Sam2TwoWayTransformer(config)
  905. # should we create a new class for this?
  906. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  907. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  908. self.upscale_layer_norm = Sam2LayerNorm(self.hidden_size // 4, data_format="channels_first")
  909. self.activation = nn.GELU()
  910. mlps_list = []
  911. for _ in range(self.num_mask_tokens):
  912. mlps_list += [Sam2FeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  913. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  914. self.iou_prediction_head = Sam2FeedForward(
  915. self.hidden_size,
  916. config.iou_head_hidden_dim,
  917. self.num_mask_tokens,
  918. config.iou_head_depth,
  919. sigmoid_output=True,
  920. )
  921. self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
  922. self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
  923. self.obj_score_token = nn.Embedding(1, self.hidden_size)
  924. self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
  925. self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
  926. self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
  927. self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
  928. def forward(
  929. self,
  930. image_embeddings: torch.Tensor,
  931. image_positional_embeddings: torch.Tensor,
  932. sparse_prompt_embeddings: torch.Tensor,
  933. dense_prompt_embeddings: torch.Tensor,
  934. multimask_output: bool,
  935. high_resolution_features: list[torch.Tensor],
  936. attention_similarity: Optional[torch.Tensor] = None,
  937. target_embedding: Optional[torch.Tensor] = None,
  938. **kwargs: Unpack[TransformersKwargs],
  939. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  940. """
  941. Predict masks given image and prompt embeddings.
  942. Args:
  943. image_embeddings (`torch.Tensor`):
  944. The embeddings from the image encoder.
  945. image_positional_embeddings (`torch.Tensor`):
  946. Positional encoding with the shape of image_embeddings.
  947. sparse_prompt_embeddings (`torch.Tensor`):
  948. The embeddings of the points and boxes.
  949. dense_prompt_embeddings (`torch.Tensor`):
  950. The embeddings of the mask inputs.
  951. multimask_output (`bool`):
  952. Whether to return multiple masks or a single mask.
  953. high_resolution_features (`list[torch.Tensor]`, *optional*):
  954. The high-resolution features from the vision encoder.
  955. attention_similarity (`torch.Tensor`, *optional*):
  956. The attention similarity tensor.
  957. target_embedding (`torch.Tensor`, *optional*):
  958. The target embedding.
  959. """
  960. batch_size, num_channels, height, width = image_embeddings.shape
  961. point_batch_size = sparse_prompt_embeddings.shape[1]
  962. # Concatenate output tokens
  963. output_tokens = torch.cat(
  964. [
  965. self.obj_score_token.weight,
  966. self.iou_token.weight,
  967. self.mask_tokens.weight,
  968. ],
  969. dim=0,
  970. )
  971. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  972. if sparse_prompt_embeddings.shape[0] != 0:
  973. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  974. else:
  975. tokens = output_tokens
  976. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  977. # Expand per-image data in batch direction to be per-mask
  978. image_embeddings = image_embeddings + dense_prompt_embeddings
  979. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
  980. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  981. # Run the transformer
  982. point_embeddings, image_embeddings = self.transformer(
  983. point_embeddings=point_embeddings,
  984. image_embeddings=image_embeddings,
  985. image_positional_embeddings=image_positional_embeddings,
  986. attention_similarity=attention_similarity,
  987. target_embedding=target_embedding,
  988. **kwargs,
  989. )
  990. iou_token_out = point_embeddings[:, :, 1, :]
  991. mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
  992. # Upscale mask embeddings and predict masks using the mask tokens
  993. image_embeddings = image_embeddings.transpose(2, 3).view(
  994. batch_size * point_batch_size, num_channels, height, width
  995. )
  996. feat_s0, feat_s1 = high_resolution_features
  997. feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
  998. feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
  999. upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
  1000. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  1001. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
  1002. hyper_in_list: list[torch.Tensor] = []
  1003. for i in range(self.num_mask_tokens):
  1004. current_mlp = self.output_hypernetworks_mlps[i]
  1005. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  1006. hyper_in = torch.stack(hyper_in_list, dim=2)
  1007. _, num_channels, height, width = upscaled_embedding.shape
  1008. upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
  1009. masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
  1010. # Generate mask quality predictions
  1011. iou_pred = self.iou_prediction_head(iou_token_out)
  1012. object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
  1013. # Select the correct mask or masks for output
  1014. if multimask_output:
  1015. mask_slice = slice(1, None)
  1016. masks = masks[:, :, mask_slice, :, :]
  1017. iou_pred = iou_pred[:, :, mask_slice]
  1018. elif self.dynamic_multimask_via_stability and not self.training:
  1019. mask_slice = slice(0, 1)
  1020. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  1021. else:
  1022. mask_slice = slice(0, 1)
  1023. masks = masks[:, :, mask_slice, :, :]
  1024. iou_pred = iou_pred[:, :, mask_slice]
  1025. sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
  1026. return masks, iou_pred, sam_tokens_out, object_score_logits
  1027. def _get_stability_scores(self, mask_logits):
  1028. """
  1029. Compute stability scores of the mask logits based on the IoU between upper and
  1030. lower thresholds.
  1031. """
  1032. mask_logits = mask_logits.flatten(-2)
  1033. stability_delta = self.dynamic_multimask_stability_delta
  1034. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  1035. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  1036. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  1037. return stability_scores
  1038. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  1039. """
  1040. When outputting a single mask, if the stability score from the current single-mask
  1041. output (based on output token 0) falls below a threshold, we instead select from
  1042. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  1043. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  1044. """
  1045. # The best mask from multimask output tokens (1~3)
  1046. multimask_logits = all_mask_logits[:, :, 1:, :, :]
  1047. multimask_iou_scores = all_iou_scores[:, :, 1:]
  1048. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
  1049. best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  1050. best_scores_inds_expanded = best_scores_inds_expanded.expand(
  1051. -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
  1052. )
  1053. best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
  1054. best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
  1055. # The mask from singlemask output token 0 and its stability score
  1056. singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
  1057. singlemask_iou_scores = all_iou_scores[:, :, 0:1]
  1058. stability_scores = self._get_stability_scores(singlemask_logits)
  1059. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  1060. # Dynamically fall back to best multimask output upon low stability scores.
  1061. mask_logits_out = torch.where(
  1062. is_stable[..., None, None].expand_as(singlemask_logits),
  1063. singlemask_logits,
  1064. best_multimask_logits,
  1065. )
  1066. iou_scores_out = torch.where(
  1067. is_stable.expand_as(singlemask_iou_scores),
  1068. singlemask_iou_scores,
  1069. best_multimask_iou_scores,
  1070. )
  1071. return mask_logits_out, iou_scores_out
  1072. @auto_docstring(
  1073. custom_intro="""
  1074. Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
  1075. input points and labels, boxes, or masks.
  1076. """
  1077. )
  1078. class Sam2Model(Sam2PreTrainedModel):
  1079. _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
  1080. # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
  1081. _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
  1082. _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)}
  1083. _keys_to_ignore_on_load_unexpected = [
  1084. r"^memory_.*",
  1085. r"^mask_downsample.*",
  1086. r"^object_pointer_proj.*",
  1087. r"^temporal_positional_encoding_projection_layer.*",
  1088. "no_memory_positional_encoding",
  1089. "no_object_pointer",
  1090. "occlusion_spatial_embedding_parameter",
  1091. ]
  1092. def __init__(self, config: Sam2Config):
  1093. super().__init__(config)
  1094. self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
  1095. self.vision_encoder = AutoModel.from_config(config.vision_config)
  1096. self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
  1097. # The module using it is not a PreTrainedModel subclass so we need this
  1098. config.mask_decoder_config._attn_implementation = config._attn_implementation
  1099. self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
  1100. self.num_feature_levels = config.vision_config.num_feature_levels
  1101. self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
  1102. # a single token to indicate no memory embedding from previous frames
  1103. self.hidden_dim = config.vision_config.fpn_hidden_size
  1104. self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  1105. self.post_init()
  1106. def _tie_weights(self):
  1107. self.prompt_encoder.shared_embedding.positional_embedding.data = (
  1108. self.shared_image_embedding.positional_embedding.data
  1109. )
  1110. def get_input_embeddings(self):
  1111. return self.vision_encoder.get_input_embeddings()
  1112. def get_image_wide_positional_embeddings(self) -> torch.Tensor:
  1113. size = self.prompt_encoder.image_embedding_size
  1114. target_device = self.shared_image_embedding.positional_embedding.device
  1115. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  1116. grid = torch.ones(size, device=target_device, dtype=target_dtype)
  1117. y_embed = grid.cumsum(dim=0) - 0.5
  1118. x_embed = grid.cumsum(dim=1) - 0.5
  1119. y_embed = y_embed / size[0]
  1120. x_embed = x_embed / size[1]
  1121. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  1122. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  1123. @torch.no_grad()
  1124. def get_image_embeddings(
  1125. self,
  1126. pixel_values: torch.FloatTensor,
  1127. **kwargs: Unpack[TransformersKwargs],
  1128. ) -> list[torch.Tensor]:
  1129. r"""
  1130. Returns the image embeddings by passing the pixel values through the vision encoder.
  1131. Args:
  1132. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1133. Input pixel values
  1134. """
  1135. batch_size = pixel_values.shape[0]
  1136. feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
  1137. # add no memory embedding to the last feature map
  1138. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1139. # reshape feature maps to the same shape as the backbone feature sizes
  1140. image_embeddings = [
  1141. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1142. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1143. ]
  1144. return image_embeddings
  1145. @torch.no_grad()
  1146. def get_prompt_embeddings(
  1147. self,
  1148. input_points: Optional[torch.FloatTensor] = None,
  1149. input_labels: Optional[torch.LongTensor] = None,
  1150. input_boxes: Optional[torch.FloatTensor] = None,
  1151. input_masks: Optional[torch.LongTensor] = None,
  1152. ):
  1153. r"""
  1154. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  1155. Args:
  1156. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  1157. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  1158. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  1159. point. The model will output `point_batch_size` times 3 masks in total.
  1160. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  1161. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  1162. processor, or can be fed by the user.
  1163. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  1164. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  1165. processor. users can also pass manually the input boxes.
  1166. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  1167. Optional input masks for the prompt encoder.
  1168. """
  1169. prompt_output = self.prompt_encoder(
  1170. input_points=input_points,
  1171. input_labels=input_labels,
  1172. input_boxes=input_boxes,
  1173. input_masks=input_masks,
  1174. )
  1175. return prompt_output
  1176. @check_model_inputs()
  1177. @auto_docstring
  1178. def forward(
  1179. self,
  1180. pixel_values: Optional[torch.FloatTensor] = None,
  1181. input_points: Optional[torch.FloatTensor] = None,
  1182. input_labels: Optional[torch.LongTensor] = None,
  1183. input_boxes: Optional[torch.FloatTensor] = None,
  1184. input_masks: Optional[torch.LongTensor] = None,
  1185. image_embeddings: Optional[torch.FloatTensor] = None,
  1186. multimask_output: bool = True,
  1187. attention_similarity: Optional[torch.FloatTensor] = None,
  1188. target_embedding: Optional[torch.FloatTensor] = None,
  1189. **kwargs: Unpack[TransformersKwargs],
  1190. ) -> Sam2ImageSegmentationOutput:
  1191. r"""
  1192. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  1193. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1194. better results. The points can be obtained by passing a list of list of list to the processor that will
  1195. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  1196. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  1197. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1198. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1199. coordinates of the point. If a different number of points is passed either for each image, or for each
  1200. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1201. computation of the embedding will be skipped for these points using the labels.
  1202. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  1203. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1204. official implementation, there are 3 types of labels
  1205. - `1`: the point is a point that contains the object of interest
  1206. - `0`: the point is a point that does not contain the object of interest
  1207. - `-1`: the point corresponds to the background
  1208. We added the label:
  1209. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1210. The padding labels should be automatically done by the processor.
  1211. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  1212. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1213. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1214. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  1215. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  1216. In the order (`x1`, `y1`, `x2`, `y2`):
  1217. - `x1`: the x coordinate of the top left point of the input box
  1218. - `y1`: the y coordinate of the top left point of the input box
  1219. - `x2`: the x coordinate of the bottom right point of the input box
  1220. - `y2`: the y coordinate of the bottom right point of the input box
  1221. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  1222. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1223. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1224. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1225. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1226. Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
  1227. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1228. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  1229. multimask_output (`bool`, *optional*):
  1230. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1231. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1232. "best" mask, by specifying `multimask_output=False`.
  1233. attention_similarity (`torch.FloatTensor`, *optional*):
  1234. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  1235. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1236. target_embedding (`torch.FloatTensor`, *optional*):
  1237. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  1238. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1239. Example:
  1240. ```python
  1241. >>> from PIL import Image
  1242. >>> import requests
  1243. >>> from transformers import AutoModel, AutoProcessor
  1244. >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1245. >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1246. >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1247. >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
  1248. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1249. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1250. >>> # Get segmentation mask
  1251. >>> outputs = model(**inputs)
  1252. >>> # Postprocess masks
  1253. >>> masks = processor.post_process_masks(
  1254. ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
  1255. ... )
  1256. ```
  1257. """
  1258. if not ((pixel_values is None) ^ (image_embeddings is None)):
  1259. raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
  1260. if input_points is not None and input_boxes is not None:
  1261. if input_points.shape[1] != input_boxes.shape[1]:
  1262. raise ValueError(
  1263. f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
  1264. )
  1265. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1266. # repeat with batch size
  1267. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
  1268. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1269. vision_attentions = None
  1270. vision_hidden_states = None
  1271. if pixel_values is not None:
  1272. feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
  1273. pixel_values,
  1274. **kwargs,
  1275. )
  1276. # add no memory embedding to the last feature map
  1277. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1278. # reshape feature maps to the same shape as the backbone feature sizes
  1279. image_embeddings = [
  1280. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1281. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1282. ]
  1283. if input_points is not None and input_labels is None:
  1284. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1285. if input_points is None and input_boxes is None:
  1286. # If no points are provide, pad with an empty point (with label -1)
  1287. input_points = torch.zeros(
  1288. batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
  1289. )
  1290. input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
  1291. if input_masks is not None:
  1292. # If mask_inputs is provided, downsize it into low-res mask input if needed
  1293. # and feed it as a dense mask prompt into the SAM mask encoder
  1294. if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
  1295. input_masks = F.interpolate(
  1296. input_masks.float(),
  1297. size=self.prompt_encoder.mask_input_size,
  1298. align_corners=False,
  1299. mode="bilinear",
  1300. antialias=True, # use antialias for downsampling
  1301. ).to(input_masks.dtype)
  1302. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1303. input_points=input_points,
  1304. input_labels=input_labels,
  1305. input_boxes=input_boxes,
  1306. input_masks=input_masks,
  1307. )
  1308. low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
  1309. image_embeddings=image_embeddings[-1],
  1310. image_positional_embeddings=image_positional_embeddings,
  1311. sparse_prompt_embeddings=sparse_embeddings,
  1312. dense_prompt_embeddings=dense_embeddings,
  1313. multimask_output=multimask_output,
  1314. high_resolution_features=image_embeddings[:-1],
  1315. attention_similarity=attention_similarity,
  1316. target_embedding=target_embedding,
  1317. **kwargs,
  1318. )
  1319. return Sam2ImageSegmentationOutput(
  1320. iou_scores=iou_scores,
  1321. pred_masks=low_res_multimasks,
  1322. object_score_logits=object_score_logits,
  1323. image_embeddings=image_embeddings,
  1324. vision_hidden_states=vision_hidden_states,
  1325. vision_attentions=vision_attentions,
  1326. )
  1327. def get_image_features(
  1328. self,
  1329. pixel_values: torch.FloatTensor,
  1330. **kwargs: Unpack[TransformersKwargs],
  1331. ) -> tuple[
  1332. list[torch.Tensor],
  1333. list[torch.Tensor],
  1334. Optional[tuple[torch.FloatTensor, ...]],
  1335. Optional[tuple[torch.FloatTensor, ...]],
  1336. ]:
  1337. r"""
  1338. Extract and preprocess image features using the vision encoder.
  1339. Args:
  1340. pixel_values (`torch.FloatTensor`):
  1341. Input pixel values of shape `(batch_size, num_channels, height, width)`.
  1342. Returns:
  1343. `tuple`: A tuple containing:
  1344. - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
  1345. - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
  1346. - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
  1347. - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
  1348. """
  1349. vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(
  1350. pixel_values,
  1351. **kwargs,
  1352. )
  1353. feature_maps = vision_outputs.fpn_hidden_states
  1354. feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
  1355. # precompute projected level 0 and level 1 features in SAM decoder
  1356. # to avoid running it again on every SAM click
  1357. feature_maps = list(feature_maps)
  1358. feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
  1359. feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
  1360. # flatten NxCxHxW to HWxNxC
  1361. feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
  1362. feature_maps_position_embeddings = [
  1363. feature_map_position_embedding.flatten(2).permute(2, 0, 1)
  1364. for feature_map_position_embedding in feature_maps_position_embeddings
  1365. ]
  1366. return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
  1367. __all__ = ["Sam2Model", "Sam2VisionModel", "Sam2PreTrainedModel", "Sam2HieraDetModel"]