model.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import os
  3. import torch
  4. from modelscope.metainfo import Models
  5. from modelscope.models.base.base_torch_model import TorchModel
  6. from modelscope.models.builder import MODELS
  7. from modelscope.utils.constant import ModelFile, Tasks
  8. from .backbone import SwinTransformer
  9. from .deformable_transformer import DeformableTransformer
  10. from .fpn_fusion import FPNFusionModule
  11. from .head import Detector
  12. @MODELS.register_module(Tasks.image_object_detection, module_name=Models.vidt)
  13. class VidtModel(TorchModel):
  14. """
  15. The implementation of 'ViDT for joint-learning of object detection and instance segmentation'.
  16. This model is dynamically initialized with the following parts:
  17. - 'backbone': pre-trained backbone model with parameters.
  18. - 'head': detection and segentation head with fine-tuning.
  19. """
  20. def __init__(self, model_dir: str, **kwargs):
  21. """ Initialize a Vidt Model.
  22. Args:
  23. model_dir: model id or path, where model_dir/pytorch_model.pt contains:
  24. - 'backbone_weights': parameters of backbone.
  25. - 'head_weights': parameters of head.
  26. """
  27. super(VidtModel, self).__init__()
  28. model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
  29. model_dict = torch.load(
  30. model_path, map_location='cpu', weights_only=True)
  31. # build backbone
  32. backbone = SwinTransformer(
  33. pretrain_img_size=[224, 224],
  34. embed_dim=96,
  35. depths=[2, 2, 6, 2],
  36. num_heads=[3, 6, 12, 24],
  37. window_size=7,
  38. drop_path_rate=0.2)
  39. backbone.finetune_det(
  40. method='vidt', det_token_num=300, pos_dim=256, cross_indices=[3])
  41. self.backbone = backbone
  42. self.backbone.load_state_dict(
  43. model_dict['backbone_weights'], strict=True)
  44. # build head
  45. epff = FPNFusionModule(backbone.num_channels, fuse_dim=256)
  46. deform_transformers = DeformableTransformer(
  47. d_model=256,
  48. nhead=8,
  49. num_decoder_layers=6,
  50. dim_feedforward=1024,
  51. dropout=0.1,
  52. activation='relu',
  53. return_intermediate_dec=True,
  54. num_feature_levels=4,
  55. dec_n_points=4,
  56. token_label=False)
  57. head = Detector(
  58. backbone,
  59. deform_transformers,
  60. num_classes=2,
  61. num_queries=300,
  62. # two essential techniques used in ViDT
  63. aux_loss=True,
  64. with_box_refine=True,
  65. # an epff module for ViDT+
  66. epff=epff,
  67. # an UQR module for ViDT+
  68. with_vector=False,
  69. processor_dct=None,
  70. # two additional losses for VIDT+
  71. iou_aware=True,
  72. token_label=False,
  73. vector_hidden_dim=256,
  74. # distil
  75. distil=False)
  76. self.head = head
  77. self.head.load_state_dict(model_dict['head_weights'], strict=True)
  78. def forward(self, x, mask):
  79. """ Dynamic forward function of VidtModel.
  80. Args:
  81. x: input images (B, 3, H, W)
  82. mask: input padding masks (B, H, W)
  83. """
  84. features_0, features_1, features_2, features_3, det_tgt, det_pos = self.backbone(
  85. x, mask)
  86. out_pred_logits, out_pred_boxes = self.head(features_0, features_1,
  87. features_2, features_3,
  88. det_tgt, det_pos, mask)
  89. return out_pred_logits, out_pred_boxes