animal_recognition_pipeline.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from PIL import Image
  8. from torchvision import transforms
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.cv.animal_recognition import Bottleneck, ResNet
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Input, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.preprocessors import LoadImage
  15. from modelscope.utils.constant import Devices, ModelFile, Tasks
  16. from modelscope.utils.logger import get_logger
  17. logger = get_logger()
  18. @PIPELINES.register_module(
  19. Tasks.animal_recognition, module_name=Pipelines.animal_recognition)
  20. class AnimalRecognitionPipeline(Pipeline):
  21. def __init__(self, model: str, **kwargs):
  22. """
  23. use `model` to create a animal recognition pipeline for prediction
  24. Args:
  25. model: model id on modelscope hub.
  26. """
  27. super().__init__(model=model, **kwargs)
  28. import torch
  29. def resnest101(**kwargs):
  30. model = ResNet(
  31. Bottleneck, [3, 4, 23, 3],
  32. radix=2,
  33. groups=1,
  34. bottleneck_width=64,
  35. deep_stem=True,
  36. stem_width=64,
  37. avg_down=True,
  38. avd=True,
  39. avd_first=False,
  40. **kwargs)
  41. return model
  42. def filter_param(src_params, own_state):
  43. copied_keys = []
  44. for name, param in src_params.items():
  45. if 'module.' == name[0:7]:
  46. name = name[7:]
  47. if '.module.' not in list(own_state.keys())[0]:
  48. name = name.replace('.module.', '.')
  49. if (name in own_state) and (own_state[name].shape
  50. == param.shape):
  51. own_state[name].copy_(param)
  52. copied_keys.append(name)
  53. def load_pretrained(model, src_params):
  54. if 'state_dict' in src_params:
  55. src_params = src_params['state_dict']
  56. own_state = model.state_dict()
  57. filter_param(src_params, own_state)
  58. model.load_state_dict(own_state)
  59. self.local_path = self.model
  60. src_params = torch.load(
  61. osp.join(self.local_path, ModelFile.TORCH_MODEL_FILE),
  62. Devices.cpu,
  63. weights_only=True)
  64. self.model = resnest101(num_classes=8288)
  65. load_pretrained(self.model, src_params)
  66. logger.info('load model done')
  67. def preprocess(self, input: Input) -> Dict[str, Any]:
  68. img = LoadImage.convert_to_img(input)
  69. normalize = transforms.Normalize(
  70. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  71. test_transforms = transforms.Compose([
  72. transforms.Resize(256),
  73. transforms.CenterCrop(224),
  74. transforms.ToTensor(), normalize
  75. ])
  76. img = test_transforms(img)
  77. result = {'img': img}
  78. return result
  79. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  80. def set_phase(model, is_train):
  81. if is_train:
  82. model.train()
  83. else:
  84. model.eval()
  85. is_train = False
  86. set_phase(self.model, is_train)
  87. img = input['img']
  88. input_img = torch.unsqueeze(img, 0)
  89. outputs = self.model(input_img)
  90. return {'outputs': outputs}
  91. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  92. label_mapping_path = osp.join(self.local_path, 'label_mapping.txt')
  93. with open(label_mapping_path, 'r', encoding='utf-8') as f:
  94. label_mapping = f.readlines()
  95. score = torch.max(inputs['outputs'])
  96. inputs = {
  97. OutputKeys.SCORES: [score.item()],
  98. OutputKeys.LABELS:
  99. [label_mapping[inputs['outputs'].argmax()].split('\t')[1]]
  100. }
  101. return inputs