human_reconstruction_pipeline.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. from typing import Any, Dict
  5. import numpy as np
  6. import torch
  7. import trimesh
  8. from modelscope.metainfo import Pipelines
  9. from modelscope.models.cv.human_reconstruction.utils import (
  10. keep_largest, reconstruction, save_obj_mesh, save_obj_mesh_with_color,
  11. to_tensor)
  12. from modelscope.outputs import OutputKeys
  13. from modelscope.pipelines import pipeline
  14. from modelscope.pipelines.base import Input, Model, Pipeline
  15. from modelscope.pipelines.builder import PIPELINES
  16. from modelscope.utils.constant import ModelFile, Tasks
  17. from modelscope.utils.logger import get_logger
  18. logger = get_logger()
  19. @PIPELINES.register_module(
  20. Tasks.human_reconstruction, module_name=Pipelines.human_reconstruction)
  21. class HumanReconstructionPipeline(Pipeline):
  22. def __init__(self, model: str, **kwargs):
  23. """The inference pipeline for human reconstruction task.
  24. Human Reconstruction Pipeline. Given one image generate a human mesh.
  25. Args:
  26. model (`str` or `Model` or module instance): A model instance or a model local dir
  27. or a model id in the model hub.
  28. Example:
  29. >>> from modelscope.pipelines import pipeline
  30. >>> test_input = 'human_reconstruction.jpg' # input image path
  31. >>> pipeline_humanRecon = pipeline('human-reconstruction',
  32. model='damo/cv_hrnet_image-human-reconstruction')
  33. >>> result = pipeline_humanRecon(test_input)
  34. >>> output = result[OutputKeys.OUTPUT]
  35. """
  36. super().__init__(model=model, **kwargs)
  37. if not isinstance(self.model, Model):
  38. logger.error('model object is not initialized.')
  39. raise Exception('model object is not initialized.')
  40. def preprocess(self, input: Input) -> Dict[str, Any]:
  41. img_crop = self.model.crop_img(input)
  42. img, mask = self.model.get_mask(img_crop)
  43. normal_f, normal_b = self.model.generation_normal(img, mask)
  44. image = to_tensor(img_crop) * 2 - 1
  45. normal_b = to_tensor(normal_b) * 2 - 1
  46. normal_f = to_tensor(normal_f) * 2 - 1
  47. mask = to_tensor(mask)
  48. result = {
  49. 'img': image,
  50. 'mask': mask,
  51. 'normal_F': normal_f,
  52. 'normal_B': normal_b
  53. }
  54. return result
  55. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  56. image = input['img']
  57. mask = input['mask']
  58. normF = input['normal_F']
  59. normB = input['normal_B']
  60. normF[1, ...] = -normF[1, ...]
  61. normB[0, ...] = -normB[0, ...]
  62. img = image * mask
  63. normal_b = normB * mask
  64. normal_f = normF * mask
  65. img = torch.cat([img, normal_f, normal_b], dim=0).float()
  66. image_tensor = img.unsqueeze(0).to(self.model.device)
  67. calib_tensor = self.model.calib
  68. net = self.model.meshmodel
  69. net.extract_features(image_tensor)
  70. verts, faces = reconstruction(net, calib_tensor, self.model.coords,
  71. self.model.mat)
  72. pre_mesh = trimesh.Trimesh(
  73. verts, faces, process=False, maintain_order=True)
  74. final_mesh = keep_largest(pre_mesh)
  75. verts = final_mesh.vertices
  76. faces = final_mesh.faces
  77. verts_tensor = torch.from_numpy(verts.T).unsqueeze(0).to(
  78. self.model.device).float()
  79. color = torch.zeros(verts.shape)
  80. interval = 20000
  81. for i in range(len(color) // interval):
  82. left = i * interval
  83. right = i * interval + interval
  84. if i == len(color) // interval - 1:
  85. right = -1
  86. pred_color = net.query_rgb(verts_tensor[:, :, left:right],
  87. calib_tensor)
  88. rgb = pred_color[0].detach().cpu() * 0.5 + 0.5
  89. color[left:right] = rgb.T
  90. vert_min = np.min(verts[:, 1])
  91. verts[:, 1] = verts[:, 1] - vert_min
  92. save_obj_mesh('human_reconstruction.obj', verts, faces)
  93. save_obj_mesh_with_color('human_color.obj', verts, faces,
  94. color.numpy())
  95. results = {'vertices': verts, 'faces': faces, 'colors': color.numpy()}
  96. return {OutputKeys.OUTPUT: results}
  97. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  98. return inputs