| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os.path as osp
- from typing import Optional
- import cv2
- import numpy as np
- import PIL.Image as Image
- import torch
- import torchvision.transforms as transforms
- from skimage.io import imread
- from skimage.transform import estimate_transform, warp
- from modelscope.metainfo import Models
- from modelscope.models.base import Tensor, TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.models.cv.human_reconstruction.models.detectors import \
- FasterRCNN
- from modelscope.models.cv.human_reconstruction.models.human_segmenter import \
- human_segmenter
- from modelscope.models.cv.human_reconstruction.models.networks import define_G
- from modelscope.models.cv.human_reconstruction.models.PixToMesh import \
- Pixto3DNet
- from modelscope.models.cv.human_reconstruction.utils import create_grid
- from modelscope.utils.constant import ModelFile, Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @MODELS.register_module(
- Tasks.human_reconstruction, module_name=Models.human_reconstruction)
- class HumanReconstruction(TorchModel):
- def __init__(self, model_dir, modelconfig, *args, **kwargs):
- """The HumanReconstruction is modified based on PiFuHD and pix2pixhd, publicly available at
- https://shunsukesaito.github.io/PIFuHD/ &
- https://github.com/NVIDIA/pix2pixHD
- Args:
- model_dir: the root directory of the model files
- modelconfig: the config param path of the model
- """
- super().__init__(model_dir=model_dir, *args, **kwargs)
- if torch.cuda.is_available():
- self.device = torch.device('cuda')
- logger.info('Use GPU: {}'.format(self.device))
- else:
- self.device = torch.device('cpu')
- logger.info('Use CPU: {}'.format(self.device))
- model_path = '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_FILE)
- normal_back_model = '{}/{}'.format(model_dir, 'Norm_B_GAN.pth')
- normal_front_model = '{}/{}'.format(model_dir, 'Norm_F_GAN.pth')
- human_seg_model = '{}/{}'.format(model_dir, ModelFile.TF_GRAPH_FILE)
- fastrcnn_ckpt = '{}/{}'.format(model_dir, 'fasterrcnn_resnet50.pth')
- self.meshmodel = Pixto3DNet(**modelconfig['model'])
- self.detector = FasterRCNN(ckpt=fastrcnn_ckpt, device=self.device)
- self.meshmodel.load_state_dict(
- torch.load(model_path, map_location='cpu'))
- self.netB = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
- self.netF = define_G(3, 3, 64, 'global', 4, 9, 1, 3, 'instance')
- self.netF.load_state_dict(torch.load(normal_front_model))
- self.netB.load_state_dict(torch.load(normal_back_model))
- self.netF = self.netF.to(self.device)
- self.netB = self.netB.to(self.device)
- self.netF.eval()
- self.netB.eval()
- self.meshmodel = self.meshmodel.to(self.device).eval()
- self.portrait_matting = human_segmenter(model_path=human_seg_model)
- b_min = np.array([-1, -1, -1])
- b_max = np.array([1, 1, 1])
- self.coords, self.mat = create_grid(modelconfig['resolution'], b_min,
- b_max)
- projection_matrix = np.identity(4)
- projection_matrix[1, 1] = -1
- self.calib = torch.Tensor(projection_matrix).float().to(self.device)
- self.calib = self.calib[:3, :4].unsqueeze(0)
- logger.info('model load over')
- def get_mask(self, img):
- result = self.portrait_matting.run(img)
- result = result[..., None]
- mask = result.repeat(3, axis=2)
- return img, mask
- @torch.no_grad()
- def crop_img(self, img_url):
- image = imread(img_url)[:, :, :3] / 255.
- h, w, _ = image.shape
- image_size = 512
- image_tensor = torch.tensor(
- image.transpose(2, 0, 1), dtype=torch.float32)[None, ...]
- bbox = self.detector.run(image_tensor)
- left = bbox[0]
- right = bbox[2]
- top = bbox[1]
- bottom = bbox[3]
- old_size = max(right - left, bottom - top)
- center = np.array(
- [right - (right - left) / 2.0, bottom - (bottom - top) / 2.0])
- size = int(old_size * 1.1)
- src_pts = np.array([[center[0] - size / 2, center[1] - size / 2],
- [center[0] - size / 2, center[1] + size / 2],
- [center[0] + size / 2, center[1] - size / 2]])
- DST_PTS = np.array([[0, 0], [0, image_size - 1], [image_size - 1, 0]])
- tform = estimate_transform('similarity', src_pts, DST_PTS)
- dst_image = warp(
- image, tform.inverse, output_shape=(image_size, image_size))
- dst_image = (dst_image[:, :, ::-1] * 255).astype(np.uint8)
- return dst_image
- @torch.no_grad()
- def generation_normal(self, img, mask):
- to_tensor = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
- im_512 = cv2.resize(img, (512, 512))
- image_512 = Image.fromarray(im_512).convert('RGB')
- image_512 = to_tensor(image_512).unsqueeze(0)
- img = image_512.to(self.device)
- nml_f = self.netF.forward(img)
- nml_b = self.netB.forward(img)
- mask = cv2.resize(mask, (512, 512))
- mask = transforms.ToTensor()(mask).unsqueeze(0)
- nml_f = (nml_f.cpu() * mask).detach().cpu().numpy()[0]
- nml_f = (np.transpose(nml_f,
- (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
- nml_b = (nml_b.cpu() * mask).detach().cpu().numpy()[0]
- nml_b = (np.transpose(nml_b,
- (1, 2, 0)) * 0.5 + 0.5)[:, :, ::-1] * 255.0
- nml_f = nml_f.astype(np.uint8)
- nml_b = nml_b.astype(np.uint8)
- return nml_f, nml_b
- # def forward(self, img, mask, normal_f, normal_b):
|