modeling_edgetam.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/edgetam/modular_edgetam.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_edgetam.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Callable, Optional, Union
  24. import numpy as np
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from torch import Tensor
  29. from transformers.utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
  30. from ...activations import ACT2FN
  31. from ...modeling_outputs import BaseModelOutput
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...pytorch_utils import compile_compatible_method_lru_cache
  35. from ...utils import ModelOutput, auto_docstring
  36. from ..auto import AutoModel
  37. from .configuration_edgetam import (
  38. EdgeTamConfig,
  39. EdgeTamMaskDecoderConfig,
  40. EdgeTamPromptEncoderConfig,
  41. EdgeTamVisionConfig,
  42. )
  43. # fix this in modular
  44. if True:
  45. from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
  46. class EdgeTamLayerNorm(nn.LayerNorm):
  47. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  48. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  49. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  50. """
  51. def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
  52. super().__init__(normalized_shape, eps=eps, **kwargs)
  53. if data_format not in ["channels_last", "channels_first"]:
  54. raise NotImplementedError(f"Unsupported data format: {data_format}")
  55. self.data_format = data_format
  56. def forward(self, features: torch.Tensor) -> torch.Tensor:
  57. """
  58. Args:
  59. features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
  60. """
  61. if self.data_format == "channels_first":
  62. features = features.permute(0, 2, 3, 1)
  63. features = super().forward(features)
  64. features = features.permute(0, 3, 1, 2)
  65. else:
  66. features = super().forward(features)
  67. return features
  68. @dataclass
  69. @auto_docstring(custom_intro="Base class for the vision encoder's outputs.")
  70. class EdgeTamVisionEncoderOutput(ModelOutput):
  71. r"""
  72. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, height, width, hidden_size)`):
  73. Sequence of hidden-states at the output of the last layer of the model.
  74. fpn_hidden_states (`tuple(torch.FloatTensor)`):
  75. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  76. `(batch_size, hidden_size, height, width)`. Feature maps from the Feature Pyramid Network neck.
  77. fpn_position_encoding (`tuple(torch.FloatTensor)`):
  78. Tuple of `torch.FloatTensor` (one for each feature level, from high to low resolution) of shape
  79. `(batch_size, hidden_size, height, width)`. Positional encodings corresponding to the `fpn_hidden_states`.
  80. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  81. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  82. one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`. Hidden-states of the
  83. model at the output of each stage.
  84. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  85. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  86. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  87. the self-attention heads.
  88. """
  89. last_hidden_state: Optional[torch.FloatTensor] = None
  90. fpn_hidden_states: Optional[torch.FloatTensor] = None
  91. fpn_position_encoding: Optional[torch.FloatTensor] = None
  92. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  93. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  94. def eager_attention_forward(
  95. module: nn.Module,
  96. query: torch.Tensor,
  97. key: torch.Tensor,
  98. value: torch.Tensor,
  99. attention_mask: Optional[torch.Tensor],
  100. scaling: float,
  101. dropout: float = 0.0,
  102. **kwargs,
  103. ):
  104. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  105. if attention_mask is not None:
  106. attn_weights = attn_weights + attention_mask
  107. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  108. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  109. attn_output = torch.matmul(attn_weights, value)
  110. attn_output = attn_output.transpose(1, 2).contiguous()
  111. return attn_output, attn_weights
  112. class EdgeTamAttention(nn.Module):
  113. """
  114. EDGETAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  115. values.
  116. """
  117. def __init__(self, config, downsample_rate=None):
  118. super().__init__()
  119. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  120. self.config = config
  121. self.hidden_size = config.hidden_size
  122. self.internal_dim = config.hidden_size // downsample_rate
  123. self.num_attention_heads = config.num_attention_heads
  124. self.head_dim = self.internal_dim // config.num_attention_heads
  125. self.scaling = self.head_dim**-0.5
  126. self.is_causal = False
  127. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  128. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  129. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  130. self.o_proj = nn.Linear(self.internal_dim, self.hidden_size)
  131. def forward(
  132. self,
  133. query: torch.Tensor,
  134. key: torch.Tensor,
  135. value: torch.Tensor,
  136. attention_similarity: Optional[torch.Tensor] = None,
  137. **kwargs: Unpack[TransformersKwargs],
  138. ) -> tuple[torch.Tensor, torch.Tensor]:
  139. # Input projections
  140. batch_size, point_batch_size = query.shape[:2]
  141. new_shape = (batch_size * point_batch_size, -1, self.num_attention_heads, self.head_dim)
  142. query = self.q_proj(query).view(*new_shape).transpose(1, 2)
  143. key = self.k_proj(key).view(*new_shape).transpose(1, 2)
  144. value = self.v_proj(value).view(*new_shape).transpose(1, 2)
  145. attention_interface: Callable = eager_attention_forward
  146. if self.config._attn_implementation != "eager":
  147. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  148. attn_output, attn_weights = attention_interface(
  149. self,
  150. query,
  151. key,
  152. value,
  153. attention_mask=attention_similarity,
  154. dropout=0.0,
  155. scaling=self.scaling,
  156. is_causal=self.is_causal,
  157. **kwargs,
  158. )
  159. attn_output = attn_output.reshape(
  160. batch_size, point_batch_size, -1, self.num_attention_heads * self.head_dim
  161. ).contiguous()
  162. attn_output = self.o_proj(attn_output)
  163. return attn_output, attn_weights
  164. class EdgeTamTwoWayAttentionBlock(nn.Module):
  165. def __init__(self, config: EdgeTamMaskDecoderConfig, skip_first_layer_pe: bool = False):
  166. """
  167. A transformer block with four layers:
  168. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  169. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  170. Arguments:
  171. config (`EdgeTamMaskDecoderConfig`):
  172. The configuration file used to instantiate the block
  173. attention_downsample_rate (*optionalk*, int, defaults to 2):
  174. The downsample ratio of the block used to reduce the inner dim of the attention.
  175. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  176. Whether or not to skip the addition of the query_point_embedding on the first layer.
  177. """
  178. super().__init__()
  179. self.self_attn = EdgeTamAttention(config, downsample_rate=1)
  180. self.layer_norm1 = nn.LayerNorm(config.hidden_size)
  181. self.cross_attn_token_to_image = EdgeTamAttention(config)
  182. self.layer_norm2 = nn.LayerNorm(config.hidden_size)
  183. self.mlp = EdgeTamFeedForward(
  184. config.hidden_size, config.mlp_dim, config.hidden_size, num_layers=config.num_hidden_layers
  185. )
  186. self.layer_norm3 = nn.LayerNorm(config.hidden_size)
  187. self.layer_norm4 = nn.LayerNorm(config.hidden_size)
  188. self.cross_attn_image_to_token = EdgeTamAttention(config)
  189. self.skip_first_layer_pe = skip_first_layer_pe
  190. def forward(
  191. self,
  192. queries: Tensor,
  193. keys: Tensor,
  194. query_point_embedding: Tensor,
  195. key_point_embedding: Tensor,
  196. attention_similarity: Tensor,
  197. **kwargs: Unpack[TransformersKwargs],
  198. ):
  199. # Self attention block
  200. if self.skip_first_layer_pe:
  201. queries, _ = self.self_attn(query=queries, key=queries, value=queries)
  202. else:
  203. query = queries + query_point_embedding
  204. attn_out, _ = self.self_attn(query=query, key=query, value=queries)
  205. queries = queries + attn_out
  206. queries = self.layer_norm1(queries)
  207. # Cross attention block, tokens attending to image embedding
  208. query = queries + query_point_embedding
  209. key = keys + key_point_embedding
  210. attn_out, _ = self.cross_attn_token_to_image(
  211. query=query, key=key, value=keys, attention_similarity=attention_similarity
  212. )
  213. queries = queries + attn_out
  214. queries = self.layer_norm2(queries)
  215. # MLP block
  216. mlp_out = self.mlp(queries)
  217. queries = queries + mlp_out
  218. queries = self.layer_norm3(queries)
  219. # Cross attention block, image embedding attending to tokens
  220. query = queries + query_point_embedding
  221. key = keys + key_point_embedding
  222. attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  223. keys = keys + attn_out
  224. keys = self.layer_norm4(keys)
  225. return queries, keys, attn_out
  226. class EdgeTamFeedForward(nn.Module):
  227. def __init__(
  228. self,
  229. input_dim: int,
  230. hidden_dim: int,
  231. output_dim: int,
  232. num_layers: int,
  233. activation: str = "relu",
  234. sigmoid_output: bool = False,
  235. ):
  236. super().__init__()
  237. self.num_layers = num_layers
  238. self.activation = ACT2FN[activation]
  239. self.proj_in = nn.Linear(input_dim, hidden_dim)
  240. self.proj_out = nn.Linear(hidden_dim, output_dim)
  241. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  242. self.sigmoid_output = sigmoid_output
  243. def forward(self, hidden_states):
  244. hidden_states = self.proj_in(hidden_states)
  245. hidden_states = self.activation(hidden_states)
  246. for layer in self.layers:
  247. hidden_states = self.activation(layer(hidden_states))
  248. hidden_states = self.proj_out(hidden_states)
  249. if self.sigmoid_output:
  250. hidden_states = F.sigmoid(hidden_states)
  251. return hidden_states
  252. @auto_docstring
  253. class EdgeTamPreTrainedModel(PreTrainedModel):
  254. config_class = EdgeTamConfig
  255. base_model_prefix = "edgetam"
  256. main_input_name = "pixel_values"
  257. _supports_sdpa = True
  258. _supports_flash_attn_2 = True
  259. _supports_attention_backend = True
  260. def _init_weights(self, module):
  261. std = self.config.initializer_range
  262. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  263. module.weight.data.normal_(mean=0.0, std=std)
  264. if module.bias is not None:
  265. module.bias.data.zero_()
  266. elif isinstance(module, nn.Embedding):
  267. module.weight.data.normal_(mean=0.0, std=std)
  268. if module.padding_idx is not None:
  269. module.weight.data[module.padding_idx].zero_()
  270. elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)):
  271. module.weight.data.fill_(1.0)
  272. module.bias.data.zero_()
  273. if isinstance(module, EdgeTamModel):
  274. if module.no_memory_embedding is not None:
  275. module.no_memory_embedding.data.zero_()
  276. # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
  277. class EdgeTamSinePositionEmbedding(nn.Module):
  278. """
  279. This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
  280. need paper, generalized to work on images.
  281. """
  282. def __init__(
  283. self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
  284. ):
  285. super().__init__()
  286. if scale is not None and normalize is False:
  287. raise ValueError("normalize should be True if scale is passed")
  288. self.num_pos_feats = num_pos_feats
  289. self.temperature = temperature
  290. self.normalize = normalize
  291. self.scale = 2 * math.pi if scale is None else scale
  292. @compile_compatible_method_lru_cache(maxsize=1)
  293. def forward(
  294. self,
  295. shape: torch.Size,
  296. device: Union[torch.device, str],
  297. dtype: torch.dtype,
  298. mask: Optional[Tensor] = None,
  299. ) -> Tensor:
  300. if mask is None:
  301. mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool)
  302. not_mask = (~mask).to(dtype)
  303. y_embed = not_mask.cumsum(1)
  304. x_embed = not_mask.cumsum(2)
  305. if self.normalize:
  306. eps = 1e-6
  307. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  308. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  309. dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype)
  310. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
  311. pos_x = x_embed[:, :, :, None] / dim_t
  312. pos_y = y_embed[:, :, :, None] / dim_t
  313. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  314. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  315. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  316. return pos
  317. class EdgeTamVisionNeck(nn.Module):
  318. def __init__(self, config: EdgeTamVisionConfig):
  319. super().__init__()
  320. self.config = config
  321. self.position_encoding = EdgeTamSinePositionEmbedding(
  322. num_pos_feats=config.fpn_hidden_size // 2, normalize=True
  323. )
  324. self.convs = nn.ModuleList()
  325. for in_channels in config.backbone_channel_list:
  326. self.convs.append(
  327. nn.Conv2d(
  328. in_channels=in_channels,
  329. out_channels=config.fpn_hidden_size,
  330. kernel_size=config.fpn_kernel_size,
  331. stride=config.fpn_stride,
  332. padding=config.fpn_padding,
  333. ),
  334. )
  335. self.fpn_top_down_levels = config.fpn_top_down_levels
  336. def forward(self, hidden_states: torch.Tensor) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
  337. fpn_hidden_states = ()
  338. fpn_position_encoding = ()
  339. # forward in top-down order (from low to high resolution)
  340. n = len(self.convs) - 1
  341. for i in range(n, -1, -1):
  342. lateral_features = hidden_states[i].permute(0, 3, 1, 2)
  343. lateral_features = self.convs[n - i](lateral_features)
  344. if i not in self.fpn_top_down_levels or i == n:
  345. prev_features = lateral_features
  346. else:
  347. top_down_features = F.interpolate(
  348. prev_features.to(dtype=torch.float32),
  349. scale_factor=2.0,
  350. mode="nearest",
  351. align_corners=None,
  352. antialias=False,
  353. ).to(lateral_features.dtype)
  354. prev_features = lateral_features + top_down_features
  355. prev_position_encoding = self.position_encoding(
  356. prev_features.shape, prev_features.device, prev_features.dtype
  357. ).to(prev_features.dtype)
  358. fpn_hidden_states += (prev_features,)
  359. fpn_position_encoding += (prev_position_encoding,)
  360. return fpn_hidden_states, fpn_position_encoding
  361. @auto_docstring(
  362. custom_intro="""
  363. The vision model from EdgeTAM without any head or projection on top.
  364. """
  365. )
  366. class EdgeTamVisionModel(EdgeTamPreTrainedModel):
  367. config_class = EdgeTamVisionConfig
  368. main_input_name = "pixel_values"
  369. _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel}
  370. def __init__(self, config: EdgeTamVisionConfig):
  371. super().__init__(config)
  372. self.config = config
  373. self.backbone = AutoModel.from_config(config.backbone_config)
  374. self.neck = EdgeTamVisionNeck(config)
  375. self.num_feature_levels = config.num_feature_levels
  376. self.post_init()
  377. @check_model_inputs()
  378. def forward(
  379. self,
  380. pixel_values: Optional[torch.FloatTensor] = None,
  381. **kwargs: Unpack[TransformersKwargs],
  382. ) -> Union[tuple, EdgeTamVisionEncoderOutput]:
  383. if pixel_values is None:
  384. raise ValueError("You have to specify pixel_values")
  385. # Forward through backbone
  386. backbone_output = self.backbone(pixel_values)
  387. intermediate_hidden_states = backbone_output.last_hidden_state
  388. intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]
  389. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  390. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  391. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  392. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  393. return EdgeTamVisionEncoderOutput(
  394. last_hidden_state=intermediate_hidden_states[-1],
  395. fpn_hidden_states=fpn_hidden_states,
  396. fpn_position_encoding=fpn_position_encoding,
  397. )
  398. @dataclass
  399. @auto_docstring(custom_intro="Base class for the EdgeTam model's output.")
  400. class EdgeTamImageSegmentationOutput(ModelOutput):
  401. r"""
  402. iou_scores (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks)`):
  403. The Intersection over Union (IoU) scores of the predicted masks.
  404. pred_masks (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_masks, height, width)`):
  405. The predicted low-resolution masks. This is an alias for `low_res_masks`. These masks need to be post-processed
  406. by the processor to be brought to the original image size.
  407. object_score_logits (`torch.FloatTensor` of shape `(batch_size, point_batch_size, 1)`):
  408. Logits for the object score, indicating if an object is present.
  409. image_embeddings (`tuple(torch.FloatTensor)`):
  410. The features from the FPN, which are used by the mask decoder. This is a tuple of `torch.FloatTensor` where each
  411. tensor has shape `(batch_size, channels, height, width)`.
  412. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
  413. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, height, width, hidden_size)`.
  414. Hidden-states of the vision model at the output of each stage.
  415. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  416. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  417. Attentions weights of the vision model.
  418. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
  419. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
  420. Attentions weights of the mask decoder.
  421. """
  422. iou_scores: Optional[torch.FloatTensor] = None
  423. pred_masks: Optional[torch.FloatTensor] = None
  424. object_score_logits: Optional[torch.FloatTensor] = None
  425. image_embeddings: tuple[torch.FloatTensor, ...] = None
  426. vision_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  427. vision_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  428. mask_decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  429. class EdgeTamPositionalEmbedding(nn.Module):
  430. def __init__(self, config: EdgeTamPromptEncoderConfig):
  431. super().__init__()
  432. self.scale = config.scale
  433. positional_embedding = self.scale * torch.randn((2, config.hidden_size // 2))
  434. self.register_buffer("positional_embedding", positional_embedding)
  435. def forward(self, input_coords, input_shape=None):
  436. """Positionally encode points that are normalized to [0,1]."""
  437. coordinates = input_coords.clone()
  438. if input_shape is not None:
  439. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  440. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  441. coordinates.to(torch.float32)
  442. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  443. coordinates = 2 * coordinates - 1
  444. coordinates = coordinates.to(self.positional_embedding.dtype)
  445. coordinates = coordinates @ self.positional_embedding
  446. coordinates = 2 * np.pi * coordinates
  447. # outputs d_1 x ... x d_n x channel shape
  448. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  449. class EdgeTamMaskEmbedding(nn.Module):
  450. def __init__(self, config: EdgeTamPromptEncoderConfig):
  451. super().__init__()
  452. self.mask_input_channels = config.mask_input_channels // 4
  453. self.activation = ACT2FN[config.hidden_act]
  454. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  455. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  456. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  457. self.layer_norm1 = EdgeTamLayerNorm(
  458. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  459. )
  460. self.layer_norm2 = EdgeTamLayerNorm(
  461. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  462. )
  463. def forward(self, masks):
  464. hidden_states = self.conv1(masks)
  465. hidden_states = self.layer_norm1(hidden_states)
  466. hidden_states = self.activation(hidden_states)
  467. hidden_states = self.conv2(hidden_states)
  468. hidden_states = self.layer_norm2(hidden_states)
  469. hidden_states = self.activation(hidden_states)
  470. dense_embeddings = self.conv3(hidden_states)
  471. return dense_embeddings
  472. class EdgeTamPromptEncoder(nn.Module):
  473. def __init__(self, config: EdgeTamPromptEncoderConfig):
  474. super().__init__()
  475. self.shared_embedding = EdgeTamPositionalEmbedding(config)
  476. self.mask_embed = EdgeTamMaskEmbedding(config)
  477. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  478. self.image_embedding_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  479. self.mask_input_size = (4 * config.image_size // config.patch_size, 4 * config.image_size // config.patch_size)
  480. self.input_image_size = config.image_size
  481. self.point_embed = nn.Embedding(config.num_point_embeddings, config.hidden_size)
  482. self.hidden_size = config.hidden_size
  483. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  484. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  485. """Embeds point prompts."""
  486. points = points + 0.5 # Shift to center of pixel
  487. if pad:
  488. points = torch.nn.functional.pad(points, (0, 0, 0, 1), mode="constant", value=0)
  489. labels = torch.nn.functional.pad(labels, (0, 1), mode="constant", value=-1)
  490. input_shape = (self.input_image_size, self.input_image_size)
  491. point_embedding = self.shared_embedding(points, input_shape)
  492. # torch.where and expanding the labels tensor is required by the ONNX export
  493. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  494. # This is required for the ONNX export. The dtype, device need to be explicitly
  495. # specified as otherwise torch.onnx.export interprets as double
  496. point_embedding = torch.where(
  497. labels[..., None] != -10,
  498. point_embedding,
  499. torch.zeros_like(point_embedding),
  500. )
  501. # Add point embeddings for labels >= 0
  502. point_embedding = point_embedding + self.point_embed(labels.clamp(min=0)) * (labels >= 0).unsqueeze(-1)
  503. return point_embedding
  504. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  505. """Embeds box prompts."""
  506. boxes += 0.5 # Shift to center of pixel
  507. coords = boxes.view(*boxes.shape[:2], 2, 2)
  508. # add padding point for consistency with the original implementation
  509. coords = torch.nn.functional.pad(coords, (0, 0, 0, 1), mode="constant", value=0)
  510. corner_embedding = self.shared_embedding(coords, (self.input_image_size, self.input_image_size))
  511. corner_embedding[:, :, 0, :] += self.point_embed.weight[2]
  512. corner_embedding[:, :, 1, :] += self.point_embed.weight[3]
  513. corner_embedding[:, :, 2, :] = self.not_a_point_embed.weight.expand_as(corner_embedding[:, :, 2, :])
  514. return corner_embedding
  515. def forward(
  516. self,
  517. input_points: Optional[tuple[torch.Tensor, torch.Tensor]],
  518. input_labels: Optional[torch.Tensor],
  519. input_boxes: Optional[torch.Tensor],
  520. input_masks: Optional[torch.Tensor],
  521. ) -> tuple[torch.Tensor, torch.Tensor]:
  522. """
  523. Embeds different types of prompts, returning both sparse and dense embeddings.
  524. Args:
  525. points (`torch.Tensor`, *optional*):
  526. point coordinates and labels to embed.
  527. boxes (`torch.Tensor`, *optional*):
  528. boxes to embed
  529. masks (`torch.Tensor`, *optional*):
  530. masks to embed
  531. """
  532. sparse_embeddings = None
  533. batch_size = 1
  534. if input_points is not None:
  535. batch_size = input_points.shape[0]
  536. if input_labels is None:
  537. raise ValueError("If points are provided, labels must also be provided.")
  538. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  539. sparse_embeddings = point_embeddings
  540. if input_boxes is not None:
  541. batch_size = input_boxes.shape[0]
  542. box_embeddings = self._embed_boxes(input_boxes)
  543. if sparse_embeddings is None:
  544. sparse_embeddings = box_embeddings
  545. else:
  546. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  547. if input_masks is not None:
  548. dense_embeddings = self.mask_embed(input_masks)
  549. else:
  550. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  551. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  552. )
  553. return sparse_embeddings, dense_embeddings
  554. class EdgeTamTwoWayTransformer(nn.Module):
  555. def __init__(self, config: EdgeTamMaskDecoderConfig):
  556. super().__init__()
  557. self.config = config
  558. self.num_hidden_layers = config.num_hidden_layers
  559. self.layers = nn.ModuleList()
  560. for i in range(self.num_hidden_layers):
  561. self.layers.append(EdgeTamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  562. self.final_attn_token_to_image = EdgeTamAttention(config)
  563. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  564. def forward(
  565. self,
  566. point_embeddings: Tensor,
  567. image_embeddings: Tensor,
  568. image_positional_embeddings: Tensor,
  569. attention_similarity: Tensor,
  570. target_embedding=None,
  571. **kwargs: Unpack[TransformersKwargs],
  572. ) -> Union[tuple, BaseModelOutput]:
  573. if image_embeddings is None:
  574. raise ValueError("You have to specify an image_embedding")
  575. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  576. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  577. # Prepare queries
  578. queries = point_embeddings
  579. keys = image_embeddings
  580. # Apply transformer blocks and final layernorm
  581. for layer in self.layers:
  582. if target_embedding is not None:
  583. queries += target_embedding
  584. queries, keys, _ = layer(
  585. queries=queries,
  586. keys=keys,
  587. query_point_embedding=point_embeddings,
  588. key_point_embedding=image_positional_embeddings,
  589. attention_similarity=attention_similarity,
  590. **kwargs,
  591. )
  592. # Apply the final attention layer from the points to the image
  593. query = queries + point_embeddings
  594. key = keys + image_positional_embeddings
  595. attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
  596. queries = queries + attn_out
  597. queries = self.layer_norm_final_attn(queries)
  598. return queries, keys
  599. class EdgeTamMaskDecoder(nn.Module):
  600. def __init__(self, config: EdgeTamMaskDecoderConfig):
  601. super().__init__()
  602. self.config = config
  603. self.hidden_size = config.hidden_size
  604. self.num_multimask_outputs = config.num_multimask_outputs
  605. self.num_mask_tokens = config.num_multimask_outputs + 1
  606. self.iou_token = nn.Embedding(1, self.hidden_size)
  607. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  608. self.transformer = EdgeTamTwoWayTransformer(config)
  609. # should we create a new class for this?
  610. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  611. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  612. self.upscale_layer_norm = EdgeTamLayerNorm(self.hidden_size // 4, data_format="channels_first")
  613. self.activation = nn.GELU()
  614. mlps_list = []
  615. for _ in range(self.num_mask_tokens):
  616. mlps_list += [EdgeTamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  617. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  618. self.iou_prediction_head = EdgeTamFeedForward(
  619. self.hidden_size,
  620. config.iou_head_hidden_dim,
  621. self.num_mask_tokens,
  622. config.iou_head_depth,
  623. sigmoid_output=True,
  624. )
  625. self.conv_s0 = nn.Conv2d(config.hidden_size, config.hidden_size // 8, kernel_size=1, stride=1)
  626. self.conv_s1 = nn.Conv2d(config.hidden_size, config.hidden_size // 4, kernel_size=1, stride=1)
  627. self.obj_score_token = nn.Embedding(1, self.hidden_size)
  628. self.pred_obj_score_head = EdgeTamFeedForward(self.hidden_size, self.hidden_size, 1, 3)
  629. self.dynamic_multimask_via_stability = config.dynamic_multimask_via_stability
  630. self.dynamic_multimask_stability_delta = config.dynamic_multimask_stability_delta
  631. self.dynamic_multimask_stability_thresh = config.dynamic_multimask_stability_thresh
  632. def forward(
  633. self,
  634. image_embeddings: torch.Tensor,
  635. image_positional_embeddings: torch.Tensor,
  636. sparse_prompt_embeddings: torch.Tensor,
  637. dense_prompt_embeddings: torch.Tensor,
  638. multimask_output: bool,
  639. high_resolution_features: list[torch.Tensor],
  640. attention_similarity: Optional[torch.Tensor] = None,
  641. target_embedding: Optional[torch.Tensor] = None,
  642. **kwargs: Unpack[TransformersKwargs],
  643. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  644. """
  645. Predict masks given image and prompt embeddings.
  646. Args:
  647. image_embeddings (`torch.Tensor`):
  648. The embeddings from the image encoder.
  649. image_positional_embeddings (`torch.Tensor`):
  650. Positional encoding with the shape of image_embeddings.
  651. sparse_prompt_embeddings (`torch.Tensor`):
  652. The embeddings of the points and boxes.
  653. dense_prompt_embeddings (`torch.Tensor`):
  654. The embeddings of the mask inputs.
  655. multimask_output (`bool`):
  656. Whether to return multiple masks or a single mask.
  657. high_resolution_features (`list[torch.Tensor]`, *optional*):
  658. The high-resolution features from the vision encoder.
  659. attention_similarity (`torch.Tensor`, *optional*):
  660. The attention similarity tensor.
  661. target_embedding (`torch.Tensor`, *optional*):
  662. The target embedding.
  663. """
  664. batch_size, num_channels, height, width = image_embeddings.shape
  665. point_batch_size = sparse_prompt_embeddings.shape[1]
  666. # Concatenate output tokens
  667. output_tokens = torch.cat(
  668. [
  669. self.obj_score_token.weight,
  670. self.iou_token.weight,
  671. self.mask_tokens.weight,
  672. ],
  673. dim=0,
  674. )
  675. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  676. if sparse_prompt_embeddings.shape[0] != 0:
  677. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  678. else:
  679. tokens = output_tokens
  680. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  681. # Expand per-image data in batch direction to be per-mask
  682. image_embeddings = image_embeddings + dense_prompt_embeddings
  683. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, dim=0)
  684. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  685. # Run the transformer
  686. point_embeddings, image_embeddings = self.transformer(
  687. point_embeddings=point_embeddings,
  688. image_embeddings=image_embeddings,
  689. image_positional_embeddings=image_positional_embeddings,
  690. attention_similarity=attention_similarity,
  691. target_embedding=target_embedding,
  692. **kwargs,
  693. )
  694. iou_token_out = point_embeddings[:, :, 1, :]
  695. mask_tokens_out = point_embeddings[:, :, 2 : (2 + self.num_mask_tokens), :]
  696. # Upscale mask embeddings and predict masks using the mask tokens
  697. image_embeddings = image_embeddings.transpose(2, 3).view(
  698. batch_size * point_batch_size, num_channels, height, width
  699. )
  700. feat_s0, feat_s1 = high_resolution_features
  701. feat_s0 = feat_s0.repeat_interleave(point_batch_size, dim=0)
  702. feat_s1 = feat_s1.repeat_interleave(point_batch_size, dim=0)
  703. upscaled_embedding = self.upscale_conv1(image_embeddings) + feat_s1
  704. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  705. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding) + feat_s0)
  706. hyper_in_list: list[torch.Tensor] = []
  707. for i in range(self.num_mask_tokens):
  708. current_mlp = self.output_hypernetworks_mlps[i]
  709. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  710. hyper_in = torch.stack(hyper_in_list, dim=2)
  711. _, num_channels, height, width = upscaled_embedding.shape
  712. upscaled_embedding = upscaled_embedding.view(batch_size, point_batch_size, num_channels, height * width)
  713. masks = (hyper_in @ upscaled_embedding).view(batch_size, point_batch_size, -1, height, width)
  714. # Generate mask quality predictions
  715. iou_pred = self.iou_prediction_head(iou_token_out)
  716. object_score_logits = self.pred_obj_score_head(point_embeddings[:, :, 0, :])
  717. # Select the correct mask or masks for output
  718. if multimask_output:
  719. mask_slice = slice(1, None)
  720. masks = masks[:, :, mask_slice, :, :]
  721. iou_pred = iou_pred[:, :, mask_slice]
  722. elif self.dynamic_multimask_via_stability and not self.training:
  723. mask_slice = slice(0, 1)
  724. masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
  725. else:
  726. mask_slice = slice(0, 1)
  727. masks = masks[:, :, mask_slice, :, :]
  728. iou_pred = iou_pred[:, :, mask_slice]
  729. sam_tokens_out = mask_tokens_out[:, :, mask_slice] # [b, 3, c] shape
  730. return masks, iou_pred, sam_tokens_out, object_score_logits
  731. def _get_stability_scores(self, mask_logits):
  732. """
  733. Compute stability scores of the mask logits based on the IoU between upper and
  734. lower thresholds.
  735. """
  736. mask_logits = mask_logits.flatten(-2)
  737. stability_delta = self.dynamic_multimask_stability_delta
  738. area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
  739. area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
  740. stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
  741. return stability_scores
  742. def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
  743. """
  744. When outputting a single mask, if the stability score from the current single-mask
  745. output (based on output token 0) falls below a threshold, we instead select from
  746. multi-mask outputs (based on output token 1~3) the mask with the highest predicted
  747. IoU score. This is intended to ensure a valid mask for both clicking and tracking.
  748. """
  749. # The best mask from multimask output tokens (1~3)
  750. multimask_logits = all_mask_logits[:, :, 1:, :, :]
  751. multimask_iou_scores = all_iou_scores[:, :, 1:]
  752. best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) # [B, P]
  753. best_scores_inds_expanded = best_scores_inds.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
  754. best_scores_inds_expanded = best_scores_inds_expanded.expand(
  755. -1, -1, 1, multimask_logits.size(-2), multimask_logits.size(-1)
  756. )
  757. best_multimask_logits = torch.gather(multimask_logits, 2, best_scores_inds_expanded) # [B, P, 1, H, W]
  758. best_multimask_iou_scores = torch.gather(multimask_iou_scores, 2, best_scores_inds.unsqueeze(-1)) # [B, P, 1]
  759. # The mask from singlemask output token 0 and its stability score
  760. singlemask_logits = all_mask_logits[:, :, 0:1, :, :]
  761. singlemask_iou_scores = all_iou_scores[:, :, 0:1]
  762. stability_scores = self._get_stability_scores(singlemask_logits)
  763. is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
  764. # Dynamically fall back to best multimask output upon low stability scores.
  765. mask_logits_out = torch.where(
  766. is_stable[..., None, None].expand_as(singlemask_logits),
  767. singlemask_logits,
  768. best_multimask_logits,
  769. )
  770. iou_scores_out = torch.where(
  771. is_stable.expand_as(singlemask_iou_scores),
  772. singlemask_iou_scores,
  773. best_multimask_iou_scores,
  774. )
  775. return mask_logits_out, iou_scores_out
  776. @auto_docstring(
  777. custom_intro="""
  778. Segment Anything Model 2 (SAM 2) for generating segmentation masks, given an input image and
  779. input points and labels, boxes, or masks.
  780. """
  781. )
  782. class EdgeTamModel(EdgeTamPreTrainedModel):
  783. _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
  784. # need to be ignored, as it's a buffer and will not be correctly detected as tied weight
  785. _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"]
  786. _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamTwoWayAttentionBlock, index=2)}
  787. _keys_to_ignore_on_load_unexpected = [
  788. r"^memory_.*",
  789. r"^mask_downsample.*",
  790. r"spatial_perceiver.*",
  791. r"^object_pointer_proj.*",
  792. r"^temporal_positional_encoding_projection_layer.*",
  793. "no_memory_positional_encoding",
  794. "no_object_pointer",
  795. "occlusion_spatial_embedding_parameter",
  796. ]
  797. def __init__(self, config: EdgeTamConfig):
  798. super().__init__(config)
  799. self.shared_image_embedding = EdgeTamPositionalEmbedding(config.prompt_encoder_config)
  800. self.vision_encoder = AutoModel.from_config(config.vision_config)
  801. self.prompt_encoder = EdgeTamPromptEncoder(config.prompt_encoder_config)
  802. # The module using it is not a PreTrainedModel subclass so we need this
  803. config.mask_decoder_config._attn_implementation = config._attn_implementation
  804. self.mask_decoder = EdgeTamMaskDecoder(config.mask_decoder_config)
  805. self.num_feature_levels = config.vision_config.num_feature_levels
  806. self.backbone_feature_sizes = config.vision_config.backbone_feature_sizes
  807. # a single token to indicate no memory embedding from previous frames
  808. self.hidden_dim = config.vision_config.fpn_hidden_size
  809. self.no_memory_embedding = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
  810. self.post_init()
  811. def _tie_weights(self):
  812. self.prompt_encoder.shared_embedding.positional_embedding.data = (
  813. self.shared_image_embedding.positional_embedding.data
  814. )
  815. def get_image_wide_positional_embeddings(self) -> torch.Tensor:
  816. size = self.prompt_encoder.image_embedding_size
  817. target_device = self.shared_image_embedding.positional_embedding.device
  818. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  819. grid = torch.ones(size, device=target_device, dtype=target_dtype)
  820. y_embed = grid.cumsum(dim=0) - 0.5
  821. x_embed = grid.cumsum(dim=1) - 0.5
  822. y_embed = y_embed / size[0]
  823. x_embed = x_embed / size[1]
  824. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  825. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  826. @torch.no_grad()
  827. def get_image_embeddings(
  828. self,
  829. pixel_values: torch.FloatTensor,
  830. **kwargs: Unpack[TransformersKwargs],
  831. ) -> list[torch.Tensor]:
  832. r"""
  833. Returns the image embeddings by passing the pixel values through the vision encoder.
  834. Args:
  835. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  836. Input pixel values
  837. """
  838. batch_size = pixel_values.shape[0]
  839. feature_maps, _, _, _ = self.get_image_features(pixel_values, **kwargs)
  840. # add no memory embedding to the last feature map
  841. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  842. # reshape feature maps to the same shape as the backbone feature sizes
  843. image_embeddings = [
  844. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  845. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  846. ]
  847. return image_embeddings
  848. @torch.no_grad()
  849. def get_prompt_embeddings(
  850. self,
  851. input_points: Optional[torch.FloatTensor] = None,
  852. input_labels: Optional[torch.LongTensor] = None,
  853. input_boxes: Optional[torch.FloatTensor] = None,
  854. input_masks: Optional[torch.LongTensor] = None,
  855. ):
  856. r"""
  857. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  858. Args:
  859. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  860. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  861. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  862. point. The model will output `point_batch_size` times 3 masks in total.
  863. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  864. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  865. processor, or can be fed by the user.
  866. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  867. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  868. processor. users can also pass manually the input boxes.
  869. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  870. Optional input masks for the prompt encoder.
  871. """
  872. prompt_output = self.prompt_encoder(
  873. input_points=input_points,
  874. input_labels=input_labels,
  875. input_boxes=input_boxes,
  876. input_masks=input_masks,
  877. )
  878. return prompt_output
  879. @check_model_inputs()
  880. @auto_docstring
  881. def forward(
  882. self,
  883. pixel_values: Optional[torch.FloatTensor] = None,
  884. input_points: Optional[torch.FloatTensor] = None,
  885. input_labels: Optional[torch.LongTensor] = None,
  886. input_boxes: Optional[torch.FloatTensor] = None,
  887. input_masks: Optional[torch.LongTensor] = None,
  888. image_embeddings: Optional[torch.FloatTensor] = None,
  889. multimask_output: bool = True,
  890. attention_similarity: Optional[torch.FloatTensor] = None,
  891. target_embedding: Optional[torch.FloatTensor] = None,
  892. **kwargs: Unpack[TransformersKwargs],
  893. ) -> EdgeTamImageSegmentationOutput:
  894. r"""
  895. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  896. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  897. better results. The points can be obtained by passing a list of list of list to the processor that will
  898. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  899. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  900. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  901. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  902. coordinates of the point. If a different number of points is passed either for each image, or for each
  903. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  904. computation of the embedding will be skipped for these points using the labels.
  905. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  906. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  907. official implementation, there are 3 types of labels
  908. - `1`: the point is a point that contains the object of interest
  909. - `0`: the point is a point that does not contain the object of interest
  910. - `-1`: the point corresponds to the background
  911. We added the label:
  912. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  913. The padding labels should be automatically done by the processor.
  914. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  915. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  916. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  917. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  918. size, the number of boxes per image and the coordinates of the top left and bottom right point of the box.
  919. In the order (`x1`, `y1`, `x2`, `y2`):
  920. - `x1`: the x coordinate of the top left point of the input box
  921. - `y1`: the y coordinate of the top left point of the input box
  922. - `x2`: the x coordinate of the bottom right point of the input box
  923. - `y2`: the y coordinate of the bottom right point of the input box
  924. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  925. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  926. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  927. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  928. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  929. Image embeddings, this is used by the mask decoder to generate masks and iou scores. For more memory
  930. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  931. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  932. multimask_output (`bool`, *optional*):
  933. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  934. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  935. "best" mask, by specifying `multimask_output=False`.
  936. attention_similarity (`torch.FloatTensor`, *optional*):
  937. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  938. model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  939. target_embedding (`torch.FloatTensor`, *optional*):
  940. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  941. the model is used for personalization as introduced in [PerSAM](https://huggingface.co/papers/2305.03048).
  942. Example:
  943. ```python
  944. >>> from PIL import Image
  945. >>> import requests
  946. >>> from transformers import AutoModel, AutoProcessor
  947. >>> model = AutoModel.from_pretrained("danelcsb/edgetam.1_hiera_tiny")
  948. >>> processor = AutoProcessor.from_pretrained("danelcsb/edgetam.1_hiera_tiny")
  949. >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  950. >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
  951. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  952. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  953. >>> # Get segmentation mask
  954. >>> outputs = model(**inputs)
  955. >>> # Postprocess masks
  956. >>> masks = processor.post_process_masks(
  957. ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
  958. ... )
  959. ```
  960. """
  961. if not ((pixel_values is None) ^ (image_embeddings is None)):
  962. raise ValueError("Exactly one of pixel_values or image_embeddings must be provided.")
  963. if input_points is not None and input_boxes is not None:
  964. if input_points.shape[1] != input_boxes.shape[1]:
  965. raise ValueError(
  966. f"You should provide as many bounding boxes as input points per box. Got {input_points.shape[1]} and {input_boxes.shape[1]}."
  967. )
  968. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  969. # repeat with batch size
  970. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings[-1].shape[0]
  971. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  972. vision_attentions = None
  973. vision_hidden_states = None
  974. if pixel_values is not None:
  975. feature_maps, _, vision_hidden_states, vision_attentions = self.get_image_features(
  976. pixel_values,
  977. **kwargs,
  978. )
  979. # add no memory embedding to the last feature map
  980. feature_maps[-1] = feature_maps[-1] + self.no_memory_embedding
  981. # reshape feature maps to the same shape as the backbone feature sizes
  982. image_embeddings = [
  983. feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
  984. for feat, feat_size in zip(feature_maps, self.backbone_feature_sizes)
  985. ]
  986. if input_points is not None and input_labels is None:
  987. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  988. if input_points is None and input_boxes is None:
  989. # If no points are provide, pad with an empty point (with label -1)
  990. input_points = torch.zeros(
  991. batch_size, 1, 1, 2, dtype=image_embeddings[-1].dtype, device=image_embeddings[-1].device
  992. )
  993. input_labels = -torch.ones(batch_size, 1, 1, dtype=torch.int32, device=image_embeddings[-1].device)
  994. if input_masks is not None:
  995. # If mask_inputs is provided, downsize it into low-res mask input if needed
  996. # and feed it as a dense mask prompt into the SAM mask encoder
  997. if input_masks.shape[-2:] != self.prompt_encoder.mask_input_size:
  998. input_masks = F.interpolate(
  999. input_masks.float(),
  1000. size=self.prompt_encoder.mask_input_size,
  1001. align_corners=False,
  1002. mode="bilinear",
  1003. antialias=True, # use antialias for downsampling
  1004. ).to(input_masks.dtype)
  1005. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1006. input_points=input_points,
  1007. input_labels=input_labels,
  1008. input_boxes=input_boxes,
  1009. input_masks=input_masks,
  1010. )
  1011. low_res_multimasks, iou_scores, _, object_score_logits = self.mask_decoder(
  1012. image_embeddings=image_embeddings[-1],
  1013. image_positional_embeddings=image_positional_embeddings,
  1014. sparse_prompt_embeddings=sparse_embeddings,
  1015. dense_prompt_embeddings=dense_embeddings,
  1016. multimask_output=multimask_output,
  1017. high_resolution_features=image_embeddings[:-1],
  1018. attention_similarity=attention_similarity,
  1019. target_embedding=target_embedding,
  1020. **kwargs,
  1021. )
  1022. return EdgeTamImageSegmentationOutput(
  1023. iou_scores=iou_scores,
  1024. pred_masks=low_res_multimasks,
  1025. object_score_logits=object_score_logits,
  1026. image_embeddings=image_embeddings,
  1027. vision_hidden_states=vision_hidden_states,
  1028. vision_attentions=vision_attentions,
  1029. )
  1030. def get_image_features(
  1031. self,
  1032. pixel_values: torch.FloatTensor,
  1033. **kwargs: Unpack[TransformersKwargs],
  1034. ) -> tuple[
  1035. list[torch.Tensor],
  1036. list[torch.Tensor],
  1037. Optional[tuple[torch.FloatTensor, ...]],
  1038. Optional[tuple[torch.FloatTensor, ...]],
  1039. ]:
  1040. r"""
  1041. Extract and preprocess image features using the vision encoder.
  1042. Args:
  1043. pixel_values (`torch.FloatTensor`):
  1044. Input pixel values of shape `(batch_size, num_channels, height, width)`.
  1045. Returns:
  1046. `tuple`: A tuple containing:
  1047. - feature_maps (`list[torch.Tensor]`): List of feature maps from different levels.
  1048. - feature_maps_position_embeddings (`list[torch.Tensor]`): List of positional embeddings for each feature level.
  1049. - vision_hidden_states (`tuple[torch.FloatTensor]`, *optional*): Hidden states from the vision encoder.
  1050. - vision_attentions (`tuple[torch.FloatTensor]`, *optional*): Attention weights from the vision encoder.
  1051. """
  1052. vision_outputs: EdgeTamVisionEncoderOutput = self.vision_encoder(
  1053. pixel_values,
  1054. **kwargs,
  1055. )
  1056. feature_maps = vision_outputs.fpn_hidden_states
  1057. feature_maps_position_embeddings = vision_outputs.fpn_position_encoding
  1058. # precompute projected level 0 and level 1 features in SAM decoder
  1059. # to avoid running it again on every SAM click
  1060. feature_maps = list(feature_maps)
  1061. feature_maps[0] = self.mask_decoder.conv_s0(feature_maps[0])
  1062. feature_maps[1] = self.mask_decoder.conv_s1(feature_maps[1])
  1063. # flatten NxCxHxW to HWxNxC
  1064. feature_maps = [feature_map.flatten(2).permute(2, 0, 1) for feature_map in feature_maps]
  1065. feature_maps_position_embeddings = [
  1066. feature_map_position_embedding.flatten(2).permute(2, 0, 1)
  1067. for feature_map_position_embedding in feature_maps_position_embeddings
  1068. ]
  1069. return feature_maps, feature_maps_position_embeddings, vision_outputs.hidden_states, vision_outputs.attentions
  1070. __all__ = ["EdgeTamModel", "EdgeTamVisionModel", "EdgeTamPreTrainedModel"]