modular_edgetam.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # coding=utf-8
  2. # Copyright 2025 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SAM 2 model."""
  16. from typing import Optional, Union
  17. import torch
  18. import torch.nn as nn
  19. import torch.utils.checkpoint
  20. from transformers.models.sam2.configuration_sam2 import Sam2Config, Sam2MaskDecoderConfig, Sam2PromptEncoderConfig
  21. from transformers.models.sam2.modeling_sam2 import (
  22. Sam2Attention,
  23. Sam2FeedForward,
  24. Sam2LayerNorm,
  25. Sam2Model,
  26. Sam2PreTrainedModel,
  27. Sam2TwoWayAttentionBlock,
  28. Sam2VisionEncoderOutput,
  29. Sam2VisionModel,
  30. )
  31. from transformers.utils.generic import TransformersKwargs, check_model_inputs
  32. from ...configuration_utils import PretrainedConfig
  33. from ...processing_utils import Unpack
  34. from ...utils import (
  35. auto_docstring,
  36. )
  37. from ..auto import CONFIG_MAPPING, AutoConfig
  38. # fix this in modular
  39. if True:
  40. from transformers.models.timm_wrapper.modeling_timm_wrapper import TimmWrapperModel
  41. class EdgeTamVisionConfig(PretrainedConfig):
  42. r"""
  43. This is the configuration class to store the configuration of a [`EdgeTamVisionModel`]. It is used to instantiate a SAM
  44. vision encoder according to the specified arguments, defining the model architecture. Instantiating a configuration
  45. defaults will yield a similar configuration to that of SAM 2.1 Hiera-tiny
  46. [facebook/EdgeTAM](https://huggingface.co/facebook/EdgeTAM) architecture.
  47. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  48. documentation from [`PretrainedConfig`] for more information.
  49. Args:
  50. backbone_config (`Union[dict, "PretrainedConfig"]`, *optional*):
  51. Configuration for the vision backbone. This is used to instantiate the backbone using
  52. `AutoModel.from_config`.
  53. backbone_channel_list (`List[int]`, *optional*, defaults to `[384, 192, 96, 48]`):
  54. The list of channel dimensions for the backbone.
  55. backbone_feature_sizes (`List[List[int]]`, *optional*, defaults to `[[256, 256], [128, 128], [64, 64]]`):
  56. The spatial sizes of the feature maps from the backbone.
  57. fpn_hidden_size (`int`, *optional*, defaults to 256):
  58. The hidden dimension of the FPN.
  59. fpn_kernel_size (`int`, *optional*, defaults to 1):
  60. The kernel size for the convolutions in the neck.
  61. fpn_stride (`int`, *optional*, defaults to 1):
  62. The stride for the convolutions in the neck.
  63. fpn_padding (`int`, *optional*, defaults to 0):
  64. The padding for the convolutions in the neck.
  65. fpn_top_down_levels (`List[int]`, *optional*, defaults to `[2, 3]`):
  66. The levels for the top-down FPN connections.
  67. num_feature_levels (`int`, *optional*, defaults to 3):
  68. The number of feature levels from the FPN to use.
  69. hidden_act (`str`, *optional*, defaults to `"gelu"`):
  70. The non-linear activation function in the neck.
  71. layer_norm_eps (`float`, *optional*, defaults to 1e-06):
  72. The epsilon for the layer normalization.
  73. initializer_range (`float`, *optional*, defaults to 0.02):
  74. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  75. """
  76. base_config_key = "vision_config"
  77. model_type = "edgetam_vision_model"
  78. sub_configs = {
  79. "backbone_config": AutoConfig,
  80. }
  81. def __init__(
  82. self,
  83. backbone_config=None,
  84. backbone_channel_list=None,
  85. backbone_feature_sizes=None,
  86. fpn_hidden_size=256,
  87. fpn_kernel_size=1,
  88. fpn_stride=1,
  89. fpn_padding=0,
  90. fpn_top_down_levels=None,
  91. num_feature_levels=3,
  92. hidden_act="gelu",
  93. layer_norm_eps=1e-6,
  94. initializer_range=0.02,
  95. **kwargs,
  96. ):
  97. super().__init__(**kwargs)
  98. backbone_channel_list = [384, 192, 96, 48] if backbone_channel_list is None else backbone_channel_list
  99. backbone_feature_sizes = (
  100. [[256, 256], [128, 128], [64, 64]] if backbone_feature_sizes is None else backbone_feature_sizes
  101. )
  102. fpn_top_down_levels = [2, 3] if fpn_top_down_levels is None else fpn_top_down_levels
  103. if isinstance(backbone_config, dict):
  104. backbone_config["model_type"] = backbone_config.get("model_type", "timm_wrapper")
  105. backbone_config = CONFIG_MAPPING[backbone_config["model_type"]](**backbone_config)
  106. elif isinstance(backbone_config, AutoConfig):
  107. backbone_config = backbone_config
  108. elif backbone_config is None:
  109. backbone_config = AutoConfig.from_pretrained(
  110. "timm/repvit_m1.dist_in1k",
  111. model_args={"in_chans": 3, "features_only": True, "out_indices": [0, 1, 2, 3]},
  112. )
  113. self.backbone_config = backbone_config
  114. # Neck
  115. self.backbone_channel_list = backbone_channel_list
  116. self.backbone_feature_sizes = backbone_feature_sizes
  117. self.fpn_hidden_size = fpn_hidden_size
  118. self.fpn_kernel_size = fpn_kernel_size
  119. self.fpn_stride = fpn_stride
  120. self.fpn_padding = fpn_padding
  121. self.fpn_top_down_levels = fpn_top_down_levels
  122. self.num_feature_levels = num_feature_levels
  123. self.hidden_act = hidden_act
  124. self.layer_norm_eps = layer_norm_eps
  125. self.initializer_range = initializer_range
  126. class EdgeTamPromptEncoderConfig(Sam2PromptEncoderConfig):
  127. pass
  128. class EdgeTamMaskDecoderConfig(Sam2MaskDecoderConfig):
  129. pass
  130. class EdgeTamConfig(Sam2Config):
  131. pass
  132. class EdgeTamLayerNorm(Sam2LayerNorm):
  133. pass
  134. class EdgeTamVisionEncoderOutput(Sam2VisionEncoderOutput):
  135. pass
  136. class EdgeTamAttention(Sam2Attention):
  137. pass
  138. class EdgeTamTwoWayAttentionBlock(Sam2TwoWayAttentionBlock):
  139. pass
  140. class EdgeTamFeedForward(Sam2FeedForward):
  141. pass
  142. @auto_docstring
  143. class EdgeTamPreTrainedModel(Sam2PreTrainedModel):
  144. def _init_weights(self, module):
  145. std = self.config.initializer_range
  146. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  147. module.weight.data.normal_(mean=0.0, std=std)
  148. if module.bias is not None:
  149. module.bias.data.zero_()
  150. elif isinstance(module, nn.Embedding):
  151. module.weight.data.normal_(mean=0.0, std=std)
  152. if module.padding_idx is not None:
  153. module.weight.data[module.padding_idx].zero_()
  154. elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)):
  155. module.weight.data.fill_(1.0)
  156. module.bias.data.zero_()
  157. if isinstance(module, EdgeTamModel):
  158. if module.no_memory_embedding is not None:
  159. module.no_memory_embedding.data.zero_()
  160. @auto_docstring(
  161. custom_intro="""
  162. The vision model from EdgeTAM without any head or projection on top.
  163. """
  164. )
  165. class EdgeTamVisionModel(Sam2VisionModel):
  166. config_class = EdgeTamVisionConfig
  167. main_input_name = "pixel_values"
  168. _can_record_outputs = {"hidden_states": TimmWrapperModel, "attentions": TimmWrapperModel}
  169. def get_input_embeddings(self):
  170. raise NotImplementedError("Can't get input embeddings from timm wrapper model")
  171. @check_model_inputs()
  172. def forward(
  173. self,
  174. pixel_values: Optional[torch.FloatTensor] = None,
  175. **kwargs: Unpack[TransformersKwargs],
  176. ) -> Union[tuple, EdgeTamVisionEncoderOutput]:
  177. if pixel_values is None:
  178. raise ValueError("You have to specify pixel_values")
  179. # Forward through backbone
  180. backbone_output = self.backbone(pixel_values)
  181. intermediate_hidden_states = backbone_output.last_hidden_state
  182. intermediate_hidden_states = [hidden_state.permute(0, 2, 3, 1) for hidden_state in intermediate_hidden_states]
  183. fpn_hidden_states, fpn_position_encoding = self.neck(intermediate_hidden_states)
  184. # Select last `num_feature_levels` feature levels from FPN and reverse order to get features from high to low resolution
  185. fpn_hidden_states = fpn_hidden_states[-self.num_feature_levels :][::-1]
  186. fpn_position_encoding = fpn_position_encoding[-self.num_feature_levels :][::-1]
  187. return EdgeTamVisionEncoderOutput(
  188. last_hidden_state=intermediate_hidden_states[-1],
  189. fpn_hidden_states=fpn_hidden_states,
  190. fpn_position_encoding=fpn_position_encoding,
  191. )
  192. class EdgeTamModel(Sam2Model):
  193. _keys_to_ignore_on_load_unexpected = [
  194. r"^memory_.*",
  195. r"^mask_downsample.*",
  196. r"spatial_perceiver.*",
  197. r"^object_pointer_proj.*",
  198. r"^temporal_positional_encoding_projection_layer.*",
  199. "no_memory_positional_encoding",
  200. "no_object_pointer",
  201. "occlusion_spatial_embedding_parameter",
  202. ]
  203. def get_input_embeddings(self):
  204. raise NotImplementedError("Can't get input embeddings from timm wrapper model")
  205. __all__ = [
  206. "EdgeTamModel",
  207. "EdgeTamVisionModel",
  208. "EdgeTamPreTrainedModel",
  209. "EdgeTamConfig",
  210. "EdgeTamVisionConfig",
  211. "EdgeTamPromptEncoderConfig",
  212. "EdgeTamMaskDecoderConfig",
  213. ]