action_recognition_pipeline.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os.path as osp
  4. from typing import Any, Dict
  5. import torch
  6. from modelscope.metainfo import Pipelines
  7. from modelscope.models.cv.action_recognition import (BaseVideoModel,
  8. PatchShiftTransformer)
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Input, Pipeline
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.preprocessors import ReadVideoData
  13. from modelscope.utils.config import Config
  14. from modelscope.utils.constant import ModelFile, Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. @PIPELINES.register_module(
  18. Tasks.action_recognition, module_name=Pipelines.action_recognition)
  19. class ActionRecognitionPipeline(Pipeline):
  20. def __init__(self, model: str, **kwargs):
  21. """
  22. use `model` to create a action recognition pipeline for prediction
  23. Args:
  24. model: model id on modelscope hub.
  25. """
  26. super().__init__(model=model, **kwargs)
  27. model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
  28. logger.info(f'loading model from {model_path}')
  29. config_path = osp.join(self.model, ModelFile.CONFIGURATION)
  30. logger.info(f'loading config from {config_path}')
  31. self.cfg = Config.from_file(config_path)
  32. self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
  33. self.infer_model.eval()
  34. self.infer_model.load_state_dict(
  35. torch.load(
  36. model_path, map_location=self.device,
  37. weights_only=True)['model_state'])
  38. self.label_mapping = self.cfg.label_mapping
  39. logger.info('load model done')
  40. def preprocess(self, input: Input) -> Dict[str, Any]:
  41. if isinstance(input, str):
  42. video_input_data = ReadVideoData(self.cfg, input).to(self.device)
  43. else:
  44. raise TypeError(f'input should be a str,'
  45. f' but got {type(input)}')
  46. result = {'video_data': video_input_data}
  47. return result
  48. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  49. pred = self.perform_inference(input['video_data'])
  50. output_label = self.label_mapping[str(pred)]
  51. return {OutputKeys.LABELS: output_label}
  52. @torch.no_grad()
  53. def perform_inference(self, data, max_bsz=4):
  54. iter_num = math.ceil(data.size(0) / max_bsz)
  55. preds_list = []
  56. for i in range(iter_num):
  57. preds_list.append(
  58. self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0])
  59. pred = torch.cat(preds_list, dim=0)
  60. return pred.mean(dim=0).argmax().item()
  61. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  62. return inputs
  63. @PIPELINES.register_module(
  64. Tasks.action_recognition, module_name=Pipelines.pst_action_recognition)
  65. class PSTActionRecognitionPipeline(Pipeline):
  66. def __init__(self, model: str, **kwargs):
  67. """
  68. use `model` to create a PST action recognition pipeline for prediction
  69. Args:
  70. model: model id on modelscope hub.
  71. """
  72. super().__init__(model=model, **kwargs)
  73. model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
  74. logger.info(f'loading model from {model_path}')
  75. config_path = osp.join(self.model, ModelFile.CONFIGURATION)
  76. logger.info(f'loading config from {config_path}')
  77. self.cfg = Config.from_file(config_path)
  78. self.infer_model = PatchShiftTransformer(model).to(self.device)
  79. self.infer_model.eval()
  80. self.infer_model.load_state_dict(
  81. torch.load(
  82. model_path, map_location=self.device,
  83. weights_only=True)['state_dict'])
  84. self.label_mapping = self.cfg.label_mapping
  85. logger.info('load model done')
  86. def preprocess(self, input: Input) -> Dict[str, Any]:
  87. if isinstance(input, str):
  88. video_input_data = ReadVideoData(self.cfg, input).to(self.device)
  89. else:
  90. raise TypeError(f'input should be a str,'
  91. f' but got {type(input)}')
  92. result = {'video_data': video_input_data}
  93. return result
  94. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  95. pred = self.perform_inference(input['video_data'])
  96. output_label = self.label_mapping[str(pred)]
  97. return {OutputKeys.LABELS: output_label}
  98. @torch.no_grad()
  99. def perform_inference(self, data, max_bsz=4):
  100. iter_num = math.ceil(data.size(0) / max_bsz)
  101. preds_list = []
  102. for i in range(iter_num):
  103. preds_list.append(
  104. self.infer_model(data[i * max_bsz:(i + 1) * max_bsz]))
  105. pred = torch.cat(preds_list, dim=0)
  106. return pred.mean(dim=0).argmax().item()
  107. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  108. return inputs