face_reconstruction_pipeline.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import os
  4. import shutil
  5. from typing import Any, Dict
  6. import cv2
  7. import face_alignment
  8. import numpy as np
  9. import PIL.Image
  10. import tensorflow as tf
  11. import torch
  12. from scipy.io import loadmat, savemat
  13. from modelscope.metainfo import Pipelines
  14. from modelscope.models.cv.face_reconstruction.models.facelandmark.large_base_lmks_infer import \
  15. LargeBaseLmkInfer
  16. from modelscope.models.cv.face_reconstruction.utils import (
  17. align_for_lm, align_img, draw_line, enlarged_bbox, image_warp_grid1,
  18. load_lm3d, mesh_to_string, read_obj, resize_on_long_side, spread_flow,
  19. write_obj)
  20. from modelscope.models.cv.skin_retouching.retinaface.predict_single import \
  21. Model
  22. from modelscope.outputs import OutputKeys
  23. from modelscope.pipelines import pipeline
  24. from modelscope.pipelines.base import Input, Pipeline
  25. from modelscope.pipelines.builder import PIPELINES
  26. from modelscope.preprocessors import LoadImage
  27. from modelscope.utils.constant import ModelFile, Tasks
  28. from modelscope.utils.device import create_device, device_placement
  29. from modelscope.utils.logger import get_logger
  30. if tf.__version__ >= '2.0':
  31. tf = tf.compat.v1
  32. tf.disable_eager_execution()
  33. logger = get_logger()
  34. @PIPELINES.register_module(
  35. Tasks.face_reconstruction, module_name=Pipelines.face_reconstruction)
  36. class FaceReconstructionPipeline(Pipeline):
  37. def __init__(self, model: str, device: str):
  38. """The inference pipeline for face reconstruction task.
  39. Args:
  40. model (`str` or `Model` or module instance): A model instance or a model local dir
  41. or a model id in the model hub.
  42. device ('str'): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X.
  43. Example:
  44. >>> from modelscope.pipelines import pipeline
  45. >>> test_image = 'data/test/images/face_reconstruction.jpg'
  46. >>> pipeline_faceRecon = pipeline('face-reconstruction',
  47. model='damo/cv_resnet50_face-reconstruction')
  48. >>> result = pipeline_faceRecon(test_image)
  49. >>> mesh = result[OutputKeys.OUTPUT]['mesh']
  50. >>> texture_map = result[OutputKeys.OUTPUT_IMG]
  51. >>> mesh['texture_map'] = texture_map
  52. >>> write_obj('hrn_mesh_mid.obj', mesh)
  53. """
  54. super().__init__(model=model, device=device)
  55. model_root = model
  56. bfm_folder = os.path.join(model_root, 'assets')
  57. checkpoint_path = os.path.join(model_root, ModelFile.TORCH_MODEL_FILE)
  58. if 'gpu' in device:
  59. self.device_name_ = 'cuda'
  60. else:
  61. self.device_name_ = device
  62. self.device_name_ = self.device_name_.lower()
  63. lmks_cpkt_path = os.path.join(model_root, 'large_base_net.pth')
  64. self.large_base_lmks_model = LargeBaseLmkInfer.model_preload(
  65. lmks_cpkt_path, self.device_name_ == 'cuda')
  66. self.detector = Model(max_size=512, device=self.device_name_)
  67. detector_ckpt_name = 'retinaface_resnet50_2020-07-20_old_torch.pth'
  68. state_dict = torch.load(
  69. os.path.join(os.path.dirname(lmks_cpkt_path), detector_ckpt_name),
  70. map_location='cpu',
  71. weights_only=True)
  72. self.detector.load_state_dict(state_dict)
  73. self.detector.eval()
  74. device = torch.device(self.device_name_)
  75. self.model.set_device(device)
  76. self.model.setup(checkpoint_path)
  77. self.model.parallelize()
  78. self.model.eval()
  79. self.model.set_render(image_res=512)
  80. save_ckpt_dir = os.path.join(
  81. os.path.expanduser('~'), '.cache/torch/hub/checkpoints')
  82. if not os.path.exists(save_ckpt_dir):
  83. os.makedirs(save_ckpt_dir)
  84. shutil.copy(
  85. os.path.join(model_root, 'face_alignment', 's3fd-619a316812.pth'),
  86. save_ckpt_dir)
  87. shutil.copy(
  88. os.path.join(model_root, 'face_alignment',
  89. '3DFAN4-4a694010b9.zip'), save_ckpt_dir)
  90. shutil.copy(
  91. os.path.join(model_root, 'face_alignment', 'depth-6c4283c0e0.zip'),
  92. save_ckpt_dir)
  93. self.lm_sess = face_alignment.FaceAlignment(
  94. face_alignment.LandmarksType.THREE_D, flip_input=False)
  95. config = tf.ConfigProto(allow_soft_placement=True)
  96. config.gpu_options.per_process_gpu_memory_fraction = 0.2
  97. config.gpu_options.allow_growth = True
  98. g1 = tf.Graph()
  99. self.face_sess = tf.Session(graph=g1, config=config)
  100. with self.face_sess.as_default():
  101. with g1.as_default():
  102. with tf.gfile.FastGFile(
  103. os.path.join(model_root, 'segment_face.pb'),
  104. 'rb') as f:
  105. graph_def = tf.GraphDef()
  106. graph_def.ParseFromString(f.read())
  107. self.face_sess.graph.as_default()
  108. tf.import_graph_def(graph_def, name='')
  109. self.face_sess.run(tf.global_variables_initializer())
  110. self.tex_size = 4096
  111. self.lm3d_std = load_lm3d(bfm_folder)
  112. self.align_params = loadmat(
  113. '{}/assets/BBRegressorParam_r.mat'.format(model_root))
  114. device = create_device(self.device_name)
  115. self.device = device
  116. def preprocess(self, input: Input) -> Dict[str, Any]:
  117. img = LoadImage.convert_to_ndarray(input)
  118. if len(img.shape) == 2:
  119. img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
  120. img = img.astype(float)
  121. result = {'img': img}
  122. return result
  123. def read_data(self, img, lm, lm3d_std, to_tensor=True, image_res=1024):
  124. # to RGB
  125. im = PIL.Image.fromarray(img[..., ::-1])
  126. W, H = im.size
  127. lm[:, -1] = H - 1 - lm[:, -1]
  128. _, im_lr, lm_lr, _ = align_img(im, lm, lm3d_std)
  129. _, im_hd, lm_hd, _ = align_img(
  130. im,
  131. lm,
  132. lm3d_std,
  133. target_size=image_res,
  134. rescale_factor=102. * image_res / 224)
  135. mask_lr = self.face_sess.run(
  136. self.face_sess.graph.get_tensor_by_name('output_alpha:0'),
  137. feed_dict={'input_image:0': np.array(im_lr)})
  138. if to_tensor:
  139. im_lr = torch.tensor(
  140. np.array(im_lr) / 255.,
  141. dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
  142. im_hd = torch.tensor(
  143. np.array(im_hd) / 255.,
  144. dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
  145. mask_lr = torch.tensor(
  146. np.array(mask_lr) / 255., dtype=torch.float32)[None,
  147. None, :, :]
  148. lm_lr = torch.tensor(lm_lr).unsqueeze(0)
  149. lm_hd = torch.tensor(lm_hd).unsqueeze(0)
  150. return im_lr, lm_lr, im_hd, lm_hd, mask_lr
  151. def parse_label(self, label):
  152. return torch.tensor(np.array(label).astype(np.float32))
  153. def prepare_data(self, img, lm_sess, five_points=None):
  154. input_img, scale, bbox = align_for_lm(
  155. img, five_points,
  156. self.align_params) # align for 68 landmark detection
  157. if scale == 0:
  158. return None
  159. # detect landmarks
  160. input_img = np.reshape(input_img, [1, 224, 224, 3]).astype(np.float32)
  161. input_img = input_img[0, :, :, ::-1]
  162. landmark = lm_sess.get_landmarks_from_image(input_img)[0]
  163. landmark = landmark[:, :2] / scale
  164. landmark[:, 0] = landmark[:, 0] + bbox[0]
  165. landmark[:, 1] = landmark[:, 1] + bbox[1]
  166. return landmark
  167. def get_img_for_texture(self, input_img_tensor):
  168. input_img = input_img_tensor.permute(
  169. 0, 2, 3, 1).detach().cpu().numpy()[0] * 255.
  170. input_img = input_img.astype(np.uint8)
  171. input_img_for_texture = self.fat_face(input_img, degree=0.03)
  172. input_img_for_texture_tensor = torch.tensor(
  173. np.array(input_img_for_texture) / 255.,
  174. dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
  175. input_img_for_texture_tensor = input_img_for_texture_tensor.to(
  176. self.model.device)
  177. return input_img_for_texture_tensor
  178. def infer_lmks(self, img_bgr):
  179. INPUT_SIZE = 224
  180. ENLARGE_RATIO = 1.35
  181. landmarks = []
  182. rgb_image = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  183. results = self.detector.predict_jsons(rgb_image)
  184. boxes = []
  185. for anno in results:
  186. if anno['score'] == -1:
  187. break
  188. boxes.append({
  189. 'x1': anno['bbox'][0],
  190. 'y1': anno['bbox'][1],
  191. 'x2': anno['bbox'][2],
  192. 'y2': anno['bbox'][3]
  193. })
  194. for detect_result in boxes:
  195. x1 = detect_result['x1']
  196. y1 = detect_result['y1']
  197. x2 = detect_result['x2']
  198. y2 = detect_result['y2']
  199. w = x2 - x1 + 1
  200. h = y2 - y1 + 1
  201. cx = (x2 + x1) / 2
  202. cy = (y2 + y1) / 2
  203. sz = max(h, w) * ENLARGE_RATIO
  204. x1 = cx - sz / 2
  205. y1 = cy - sz / 2
  206. trans_x1 = x1
  207. trans_y1 = y1
  208. x2 = x1 + sz
  209. y2 = y1 + sz
  210. height, width, _ = rgb_image.shape
  211. dx = max(0, -x1)
  212. dy = max(0, -y1)
  213. x1 = max(0, x1)
  214. y1 = max(0, y1)
  215. edx = max(0, x2 - width)
  216. edy = max(0, y2 - height)
  217. x2 = min(width, x2)
  218. y2 = min(height, y2)
  219. crop_img = rgb_image[int(y1):int(y2), int(x1):int(x2)]
  220. if dx > 0 or dy > 0 or edx > 0 or edy > 0:
  221. crop_img = cv2.copyMakeBorder(
  222. crop_img,
  223. int(dy),
  224. int(edy),
  225. int(dx),
  226. int(edx),
  227. cv2.BORDER_CONSTANT,
  228. value=(103.94, 116.78, 123.68))
  229. crop_img = cv2.resize(crop_img, (INPUT_SIZE, INPUT_SIZE))
  230. base_lmks = LargeBaseLmkInfer.infer_img(
  231. crop_img, self.large_base_lmks_model,
  232. self.device_name_ == 'cuda')
  233. inv_scale = sz / INPUT_SIZE
  234. affine_base_lmks = np.zeros((106, 2))
  235. for idx in range(106):
  236. affine_base_lmks[idx][
  237. 0] = base_lmks[0][idx * 2 + 0] * inv_scale + trans_x1
  238. affine_base_lmks[idx][
  239. 1] = base_lmks[0][idx * 2 + 1] * inv_scale + trans_y1
  240. x1 = np.min(affine_base_lmks[:, 0])
  241. y1 = np.min(affine_base_lmks[:, 1])
  242. x2 = np.max(affine_base_lmks[:, 0])
  243. y2 = np.max(affine_base_lmks[:, 1])
  244. w = x2 - x1 + 1
  245. h = y2 - y1 + 1
  246. cx = (x2 + x1) / 2
  247. cy = (y2 + y1) / 2
  248. sz = max(h, w) * ENLARGE_RATIO
  249. x1 = cx - sz / 2
  250. y1 = cy - sz / 2
  251. trans_x1 = x1
  252. trans_y1 = y1
  253. x2 = x1 + sz
  254. y2 = y1 + sz
  255. height, width, _ = rgb_image.shape
  256. dx = max(0, -x1)
  257. dy = max(0, -y1)
  258. x1 = max(0, x1)
  259. y1 = max(0, y1)
  260. edx = max(0, x2 - width)
  261. edy = max(0, y2 - height)
  262. x2 = min(width, x2)
  263. y2 = min(height, y2)
  264. crop_img = rgb_image[int(y1):int(y2), int(x1):int(x2)]
  265. if dx > 0 or dy > 0 or edx > 0 or edy > 0:
  266. crop_img = cv2.copyMakeBorder(
  267. crop_img,
  268. int(dy),
  269. int(edy),
  270. int(dx),
  271. int(edx),
  272. cv2.BORDER_CONSTANT,
  273. value=(103.94, 116.78, 123.68))
  274. crop_img = cv2.resize(crop_img, (INPUT_SIZE, INPUT_SIZE))
  275. base_lmks = LargeBaseLmkInfer.infer_img(
  276. crop_img, self.large_base_lmks_model,
  277. self.device_name_.lower() == 'cuda')
  278. inv_scale = sz / INPUT_SIZE
  279. affine_base_lmks = np.zeros((106, 2))
  280. for idx in range(106):
  281. affine_base_lmks[idx][
  282. 0] = base_lmks[0][idx * 2 + 0] * inv_scale + trans_x1
  283. affine_base_lmks[idx][
  284. 1] = base_lmks[0][idx * 2 + 1] * inv_scale + trans_y1
  285. landmarks.append(affine_base_lmks)
  286. return boxes, landmarks
  287. def find_face_contour(self, image):
  288. boxes, landmarks = self.infer_lmks(image)
  289. landmarks = np.array(landmarks)
  290. args = [[0, 33, False], [33, 38, False], [42, 47, False],
  291. [51, 55, False], [57, 64, False], [66, 74, True],
  292. [75, 83, True], [84, 96, True]]
  293. roi_bboxs = []
  294. for i in range(len(boxes)):
  295. roi_bbox = enlarged_bbox([
  296. boxes[i]['x1'], boxes[i]['y1'], boxes[i]['x2'], boxes[i]['y2']
  297. ], image.shape[1], image.shape[0], 0.5)
  298. roi_bbox = [int(x) for x in roi_bbox]
  299. roi_bboxs.append(roi_bbox)
  300. people_maps = []
  301. for i in range(landmarks.shape[0]):
  302. landmark = landmarks[i, :, :]
  303. maps = []
  304. whole_mask = np.zeros((image.shape[0], image.shape[1]), np.uint8)
  305. roi_box = roi_bboxs[i]
  306. roi_box_width = roi_box[2] - roi_box[0]
  307. roi_box_height = roi_box[3] - roi_box[1]
  308. short_side_length = roi_box_width if roi_box_width < roi_box_height else roi_box_height
  309. line_width = short_side_length // 10
  310. if line_width == 0:
  311. line_width = 1
  312. kernel_size = line_width * 2
  313. gaussian_kernel = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
  314. for t, arg in enumerate(args):
  315. mask = np.zeros((image.shape[0], image.shape[1]), np.uint8)
  316. draw_line(mask, landmark[arg[0]:arg[1]], (255, 255, 255),
  317. line_width, arg[2])
  318. mask = cv2.GaussianBlur(mask,
  319. (gaussian_kernel, gaussian_kernel), 0)
  320. if t >= 1:
  321. draw_line(whole_mask, landmark[arg[0]:arg[1]],
  322. (255, 255, 255), line_width * 2, arg[2])
  323. maps.append(mask)
  324. whole_mask = cv2.GaussianBlur(whole_mask,
  325. (gaussian_kernel, gaussian_kernel),
  326. 0)
  327. maps.append(whole_mask)
  328. people_maps.append(maps)
  329. return people_maps[0], boxes
  330. def fat_face(self, img, degree=0.1):
  331. _img, scale = resize_on_long_side(img, 800)
  332. contour_maps, boxes = self.find_face_contour(_img)
  333. contour_map = contour_maps[0]
  334. boxes = boxes[0]
  335. Flow = np.zeros(
  336. shape=(contour_map.shape[0], contour_map.shape[1], 2),
  337. dtype=np.float32)
  338. box_center = [(boxes['x1'] + boxes['x2']) / 2,
  339. (boxes['y1'] + boxes['y2']) / 2]
  340. box_length = max(
  341. abs(boxes['y1'] - boxes['y2']), abs(boxes['x1'] - boxes['x2']))
  342. value_1 = 2 * (Flow.shape[0] - box_center[1] - 1)
  343. value_2 = 2 * (Flow.shape[1] - box_center[0] - 1)
  344. value_list = [
  345. box_length * 2, 2 * (box_center[0] - 1), 2 * (box_center[1] - 1),
  346. value_1, value_2
  347. ]
  348. flow_box_length = min(value_list)
  349. flow_box_length = int(flow_box_length)
  350. sf = spread_flow(100, flow_box_length * degree)
  351. sf = cv2.resize(sf, (flow_box_length, flow_box_length))
  352. Flow[int(box_center[1]
  353. - flow_box_length / 2):int(box_center[1]
  354. + flow_box_length / 2),
  355. int(box_center[0]
  356. - flow_box_length / 2):int(box_center[0]
  357. + flow_box_length / 2)] = sf
  358. Flow = Flow * np.dstack((contour_map, contour_map)) / 255.0
  359. inter_face_maps = contour_maps[-1]
  360. Flow = Flow * (1.0 - np.dstack(
  361. (inter_face_maps, inter_face_maps)) / 255.0)
  362. Flow = cv2.resize(Flow, (img.shape[1], img.shape[0]))
  363. Flow = Flow / scale
  364. pred, top_bound, bottom_bound, left_bound, right_bound = image_warp_grid1(
  365. Flow[..., 0], Flow[..., 1], img, 1.0, [0, 0, 0, 0])
  366. return pred
  367. def predict_base(self, img):
  368. if img.shape[0] > 2000 or img.shape[1] > 2000:
  369. img, _ = resize_on_long_side(img, 1500)
  370. box, results = self.infer_lmks(img)
  371. if results is None or np.array(results).shape[0] == 0:
  372. return {}
  373. landmarks = []
  374. results = results[0]
  375. for idx in [74, 83, 54, 84, 90]:
  376. landmarks.append([results[idx][0], results[idx][1]])
  377. landmarks = np.array(landmarks)
  378. landmarks = self.prepare_data(img, self.lm_sess, five_points=landmarks)
  379. im_tensor, lm_tensor, im_hd_tensor, lm_hd_tensor, mask = self.read_data(
  380. img, landmarks, self.lm3d_std, image_res=512)
  381. data = {
  382. 'imgs': im_tensor,
  383. 'imgs_hd': im_hd_tensor,
  384. 'lms': lm_tensor,
  385. 'lms_hd': lm_hd_tensor,
  386. 'face_mask': mask,
  387. }
  388. self.model.set_input_base(data) # unpack data from data loader
  389. output = self.model.predict_results_base() # run inference
  390. return output
  391. def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
  392. rgb_image = input['img'].cpu().numpy().astype(np.uint8)
  393. bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
  394. img = bgr_image
  395. base_model_output = self.predict_base(img)
  396. input_img_for_tex = self.get_img_for_texture(
  397. base_model_output['input_img'])
  398. hrn_input = {
  399. 'input_img': base_model_output['input_img'],
  400. 'input_img_for_tex': input_img_for_tex,
  401. 'input_img_hd': base_model_output['input_img_hd'],
  402. 'face_mask': base_model_output['face_mask'],
  403. 'gt_lm': base_model_output['gt_lm'],
  404. 'coeffs': base_model_output['coeffs'],
  405. 'position_map': base_model_output['position_map'],
  406. 'texture_map': base_model_output['texture_map'],
  407. 'tex_valid_mask': base_model_output['tex_valid_mask'],
  408. 'de_retouched_albedo_map':
  409. base_model_output['de_retouched_albedo_map']
  410. }
  411. self.model.set_input_hrn(hrn_input)
  412. self.model.get_edge_points_horizontal()
  413. self.model(visualize=True)
  414. results = self.model.save_results_hrn()
  415. texture_map = results['texture_map']
  416. results = {
  417. 'mesh': results['face_mesh'],
  418. 'vis_image': results['vis_image'],
  419. 'frame_list': results['frame_list'],
  420. }
  421. return {
  422. OutputKeys.OUTPUT_OBJ: None,
  423. OutputKeys.OUTPUT_IMG: texture_map,
  424. OutputKeys.OUTPUT: results
  425. }
  426. def postprocess(self, inputs, **kwargs) -> Dict[str, Any]:
  427. render = kwargs.get('render', False)
  428. output_obj = inputs[OutputKeys.OUTPUT_OBJ]
  429. texture_map = inputs[OutputKeys.OUTPUT_IMG]
  430. results = inputs[OutputKeys.OUTPUT]
  431. if render:
  432. output_obj = io.BytesIO()
  433. mesh_str = mesh_to_string(results['mesh'])
  434. mesh_bytes = mesh_str.encode(encoding='utf-8')
  435. output_obj.write(mesh_bytes)
  436. result = {
  437. OutputKeys.OUTPUT_OBJ: output_obj,
  438. OutputKeys.OUTPUT_IMG: texture_map,
  439. OutputKeys.OUTPUT: None if render else results,
  440. }
  441. return result