frame.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from __future__ import annotations # noqa
  4. from typing import Any, Callable, Iterable, Optional, Sequence, Tuple
  5. import numpy as np
  6. import torch
  7. def zero_translation(
  8. batch_dims: Tuple[int],
  9. dtype: Optional[torch.dtype] = torch.float,
  10. device: Optional[torch.device] = torch.device('cpu'),
  11. requires_grad: bool = False,
  12. ) -> torch.Tensor:
  13. trans = torch.zeros((*batch_dims, 3),
  14. dtype=dtype,
  15. device=device,
  16. requires_grad=requires_grad)
  17. return trans
  18. # pylint: disable=bad-whitespace
  19. _QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)
  20. _QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr
  21. _QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii
  22. _QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj
  23. _QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk
  24. _QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij
  25. _QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik
  26. _QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk
  27. _QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir
  28. _QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr
  29. _QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr
  30. _QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9)
  31. _QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT)
  32. _QUAT_MULTIPLY = np.zeros((4, 4, 4))
  33. _QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0],
  34. [0, 0, 0, -1]]
  35. _QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1],
  36. [0, 0, -1, 0]]
  37. _QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0],
  38. [0, 1, 0, 0]]
  39. _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0],
  40. [1, 0, 0, 0]]
  41. _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
  42. _QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC)
  43. class Rotation:
  44. def __init__(
  45. self,
  46. mat: torch.Tensor,
  47. ):
  48. if mat.shape[-2:] != (3, 3):
  49. raise ValueError(f'incorrect rotation shape: {mat.shape}')
  50. self._mat = mat
  51. @staticmethod
  52. def identity(
  53. shape,
  54. dtype: Optional[torch.dtype] = torch.float,
  55. device: Optional[torch.device] = torch.device('cpu'),
  56. requires_grad: bool = False,
  57. ) -> Rotation:
  58. mat = torch.eye(
  59. 3, dtype=dtype, device=device, requires_grad=requires_grad)
  60. mat = mat.view(*((1, ) * len(shape)), 3, 3)
  61. mat = mat.expand(*shape, -1, -1)
  62. return Rotation(mat)
  63. @staticmethod
  64. def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
  65. return (a.float() @ b.float()).type(a.dtype)
  66. @staticmethod
  67. def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
  68. return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype)
  69. def __getitem__(self, index: Any) -> Rotation:
  70. if not isinstance(index, tuple):
  71. index = (index, )
  72. return Rotation(mat=self._mat[index + (slice(None), slice(None))])
  73. def __mul__(self, right: Any) -> Rotation:
  74. if isinstance(right, (int, float)):
  75. return Rotation(mat=self._mat * right)
  76. elif isinstance(right, torch.Tensor):
  77. return Rotation(mat=self._mat * right[..., None, None])
  78. else:
  79. raise TypeError(
  80. f'multiplicand must be a tensor or a number, got {type(right)}.'
  81. )
  82. def __rmul__(self, left: Any) -> Rotation:
  83. return self.__mul__(left)
  84. def __matmul__(self, other: Rotation) -> Rotation:
  85. new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat)
  86. return Rotation(mat=new_mat)
  87. @property
  88. def _inv_mat(self):
  89. return self._mat.transpose(-1, -2)
  90. @property
  91. def rot_mat(self) -> torch.Tensor:
  92. return self._mat
  93. def invert(self) -> Rotation:
  94. return Rotation(mat=self._inv_mat)
  95. def apply(self, pts: torch.Tensor) -> torch.Tensor:
  96. return Rotation.mat_mul_vec(self._mat, pts)
  97. def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
  98. return Rotation.mat_mul_vec(self._inv_mat, pts)
  99. # inherit tensor behaviors
  100. @property
  101. def shape(self) -> torch.Size:
  102. s = self._mat.shape[:-2]
  103. return s
  104. @property
  105. def dtype(self) -> torch.dtype:
  106. return self._mat.dtype
  107. @property
  108. def device(self) -> torch.device:
  109. return self._mat.device
  110. @property
  111. def requires_grad(self) -> bool:
  112. return self._mat.requires_grad
  113. def unsqueeze(self, dim: int) -> Rotation:
  114. if dim >= len(self.shape):
  115. raise ValueError('Invalid dimension')
  116. rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2)
  117. return Rotation(mat=rot_mats)
  118. def map_tensor_fn(self, fn: Callable[[torch.Tensor],
  119. torch.Tensor]) -> Rotation:
  120. mat = self._mat.view(self._mat.shape[:-2] + (9, ))
  121. mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1)
  122. mat = mat.view(mat.shape[:-1] + (3, 3))
  123. return Rotation(mat=mat)
  124. @staticmethod
  125. def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
  126. rot_mats = [r.rot_mat for r in rs]
  127. rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
  128. return Rotation(mat=rot_mats)
  129. def cuda(self) -> Rotation:
  130. return Rotation(mat=self._mat.cuda())
  131. def to(self, device: Optional[torch.device],
  132. dtype: Optional[torch.dtype]) -> Rotation:
  133. return Rotation(mat=self._mat.to(device=device, dtype=dtype))
  134. def type(self, dtype: Optional[torch.dtype]) -> Rotation:
  135. return Rotation(mat=self._mat.type(dtype))
  136. def detach(self) -> Rotation:
  137. return Rotation(mat=self._mat.detach())
  138. class Frame:
  139. def __init__(
  140. self,
  141. rotation: Optional[Rotation],
  142. translation: Optional[torch.Tensor],
  143. ):
  144. if rotation is None and translation is None:
  145. rotation = Rotation.identity((0, ))
  146. translation = zero_translation((0, ))
  147. elif translation is None:
  148. translation = zero_translation(rotation.shape, rotation.dtype,
  149. rotation.device,
  150. rotation.requires_grad)
  151. elif rotation is None:
  152. rotation = Rotation.identity(
  153. translation.shape[:-1],
  154. translation.dtype,
  155. translation.device,
  156. translation.requires_grad,
  157. )
  158. if (rotation.shape != translation.shape[:-1]) or (rotation.device
  159. != # noqa W504
  160. translation.device):
  161. raise ValueError('RotationMatrix and translation incompatible')
  162. self._r = rotation
  163. self._t = translation
  164. @staticmethod
  165. def identity(
  166. shape: Iterable[int],
  167. dtype: Optional[torch.dtype] = torch.float,
  168. device: Optional[torch.device] = torch.device('cpu'),
  169. requires_grad: bool = False,
  170. ) -> Frame:
  171. return Frame(
  172. Rotation.identity(shape, dtype, device, requires_grad),
  173. zero_translation(shape, dtype, device, requires_grad),
  174. )
  175. def __getitem__(
  176. self,
  177. index: Any,
  178. ) -> Frame:
  179. if type(index) != tuple:
  180. index = (index, )
  181. return Frame(
  182. self._r[index],
  183. self._t[index + (slice(None), )],
  184. )
  185. def __mul__(
  186. self,
  187. right: torch.Tensor,
  188. ) -> Frame:
  189. if not (isinstance(right, torch.Tensor)):
  190. raise TypeError('The other multiplicand must be a Tensor')
  191. new_rots = self._r * right
  192. new_trans = self._t * right[..., None]
  193. return Frame(new_rots, new_trans)
  194. def __rmul__(
  195. self,
  196. left: torch.Tensor,
  197. ) -> Frame:
  198. return self.__mul__(left)
  199. @property
  200. def shape(self) -> torch.Size:
  201. s = self._t.shape[:-1]
  202. return s
  203. @property
  204. def device(self) -> torch.device:
  205. return self._t.device
  206. def get_rots(self) -> Rotation:
  207. return self._r
  208. def get_trans(self) -> torch.Tensor:
  209. return self._t
  210. def compose(
  211. self,
  212. other: Frame,
  213. ) -> Frame:
  214. new_rot = self._r @ other._r
  215. new_trans = self._r.apply(other._t) + self._t
  216. return Frame(new_rot, new_trans)
  217. def apply(
  218. self,
  219. pts: torch.Tensor,
  220. ) -> torch.Tensor:
  221. rotated = self._r.apply(pts)
  222. return rotated + self._t
  223. def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
  224. pts = pts - self._t
  225. return self._r.invert_apply(pts)
  226. def invert(self) -> Frame:
  227. rot_inv = self._r.invert()
  228. trn_inv = rot_inv.apply(self._t)
  229. return Frame(rot_inv, -1 * trn_inv)
  230. def map_tensor_fn(self, fn: Callable[[torch.Tensor],
  231. torch.Tensor]) -> Frame:
  232. new_rots = self._r.map_tensor_fn(fn)
  233. new_trans = torch.stack(
  234. list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1)
  235. return Frame(new_rots, new_trans)
  236. def to_tensor_4x4(self) -> torch.Tensor:
  237. tensor = self._t.new_zeros((*self.shape, 4, 4))
  238. tensor[..., :3, :3] = self._r.rot_mat
  239. tensor[..., :3, 3] = self._t
  240. tensor[..., 3, 3] = 1
  241. return tensor
  242. @staticmethod
  243. def from_tensor_4x4(t: torch.Tensor) -> Frame:
  244. if t.shape[-2:] != (4, 4):
  245. raise ValueError('Incorrectly shaped input tensor')
  246. rots = Rotation(mat=t[..., :3, :3])
  247. trans = t[..., :3, 3]
  248. return Frame(rots, trans)
  249. @staticmethod
  250. def from_3_points(
  251. p_neg_x_axis: torch.Tensor,
  252. origin: torch.Tensor,
  253. p_xy_plane: torch.Tensor,
  254. eps: float = 1e-8,
  255. ) -> Frame:
  256. p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
  257. origin = torch.unbind(origin, dim=-1)
  258. p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
  259. e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
  260. e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
  261. denom = torch.sqrt(sum((c * c for c in e0)) + eps)
  262. e0 = [c / denom for c in e0]
  263. dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
  264. e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
  265. denom = torch.sqrt(sum((c * c for c in e1)) + eps)
  266. e1 = [c / denom for c in e1]
  267. e2 = [
  268. e0[1] * e1[2] - e0[2] * e1[1],
  269. e0[2] * e1[0] - e0[0] * e1[2],
  270. e0[0] * e1[1] - e0[1] * e1[0],
  271. ]
  272. rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
  273. rots = rots.reshape(rots.shape[:-1] + (3, 3))
  274. rot_obj = Rotation(mat=rots)
  275. return Frame(rot_obj, torch.stack(origin, dim=-1))
  276. def unsqueeze(
  277. self,
  278. dim: int,
  279. ) -> Frame:
  280. if dim >= len(self.shape):
  281. raise ValueError('Invalid dimension')
  282. rots = self._r.unsqueeze(dim)
  283. trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1)
  284. return Frame(rots, trans)
  285. @staticmethod
  286. def cat(
  287. Ts: Sequence[Frame],
  288. dim: int,
  289. ) -> Frame:
  290. rots = Rotation.cat([T._r for T in Ts], dim)
  291. trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1)
  292. return Frame(rots, trans)
  293. def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame:
  294. return Frame(fn(self._r), self._t)
  295. def apply_trans_fn(self, fn: Callable[[torch.Tensor],
  296. torch.Tensor]) -> Frame:
  297. return Frame(self._r, fn(self._t))
  298. def scale_translation(self, trans_scale_factor: float) -> Frame:
  299. # fn = lambda t: t * trans_scale_factor
  300. def fn(t):
  301. return t * trans_scale_factor
  302. return self.apply_trans_fn(fn)
  303. def stop_rot_gradient(self) -> Frame:
  304. # fn = lambda r: r.detach()
  305. def fn(r):
  306. return r.detach()
  307. return self.apply_rot_fn(fn)
  308. @staticmethod
  309. def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
  310. input_dtype = ca_xyz.dtype
  311. n_xyz = n_xyz.float()
  312. ca_xyz = ca_xyz.float()
  313. c_xyz = c_xyz.float()
  314. n_xyz = n_xyz - ca_xyz
  315. c_xyz = c_xyz - ca_xyz
  316. c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)]
  317. norm = torch.sqrt(eps + c_x**2 + c_y**2)
  318. sin_c1 = -c_y / norm
  319. cos_c1 = c_x / norm
  320. c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
  321. c1_rots[..., 0, 0] = cos_c1
  322. c1_rots[..., 0, 1] = -1 * sin_c1
  323. c1_rots[..., 1, 0] = sin_c1
  324. c1_rots[..., 1, 1] = cos_c1
  325. c1_rots[..., 2, 2] = 1
  326. norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2)
  327. sin_c2 = d_pair / norm
  328. cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm
  329. c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
  330. c2_rots[..., 0, 0] = cos_c2
  331. c2_rots[..., 0, 2] = sin_c2
  332. c2_rots[..., 1, 1] = 1
  333. c2_rots[..., 2, 0] = -1 * sin_c2
  334. c2_rots[..., 2, 2] = cos_c2
  335. c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots)
  336. n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz)
  337. _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
  338. norm = torch.sqrt(eps + n_y**2 + n_z**2)
  339. sin_n = -n_z / norm
  340. cos_n = n_y / norm
  341. n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
  342. n_rots[..., 0, 0] = 1
  343. n_rots[..., 1, 1] = cos_n
  344. n_rots[..., 1, 2] = -1 * sin_n
  345. n_rots[..., 2, 1] = sin_n
  346. n_rots[..., 2, 2] = cos_n
  347. rots = Rotation.mat_mul_mat(n_rots, c_rots)
  348. rots = rots.transpose(-1, -2)
  349. rot_obj = Rotation(mat=rots.type(input_dtype))
  350. return Frame(rot_obj, ca_xyz.type(input_dtype))
  351. def cuda(self) -> Frame:
  352. return Frame(self._r.cuda(), self._t.cuda())
  353. @property
  354. def dtype(self) -> torch.dtype:
  355. assert self._r.dtype == self._t.dtype
  356. return self._r.dtype
  357. def type(self, dtype) -> Frame:
  358. return Frame(self._r.type(dtype), self._t.type(dtype))
  359. class Quaternion:
  360. def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor):
  361. if quaternion.shape[-1] != 4:
  362. raise ValueError(f'incorrect quaternion shape: {quaternion.shape}')
  363. self._q = quaternion
  364. self._t = translation
  365. @staticmethod
  366. def identity(
  367. shape: Iterable[int],
  368. dtype: Optional[torch.dtype] = torch.float,
  369. device: Optional[torch.device] = torch.device('cpu'),
  370. requires_grad: bool = False,
  371. ) -> Quaternion:
  372. trans = zero_translation(shape, dtype, device, requires_grad)
  373. quats = torch.zeros((*shape, 4),
  374. dtype=dtype,
  375. device=device,
  376. requires_grad=requires_grad)
  377. with torch.no_grad():
  378. quats[..., 0] = 1
  379. return Quaternion(quats, trans)
  380. def get_quats(self):
  381. return self._q
  382. def get_trans(self):
  383. return self._t
  384. def get_rot_mats(self):
  385. quats = self.get_quats()
  386. rot_mats = Quaternion.quat_to_rot(quats)
  387. return rot_mats
  388. @staticmethod
  389. def quat_to_rot(normalized_quat):
  390. global _QUAT_TO_ROT_tensor
  391. dtype = normalized_quat.dtype
  392. normalized_quat = normalized_quat.float()
  393. if _QUAT_TO_ROT_tensor.device != normalized_quat.device:
  394. _QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to(
  395. normalized_quat.device)
  396. rot_tensor = torch.sum(
  397. _QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None]
  398. * normalized_quat[..., None, :, None],
  399. dim=(-3, -2),
  400. )
  401. rot_tensor = rot_tensor.type(dtype)
  402. rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3)
  403. return rot_tensor
  404. @staticmethod
  405. def normalize_quat(quats):
  406. dtype = quats.dtype
  407. quats = quats.float()
  408. quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
  409. quats = quats.type(dtype)
  410. return quats
  411. @staticmethod
  412. def quat_multiply_by_vec(quat, vec):
  413. dtype = quat.dtype
  414. quat = quat.float()
  415. vec = vec.float()
  416. global _QUAT_MULTIPLY_BY_VEC_tensor
  417. if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device:
  418. _QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to(
  419. quat.device)
  420. mat = _QUAT_MULTIPLY_BY_VEC_tensor
  421. reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape)
  422. return torch.sum(
  423. reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None],
  424. dim=(-3, -2),
  425. ).type(dtype)
  426. def compose_q_update_vec(self,
  427. q_update_vec: torch.Tensor,
  428. normalize_quats: bool = True) -> torch.Tensor:
  429. quats = self.get_quats()
  430. new_quats = quats + Quaternion.quat_multiply_by_vec(
  431. quats, q_update_vec)
  432. if normalize_quats:
  433. new_quats = Quaternion.normalize_quat(new_quats)
  434. return new_quats
  435. def compose_update_vec(
  436. self,
  437. update_vec: torch.Tensor,
  438. pre_rot_mat: Rotation,
  439. ) -> Quaternion:
  440. q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:]
  441. new_quats = self.compose_q_update_vec(q_vec)
  442. trans_update = pre_rot_mat.apply(t_vec)
  443. new_trans = self._t + trans_update
  444. return Quaternion(new_quats, new_trans)
  445. def stop_rot_gradient(self) -> Quaternion:
  446. return Quaternion(self._q.detach(), self._t)