modeling_eomt.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/eomt/modular_eomt.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_eomt.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import collections.abc
  22. import math
  23. from dataclasses import dataclass
  24. from typing import Callable, Optional
  25. import numpy as np
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import Tensor, nn
  29. from ...activations import ACT2FN
  30. from ...file_utils import ModelOutput, is_scipy_available, requires_backends
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, is_accelerate_available
  35. from ...utils.generic import check_model_inputs
  36. from .configuration_eomt import EomtConfig
  37. if is_scipy_available():
  38. from scipy.optimize import linear_sum_assignment
  39. if is_accelerate_available():
  40. from accelerate import PartialState
  41. from accelerate.utils import reduce
  42. @dataclass
  43. @auto_docstring(
  44. custom_intro="""
  45. Class for outputs of [`EomtForUniversalSegmentationOutput`].
  46. This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or
  47. [`~EomtImageProcessor.post_process_instance_segmentation`] or
  48. [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
  49. [`~EomtImageProcessor] for details regarding usage.
  50. """
  51. )
  52. class EomtForUniversalSegmentationOutput(ModelOutput):
  53. r"""
  54. loss (`torch.Tensor`, *optional*):
  55. The computed loss, returned when labels are present.
  56. class_queries_logits (`torch.FloatTensor`):
  57. A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
  58. query. Note the `+ 1` is needed because we incorporate the null class.
  59. masks_queries_logits (`torch.FloatTensor`):
  60. A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
  61. query.
  62. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  63. Last hidden states (final feature map) of the last layer.
  64. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  65. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  66. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
  67. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  68. Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  69. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
  70. patch_offsets (`list[torch.Tensor]`, *optional*):
  71. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  72. """
  73. loss: Optional[torch.FloatTensor] = None
  74. class_queries_logits: Optional[torch.FloatTensor] = None
  75. masks_queries_logits: Optional[torch.FloatTensor] = None
  76. last_hidden_state: Optional[torch.FloatTensor] = None
  77. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  78. attentions: Optional[tuple[torch.FloatTensor]] = None
  79. patch_offsets: Optional[list[torch.Tensor]] = None
  80. # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
  81. def sample_point(
  82. input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
  83. ) -> torch.Tensor:
  84. """
  85. A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
  86. Args:
  87. input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
  88. A tensor that contains features map on a height * width grid
  89. point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
  90. 2)):
  91. A tensor that contains [0, 1] * [0, 1] normalized point coordinates
  92. add_dim (`bool`):
  93. boolean value to keep track of added dimension
  94. Returns:
  95. point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
  96. height_grid, width_grid):
  97. A tensor that contains features for points in `point_coordinates`.
  98. """
  99. if point_coordinates.dim() == 3:
  100. add_dim = True
  101. point_coordinates = point_coordinates.unsqueeze(2)
  102. # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
  103. point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
  104. if add_dim:
  105. point_features = point_features.squeeze(3)
  106. return point_features
  107. def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
  108. """
  109. A pair wise version of the dice loss, see `dice_loss` for usage.
  110. Args:
  111. inputs (`torch.Tensor`):
  112. A tensor representing a mask
  113. labels (`torch.Tensor`):
  114. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  115. (0 for the negative class and 1 for the positive class).
  116. Returns:
  117. `torch.Tensor`: The computed loss between each pairs.
  118. """
  119. inputs = inputs.sigmoid().flatten(1)
  120. numerator = 2 * torch.matmul(inputs, labels.T)
  121. # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
  122. denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
  123. loss = 1 - (numerator + 1) / (denominator + 1)
  124. return loss
  125. def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
  126. r"""
  127. A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
  128. Args:
  129. inputs (`torch.Tensor`):
  130. A tensor representing a mask.
  131. labels (`torch.Tensor`):
  132. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  133. (0 for the negative class and 1 for the positive class).
  134. Returns:
  135. loss (`torch.Tensor`): The computed loss between each pairs.
  136. """
  137. height_and_width = inputs.shape[1]
  138. criterion = nn.BCEWithLogitsLoss(reduction="none")
  139. cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
  140. cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
  141. loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
  142. loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
  143. loss = loss_pos + loss_neg
  144. return loss
  145. # Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/matcher.py
  146. class EomtHungarianMatcher(nn.Module):
  147. """This class computes an assignment between the labels and the predictions of the network.
  148. For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
  149. predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
  150. un-matched (and thus treated as non-objects).
  151. """
  152. def __init__(
  153. self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
  154. ):
  155. """Creates the matcher
  156. Params:
  157. cost_class (`float`, *optional*, defaults to 1.0):
  158. Relative weight of the classification error in the matching cost.
  159. cost_mask (`float`, *optional*, defaults to 1.0):
  160. This is the relative weight of the focal loss of the binary mask in the matching cost.
  161. cost_dice (`float`, *optional*, defaults to 1.0):
  162. This is the relative weight of the dice loss of the binary mask in the matching cost.
  163. num_points (`int`, *optional*, defaults to 12544):
  164. No. of points to sample on which the mask loss will be calculated. The same set of K points are
  165. uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
  166. matching.
  167. """
  168. super().__init__()
  169. if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
  170. raise ValueError("All costs can't be 0")
  171. self.num_points = num_points
  172. self.cost_class = cost_class
  173. self.cost_mask = cost_mask
  174. self.cost_dice = cost_dice
  175. @torch.no_grad()
  176. def forward(
  177. self,
  178. masks_queries_logits: torch.Tensor,
  179. class_queries_logits: torch.Tensor,
  180. mask_labels: torch.Tensor,
  181. class_labels: torch.Tensor,
  182. ) -> list[tuple[Tensor]]:
  183. """
  184. Params:
  185. masks_queries_logits (`torch.Tensor`):
  186. A tensor of dim `batch_size, num_queries, num_labels` with the classification logits.
  187. class_queries_logits (`torch.Tensor`):
  188. A tensor of dim `batch_size, num_queries, height, width` with the predicted masks.
  189. class_labels (`torch.Tensor`):
  190. A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the
  191. target) containing the class labels.
  192. mask_labels (`torch.Tensor`):
  193. A tensor of dim `num_target_boxes, height, width` containing the target masks.
  194. Returns:
  195. matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j)
  196. where:
  197. - index_i is the indices of the selected predictions (in order)
  198. - index_j is the indices of the corresponding selected labels (in order)
  199. For each batch element, it holds:
  200. len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
  201. """
  202. indices: list[tuple[np.array]] = []
  203. # iterate through batch size
  204. batch_size = masks_queries_logits.shape[0]
  205. for i in range(batch_size):
  206. pred_probs = class_queries_logits[i].softmax(-1)
  207. pred_mask = masks_queries_logits[i]
  208. # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted.
  209. cost_class = -pred_probs[:, class_labels[i]]
  210. target_mask = mask_labels[i].to(pred_mask)
  211. target_mask = target_mask[:, None]
  212. pred_mask = pred_mask[:, None]
  213. # Sample ground truth and predicted masks
  214. point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
  215. target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1)
  216. target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
  217. pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1)
  218. pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
  219. # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
  220. cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
  221. # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels)
  222. cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
  223. # final cost matrix
  224. cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
  225. # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
  226. cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
  227. cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
  228. cost_matrix = torch.nan_to_num(cost_matrix, 0)
  229. # do the assignment using the hungarian algorithm in scipy
  230. assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
  231. indices.append(assigned_indices)
  232. # It could be stacked in one tensor
  233. matched_indices = [
  234. (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
  235. ]
  236. return matched_indices
  237. def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
  238. r"""
  239. Compute the DICE loss, similar to generalized IOU for masks as follows:
  240. $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
  241. In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
  242. $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
  243. Args:
  244. inputs (`torch.Tensor`):
  245. A tensor representing a mask.
  246. labels (`torch.Tensor`):
  247. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  248. (0 for the negative class and 1 for the positive class).
  249. num_masks (`int`):
  250. The number of masks present in the current batch, used for normalization.
  251. Returns:
  252. `torch.Tensor`: The computed loss.
  253. """
  254. probs = inputs.sigmoid().flatten(1)
  255. numerator = 2 * (probs * labels).sum(-1)
  256. denominator = probs.sum(-1) + labels.sum(-1)
  257. loss = 1 - (numerator + 1) / (denominator + 1)
  258. loss = loss.sum() / num_masks
  259. return loss
  260. def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
  261. r"""
  262. Args:
  263. inputs (`torch.Tensor`):
  264. A float tensor of arbitrary shape.
  265. labels (`torch.Tensor`):
  266. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  267. (0 for the negative class and 1 for the positive class).
  268. Returns:
  269. loss (`torch.Tensor`): The computed loss.
  270. """
  271. criterion = nn.BCEWithLogitsLoss(reduction="none")
  272. cross_entropy_loss = criterion(inputs, labels)
  273. loss = cross_entropy_loss.mean(1).sum() / num_masks
  274. return loss
  275. # Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/criterion.py
  276. class EomtLoss(nn.Module):
  277. def __init__(self, config: EomtConfig, weight_dict: dict[str, float]):
  278. """
  279. The Eomt Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
  280. compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
  281. of matched ground-truth / prediction (supervise class and mask)
  282. Args:
  283. config (`EomtConfig`):
  284. The configuration for Eomt model also containing loss calculation specific parameters.
  285. weight_dict (`dict[str, float]`):
  286. A dictionary of weights to be applied to the different losses.
  287. """
  288. super().__init__()
  289. requires_backends(self, ["scipy"])
  290. self.num_labels = config.num_labels
  291. self.weight_dict = weight_dict
  292. # Weight to apply to the null class
  293. self.eos_coef = config.no_object_weight
  294. empty_weight = torch.ones(self.num_labels + 1)
  295. empty_weight[-1] = self.eos_coef
  296. self.register_buffer("empty_weight", empty_weight)
  297. # pointwise mask loss parameters
  298. self.num_points = config.train_num_points
  299. self.oversample_ratio = config.oversample_ratio
  300. self.importance_sample_ratio = config.importance_sample_ratio
  301. self.matcher = EomtHungarianMatcher(
  302. cost_class=config.class_weight,
  303. cost_dice=config.dice_weight,
  304. cost_mask=config.mask_weight,
  305. num_points=self.num_points,
  306. )
  307. def _max_by_axis(self, sizes: list[list[int]]) -> list[int]:
  308. maxes = sizes[0]
  309. for sublist in sizes[1:]:
  310. for index, item in enumerate(sublist):
  311. maxes[index] = max(maxes[index], item)
  312. return maxes
  313. # Adapted from nested_tensor_from_tensor_list() in original implementation
  314. def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
  315. # get the maximum size in the batch
  316. max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
  317. # compute final size
  318. batch_shape = [len(tensors)] + max_size
  319. batch_size, _, height, width = batch_shape
  320. dtype = tensors[0].dtype
  321. device = tensors[0].device
  322. padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
  323. padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
  324. # pad the tensors to the size of the biggest one
  325. for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
  326. padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
  327. padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
  328. return padded_tensors, padding_masks
  329. def loss_labels(
  330. self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array]
  331. ) -> dict[str, Tensor]:
  332. """Compute the losses related to the labels using cross entropy.
  333. Args:
  334. class_queries_logits (`torch.Tensor`):
  335. A tensor of shape `batch_size, num_queries, num_labels`
  336. class_labels (`list[torch.Tensor]`):
  337. List of class labels of shape `(labels)`.
  338. indices (`tuple[np.array])`:
  339. The indices computed by the Hungarian matcher.
  340. Returns:
  341. `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
  342. - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
  343. """
  344. pred_logits = class_queries_logits
  345. batch_size, num_queries, _ = pred_logits.shape
  346. criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
  347. idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries)
  348. target_classes_o = torch.cat(
  349. [target[j] for target, (_, j) in zip(class_labels, indices)]
  350. ) # shape of (batch_size, num_queries)
  351. target_classes = torch.full(
  352. (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
  353. )
  354. target_classes[idx] = target_classes_o
  355. # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
  356. pred_logits_transposed = pred_logits.transpose(1, 2)
  357. loss_ce = criterion(pred_logits_transposed, target_classes)
  358. losses = {"loss_cross_entropy": loss_ce}
  359. return losses
  360. def loss_masks(
  361. self,
  362. masks_queries_logits: torch.Tensor,
  363. mask_labels: list[torch.Tensor],
  364. indices: tuple[np.array],
  365. num_masks: int,
  366. ) -> dict[str, torch.Tensor]:
  367. """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
  368. Args:
  369. masks_queries_logits (`torch.Tensor`):
  370. A tensor of shape `(batch_size, num_queries, height, width)`.
  371. mask_labels (`torch.Tensor`):
  372. List of mask labels of shape `(labels, height, width)`.
  373. indices (`tuple[np.array])`:
  374. The indices computed by the Hungarian matcher.
  375. num_masks (`int)`:
  376. The number of masks, used for normalization.
  377. Returns:
  378. losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
  379. - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
  380. masks.
  381. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
  382. masks.
  383. """
  384. src_idx = self._get_predictions_permutation_indices(indices)
  385. tgt_idx = self._get_targets_permutation_indices(indices)
  386. # shape (batch_size * num_queries, height, width)
  387. pred_masks = masks_queries_logits[src_idx]
  388. # shape (batch_size, num_queries, height, width)
  389. # pad all and stack the targets to the num_labels dimension
  390. target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
  391. target_masks = target_masks[tgt_idx]
  392. # No need to upsample predictions as we are using normalized coordinates
  393. pred_masks = pred_masks[:, None]
  394. target_masks = target_masks[:, None]
  395. # Sample point coordinates
  396. with torch.no_grad():
  397. point_coordinates = self.sample_points_using_uncertainty(
  398. pred_masks,
  399. lambda logits: self.calculate_uncertainty(logits),
  400. self.num_points,
  401. self.oversample_ratio,
  402. self.importance_sample_ratio,
  403. )
  404. point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
  405. point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
  406. losses = {
  407. "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
  408. "loss_dice": dice_loss(point_logits, point_labels, num_masks),
  409. }
  410. del pred_masks
  411. del target_masks
  412. return losses
  413. def _get_predictions_permutation_indices(self, indices):
  414. # Permute predictions following indices
  415. batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  416. predictions_indices = torch.cat([src for (src, _) in indices])
  417. return batch_indices, predictions_indices
  418. def _get_targets_permutation_indices(self, indices):
  419. # Permute labels following indices
  420. batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  421. target_indices = torch.cat([tgt for (_, tgt) in indices])
  422. return batch_indices, target_indices
  423. def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
  424. """
  425. In Eomt paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
  426. for the foreground class in `classes`.
  427. Args:
  428. logits (`torch.Tensor`):
  429. A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
  430. the number of foreground classes. The values are logits.
  431. Returns:
  432. scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
  433. uncertain locations having the highest uncertainty score.
  434. """
  435. uncertainty_scores = -(torch.abs(logits))
  436. return uncertainty_scores
  437. def sample_points_using_uncertainty(
  438. self,
  439. logits: torch.Tensor,
  440. uncertainty_function,
  441. num_points: int,
  442. oversample_ratio: int,
  443. importance_sample_ratio: float,
  444. ) -> torch.Tensor:
  445. """
  446. This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
  447. uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
  448. prediction as input.
  449. Args:
  450. logits (`float`):
  451. Logit predictions for P points.
  452. uncertainty_function:
  453. A function that takes logit predictions for P points and returns their uncertainties.
  454. num_points (`int`):
  455. The number of points P to sample.
  456. oversample_ratio (`int`):
  457. Oversampling parameter.
  458. importance_sample_ratio (`float`):
  459. Ratio of points that are sampled via importance sampling.
  460. Returns:
  461. point_coordinates (`torch.Tensor`):
  462. Coordinates for P sampled points.
  463. """
  464. num_boxes = logits.shape[0]
  465. num_points_sampled = int(num_points * oversample_ratio)
  466. # Get random point coordinates
  467. point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
  468. # Get sampled prediction value for the point coordinates
  469. point_logits = sample_point(logits, point_coordinates, align_corners=False)
  470. # Calculate the uncertainties based on the sampled prediction values of the points
  471. point_uncertainties = uncertainty_function(point_logits)
  472. num_uncertain_points = int(importance_sample_ratio * num_points)
  473. num_random_points = num_points - num_uncertain_points
  474. idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
  475. shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
  476. idx += shift[:, None]
  477. point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
  478. if num_random_points > 0:
  479. point_coordinates = torch.cat(
  480. [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
  481. dim=1,
  482. )
  483. return point_coordinates
  484. def forward(
  485. self,
  486. masks_queries_logits: torch.Tensor,
  487. class_queries_logits: torch.Tensor,
  488. mask_labels: list[torch.Tensor],
  489. class_labels: list[torch.Tensor],
  490. auxiliary_predictions: Optional[dict[str, torch.Tensor]] = None,
  491. ) -> dict[str, torch.Tensor]:
  492. """
  493. This performs the loss computation.
  494. Args:
  495. masks_queries_logits (`torch.Tensor`):
  496. A tensor of shape `(batch_size, num_queries, height, width)`.
  497. class_queries_logits (`torch.Tensor`):
  498. A tensor of shape `(batch_size, num_queries, num_labels)`.
  499. mask_labels (`torch.Tensor`):
  500. List of mask labels of shape `(labels, height, width)`.
  501. class_labels (`list[torch.Tensor]`):
  502. List of class labels of shape `(labels)`.
  503. auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*):
  504. if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], then it contains the logits from
  505. the inner layers of the EomtMaskedAttentionDecoder.
  506. Returns:
  507. losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
  508. - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
  509. - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
  510. masks.
  511. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
  512. masks.
  513. if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], the dictionary contains additional
  514. losses for each auxiliary predictions.
  515. """
  516. # retrieve the matching between the outputs of the last layer and the labels
  517. indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
  518. # compute the average number of target masks for normalization purposes
  519. num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
  520. # get all the losses
  521. losses: dict[str, Tensor] = {
  522. **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
  523. **self.loss_labels(class_queries_logits, class_labels, indices),
  524. }
  525. # in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  526. if auxiliary_predictions is not None:
  527. for idx, aux_outputs in enumerate(auxiliary_predictions):
  528. masks_queries_logits = aux_outputs["masks_queries_logits"]
  529. class_queries_logits = aux_outputs["class_queries_logits"]
  530. loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
  531. loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
  532. losses.update(loss_dict)
  533. return losses
  534. def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
  535. """
  536. Computes the average number of target masks across the batch, for normalization purposes.
  537. """
  538. num_masks = sum(len(classes) for classes in class_labels)
  539. num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
  540. world_size = 1
  541. if is_accelerate_available():
  542. if PartialState._shared_state != {}:
  543. num_masks = reduce(num_masks)
  544. world_size = PartialState().num_processes
  545. num_masks = torch.clamp(num_masks / world_size, min=1)
  546. return num_masks
  547. class EomtPatchEmbeddings(nn.Module):
  548. """
  549. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  550. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  551. Transformer.
  552. """
  553. def __init__(self, config):
  554. super().__init__()
  555. image_size, patch_size = config.image_size, config.patch_size
  556. num_channels, hidden_size = config.num_channels, config.hidden_size
  557. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  558. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  559. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  560. self.image_size = image_size
  561. self.patch_size = patch_size
  562. self.num_channels = num_channels
  563. self.num_patches = num_patches
  564. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  565. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  566. num_channels = pixel_values.shape[1]
  567. if num_channels != self.num_channels:
  568. raise ValueError(
  569. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  570. f" Expected {self.num_channels} but got {num_channels}."
  571. )
  572. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  573. return embeddings
  574. class EomtEmbeddings(nn.Module):
  575. """
  576. Construct the CLS token, mask token, position and patch embeddings.
  577. """
  578. def __init__(self, config: EomtConfig) -> None:
  579. super().__init__()
  580. self.config = config
  581. self.patch_size = config.patch_size
  582. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  583. self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
  584. self.patch_embeddings = EomtPatchEmbeddings(config)
  585. num_patches = self.patch_embeddings.num_patches
  586. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  587. self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
  588. self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
  589. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  590. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  591. batch_size, _, _, _ = pixel_values.shape
  592. target_dtype = self.patch_embeddings.projection.weight.dtype
  593. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  594. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  595. register_tokens = self.register_tokens.expand(batch_size, -1, -1)
  596. embeddings = embeddings + self.position_embeddings(self.position_ids)
  597. embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
  598. embeddings = self.dropout(embeddings)
  599. return embeddings
  600. def eager_attention_forward(
  601. module: nn.Module,
  602. query: torch.Tensor,
  603. key: torch.Tensor,
  604. value: torch.Tensor,
  605. attention_mask: Optional[torch.Tensor],
  606. scaling: float,
  607. dropout: float = 0.0,
  608. **kwargs,
  609. ):
  610. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  611. if attention_mask is not None:
  612. attn_weights = attn_weights + attention_mask
  613. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  614. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  615. attn_output = torch.matmul(attn_weights, value)
  616. attn_output = attn_output.transpose(1, 2).contiguous()
  617. return attn_output, attn_weights
  618. class EomtAttention(nn.Module):
  619. """Multi-headed attention from 'Attention Is All You Need' paper"""
  620. def __init__(self, config):
  621. super().__init__()
  622. self.config = config
  623. self.embed_dim = config.hidden_size
  624. self.num_heads = config.num_attention_heads
  625. self.head_dim = self.embed_dim // self.num_heads
  626. if self.head_dim * self.num_heads != self.embed_dim:
  627. raise ValueError(
  628. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  629. f" {self.num_heads})."
  630. )
  631. self.scale = self.head_dim**-0.5
  632. self.dropout = config.attention_dropout
  633. self.is_causal = False
  634. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  635. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  636. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  637. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  638. def forward(
  639. self,
  640. hidden_states: torch.Tensor,
  641. attention_mask: Optional[torch.Tensor] = None,
  642. **kwargs,
  643. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  644. """Input shape: Batch x Time x Channel"""
  645. batch_size, seq_length, embed_dim = hidden_states.shape
  646. queries = self.q_proj(hidden_states)
  647. keys = self.k_proj(hidden_states)
  648. values = self.v_proj(hidden_states)
  649. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  650. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  651. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  652. attention_interface: Callable = eager_attention_forward
  653. if self.config._attn_implementation != "eager":
  654. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  655. attn_output, attn_weights = attention_interface(
  656. self,
  657. queries,
  658. keys,
  659. values,
  660. attention_mask,
  661. is_causal=self.is_causal,
  662. scaling=self.scale,
  663. dropout=0.0 if not self.training else self.dropout,
  664. )
  665. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  666. attn_output = self.out_proj(attn_output)
  667. return attn_output, attn_weights
  668. class EomtLayerScale(nn.Module):
  669. def __init__(self, config) -> None:
  670. super().__init__()
  671. self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
  672. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  673. return hidden_state * self.lambda1
  674. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  675. """
  676. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  677. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  678. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  679. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  680. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  681. argument.
  682. """
  683. if drop_prob == 0.0 or not training:
  684. return input
  685. keep_prob = 1 - drop_prob
  686. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  687. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  688. random_tensor.floor_() # binarize
  689. output = input.div(keep_prob) * random_tensor
  690. return output
  691. class EomtDropPath(nn.Module):
  692. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  693. def __init__(self, drop_prob: Optional[float] = None) -> None:
  694. super().__init__()
  695. self.drop_prob = drop_prob
  696. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  697. return drop_path(hidden_states, self.drop_prob, self.training)
  698. def extra_repr(self) -> str:
  699. return f"p={self.drop_prob}"
  700. class EomtMLP(nn.Module):
  701. def __init__(self, config) -> None:
  702. super().__init__()
  703. in_features = out_features = config.hidden_size
  704. hidden_features = int(config.hidden_size * config.mlp_ratio)
  705. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  706. if isinstance(config.hidden_act, str):
  707. self.activation = ACT2FN[config.hidden_act]
  708. else:
  709. self.activation = config.hidden_act
  710. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  711. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  712. hidden_state = self.fc1(hidden_state)
  713. hidden_state = self.activation(hidden_state)
  714. hidden_state = self.fc2(hidden_state)
  715. return hidden_state
  716. class EomtSwiGLUFFN(nn.Module):
  717. def __init__(self, config) -> None:
  718. super().__init__()
  719. in_features = out_features = config.hidden_size
  720. hidden_features = int(config.hidden_size * config.mlp_ratio)
  721. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  722. self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
  723. self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
  724. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  725. hidden_state = self.weights_in(hidden_state)
  726. x1, x2 = hidden_state.chunk(2, dim=-1)
  727. hidden = nn.functional.silu(x1) * x2
  728. return self.weights_out(hidden)
  729. class EomtLayer(GradientCheckpointingLayer):
  730. """This corresponds to the Block class in the original implementation."""
  731. def __init__(self, config: EomtConfig) -> None:
  732. super().__init__()
  733. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  734. self.attention = EomtAttention(config)
  735. self.layer_scale1 = EomtLayerScale(config)
  736. self.drop_path = EomtDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  737. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  738. if config.use_swiglu_ffn:
  739. self.mlp = EomtSwiGLUFFN(config)
  740. else:
  741. self.mlp = EomtMLP(config)
  742. self.layer_scale2 = EomtLayerScale(config)
  743. def forward(
  744. self,
  745. hidden_states: torch.Tensor,
  746. head_mask: Optional[torch.Tensor] = None,
  747. ) -> torch.Tensor:
  748. hidden_states_norm = self.norm1(hidden_states)
  749. self_attention_output, _ = self.attention(hidden_states_norm, head_mask)
  750. self_attention_output = self.layer_scale1(self_attention_output)
  751. # first residual connection
  752. hidden_states = self.drop_path(self_attention_output) + hidden_states
  753. # in Eomt, layernorm is also applied after self-attention
  754. layer_output = self.norm2(hidden_states)
  755. layer_output = self.mlp(layer_output)
  756. layer_output = self.layer_scale2(layer_output)
  757. # second residual connection
  758. layer_output = self.drop_path(layer_output) + hidden_states
  759. return layer_output
  760. class EomtLayerNorm2d(nn.LayerNorm):
  761. def __init__(self, num_channels, eps=1e-6, affine=True):
  762. super().__init__(num_channels, eps=eps, elementwise_affine=affine)
  763. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  764. hidden_state = hidden_state.permute(0, 2, 3, 1)
  765. hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
  766. hidden_state = hidden_state.permute(0, 3, 1, 2)
  767. return hidden_state
  768. class EomtScaleLayer(nn.Module):
  769. def __init__(self, config: EomtConfig):
  770. super().__init__()
  771. hidden_size = config.hidden_size
  772. self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
  773. self.activation = ACT2FN[config.hidden_act]
  774. self.conv2 = nn.Conv2d(
  775. hidden_size,
  776. hidden_size,
  777. kernel_size=3,
  778. padding=1,
  779. groups=hidden_size,
  780. bias=False,
  781. )
  782. self.layernorm2d = EomtLayerNorm2d(hidden_size)
  783. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  784. hidden_states = self.conv1(hidden_states)
  785. hidden_states = self.activation(hidden_states)
  786. hidden_states = self.conv2(hidden_states)
  787. hidden_states = self.layernorm2d(hidden_states)
  788. return hidden_states
  789. class EomtScaleBlock(nn.Module):
  790. def __init__(self, config: EomtConfig):
  791. super().__init__()
  792. self.num_blocks = config.num_upscale_blocks
  793. self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)])
  794. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  795. for block in self.block:
  796. hidden_states = block(hidden_states)
  797. return hidden_states
  798. class EomtMaskHead(nn.Module):
  799. def __init__(self, config: EomtConfig):
  800. super().__init__()
  801. hidden_size = config.hidden_size
  802. self.fc1 = nn.Linear(hidden_size, hidden_size)
  803. self.fc2 = nn.Linear(hidden_size, hidden_size)
  804. self.fc3 = nn.Linear(hidden_size, hidden_size)
  805. self.activation = ACT2FN[config.hidden_act]
  806. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  807. hidden_states = self.activation(self.fc1(hidden_states))
  808. hidden_states = self.activation(self.fc2(hidden_states))
  809. hidden_states = self.fc3(hidden_states)
  810. return hidden_states
  811. @auto_docstring
  812. class EomtPreTrainedModel(PreTrainedModel):
  813. """
  814. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  815. models.
  816. """
  817. config: EomtConfig
  818. base_model_prefix = "eomt"
  819. main_input_name = "pixel_values"
  820. supports_gradient_checkpointing = False
  821. _no_split_modules = ["EomtLayer"]
  822. _supports_sdpa = True
  823. _can_record_outputs = {
  824. "hidden_states": EomtLayer,
  825. "attentions": EomtAttention,
  826. }
  827. def _init_weights(self, module: nn.Module) -> None:
  828. std = self.config.initializer_range
  829. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  830. nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  831. if module.bias is not None:
  832. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  833. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  834. nn.init.uniform_(module.bias, -bound, bound)
  835. elif isinstance(module, nn.LayerNorm):
  836. module.weight.data.fill_(1.0)
  837. module.bias.data.zero_()
  838. elif isinstance(module, nn.Embedding):
  839. module.weight.data.normal_(mean=0.0, std=1)
  840. if module.padding_idx is not None:
  841. module.weight.data[module.padding_idx].zero_()
  842. elif isinstance(module, EomtLayerScale):
  843. if hasattr(module, "lambda1"):
  844. module.lambda1.data.fill_(self.config.layerscale_value)
  845. elif isinstance(module, EomtEmbeddings):
  846. module.cls_token.data = nn.init.trunc_normal_(
  847. module.cls_token.data.to(torch.float32), mean=0.0, std=std
  848. ).to(module.cls_token.dtype)
  849. module.register_tokens.data.zero_()
  850. @auto_docstring(
  851. custom_intro="""
  852. The EoMT Model with head on top for instance/semantic/panoptic segmentation.
  853. """
  854. )
  855. class EomtForUniversalSegmentation(EomtPreTrainedModel):
  856. main_input_name = "pixel_values"
  857. def __init__(self, config: EomtConfig):
  858. super().__init__(config)
  859. self.config = config
  860. self.num_hidden_layers = config.num_hidden_layers
  861. self.embeddings = EomtEmbeddings(config)
  862. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  863. self.query = nn.Embedding(config.num_queries, config.hidden_size)
  864. self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)])
  865. self.upscale_block = EomtScaleBlock(config)
  866. self.mask_head = EomtMaskHead(config)
  867. self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
  868. self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  869. self.weight_dict: dict[str, float] = {
  870. "loss_cross_entropy": config.class_weight,
  871. "loss_mask": config.mask_weight,
  872. "loss_dice": config.dice_weight,
  873. }
  874. self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict)
  875. self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
  876. self.post_init()
  877. def get_loss_dict(
  878. self,
  879. masks_queries_logits: Tensor,
  880. class_queries_logits: Tensor,
  881. mask_labels: Tensor,
  882. class_labels: Tensor,
  883. auxiliary_predictions: dict[str, Tensor],
  884. ) -> dict[str, Tensor]:
  885. loss_dict: dict[str, Tensor] = self.criterion(
  886. masks_queries_logits=masks_queries_logits,
  887. class_queries_logits=class_queries_logits,
  888. mask_labels=mask_labels,
  889. class_labels=class_labels,
  890. auxiliary_predictions=auxiliary_predictions,
  891. )
  892. # weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
  893. for key, weight in self.weight_dict.items():
  894. for loss_key, loss in loss_dict.items():
  895. if key in loss_key:
  896. loss *= weight
  897. return loss_dict
  898. def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor:
  899. return sum(loss_dict.values())
  900. @check_model_inputs()
  901. @auto_docstring
  902. def forward(
  903. self,
  904. pixel_values: Tensor,
  905. mask_labels: Optional[list[Tensor]] = None,
  906. class_labels: Optional[list[Tensor]] = None,
  907. patch_offsets: Optional[list[Tensor]] = None,
  908. **kwargs: Unpack[TransformersKwargs],
  909. ) -> EomtForUniversalSegmentationOutput:
  910. r"""
  911. mask_labels (`list[torch.Tensor]`, *optional*):
  912. list of mask labels of shape `(num_labels, height, width)` to be fed to a model
  913. class_labels (`list[torch.LongTensor]`, *optional*):
  914. list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
  915. labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
  916. patch_offsets (`list[torch.Tensor]`, *optional*):
  917. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  918. """
  919. masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
  920. attention_mask = None
  921. if pixel_values is None:
  922. raise ValueError("You have to specify pixel_values")
  923. hidden_states = self.embeddings(pixel_values)
  924. for idx, layer_module in enumerate(self.layers):
  925. if idx == self.num_hidden_layers - self.config.num_blocks:
  926. query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
  927. hidden_states = torch.cat((query, hidden_states), dim=1)
  928. if idx >= self.num_hidden_layers - self.config.num_blocks and (
  929. self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
  930. ):
  931. norm_hidden_states = self.layernorm(hidden_states)
  932. masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
  933. masks_queries_logits_per_layer += (masks_queries_logits,)
  934. class_queries_logits_per_layer += (class_queries_logits,)
  935. attention_mask = torch.ones(
  936. hidden_states.shape[0],
  937. hidden_states.shape[1],
  938. hidden_states.shape[1],
  939. device=hidden_states.device,
  940. dtype=torch.bool,
  941. )
  942. interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
  943. interpolated_logits = interpolated_logits.view(
  944. interpolated_logits.size(0), interpolated_logits.size(1), -1
  945. )
  946. num_query_tokens = self.config.num_queries
  947. encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens
  948. # Set attention mask for queries to focus on encoder tokens based on interpolated logits
  949. attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
  950. # Disable attention mask for random query tokens.
  951. attention_mask = self._disable_attention_mask(
  952. attention_mask,
  953. prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
  954. num_query_tokens=num_query_tokens,
  955. encoder_start_tokens=encoder_start_tokens,
  956. device=attention_mask.device,
  957. )
  958. # Expand attention mask to 4d mask.
  959. attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
  960. attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9)
  961. hidden_states = layer_module(hidden_states, attention_mask)
  962. sequence_output = self.layernorm(hidden_states)
  963. masks_queries_logits, class_queries_logits = self.predict(sequence_output)
  964. masks_queries_logits_per_layer += (masks_queries_logits,)
  965. class_queries_logits_per_layer += (class_queries_logits,)
  966. loss = None
  967. if mask_labels is not None and class_labels is not None:
  968. loss = 0.0
  969. for masks_queries_logits, class_queries_logits in zip(
  970. masks_queries_logits_per_layer, class_queries_logits_per_layer
  971. ):
  972. loss_dict = self.get_loss_dict(
  973. masks_queries_logits=masks_queries_logits,
  974. class_queries_logits=class_queries_logits,
  975. mask_labels=mask_labels,
  976. class_labels=class_labels,
  977. auxiliary_predictions=None,
  978. )
  979. loss += self.get_loss(loss_dict)
  980. return EomtForUniversalSegmentationOutput(
  981. loss=loss,
  982. masks_queries_logits=masks_queries_logits,
  983. class_queries_logits=class_queries_logits,
  984. last_hidden_state=sequence_output,
  985. patch_offsets=patch_offsets,
  986. )
  987. def get_input_embeddings(self):
  988. return self.embeddings.patch_embeddings
  989. def predict(self, logits: torch.Tensor):
  990. query_tokens = logits[:, : self.config.num_queries, :]
  991. class_logits = self.class_predictor(query_tokens)
  992. prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
  993. prefix_tokens = prefix_tokens.transpose(1, 2)
  994. prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
  995. query_tokens = self.mask_head(query_tokens)
  996. prefix_tokens = self.upscale_block(prefix_tokens)
  997. mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
  998. return mask_logits, class_logits
  999. @staticmethod
  1000. def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
  1001. if prob < 1:
  1002. # Generate random queries to disable based on the probs
  1003. random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
  1004. # Disable attention to the query tokens, considering the prefix tokens
  1005. attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
  1006. return attn_mask
  1007. __all__ = ["EomtPreTrainedModel", "EomtForUniversalSegmentation"]