loss_d_fine.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from ..utils import is_vision_available
  18. from .loss_for_object_detection import (
  19. box_iou,
  20. )
  21. from .loss_rt_detr import RTDetrHungarianMatcher, RTDetrLoss
  22. if is_vision_available():
  23. from transformers.image_transforms import center_to_corners_format
  24. @torch.jit.unused
  25. def _set_aux_loss(outputs_class, outputs_coord):
  26. # this is a workaround to make torchscript happy, as torchscript
  27. # doesn't support dictionary with non-homogeneous values, such
  28. # as a dict having both a Tensor and a list.
  29. return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
  30. @torch.jit.unused
  31. def _set_aux_loss2(
  32. outputs_class, outputs_coord, outputs_corners, outputs_ref, teacher_corners=None, teacher_logits=None
  33. ):
  34. # this is a workaround to make torchscript happy, as torchscript
  35. # doesn't support dictionary with non-homogeneous values, such
  36. # as a dict having both a Tensor and a list.
  37. return [
  38. {
  39. "logits": a,
  40. "pred_boxes": b,
  41. "pred_corners": c,
  42. "ref_points": d,
  43. "teacher_corners": teacher_corners,
  44. "teacher_logits": teacher_logits,
  45. }
  46. for a, b, c, d in zip(outputs_class, outputs_coord, outputs_corners, outputs_ref)
  47. ]
  48. def weighting_function(max_num_bins: int, up: torch.Tensor, reg_scale: int) -> torch.Tensor:
  49. """
  50. Generates the non-uniform Weighting Function W(n) for bounding box regression.
  51. Args:
  52. max_num_bins (int): Max number of the discrete bins.
  53. up (Tensor): Controls upper bounds of the sequence,
  54. where maximum offset is ±up * H / W.
  55. reg_scale (float): Controls the curvature of the Weighting Function.
  56. Larger values result in flatter weights near the central axis W(max_num_bins/2)=0
  57. and steeper weights at both ends.
  58. Returns:
  59. Tensor: Sequence of Weighting Function.
  60. """
  61. upper_bound1 = abs(up[0]) * abs(reg_scale)
  62. upper_bound2 = abs(up[0]) * abs(reg_scale) * 2
  63. step = (upper_bound1 + 1) ** (2 / (max_num_bins - 2))
  64. left_values = [-((step) ** i) + 1 for i in range(max_num_bins // 2 - 1, 0, -1)]
  65. right_values = [(step) ** i - 1 for i in range(1, max_num_bins // 2)]
  66. values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2]
  67. values = [v if v.dim() > 0 else v.unsqueeze(0) for v in values]
  68. values = torch.cat(values, 0)
  69. return values
  70. def translate_gt(gt: torch.Tensor, max_num_bins: int, reg_scale: int, up: torch.Tensor):
  71. """
  72. Decodes bounding box ground truth (GT) values into distribution-based GT representations.
  73. This function maps continuous GT values into discrete distribution bins, which can be used
  74. for regression tasks in object detection models. It calculates the indices of the closest
  75. bins to each GT value and assigns interpolation weights to these bins based on their proximity
  76. to the GT value.
  77. Args:
  78. gt (Tensor): Ground truth bounding box values, shape (N, ).
  79. max_num_bins (int): Maximum number of discrete bins for the distribution.
  80. reg_scale (float): Controls the curvature of the Weighting Function.
  81. up (Tensor): Controls the upper bounds of the Weighting Function.
  82. Returns:
  83. tuple[Tensor, Tensor, Tensor]:
  84. - indices (Tensor): Index of the left bin closest to each GT value, shape (N, ).
  85. - weight_right (Tensor): Weight assigned to the right bin, shape (N, ).
  86. - weight_left (Tensor): Weight assigned to the left bin, shape (N, ).
  87. """
  88. gt = gt.reshape(-1)
  89. function_values = weighting_function(max_num_bins, up, reg_scale)
  90. # Find the closest left-side indices for each value
  91. diffs = function_values.unsqueeze(0) - gt.unsqueeze(1)
  92. mask = diffs <= 0
  93. closest_left_indices = torch.sum(mask, dim=1) - 1
  94. # Calculate the weights for the interpolation
  95. indices = closest_left_indices.float()
  96. weight_right = torch.zeros_like(indices)
  97. weight_left = torch.zeros_like(indices)
  98. valid_idx_mask = (indices >= 0) & (indices < max_num_bins)
  99. valid_indices = indices[valid_idx_mask].long()
  100. # Obtain distances
  101. left_values = function_values[valid_indices]
  102. right_values = function_values[valid_indices + 1]
  103. left_diffs = torch.abs(gt[valid_idx_mask] - left_values)
  104. right_diffs = torch.abs(right_values - gt[valid_idx_mask])
  105. # Valid weights
  106. weight_right[valid_idx_mask] = left_diffs / (left_diffs + right_diffs)
  107. weight_left[valid_idx_mask] = 1.0 - weight_right[valid_idx_mask]
  108. # Invalid weights (out of range)
  109. invalid_idx_mask_neg = indices < 0
  110. weight_right[invalid_idx_mask_neg] = 0.0
  111. weight_left[invalid_idx_mask_neg] = 1.0
  112. indices[invalid_idx_mask_neg] = 0.0
  113. invalid_idx_mask_pos = indices >= max_num_bins
  114. weight_right[invalid_idx_mask_pos] = 1.0
  115. weight_left[invalid_idx_mask_pos] = 0.0
  116. indices[invalid_idx_mask_pos] = max_num_bins - 0.1
  117. return indices, weight_right, weight_left
  118. def bbox2distance(points, bbox, max_num_bins, reg_scale, up, eps=0.1):
  119. """
  120. Converts bounding box coordinates to distances from a reference point.
  121. Args:
  122. points (Tensor): (n, 4) [x, y, w, h], where (x, y) is the center.
  123. bbox (Tensor): (n, 4) bounding boxes in "xyxy" format.
  124. max_num_bins (float): Maximum bin value.
  125. reg_scale (float): Controlling curvarture of W(n).
  126. up (Tensor): Controlling upper bounds of W(n).
  127. eps (float): Small value to ensure target < max_num_bins.
  128. Returns:
  129. Tensor: Decoded distances.
  130. """
  131. reg_scale = abs(reg_scale)
  132. left = (points[:, 0] - bbox[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
  133. top = (points[:, 1] - bbox[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
  134. right = (bbox[:, 2] - points[:, 0]) / (points[..., 2] / reg_scale + 1e-16) - 0.5 * reg_scale
  135. bottom = (bbox[:, 3] - points[:, 1]) / (points[..., 3] / reg_scale + 1e-16) - 0.5 * reg_scale
  136. four_lens = torch.stack([left, top, right, bottom], -1)
  137. four_lens, weight_right, weight_left = translate_gt(four_lens, max_num_bins, reg_scale, up)
  138. if max_num_bins is not None:
  139. four_lens = four_lens.clamp(min=0, max=max_num_bins - eps)
  140. return four_lens.reshape(-1).detach(), weight_right.detach(), weight_left.detach()
  141. class DFineLoss(RTDetrLoss):
  142. """
  143. This class computes the losses for D-FINE. The process happens in two steps: 1) we compute hungarian assignment
  144. between ground truth boxes and the outputs of the model 2) we supervise each pair of matched ground-truth /
  145. prediction (supervise class and box).
  146. Args:
  147. matcher (`DetrHungarianMatcher`):
  148. Module able to compute a matching between targets and proposals.
  149. weight_dict (`Dict`):
  150. Dictionary relating each loss with its weights. These losses are configured in DFineConf as
  151. `weight_loss_vfl`, `weight_loss_bbox`, `weight_loss_giou`, `weight_loss_fgl`, `weight_loss_ddf`
  152. losses (`list[str]`):
  153. List of all the losses to be applied. See `get_loss` for a list of all available losses.
  154. alpha (`float`):
  155. Parameter alpha used to compute the focal loss.
  156. gamma (`float`):
  157. Parameter gamma used to compute the focal loss.
  158. eos_coef (`float`):
  159. Relative classification weight applied to the no-object category.
  160. num_classes (`int`):
  161. Number of object categories, omitting the special no-object category.
  162. """
  163. def __init__(self, config):
  164. super().__init__(config)
  165. self.matcher = RTDetrHungarianMatcher(config)
  166. self.max_num_bins = config.max_num_bins
  167. self.weight_dict = {
  168. "loss_vfl": config.weight_loss_vfl,
  169. "loss_bbox": config.weight_loss_bbox,
  170. "loss_giou": config.weight_loss_giou,
  171. "loss_fgl": config.weight_loss_fgl,
  172. "loss_ddf": config.weight_loss_ddf,
  173. }
  174. self.losses = ["vfl", "boxes", "local"]
  175. self.reg_scale = config.reg_scale
  176. self.up = nn.Parameter(torch.tensor([config.up]), requires_grad=False)
  177. def unimodal_distribution_focal_loss(
  178. self, pred, label, weight_right, weight_left, weight=None, reduction="sum", avg_factor=None
  179. ):
  180. dis_left = label.long()
  181. dis_right = dis_left + 1
  182. loss = F.cross_entropy(pred, dis_left, reduction="none") * weight_left.reshape(-1) + F.cross_entropy(
  183. pred, dis_right, reduction="none"
  184. ) * weight_right.reshape(-1)
  185. if weight is not None:
  186. weight = weight.float()
  187. loss = loss * weight
  188. if avg_factor is not None:
  189. loss = loss.sum() / avg_factor
  190. elif reduction == "mean":
  191. loss = loss.mean()
  192. elif reduction == "sum":
  193. loss = loss.sum()
  194. return loss
  195. def loss_local(self, outputs, targets, indices, num_boxes, T=5):
  196. """Compute Fine-Grained Localization (FGL) Loss
  197. and Decoupled Distillation Focal (DDF) Loss."""
  198. losses = {}
  199. if "pred_corners" in outputs:
  200. idx = self._get_source_permutation_idx(indices)
  201. target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
  202. pred_corners = outputs["pred_corners"][idx].reshape(-1, (self.max_num_bins + 1))
  203. ref_points = outputs["ref_points"][idx].detach()
  204. with torch.no_grad():
  205. self.fgl_targets = bbox2distance(
  206. ref_points,
  207. center_to_corners_format(target_boxes),
  208. self.max_num_bins,
  209. self.reg_scale,
  210. self.up,
  211. )
  212. target_corners, weight_right, weight_left = self.fgl_targets
  213. ious = torch.diag(
  214. box_iou(center_to_corners_format(outputs["pred_boxes"][idx]), center_to_corners_format(target_boxes))[
  215. 0
  216. ]
  217. )
  218. weight_targets = ious.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
  219. losses["loss_fgl"] = self.unimodal_distribution_focal_loss(
  220. pred_corners,
  221. target_corners,
  222. weight_right,
  223. weight_left,
  224. weight_targets,
  225. avg_factor=num_boxes,
  226. )
  227. pred_corners = outputs["pred_corners"].reshape(-1, (self.max_num_bins + 1))
  228. target_corners = outputs["teacher_corners"].reshape(-1, (self.max_num_bins + 1))
  229. if torch.equal(pred_corners, target_corners):
  230. losses["loss_ddf"] = pred_corners.sum() * 0
  231. else:
  232. weight_targets_local = outputs["teacher_logits"].sigmoid().max(dim=-1)[0]
  233. mask = torch.zeros_like(weight_targets_local, dtype=torch.bool)
  234. mask[idx] = True
  235. mask = mask.unsqueeze(-1).repeat(1, 1, 4).reshape(-1)
  236. weight_targets_local[idx] = ious.reshape_as(weight_targets_local[idx]).to(weight_targets_local.dtype)
  237. weight_targets_local = weight_targets_local.unsqueeze(-1).repeat(1, 1, 4).reshape(-1).detach()
  238. loss_match_local = (
  239. weight_targets_local
  240. * (T**2)
  241. * (
  242. nn.KLDivLoss(reduction="none")(
  243. F.log_softmax(pred_corners / T, dim=1),
  244. F.softmax(target_corners.detach() / T, dim=1),
  245. )
  246. ).sum(-1)
  247. )
  248. batch_scale = 1 / outputs["pred_boxes"].shape[0] # it should be refined
  249. self.num_pos, self.num_neg = (
  250. (mask.sum() * batch_scale) ** 0.5,
  251. ((~mask).sum() * batch_scale) ** 0.5,
  252. )
  253. loss_match_local1 = loss_match_local[mask].mean() if mask.any() else 0
  254. loss_match_local2 = loss_match_local[~mask].mean() if (~mask).any() else 0
  255. losses["loss_ddf"] = (loss_match_local1 * self.num_pos + loss_match_local2 * self.num_neg) / (
  256. self.num_pos + self.num_neg
  257. )
  258. return losses
  259. def get_loss(self, loss, outputs, targets, indices, num_boxes):
  260. loss_map = {
  261. "cardinality": self.loss_cardinality,
  262. "local": self.loss_local,
  263. "boxes": self.loss_boxes,
  264. "focal": self.loss_labels_focal,
  265. "vfl": self.loss_labels_vfl,
  266. }
  267. if loss not in loss_map:
  268. raise ValueError(f"Loss {loss} not supported")
  269. return loss_map[loss](outputs, targets, indices, num_boxes)
  270. def DFineForObjectDetectionLoss(
  271. logits,
  272. labels,
  273. device,
  274. pred_boxes,
  275. config,
  276. outputs_class=None,
  277. outputs_coord=None,
  278. enc_topk_logits=None,
  279. enc_topk_bboxes=None,
  280. denoising_meta_values=None,
  281. predicted_corners=None,
  282. initial_reference_points=None,
  283. **kwargs,
  284. ):
  285. criterion = DFineLoss(config)
  286. criterion.to(device)
  287. # Second: compute the losses, based on outputs and labels
  288. outputs_loss = {}
  289. outputs_loss["logits"] = logits
  290. outputs_loss["pred_boxes"] = pred_boxes.clamp(min=0, max=1)
  291. auxiliary_outputs = None
  292. if config.auxiliary_loss:
  293. if denoising_meta_values is not None:
  294. dn_out_coord, outputs_coord = torch.split(
  295. outputs_coord.clamp(min=0, max=1), denoising_meta_values["dn_num_split"], dim=2
  296. )
  297. dn_out_class, outputs_class = torch.split(outputs_class, denoising_meta_values["dn_num_split"], dim=2)
  298. dn_out_corners, out_corners = torch.split(predicted_corners, denoising_meta_values["dn_num_split"], dim=2)
  299. dn_out_refs, out_refs = torch.split(initial_reference_points, denoising_meta_values["dn_num_split"], dim=2)
  300. auxiliary_outputs = _set_aux_loss2(
  301. outputs_class[:, :-1].transpose(0, 1),
  302. outputs_coord[:, :-1].transpose(0, 1),
  303. out_corners[:, :-1].transpose(0, 1),
  304. out_refs[:, :-1].transpose(0, 1),
  305. out_corners[:, -1],
  306. outputs_class[:, -1],
  307. )
  308. outputs_loss["auxiliary_outputs"] = auxiliary_outputs
  309. outputs_loss["auxiliary_outputs"].extend(
  310. _set_aux_loss([enc_topk_logits], [enc_topk_bboxes.clamp(min=0, max=1)])
  311. )
  312. dn_auxiliary_outputs = _set_aux_loss2(
  313. dn_out_class.transpose(0, 1),
  314. dn_out_coord.transpose(0, 1),
  315. dn_out_corners.transpose(0, 1),
  316. dn_out_refs.transpose(0, 1),
  317. dn_out_corners[:, -1],
  318. dn_out_class[:, -1],
  319. )
  320. outputs_loss["dn_auxiliary_outputs"] = dn_auxiliary_outputs
  321. outputs_loss["denoising_meta_values"] = denoising_meta_values
  322. loss_dict = criterion(outputs_loss, labels)
  323. loss = sum(loss_dict.values())
  324. return loss, loss_dict, auxiliary_outputs