modular_sam2.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463
  1. # coding=utf-8
  2. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SAM 2 model."""
  16. from dataclasses import dataclass
  17. from typing import Callable, Optional, Union
  18. import numpy as np
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from ...activations import ACT2FN
  23. from ...image_processing_utils import BatchFeature, get_size_dict
  24. from ...image_processing_utils_fast import BaseImageProcessorFast, DefaultFastImageProcessorKwargs
  25. from ...image_utils import (
  26. IMAGENET_DEFAULT_MEAN,
  27. IMAGENET_DEFAULT_STD,
  28. ChannelDimension,
  29. ImageInput,
  30. PILImageResampling,
  31. SizeDict,
  32. pil_torch_interpolation_mapping,
  33. )
  34. from ...modeling_layers import GradientCheckpointingLayer
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...utils import (
  38. ModelOutput,
  39. TensorType,
  40. auto_docstring,
  41. logging,
  42. )
  43. from ...utils.generic import TransformersKwargs, check_model_inputs
  44. from ..auto import AutoModel
  45. from ..maskformer.modeling_maskformer import MaskFormerSinePositionEmbedding
  46. from ..sam.image_processing_sam_fast import SamImageProcessorFast
  47. from ..sam.modeling_sam import (
  48. SamLayerNorm,
  49. SamMaskDecoder,
  50. SamMaskEmbedding,
  51. SamModel,
  52. SamPromptEncoder,
  53. SamTwoWayAttentionBlock,
  54. SamTwoWayTransformer,
  55. eager_attention_forward,
  56. )
  57. from ..vitdet.modeling_vitdet import window_partition, window_unpartition
  58. from .configuration_sam2 import (
  59. Sam2Config,
  60. Sam2HieraDetConfig,
  61. Sam2MaskDecoderConfig,
  62. Sam2PromptEncoderConfig,
  63. Sam2VisionConfig,
  64. )
  65. logger = logging.get_logger(__name__)
  66. class Sam2FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
  67. r"""
  68. mask_size (`dict[str, int]`, *optional*):
  69. The size `{"height": int, "width": int}` to resize the segmentation maps to.
  70. """
  71. mask_size: Optional[dict[str, int]]
  72. @auto_docstring
  73. class Sam2ImageProcessorFast(SamImageProcessorFast):
  74. resample = PILImageResampling.BILINEAR
  75. image_mean = IMAGENET_DEFAULT_MEAN
  76. image_std = IMAGENET_DEFAULT_STD
  77. size = {"height": 1024, "width": 1024}
  78. mask_size = {"height": 256, "width": 256}
  79. do_resize = True
  80. do_rescale = True
  81. do_normalize = True
  82. do_convert_rgb = True
  83. valid_kwargs = Sam2FastImageProcessorKwargs
  84. # modular artefacts
  85. do_pad = None
  86. pad_size = None
  87. mask_pad_size = None
  88. def __init__(self, **kwargs: Unpack[Sam2FastImageProcessorKwargs]):
  89. BaseImageProcessorFast.__init__(self, **kwargs)
  90. def pad_image(self):
  91. raise NotImplementedError("No pad_image for SAM 2.")
  92. def _get_preprocess_shape(self):
  93. raise NotImplementedError("No _get_preprocess_shape for SAM 2.")
  94. def resize(self):
  95. raise NotImplementedError("No need to override resize for SAM 2.")
  96. def _preprocess(
  97. self,
  98. images: list["torch.Tensor"],
  99. return_tensors: Optional[Union[str, TensorType]],
  100. **kwargs,
  101. ) -> "torch.Tensor":
  102. return BaseImageProcessorFast._preprocess(self, images, return_tensors=return_tensors, **kwargs).pixel_values
  103. def _preprocess_image_like_inputs(
  104. self,
  105. images: ImageInput,
  106. segmentation_maps: Optional[ImageInput],
  107. do_convert_rgb: bool,
  108. input_data_format: ChannelDimension,
  109. device: Optional[Union[str, "torch.device"]] = None,
  110. **kwargs: Unpack[Sam2FastImageProcessorKwargs],
  111. ) -> BatchFeature:
  112. """
  113. Preprocess image-like inputs.
  114. """
  115. images = self._prepare_image_like_inputs(
  116. images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
  117. )
  118. original_sizes = [image.shape[-2:] for image in images]
  119. images_kwargs = kwargs.copy()
  120. pixel_values = self._preprocess(images, **images_kwargs)
  121. reshaped_input_sizes = [image.shape[-2:] for image in images]
  122. data = {
  123. "pixel_values": pixel_values,
  124. "original_sizes": original_sizes,
  125. "reshaped_input_sizes": reshaped_input_sizes,
  126. }
  127. if segmentation_maps is not None:
  128. processed_segmentation_maps = self._prepare_image_like_inputs(
  129. images=segmentation_maps,
  130. expected_ndims=2,
  131. do_convert_rgb=False,
  132. input_data_format=ChannelDimension.FIRST,
  133. )
  134. segmentation_maps_kwargs = kwargs.copy()
  135. segmentation_maps_kwargs.update(
  136. {
  137. "do_normalize": False,
  138. "do_rescale": False,
  139. "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST],
  140. "size": segmentation_maps_kwargs.pop("mask_size"),
  141. }
  142. )
  143. processed_segmentation_maps = self._preprocess(
  144. images=processed_segmentation_maps, **segmentation_maps_kwargs
  145. )
  146. data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
  147. return BatchFeature(data=data, tensor_type=kwargs["return_tensors"])
  148. def _further_process_kwargs(
  149. self,
  150. size: Optional[SizeDict] = None,
  151. mask_size: Optional[SizeDict] = None,
  152. default_to_square: Optional[bool] = None,
  153. image_mean: Optional[Union[float, list[float]]] = None,
  154. image_std: Optional[Union[float, list[float]]] = None,
  155. data_format: Optional[ChannelDimension] = None,
  156. **kwargs,
  157. ) -> dict:
  158. """
  159. Update kwargs that need further processing before being validated
  160. Can be overridden by subclasses to customize the processing of kwargs.
  161. """
  162. if kwargs is None:
  163. kwargs = {}
  164. if size is not None:
  165. size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
  166. if mask_size is not None:
  167. mask_size = SizeDict(**get_size_dict(mask_size, param_name="mask_size"))
  168. if isinstance(image_mean, list):
  169. image_mean = tuple(image_mean)
  170. if isinstance(image_std, list):
  171. image_std = tuple(image_std)
  172. if data_format is None:
  173. data_format = ChannelDimension.FIRST
  174. kwargs["size"] = size
  175. kwargs["mask_size"] = mask_size
  176. kwargs["image_mean"] = image_mean
  177. kwargs["image_std"] = image_std
  178. kwargs["data_format"] = data_format
  179. # torch resize uses interpolation instead of resample
  180. # Check if resample is an int before checking if it's an instance of PILImageResampling
  181. # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
  182. # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
  183. resample = kwargs.pop("resample")
  184. kwargs["interpolation"] = (
  185. pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
  186. )
  187. return kwargs
  188. def _apply_non_overlapping_constraints(self, pred_masks: torch.Tensor) -> torch.Tensor:
  189. """
  190. Apply non-overlapping constraints to the object scores in pred_masks. Here we
  191. keep only the highest scoring object at each spatial location in pred_masks.
  192. """
  193. batch_size = pred_masks.size(0)
  194. if batch_size == 1:
  195. return pred_masks
  196. device = pred_masks.device
  197. # "max_obj_inds": object index of the object with the highest score at each location
  198. max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
  199. # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
  200. batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
  201. keep = max_obj_inds == batch_obj_inds
  202. # suppress overlapping regions' scores below -10.0 so that the foreground regions
  203. # don't overlap (here sigmoid(-10.0)=4.5398e-05)
  204. pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
  205. return pred_masks
  206. def post_process_masks(
  207. self,
  208. masks,
  209. original_sizes,
  210. mask_threshold=0.0,
  211. binarize=True,
  212. max_hole_area=0.0,
  213. max_sprinkle_area=0.0,
  214. apply_non_overlapping_constraints=False,
  215. **kwargs,
  216. ):
  217. """
  218. Remove padding and upscale masks to the original image size.
  219. Args:
  220. masks (`Union[torch.Tensor, List[torch.Tensor], np.ndarray, List[np.ndarray]]`):
  221. Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
  222. original_sizes (`Union[torch.Tensor, List[Tuple[int,int]]]`):
  223. The original sizes of each image before it was resized to the model's expected input shape, in (height,
  224. width) format.
  225. mask_threshold (`float`, *optional*, defaults to 0.0):
  226. Threshold for binarization and post-processing operations.
  227. binarize (`bool`, *optional*, defaults to `True`):
  228. Whether to binarize the masks.
  229. max_hole_area (`float`, *optional*, defaults to 0.0):
  230. The maximum area of a hole to fill.
  231. max_sprinkle_area (`float`, *optional*, defaults to 0.0):
  232. The maximum area of a sprinkle to fill.
  233. apply_non_overlapping_constraints (`bool`, *optional*, defaults to `False`):
  234. Whether to apply non-overlapping constraints to the masks.
  235. Returns:
  236. (`torch.Tensor`): Batched masks in batch_size, num_channels, height, width) format, where (height, width)
  237. is given by original_size.
  238. """
  239. if isinstance(original_sizes, (torch.Tensor, np.ndarray)):
  240. original_sizes = original_sizes.tolist()
  241. # TODO: add connected components kernel for postprocessing
  242. output_masks = []
  243. for i, original_size in enumerate(original_sizes):
  244. if isinstance(masks[i], np.ndarray):
  245. masks[i] = torch.from_numpy(masks[i])
  246. elif not isinstance(masks[i], torch.Tensor):
  247. raise ValueError("Input masks should be a list of `torch.tensors` or a list of `np.ndarray`")
  248. interpolated_mask = F.interpolate(masks[i], original_size, mode="bilinear", align_corners=False)
  249. if apply_non_overlapping_constraints:
  250. interpolated_mask = self._apply_non_overlapping_constraints(interpolated_mask)
  251. if binarize:
  252. interpolated_mask = interpolated_mask > mask_threshold
  253. output_masks.append(interpolated_mask)
  254. return output_masks
  255. @dataclass
  256. @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
  257. class Sam2VisionEncoderOutput(ModelOutput):
  258. r"""
  259. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  260. Sequence of hidden-states at the output of the last layer of the model.
  261. fpn_hidden_states (`tuple(torch.FloatTensor)`):
  262. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  263. `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
  264. fpn_position_encoding (`tuple(torch.FloatTensor)`):
  265. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  266. `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
  267. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  268. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  269. one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
  270. model at the output of each stage.
  271. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  272. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  273. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  274. the self-attention heads.
  275. """
  276. last_hidden_state: Optional[torch.FloatTensor] = None
  277. fpn_hidden_states: Optional[torch.FloatTensor] = None
  278. fpn_position_encoding: Optional[torch.FloatTensor] = None
  279. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  280. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  281. @dataclass
  282. @auto_docstring(custom_intro="Base class for the Sam2 model's output.")
  283. class Sam2ImageSegmentationOutput(ModelOutput):
  284. r"""
  285. iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
  286. The Intersection over Union (IoU) scores of the predicted masks.
  287. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
  288. The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
  289. by the processor to be brought to the original image size.
  290. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
  291. Logits for the object score, indicating if an object is present.
  292. image_embeddings (`tuple(torch.FloatTensor)`):
  293. The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
  294. tensor has shape `(batch_size, channels, height, width)`.
  295. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  296. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
  297. Hidden-states of the vision model at the output of each stage.
  298. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  299. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  300. Attentions weights of the vision model.
  301. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  302. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  303. Attentions weights of the mask decoder.
  304. """
  305. iou_scores: Optional[torch.FloatTensor] = None
  306. pred_masks: Optional[torch.FloatTensor] = None
  307. object_score_logits: Optional[torch.FloatTensor] = None
  308. image_embeddings: tuple[torch.FloatTensor, ...] = None
  309. vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  310. vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  311. mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  312. class Sam2PatchEmbeddings(nn.Module):
  313. r"""
  314. Turns pixel values into patch embeddings for transformer consumption.
  315. Args:
  316. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  317. Pixel values. Pixel values can be obtained using
  318. [`AutoImageProcessor`]. See [`Sam2ImageProcessorFast.__call__`] for details.
  319. Returns:
  320. embeddings (`torch.FloatTensor`):
  321. Patch embeddings depend on image_size, patch_kernel_size, patch_stride and patch_padding
  322. """
  323. def __init__(self, config: Sam2HieraDetConfig):
  324. super().__init__()
  325. num_channels = config.num_channels
  326. hidden_size = config.hidden_size
  327. self.projection = nn.Conv2d(
  328. num_channels,
  329. hidden_size,
  330. kernel_size=config.patch_kernel_size,
  331. stride=config.patch_stride,
  332. padding=config.patch_padding,
  333. )
  334. def forward(self, pixel_values):
  335. _, num_channels, height, width = pixel_values.shape
  336. embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
  337. return embeddings
  338. class Sam2SinePositionEmbedding(MaskFormerSinePositionEmbedding):
  339. pass
  340. class Sam2VisionNeck(nn.Module):
  341. def __init__(self, config: Sam2VisionConfig):
  342. super().__init__()
  343. self.config = config
  344. self.position_encoding = Sam2SinePositionEmbedding(num_pos_feats=config.fpn_hidden_size // 2, normalize=True)
  345. self.convs = nn.ModuleList()
  346. for in_channels in config.backbone_channel_list:
  347. self.convs.append(
  348. nn.Conv2d(
  349. in_channels=in_channels,
  350. out_channels=config.fpn_hidden_size,
  351. kernel_size=config.fpn_kernel_size,
  352. stride=config.fpn_stride,
  353. padding=config.fpn_padding,
  354. ),
  355. )
  356. self.fpn_top_down_levels = config.fpn_top_down_levels
  357. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  358. fpn_hidden_states = ()
  359. fpn_position_encoding = ()
  360. # forward in top-down order (from low to high resolution)
  361. n = len(self.convs) - 1
  362. for i in range(n, -1, -1):
  363. lateral_features = hidden_states[i].permute(0, 3, 1, 2)
  364. lateral_features = self.convs[n - i](lateral_features)
  365. if i not in self.fpn_top_down_levels or i == n:
  366. prev_features = lateral_features
  367. else:
  368. top_down_features = F.interpolate(
  369. prev_features.to(dtype=torch.float32),
  370. scale_factor=2.0,
  371. mode="nearest",
  372. align_corners=None,
  373. antialias=False,
  374. ).to(lateral_features.dtype)
  375. prev_features = lateral_features + top_down_features
  376. prev_position_encoding = self.position_encoding(
  377. prev_features.shape, prev_features.device, prev_features.dtype
  378. ).to(prev_features.dtype)
  379. fpn_hidden_states += (prev_features,)
  380. fpn_position_encoding += (prev_position_encoding,)
  381. return fpn_hidden_states, fpn_position_encoding
  382. def do_pool(x: torch.Tensor, query_stride: Optional[int] = None) -> torch.Tensor:
  383. if query_stride is None:
  384. return x
  385. # (B, H, W, C) -> (B, C, H, W)
  386. x = x.permute(0, 3, 1, 2)
  387. x = nn.functional.max_pool2d(x, kernel_size=query_stride, stride=query_stride, ceil_mode=False)
  388. # (B, C, H', W') -> (B, H', W', C)
  389. x = x.permute(0, 2, 3, 1)
  390. return x
  391. class Sam2MultiScaleAttention(nn.Module):
  392. def __init__(
  393. self,
  394. config: Sam2HieraDetConfig,
  395. dim: int,
  396. dim_out: int,
  397. num_attention_heads: int,
  398. query_stride: Optional[tuple[int, int]] = None,
  399. ):
  400. super().__init__()
  401. self.config = config
  402. self.dim = dim
  403. self.dim_out = dim_out
  404. self.query_stride = query_stride
  405. self.num_attention_heads = num_attention_heads
  406. head_dim = dim_out // num_attention_heads
  407. self.scale = head_dim**-0.5
  408. self.qkv = nn.Linear(dim, dim_out * 3)
  409. self.proj = nn.Linear(dim_out, dim_out)
  410. self.is_causal = False
  411. def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
  412. batch_size, height, width, _ = hidden_states.shape
  413. # qkv with shape (B, H * W, 3, nHead, C)
  414. qkv = self.qkv(hidden_states).reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  415. # q, k, v with shape (B, H * W, nheads, C)
  416. query, key, value = torch.unbind(qkv, 2)
  417. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  418. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  419. # Q pooling (for downsample at stage changes)
  420. if self.query_stride:
  421. query = do_pool(query.reshape(batch_size, height, width, -1), self.query_stride)
  422. height, width = query.shape[1:3] # downsampled shape
  423. query = query.reshape(batch_size, height * width, self.num_attention_heads, -1)
  424. # transpose query, key, value to (B, nHead, H * W, C)
  425. query = query.transpose(1, 2)
  426. key = key.transpose(1, 2)
  427. value = value.transpose(1, 2)
  428. attention_interface: Callable = eager_attention_forward
  429. if self.config._attn_implementation != "eager":
  430. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  431. attn_output, _ = attention_interface(
  432. self,
  433. query,
  434. key,
  435. value,
  436. attention_mask=None,
  437. is_causal=self.is_causal,
  438. scaling=self.scale,
  439. **kwargs,
  440. )
  441. attn_output = attn_output.reshape(batch_size, height, width, -1)
  442. attn_output = self.proj(attn_output)
  443. return attn_output
  444. class Sam2FeedForward(nn.Module):
  445. def __init__(
  446. self,
  447. input_dim: int,
  448. hidden_dim: int,
  449. output_dim: int,
  450. num_layers: int,
  451. activation: str = "relu",
  452. sigmoid_output: bool = False,
  453. ):
  454. super().__init__()
  455. self.num_layers = num_layers
  456. self.activation = ACT2FN[activation]
  457. self.proj_in = nn.Linear(input_dim, hidden_dim)
  458. self.proj_out = nn.Linear(hidden_dim, output_dim)
  459. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  460. self.sigmoid_output = sigmoid_output
  461. def forward(self, hidden_states):
  462. hidden_states = self.proj_in(hidden_states)
  463. hidden_states = self.activation(hidden_states)
  464. for layer in self.layers:
  465. hidden_states = self.activation(layer(hidden_states))
  466. hidden_states = self.proj_out(hidden_states)
  467. if self.sigmoid_output:
  468. hidden_states = F.sigmoid(hidden_states)
  469. return hidden_states
  470. class Sam2MultiScaleBlock(GradientCheckpointingLayer):
  471. def __init__(
  472. self,
  473. config: Sam2HieraDetConfig,
  474. stage_idx: int,
  475. block_idx: int,
  476. total_block_idx: int,
  477. ):
  478. super().__init__()
  479. # take embed dim from previous stage if first block of stage
  480. self.dim = (
  481. config.embed_dim_per_stage[stage_idx - 1]
  482. if stage_idx > 0 and block_idx == 0
  483. else config.embed_dim_per_stage[stage_idx]
  484. )
  485. self.dim_out = config.embed_dim_per_stage[stage_idx]
  486. self.layer_norm1 = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  487. # take window size from previous stage if first block of stage
  488. self.window_size = (
  489. config.window_size_per_stage[stage_idx - 1]
  490. if stage_idx > 0 and block_idx == 0
  491. else config.window_size_per_stage[stage_idx]
  492. )
  493. self.window_size = 0 if total_block_idx in config.global_attention_blocks else self.window_size
  494. # use query stride for first block of stage if stage is a query pool stage
  495. self.query_stride = (
  496. config.query_stride if 0 < stage_idx <= config.num_query_pool_stages and block_idx == 0 else None
  497. )
  498. self.attn = Sam2MultiScaleAttention(
  499. config,
  500. self.dim,
  501. self.dim_out,
  502. num_attention_heads=config.num_attention_heads_per_stage[stage_idx],
  503. query_stride=self.query_stride,
  504. )
  505. self.layer_norm2 = nn.LayerNorm(self.dim_out, eps=config.layer_norm_eps)
  506. self.mlp = Sam2FeedForward(
  507. self.dim_out,
  508. int(self.dim_out * config.mlp_ratio),
  509. self.dim_out,
  510. num_layers=2,
  511. activation=config.hidden_act,
  512. )
  513. if self.dim != self.dim_out:
  514. self.proj = nn.Linear(self.dim, self.dim_out)
  515. def forward(
  516. self,
  517. hidden_states: torch.Tensor,
  518. **kwargs: Unpack[TransformersKwargs],
  519. ) -> torch.FloatTensor:
  520. residual = hidden_states # batch_size, height, width, channel
  521. hidden_states = self.layer_norm1(hidden_states)
  522. # Skip connection
  523. if self.dim != self.dim_out:
  524. residual = do_pool(self.proj(hidden_states), self.query_stride)
  525. # Window partition
  526. window_size = self.window_size
  527. if self.window_size > 0:
  528. H, W = hidden_states.shape[1], hidden_states.shape[2]
  529. hidden_states, pad_hw = window_partition(hidden_states, window_size)
  530. # Window Attention + Q Pooling (if stage change)
  531. attn_output = self.attn(
  532. hidden_states=hidden_states,
  533. **kwargs,
  534. )
  535. hidden_states = attn_output
  536. if self.query_stride:
  537. # Shapes have changed due to Q pooling
  538. window_size = self.window_size // self.query_stride[0]
  539. H, W = residual.shape[1:3]
  540. pad_h = (window_size - H % window_size) % window_size
  541. pad_w = (window_size - W % window_size) % window_size
  542. pad_hw = (H + pad_h, W + pad_w)
  543. # Reverse window partition
  544. if self.window_size > 0:
  545. hidden_states = window_unpartition(hidden_states, window_size, pad_hw, (H, W))
  546. hidden_states = residual + hidden_states
  547. layernorm_output = self.layer_norm2(hidden_states)
  548. hidden_states = hidden_states + self.mlp(layernorm_output)
  549. return hidden_states
  550. @dataclass
  551. @auto_docstring(
  552. custom_intro="""
  553. Hiera model's outputs that also contains a pooling of the last hidden states.
  554. """
  555. )
  556. class Sam2HieraDetModelOutput(ModelOutput):
  557. r"""
  558. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  559. hidden-states at the output of the last layer of the model.
  560. intermediate_hidden_states (`tuple[torch.FloatTensor]` of shape `(batch_size, height, width, hidden_size)`):
  561. Sequence of hidden-states at the output of the intermediate layers of the model.
  562. """
  563. last_hidden_state: Optional[torch.FloatTensor] = None
  564. intermediate_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  565. @auto_docstring
  566. class Sam2PreTrainedModel(PreTrainedModel):
  567. config_class = Sam2Config
  568. base_model_prefix = "sam2"
  569. main_input_name = "pixel_values"
  570. _supports_sdpa = True
  571. _supports_flash_attn_2 = True
  572. _supports_attention_backend = True
  573. def _init_weights(self, module):
  574. std = self.config.initializer_range
  575. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  576. module.weight.data.normal_(mean=0.0, std=std)
  577. if module.bias is not None:
  578. module.bias.data.zero_()
  579. elif isinstance(module, nn.Embedding):
  580. module.weight.data.normal_(mean=0.0, std=std)
  581. if module.padding_idx is not None:
  582. module.weight.data[module.padding_idx].zero_()
  583. elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)):
  584. module.weight.data.fill_(1.0)
  585. module.bias.data.zero_()
  586. if isinstance(module, Sam2HieraDetModel):
  587. if module.pos_embed is not None:
  588. module.pos_embed.data.zero_()
  589. if module.pos_embed_window is not None:
  590. module.pos_embed_window.data.zero_()
  591. if isinstance(module, Sam2Model):
  592. if module.no_memory_embedding is not None:
  593. module.no_memory_embedding.data.zero_()
  594. class Sam2HieraDetModel(Sam2PreTrainedModel):
  595. config_class = Sam2HieraDetConfig
  596. main_input_name = "pixel_values"
  597. _can_record_outputs = {
  598. "hidden_states": Sam2MultiScaleBlock,
  599. "attentions": Sam2MultiScaleAttention,
  600. }
  601. def __init__(self, config: Sam2HieraDetConfig):
  602. super().__init__(config)
  603. self.patch_embed = Sam2PatchEmbeddings(config)
  604. # Windowed positional embedding (https://huggingface.co/papers/2311.05613)
  605. self.pos_embed = nn.Parameter(
  606. torch.zeros(1, config.hidden_size, *config.window_positional_embedding_background_size)
  607. )
  608. self.pos_embed_window = nn.Parameter(
  609. torch.zeros(1, config.hidden_size, config.window_size_per_stage[0], config.window_size_per_stage[0])
  610. )
  611. self.stage_ends = (np.cumsum(config.blocks_per_stage) - 1).tolist()
  612. self.blocks = nn.ModuleList()
  613. total_block_idx = 0
  614. for stage_idx, blocks_per_stage in enumerate(config.blocks_per_stage):
  615. for block_idx in range(blocks_per_stage):
  616. block = Sam2MultiScaleBlock(
  617. config=config, stage_idx=stage_idx, block_idx=block_idx, total_block_idx=total_block_idx
  618. )
  619. self.blocks.append(block)
  620. total_block_idx += 1
  621. def get_input_embeddings(self):
  622. return self.patch_embed
  623. def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
  624. h, w = hw
  625. window_embed = self.pos_embed_window
  626. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  627. pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
  628. pos_embed = pos_embed.permute(0, 2, 3, 1)
  629. return pos_embed
  630. @check_model_inputs()
  631. def forward(
  632. self,
  633. pixel_values: Optional[torch.FloatTensor] = None,
  634. **kwargs: Unpack[TransformersKwargs],
  635. ) -> Union[tuple, Sam2HieraDetModelOutput]:
  636. if pixel_values is None:
  637. raise ValueError("You have to specify pixel_values")
  638. hidden_states = self.patch_embed(pixel_values)
  639. hidden_states = hidden_states + self._get_pos_embed(hidden_states.shape[1:3])
  640. intermediate_hidden_states = ()
  641. for i, block_module in enumerate(self.blocks):
  642. hidden_states = block_module(hidden_states, **kwargs)
  643. if i in self.stage_ends:
  644. intermediate_hidden_states = intermediate_hidden_states + (hidden_states,)
  645. return Sam2HieraDetModelOutput(
  646. last_hidden_state=hidden_states,
  647. intermediate_hidden_states=intermediate_hidden_states,
  648. )
  649. @auto_docstring(
  650. custom_intro="""
  651. The vision model from Sam without any head or projection on top.
  652. """
  653. )
  654. class Sam2VisionModel(Sam2PreTrainedModel):
  655. config_class = Sam2VisionConfig
  656. main_input_name = "pixel_values"
  657. _can_record_outputs = {
  658. "hidden_states": Sam2MultiScaleBlock,
  659. "attentions": Sam2MultiScaleAttention,
  660. }
  661. def __init__(self, config: Sam2VisionConfig):
  662. super().__init__(config)
  663. self.config = config
  664. self.backbone = AutoModel.from_config(config.backbone_config)
  665. self.neck = Sam2VisionNeck(config)
  666. self.num_feature_levels = config.num_feature_levels
  667. self.post_init()
  668. def get_input_embeddings(self):
  669. return self.backbone.get_input_embeddings()
  670. @check_model_inputs()
  671. def forward(
  672. self,
  673. pixel_values: Optional[torch.FloatTensor] = None,
  674. **kwargs: Unpack[TransformersKwargs],
  675. ) -> Union[tuple, Sam2VisionEncoderOutput]:
  676. if pixel_values is None:
  677. raise ValueError("You have to specify pixel_values")
  678. # Forward through backbone
  679. backbone_output = self.backbone(pixel_values, **kwargs)
  680. hidden_states = backbone_output.last_hidden_state
  681. intermediate_hidden_states = backbone_output.intermediate_hidden_states
  682. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  683. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  684. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  685. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  686. return Sam2VisionEncoderOutput(
  687. last_hidden_state=hidden_states,
  688. fpn_hidden_states=fpn_hidden_states,
  689. fpn_position_encoding=fpn_position_encoding,
  690. )
  691. class Sam2PositionalEmbedding(nn.Module):
  692. def __init__(self, config: Sam2PromptEncoderConfig):
  693. super().__init__()
  694. self.scale = config.scale
  695. positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
  696. self.register_buffer("positional_embedding", positional_embedding)
  697. def forward(self, input_coords, input_shape=None):
  698. """Positionally encode points that are normalized to [0,1]."""
  699. coordinates = input_coords.clone()
  700. if input_shape is not None:
  701. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  702. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  703. coordinates.to(torch.float32)
  704. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  705. coordinates = 2 * coordinates - 1
  706. coordinates = coordinates.to(self.positional_embedding.dtype)
  707. coordinates = coordinates @ self.positional_embedding
  708. coordinates = 2 * np.pi * coordinates
  709. # outputs d_1 x ... x d_n x channel shape
  710. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  711. class Sam2MaskEmbedding(SamMaskEmbedding):
  712. pass
  713. class Sam2PromptEncoder(SamPromptEncoder):
  714. def __init__(self, config: Sam2PromptEncoderConfig):
  715. nn.Module.__init__(self)
  716. self.shared_embedding = Sam2PositionalEmbedding(config)
  717. self.mask_embed = Sam2MaskEmbedding(config)
  718. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  719. self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  720. self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
  721. self.input_image_size = config.image_size
  722. self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
  723. self.hidden_size = config.hidden_size
  724. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  725. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  726. """Embeds point prompts."""
  727. points = points + 0.5 # Shift to center of pixel
  728. if pad:
  729. points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
  730. labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
  731. input_shape = (self.input_image_size, self.input_image_size)
  732. point_embedding = self.shared_embedding(points, input_shape)
  733. # torch.where and expanding the labels tensor is required by the ONNX export
  734. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  735. # This is required for the ONNX export. The dtype, device need to be explicitly
  736. # specified as otherwise torch.onnx.export interprets as double
  737. point_embedding = torch.where(
  738. labels[..., None] != -10,
  739. point_embedding,
  740. torch.zeros_like(point_embedding),
  741. )
  742. # Add point embeddings for labels >= 0
  743. point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
  744. return point_embedding
  745. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  746. """Embeds box prompts."""
  747. boxes += 0.5 # Shift to center of pixel
  748. coords = boxes.view(*boxes.shape[:2], 2, 2)
  749. # add padding point for consistency with the original implementation
  750. coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
  751. corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
  752. corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
  753. corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
  754. corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
  755. return corner_embedding
  756. class Sam2Attention(nn.Module):
  757. """
  758. SAM2's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  759. values.
  760. """
  761. def __init__(self, config, downsample_rate=None):
  762. super().__init__()
  763. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  764. self.config = config
  765. self.hidden_size = config.hidden_size
  766. self.internal_dim = config.hidden_size // downsample_rate
  767. self.num_attention_heads = config.num_attention_heads
  768. self.head_dim = self.internal_dim // config.num_attention_heads
  769. self.scaling = self.head_dim**-0.5
  770. self.is_causal = False
  771. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  772. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  773. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  774. self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
  775. def forward(
  776. self,
  777. query: torch.Tensor,
  778. key: torch.Tensor,
  779. value: torch.Tensor,
  780. attention_similarity: Optional[torch.Tensor] = None,
  781. **kwargs: Unpack[TransformersKwargs],
  782. ) -> tuple[torch.Tensor, torch.Tensor]:
  783. # Input projections
  784. batch_size, point_batch_size = query.shape[:2]
  785. new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
  786. query = self.q_proj(query).view(*new_shape).transpose(1, 2)
  787. key = self.k_proj(key).view(*new_shape).transpose(1, 2)
  788. value = self.v_proj(value).view(*new_shape).transpose(1, 2)
  789. attention_interface: Callable = eager_attention_forward
  790. if self.config._attn_implementation != "eager":
  791. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  792. attn_output, attn_weights = attention_interface(
  793. self,
  794. query,
  795. key,
  796. value,
  797. attention_mask=attention_similarity,
  798. dropout=0.0,
  799. scaling=self.scaling,
  800. is_causal=self.is_causal,
  801. **kwargs,
  802. )
  803. attn_output = attn_output.reshape(
  804. batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
  805. ).contiguous()
  806. attn_output = self.o_proj(attn_output)
  807. return attn_output, attn_weights
  808. class Sam2TwoWayAttentionBlock(SamTwoWayAttentionBlock, GradientCheckpointingLayer):
  809. def __init__(self, config: Sam2MaskDecoderConfig, skip_first_layer_pe: bool = False):
  810. nn.Module.__init__(self)
  811. self.self_attn = Sam2Attention(config, downsample_rate=1)
  812. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  813. self.cross_attn_token_to_image = Sam2Attention(config)
  814. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  815. self.mlp = Sam2FeedForward(
  816. config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
  817. )
  818. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  819. self.layer_norm4 = nn.LayerNorm(config.hidden_size)
  820. self.cross_attn_image_to_token = Sam2Attention(config)
  821. self.skip_first_layer_pe = skip_first_layer_pe
  822. class Sam2TwoWayTransformer(SamTwoWayTransformer):
  823. pass
  824. class Sam2LayerNorm(SamLayerNorm):
  825. pass
  826. class Sam2MaskDecoder(SamMaskDecoder):
  827. def __init__(self, config: Sam2MaskDecoderConfig):
  828. super().__init__(config)
  829. del self.iou_prediction_head
  830. self.iou_prediction_head = Sam2FeedForward(
  831. self.hidden_size,
  832. config.iou_head_hidden_dim,
  833. self.num_mask_tokens,
  834. config.iou_head_depth,
  835. sigmoid_output=True,
  836. )
  837. self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
  838. self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
  839. self.obj_score_token = nn.Embedding(1, self.hidden_size)
  840. self.pred_obj_score_head = Sam2FeedForward(self.hidden_size, self.hidden_size, 1, 3)
  841. self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
  842. self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
  843. self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
  844. def _get_stability_scores(self, mask_logits):
  845. """
  846. Compute stability scores of the mask logits based on the IoU between upper and
  847. lower thresholds.
  848. """
  849. mask_logits = mask_logits.flatten(-2)
  850. stability_delta = self.dynamic_multimask_stability_delta
  851. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  852. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  853. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  854. return stability_scores
  855. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  856. """
  857. When outputting a single mask, if the stability score from the current single-mask
  858. output (based on output token 0) falls below a threshold, we instead select from
  859. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  860. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  861. """
  862. # The best mask from multimask output tokens (1~3)
  863. multimask_logits = all_mask_logits[:, :, 1:, :, :]
  864. multimask_iou_scores = all_iou_scores[:, :, 1:]
  865. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
  866. best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  867. best_scores_inds_expanded = best_scores_inds_expanded.expand(
  868. -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
  869. )
  870. best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
  871. best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
  872. # The mask from singlemask output token 0 and its stability score
  873. singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
  874. singlemask_iou_scores = all_iou_scores[:, :, 0:1]
  875. stability_scores = self._get_stability_scores(singlemask_logits)
  876. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  877. # Dynamically fall back to best multimask output upon low stability scores.
  878. mask_logits_out = torch.where(
  879. is_stable[..., None, None].expand_as(singlemask_logits),
  880. singlemask_logits,
  881. best_multimask_logits,
  882. )
  883. iou_scores_out = torch.where(
  884. is_stable.expand_as(singlemask_iou_scores),
  885. singlemask_iou_scores,
  886. best_multimask_iou_scores,
  887. )
  888. return mask_logits_out, iou_scores_out
  889. def forward(
  890. self,
  891. image_embeddings: torch.Tensor,
  892. image_positional_embeddings: torch.Tensor,
  893. sparse_prompt_embeddings: torch.Tensor,
  894. dense_prompt_embeddings: torch.Tensor,
  895. multimask_output: bool,
  896. high_resolution_features: list[torch.Tensor],
  897. attention_similarity: Optional[torch.Tensor] = None,
  898. target_embedding: Optional[torch.Tensor] = None,
  899. **kwargs: Unpack[TransformersKwargs],
  900. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  901. """
  902. Predict masks given image and prompt embeddings.
  903. Args:
  904. image_embeddings (`torch.Tensor`):
  905. The embeddings from the image encoder.
  906. image_positional_embeddings (`torch.Tensor`):
  907. Positional encoding with the shape of image_embeddings.
  908. sparse_prompt_embeddings (`torch.Tensor`):
  909. The embeddings of the points and boxes.
  910. dense_prompt_embeddings (`torch.Tensor`):
  911. The embeddings of the mask inputs.
  912. multimask_output (`bool`):
  913. Whether to return multiple masks or a single mask.
  914. high_resolution_features (`list[torch.Tensor]`, *optional*):
  915. The high-resolution features from the vision encoder.
  916. attention_similarity (`torch.Tensor`, *optional*):
  917. The attention similarity tensor.
  918. target_embedding (`torch.Tensor`, *optional*):
  919. The target embedding.
  920. """
  921. batch_size, num_channels, height, width = image_embeddings.shape
  922. point_batch_size = sparse_prompt_embeddings.shape[1]
  923. # Concatenate output tokens
  924. output_tokens = torch.cat(
  925. [
  926. self.obj_score_token.weight,
  927. self.iou_token.weight,
  928. self.mask_tokens.weight,
  929. ],
  930. dim=0,
  931. )
  932. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  933. if sparse_prompt_embeddings.shape[0] != 0:
  934. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  935. else:
  936. tokens = output_tokens
  937. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  938. # Expand per-image data in batch direction to be per-mask
  939. image_embeddings = image_embeddings + dense_prompt_embeddings
  940. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
  941. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  942. # Run the transformer
  943. point_embeddings, image_embeddings = self.transformer(
  944. point_embeddings=point_embeddings,
  945. image_embeddings=image_embeddings,
  946. image_positional_embeddings=image_positional_embeddings,
  947. attention_similarity=attention_similarity,
  948. target_embedding=target_embedding,
  949. **kwargs,
  950. )
  951. iou_token_out = point_embeddings[:, :, 1, :]
  952. mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
  953. # Upscale mask embeddings and predict masks using the mask tokens
  954. image_embeddings = image_embeddings.transpose(2, 3).view(
  955. batch_size * point_batch_size, num_channels, height, width
  956. )
  957. feat_s0, feat_s1 = high_resolution_features
  958. feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
  959. feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
  960. upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
  961. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  962. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
  963. hyper_in_list: list[torch.Tensor] = []
  964. for i in range(self.num_mask_tokens):
  965. current_mlp = self.output_hypernetworks_mlps[i]
  966. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  967. hyper_in = torch.stack(hyper_in_list, dim=2)
  968. _, num_channels, height, width = upscaled_embedding.shape
  969. upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
  970. masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
  971. # Generate mask quality predictions
  972. iou_pred = self.iou_prediction_head(iou_token_out)
  973. object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
  974. # Select the correct mask or masks for output
  975. if multimask_output:
  976. mask_slice = slice(1, None)
  977. masks = masks[:, :, mask_slice, :, :]
  978. iou_pred = iou_pred[:, :, mask_slice]
  979. elif self.dynamic_multimask_via_stability and not self.training:
  980. mask_slice = slice(0, 1)
  981. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  982. else:
  983. mask_slice = slice(0, 1)
  984. masks = masks[:, :, mask_slice, :, :]
  985. iou_pred = iou_pred[:, :, mask_slice]
  986. sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
  987. return masks, iou_pred, sam_tokens_out, object_score_logits
  988. @auto_docstring(
  989. custom_intro="""
  990. Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
  991. input points and labels, boxes, or masks.
  992. """
  993. )
  994. class Sam2Model(SamModel):
  995. _keys_to_ignore_on_load_unexpected = [
  996. r"^memory_.*",
  997. r"^mask_downsample.*",
  998. r"^object_pointer_proj.*",
  999. r"^temporal_positional_encoding_projection_layer.*",
  1000. "no_memory_positional_encoding",
  1001. "no_object_pointer",
  1002. "occlusion_spatial_embedding_parameter",
  1003. ]
  1004. def __init__(self, config: Sam2Config):
  1005. PreTrainedModel.__init__(self, config)
  1006. self.shared_image_embedding = Sam2PositionalEmbedding(config.prompt_encoder_config)
  1007. self.vision_encoder = AutoModel.from_config(config.vision_config)
  1008. self.prompt_encoder = Sam2PromptEncoder(config.prompt_encoder_config)
  1009. # The module using it is not a PreTrainedModel subclass so we need this
  1010. config.mask_decoder_config._attn_implementation = config._attn_implementation
  1011. self.mask_decoder = Sam2MaskDecoder(config.mask_decoder_config)
  1012. self.num_feature_levels = config.vision_config.num_feature_levels
  1013. self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
  1014. # a single token to indicate no memory embedding from previous frames
  1015. self.hidden_dim = config.vision_config.fpn_hidden_size
  1016. self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  1017. self.post_init()
  1018. def get_image_wide_positional_embeddings(self) -> torch.Tensor:
  1019. size = self.prompt_encoder.image_embedding_size
  1020. target_device = self.shared_image_embedding.positional_embedding.device
  1021. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  1022. grid = torch.ones(size, device=target_device, dtype=target_dtype)
  1023. y_embed = grid.cumsum(dim=0) - 0.5
  1024. x_embed = grid.cumsum(dim=1) - 0.5
  1025. y_embed = y_embed / size[0]
  1026. x_embed = x_embed / size[1]
  1027. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  1028. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  1029. @torch.no_grad()
  1030. def get_image_embeddings(
  1031. self,
  1032. pixel_values: torch.FloatTensor,
  1033. **kwargs: Unpack[TransformersKwargs],
  1034. ) -> list[torch.Tensor]:
  1035. r"""
  1036. Returns the image embeddings by passing the pixel values through the vision encoder.
  1037. Args:
  1038. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1039. Input pixel values
  1040. """
  1041. batch_size = pixel_values.shape[0]
  1042. feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
  1043. # add no memory embedding to the last feature map
  1044. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1045. # reshape feature maps to the same shape as the backbone feature sizes
  1046. image_embeddings = [
  1047. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1048. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1049. ]
  1050. return image_embeddings
  1051. def get_image_features(
  1052. self,
  1053. pixel_values: torch.FloatTensor,
  1054. **kwargs: Unpack[TransformersKwargs],
  1055. ) -> tuple[
  1056. list[torch.Tensor],
  1057. list[torch.Tensor],
  1058. Optional[tuple[torch.FloatTensor, ...]],
  1059. Optional[tuple[torch.FloatTensor, ...]],
  1060. ]:
  1061. r"""
  1062. Extract and preprocess image features using the vision encoder.
  1063. Args:
  1064. pixel_values (`torch.FloatTensor`):
  1065. Input pixel values of shape `(batch_size, num_channels, height, width)`.
  1066. Returns:
  1067. `tuple`: A tuple containing:
  1068. - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
  1069. - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
  1070. - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
  1071. - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
  1072. """
  1073. vision_outputs: Sam2VisionEncoderOutput = self.vision_encoder(
  1074. pixel_values,
  1075. **kwargs,
  1076. )
  1077. feature_maps = vision_outputs.fpn_hidden_states
  1078. feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
  1079. # precompute projected level 0 and level 1 features in SAM decoder
  1080. # to avoid running it again on every SAM click
  1081. feature_maps = list(feature_maps)
  1082. feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
  1083. feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
  1084. # flatten NxCxHxW to HWxNxC
  1085. feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
  1086. feature_maps_position_embeddings = [
  1087. feature_map_position_embedding.flatten(2).permute(2, 0, 1)
  1088. for feature_map_position_embedding in feature_maps_position_embeddings
  1089. ]
  1090. return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
  1091. @check_model_inputs()
  1092. @auto_docstring
  1093. def forward(
  1094. self,
  1095. pixel_values: Optional[torch.FloatTensor] = None,
  1096. input_points: Optional[torch.FloatTensor] = None,
  1097. input_labels: Optional[torch.LongTensor] = None,
  1098. input_boxes: Optional[torch.FloatTensor] = None,
  1099. input_masks: Optional[torch.LongTensor] = None,
  1100. image_embeddings: Optional[torch.FloatTensor] = None,
  1101. multimask_output: bool = True,
  1102. attention_similarity: Optional[torch.FloatTensor] = None,
  1103. target_embedding: Optional[torch.FloatTensor] = None,
  1104. **kwargs: Unpack[TransformersKwargs],
  1105. ) -> Sam2ImageSegmentationOutput:
  1106. r"""
  1107. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  1108. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1109. better results. The points can be obtained by passing a list of list of list to the processor that will
  1110. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  1111. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  1112. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1113. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1114. coordinates of the point. If a different number of points is passed either for each image, or for each
  1115. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1116. computation of the embedding will be skipped for these points using the labels.
  1117. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  1118. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1119. official implementation, there are 3 types of labels
  1120. - `1`: the point is a point that contains the object of interest
  1121. - `0`: the point is a point that does not contain the object of interest
  1122. - `-1`: the point corresponds to the background
  1123. We added the label:
  1124. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1125. The padding labels should be automatically done by the processor.
  1126. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  1127. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1128. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1129. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  1130. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  1131. In the order (`x1`, `y1`, `x2`, `y2`):
  1132. - `x1`: the x coordinate of the top left point of the input box
  1133. - `y1`: the y coordinate of the top left point of the input box
  1134. - `x2`: the x coordinate of the bottom right point of the input box
  1135. - `y2`: the y coordinate of the bottom right point of the input box
  1136. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  1137. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1138. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1139. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1140. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1141. Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
  1142. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1143. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  1144. multimask_output (`bool`, *optional*):
  1145. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1146. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1147. "best" mask, by specifying `multimask_output=False`.
  1148. attention_similarity (`torch.FloatTensor`, *optional*):
  1149. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  1150. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1151. target_embedding (`torch.FloatTensor`, *optional*):
  1152. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  1153. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  1154. Example:
  1155. ```python
  1156. >>> from PIL import Image
  1157. >>> import requests
  1158. >>> from transformers import AutoModel, AutoProcessor
  1159. >>> model = AutoModel.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1160. >>> processor = AutoProcessor.from_pretrained("danelcsb/sam2.1_hiera_tiny")
  1161. >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1162. >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
  1163. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1164. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1165. >>> # Get segmentation mask
  1166. >>> outputs = model(**inputs)
  1167. >>> # Postprocess masks
  1168. >>> masks = processor.post_process_masks(
  1169. ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
  1170. ... )
  1171. ```
  1172. """
  1173. if not ((pixel_values is None) ^ (image_embeddings is None)):
  1174. raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
  1175. if input_points is not None and input_boxes is not None:
  1176. if input_points.shape[1] != input_boxes.shape[1]:
  1177. raise ValueError(
  1178. f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
  1179. )
  1180. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1181. # repeat with batch size
  1182. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
  1183. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1184. vision_attentions = None
  1185. vision_hidden_states = None
  1186. if pixel_values is not None:
  1187. feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
  1188. pixel_values,
  1189. **kwargs,
  1190. )
  1191. # add no memory embedding to the last feature map
  1192. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  1193. # reshape feature maps to the same shape as the backbone feature sizes
  1194. image_embeddings = [
  1195. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  1196. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  1197. ]
  1198. if input_points is not None and input_labels is None:
  1199. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1200. if input_points is None and input_boxes is None:
  1201. # If no points are provide, pad with an empty point (with label -1)
  1202. input_points = torch.zeros(
  1203. batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
  1204. )
  1205. input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
  1206. if input_masks is not None:
  1207. # If mask_inputs is provided, downsize it into low-res mask input if needed
  1208. # and feed it as a dense mask prompt into the SAM mask encoder
  1209. if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
  1210. input_masks = F.interpolate(
  1211. input_masks.float(),
  1212. size=self.prompt_encoder.mask_input_size,
  1213. align_corners=False,
  1214. mode="bilinear",
  1215. antialias=True, # use antialias for downsampling
  1216. ).to(input_masks.dtype)
  1217. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1218. input_points=input_points,
  1219. input_labels=input_labels,
  1220. input_boxes=input_boxes,
  1221. input_masks=input_masks,
  1222. )
  1223. low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
  1224. image_embeddings=image_embeddings[-1],
  1225. image_positional_embeddings=image_positional_embeddings,
  1226. sparse_prompt_embeddings=sparse_embeddings,
  1227. dense_prompt_embeddings=dense_embeddings,
  1228. multimask_output=multimask_output,
  1229. high_resolution_features=image_embeddings[:-1],
  1230. attention_similarity=attention_similarity,
  1231. target_embedding=target_embedding,
  1232. **kwargs,
  1233. )
  1234. return Sam2ImageSegmentationOutput(
  1235. iou_scores=iou_scores,
  1236. pred_masks=low_res_multimasks,
  1237. object_score_logits=object_score_logits,
  1238. image_embeddings=image_embeddings,
  1239. vision_hidden_states=vision_hidden_states,
  1240. vision_attentions=vision_attentions,
  1241. )
  1242. __all__ = [
  1243. "Sam2Model",
  1244. "Sam2VisionModel",
  1245. "Sam2PreTrainedModel",
  1246. "Sam2ImageProcessorFast",
  1247. "Sam2HieraDetModel",
  1248. ]