action_detection_pipeline.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. import os.path as osp
  4. from typing import Any, Dict
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models.cv.action_detection import ActionDetONNX
  7. from modelscope.outputs import OutputKeys
  8. from modelscope.pipelines.base import Input, Pipeline
  9. from modelscope.pipelines.builder import PIPELINES
  10. from modelscope.utils.config import Config
  11. from modelscope.utils.constant import ModelFile, Tasks
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger()
  14. @PIPELINES.register_module(
  15. Tasks.action_detection, module_name=Pipelines.action_detection)
  16. class ActionDetectionPipeline(Pipeline):
  17. def __init__(self, model: str, **kwargs):
  18. """
  19. use `model` to create a action detection pipeline for prediction
  20. Args:
  21. model: model id on modelscope hub.
  22. """
  23. super().__init__(model=model, **kwargs)
  24. model_path = osp.join(self.model, ModelFile.ONNX_MODEL_FILE)
  25. logger.info(f'loading model from {model_path}')
  26. config_path = osp.join(self.model, ModelFile.CONFIGURATION)
  27. logger.info(f'loading config from {config_path}')
  28. self.cfg = Config.from_file(config_path)
  29. self.cfg.MODEL.model_file = model_path
  30. self.cfg.MODEL.update(kwargs)
  31. self.model = ActionDetONNX(self.model, self.cfg.MODEL,
  32. self.device_name)
  33. logger.info('load model done')
  34. def preprocess(self, input: Input) -> Dict[str, Any]:
  35. if isinstance(input, str):
  36. video_name = input
  37. else:
  38. raise TypeError(f'input should be a str,'
  39. f' but got {type(input)}')
  40. result = {'video_name': video_name}
  41. return result
  42. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  43. preds = self.model.forward(input['video_name'])
  44. labels = sum([pred['actions']['labels'] for pred in preds], [])
  45. scores = sum([pred['actions']['scores'] for pred in preds], [])
  46. boxes = sum([pred['actions']['boxes'] for pred in preds], [])
  47. timestamps = sum([[pred['timestamp']] * len(pred['actions']['labels'])
  48. for pred in preds], [])
  49. out = {
  50. OutputKeys.TIMESTAMPS: timestamps,
  51. OutputKeys.LABELS: labels,
  52. OutputKeys.SCORES: scores,
  53. OutputKeys.BOXES: boxes
  54. }
  55. return out
  56. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  57. return inputs