model.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import json
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from modelscope.metainfo import Models
  9. from modelscope.models.base.base_torch_model import TorchModel
  10. from modelscope.models.builder import MODELS
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.utils.constant import ModelFile, Tasks
  13. from .backbone import build_backbone
  14. from .head import FPNSegmentor, LinearClassifier
  15. @MODELS.register_module(
  16. Tasks.image_segmentation, module_name=Models.vision_middleware)
  17. class VisionMiddlewareModel(TorchModel):
  18. """
  19. The implementation of 'ViM: Vision Middleware for Unified Downstream Transferring'.
  20. This model is dynamically initialized with the following parts:
  21. - backbone: the upstream pre-trained backbone model (CLIP in this code)
  22. - ViM: the zoo of middlestream trained ViM modules
  23. - ViM-aggregation: the specific aggregation weights for downstream tasks
  24. """
  25. def __init__(self, model_dir: str, *args, **kwargs):
  26. """
  27. Initialize a ViM-based Model.
  28. Args:
  29. model_dir: model id or path, where model_dir/pytorch_model.pt contains:
  30. - 'meta_info': basic information of ViM, e.g. task_list
  31. - 'backbone_weights': parameters of backbone [upstream]
  32. - 'ViM_weights': parameters of ViM [midstream]
  33. - 'ViM_agg_weights': parameters of ViM-aggregation [downstream]
  34. """
  35. super(VisionMiddlewareModel, self).__init__()
  36. model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
  37. model_dict = torch.load(model_path, map_location='cpu')
  38. meta_info = model_dict['meta_info']
  39. self.task_list = meta_info['task_list']
  40. # build up backbone
  41. backbone_weights = model_dict['backbone_weights']
  42. self.backbone = build_backbone(
  43. arch=meta_info['backbone_arch'], pretrained=backbone_weights)
  44. self.backbone.eval()
  45. # build up ViM
  46. vim_weights = model_dict['ViM_weights']
  47. num_layers = len(vim_weights)
  48. for layer_i in range(num_layers):
  49. self.backbone.transformer.resblocks[layer_i].vim_att.register_ViM(
  50. vim_weights[layer_i]['vim_att_weights'])
  51. self.backbone.transformer.resblocks[layer_i].vim_mlp.register_ViM(
  52. vim_weights[layer_i]['vim_mlp_weights'])
  53. # build up each task-related ViM aggregation
  54. agg_weights = model_dict['ViM_agg_weights']
  55. agg_algo = meta_info['ViM_agg_algo']
  56. for task_name in meta_info['task_list']:
  57. for layer_i in range(num_layers):
  58. self.backbone.transformer.resblocks[
  59. layer_i].vim_att.register_task(
  60. task_name,
  61. agg_weights[task_name][layer_i]['vim_att_agg'],
  62. agg_algo)
  63. self.backbone.transformer.resblocks[
  64. layer_i].vim_mlp.register_task(
  65. task_name,
  66. agg_weights[task_name][layer_i]['vim_mlp_agg'],
  67. agg_algo)
  68. # build up each task-related head
  69. self.heads = nn.ModuleDict()
  70. self.label_maps = {}
  71. for task_name in meta_info['task_list']:
  72. head_weights = model_dict['head_weights']
  73. if task_name.startswith('cls'):
  74. self.heads[task_name] = LinearClassifier(
  75. in_channels=self.backbone.output_dim,
  76. num_classes=head_weights[task_name]
  77. ['classifier.bias'].shape[0])
  78. elif task_name.startswith('seg'):
  79. self.heads[task_name] = FPNSegmentor()
  80. else:
  81. raise NotImplementedError(
  82. 'Task type [{}] is not supported'.format(task_name))
  83. self.heads[task_name].load_state_dict(head_weights[task_name])
  84. self.heads[task_name].eval()
  85. if task_name in meta_info['label_map'].keys():
  86. self.label_maps[task_name] = meta_info['label_map'][task_name]
  87. def __call__(self, inputs, task_name) -> Dict[str, Any]:
  88. return self.postprocess(
  89. self.forward(inputs, task_name), inputs, task_name)
  90. def forward(self, inputs, task_name):
  91. """
  92. Dynamic Forward Function of ViM.
  93. Args:
  94. x: the input images (B, 3, H, W)
  95. task_name: specified task for forwarding
  96. """
  97. if task_name not in self.task_list:
  98. raise NotImplementedError(
  99. f'task_name should in {self.task_list}, but got {task_name}')
  100. features = self.backbone(inputs, task_name=task_name)
  101. outputs = self.heads[task_name](features)
  102. return outputs
  103. def postprocess(self, outputs, inputs, task_name):
  104. """
  105. Post-process of ViM, based on task_name.
  106. Args:
  107. inputs: batched input image (B, 3, H, W)
  108. outputs: batched output (format based on task_name)
  109. task_name (str): task name
  110. """
  111. _, in_channels, img_height, img_width = inputs.size()
  112. if 'seg' in task_name:
  113. # outputs in shape of [1, C, H, W]
  114. seg = F.softmax(outputs, dim=1)
  115. seg = F.interpolate(seg, (img_height, img_width), None, 'bilinear',
  116. True)
  117. seg = seg[0].detach().cpu()
  118. pred = torch.argmax(seg, dim=0)
  119. labels = sorted(list(set(pred.reshape(-1).numpy())))
  120. masks, scores = [], []
  121. for label in labels:
  122. mask = (pred == label)
  123. masks.append(mask.long().numpy())
  124. scores.append(((mask.float() * seg[label]).sum()
  125. / mask.float().sum()).item())
  126. label_names = [
  127. self.label_maps[task_name][label] for label in labels
  128. ]
  129. return {
  130. OutputKeys.MASKS: masks,
  131. OutputKeys.LABELS: label_names,
  132. OutputKeys.SCORES: scores
  133. }
  134. else:
  135. raise NotImplementedError(
  136. 'Only segmentation task is currently supported in pipeline')
  137. def get_tasks(self):
  138. """
  139. Get the supported tasks of current ViM model.
  140. """
  141. return self.task_list