| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
- import os.path as osp
- from typing import Any, Dict
- import json
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from modelscope.metainfo import Models
- from modelscope.models.base.base_torch_model import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.outputs import OutputKeys
- from modelscope.utils.constant import ModelFile, Tasks
- from .backbone import build_backbone
- from .head import FPNSegmentor, LinearClassifier
- @MODELS.register_module(
- Tasks.image_segmentation, module_name=Models.vision_middleware)
- class VisionMiddlewareModel(TorchModel):
- """
- The implementation of 'ViM: Vision Middleware for Unified Downstream Transferring'.
- This model is dynamically initialized with the following parts:
- - backbone: the upstream pre-trained backbone model (CLIP in this code)
- - ViM: the zoo of middlestream trained ViM modules
- - ViM-aggregation: the specific aggregation weights for downstream tasks
- """
- def __init__(self, model_dir: str, *args, **kwargs):
- """
- Initialize a ViM-based Model.
- Args:
- model_dir: model id or path, where model_dir/pytorch_model.pt contains:
- - 'meta_info': basic information of ViM, e.g. task_list
- - 'backbone_weights': parameters of backbone [upstream]
- - 'ViM_weights': parameters of ViM [midstream]
- - 'ViM_agg_weights': parameters of ViM-aggregation [downstream]
- """
- super(VisionMiddlewareModel, self).__init__()
- model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
- model_dict = torch.load(model_path, map_location='cpu')
- meta_info = model_dict['meta_info']
- self.task_list = meta_info['task_list']
- # build up backbone
- backbone_weights = model_dict['backbone_weights']
- self.backbone = build_backbone(
- arch=meta_info['backbone_arch'], pretrained=backbone_weights)
- self.backbone.eval()
- # build up ViM
- vim_weights = model_dict['ViM_weights']
- num_layers = len(vim_weights)
- for layer_i in range(num_layers):
- self.backbone.transformer.resblocks[layer_i].vim_att.register_ViM(
- vim_weights[layer_i]['vim_att_weights'])
- self.backbone.transformer.resblocks[layer_i].vim_mlp.register_ViM(
- vim_weights[layer_i]['vim_mlp_weights'])
- # build up each task-related ViM aggregation
- agg_weights = model_dict['ViM_agg_weights']
- agg_algo = meta_info['ViM_agg_algo']
- for task_name in meta_info['task_list']:
- for layer_i in range(num_layers):
- self.backbone.transformer.resblocks[
- layer_i].vim_att.register_task(
- task_name,
- agg_weights[task_name][layer_i]['vim_att_agg'],
- agg_algo)
- self.backbone.transformer.resblocks[
- layer_i].vim_mlp.register_task(
- task_name,
- agg_weights[task_name][layer_i]['vim_mlp_agg'],
- agg_algo)
- # build up each task-related head
- self.heads = nn.ModuleDict()
- self.label_maps = {}
- for task_name in meta_info['task_list']:
- head_weights = model_dict['head_weights']
- if task_name.startswith('cls'):
- self.heads[task_name] = LinearClassifier(
- in_channels=self.backbone.output_dim,
- num_classes=head_weights[task_name]
- ['classifier.bias'].shape[0])
- elif task_name.startswith('seg'):
- self.heads[task_name] = FPNSegmentor()
- else:
- raise NotImplementedError(
- 'Task type [{}] is not supported'.format(task_name))
- self.heads[task_name].load_state_dict(head_weights[task_name])
- self.heads[task_name].eval()
- if task_name in meta_info['label_map'].keys():
- self.label_maps[task_name] = meta_info['label_map'][task_name]
- def __call__(self, inputs, task_name) -> Dict[str, Any]:
- return self.postprocess(
- self.forward(inputs, task_name), inputs, task_name)
- def forward(self, inputs, task_name):
- """
- Dynamic Forward Function of ViM.
- Args:
- x: the input images (B, 3, H, W)
- task_name: specified task for forwarding
- """
- if task_name not in self.task_list:
- raise NotImplementedError(
- f'task_name should in {self.task_list}, but got {task_name}')
- features = self.backbone(inputs, task_name=task_name)
- outputs = self.heads[task_name](features)
- return outputs
- def postprocess(self, outputs, inputs, task_name):
- """
- Post-process of ViM, based on task_name.
- Args:
- inputs: batched input image (B, 3, H, W)
- outputs: batched output (format based on task_name)
- task_name (str): task name
- """
- _, in_channels, img_height, img_width = inputs.size()
- if 'seg' in task_name:
- # outputs in shape of [1, C, H, W]
- seg = F.softmax(outputs, dim=1)
- seg = F.interpolate(seg, (img_height, img_width), None, 'bilinear',
- True)
- seg = seg[0].detach().cpu()
- pred = torch.argmax(seg, dim=0)
- labels = sorted(list(set(pred.reshape(-1).numpy())))
- masks, scores = [], []
- for label in labels:
- mask = (pred == label)
- masks.append(mask.long().numpy())
- scores.append(((mask.float() * seg[label]).sum()
- / mask.float().sum()).item())
- label_names = [
- self.label_maps[task_name][label] for label in labels
- ]
- return {
- OutputKeys.MASKS: masks,
- OutputKeys.LABELS: label_names,
- OutputKeys.SCORES: scores
- }
- else:
- raise NotImplementedError(
- 'Only segmentation task is currently supported in pipeline')
- def get_tasks(self):
- """
- Get the supported tasks of current ViM model.
- """
- return self.task_list
|