# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. import os.path as osp from typing import Any, Dict import json import torch import torch.nn as nn import torch.nn.functional as F from modelscope.metainfo import Models from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS from modelscope.outputs import OutputKeys from modelscope.utils.constant import ModelFile, Tasks from .backbone import build_backbone from .head import FPNSegmentor, LinearClassifier @MODELS.register_module( Tasks.image_segmentation, module_name=Models.vision_middleware) class VisionMiddlewareModel(TorchModel): """ The implementation of 'ViM: Vision Middleware for Unified Downstream Transferring'. This model is dynamically initialized with the following parts: - backbone: the upstream pre-trained backbone model (CLIP in this code) - ViM: the zoo of middlestream trained ViM modules - ViM-aggregation: the specific aggregation weights for downstream tasks """ def __init__(self, model_dir: str, *args, **kwargs): """ Initialize a ViM-based Model. Args: model_dir: model id or path, where model_dir/pytorch_model.pt contains: - 'meta_info': basic information of ViM, e.g. task_list - 'backbone_weights': parameters of backbone [upstream] - 'ViM_weights': parameters of ViM [midstream] - 'ViM_agg_weights': parameters of ViM-aggregation [downstream] """ super(VisionMiddlewareModel, self).__init__() model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) model_dict = torch.load(model_path, map_location='cpu') meta_info = model_dict['meta_info'] self.task_list = meta_info['task_list'] # build up backbone backbone_weights = model_dict['backbone_weights'] self.backbone = build_backbone( arch=meta_info['backbone_arch'], pretrained=backbone_weights) self.backbone.eval() # build up ViM vim_weights = model_dict['ViM_weights'] num_layers = len(vim_weights) for layer_i in range(num_layers): self.backbone.transformer.resblocks[layer_i].vim_att.register_ViM( vim_weights[layer_i]['vim_att_weights']) self.backbone.transformer.resblocks[layer_i].vim_mlp.register_ViM( vim_weights[layer_i]['vim_mlp_weights']) # build up each task-related ViM aggregation agg_weights = model_dict['ViM_agg_weights'] agg_algo = meta_info['ViM_agg_algo'] for task_name in meta_info['task_list']: for layer_i in range(num_layers): self.backbone.transformer.resblocks[ layer_i].vim_att.register_task( task_name, agg_weights[task_name][layer_i]['vim_att_agg'], agg_algo) self.backbone.transformer.resblocks[ layer_i].vim_mlp.register_task( task_name, agg_weights[task_name][layer_i]['vim_mlp_agg'], agg_algo) # build up each task-related head self.heads = nn.ModuleDict() self.label_maps = {} for task_name in meta_info['task_list']: head_weights = model_dict['head_weights'] if task_name.startswith('cls'): self.heads[task_name] = LinearClassifier( in_channels=self.backbone.output_dim, num_classes=head_weights[task_name] ['classifier.bias'].shape[0]) elif task_name.startswith('seg'): self.heads[task_name] = FPNSegmentor() else: raise NotImplementedError( 'Task type [{}] is not supported'.format(task_name)) self.heads[task_name].load_state_dict(head_weights[task_name]) self.heads[task_name].eval() if task_name in meta_info['label_map'].keys(): self.label_maps[task_name] = meta_info['label_map'][task_name] def __call__(self, inputs, task_name) -> Dict[str, Any]: return self.postprocess( self.forward(inputs, task_name), inputs, task_name) def forward(self, inputs, task_name): """ Dynamic Forward Function of ViM. Args: x: the input images (B, 3, H, W) task_name: specified task for forwarding """ if task_name not in self.task_list: raise NotImplementedError( f'task_name should in {self.task_list}, but got {task_name}') features = self.backbone(inputs, task_name=task_name) outputs = self.heads[task_name](features) return outputs def postprocess(self, outputs, inputs, task_name): """ Post-process of ViM, based on task_name. Args: inputs: batched input image (B, 3, H, W) outputs: batched output (format based on task_name) task_name (str): task name """ _, in_channels, img_height, img_width = inputs.size() if 'seg' in task_name: # outputs in shape of [1, C, H, W] seg = F.softmax(outputs, dim=1) seg = F.interpolate(seg, (img_height, img_width), None, 'bilinear', True) seg = seg[0].detach().cpu() pred = torch.argmax(seg, dim=0) labels = sorted(list(set(pred.reshape(-1).numpy()))) masks, scores = [], [] for label in labels: mask = (pred == label) masks.append(mask.long().numpy()) scores.append(((mask.float() * seg[label]).sum() / mask.float().sum()).item()) label_names = [ self.label_maps[task_name][label] for label in labels ] return { OutputKeys.MASKS: masks, OutputKeys.LABELS: label_names, OutputKeys.SCORES: scores } else: raise NotImplementedError( 'Only segmentation task is currently supported in pipeline') def get_tasks(self): """ Get the supported tasks of current ViM model. """ return self.task_list