utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import cv2
  4. import numpy as np
  5. import nvdiffrast.torch as dr
  6. import torch
  7. import torch.nn.functional as F
  8. def read_obj(obj_path, print_shape=False):
  9. with open(obj_path, 'r') as f:
  10. bfm_lines = f.readlines()
  11. vertices = []
  12. faces = []
  13. uvs = []
  14. vns = []
  15. faces_uv = []
  16. faces_normal = []
  17. max_face_length = 0
  18. for line in bfm_lines:
  19. if line[:2] == 'v ':
  20. vertex = [
  21. float(a) for a in line.strip().split(' ')[1:] if len(a) > 0
  22. ]
  23. vertices.append(vertex)
  24. if line[:2] == 'f ':
  25. items = line.strip().split(' ')[1:]
  26. face = [int(a.split('/')[0]) for a in items if len(a) > 0]
  27. max_face_length = max(max_face_length, len(face))
  28. faces.append(face)
  29. if '/' in items[0] and len(items[0].split('/')[1]) > 0:
  30. face_uv = [int(a.split('/')[1]) for a in items if len(a) > 0]
  31. faces_uv.append(face_uv)
  32. if '/' in items[0] and len(items[0].split('/')) >= 3 and len(
  33. items[0].split('/')[2]) > 0:
  34. face_normal = [
  35. int(a.split('/')[2]) for a in items if len(a) > 0
  36. ]
  37. faces_normal.append(face_normal)
  38. if line[:3] == 'vt ':
  39. items = line.strip().split(' ')[1:]
  40. uv = [float(a) for a in items if len(a) > 0]
  41. uvs.append(uv)
  42. if line[:3] == 'vn ':
  43. items = line.strip().split(' ')[1:]
  44. vn = [float(a) for a in items if len(a) > 0]
  45. vns.append(vn)
  46. vertices = np.array(vertices).astype(np.float32)
  47. if max_face_length <= 3:
  48. faces = np.array(faces).astype(np.int32)
  49. else:
  50. print('not a triangle face mesh!')
  51. if vertices.shape[1] == 3:
  52. mesh = {
  53. 'vertices': vertices,
  54. 'faces': faces,
  55. }
  56. else:
  57. mesh = {
  58. 'vertices': vertices[:, :3],
  59. 'colors': vertices[:, 3:],
  60. 'faces': faces,
  61. }
  62. if len(uvs) > 0:
  63. uvs = np.array(uvs).astype(np.float32)
  64. mesh['uvs'] = uvs
  65. if len(vns) > 0:
  66. vns = np.array(vns).astype(np.float32)
  67. mesh['normals'] = vns
  68. if len(faces_uv) > 0:
  69. if max_face_length <= 3:
  70. faces_uv = np.array(faces_uv).astype(np.int32)
  71. mesh['faces_uv'] = faces_uv
  72. if len(faces_normal) > 0:
  73. if max_face_length <= 3:
  74. faces_normal = np.array(faces_normal).astype(np.int32)
  75. mesh['faces_normal'] = faces_normal
  76. if print_shape:
  77. print('num of vertices', len(vertices))
  78. print('num of faces', len(faces))
  79. return mesh
  80. def write_obj(save_path, mesh):
  81. save_dir = os.path.dirname(save_path)
  82. save_name = os.path.splitext(os.path.basename(save_path))[0]
  83. if 'texture_map' in mesh:
  84. cv2.imwrite(
  85. os.path.join(save_dir, save_name + '.png'), mesh['texture_map'])
  86. with open(os.path.join(save_dir, save_name + '.mtl'), 'w') as wf:
  87. wf.write('newmtl material_0\n')
  88. wf.write('Ka 1.000000 0.000000 0.000000\n')
  89. wf.write('Kd 1.000000 1.000000 1.000000\n')
  90. wf.write('Ks 0.000000 0.000000 0.000000\n')
  91. wf.write('Tr 0.000000\n')
  92. wf.write('illum 0\n')
  93. wf.write('Ns 0.000000\n')
  94. wf.write('map_Kd {}\n'.format(save_name + '.png'))
  95. with open(save_path, 'w') as wf:
  96. if 'texture_map' in mesh:
  97. wf.write('# Create by ModelScope\n')
  98. wf.write('mtllib ./{}.mtl\n'.format(save_name))
  99. if 'colors' in mesh:
  100. for i, v in enumerate(mesh['vertices']):
  101. wf.write('v {} {} {} {} {} {}\n'.format(
  102. v[0], v[1], v[2], mesh['colors'][i][0],
  103. mesh['colors'][i][1], mesh['colors'][i][2]))
  104. else:
  105. for v in mesh['vertices']:
  106. wf.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
  107. if 'uvs' in mesh:
  108. for uv in mesh['uvs']:
  109. wf.write('vt {} {}\n'.format(uv[0], uv[1]))
  110. if 'normals' in mesh:
  111. for vn in mesh['normals']:
  112. wf.write('vn {} {} {}\n'.format(vn[0], vn[1], vn[2]))
  113. if 'faces' in mesh:
  114. for ind, face in enumerate(mesh['faces']):
  115. if 'faces_uv' in mesh or 'faces_normal' in mesh:
  116. if 'faces_uv' in mesh:
  117. face_uv = mesh['faces_uv'][ind]
  118. else:
  119. face_uv = face
  120. if 'faces_normal' in mesh:
  121. face_normal = mesh['faces_normal'][ind]
  122. else:
  123. face_normal = face
  124. row = 'f ' + ' '.join([
  125. '{}/{}/{}'.format(face[i], face_uv[i], face_normal[i])
  126. for i in range(len(face))
  127. ]) + '\n'
  128. else:
  129. row = 'f ' + ' '.join(
  130. ['{}'.format(face[i])
  131. for i in range(len(face))]) + '\n'
  132. wf.write(row)
  133. def projection(x=0.1, n=1.0, f=50.0):
  134. return np.array([[n / x, 0, 0, 0], [0, n / x, 0, 0],
  135. [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
  136. [0, 0, -1, 0]]).astype(np.float32)
  137. def translate(x, y, z):
  138. return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z],
  139. [0, 0, 0, 1]]).astype(np.float32)
  140. def rotate_x(a):
  141. s, c = np.sin(a), np.cos(a)
  142. return np.array([[1, 0, 0, 0], [0, c, s, 0], [0, -s, c, 0],
  143. [0, 0, 0, 1]]).astype(np.float32)
  144. def rotate_y(a):
  145. s, c = np.sin(a), np.cos(a)
  146. return np.array([[c, 0, s, 0], [0, 1, 0, 0], [-s, 0, c, 0],
  147. [0, 0, 0, 1]]).astype(np.float32)
  148. def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  149. return torch.sum(x * y, -1, keepdim=True)
  150. def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
  151. return 2 * dot(x, n) * n - x
  152. def length(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
  153. return torch.sqrt(torch.clamp(
  154. dot(x, x),
  155. min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
  156. def safe_normalize(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
  157. return x / length(x, eps)
  158. def transform_pos(mtx, pos):
  159. t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx,
  160. np.ndarray) else mtx
  161. posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)
  162. return torch.matmul(posw, t_mtx.t())[None, ...]
  163. def render(glctx, mtx, pos, pos_idx, uv, uv_idx, tex, resolution, enable_mip,
  164. max_mip_level):
  165. pos_clip = transform_pos(mtx, pos)
  166. rast_out, rast_out_db = dr.rasterize(
  167. glctx, pos_clip, pos_idx, resolution=[resolution, resolution])
  168. if enable_mip:
  169. texc, texd = dr.interpolate(
  170. uv[None, ...],
  171. rast_out,
  172. uv_idx,
  173. rast_db=rast_out_db,
  174. diff_attrs='all')
  175. color = dr.texture(
  176. tex[None, ...],
  177. texc,
  178. texd,
  179. filter_mode='linear-mipmap-linear',
  180. max_mip_level=max_mip_level)
  181. else:
  182. texc, _ = dr.interpolate(uv[None, ...], rast_out, uv_idx)
  183. color = dr.texture(tex[None, ...], texc, filter_mode='linear')
  184. pos_idx = pos_idx.type(torch.long)
  185. v0 = pos[pos_idx[:, 0], :]
  186. v1 = pos[pos_idx[:, 1], :]
  187. v2 = pos[pos_idx[:, 2], :]
  188. face_normals = safe_normalize(torch.cross(v1 - v0, v2 - v0))
  189. face_normal_indices = (torch.arange(
  190. 0, face_normals.shape[0], dtype=torch.int64,
  191. device='cuda')[:, None]).repeat(1, 3)
  192. gb_geometric_normal, _ = dr.interpolate(face_normals[None, ...], rast_out,
  193. face_normal_indices.int())
  194. normal = (gb_geometric_normal + 1) * 0.5
  195. mask = torch.clamp(rast_out[..., -1:], 0, 1)
  196. color = color * mask + (1 - mask) * torch.ones_like(color)
  197. normal = normal * mask + (1 - mask) * torch.ones_like(normal)
  198. return color, mask, normal
  199. # The following code is based on https://github.com/Mathux/ACTOR.git
  200. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  201. # Check PYTORCH3D_LICENCE before use
  202. def _copysign(a, b):
  203. """
  204. Return a tensor where each element has the absolute value taken from the,
  205. corresponding element of a, with sign taken from the corresponding
  206. element of b. This is like the standard copysign floating-point operation,
  207. but is not careful about negative 0 and NaN.
  208. Args:
  209. a: source tensor.
  210. b: tensor whose signs will be used, of the same shape as a.
  211. Returns:
  212. Tensor of the same shape as a with the signs of b.
  213. """
  214. signs_differ = (a < 0) != (b < 0)
  215. return torch.where(signs_differ, -a, a)
  216. def _sqrt_positive_part(x):
  217. """
  218. Returns torch.sqrt(torch.max(0, x))
  219. but with a zero subgradient where x is 0.
  220. """
  221. ret = torch.zeros_like(x)
  222. positive_mask = x > 0
  223. ret[positive_mask] = torch.sqrt(x[positive_mask])
  224. return ret
  225. def matrix_to_quaternion(matrix):
  226. """
  227. Convert rotations given as rotation matrices to quaternions.
  228. Args:
  229. matrix: Rotation matrices as tensor of shape (..., 3, 3).
  230. Returns:
  231. quaternions with real part first, as tensor of shape (..., 4).
  232. """
  233. if matrix.size(-1) != 3 or matrix.size(-2) != 3:
  234. raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
  235. m00 = matrix[..., 0, 0]
  236. m11 = matrix[..., 1, 1]
  237. m22 = matrix[..., 2, 2]
  238. o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
  239. x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
  240. y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
  241. z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
  242. o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
  243. o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
  244. o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
  245. return torch.stack((o0, o1, o2, o3), -1)
  246. def quaternion_to_axis_angle(quaternions):
  247. """
  248. Convert rotations given as quaternions to axis/angle.
  249. Args:
  250. quaternions: quaternions with real part first,
  251. as tensor of shape (..., 4).
  252. Returns:
  253. Rotations given as a vector in axis angle form, as a tensor
  254. of shape (..., 3), where the magnitude is the angle
  255. turned anticlockwise in radians around the vector's
  256. direction.
  257. """
  258. norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
  259. half_angles = torch.atan2(norms, quaternions[..., :1])
  260. angles = 2 * half_angles
  261. eps = 1e-6
  262. small_angles = angles.abs() < eps
  263. sin_half_angles_over_angles = torch.empty_like(angles)
  264. sin_half_angles_over_angles[~small_angles] = (
  265. torch.sin(half_angles[~small_angles]) / angles[~small_angles])
  266. # for x small, sin(x/2) is about x/2 - (x/2)^3/6
  267. # so sin(x/2)/x is about 1/2 - (x*x)/48
  268. sin_half_angles_over_angles[small_angles] = (
  269. 0.5 - (angles[small_angles] * angles[small_angles]) / 48)
  270. return quaternions[..., 1:] / sin_half_angles_over_angles
  271. def matrix_to_axis_angle(matrix):
  272. """
  273. Convert rotations given as rotation matrices to axis/angle.
  274. Args:
  275. matrix: Rotation matrices as tensor of shape (..., 3, 3).
  276. Returns:
  277. Rotations given as a vector in axis angle form, as a tensor
  278. of shape (..., 3), where the magnitude is the angle
  279. turned anticlockwise in radians around the vector's
  280. direction.
  281. """
  282. return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
  283. def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
  284. """
  285. Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
  286. using Gram--Schmidt orthogonalisation per Section B of [1].
  287. Args:
  288. d6: 6D rotation representation, of size (*, 6)
  289. Returns:
  290. batch of rotation matrices of size (*, 3, 3)
  291. [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
  292. On the Continuity of Rotation Representations in Neural Networks.
  293. IEEE Conference on Computer Vision and Pattern Recognition, 2019.
  294. Retrieved from http://arxiv.org/abs/1812.07035
  295. """
  296. a1, a2 = d6[..., :3], d6[..., 3:]
  297. b1 = F.normalize(a1, dim=-1)
  298. b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
  299. b2 = F.normalize(b2, dim=-1)
  300. b3 = torch.cross(b1, b2, dim=-1)
  301. return torch.stack((b1, b2, b3), dim=-2)