| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import math
- import os.path as osp
- from typing import Any, Dict
- import torch
- from torchvision import transforms
- from modelscope.metainfo import Pipelines
- from modelscope.models.cv.tinynas_classfication import get_zennet
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Input, Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors import LoadImage
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @PIPELINES.register_module(
- Tasks.image_classification, module_name=Pipelines.tinynas_classification)
- class TinynasClassificationPipeline(Pipeline):
- def __init__(self, model: str, **kwargs):
- """
- use `model` to create a tinynas classification pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, **kwargs)
- self.path = model
- self.model = get_zennet()
- model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE)
- checkpoint = torch.load(
- model_pth_path, map_location='cpu', weights_only=True)
- if 'state_dict' in checkpoint:
- state_dict = checkpoint['state_dict']
- else:
- state_dict = checkpoint
- self.model.load_state_dict(state_dict, strict=True)
- logger.info('load model done')
- def preprocess(self, input: Input) -> Dict[str, Any]:
- img = LoadImage.convert_to_img(input)
- input_image_size = 224
- crop_image_size = 380
- input_image_crop = 0.875
- resize_image_size = int(math.ceil(crop_image_size / input_image_crop))
- transforms_normalize = transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- transform_list = [
- transforms.Resize(
- resize_image_size,
- interpolation=transforms.InterpolationMode.BICUBIC),
- transforms.CenterCrop(crop_image_size),
- transforms.ToTensor(), transforms_normalize
- ]
- transformer = transforms.Compose(transform_list)
- img = transformer(img)
- img = torch.unsqueeze(img, 0)
- img = torch.nn.functional.interpolate(
- img, input_image_size, mode='bilinear')
- result = {'img': img}
- return result
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
- is_train = False
- if is_train:
- self.model.train()
- else:
- self.model.eval()
- outputs = self.model(input['img'])
- return {'outputs': outputs}
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- label_mapping_path = osp.join(self.path, 'label_map.txt')
- f = open(label_mapping_path, encoding='utf-8')
- content = f.read()
- f.close()
- label_dict = eval(content)
- output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1)
- score = torch.max(output_prob)
- output_dict = {
- OutputKeys.SCORES: [score.item()],
- OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]]
- }
- return output_dict
|