utils.py 31 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import math
  4. import os
  5. import os.path as osp
  6. from array import array
  7. import cv2
  8. import numba
  9. import numpy as np
  10. import torch
  11. import torch.nn.functional as F
  12. from PIL import Image
  13. from scipy.io import loadmat, savemat
  14. def img_value_rescale(img, old_range: list, new_range: list):
  15. assert len(old_range) == 2
  16. assert len(new_range) == 2
  17. img = (img - old_range[0]) / (old_range[1] - old_range[0]) * (
  18. new_range[1] - new_range[0]) + new_range[0]
  19. return img
  20. def resize_on_long_side(img, long_side=800):
  21. src_height = img.shape[0]
  22. src_width = img.shape[1]
  23. if src_height > src_width:
  24. scale = long_side * 1.0 / src_height
  25. _img = cv2.resize(
  26. img, (int(src_width * scale), long_side),
  27. interpolation=cv2.INTER_CUBIC)
  28. else:
  29. scale = long_side * 1.0 / src_width
  30. _img = cv2.resize(
  31. img, (long_side, int(src_height * scale)),
  32. interpolation=cv2.INTER_CUBIC)
  33. return _img, scale
  34. def get_mg_layer(src, gt, skin_mask=None):
  35. """
  36. src, gt shape: [h, w, 3] value: [0, 1]
  37. return: mg, shape: [h, w, 1] value: [0, 1]
  38. """
  39. mg = (src * src - gt + 1e-10) / (2 * src * src - 2 * src + 2e-10)
  40. mg[mg < 0] = 0.5
  41. mg[mg > 1] = 0.5
  42. diff_abs = np.abs(gt - src)
  43. mg[diff_abs < (1 / 255.0)] = 0.5
  44. if skin_mask is not None:
  45. mg[skin_mask == 0] = 0.5
  46. return mg
  47. def str2bool(v):
  48. if isinstance(v, bool):
  49. return v
  50. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  51. return True
  52. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  53. return False
  54. else:
  55. raise argparse.ArgumentTypeError('Boolean value expected.')
  56. def spread_flow(length, spread_ratio=2):
  57. Flow = np.zeros(shape=(length, length, 2), dtype=np.float32)
  58. mag = np.zeros(shape=(length, length), dtype=np.float32)
  59. radius = length * 0.5
  60. for h in range(Flow.shape[0]):
  61. for w in range(Flow.shape[1]):
  62. if (h - length // 2)**2 + (w - length // 2)**2 <= radius**2:
  63. Flow[h, w, 0] = -(w - length // 2)
  64. Flow[h, w, 1] = -(h - length // 2)
  65. distance = np.sqrt((w - length // 2)**2 + (h - length // 2)**2)
  66. if distance <= radius / 2.0:
  67. mag[h, w] = 2.0 / radius * distance
  68. else:
  69. mag[h, w] = -2.0 / radius * distance + 2.0
  70. _, ang = cv2.cartToPolar(Flow[..., 0] + 1e-8, Flow[..., 1] + 1e-8)
  71. mag *= spread_ratio
  72. x, y = cv2.polarToCart(mag, ang, angleInDegrees=False)
  73. Flow = np.dstack((x, y))
  74. return Flow
  75. @numba.jit(nopython=True, parallel=True)
  76. def bilinear_interp(x, y, v11, v12, v21, v22):
  77. t = 0.2
  78. if x < t and y < t:
  79. return v11
  80. elif x < t and y > 1 - t:
  81. return v12
  82. elif x > 1 - t and y < t:
  83. return v21
  84. elif x > 1 - t and y > 1 - t:
  85. return v22
  86. else:
  87. result = (v11 * (1 - y) + v12 * y) * (1 - x) + \
  88. (v21 * (1 - y) + v22 * y) * x
  89. if result < 0:
  90. result = 0
  91. if result > 255:
  92. result = 255
  93. return result
  94. @numba.jit(nopython=True, parallel=True)
  95. def image_warp_grid1(rDx, rDy, oriImg, transRatio, pads):
  96. # assert oriImg.dtype == np.uint8
  97. srcW = oriImg.shape[1]
  98. srcH = oriImg.shape[0]
  99. padTop, padBottom, padLeft, padRight = pads
  100. left_bound = padLeft + 1
  101. right_bound = srcW - padRight
  102. bottom_bound = srcH - padBottom
  103. top_bound = padTop + 1
  104. newImg = oriImg.copy()
  105. for i in range(srcH):
  106. for j in range(srcW):
  107. _i = i
  108. _j = j
  109. deltaX = rDx[_i, _j]
  110. deltaY = rDy[_i, _j]
  111. if abs(deltaX) < 0.2 and abs(deltaY) < 0.2:
  112. continue
  113. nx = _j + deltaX * transRatio
  114. ny = _i + deltaY * transRatio
  115. if nx >= srcW - padRight:
  116. if nx > srcW - 1:
  117. nx = srcW - 1
  118. if _j < right_bound:
  119. right_bound = _j
  120. if ny >= srcH - padBottom:
  121. if ny > srcH - 1:
  122. ny = srcH - 1
  123. if _i < bottom_bound:
  124. bottom_bound = _i
  125. if nx < padLeft:
  126. if nx < 0:
  127. nx = 0
  128. if _j + 1 > left_bound:
  129. left_bound = _j + 1
  130. if ny < padTop:
  131. if ny < 0:
  132. ny = 0
  133. if _i + 1 > top_bound:
  134. top_bound = _i + 1
  135. nxi = int(math.floor(nx))
  136. nyi = int(math.floor(ny))
  137. nxi1 = int(math.ceil(nx))
  138. nyi1 = int(math.ceil(ny))
  139. if nxi < 0:
  140. nxi = 0
  141. if nxi > oriImg.shape[1] - 1:
  142. nxi = oriImg.shape[1] - 1
  143. if nxi1 < 0:
  144. nxi1 = 0
  145. if nxi1 > oriImg.shape[1] - 1:
  146. nxi1 = oriImg.shape[1] - 1
  147. if nyi < 0:
  148. nyi = 0
  149. if nyi > oriImg.shape[0] - 1:
  150. nyi = oriImg.shape[0] - 1
  151. if nyi1 < 0:
  152. nyi1 = 0
  153. if nyi1 > oriImg.shape[0] - 1:
  154. nyi1 = oriImg.shape[0] - 1
  155. for ll in range(3):
  156. newImg[_i, _j,
  157. ll] = bilinear_interp(ny - nyi, nx - nxi,
  158. oriImg[nyi, nxi,
  159. ll], oriImg[nyi, nxi1, ll],
  160. oriImg[nyi1, nxi,
  161. ll], oriImg[nyi1, nxi1,
  162. ll])
  163. return newImg, top_bound, bottom_bound, left_bound, right_bound
  164. def warp(x, flow, mode='bilinear', padding_mode='zeros', coff=0.1):
  165. """
  166. Args:
  167. x: [n, c, h, w]
  168. flow: [n, h, w, 2]
  169. mode:
  170. padding_mode:
  171. coff:
  172. Returns:
  173. """
  174. n, c, h, w = x.size()
  175. yv, xv = torch.meshgrid([torch.arange(h), torch.arange(w)])
  176. xv = xv.float() / (w - 1) * 2.0 - 1
  177. yv = yv.float() / (h - 1) * 2.0 - 1
  178. '''
  179. grid[0,:,:,0] =
  180. -1, .....1
  181. -1, .....1
  182. -1, .....1
  183. grid[0,:,:,1] =
  184. -1, -1, -1
  185. ; ;
  186. 1, 1, 1
  187. '''
  188. if torch.cuda.is_available():
  189. grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)),
  190. -1).unsqueeze(0).cuda()
  191. else:
  192. grid = torch.cat((xv.unsqueeze(-1), yv.unsqueeze(-1)), -1).unsqueeze(0)
  193. grid_x = grid + 2 * flow * coff
  194. warp_x = F.grid_sample(x, grid_x, mode=mode, padding_mode=padding_mode)
  195. return warp_x
  196. # load expression basis
  197. def LoadExpBasis(bfm_folder='asset/BFM'):
  198. n_vertex = 53215
  199. Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
  200. exp_dim = array('i')
  201. exp_dim.fromfile(Expbin, 1)
  202. expMU = array('f')
  203. expPC = array('f')
  204. expMU.fromfile(Expbin, 3 * n_vertex)
  205. expPC.fromfile(Expbin, 3 * exp_dim[0] * n_vertex)
  206. Expbin.close()
  207. expPC = np.array(expPC)
  208. expPC = np.reshape(expPC, [exp_dim[0], -1])
  209. expPC = np.transpose(expPC)
  210. expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
  211. return expPC, expEV
  212. # transfer original BFM09 to our face model
  213. def transferBFM09(bfm_folder='BFM'):
  214. print('Transfer BFM09 to BFM_model_front......')
  215. original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
  216. shapePC = original_BFM['shapePC'] # shape basis, 160470*199
  217. shapeEV = original_BFM['shapeEV'] # corresponding eigen value, 199*1
  218. shapeMU = original_BFM['shapeMU'] # mean face, 160470*1
  219. texPC = original_BFM['texPC'] # texture basis, 160470*199
  220. texEV = original_BFM['texEV'] # eigen value, 199*1
  221. texMU = original_BFM['texMU'] # mean texture, 160470*1
  222. expPC, expEV = LoadExpBasis()
  223. # transfer BFM09 to our face model
  224. idBase = shapePC * np.reshape(shapeEV, [-1, 199])
  225. idBase = idBase / 1e5 # unify the scale to decimeter
  226. idBase = idBase[:, :80] # use only first 80 basis
  227. exBase = expPC * np.reshape(expEV, [-1, 79])
  228. exBase = exBase / 1e5 # unify the scale to decimeter
  229. exBase = exBase[:, :64] # use only first 64 basis
  230. texBase = texPC * np.reshape(texEV, [-1, 199])
  231. texBase = texBase[:, :80] # use only first 80 basis
  232. # our face model is cropped along face landmarks and contains only 35709 vertex.
  233. # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
  234. # thus we select corresponding vertex to get our face model.
  235. index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
  236. index_exp = index_exp['idx'].astype(
  237. np.int32) - 1 # starts from 0 (to 53215)
  238. index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
  239. index_shape = index_shape['trimIndex'].astype(
  240. np.int32) - 1 # starts from 0 (to 53490)
  241. index_shape = index_shape[index_exp]
  242. idBase = np.reshape(idBase, [-1, 3, 80])
  243. idBase = idBase[index_shape, :, :]
  244. idBase = np.reshape(idBase, [-1, 80])
  245. texBase = np.reshape(texBase, [-1, 3, 80])
  246. texBase = texBase[index_shape, :, :]
  247. texBase = np.reshape(texBase, [-1, 80])
  248. exBase = np.reshape(exBase, [-1, 3, 64])
  249. exBase = exBase[index_exp, :, :]
  250. exBase = np.reshape(exBase, [-1, 64])
  251. meanshape = np.reshape(shapeMU, [-1, 3]) / 1e5
  252. meanshape = meanshape[index_shape, :]
  253. meanshape = np.reshape(meanshape, [1, -1])
  254. meantex = np.reshape(texMU, [-1, 3])
  255. meantex = meantex[index_shape, :]
  256. meantex = np.reshape(meantex, [1, -1])
  257. # other info contains triangles, region used for computing photometric loss,
  258. # region used for skin texture regularization, and 68 landmarks index etc.
  259. other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
  260. frontmask2_idx = other_info['frontmask2_idx']
  261. skinmask = other_info['skinmask']
  262. keypoints = other_info['keypoints']
  263. point_buf = other_info['point_buf']
  264. tri = other_info['tri']
  265. tri_mask2 = other_info['tri_mask2']
  266. # save our face model
  267. savemat(
  268. osp.join(bfm_folder, 'BFM_model_front.mat'), {
  269. 'meanshape': meanshape,
  270. 'meantex': meantex,
  271. 'idBase': idBase,
  272. 'exBase': exBase,
  273. 'texBase': texBase,
  274. 'tri': tri,
  275. 'point_buf': point_buf,
  276. 'tri_mask2': tri_mask2,
  277. 'keypoints': keypoints,
  278. 'frontmask2_idx': frontmask2_idx,
  279. 'skinmask': skinmask
  280. })
  281. # load landmarks for standard face, which is used for image preprocessing
  282. def load_lm3d(bfm_folder):
  283. Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
  284. Lm3D = Lm3D['lm']
  285. # calculate 5 facial landmarks using 68 landmarks
  286. lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
  287. value_list = [
  288. Lm3D[lm_idx[0], :],
  289. np.mean(Lm3D[lm_idx[[1, 2]], :], 0),
  290. np.mean(Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :],
  291. Lm3D[lm_idx[6], :]
  292. ]
  293. Lm3D = np.stack(value_list, axis=0)
  294. Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
  295. return Lm3D
  296. def mesh_to_string(mesh):
  297. out_string = ''
  298. out_string += '# Create by HRN\n'
  299. if 'colors' in mesh:
  300. for i, v in enumerate(mesh['vertices']):
  301. out_string += \
  302. 'v {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(
  303. v[0], v[1], v[2], mesh['colors'][i][0],
  304. mesh['colors'][i][1], mesh['colors'][i][2])
  305. else:
  306. for v in mesh['vertices']:
  307. out_string += 'v {:.6f} {:.6f} {:.6f}\n'.format(v[0], v[1], v[2])
  308. if 'UVs' in mesh:
  309. for uv in mesh['UVs']:
  310. out_string += 'vt {:.6f} {:.6f}\n'.format(uv[0], uv[1])
  311. if 'normals' in mesh:
  312. for vn in mesh['normals']:
  313. out_string += 'vn {:.6f} {:.6f} {:.6f}\n'.format(
  314. vn[0], vn[1], vn[2])
  315. if 'faces' in mesh:
  316. for ind, face in enumerate(mesh['faces']):
  317. if 'faces_uv' in mesh or 'faces_normal' in mesh or 'UVs' in mesh:
  318. if 'faces_uv' in mesh:
  319. face_uv = mesh['faces_uv'][ind]
  320. else:
  321. face_uv = face
  322. if 'faces_normal' in mesh:
  323. face_normal = mesh['faces_normal'][ind]
  324. else:
  325. face_normal = face
  326. row = 'f ' + ' '.join([
  327. '{}/{}/{}'.format(face[i], face_uv[i], face_normal[i])
  328. for i in range(len(face))
  329. ]) + '\n'
  330. else:
  331. row = 'f ' + ' '.join(
  332. ['{}'.format(face[i]) for i in range(len(face))]) + '\n'
  333. out_string += row
  334. return out_string
  335. def write_obj(save_path, mesh):
  336. save_dir = os.path.dirname(save_path)
  337. save_name = os.path.splitext(os.path.basename(save_path))[0]
  338. if 'texture_map' in mesh:
  339. cv2.imwrite(
  340. os.path.join(save_dir, save_name + '.jpg'), mesh['texture_map'])
  341. with open(os.path.join(save_dir, save_name + '.mtl'), 'w') as wf:
  342. wf.write('# Created by HRN\n')
  343. wf.write('newmtl material_0\n')
  344. wf.write('Ka 1.000000 0.000000 0.000000\n')
  345. wf.write('Kd 1.000000 1.000000 1.000000\n')
  346. wf.write('Ks 0.000000 0.000000 0.000000\n')
  347. wf.write('Tr 0.000000\n')
  348. wf.write('illum 0\n')
  349. wf.write('Ns 0.000000\n')
  350. wf.write('map_Kd {}\n'.format(save_name + '.jpg'))
  351. with open(save_path, 'w') as wf:
  352. if 'texture_map' in mesh:
  353. wf.write('# Create by HRN\n')
  354. wf.write('mtllib ./{}.mtl\n'.format(save_name))
  355. if 'colors' in mesh:
  356. for i, v in enumerate(mesh['vertices']):
  357. wf.write(
  358. 'v {:.6f} {:.6f} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(
  359. v[0], v[1], v[2], mesh['colors'][i][0],
  360. mesh['colors'][i][1], mesh['colors'][i][2]))
  361. else:
  362. for v in mesh['vertices']:
  363. wf.write('v {:.6f} {:.6f} {:.6f}\n'.format(v[0], v[1], v[2]))
  364. if 'UVs' in mesh:
  365. for uv in mesh['UVs']:
  366. wf.write('vt {:.6f} {:.6f}\n'.format(uv[0], uv[1]))
  367. if 'normals' in mesh:
  368. for vn in mesh['normals']:
  369. wf.write('vn {:.6f} {:.6f} {:.6f}\n'.format(
  370. vn[0], vn[1], vn[2]))
  371. if 'faces' in mesh:
  372. for ind, face in enumerate(mesh['faces']):
  373. if 'faces_uv' in mesh or 'faces_normal' in mesh or 'UVs' in mesh:
  374. if 'faces_uv' in mesh:
  375. face_uv = mesh['faces_uv'][ind]
  376. else:
  377. face_uv = face
  378. if 'faces_normal' in mesh:
  379. face_normal = mesh['faces_normal'][ind]
  380. else:
  381. face_normal = face
  382. row = 'f ' + ' '.join([
  383. '{}/{}/{}'.format(face[i], face_uv[i], face_normal[i])
  384. for i in range(len(face))
  385. ]) + '\n'
  386. else:
  387. row = 'f ' + ' '.join(
  388. ['{}'.format(face[i])
  389. for i in range(len(face))]) + '\n'
  390. wf.write(row)
  391. def read_obj(obj_path, print_shape=True):
  392. with open(obj_path, 'r') as f:
  393. bfm_lines = f.readlines()
  394. vertices = []
  395. faces = []
  396. uvs = []
  397. vns = []
  398. faces_uv = []
  399. faces_normal = []
  400. max_face_length = 0
  401. for line in bfm_lines:
  402. if line[:2] == 'v ':
  403. vertex = [
  404. float(a) for a in line.strip().split(' ')[1:] if len(a) > 0
  405. ]
  406. vertices.append(vertex)
  407. if line[:2] == 'f ':
  408. items = line.strip().split(' ')[1:]
  409. face = [int(a.split('/')[0]) for a in items if len(a) > 0]
  410. max_face_length = max(max_face_length, len(face))
  411. if len(faces) > 0 and len(face) != len(faces[0]):
  412. continue
  413. faces.append(face)
  414. if '/' in items[0] and len(items[0].split('/')[1]) > 0:
  415. face_uv = [int(a.split('/')[1]) for a in items if len(a) > 0]
  416. faces_uv.append(face_uv)
  417. if '/' in items[0] and len(items[0].split('/')) >= 3 and len(
  418. items[0].split('/')[2]) > 0:
  419. face_normal = [
  420. int(a.split('/')[2]) for a in items if len(a) > 0
  421. ]
  422. faces_normal.append(face_normal)
  423. if line[:3] == 'vt ':
  424. items = line.strip().split(' ')[1:]
  425. uv = [float(a) for a in items if len(a) > 0]
  426. uvs.append(uv)
  427. if line[:3] == 'vn ':
  428. items = line.strip().split(' ')[1:]
  429. vn = [float(a) for a in items if len(a) > 0]
  430. vns.append(vn)
  431. vertices = np.array(vertices).astype(np.float32)
  432. if max_face_length <= 3:
  433. faces = np.array(faces).astype(np.int32)
  434. if vertices.shape[1] == 3:
  435. mesh = {
  436. 'vertices': vertices,
  437. 'faces': faces,
  438. }
  439. else:
  440. mesh = {
  441. 'vertices': vertices[:, :3],
  442. 'colors': vertices[:, 3:],
  443. 'faces': faces,
  444. }
  445. if len(uvs) > 0:
  446. uvs = np.array(uvs).astype(np.float32)
  447. mesh['uvs'] = uvs
  448. if len(vns) > 0:
  449. vns = np.array(vns).astype(np.float32)
  450. mesh['vns'] = vns
  451. if len(faces_uv) > 0:
  452. if max_face_length <= 3:
  453. faces_uv = np.array(faces_uv).astype(np.int32)
  454. mesh['faces_uv'] = faces_uv
  455. if len(faces_normal) > 0:
  456. if max_face_length <= 3:
  457. faces_normal = np.array(faces_normal).astype(np.int32)
  458. mesh['faces_normal'] = faces_normal
  459. return mesh
  460. # calculating least square problem for image alignment
  461. def POS(xp, x):
  462. npts = xp.shape[1]
  463. A = np.zeros([2 * npts, 8])
  464. A[0:2 * npts - 1:2, 0:3] = x.transpose()
  465. A[0:2 * npts - 1:2, 3] = 1
  466. A[1:2 * npts:2, 4:7] = x.transpose()
  467. A[1:2 * npts:2, 7] = 1
  468. b = np.reshape(xp.transpose(), [2 * npts, 1])
  469. k, _, _, _ = np.linalg.lstsq(A, b)
  470. R1 = k[0:3]
  471. R2 = k[4:7]
  472. sTx = k[3]
  473. sTy = k[7]
  474. s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2
  475. t = np.stack([sTx, sTy], axis=0)
  476. return t, s
  477. # bounding box for 68 landmark detection
  478. def BBRegression(points, params):
  479. w1 = params['W1']
  480. b1 = params['B1']
  481. w2 = params['W2']
  482. b2 = params['B2']
  483. data = points.copy()
  484. data = data.reshape([5, 2])
  485. data_mean = np.mean(data, axis=0)
  486. x_mean = data_mean[0]
  487. y_mean = data_mean[1]
  488. data[:, 0] = data[:, 0] - x_mean
  489. data[:, 1] = data[:, 1] - y_mean
  490. rms = np.sqrt(np.sum(data**2) / 5)
  491. data = data / rms
  492. data = data.reshape([1, 10])
  493. data = np.transpose(data)
  494. inputs = np.matmul(w1, data) + b1
  495. inputs = 2 / (1 + np.exp(-2 * inputs)) - 1
  496. inputs = np.matmul(w2, inputs) + b2
  497. inputs = np.transpose(inputs)
  498. x = inputs[:, 0] * rms + x_mean
  499. y = inputs[:, 1] * rms + y_mean
  500. w = 224 / inputs[:, 2] * rms
  501. rects = [x, y, w, w]
  502. return np.array(rects).reshape([4])
  503. # utils for landmark detection
  504. def img_padding(img, box):
  505. success = True
  506. bbox = box.copy()
  507. res = np.zeros([2 * img.shape[0], 2 * img.shape[1], 3])
  508. res[img.shape[0] // 2:img.shape[0] + img.shape[0] // 2,
  509. img.shape[1] // 2:img.shape[1] + img.shape[1] // 2] = img
  510. bbox[0] = bbox[0] + img.shape[1] // 2
  511. bbox[1] = bbox[1] + img.shape[0] // 2
  512. if bbox[0] < 0 or bbox[1] < 0:
  513. success = False
  514. return res, bbox, success
  515. # utils for landmark detection
  516. def crop(img, bbox):
  517. padded_img, padded_bbox, flag = img_padding(img, bbox)
  518. if flag:
  519. crop_img = padded_img[padded_bbox[1]:padded_bbox[1] + padded_bbox[3],
  520. padded_bbox[0]:padded_bbox[0] + padded_bbox[2]]
  521. crop_img = cv2.resize(
  522. crop_img.astype(np.uint8), (224, 224),
  523. interpolation=cv2.INTER_CUBIC)
  524. scale = 224 / padded_bbox[3]
  525. return crop_img, scale
  526. else:
  527. return padded_img, 0
  528. # utils for landmark detection
  529. def scale_trans(img, lm, t, s):
  530. imgw = img.shape[1]
  531. imgh = img.shape[0]
  532. M_s = np.array(
  533. [[1, 0, -t[0] + imgw // 2 + 0.5], [0, 1, -imgh // 2 + t[1]]],
  534. dtype=np.float32)
  535. img = cv2.warpAffine(img, M_s, (imgw, imgh))
  536. w = int(imgw / s * 100)
  537. h = int(imgh / s * 100)
  538. img = cv2.resize(img, (w, h))
  539. lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] - t[1] + imgh // 2],
  540. axis=1) / s * 100
  541. left = w // 2 - 112
  542. up = h // 2 - 112
  543. bbox = [left, up, 224, 224]
  544. cropped_img, scale2 = crop(img, bbox)
  545. assert (scale2 != 0)
  546. t1 = np.array([bbox[0], bbox[1]])
  547. # back to raw img s * crop + s * t1 + t2
  548. t1 = np.array([w // 2 - 112, h // 2 - 112])
  549. scale = s / 100
  550. t2 = np.array([t[0] - imgw / 2, t[1] - imgh / 2])
  551. inv = (scale / scale2, scale * t1 + t2.reshape([2]))
  552. return cropped_img, inv
  553. # utils for landmark detection
  554. def align_for_lm(img, five_points, params):
  555. five_points = np.array(five_points).reshape([1, 10])
  556. bbox = BBRegression(five_points, params)
  557. assert (bbox[2] != 0)
  558. bbox = np.round(bbox).astype(np.int32)
  559. crop_img, scale = crop(img, bbox)
  560. return crop_img, scale, bbox
  561. # resize and crop images for face reconstruction
  562. def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
  563. w0, h0 = img.size
  564. w = (w0 * s).astype(np.int32)
  565. h = (h0 * s).astype(np.int32)
  566. left = (w / 2 - target_size / 2 + float(
  567. (t[0] - w0 / 2) * s)).astype(np.int32)
  568. right = left + target_size
  569. up = (h / 2 - target_size / 2 + float(
  570. (h0 / 2 - t[1]) * s)).astype(np.int32)
  571. below = up + target_size
  572. new_img = img.resize((w, h), resample=Image.BICUBIC)
  573. new_img = new_img.crop((left, up, right, below))
  574. if mask is not None:
  575. mask = mask.resize((w, h), resample=Image.BICUBIC)
  576. mask = mask.crop((left, up, right, below))
  577. new_lm = np.stack([lm[:, 0] - t[0] + w0 / 2, lm[:, 1] - t[1] + h0 / 2],
  578. axis=1) * s
  579. new_lm = new_lm - np.reshape(
  580. np.array([(w / 2 - target_size / 2),
  581. (h / 2 - target_size / 2)]), [1, 2])
  582. return new_img, new_lm, mask
  583. # utils for face reconstruction
  584. def extract_5p(lm):
  585. lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
  586. value_list = [
  587. lm[lm_idx[0], :],
  588. np.mean(lm[lm_idx[[1, 2]], :], 0),
  589. np.mean(lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]
  590. ]
  591. lm5p = np.stack(value_list, axis=0)
  592. lm5p = lm5p[[1, 2, 0, 3, 4], :]
  593. return lm5p
  594. # utils for face reconstruction
  595. def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
  596. """
  597. Return:
  598. transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
  599. img_new --PIL.Image (target_size, target_size, 3)
  600. lm_new --numpy.array (68, 2), y direction is opposite to v direction
  601. mask_new --PIL.Image (target_size, target_size)
  602. Parameters:
  603. img --PIL.Image (raw_H, raw_W, 3)
  604. lm --numpy.array (68, 2), y direction is opposite to v direction
  605. lm3D --numpy.array (5, 3)
  606. mask --PIL.Image (raw_H, raw_W, 3)
  607. """
  608. w0, h0 = img.size
  609. if lm.shape[0] != 5:
  610. lm5p = extract_5p(lm)
  611. else:
  612. lm5p = lm
  613. # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
  614. t, s = POS(lm5p.transpose(), lm3D.transpose())
  615. t = t.squeeze()
  616. s = rescale_factor / s
  617. # processing the image
  618. img_new, lm_new, mask_new = resize_n_crop_img(
  619. img, lm, t, s, target_size=target_size, mask=mask)
  620. trans_params = np.array([w0, h0, s, t[0], t[1]])
  621. return trans_params, img_new, lm_new, mask_new
  622. def normalize_v3(arr):
  623. ''' Normalize a numpy array of 3 component vectors shape=(n,3) '''
  624. lens = np.sqrt(arr[:, 0]**2 + arr[:, 1]**2 + arr[:, 2]**2)[:, None]
  625. arr /= lens
  626. return arr
  627. def estimate_normals(vertices, faces):
  628. norm = np.zeros(vertices.shape, dtype=vertices.dtype)
  629. tris = vertices[faces]
  630. n = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0])
  631. n[(n[:, 0] == 0) * (n[:, 1] == 0) * (n[:, 2] == 0)] = [0, 0, 1.0]
  632. n = normalize_v3(n)
  633. for i in range(3):
  634. for j in range(faces.shape[0]):
  635. norm[faces[j, i]] += n[j]
  636. inds = (norm[:, 0] == 0) * (norm[:, 1] == 0) * (norm[:, 2] == 0)
  637. norm[inds] = [0, 0, 1.0]
  638. result = normalize_v3(norm)
  639. return result
  640. def draw_landmarks(img, landmark, color='r', step=2):
  641. """
  642. Return:
  643. img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
  644. Parameters:
  645. img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
  646. landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
  647. color -- str, 'r' or 'b' (red or blue)
  648. """
  649. if color == 'r':
  650. c = np.array([255., 0, 0])
  651. else:
  652. c = np.array([0, 0, 255.])
  653. _, H, W, _ = img.shape
  654. img, landmark = img.copy(), landmark.copy()
  655. landmark[..., 1] = H - 1 - landmark[..., 1]
  656. landmark = np.round(landmark).astype(np.int32)
  657. for i in range(landmark.shape[1]):
  658. x, y = landmark[:, i, 0], landmark[:, i, 1]
  659. for j in range(-step, step):
  660. for k in range(-step, step):
  661. u = np.clip(x + j, 0, W - 1)
  662. v = np.clip(y + k, 0, H - 1)
  663. for m in range(landmark.shape[0]):
  664. img[m, v[m], u[m]] = c
  665. return img
  666. def split_vis(img_path, target_dir=None):
  667. img = cv2.imread(img_path)
  668. h, w = img.shape[:2]
  669. n_split = w // h
  670. if target_dir is None:
  671. target_dir = os.path.dirname(img_path)
  672. base_name = os.path.splitext(os.path.basename(img_path))[0]
  673. for i in range(n_split):
  674. img_i = img[:, i * h:(i + 1) * h, :]
  675. cv2.imwrite(
  676. os.path.join(target_dir, '{}_{:0>2d}.jpg'.format(base_name,
  677. i + 1)), img_i)
  678. def write_video(image_list, save_path, fps=20.0):
  679. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  680. # fourcc = cv2.VideoWriter_fourcc(*'MJPG') # avi格式
  681. h, w = image_list[0].shape[:2]
  682. out = cv2.VideoWriter(save_path, fourcc, fps, (w, h), True)
  683. for frame in image_list:
  684. out.write(frame)
  685. out.release()
  686. # ---------------------------- process/generate vertices, normals, faces
  687. def generate_triangles(h, w, margin_x=2, margin_y=5, mask=None):
  688. # quad layout:
  689. # 0 1 ... w-1
  690. # w w+1
  691. # .
  692. # w*h
  693. triangles = []
  694. for x in range(margin_x, w - 1 - margin_x):
  695. for y in range(margin_y, h - 1 - margin_y):
  696. triangle0 = [y * w + x, y * w + x + 1, (y + 1) * w + x]
  697. triangle1 = [y * w + x + 1, (y + 1) * w + x + 1, (y + 1) * w + x]
  698. triangles.append(triangle0)
  699. triangles.append(triangle1)
  700. triangles = np.array(triangles)
  701. triangles = triangles[:, [0, 2, 1]]
  702. return triangles
  703. def face_vertices(vertices, faces):
  704. """
  705. :param vertices: [batch size, number of vertices, 3]
  706. :param faces: [batch size, number of faces, 3]
  707. :return: [batch size, number of faces, 3, 3]
  708. """
  709. assert (vertices.ndimension() == 3)
  710. assert (faces.ndimension() == 3)
  711. assert (vertices.shape[0] == faces.shape[0])
  712. assert (vertices.shape[2] == 3)
  713. assert (faces.shape[2] == 3)
  714. bs, nv = vertices.shape[:2]
  715. bs, nf = faces.shape[:2]
  716. device = vertices.device
  717. faces = faces + (torch.arange(bs, dtype=torch.int32).to(device)
  718. * nv)[:, None, None]
  719. vertices = vertices.reshape((bs * nv, 3))
  720. # pytorch only supports long and byte tensors for indexing
  721. return vertices[faces.long()]
  722. def vertex_normals(vertices, faces):
  723. """
  724. :param vertices: [batch size, number of vertices, 3]
  725. :param faces: [batch size, number of faces, 3]
  726. :return: [batch size, number of vertices, 3]
  727. """
  728. assert (vertices.ndimension() == 3)
  729. assert (faces.ndimension() == 3)
  730. assert (vertices.shape[0] == faces.shape[0])
  731. assert (vertices.shape[2] == 3)
  732. assert (faces.shape[2] == 3)
  733. bs, nv = vertices.shape[:2]
  734. bs, nf = faces.shape[:2]
  735. device = vertices.device
  736. normals = torch.zeros(bs * nv, 3).to(device)
  737. faces = faces + (torch.arange(bs, dtype=torch.int32).to(device)
  738. * nv)[:, None, None] # expanded faces
  739. vertices_faces = vertices.reshape((bs * nv, 3))[faces.long()]
  740. faces = faces.reshape(-1, 3)
  741. vertices_faces = vertices_faces.reshape(-1, 3, 3)
  742. normals.index_add_(
  743. 0, faces[:, 1].long(),
  744. torch.cross(vertices_faces[:, 2] - vertices_faces[:, 1],
  745. vertices_faces[:, 0] - vertices_faces[:, 1]))
  746. normals.index_add_(
  747. 0, faces[:, 2].long(),
  748. torch.cross(vertices_faces[:, 0] - vertices_faces[:, 2],
  749. vertices_faces[:, 1] - vertices_faces[:, 2]))
  750. normals.index_add_(
  751. 0, faces[:, 0].long(),
  752. torch.cross(vertices_faces[:, 1] - vertices_faces[:, 0],
  753. vertices_faces[:, 2] - vertices_faces[:, 0]))
  754. normals = F.normalize(normals, eps=1e-6, dim=1)
  755. normals = normals.reshape((bs, nv, 3))
  756. # pytorch only supports long and byte tensors for indexing
  757. return normals
  758. def dict2obj(d):
  759. # if isinstance(d, list):
  760. # d = [dict2obj(x) for x in d]
  761. if not isinstance(d, dict):
  762. return d
  763. class C(object):
  764. pass
  765. o = C()
  766. for k in d:
  767. o.__dict__[k] = dict2obj(d[k])
  768. return o
  769. def enlarged_bbox(bbox, img_width, img_height, enlarge_ratio=0.2):
  770. '''
  771. :param bbox: [xmin,ymin,xmax,ymax]
  772. :return: bbox: [xmin,ymin,xmax,ymax]
  773. '''
  774. left = bbox[0]
  775. top = bbox[1]
  776. right = bbox[2]
  777. bottom = bbox[3]
  778. roi_width = right - left
  779. roi_height = bottom - top
  780. new_left = left - int(roi_width * enlarge_ratio)
  781. new_left = 0 if new_left < 0 else new_left
  782. new_top = top - int(roi_height * enlarge_ratio)
  783. new_top = 0 if new_top < 0 else new_top
  784. new_right = right + int(roi_width * enlarge_ratio)
  785. new_right = img_width if new_right > img_width else new_right
  786. new_bottom = bottom + int(roi_height * enlarge_ratio)
  787. new_bottom = img_height if new_bottom > img_height else new_bottom
  788. bbox = [new_left, new_top, new_right, new_bottom]
  789. bbox = [int(x) for x in bbox]
  790. return bbox
  791. def draw_line(im, points, color, stroke_size=2, closed=False):
  792. points = points.astype(np.int32)
  793. for i in range(len(points) - 1):
  794. cv2.line(im, tuple(points[i]), tuple(points[i + 1]), color,
  795. stroke_size)
  796. if closed:
  797. cv2.line(im, tuple(points[0]), tuple(points[-1]), color, stroke_size)