| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import math
- import os.path as osp
- from typing import Any, Dict
- import torch
- from modelscope.metainfo import Pipelines
- from modelscope.models.cv.action_recognition import (BaseVideoModel,
- PatchShiftTransformer)
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Input, Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors import ReadVideoData
- from modelscope.utils.config import Config
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @PIPELINES.register_module(
- Tasks.action_recognition, module_name=Pipelines.action_recognition)
- class ActionRecognitionPipeline(Pipeline):
- def __init__(self, model: str, **kwargs):
- """
- use `model` to create a action recognition pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, **kwargs)
- model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
- logger.info(f'loading model from {model_path}')
- config_path = osp.join(self.model, ModelFile.CONFIGURATION)
- logger.info(f'loading config from {config_path}')
- self.cfg = Config.from_file(config_path)
- self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device)
- self.infer_model.eval()
- self.infer_model.load_state_dict(
- torch.load(
- model_path, map_location=self.device,
- weights_only=True)['model_state'])
- self.label_mapping = self.cfg.label_mapping
- logger.info('load model done')
- def preprocess(self, input: Input) -> Dict[str, Any]:
- if isinstance(input, str):
- video_input_data = ReadVideoData(self.cfg, input).to(self.device)
- else:
- raise TypeError(f'input should be a str,'
- f' but got {type(input)}')
- result = {'video_data': video_input_data}
- return result
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
- pred = self.perform_inference(input['video_data'])
- output_label = self.label_mapping[str(pred)]
- return {OutputKeys.LABELS: output_label}
- @torch.no_grad()
- def perform_inference(self, data, max_bsz=4):
- iter_num = math.ceil(data.size(0) / max_bsz)
- preds_list = []
- for i in range(iter_num):
- preds_list.append(
- self.infer_model(data[i * max_bsz:(i + 1) * max_bsz])[0])
- pred = torch.cat(preds_list, dim=0)
- return pred.mean(dim=0).argmax().item()
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- return inputs
- @PIPELINES.register_module(
- Tasks.action_recognition, module_name=Pipelines.pst_action_recognition)
- class PSTActionRecognitionPipeline(Pipeline):
- def __init__(self, model: str, **kwargs):
- """
- use `model` to create a PST action recognition pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, **kwargs)
- model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
- logger.info(f'loading model from {model_path}')
- config_path = osp.join(self.model, ModelFile.CONFIGURATION)
- logger.info(f'loading config from {config_path}')
- self.cfg = Config.from_file(config_path)
- self.infer_model = PatchShiftTransformer(model).to(self.device)
- self.infer_model.eval()
- self.infer_model.load_state_dict(
- torch.load(
- model_path, map_location=self.device,
- weights_only=True)['state_dict'])
- self.label_mapping = self.cfg.label_mapping
- logger.info('load model done')
- def preprocess(self, input: Input) -> Dict[str, Any]:
- if isinstance(input, str):
- video_input_data = ReadVideoData(self.cfg, input).to(self.device)
- else:
- raise TypeError(f'input should be a str,'
- f' but got {type(input)}')
- result = {'video_data': video_input_data}
- return result
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
- pred = self.perform_inference(input['video_data'])
- output_label = self.label_mapping[str(pred)]
- return {OutputKeys.LABELS: output_label}
- @torch.no_grad()
- def perform_inference(self, data, max_bsz=4):
- iter_num = math.ceil(data.size(0) / max_bsz)
- preds_list = []
- for i in range(iter_num):
- preds_list.append(
- self.infer_model(data[i * max_bsz:(i + 1) * max_bsz]))
- pred = torch.cat(preds_list, dim=0)
- return pred.mean(dim=0).argmax().item()
- def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
- return inputs
|