general_recognition_pipeline.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torchvision import transforms
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.cv.animal_recognition import resnet
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Input, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.preprocessors import LoadImage, load_image
  15. from modelscope.utils.constant import ModelFile, Tasks
  16. from modelscope.utils.logger import get_logger
  17. logger = get_logger()
  18. @PIPELINES.register_module(
  19. Tasks.general_recognition, module_name=Pipelines.general_recognition)
  20. class GeneralRecognitionPipeline(Pipeline):
  21. def __init__(self, model: str, device: str):
  22. """
  23. use `model` to create a general recognition pipeline for prediction
  24. Args:
  25. model: model id on modelscope hub.
  26. """
  27. super().__init__(model=model)
  28. import torch
  29. def resnest101(**kwargs):
  30. model = resnet.ResNet(
  31. resnet.Bottleneck, [3, 4, 23, 3],
  32. radix=2,
  33. groups=1,
  34. bottleneck_width=64,
  35. deep_stem=True,
  36. stem_width=64,
  37. avg_down=True,
  38. avd=True,
  39. avd_first=False,
  40. **kwargs)
  41. return model
  42. def filter_param(src_params, own_state):
  43. copied_keys = []
  44. for name, param in src_params.items():
  45. if 'module.' == name[0:7]:
  46. name = name[7:]
  47. if '.module.' not in list(own_state.keys())[0]:
  48. name = name.replace('.module.', '.')
  49. if (name in own_state) and (own_state[name].shape
  50. == param.shape):
  51. own_state[name].copy_(param)
  52. copied_keys.append(name)
  53. def load_pretrained(model, src_params):
  54. if 'state_dict' in src_params:
  55. src_params = src_params['state_dict']
  56. own_state = model.state_dict()
  57. filter_param(src_params, own_state)
  58. model.load_state_dict(own_state)
  59. device = 'cpu'
  60. self.local_path = self.model
  61. src_params = torch.load(
  62. osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE),
  63. device,
  64. weights_only=True)
  65. self.model = resnest101(num_classes=54092)
  66. load_pretrained(self.model, src_params)
  67. logger.info('load model done')
  68. def preprocess(self, input: Input) -> Dict[str, Any]:
  69. img = LoadImage.convert_to_img(input)
  70. normalize = transforms.Normalize(
  71. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  72. transform = transforms.Compose([
  73. transforms.Resize(256),
  74. transforms.CenterCrop(224),
  75. transforms.ToTensor(), normalize
  76. ])
  77. img = transform(img)
  78. result = {'img': img}
  79. return result
  80. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  81. def set_phase(model, is_train):
  82. if is_train:
  83. model.train()
  84. else:
  85. model.eval()
  86. is_train = False
  87. set_phase(self.model, is_train)
  88. img = input['img']
  89. input_img = torch.unsqueeze(img, 0)
  90. outputs = self.model(input_img)
  91. return {'outputs': outputs}
  92. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  93. label_mapping_path = osp.join(self.local_path, 'meta_info.txt')
  94. with open(label_mapping_path, 'r', encoding='utf-8') as f:
  95. label_mapping = f.readlines()
  96. score = torch.max(inputs['outputs'])
  97. inputs = {
  98. OutputKeys.SCORES: [score.item()],
  99. OutputKeys.LABELS:
  100. [label_mapping[inputs['outputs'].argmax()].split('\t')[1]]
  101. }
  102. return inputs