| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- # The implementation here is modified based on timm,
- # originally Apache 2.0 License and publicly available at
- # https://github.com/naver-ai/vidt/blob/vidt-plus/methods/vidt/detector.py
- import copy
- import math
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class Detector(nn.Module):
- """ This is a combination of "Swin with RAM" and a "Neck-free Deformable Decoder" """
- def __init__(
- self,
- backbone,
- transformer,
- num_classes,
- num_queries,
- aux_loss=False,
- with_box_refine=False,
- # The three additional techniques for ViDT+
- epff=None, # (1) Efficient Pyramid Feature Fusion Module
- with_vector=False,
- processor_dct=None,
- vector_hidden_dim=256, # (2) UQR Module
- iou_aware=False,
- token_label=False, # (3) Additional losses
- distil=False):
- """ Initializes the model.
- Args:
- backbone: torch module of the backbone to be used. See backbone.py
- transformer: torch module of the transformer architecture. See transformer.py
- num_classes: number of object classes
- num_queries: number of object queries (i.e., det tokens). This is the maximal number of objects
- DETR can detect in a single image. For COCO, we recommend 100 queries.
- aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
- with_box_refine: iterative bounding box refinement
- epff: None or fusion module available
- iou_aware: True if iou_aware is to be used.
- see the original paper https://arxiv.org/abs/1912.05992
- token_label: True if token_label is to be used.
- see the original paper https://arxiv.org/abs/2104.10858
- distil: whether to use knowledge distillation with token matching
- """
- super().__init__()
- self.num_queries = num_queries
- self.transformer = transformer
- hidden_dim = transformer.d_model
- self.class_embed = nn.Linear(hidden_dim, num_classes)
- self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
- # two essential techniques used [default use]
- self.aux_loss = aux_loss
- self.with_box_refine = with_box_refine
- # For UQR module for ViDT+
- self.with_vector = with_vector
- self.processor_dct = processor_dct
- if self.with_vector:
- print(
- f'Training with vector_hidden_dim {vector_hidden_dim}.',
- flush=True)
- self.vector_embed = MLP(hidden_dim, vector_hidden_dim,
- self.processor_dct.n_keep, 3)
- # For two additional losses for ViDT+
- self.iou_aware = iou_aware
- self.token_label = token_label
- # distillation
- self.distil = distil
- # For EPFF module for ViDT+
- if epff is None:
- num_backbone_outs = len(backbone.num_channels)
- input_proj_list = []
- for _ in range(num_backbone_outs):
- in_channels = backbone.num_channels[_]
- input_proj_list.append(
- nn.Sequential(
- # This is 1x1 conv -> so linear layer
- nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
- nn.GroupNorm(32, hidden_dim),
- ))
- self.input_proj = nn.ModuleList(input_proj_list)
- # initialize the projection layer for [PATCH] tokens
- for proj in self.input_proj:
- nn.init.xavier_uniform_(proj[0].weight, gain=1)
- nn.init.constant_(proj[0].bias, 0)
- self.fusion = None
- else:
- # the cross scale fusion module has its own reduction layers
- self.fusion = epff
- # channel dim reduction for [DET] tokens
- self.tgt_proj = nn.Sequential(
- # This is 1x1 conv -> so linear layer
- nn.Conv2d(backbone.num_channels[-2], hidden_dim, kernel_size=1),
- nn.GroupNorm(32, hidden_dim),
- )
- # channel dim reductionfor [DET] learnable pos encodings
- self.query_pos_proj = nn.Sequential(
- # This is 1x1 conv -> so linear layer
- nn.Conv2d(hidden_dim, hidden_dim, kernel_size=1),
- nn.GroupNorm(32, hidden_dim),
- )
- # initialize detection head: box regression and classification
- prior_prob = 0.01
- bias_value = -math.log((1 - prior_prob) / prior_prob)
- self.class_embed.bias.data = torch.ones(num_classes) * bias_value
- nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
- nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
- # initialize projection layer for [DET] tokens and encodings
- nn.init.xavier_uniform_(self.tgt_proj[0].weight, gain=1)
- nn.init.constant_(self.tgt_proj[0].bias, 0)
- nn.init.xavier_uniform_(self.query_pos_proj[0].weight, gain=1)
- nn.init.constant_(self.query_pos_proj[0].bias, 0)
- if self.with_vector:
- nn.init.constant_(self.vector_embed.layers[-1].weight.data, 0)
- nn.init.constant_(self.vector_embed.layers[-1].bias.data, 0)
- # the prediction is made for each decoding layers + the standalone detector (Swin with RAM)
- num_pred = transformer.decoder.num_layers + 1
- # set up all required nn.Module for additional techniques
- if with_box_refine:
- self.class_embed = _get_clones(self.class_embed, num_pred)
- self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
- nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:],
- -2.0)
- # hack implementation for iterative bounding box refinement
- self.transformer.decoder.bbox_embed = self.bbox_embed
- else:
- nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
- self.class_embed = nn.ModuleList(
- [self.class_embed for _ in range(num_pred)])
- self.bbox_embed = nn.ModuleList(
- [self.bbox_embed for _ in range(num_pred)])
- self.transformer.decoder.bbox_embed = None
- if self.with_vector:
- nn.init.constant_(self.vector_embed.layers[-1].bias.data[2:], -2.0)
- self.vector_embed = nn.ModuleList(
- [self.vector_embed for _ in range(num_pred)])
- if self.iou_aware:
- self.iou_embed = MLP(hidden_dim, hidden_dim, 1, 3)
- if with_box_refine:
- self.iou_embed = _get_clones(self.iou_embed, num_pred)
- else:
- self.iou_embed = nn.ModuleList(
- [self.iou_embed for _ in range(num_pred)])
- def forward(self, features_0, features_1, features_2, features_3, det_tgt,
- det_pos, mask):
- """ The forward step of ViDT
- Args:
- The forward expects a NestedTensor, which consists of:
- - features_0: images feature
- - features_1: images feature
- - features_2: images feature
- - features_3: images feature
- - det_tgt: images det logits feature
- - det_pos: images det position feature
- - mask: images mask
- Returns:
- A dictionary having the key and value pairs below:
- - "out_pred_logits": the classification logits (including no-object) for all queries.
- Shape= [batch_size x num_queries x (num_classes + 1)]
- - "out_pred_boxes": The normalized boxes coordinates for all queries, represented as
- (center_x, center_y, height, width). These values are normalized in [0, 1],
- relative to the size of each individual image (disregarding possible padding).
- See PostProcess for information on how to retrieve the unnormalized bounding box.
- """
- features = [features_0, features_1, features_2, features_3]
- # [DET] token and encoding projection to compact representation for the input to the Neck-free transformer
- det_tgt = self.tgt_proj(det_tgt.unsqueeze(-1)).squeeze(-1).permute(
- 0, 2, 1)
- det_pos = self.query_pos_proj(
- det_pos.unsqueeze(-1)).squeeze(-1).permute(0, 2, 1)
- # [PATCH] token projection
- shapes = []
- for le, src in enumerate(features):
- shapes.append(src.shape[-2:])
- srcs = []
- if self.fusion is None:
- for le, src in enumerate(features):
- srcs.append(self.input_proj[le](src))
- else:
- # EPFF (multi-scale fusion) is used if fusion is activated
- srcs = self.fusion(features)
- masks = []
- for le, src in enumerate(srcs):
- # resize mask
- shapes.append(src.shape[-2:])
- _mask = F.interpolate(
- mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
- masks.append(_mask)
- assert mask is not None
- outputs_classes = []
- outputs_coords = []
- # return the output of the neck-free decoder
- hs, init_reference, inter_references, enc_token_class_unflat = self.transformer(
- srcs, masks, det_tgt, det_pos)
- # perform predictions via the detection head
- for lvl in range(hs.shape[0]):
- reference = init_reference if lvl == 0 else inter_references[lvl
- - 1]
- reference = inverse_sigmoid(reference)
- outputs_class = self.class_embed[lvl](hs[lvl])
- # bbox output + reference
- tmp = self.bbox_embed[lvl](hs[lvl])
- if reference.shape[-1] == 4:
- tmp += reference
- else:
- assert reference.shape[-1] == 2
- tmp[..., :2] += reference
- outputs_coord = tmp.sigmoid()
- outputs_classes.append(outputs_class)
- outputs_coords.append(outputs_coord)
- # stack all predictions made from each decoding layers
- outputs_class = torch.stack(outputs_classes)
- outputs_coord = torch.stack(outputs_coords)
- outputs_vector = None
- if self.with_vector:
- outputs_vectors = []
- for lvl in range(hs.shape[0]):
- outputs_vector = self.vector_embed[lvl](hs[lvl])
- outputs_vectors.append(outputs_vector)
- outputs_vector = torch.stack(outputs_vectors)
- # final prediction is made the last decoding layer
- out = {
- 'pred_logits': outputs_class[-1],
- 'pred_boxes': outputs_coord[-1]
- }
- if self.with_vector:
- out.update({'pred_vectors': outputs_vector[-1]})
- # aux loss is defined by using the rest predictions
- if self.aux_loss and self.transformer.decoder.num_layers > 0:
- out['aux_outputs'] = self._set_aux_loss(outputs_class,
- outputs_coord,
- outputs_vector)
- # iou awareness loss is defined for each decoding layer similar to auxiliary decoding loss
- if self.iou_aware:
- outputs_ious = []
- for lvl in range(hs.shape[0]):
- outputs_ious.append(self.iou_embed[lvl](hs[lvl]))
- outputs_iou = torch.stack(outputs_ious)
- out['pred_ious'] = outputs_iou[-1]
- if self.aux_loss:
- for i, aux in enumerate(out['aux_outputs']):
- aux['pred_ious'] = outputs_iou[i]
- # token label loss
- if self.token_label:
- out['enc_tokens'] = {'pred_logits': enc_token_class_unflat}
- if self.distil:
- # 'patch_token': multi-scale patch tokens from each stage
- # 'body_det_token' and 'neck_det_tgt': the input det_token for multiple detection heads
- out['distil_tokens'] = {
- 'patch_token': srcs,
- 'body_det_token': det_tgt,
- 'neck_det_token': hs
- }
- out_pred_logits = out['pred_logits']
- out_pred_boxes = out['pred_boxes']
- return out_pred_logits, out_pred_boxes
- @torch.jit.unused
- def _set_aux_loss(self, outputs_class, outputs_coord, outputs_vector):
- # this is a workaround to make torchscript happy, as torchscript
- # doesn't support dictionary with non-homogeneous values, such
- # as a dict having both a Tensor and a list.
- if outputs_vector is None:
- return [{
- 'pred_logits': a,
- 'pred_boxes': b
- } for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
- else:
- return [{
- 'pred_logits': a,
- 'pred_boxes': b,
- 'pred_vectors': c
- } for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1],
- outputs_vector[:-1])]
- class MLP(nn.Module):
- """ Very simple multi-layer perceptron (also called FFN)"""
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
- super().__init__()
- self.num_layers = num_layers
- h = [hidden_dim] * (num_layers - 1)
- self.layers = nn.ModuleList(
- nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
- def forward(self, x):
- for i, layer in enumerate(self.layers):
- x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
- return x
- def inverse_sigmoid(x, eps=1e-5):
- x = x.clamp(min=0, max=1)
- x1 = x.clamp(min=eps)
- x2 = (1 - x).clamp(min=eps)
- return torch.log(x1 / x2)
- def box_cxcywh_to_xyxy(x):
- x_c, y_c, w, h = x.unbind(-1)
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
- return torch.stack(b, dim=-1)
- # process post_results
- def get_predictions(post_results, bbox_thu=0.40):
- batch_final_res = []
- for per_img_res in post_results:
- per_img_final_res = []
- for i in range(len(per_img_res['scores'])):
- score = float(per_img_res['scores'][i].cpu())
- label = int(per_img_res['labels'][i].cpu())
- bbox = []
- for it in per_img_res['boxes'][i].cpu():
- bbox.append(int(it))
- if score >= bbox_thu:
- per_img_final_res.append([score, label, bbox])
- batch_final_res.append(per_img_final_res)
- return batch_final_res
- class PostProcess(nn.Module):
- """ This module converts the model's output into the format expected by the coco api"""
- def __init__(self, processor_dct=None):
- super().__init__()
- # For instance segmentation using UQR module
- self.processor_dct = processor_dct
- @torch.no_grad()
- def forward(self, out_logits, out_bbox, target_sizes):
- """ Perform the computation
- Args:
- out_logits: raw logits outputs of the model
- out_bbox: raw bbox outputs of the model
- target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
- For evaluation, this must be the original image size (before any data augmentation)
- For visualization, this should be the image size after data augment, but before padding
- """
- assert len(out_logits) == len(target_sizes)
- assert target_sizes.shape[1] == 2
- prob = out_logits.sigmoid()
- topk_values, topk_indexes = torch.topk(
- prob.view(out_logits.shape[0], -1), 100, dim=1)
- scores = topk_values
- topk_boxes = topk_indexes // out_logits.shape[2]
- labels = topk_indexes % out_logits.shape[2]
- boxes = box_cxcywh_to_xyxy(out_bbox)
- boxes = torch.gather(boxes, 1,
- topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
- # and from relative [0, 1] to absolute [0, height] coordinates
- img_h, img_w = target_sizes.unbind(1)
- scale_fct = torch.stack([img_w, img_h, img_w, img_h],
- dim=1).to(torch.float32)
- boxes = boxes * scale_fct[:, None, :]
- results = [{
- 'scores': s,
- 'labels': l,
- 'boxes': b
- } for s, l, b in zip(scores, labels, boxes)]
- return results
- def _get_clones(module, N):
- """ Clone a module N times """
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|