| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
- import os
- import torch
- from modelscope.metainfo import Models
- from modelscope.models.base.base_torch_model import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import ModelFile, Tasks
- from .backbone import SwinTransformer
- from .deformable_transformer import DeformableTransformer
- from .fpn_fusion import FPNFusionModule
- from .head import Detector
- @MODELS.register_module(Tasks.image_object_detection, module_name=Models.vidt)
- class VidtModel(TorchModel):
- """
- The implementation of 'ViDT for joint-learning of object detection and instance segmentation'.
- This model is dynamically initialized with the following parts:
- - 'backbone': pre-trained backbone model with parameters.
- - 'head': detection and segentation head with fine-tuning.
- """
- def __init__(self, model_dir: str, **kwargs):
- """ Initialize a Vidt Model.
- Args:
- model_dir: model id or path, where model_dir/pytorch_model.pt contains:
- - 'backbone_weights': parameters of backbone.
- - 'head_weights': parameters of head.
- """
- super(VidtModel, self).__init__()
- model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- model_dict = torch.load(
- model_path, map_location='cpu', weights_only=True)
- # build backbone
- backbone = SwinTransformer(
- pretrain_img_size=[224, 224],
- embed_dim=96,
- depths=[2, 2, 6, 2],
- num_heads=[3, 6, 12, 24],
- window_size=7,
- drop_path_rate=0.2)
- backbone.finetune_det(
- method='vidt', det_token_num=300, pos_dim=256, cross_indices=[3])
- self.backbone = backbone
- self.backbone.load_state_dict(
- model_dict['backbone_weights'], strict=True)
- # build head
- epff = FPNFusionModule(backbone.num_channels, fuse_dim=256)
- deform_transformers = DeformableTransformer(
- d_model=256,
- nhead=8,
- num_decoder_layers=6,
- dim_feedforward=1024,
- dropout=0.1,
- activation='relu',
- return_intermediate_dec=True,
- num_feature_levels=4,
- dec_n_points=4,
- token_label=False)
- head = Detector(
- backbone,
- deform_transformers,
- num_classes=2,
- num_queries=300,
- # two essential techniques used in ViDT
- aux_loss=True,
- with_box_refine=True,
- # an epff module for ViDT+
- epff=epff,
- # an UQR module for ViDT+
- with_vector=False,
- processor_dct=None,
- # two additional losses for VIDT+
- iou_aware=True,
- token_label=False,
- vector_hidden_dim=256,
- # distil
- distil=False)
- self.head = head
- self.head.load_state_dict(model_dict['head_weights'], strict=True)
- def forward(self, x, mask):
- """ Dynamic forward function of VidtModel.
- Args:
- x: input images (B, 3, H, W)
- mask: input padding masks (B, H, W)
- """
- features_0, features_1, features_2, features_3, det_tgt, det_pos = self.backbone(
- x, mask)
- out_pred_logits, out_pred_boxes = self.head(features_0, features_1,
- features_2, features_3,
- det_tgt, det_pos, mask)
- return out_pred_logits, out_pred_boxes
|