human3d_animation_pipeline.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict
  4. import cv2
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models.cv.human3d_animation import (gen_skeleton_bvh, read_obj,
  7. write_obj)
  8. from modelscope.msdatasets import MsDataset
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Pipeline
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.utils.constant import Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. @PIPELINES.register_module(
  16. Tasks.human3d_animation, module_name=Pipelines.human3d_animation)
  17. class Human3DAnimationPipeline(Pipeline):
  18. """ Human3D library render pipeline
  19. Example:
  20. ```python
  21. >>> from modelscope.pipelines import pipeline
  22. >>> human3d = pipeline(Tasks.human3d_animation,
  23. 'damo/cv_3d-human-animation')
  24. >>> human3d({
  25. 'dataset_id': 'damo/3DHuman_synthetic_dataset', # dataset id (str)
  26. 'case_id': '3f2a7538253e42a8', # case id (str)
  27. 'action_dataset': 'damo/3DHuman_action_dataset', # action data id
  28. 'action': 'ArmsHipHopDance' # action name or action file path (str)
  29. 'save_dir': 'output' # save directory (str)
  30. })
  31. >>> #
  32. ```
  33. """
  34. def __init__(self, model, device='gpu', **kwargs):
  35. """
  36. use model to create a image sky change pipeline for image editing
  37. Args:
  38. model (str or Model): model_id on modelscope hub
  39. device (str): only support gpu
  40. """
  41. super().__init__(model=model, **kwargs)
  42. self.model_dir = model
  43. logger.info('model_dir:', self.model_dir)
  44. def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  45. return inputs
  46. def gen_skeleton(self, case_dir, action_dir, action):
  47. self.case_dir = case_dir
  48. self.action_dir = action_dir
  49. self.action = action
  50. status = gen_skeleton_bvh(self.model_dir, self.action_dir,
  51. self.case_dir, self.action)
  52. return status
  53. def gen_weights(self, save_dir=None):
  54. case_name = os.path.basename(self.case_dir)
  55. action_name = os.path.basename(self.action).replace('.npy', '')
  56. if save_dir is None:
  57. gltf_path = os.path.join(self.case_dir, '%s-%s.glb' %
  58. (case_name, action_name))
  59. else:
  60. os.makedirs(save_dir, exist_ok=True)
  61. gltf_path = os.path.join(save_dir, '%s-%s.glb' %
  62. (case_name, action_name))
  63. exec_path = os.path.join(self.model_dir, 'skinning.py')
  64. cmd = f'{self.blender} -b -P {exec_path} -- --input {self.case_dir}' \
  65. f' --gltf_path {gltf_path} --action {self.action}'
  66. os.system(cmd)
  67. return gltf_path
  68. def animate(self, mesh_path, action_dir, action, save_dir=None):
  69. case_dir = os.path.dirname(os.path.abspath(mesh_path))
  70. tex_path = mesh_path.replace('.obj', '.png')
  71. mesh = read_obj(mesh_path)
  72. tex = cv2.imread(tex_path)
  73. vertices = mesh['vertices']
  74. mesh['vertices'] = vertices
  75. mesh['texture_map'] = tex
  76. write_obj(mesh_path, mesh)
  77. self.gen_skeleton(case_dir, action_dir, action)
  78. gltf_path = self.gen_weights(save_dir)
  79. if os.path.exists(gltf_path):
  80. logger.info('save animation succeed!')
  81. else:
  82. logger.info('save animation failed!')
  83. return gltf_path
  84. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  85. dataset_id = input['dataset_id']
  86. case_id = input['case_id']
  87. action_data_id = input['action_dataset']
  88. action = input['action']
  89. if 'save_dir' in input:
  90. save_dir = input['save_dir']
  91. else:
  92. save_dir = None
  93. if 'blender' in input:
  94. self.blender = input['blender']
  95. else:
  96. self.blender = 'blender'
  97. if case_id.endswith('.obj'):
  98. mesh_path = case_id
  99. else:
  100. dataset_name = dataset_id.split('/')[-1]
  101. user_name = dataset_id.split('/')[0]
  102. data_dir = MsDataset.load(
  103. dataset_name, namespace=user_name,
  104. subset_name=case_id).config_kwargs['split_config']['test']
  105. case_dir = os.path.join(data_dir, case_id)
  106. mesh_path = os.path.join(case_dir, 'body.obj')
  107. logger.info('load mesh:', mesh_path)
  108. dataset_name = action_data_id.split('/')[-1]
  109. user_name = action_data_id.split('/')[0]
  110. action_dir = MsDataset.load(
  111. dataset_name, namespace=user_name,
  112. split='test').config_kwargs['split_config']['test']
  113. action_dir = os.path.join(action_dir, 'actions_a')
  114. output = self.animate(mesh_path, action_dir, action, save_dir)
  115. return {OutputKeys.OUTPUT: output}
  116. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  117. return inputs