motion_generation_pipeline.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. import tempfile
  4. from typing import Any, Dict
  5. import numpy as np
  6. import torch
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models.cv.motion_generation import (ClassifierFreeSampleModel,
  9. create_model,
  10. load_model_wo_clip)
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines.base import Input, Pipeline
  13. from modelscope.pipelines.builder import PIPELINES
  14. from modelscope.utils.config import Config
  15. from modelscope.utils.constant import ModelFile, Tasks
  16. from modelscope.utils.cv.motion_utils.motion_process import recover_from_ric
  17. from modelscope.utils.cv.motion_utils.plot_script import plot_3d_motion
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. @PIPELINES.register_module(
  21. Tasks.motion_generation, module_name=Pipelines.motion_generattion)
  22. class MDMMotionGeneration(Pipeline):
  23. def __init__(self, model: str, **kwargs):
  24. """
  25. use `model` to create motion generation pipeline for prediction
  26. Args:
  27. model: model id on modelscope hub.
  28. """
  29. super().__init__(model=model, **kwargs)
  30. model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
  31. logger.info(f'loading model from {model_path}')
  32. config_path = osp.join(self.model, ModelFile.CONFIGURATION)
  33. logger.info(f'loading config from {config_path}')
  34. self.mean = np.load(osp.join(self.model, 'Mean.npy'))
  35. self.std = np.load(osp.join(self.model, 'Std.npy'))
  36. self.cfg = Config.from_file(config_path)
  37. self.cfg.update({'smpl_data_path': osp.join(self.model, 'smpl')})
  38. self.cfg.update(kwargs)
  39. self.n_joints = 22
  40. self.fps = 20
  41. self.n_frames = 120
  42. self.mdm, self.diffusion = create_model(self.cfg)
  43. state_dict = torch.load(
  44. model_path, map_location='cpu', weights_only=True)
  45. load_model_wo_clip(self.mdm, state_dict)
  46. self.mdm = ClassifierFreeSampleModel(self.mdm)
  47. self.mdm.to(self.device)
  48. self.mdm.eval()
  49. logger.info('load model done')
  50. def preprocess(self, input: Input) -> Dict[str, Any]:
  51. if isinstance(input, str):
  52. input_text = input
  53. else:
  54. raise TypeError(f'input should be a str,'
  55. f' but got {type(input)}')
  56. result = {'input_text': input_text}
  57. return result
  58. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  59. texts = [input['input_text']]
  60. model_kwargs = {
  61. 'y': {
  62. 'mask': torch.ones(1, 1, 1, self.n_frames) > 0,
  63. 'lengths': torch.tensor([self.n_frames]),
  64. 'tokens': None,
  65. 'text': texts,
  66. 'scale': torch.ones(1, device=self.device) * 2.5
  67. }
  68. }
  69. sample_fn = self.diffusion.p_sample_loop
  70. sample = sample_fn(
  71. self.mdm,
  72. (1, self.mdm.njoints, self.mdm.nfeats, self.n_frames),
  73. clip_denoised=False,
  74. model_kwargs=model_kwargs,
  75. skip_timesteps=0,
  76. init_image=None,
  77. progress=True,
  78. dump_steps=None,
  79. noise=None,
  80. const_noise=False,
  81. )
  82. sample = (sample.cpu().permute(0, 2, 3, 1) * self.std
  83. + self.mean).float()
  84. sample = recover_from_ric(sample, self.n_joints)
  85. sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
  86. sample = self.mdm.rot2xyz(
  87. x=sample,
  88. mask=None,
  89. pose_rep='xyz',
  90. glob=True,
  91. translation=True,
  92. jointstype='smpl',
  93. vertstrans=True,
  94. betas=None,
  95. beta=0,
  96. glob_rot=None,
  97. get_rotations_back=False)
  98. motion = sample.cpu().numpy()
  99. motion = motion[0].transpose(2, 0, 1)
  100. out = {OutputKeys.KEYPOINTS: motion, 'text': input['input_text']}
  101. return out
  102. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  103. output_video_path = kwargs.get(
  104. 'output_video',
  105. tempfile.NamedTemporaryFile(suffix='.mp4').name)
  106. kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10],
  107. [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
  108. [9, 13, 16, 18, 20]]
  109. if output_video_path is not None:
  110. plot_3d_motion(
  111. output_video_path,
  112. kinematic_chain,
  113. inputs[OutputKeys.KEYPOINTS],
  114. inputs.pop('text'),
  115. dataset='humanml',
  116. fps=20)
  117. inputs.update({OutputKeys.OUTPUT_VIDEO: output_video_path})
  118. return inputs