human3d_render_pipeline.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import os
  4. from typing import Any, Dict
  5. import cv2
  6. import numpy as np
  7. import nvdiffrast.torch as dr
  8. import torch
  9. import tqdm
  10. from modelscope.metainfo import Pipelines
  11. from modelscope.models.cv.face_reconstruction.utils import mesh_to_string
  12. from modelscope.models.cv.human3d_animation import (projection, read_obj,
  13. render, rotate_x, rotate_y,
  14. translate)
  15. from modelscope.msdatasets import MsDataset
  16. from modelscope.outputs import OutputKeys
  17. from modelscope.pipelines.base import Model, Pipeline
  18. from modelscope.pipelines.builder import PIPELINES
  19. from modelscope.pipelines.util import is_model
  20. from modelscope.utils.constant import Invoke, Tasks
  21. from modelscope.utils.logger import get_logger
  22. logger = get_logger()
  23. @PIPELINES.register_module(
  24. Tasks.human3d_render, module_name=Pipelines.human3d_render)
  25. class Human3DRenderPipeline(Pipeline):
  26. """ Human3D library render pipeline
  27. Example:
  28. ```python
  29. >>> from modelscope.pipelines import pipeline
  30. >>> human3d = pipeline(Tasks.human3d_render,
  31. 'damo/cv_3d-human-synthesis-library')
  32. >>> human3d({
  33. 'data_dir': '/data/human3d-syn-library', # data dir path (str)
  34. 'case_id': '3f2a7538253e42a8', # case id (str)
  35. })
  36. >>> #
  37. ```
  38. """
  39. def __init__(self, model: str, device='gpu', **kwargs):
  40. """
  41. use model to create a image sky change pipeline for image editing
  42. Args:
  43. model (str or Model): model_id on modelscope hub
  44. device (str): only support gpu
  45. """
  46. super().__init__(model=model, **kwargs)
  47. self.model_dir = model
  48. def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  49. return inputs
  50. def load_3d_model(self, mesh_path):
  51. mesh = read_obj(mesh_path)
  52. tex_path = mesh_path.replace('.obj', '.png')
  53. if not os.path.exists(tex_path):
  54. tex = np.zeros((256, 256, 3), dtype=np.uint8)
  55. else:
  56. tex = cv2.imread(tex_path)
  57. mesh['texture_map'] = tex.copy()
  58. return mesh, tex
  59. def format_nvdiffrast_format(self, mesh, tex):
  60. vert = mesh['vertices']
  61. cent = (vert.max(axis=0) + vert.min(axis=0)) / 2
  62. vert -= cent
  63. tri = mesh['faces']
  64. tri = tri - 1 if tri.min() == 1 else tri
  65. vert_uv = mesh['uvs']
  66. tri_uv = mesh['faces_uv']
  67. tri_uv = tri_uv - 1 if tri_uv.min() == 1 else tri_uv
  68. vtx_pos = torch.from_numpy(vert.astype(np.float32)).cuda()
  69. pos_idx = torch.from_numpy(tri.astype(np.int32)).cuda()
  70. vtx_uv = torch.from_numpy(vert_uv.astype(np.float32)).cuda()
  71. uv_idx = torch.from_numpy(tri_uv.astype(np.int32)).cuda()
  72. tex = tex[::-1, :, ::-1]
  73. tex = torch.from_numpy(tex.astype(np.float32) / 255.0).cuda()
  74. return vtx_pos, pos_idx, vtx_uv, uv_idx, tex
  75. def render_scene(self, mesh_path, resolution=512):
  76. if not os.path.exists(mesh_path):
  77. logger.info('can not found %s, use default one' % mesh_path)
  78. mesh_path = os.path.join(self.model_dir, '3D-assets',
  79. '3f2a7538253e42a8', 'body.obj')
  80. mesh, texture = self.load_3d_model(mesh_path)
  81. vtx_pos, pos_idx, vtx_uv, uv_idx, tex = self.format_nvdiffrast_format(
  82. mesh, texture)
  83. glctx = dr.RasterizeCudaContext()
  84. ang = 0.0
  85. frame_length = 80
  86. step = 2 * np.pi / frame_length
  87. frames_color = []
  88. frames_normals = []
  89. for i in tqdm.tqdm(range(frame_length)):
  90. proj = projection(x=0.4, n=1.0, f=200.0)
  91. a_rot = np.matmul(rotate_x(0.0), rotate_y(ang))
  92. a_mv = np.matmul(translate(0, 0, -2.7), a_rot)
  93. r_mvp = np.matmul(proj, a_mv).astype(np.float32)
  94. pred_img, pred_mask, normal = render(
  95. glctx,
  96. r_mvp,
  97. vtx_pos,
  98. pos_idx,
  99. vtx_uv,
  100. uv_idx,
  101. tex,
  102. resolution=resolution,
  103. enable_mip=False,
  104. max_mip_level=9)
  105. color = np.clip(
  106. np.rint(pred_img[0].detach().cpu().numpy() * 255.0), 0,
  107. 255).astype(np.uint8)[::-1, :, :]
  108. normals = np.clip(
  109. np.rint(normal[0].detach().cpu().numpy() * 255.0), 0,
  110. 255).astype(np.uint8)[::-1, :, :]
  111. frames_color.append(color)
  112. frames_normals.append(normals)
  113. ang = ang + step
  114. logger.info('render case %s done'
  115. % os.path.basename(os.path.dirname(mesh_path)))
  116. return mesh, frames_color, frames_normals
  117. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  118. dataset_id = input['dataset_id']
  119. case_id = input['case_id']
  120. if 'resolution' in input:
  121. resolution = input['resolution']
  122. else:
  123. resolution = 512
  124. if case_id.endswith('.obj'):
  125. mesh_path = case_id
  126. else:
  127. dataset_name = dataset_id.split('/')[-1]
  128. user_name = dataset_id.split('/')[0]
  129. data_dir = MsDataset.load(
  130. dataset_name, namespace=user_name,
  131. subset_name=case_id).config_kwargs['split_config']['test']
  132. case_dir = os.path.join(data_dir, case_id)
  133. mesh_path = os.path.join(case_dir, 'body.obj')
  134. mesh, colors, normals = self.render_scene(mesh_path, resolution)
  135. results = {
  136. 'mesh': mesh,
  137. 'frames_color': colors,
  138. 'frames_normal': normals,
  139. }
  140. return {OutputKeys.OUTPUT_OBJ: None, OutputKeys.OUTPUT: results}
  141. def postprocess(self, inputs, **kwargs) -> Dict[str, Any]:
  142. render = kwargs.get('render', False)
  143. output_obj = inputs[OutputKeys.OUTPUT_OBJ]
  144. results = inputs[OutputKeys.OUTPUT]
  145. if render:
  146. output_obj = io.BytesIO()
  147. mesh_str = mesh_to_string(results['mesh'])
  148. mesh_bytes = mesh_str.encode(encoding='utf-8')
  149. output_obj.write(mesh_bytes)
  150. result = {
  151. OutputKeys.OUTPUT_OBJ: output_obj,
  152. OutputKeys.OUTPUT: None if render else results,
  153. }
  154. return result