Reconstruction.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Optional
  4. import cv2
  5. import numpy as np
  6. import PIL.Image as Image
  7. import torch
  8. import torchvision.transforms as transforms
  9. from skimage.io import imread
  10. from skimage.transform import estimate_transform, warp
  11. from modelscope.metainfo import Models
  12. from modelscope.models.base import Tensor, TorchModel
  13. from modelscope.models.builder import MODELS
  14. from modelscope.models.cv.human_reconstruction.models.detectors import \
  15. FasterRCNN
  16. from modelscope.models.cv.human_reconstruction.models.human_segmenter import \
  17. human_segmenter
  18. from modelscope.models.cv.human_reconstruction.models.networks import define_G
  19. from modelscope.models.cv.human_reconstruction.models.PixToMesh import \
  20. Pixto3DNet
  21. from modelscope.models.cv.human_reconstruction.utils import create_grid
  22. from modelscope.utils.constant import ModelFile, Tasks
  23. from modelscope.utils.logger import get_logger
  24. logger = get_logger()
  25. @MODELS.register_module(
  26. Tasks.human_reconstruction, module_name=Models.human_reconstruction)
  27. class HumanReconstruction(TorchModel):
  28. def __init__(self, model_dir, modelconfig, *args, **kwargs):
  29. """The HumanReconstruction is modified based on PiFuHD and pix2pixhd, publicly available at
  30. https://shunsukesaito.github.io/PIFuHD/ &
  31. https://github.com/NVIDIA/pix2pixHD
  32. Args:
  33. model_dir: the root directory of the model files
  34. modelconfig: the config param path of the model
  35. """
  36. super().__init__(model_dir=model_dir, *args, **kwargs)
  37. if torch.cuda.is_available():
  38. self.device = torch.device('cuda')
  39. logger.info('Use GPU: {}'.format(self.device))
  40. else:
  41. self.device = torch.device('cpu')
  42. logger.info('Use CPU: {}'.format(self.device))
  43. model_path = '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_FILE)
  44. normal_back_model = '{}/{}'.format(model_dir, 'Norm_B_GAN.pth')
  45. normal_front_model = '{}/{}'.format(model_dir, 'Norm_F_GAN.pth')
  46. human_seg_model = '{}/{}'.format(model_dir, ModelFile.TF_GRAPH_FILE)
  47. fastrcnn_ckpt = '{}/{}'.format(model_dir, 'fasterrcnn_resnet50.pth')
  48. self.meshmodel = Pixto3DNet(**modelconfig['model'])
  49. self.detector = FasterRCNN(ckpt=fastrcnn_ckpt, device=self.device)
  50. self.meshmodel.load_state_dict(
  51. torch.load(model_path, map_location='cpu'))
  52. self.netB = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
  53. self.netF = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
  54. self.netF.load_state_dict(torch.load(normal_front_model))
  55. self.netB.load_state_dict(torch.load(normal_back_model))
  56. self.netF = self.netF.to(self.device)
  57. self.netB = self.netB.to(self.device)
  58. self.netF.eval()
  59. self.netB.eval()
  60. self.meshmodel = self.meshmodel.to(self.device).eval()
  61. self.portrait_matting = human_segmenter(model_path=human_seg_model)
  62. b_min = np.array([-1, -1, -1])
  63. b_max = np.array([1, 1, 1])
  64. self.coords, self.mat = create_grid(modelconfig['resolution'], b_min,
  65. b_max)
  66. projection_matrix = np.identity(4)
  67. projection_matrix[1, 1] = -1
  68. self.calib = torch.Tensor(projection_matrix).float().to(self.device)
  69. self.calib = self.calib[:3, :4].unsqueeze(0)
  70. logger.info('model load over')
  71. def get_mask(self, img):
  72. result = self.portrait_matting.run(img)
  73. result = result[..., None]
  74. mask = result.repeat(3, axis=2)
  75. return img, mask
  76. @torch.no_grad()
  77. def crop_img(self, img_url):
  78. image = imread(img_url)[:, :, :3] / 255.
  79. h, w, _ = image.shape
  80. image_size = 512
  81. image_tensor = torch.tensor(
  82. image.transpose(2, 0, 1), dtype=torch.float32)[None, ...]
  83. bbox = self.detector.run(image_tensor)
  84. left = bbox[0]
  85. right = bbox[2]
  86. top = bbox[1]
  87. bottom = bbox[3]
  88. old_size = max(right - left, bottom - top)
  89. center = np.array(
  90. [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
  91. size = int(old_size * 1.1)
  92. src_pts = np.array([[center[0] - size / 2, center[1] - size / 2],
  93. [center[0] - size / 2, center[1] + size / 2],
  94. [center[0] + size / 2, center[1] - size / 2]])
  95. DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
  96. tform = estimate_transform('similarity', src_pts, DST_PTS)
  97. dst_image = warp(
  98. image, tform.inverse, output_shape=(image_size, image_size))
  99. dst_image = (dst_image[:, :, ::-1] * 255).astype(np.uint8)
  100. return dst_image
  101. @torch.no_grad()
  102. def generation_normal(self, img, mask):
  103. to_tensor = transforms.Compose([
  104. transforms.ToTensor(),
  105. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  106. ])
  107. im_512 = cv2.resize(img, (512, 512))
  108. image_512 = Image.fromarray(im_512).convert('RGB')
  109. image_512 = to_tensor(image_512).unsqueeze(0)
  110. img = image_512.to(self.device)
  111. nml_f = self.netF.forward(img)
  112. nml_b = self.netB.forward(img)
  113. mask = cv2.resize(mask, (512, 512))
  114. mask = transforms.ToTensor()(mask).unsqueeze(0)
  115. nml_f = (nml_f.cpu() * mask).detach().cpu().numpy()[0]
  116. nml_f = (np.transpose(nml_f,
  117. (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
  118. nml_b = (nml_b.cpu() * mask).detach().cpu().numpy()[0]
  119. nml_b = (np.transpose(nml_b,
  120. (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
  121. nml_f = nml_f.astype(np.uint8)
  122. nml_b = nml_b.astype(np.uint8)
  123. return nml_f, nml_b
  124. # def forward(self, img, mask, normal_f, normal_b):