head.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. # The implementation here is modified based on timm,
  2. # originally Apache 2.0 License and publicly available at
  3. # https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/detector.py
  4. import copy
  5. import math
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class Detector(nn.Module):
  10. """ This is a combination of "Swin with RAM" and a "Neck-free Deformable Decoder" """
  11. def __init__(
  12. self,
  13. backbone,
  14. transformer,
  15. num_classes,
  16. num_queries,
  17. aux_loss=False,
  18. with_box_refine=False,
  19. # The three additional techniques for ViDT+
  20. epff=None, # (1) Efficient Pyramid Feature Fusion Module
  21. with_vector=False,
  22. processor_dct=None,
  23. vector_hidden_dim=256, # (2) UQR Module
  24. iou_aware=False,
  25. token_label=False, # (3) Additional losses
  26. distil=False):
  27. """ Initializes the model.
  28. Args:
  29. backbone: torch module of the backbone to be used. See backbone.py
  30. transformer: torch module of the transformer architecture. See transformer.py
  31. num_classes: number of object classes
  32. num_queries: number of object queries (i.e., det tokens). This is the maximal number of objects
  33. DETR can detect in a single image. For COCO, we recommend 100 queries.
  34. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
  35. with_box_refine: iterative bounding box refinement
  36. epff: None or fusion module available
  37. iou_aware: True if iou_aware is to be used.
  38. see the original paper https://arxiv.org/abs/1912.05992
  39. token_label: True if token_label is to be used.
  40. see the original paper https://arxiv.org/abs/2104.10858
  41. distil: whether to use knowledge distillation with token matching
  42. """
  43. super().__init__()
  44. self.num_queries = num_queries
  45. self.transformer = transformer
  46. hidden_dim = transformer.d_model
  47. self.class_embed = nn.Linear(hidden_dim, num_classes)
  48. self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
  49. # two essential techniques used [default use]
  50. self.aux_loss = aux_loss
  51. self.with_box_refine = with_box_refine
  52. # For UQR module for ViDT+
  53. self.with_vector = with_vector
  54. self.processor_dct = processor_dct
  55. if self.with_vector:
  56. print(
  57. f'Training with vector_hidden_dim {vector_hidden_dim}.',
  58. flush=True)
  59. self.vector_embed = MLP(hidden_dim, vector_hidden_dim,
  60. self.processor_dct.n_keep, 3)
  61. # For two additional losses for ViDT+
  62. self.iou_aware = iou_aware
  63. self.token_label = token_label
  64. # distillation
  65. self.distil = distil
  66. # For EPFF module for ViDT+
  67. if epff is None:
  68. num_backbone_outs = len(backbone.num_channels)
  69. input_proj_list = []
  70. for _ in range(num_backbone_outs):
  71. in_channels = backbone.num_channels[_]
  72. input_proj_list.append(
  73. nn.Sequential(
  74. # This is 1x1 conv -> so linear layer
  75. nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
  76. nn.GroupNorm(32, hidden_dim),
  77. ))
  78. self.input_proj = nn.ModuleList(input_proj_list)
  79. # initialize the projection layer for [PATCH] tokens
  80. for proj in self.input_proj:
  81. nn.init.xavier_uniform_(proj[0].weight, gain=1)
  82. nn.init.constant_(proj[0].bias, 0)
  83. self.fusion = None
  84. else:
  85. # the cross scale fusion module has its own reduction layers
  86. self.fusion = epff
  87. # channel dim reduction for [DET] tokens
  88. self.tgt_proj = nn.Sequential(
  89. # This is 1x1 conv -> so linear layer
  90. nn.Conv2d(backbone.num_channels[-2], hidden_dim, kernel_size=1),
  91. nn.GroupNorm(32, hidden_dim),
  92. )
  93. # channel dim reductionfor [DET] learnable pos encodings
  94. self.query_pos_proj = nn.Sequential(
  95. # This is 1x1 conv -> so linear layer
  96. nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
  97. nn.GroupNorm(32, hidden_dim),
  98. )
  99. # initialize detection head: box regression and classification
  100. prior_prob = 0.01
  101. bias_value = -math.log((1 - prior_prob) / prior_prob)
  102. self.class_embed.bias.data = torch.ones(num_classes) * bias_value
  103. nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
  104. nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
  105. # initialize projection layer for [DET] tokens and encodings
  106. nn.init.xavier_uniform_(self.tgt_proj[0].weight, gain=1)
  107. nn.init.constant_(self.tgt_proj[0].bias, 0)
  108. nn.init.xavier_uniform_(self.query_pos_proj[0].weight, gain=1)
  109. nn.init.constant_(self.query_pos_proj[0].bias, 0)
  110. if self.with_vector:
  111. nn.init.constant_(self.vector_embed.layers[-1].weight.data, 0)
  112. nn.init.constant_(self.vector_embed.layers[-1].bias.data, 0)
  113. # the prediction is made for each decoding layers + the standalone detector (Swin with RAM)
  114. num_pred = transformer.decoder.num_layers + 1
  115. # set up all required nn.Module for additional techniques
  116. if with_box_refine:
  117. self.class_embed = _get_clones(self.class_embed, num_pred)
  118. self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
  119. nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:],
  120. -2.0)
  121. # hack implementation for iterative bounding box refinement
  122. self.transformer.decoder.bbox_embed = self.bbox_embed
  123. else:
  124. nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
  125. self.class_embed = nn.ModuleList(
  126. [self.class_embed for _ in range(num_pred)])
  127. self.bbox_embed = nn.ModuleList(
  128. [self.bbox_embed for _ in range(num_pred)])
  129. self.transformer.decoder.bbox_embed = None
  130. if self.with_vector:
  131. nn.init.constant_(self.vector_embed.layers[-1].bias.data[2:], -2.0)
  132. self.vector_embed = nn.ModuleList(
  133. [self.vector_embed for _ in range(num_pred)])
  134. if self.iou_aware:
  135. self.iou_embed = MLP(hidden_dim, hidden_dim, 1, 3)
  136. if with_box_refine:
  137. self.iou_embed = _get_clones(self.iou_embed, num_pred)
  138. else:
  139. self.iou_embed = nn.ModuleList(
  140. [self.iou_embed for _ in range(num_pred)])
  141. def forward(self, features_0, features_1, features_2, features_3, det_tgt,
  142. det_pos, mask):
  143. """ The forward step of ViDT
  144. Args:
  145. The forward expects a NestedTensor, which consists of:
  146. - features_0: images feature
  147. - features_1: images feature
  148. - features_2: images feature
  149. - features_3: images feature
  150. - det_tgt: images det logits feature
  151. - det_pos: images det position feature
  152. - mask: images mask
  153. Returns:
  154. A dictionary having the key and value pairs below:
  155. - "out_pred_logits": the classification logits (including no-object) for all queries.
  156. Shape= [batch_size x num_queries x (num_classes + 1)]
  157. - "out_pred_boxes": The normalized boxes coordinates for all queries, represented as
  158. (center_x, center_y, height, width). These values are normalized in [0, 1],
  159. relative to the size of each individual image (disregarding possible padding).
  160. See PostProcess for information on how to retrieve the unnormalized bounding box.
  161. """
  162. features = [features_0, features_1, features_2, features_3]
  163. # [DET] token and encoding projection to compact representation for the input to the Neck-free transformer
  164. det_tgt = self.tgt_proj(det_tgt.unsqueeze(-1)).squeeze(-1).permute(
  165. 0, 2, 1)
  166. det_pos = self.query_pos_proj(
  167. det_pos.unsqueeze(-1)).squeeze(-1).permute(0, 2, 1)
  168. # [PATCH] token projection
  169. shapes = []
  170. for le, src in enumerate(features):
  171. shapes.append(src.shape[-2:])
  172. srcs = []
  173. if self.fusion is None:
  174. for le, src in enumerate(features):
  175. srcs.append(self.input_proj[le](src))
  176. else:
  177. # EPFF (multi-scale fusion) is used if fusion is activated
  178. srcs = self.fusion(features)
  179. masks = []
  180. for le, src in enumerate(srcs):
  181. # resize mask
  182. shapes.append(src.shape[-2:])
  183. _mask = F.interpolate(
  184. mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
  185. masks.append(_mask)
  186. assert mask is not None
  187. outputs_classes = []
  188. outputs_coords = []
  189. # return the output of the neck-free decoder
  190. hs, init_reference, inter_references, enc_token_class_unflat = self.transformer(
  191. srcs, masks, det_tgt, det_pos)
  192. # perform predictions via the detection head
  193. for lvl in range(hs.shape[0]):
  194. reference = init_reference if lvl == 0 else inter_references[lvl
  195. - 1]
  196. reference = inverse_sigmoid(reference)
  197. outputs_class = self.class_embed[lvl](hs[lvl])
  198. # bbox output + reference
  199. tmp = self.bbox_embed[lvl](hs[lvl])
  200. if reference.shape[-1] == 4:
  201. tmp += reference
  202. else:
  203. assert reference.shape[-1] == 2
  204. tmp[..., :2] += reference
  205. outputs_coord = tmp.sigmoid()
  206. outputs_classes.append(outputs_class)
  207. outputs_coords.append(outputs_coord)
  208. # stack all predictions made from each decoding layers
  209. outputs_class = torch.stack(outputs_classes)
  210. outputs_coord = torch.stack(outputs_coords)
  211. outputs_vector = None
  212. if self.with_vector:
  213. outputs_vectors = []
  214. for lvl in range(hs.shape[0]):
  215. outputs_vector = self.vector_embed[lvl](hs[lvl])
  216. outputs_vectors.append(outputs_vector)
  217. outputs_vector = torch.stack(outputs_vectors)
  218. # final prediction is made the last decoding layer
  219. out = {
  220. 'pred_logits': outputs_class[-1],
  221. 'pred_boxes': outputs_coord[-1]
  222. }
  223. if self.with_vector:
  224. out.update({'pred_vectors': outputs_vector[-1]})
  225. # aux loss is defined by using the rest predictions
  226. if self.aux_loss and self.transformer.decoder.num_layers > 0:
  227. out['aux_outputs'] = self._set_aux_loss(outputs_class,
  228. outputs_coord,
  229. outputs_vector)
  230. # iou awareness loss is defined for each decoding layer similar to auxiliary decoding loss
  231. if self.iou_aware:
  232. outputs_ious = []
  233. for lvl in range(hs.shape[0]):
  234. outputs_ious.append(self.iou_embed[lvl](hs[lvl]))
  235. outputs_iou = torch.stack(outputs_ious)
  236. out['pred_ious'] = outputs_iou[-1]
  237. if self.aux_loss:
  238. for i, aux in enumerate(out['aux_outputs']):
  239. aux['pred_ious'] = outputs_iou[i]
  240. # token label loss
  241. if self.token_label:
  242. out['enc_tokens'] = {'pred_logits': enc_token_class_unflat}
  243. if self.distil:
  244. # 'patch_token': multi-scale patch tokens from each stage
  245. # 'body_det_token' and 'neck_det_tgt': the input det_token for multiple detection heads
  246. out['distil_tokens'] = {
  247. 'patch_token': srcs,
  248. 'body_det_token': det_tgt,
  249. 'neck_det_token': hs
  250. }
  251. out_pred_logits = out['pred_logits']
  252. out_pred_boxes = out['pred_boxes']
  253. return out_pred_logits, out_pred_boxes
  254. @torch.jit.unused
  255. def _set_aux_loss(self, outputs_class, outputs_coord, outputs_vector):
  256. # this is a workaround to make torchscript happy, as torchscript
  257. # doesn't support dictionary with non-homogeneous values, such
  258. # as a dict having both a Tensor and a list.
  259. if outputs_vector is None:
  260. return [{
  261. 'pred_logits': a,
  262. 'pred_boxes': b
  263. } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  264. else:
  265. return [{
  266. 'pred_logits': a,
  267. 'pred_boxes': b,
  268. 'pred_vectors': c
  269. } for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1],
  270. outputs_vector[:-1])]
  271. class MLP(nn.Module):
  272. """ Very simple multi-layer perceptron (also called FFN)"""
  273. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  274. super().__init__()
  275. self.num_layers = num_layers
  276. h = [hidden_dim] * (num_layers - 1)
  277. self.layers = nn.ModuleList(
  278. nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  279. def forward(self, x):
  280. for i, layer in enumerate(self.layers):
  281. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  282. return x
  283. def inverse_sigmoid(x, eps=1e-5):
  284. x = x.clamp(min=0, max=1)
  285. x1 = x.clamp(min=eps)
  286. x2 = (1 - x).clamp(min=eps)
  287. return torch.log(x1 / x2)
  288. def box_cxcywh_to_xyxy(x):
  289. x_c, y_c, w, h = x.unbind(-1)
  290. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
  291. return torch.stack(b, dim=-1)
  292. # process post_results
  293. def get_predictions(post_results, bbox_thu=0.40):
  294. batch_final_res = []
  295. for per_img_res in post_results:
  296. per_img_final_res = []
  297. for i in range(len(per_img_res['scores'])):
  298. score = float(per_img_res['scores'][i].cpu())
  299. label = int(per_img_res['labels'][i].cpu())
  300. bbox = []
  301. for it in per_img_res['boxes'][i].cpu():
  302. bbox.append(int(it))
  303. if score >= bbox_thu:
  304. per_img_final_res.append([score, label, bbox])
  305. batch_final_res.append(per_img_final_res)
  306. return batch_final_res
  307. class PostProcess(nn.Module):
  308. """ This module converts the model's output into the format expected by the coco api"""
  309. def __init__(self, processor_dct=None):
  310. super().__init__()
  311. # For instance segmentation using UQR module
  312. self.processor_dct = processor_dct
  313. @torch.no_grad()
  314. def forward(self, out_logits, out_bbox, target_sizes):
  315. """ Perform the computation
  316. Args:
  317. out_logits: raw logits outputs of the model
  318. out_bbox: raw bbox outputs of the model
  319. target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
  320. For evaluation, this must be the original image size (before any data augmentation)
  321. For visualization, this should be the image size after data augment, but before padding
  322. """
  323. assert len(out_logits) == len(target_sizes)
  324. assert target_sizes.shape[1] == 2
  325. prob = out_logits.sigmoid()
  326. topk_values, topk_indexes = torch.topk(
  327. prob.view(out_logits.shape[0], -1), 100, dim=1)
  328. scores = topk_values
  329. topk_boxes = topk_indexes // out_logits.shape[2]
  330. labels = topk_indexes % out_logits.shape[2]
  331. boxes = box_cxcywh_to_xyxy(out_bbox)
  332. boxes = torch.gather(boxes, 1,
  333. topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
  334. # and from relative [0, 1] to absolute [0, height] coordinates
  335. img_h, img_w = target_sizes.unbind(1)
  336. scale_fct = torch.stack([img_w, img_h, img_w, img_h],
  337. dim=1).to(torch.float32)
  338. boxes = boxes * scale_fct[:, None, :]
  339. results = [{
  340. 'scores': s,
  341. 'labels': l,
  342. 'boxes': b
  343. } for s, l, b in zip(scores, labels, boxes)]
  344. return results
  345. def _get_clones(module, N):
  346. """ Clone a module N times """
  347. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])