# Copyright (c) Alibaba, Inc. and its affiliates. import math import os.path as osp from typing import Any, Dict import torch from torchvision import transforms from modelscope.metainfo import Pipelines from modelscope.models.cv.tinynas_classfication import get_zennet from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger logger = get_logger() @PIPELINES.register_module( Tasks.image_classification, module_name=Pipelines.tinynas_classification) class TinynasClassificationPipeline(Pipeline): def __init__(self, model: str, **kwargs): """ use `model` to create a tinynas classification pipeline for prediction Args: model: model id on modelscope hub. """ super().__init__(model=model, **kwargs) self.path = model self.model = get_zennet() model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE) checkpoint = torch.load( model_pth_path, map_location='cpu', weights_only=True) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint self.model.load_state_dict(state_dict, strict=True) logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: img = LoadImage.convert_to_img(input) input_image_size = 224 crop_image_size = 380 input_image_crop = 0.875 resize_image_size = int(math.ceil(crop_image_size / input_image_crop)) transforms_normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform_list = [ transforms.Resize( resize_image_size, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(crop_image_size), transforms.ToTensor(), transforms_normalize ] transformer = transforms.Compose(transform_list) img = transformer(img) img = torch.unsqueeze(img, 0) img = torch.nn.functional.interpolate( img, input_image_size, mode='bilinear') result = {'img': img} return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: is_train = False if is_train: self.model.train() else: self.model.eval() outputs = self.model(input['img']) return {'outputs': outputs} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: label_mapping_path = osp.join(self.path, 'label_map.txt') f = open(label_mapping_path, encoding='utf-8') content = f.read() f.close() label_dict = eval(content) output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1) score = torch.max(output_prob) output_dict = { OutputKeys.SCORES: [score.item()], OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]] } return output_dict