image_classification_pipeline.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import numpy as np
  4. import torch
  5. from modelscope.metainfo import Pipelines, Preprocessors
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.pipelines.base import Input, Model, Pipeline
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.pipelines.util import batch_process
  10. from modelscope.preprocessors import Preprocessor
  11. from modelscope.preprocessors.image import LoadImage
  12. from modelscope.utils.constant import Fields, Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. @PIPELINES.register_module(
  16. Tasks.image_classification, module_name=Pipelines.image_classification)
  17. @PIPELINES.register_module(
  18. Tasks.image_classification,
  19. module_name=Pipelines.general_image_classification)
  20. @PIPELINES.register_module(
  21. Tasks.image_classification,
  22. module_name=Pipelines.daily_image_classification)
  23. @PIPELINES.register_module(
  24. Tasks.image_classification,
  25. module_name=Pipelines.nextvit_small_daily_image_classification)
  26. @PIPELINES.register_module(
  27. Tasks.image_classification,
  28. module_name=Pipelines.convnext_base_image_classification_garbage)
  29. @PIPELINES.register_module(
  30. Tasks.image_classification,
  31. module_name=Pipelines.common_image_classification)
  32. @PIPELINES.register_module(
  33. Tasks.image_classification,
  34. module_name=Pipelines.easyrobust_classification)
  35. @PIPELINES.register_module(
  36. Tasks.image_classification,
  37. module_name=Pipelines.bnext_small_image_classification)
  38. class GeneralImageClassificationPipeline(Pipeline):
  39. def __init__(self,
  40. model: str,
  41. preprocessor: Optional[Preprocessor] = None,
  42. config_file: str = None,
  43. device: str = 'gpu',
  44. auto_collate=True,
  45. **kwargs):
  46. """Use `model` and `preprocessor` to create an image classification pipeline for prediction
  47. Args:
  48. model: A str format model id or model local dir to build the model instance from.
  49. preprocessor: A preprocessor instance to preprocess the data, if None,
  50. the pipeline will try to build the preprocessor according to the configuration.json file.
  51. kwargs: The args needed by the `Pipeline` class.
  52. """
  53. super().__init__(
  54. model=model,
  55. preprocessor=preprocessor,
  56. config_file=config_file,
  57. device=device,
  58. auto_collate=auto_collate)
  59. self.target_gpus = None
  60. if preprocessor is None:
  61. assert hasattr(self.model, 'model_dir'), 'Model used in ImageClassificationPipeline should has ' \
  62. 'a `model_dir` attribute to build a preprocessor.'
  63. if self.model.__class__.__name__ == 'OfaForAllTasks':
  64. self.preprocessor = Preprocessor.from_pretrained(
  65. model_name_or_path=self.model.model_dir,
  66. type=Preprocessors.ofa_tasks_preprocessor,
  67. field=Fields.multi_modal,
  68. **kwargs)
  69. else:
  70. if next(self.model.parameters()).is_cuda:
  71. self.target_gpus = [next(self.model.parameters()).device]
  72. assert hasattr(self.model, 'model_dir'), 'Model used in GeneralImageClassificationPipeline' \
  73. ' should has a `model_dir` attribute to build a preprocessor.'
  74. self.preprocessor = Preprocessor.from_pretrained(
  75. self.model.model_dir, **kwargs)
  76. if self.preprocessor.__class__.__name__ == 'ImageClassificationBypassPreprocessor':
  77. from modelscope.preprocessors import ImageClassificationMmcvPreprocessor
  78. self.preprocessor = ImageClassificationMmcvPreprocessor(
  79. self.model.model_dir, **kwargs)
  80. logger.info('load model done')
  81. def _batch(self, data):
  82. if self.model.__class__.__name__ == 'OfaForAllTasks':
  83. return batch_process(self.model, data)
  84. else:
  85. return super()._batch(data)
  86. def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
  87. if self.model.__class__.__name__ == 'OfaForAllTasks':
  88. return super().preprocess(input, **preprocess_params)
  89. else:
  90. img = LoadImage.convert_to_ndarray(input)
  91. img = img[:, :, ::-1] # Convert to BGR
  92. data = super().preprocess(img, **preprocess_params)
  93. from mmcv.parallel import collate, scatter
  94. data = collate([data], samples_per_gpu=1)
  95. if self.target_gpus is not None:
  96. # scatter to specified GPU
  97. data = scatter(data, self.target_gpus)[0]
  98. return data
  99. def forward(self, input: Dict[str, Any],
  100. **forward_params) -> Dict[str, Any]:
  101. if self.model.__class__.__name__ != 'OfaForAllTasks':
  102. input['return_loss'] = False
  103. return self.model(input)
  104. def postprocess(self, inputs: Dict[str, Any],
  105. **post_params) -> Dict[str, Any]:
  106. if self.model.__class__.__name__ != 'OfaForAllTasks':
  107. scores = inputs
  108. pred_scores = np.sort(scores, axis=1)[0][::-1][:5]
  109. pred_labels = np.argsort(scores, axis=1)[0][::-1][:5]
  110. result = {
  111. 'pred_score': [score for score in pred_scores],
  112. 'pred_class':
  113. [self.model.CLASSES[label] for label in pred_labels]
  114. }
  115. outputs = {
  116. OutputKeys.SCORES: result['pred_score'],
  117. OutputKeys.LABELS: result['pred_class']
  118. }
  119. return outputs
  120. else:
  121. return inputs