| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os.path as osp
- import tempfile
- from typing import Any, Dict
- import numpy as np
- import torch
- from modelscope.metainfo import Pipelines
- from modelscope.models.cv.motion_generation import (ClassifierFreeSampleModel,
- create_model,
- load_model_wo_clip)
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Input, Pipeline
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.utils.config import Config
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.cv.motion_utils.motion_process import recover_from_ric
- from modelscope.utils.cv.motion_utils.plot_script import plot_3d_motion
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @PIPELINES.register_module(
- Tasks.motion_generation, module_name=Pipelines.motion_generattion)
- class MDMMotionGeneration(Pipeline):
- def __init__(self, model: str, **kwargs):
- """
- use `model` to create motion generation pipeline for prediction
- Args:
- model: model id on modelscope hub.
- """
- super().__init__(model=model, **kwargs)
- model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
- logger.info(f'loading model from {model_path}')
- config_path = osp.join(self.model, ModelFile.CONFIGURATION)
- logger.info(f'loading config from {config_path}')
- self.mean = np.load(osp.join(self.model, 'Mean.npy'))
- self.std = np.load(osp.join(self.model, 'Std.npy'))
- self.cfg = Config.from_file(config_path)
- self.cfg.update({'smpl_data_path': osp.join(self.model, 'smpl')})
- self.cfg.update(kwargs)
- self.n_joints = 22
- self.fps = 20
- self.n_frames = 120
- self.mdm, self.diffusion = create_model(self.cfg)
- state_dict = torch.load(
- model_path, map_location='cpu', weights_only=True)
- load_model_wo_clip(self.mdm, state_dict)
- self.mdm = ClassifierFreeSampleModel(self.mdm)
- self.mdm.to(self.device)
- self.mdm.eval()
- logger.info('load model done')
- def preprocess(self, input: Input) -> Dict[str, Any]:
- if isinstance(input, str):
- input_text = input
- else:
- raise TypeError(f'input should be a str,'
- f' but got {type(input)}')
- result = {'input_text': input_text}
- return result
- def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
- texts = [input['input_text']]
- model_kwargs = {
- 'y': {
- 'mask': torch.ones(1, 1, 1, self.n_frames) > 0,
- 'lengths': torch.tensor([self.n_frames]),
- 'tokens': None,
- 'text': texts,
- 'scale': torch.ones(1, device=self.device) * 2.5
- }
- }
- sample_fn = self.diffusion.p_sample_loop
- sample = sample_fn(
- self.mdm,
- (1, self.mdm.njoints, self.mdm.nfeats, self.n_frames),
- clip_denoised=False,
- model_kwargs=model_kwargs,
- skip_timesteps=0,
- init_image=None,
- progress=True,
- dump_steps=None,
- noise=None,
- const_noise=False,
- )
- sample = (sample.cpu().permute(0, 2, 3, 1) * self.std
- + self.mean).float()
- sample = recover_from_ric(sample, self.n_joints)
- sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
- sample = self.mdm.rot2xyz(
- x=sample,
- mask=None,
- pose_rep='xyz',
- glob=True,
- translation=True,
- jointstype='smpl',
- vertstrans=True,
- betas=None,
- beta=0,
- glob_rot=None,
- get_rotations_back=False)
- motion = sample.cpu().numpy()
- motion = motion[0].transpose(2, 0, 1)
- out = {OutputKeys.KEYPOINTS: motion, 'text': input['input_text']}
- return out
- def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
- output_video_path = kwargs.get(
- 'output_video',
- tempfile.NamedTemporaryFile(suffix='.mp4').name)
- kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10],
- [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
- [9, 13, 16, 18, 20]]
- if output_video_path is not None:
- plot_3d_motion(
- output_video_path,
- kinematic_chain,
- inputs[OutputKeys.KEYPOINTS],
- inputs.pop('text'),
- dataset='humanml',
- fps=20)
- inputs.update({OutputKeys.OUTPUT_VIDEO: output_video_path})
- return inputs
|