utils.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import os
  2. import mcubes
  3. import numpy as np
  4. import torch
  5. def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
  6. file = open(mesh_path, 'w')
  7. for idx, v in enumerate(verts):
  8. c = colors[idx]
  9. file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
  10. (v[0], v[1], v[2], c[0], c[1], c[2]))
  11. for f in faces:
  12. f_plus = f + 1
  13. file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
  14. file.close()
  15. def save_obj_mesh(mesh_path, verts, faces):
  16. file = open(mesh_path, 'w')
  17. for idx, v in enumerate(verts):
  18. file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
  19. for f in faces:
  20. f_plus = f + 1
  21. file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
  22. file.close()
  23. def to_tensor(img):
  24. if len(img.shape) == 2:
  25. img = img[:, :, np.newaxis]
  26. img = torch.from_numpy(img.transpose(2, 0, 1)).float()
  27. img = img / 255.
  28. return img
  29. def reconstruction(net, calib_tensor, coords, mat, num_samples=50000):
  30. def eval_func(points):
  31. points = np.expand_dims(points, axis=0)
  32. points = np.repeat(points, 1, axis=0)
  33. samples = torch.from_numpy(points).cuda().float()
  34. net.query(samples, calib_tensor)
  35. pred = net.get_preds()
  36. pred = pred[0]
  37. return pred.detach().cpu().numpy()
  38. sdf = eval_grid(coords, eval_func, num_samples=num_samples)
  39. vertices, faces = mcubes.marching_cubes(sdf, 0.5)
  40. verts = np.matmul(mat[:3, :3], vertices.T) + mat[:3, 3:4]
  41. verts = verts.T
  42. return verts, faces
  43. def keep_largest(mesh_big):
  44. mesh_lst = mesh_big.split(only_watertight=False)
  45. keep_mesh = mesh_lst[0]
  46. for mesh in mesh_lst:
  47. if mesh.vertices.shape[0] > keep_mesh.vertices.shape[0]:
  48. keep_mesh = mesh
  49. return keep_mesh
  50. def eval_grid(coords,
  51. eval_func,
  52. init_resolution=64,
  53. threshold=0.01,
  54. num_samples=512 * 512 * 512):
  55. resolution = coords.shape[1:4]
  56. sdf = np.zeros(resolution)
  57. dirty = np.ones(resolution, dtype=bool)
  58. grid_mask = np.zeros(resolution, dtype=bool)
  59. reso = resolution[0] // init_resolution
  60. while reso > 0:
  61. grid_mask[0:resolution[0]:reso, 0:resolution[1]:reso,
  62. 0:resolution[2]:reso] = True
  63. test_mask = np.logical_and(grid_mask, dirty)
  64. points = coords[:, test_mask]
  65. sdf[test_mask] = batch_eval(points, eval_func, num_samples=num_samples)
  66. dirty[test_mask] = False
  67. if reso <= 1:
  68. break
  69. for x in range(0, resolution[0] - reso, reso):
  70. for y in range(0, resolution[1] - reso, reso):
  71. for z in range(0, resolution[2] - reso, reso):
  72. if not dirty[x + reso // 2, y + reso // 2, z + reso // 2]:
  73. continue
  74. v0 = sdf[x, y, z]
  75. v1 = sdf[x, y, z + reso]
  76. v2 = sdf[x, y + reso, z]
  77. v3 = sdf[x, y + reso, z + reso]
  78. v4 = sdf[x + reso, y, z]
  79. v5 = sdf[x + reso, y, z + reso]
  80. v6 = sdf[x + reso, y + reso, z]
  81. v7 = sdf[x + reso, y + reso, z + reso]
  82. v = np.array([v0, v1, v2, v3, v4, v5, v6, v7])
  83. v_min = v.min()
  84. v_max = v.max()
  85. if (v_max - v_min) < threshold:
  86. sdf[x:x + reso, y:y + reso,
  87. z:z + reso] = (v_max + v_min) / 2
  88. dirty[x:x + reso, y:y + reso, z:z + reso] = False
  89. reso //= 2
  90. return sdf.reshape(resolution)
  91. def batch_eval(points, eval_func, num_samples=512 * 512 * 512):
  92. num_pts = points.shape[1]
  93. sdf = np.zeros(num_pts)
  94. num_batches = num_pts // num_samples
  95. for i in range(num_batches):
  96. sdf[i * num_samples:i * num_samples + num_samples] = eval_func(
  97. points[:, i * num_samples:i * num_samples + num_samples])
  98. if num_pts % num_samples:
  99. sdf[num_batches * num_samples:] = eval_func(points[:, num_batches
  100. * num_samples:])
  101. return sdf
  102. def create_grid(res,
  103. b_min=np.array([0, 0, 0]),
  104. b_max=np.array([1, 1, 1]),
  105. transform=None):
  106. coords = np.mgrid[:res, :res, :res]
  107. coords = coords.reshape(3, -1)
  108. coords_matrix = np.eye(4)
  109. length = b_max - b_min
  110. coords_matrix[0, 0] = length[0] / res
  111. coords_matrix[1, 1] = length[1] / res
  112. coords_matrix[2, 2] = length[2] / res
  113. coords_matrix[0:3, 3] = b_min
  114. coords = np.matmul(coords_matrix[:3, :3], coords) + coords_matrix[:3, 3:4]
  115. if transform is not None:
  116. coords = np.matmul(transform[:3, :3], coords) + transform[:3, 3:4]
  117. coords_matrix = np.matmul(transform, coords_matrix)
  118. coords = coords.reshape(3, res, res, res)
  119. return coords, coords_matrix
  120. def get_submesh(verts,
  121. faces,
  122. color,
  123. verts_retained=None,
  124. faces_retained=None,
  125. min_vert_in_face=2):
  126. verts = verts
  127. faces = faces
  128. colors = color
  129. if verts_retained is not None:
  130. if verts_retained.dtype != 'bool':
  131. vert_mask = np.zeros(len(verts), dtype=bool)
  132. vert_mask[verts_retained] = True
  133. else:
  134. vert_mask = verts_retained
  135. bool_faces = np.sum(
  136. vert_mask[faces.ravel()].reshape(-1, 3), axis=1) > min_vert_in_face
  137. elif faces_retained is not None:
  138. if faces_retained.dtype != 'bool':
  139. bool_faces = np.zeros(len(faces_retained), dtype=bool)
  140. else:
  141. bool_faces = faces_retained
  142. new_faces = faces[bool_faces]
  143. vertex_ids = list(set(new_faces.ravel()))
  144. oldtonew = -1 * np.ones([len(verts)])
  145. oldtonew[vertex_ids] = range(0, len(vertex_ids))
  146. new_verts = verts[vertex_ids]
  147. new_colors = colors[vertex_ids]
  148. new_faces = oldtonew[new_faces].astype('int32')
  149. return (new_verts, new_faces, new_colors, bool_faces, vertex_ids)