tinynas_classification_pipeline.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os.path as osp
  4. from typing import Any, Dict
  5. import torch
  6. from torchvision import transforms
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models.cv.tinynas_classfication import get_zennet
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Input, Pipeline
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.preprocessors import LoadImage
  13. from modelscope.utils.constant import ModelFile, Tasks
  14. from modelscope.utils.logger import get_logger
  15. logger = get_logger()
  16. @PIPELINES.register_module(
  17. Tasks.image_classification, module_name=Pipelines.tinynas_classification)
  18. class TinynasClassificationPipeline(Pipeline):
  19. def __init__(self, model: str, **kwargs):
  20. """
  21. use `model` to create a tinynas classification pipeline for prediction
  22. Args:
  23. model: model id on modelscope hub.
  24. """
  25. super().__init__(model=model, **kwargs)
  26. self.path = model
  27. self.model = get_zennet()
  28. model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE)
  29. checkpoint = torch.load(
  30. model_pth_path, map_location='cpu', weights_only=True)
  31. if 'state_dict' in checkpoint:
  32. state_dict = checkpoint['state_dict']
  33. else:
  34. state_dict = checkpoint
  35. self.model.load_state_dict(state_dict, strict=True)
  36. logger.info('load model done')
  37. def preprocess(self, input: Input) -> Dict[str, Any]:
  38. img = LoadImage.convert_to_img(input)
  39. input_image_size = 224
  40. crop_image_size = 380
  41. input_image_crop = 0.875
  42. resize_image_size = int(math.ceil(crop_image_size / input_image_crop))
  43. transforms_normalize = transforms.Normalize(
  44. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  45. transform_list = [
  46. transforms.Resize(
  47. resize_image_size,
  48. interpolation=transforms.InterpolationMode.BICUBIC),
  49. transforms.CenterCrop(crop_image_size),
  50. transforms.ToTensor(), transforms_normalize
  51. ]
  52. transformer = transforms.Compose(transform_list)
  53. img = transformer(img)
  54. img = torch.unsqueeze(img, 0)
  55. img = torch.nn.functional.interpolate(
  56. img, input_image_size, mode='bilinear')
  57. result = {'img': img}
  58. return result
  59. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  60. is_train = False
  61. if is_train:
  62. self.model.train()
  63. else:
  64. self.model.eval()
  65. outputs = self.model(input['img'])
  66. return {'outputs': outputs}
  67. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  68. label_mapping_path = osp.join(self.path, 'label_map.txt')
  69. f = open(label_mapping_path, encoding='utf-8')
  70. content = f.read()
  71. f.close()
  72. label_dict = eval(content)
  73. output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1)
  74. score = torch.max(output_prob)
  75. output_dict = {
  76. OutputKeys.SCORES: [score.item()],
  77. OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]]
  78. }
  79. return output_dict