modular_eomt.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # coding=utf-8
  2. # Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch EoMT model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import Tensor, nn
  22. from ...activations import ACT2FN
  23. from ...file_utils import (
  24. ModelOutput,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import (
  29. TransformersKwargs,
  30. auto_docstring,
  31. logging,
  32. )
  33. from ...utils.generic import check_model_inputs
  34. from ..dinov2.modeling_dinov2 import (
  35. Dinov2Embeddings,
  36. Dinov2Layer,
  37. Dinov2LayerScale,
  38. Dinov2PatchEmbeddings,
  39. )
  40. from ..mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentation, Mask2FormerLoss
  41. from ..siglip.modeling_siglip import SiglipAttention
  42. from ..vit.configuration_vit import ViTConfig
  43. logger = logging.get_logger(__name__)
  44. class EomtConfig(ViTConfig):
  45. r"""
  46. This is the configuration class to store the configuration of a [`EomtForUniversalSegmentation`]. It is used to instantiate an EoMT model
  47. according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  48. defaults will yield a similar configuration to that of the EoMT
  49. [tue-mps/coco_panoptic_eomt_large_640](https://huggingface.co/tue-mps/coco_panoptic_eomt_large_640)
  50. architecture.
  51. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  52. documentation from [`PretrainedConfig`] for more information.
  53. Args:
  54. hidden_size (`int`, *optional*, defaults to 1024):
  55. Dimensionality of the hidden representations.
  56. num_hidden_layers (`int`, *optional*, defaults to 24):
  57. Number of hidden layers in the Transformer encoder.
  58. num_attention_heads (`int`, *optional*, defaults to 16):
  59. Number of attention heads in each attention layer.
  60. mlp_ratio (`int`, *optional*, defaults to 4):
  61. Ratio of the MLP hidden dimensionality to the hidden size.
  62. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  63. The non-linear activation function (function or string) in the encoder.
  64. hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
  65. The dropout probability for all fully connected layers in the embeddings and encoder.
  66. initializer_range (`float`, *optional*, defaults to 0.02):
  67. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  68. layer_norm_eps (`float`, *optional*, defaults to 1e-06):
  69. The epsilon used by the layer normalization layers.
  70. image_size (`int`, *optional*, defaults to 640):
  71. The size (resolution) of each input image.
  72. patch_size (`int`, *optional*, defaults to 16):
  73. The size (resolution) of each patch.
  74. num_channels (`int`, *optional*, defaults to 3):
  75. The number of input channels.
  76. layerscale_value (`float`, *optional*, defaults to 1.0):
  77. Initial value for the LayerScale parameter.
  78. drop_path_rate (`float`, *optional*, defaults to 0.0):
  79. The stochastic depth rate (drop path) used during training.
  80. num_upscale_blocks (`int`, *optional*, defaults to 2):
  81. Number of upsampling blocks used in the decoder or segmentation head.
  82. attention_dropout (`float`, *optional*, defaults to 0.0):
  83. Dropout probability applied after attention projection.
  84. use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
  85. Whether to use the SwiGLU feedforward neural network.
  86. num_blocks (`int`, *optional*, defaults to 4):
  87. Number of feature blocks or stages in the architecture.
  88. no_object_weight (`float`, *optional*, defaults to 0.1):
  89. Loss weight for the 'no object' class in panoptic/instance segmentation.
  90. class_weight (`float`, *optional*, defaults to 2.0):
  91. Loss weight for classification targets.
  92. mask_weight (`float`, *optional*, defaults to 5.0):
  93. Loss weight for mask prediction.
  94. dice_weight (`float`, *optional*, defaults to 5.0):
  95. Loss weight for the dice loss component.
  96. train_num_points (`int`, *optional*, defaults to 12544):
  97. Number of points to sample for mask loss computation during training.
  98. oversample_ratio (`float`, *optional*, defaults to 3.0):
  99. Oversampling ratio used in point sampling for mask training.
  100. importance_sample_ratio (`float`, *optional*, defaults to 0.75):
  101. Ratio of points to sample based on importance during training.
  102. num_queries (`int`, *optional*, defaults to 200):
  103. Number of object queries in the Transformer.
  104. num_register_tokens (`int`, *optional*, defaults to 4):
  105. Number of learnable register tokens added to the transformer input.
  106. Example:
  107. ```python
  108. >>> from transformers import EomtConfig, EomtForUniversalSegmentation
  109. >>> # Initialize configuration
  110. >>> config = EomtConfig()
  111. >>> # Initialize model
  112. >>> model = EomtForUniversalSegmentation(config)
  113. >>> # Access config
  114. >>> config = model.config
  115. ```"""
  116. model_type = "eomt"
  117. def __init__(
  118. self,
  119. hidden_size=1024,
  120. num_hidden_layers=24,
  121. num_attention_heads=16,
  122. mlp_ratio=4,
  123. hidden_act="gelu",
  124. hidden_dropout_prob=0.0,
  125. initializer_range=0.02,
  126. layer_norm_eps=1e-6,
  127. image_size=640,
  128. patch_size=16,
  129. num_channels=3,
  130. layerscale_value=1.0,
  131. drop_path_rate=0.0,
  132. num_upscale_blocks=2,
  133. attention_dropout=0.0,
  134. use_swiglu_ffn=False,
  135. num_blocks=4,
  136. no_object_weight: float = 0.1,
  137. class_weight: float = 2.0,
  138. mask_weight: float = 5.0,
  139. dice_weight: float = 5.0,
  140. train_num_points: int = 12544,
  141. oversample_ratio: float = 3.0,
  142. importance_sample_ratio: float = 0.75,
  143. num_queries=200,
  144. num_register_tokens=4,
  145. **kwargs,
  146. ):
  147. super().__init__(
  148. hidden_size=hidden_size,
  149. num_hidden_layers=num_hidden_layers,
  150. num_attention_heads=num_attention_heads,
  151. hidden_dropout_prob=hidden_dropout_prob,
  152. hidden_act=hidden_act,
  153. initializer_range=initializer_range,
  154. layer_norm_eps=layer_norm_eps,
  155. image_size=image_size,
  156. patch_size=patch_size,
  157. num_channels=num_channels,
  158. **kwargs,
  159. )
  160. del self.intermediate_size
  161. del self.qkv_bias
  162. del self.pooler_act
  163. del self.pooler_output_size
  164. del self.encoder_stride
  165. del self.attention_probs_dropout_prob
  166. self.mlp_ratio = mlp_ratio
  167. self.attention_dropout = attention_dropout
  168. self.layerscale_value = layerscale_value
  169. self.drop_path_rate = drop_path_rate
  170. self.num_upscale_blocks = num_upscale_blocks
  171. self.use_swiglu_ffn = use_swiglu_ffn
  172. self.num_blocks = num_blocks
  173. self.no_object_weight = no_object_weight
  174. self.class_weight = class_weight
  175. self.mask_weight = mask_weight
  176. self.dice_weight = dice_weight
  177. self.train_num_points = train_num_points
  178. self.oversample_ratio = oversample_ratio
  179. self.importance_sample_ratio = importance_sample_ratio
  180. self.num_queries = num_queries
  181. self.num_register_tokens = num_register_tokens
  182. @dataclass
  183. @auto_docstring(
  184. custom_intro="""
  185. Class for outputs of [`EomtForUniversalSegmentationOutput`].
  186. This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or
  187. [`~EomtImageProcessor.post_process_instance_segmentation`] or
  188. [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
  189. [`~EomtImageProcessor] for details regarding usage.
  190. """
  191. )
  192. class EomtForUniversalSegmentationOutput(ModelOutput):
  193. r"""
  194. loss (`torch.Tensor`, *optional*):
  195. The computed loss, returned when labels are present.
  196. class_queries_logits (`torch.FloatTensor`):
  197. A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
  198. query. Note the `+ 1` is needed because we incorporate the null class.
  199. masks_queries_logits (`torch.FloatTensor`):
  200. A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
  201. query.
  202. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  203. Last hidden states (final feature map) of the last layer.
  204. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  205. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  206. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
  207. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  208. Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  209. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
  210. patch_offsets (`list[torch.Tensor]`, *optional*):
  211. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  212. """
  213. loss: Optional[torch.FloatTensor] = None
  214. class_queries_logits: Optional[torch.FloatTensor] = None
  215. masks_queries_logits: Optional[torch.FloatTensor] = None
  216. last_hidden_state: Optional[torch.FloatTensor] = None
  217. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  218. attentions: Optional[tuple[torch.FloatTensor]] = None
  219. patch_offsets: Optional[list[torch.Tensor]] = None
  220. class EomtLoss(Mask2FormerLoss):
  221. pass
  222. class EomtPatchEmbeddings(Dinov2PatchEmbeddings):
  223. pass
  224. class EomtEmbeddings(Dinov2Embeddings):
  225. def __init__(self, config: EomtConfig) -> None:
  226. nn.Module.__init__(self)
  227. self.config = config
  228. self.patch_size = config.patch_size
  229. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  230. self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
  231. self.patch_embeddings = EomtPatchEmbeddings(config)
  232. num_patches = self.patch_embeddings.num_patches
  233. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  234. self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
  235. self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
  236. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  237. def interpolate_pos_encoding(self):
  238. raise AttributeError("Not needed for Eomt Model")
  239. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  240. batch_size, _, _, _ = pixel_values.shape
  241. target_dtype = self.patch_embeddings.projection.weight.dtype
  242. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  243. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  244. register_tokens = self.register_tokens.expand(batch_size, -1, -1)
  245. embeddings = embeddings + self.position_embeddings(self.position_ids)
  246. embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
  247. embeddings = self.dropout(embeddings)
  248. return embeddings
  249. class EomtAttention(SiglipAttention):
  250. pass
  251. class EomtLayerScale(Dinov2LayerScale):
  252. pass
  253. class EomtLayer(Dinov2Layer):
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. head_mask: Optional[torch.Tensor] = None,
  258. ) -> torch.Tensor:
  259. hidden_states_norm = self.norm1(hidden_states)
  260. self_attention_output, _ = self.attention(hidden_states_norm, head_mask)
  261. self_attention_output = self.layer_scale1(self_attention_output)
  262. # first residual connection
  263. hidden_states = self.drop_path(self_attention_output) + hidden_states
  264. # in Eomt, layernorm is also applied after self-attention
  265. layer_output = self.norm2(hidden_states)
  266. layer_output = self.mlp(layer_output)
  267. layer_output = self.layer_scale2(layer_output)
  268. # second residual connection
  269. layer_output = self.drop_path(layer_output) + hidden_states
  270. return layer_output
  271. class EomtLayerNorm2d(nn.LayerNorm):
  272. def __init__(self, num_channels, eps=1e-6, affine=True):
  273. super().__init__(num_channels, eps=eps, elementwise_affine=affine)
  274. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  275. hidden_state = hidden_state.permute(0, 2, 3, 1)
  276. hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
  277. hidden_state = hidden_state.permute(0, 3, 1, 2)
  278. return hidden_state
  279. class EomtScaleLayer(nn.Module):
  280. def __init__(self, config: EomtConfig):
  281. super().__init__()
  282. hidden_size = config.hidden_size
  283. self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
  284. self.activation = ACT2FN[config.hidden_act]
  285. self.conv2 = nn.Conv2d(
  286. hidden_size,
  287. hidden_size,
  288. kernel_size=3,
  289. padding=1,
  290. groups=hidden_size,
  291. bias=False,
  292. )
  293. self.layernorm2d = EomtLayerNorm2d(hidden_size)
  294. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  295. hidden_states = self.conv1(hidden_states)
  296. hidden_states = self.activation(hidden_states)
  297. hidden_states = self.conv2(hidden_states)
  298. hidden_states = self.layernorm2d(hidden_states)
  299. return hidden_states
  300. class EomtScaleBlock(nn.Module):
  301. def __init__(self, config: EomtConfig):
  302. super().__init__()
  303. self.num_blocks = config.num_upscale_blocks
  304. self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)])
  305. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  306. for block in self.block:
  307. hidden_states = block(hidden_states)
  308. return hidden_states
  309. class EomtMaskHead(nn.Module):
  310. def __init__(self, config: EomtConfig):
  311. super().__init__()
  312. hidden_size = config.hidden_size
  313. self.fc1 = nn.Linear(hidden_size, hidden_size)
  314. self.fc2 = nn.Linear(hidden_size, hidden_size)
  315. self.fc3 = nn.Linear(hidden_size, hidden_size)
  316. self.activation = ACT2FN[config.hidden_act]
  317. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  318. hidden_states = self.activation(self.fc1(hidden_states))
  319. hidden_states = self.activation(self.fc2(hidden_states))
  320. hidden_states = self.fc3(hidden_states)
  321. return hidden_states
  322. @auto_docstring
  323. class EomtPreTrainedModel(PreTrainedModel):
  324. """
  325. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  326. models.
  327. """
  328. config: EomtConfig
  329. base_model_prefix = "eomt"
  330. main_input_name = "pixel_values"
  331. supports_gradient_checkpointing = False
  332. _no_split_modules = ["EomtLayer"]
  333. _supports_sdpa = True
  334. _can_record_outputs = {
  335. "hidden_states": EomtLayer,
  336. "attentions": EomtAttention,
  337. }
  338. def _init_weights(self, module: nn.Module) -> None:
  339. std = self.config.initializer_range
  340. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  341. nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  342. if module.bias is not None:
  343. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  344. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  345. nn.init.uniform_(module.bias, -bound, bound)
  346. elif isinstance(module, nn.LayerNorm):
  347. module.weight.data.fill_(1.0)
  348. module.bias.data.zero_()
  349. elif isinstance(module, nn.Embedding):
  350. module.weight.data.normal_(mean=0.0, std=1)
  351. if module.padding_idx is not None:
  352. module.weight.data[module.padding_idx].zero_()
  353. elif isinstance(module, EomtLayerScale):
  354. if hasattr(module, "lambda1"):
  355. module.lambda1.data.fill_(self.config.layerscale_value)
  356. elif isinstance(module, EomtEmbeddings):
  357. module.cls_token.data = nn.init.trunc_normal_(
  358. module.cls_token.data.to(torch.float32), mean=0.0, std=std
  359. ).to(module.cls_token.dtype)
  360. module.register_tokens.data.zero_()
  361. @auto_docstring(
  362. custom_intro="""
  363. The EoMT Model with head on top for instance/semantic/panoptic segmentation.
  364. """
  365. )
  366. class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation):
  367. def __init__(self, config: EomtConfig):
  368. PreTrainedModel.__init__(self, config)
  369. self.config = config
  370. self.num_hidden_layers = config.num_hidden_layers
  371. self.embeddings = EomtEmbeddings(config)
  372. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  373. self.query = nn.Embedding(config.num_queries, config.hidden_size)
  374. self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)])
  375. self.upscale_block = EomtScaleBlock(config)
  376. self.mask_head = EomtMaskHead(config)
  377. self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
  378. self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  379. self.weight_dict: dict[str, float] = {
  380. "loss_cross_entropy": config.class_weight,
  381. "loss_mask": config.mask_weight,
  382. "loss_dice": config.dice_weight,
  383. }
  384. self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict)
  385. self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
  386. self.post_init()
  387. def get_input_embeddings(self):
  388. return self.embeddings.patch_embeddings
  389. def get_auxiliary_logits(self):
  390. raise AttributeError("Note needed for Eomt Model.")
  391. def predict(self, logits: torch.Tensor):
  392. query_tokens = logits[:, : self.config.num_queries, :]
  393. class_logits = self.class_predictor(query_tokens)
  394. prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
  395. prefix_tokens = prefix_tokens.transpose(1, 2)
  396. prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
  397. query_tokens = self.mask_head(query_tokens)
  398. prefix_tokens = self.upscale_block(prefix_tokens)
  399. mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
  400. return mask_logits, class_logits
  401. @staticmethod
  402. def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
  403. if prob < 1:
  404. # Generate random queries to disable based on the probs
  405. random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
  406. # Disable attention to the query tokens, considering the prefix tokens
  407. attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
  408. return attn_mask
  409. @check_model_inputs()
  410. @auto_docstring
  411. def forward(
  412. self,
  413. pixel_values: Tensor,
  414. mask_labels: Optional[list[Tensor]] = None,
  415. class_labels: Optional[list[Tensor]] = None,
  416. patch_offsets: Optional[list[Tensor]] = None,
  417. **kwargs: Unpack[TransformersKwargs],
  418. ) -> EomtForUniversalSegmentationOutput:
  419. r"""
  420. mask_labels (`list[torch.Tensor]`, *optional*):
  421. list of mask labels of shape `(num_labels, height, width)` to be fed to a model
  422. class_labels (`list[torch.LongTensor]`, *optional*):
  423. list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
  424. labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
  425. patch_offsets (`list[torch.Tensor]`, *optional*):
  426. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  427. """
  428. masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
  429. attention_mask = None
  430. if pixel_values is None:
  431. raise ValueError("You have to specify pixel_values")
  432. hidden_states = self.embeddings(pixel_values)
  433. for idx, layer_module in enumerate(self.layers):
  434. if idx == self.num_hidden_layers - self.config.num_blocks:
  435. query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
  436. hidden_states = torch.cat((query, hidden_states), dim=1)
  437. if idx >= self.num_hidden_layers - self.config.num_blocks and (
  438. self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
  439. ):
  440. norm_hidden_states = self.layernorm(hidden_states)
  441. masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
  442. masks_queries_logits_per_layer += (masks_queries_logits,)
  443. class_queries_logits_per_layer += (class_queries_logits,)
  444. attention_mask = torch.ones(
  445. hidden_states.shape[0],
  446. hidden_states.shape[1],
  447. hidden_states.shape[1],
  448. device=hidden_states.device,
  449. dtype=torch.bool,
  450. )
  451. interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
  452. interpolated_logits = interpolated_logits.view(
  453. interpolated_logits.size(0), interpolated_logits.size(1), -1
  454. )
  455. num_query_tokens = self.config.num_queries
  456. encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens
  457. # Set attention mask for queries to focus on encoder tokens based on interpolated logits
  458. attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
  459. # Disable attention mask for random query tokens.
  460. attention_mask = self._disable_attention_mask(
  461. attention_mask,
  462. prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
  463. num_query_tokens=num_query_tokens,
  464. encoder_start_tokens=encoder_start_tokens,
  465. device=attention_mask.device,
  466. )
  467. # Expand attention mask to 4d mask.
  468. attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
  469. attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9)
  470. hidden_states = layer_module(hidden_states, attention_mask)
  471. sequence_output = self.layernorm(hidden_states)
  472. masks_queries_logits, class_queries_logits = self.predict(sequence_output)
  473. masks_queries_logits_per_layer += (masks_queries_logits,)
  474. class_queries_logits_per_layer += (class_queries_logits,)
  475. loss = None
  476. if mask_labels is not None and class_labels is not None:
  477. loss = 0.0
  478. for masks_queries_logits, class_queries_logits in zip(
  479. masks_queries_logits_per_layer, class_queries_logits_per_layer
  480. ):
  481. loss_dict = self.get_loss_dict(
  482. masks_queries_logits=masks_queries_logits,
  483. class_queries_logits=class_queries_logits,
  484. mask_labels=mask_labels,
  485. class_labels=class_labels,
  486. auxiliary_predictions=None,
  487. )
  488. loss += self.get_loss(loss_dict)
  489. return EomtForUniversalSegmentationOutput(
  490. loss=loss,
  491. masks_queries_logits=masks_queries_logits,
  492. class_queries_logits=class_queries_logits,
  493. last_hidden_state=sequence_output,
  494. patch_offsets=patch_offsets,
  495. )
  496. __all__ = ["EomtConfig", "EomtPreTrainedModel", "EomtForUniversalSegmentation"]