featurization.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from typing import Dict
  4. import torch
  5. import torch.nn as nn
  6. from unicore.utils import batched_gather, one_hot
  7. from modelscope.models.science.unifold.data import residue_constants as rc
  8. from .frame import Frame
  9. def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
  10. is_gly = aatype == rc.restype_order['G']
  11. ca_idx = rc.atom_order['CA']
  12. cb_idx = rc.atom_order['CB']
  13. pseudo_beta = torch.where(
  14. is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3),
  15. all_atom_positions[..., ca_idx, :],
  16. all_atom_positions[..., cb_idx, :],
  17. )
  18. if all_atom_masks is not None:
  19. pseudo_beta_mask = torch.where(
  20. is_gly,
  21. all_atom_masks[..., ca_idx],
  22. all_atom_masks[..., cb_idx],
  23. )
  24. return pseudo_beta, pseudo_beta_mask
  25. else:
  26. return pseudo_beta
  27. def atom14_to_atom37(atom14, batch):
  28. atom37_data = batched_gather(
  29. atom14,
  30. batch['residx_atom37_to_atom14'],
  31. dim=-2,
  32. num_batch_dims=len(atom14.shape[:-2]),
  33. )
  34. atom37_data = atom37_data * batch['atom37_atom_exists'][..., None]
  35. return atom37_data
  36. def build_template_angle_feat(template_feats, v2_feature=False):
  37. template_aatype = template_feats['template_aatype']
  38. torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos']
  39. torsion_angles_mask = template_feats['template_torsion_angles_mask']
  40. if not v2_feature:
  41. alt_torsion_angles_sin_cos = template_feats[
  42. 'template_alt_torsion_angles_sin_cos']
  43. template_angle_feat = torch.cat(
  44. [
  45. one_hot(template_aatype, 22),
  46. torsion_angles_sin_cos.reshape(
  47. *torsion_angles_sin_cos.shape[:-2], 14),
  48. alt_torsion_angles_sin_cos.reshape(
  49. *alt_torsion_angles_sin_cos.shape[:-2], 14),
  50. torsion_angles_mask,
  51. ],
  52. dim=-1,
  53. )
  54. template_angle_mask = torsion_angles_mask[..., 2]
  55. else:
  56. chi_mask = torsion_angles_mask[..., 3:]
  57. chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask
  58. chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask
  59. template_angle_feat = torch.cat(
  60. [
  61. one_hot(template_aatype, 22),
  62. chi_angles_sin,
  63. chi_angles_cos,
  64. chi_mask,
  65. ],
  66. dim=-1,
  67. )
  68. template_angle_mask = chi_mask[..., 0]
  69. return template_angle_feat, template_angle_mask
  70. def build_template_pair_feat(
  71. batch,
  72. min_bin,
  73. max_bin,
  74. num_bins,
  75. eps=1e-20,
  76. inf=1e8,
  77. ):
  78. template_mask = batch['template_pseudo_beta_mask']
  79. template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
  80. tpb = batch['template_pseudo_beta']
  81. dgram = torch.sum(
  82. (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True)
  83. lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2
  84. upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
  85. dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
  86. to_concat = [dgram, template_mask_2d[..., None]]
  87. aatype_one_hot = nn.functional.one_hot(
  88. batch['template_aatype'],
  89. rc.restype_num + 2,
  90. )
  91. n_res = batch['template_aatype'].shape[-1]
  92. to_concat.append(aatype_one_hot[..., None, :, :].expand(
  93. *aatype_one_hot.shape[:-2], n_res, -1, -1))
  94. to_concat.append(aatype_one_hot[...,
  95. None, :].expand(*aatype_one_hot.shape[:-2],
  96. -1, n_res, -1))
  97. to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3))
  98. to_concat.append(template_mask_2d[..., None])
  99. act = torch.cat(to_concat, dim=-1)
  100. act = act * template_mask_2d[..., None]
  101. return act
  102. def build_template_pair_feat_v2(
  103. batch,
  104. min_bin,
  105. max_bin,
  106. num_bins,
  107. multichain_mask_2d=None,
  108. eps=1e-20,
  109. inf=1e8,
  110. ):
  111. template_mask = batch['template_pseudo_beta_mask']
  112. template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
  113. if multichain_mask_2d is not None:
  114. template_mask_2d *= multichain_mask_2d
  115. tpb = batch['template_pseudo_beta']
  116. dgram = torch.sum(
  117. (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True)
  118. lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2
  119. upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
  120. dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
  121. dgram *= template_mask_2d[..., None]
  122. to_concat = [dgram, template_mask_2d[..., None]]
  123. aatype_one_hot = one_hot(
  124. batch['template_aatype'],
  125. rc.restype_num + 2,
  126. )
  127. n_res = batch['template_aatype'].shape[-1]
  128. to_concat.append(aatype_one_hot[..., None, :, :].expand(
  129. *aatype_one_hot.shape[:-2], n_res, -1, -1))
  130. to_concat.append(aatype_one_hot[...,
  131. None, :].expand(*aatype_one_hot.shape[:-2],
  132. -1, n_res, -1))
  133. n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
  134. rigids = Frame.make_transform_from_reference(
  135. n_xyz=batch['template_all_atom_positions'][..., n, :],
  136. ca_xyz=batch['template_all_atom_positions'][..., ca, :],
  137. c_xyz=batch['template_all_atom_positions'][..., c, :],
  138. eps=eps,
  139. )
  140. points = rigids.get_trans()[..., None, :, :]
  141. rigid_vec = rigids[..., None].invert_apply(points)
  142. inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
  143. t_aa_masks = batch['template_all_atom_mask']
  144. backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[...,
  145. c]
  146. backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[...,
  147. None, :]
  148. if multichain_mask_2d is not None:
  149. backbone_mask_2d *= multichain_mask_2d
  150. inv_distance_scalar = inv_distance_scalar * backbone_mask_2d
  151. unit_vector_data = rigid_vec * inv_distance_scalar[..., None]
  152. to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1))
  153. to_concat.append(backbone_mask_2d[..., None])
  154. return to_concat
  155. def build_extra_msa_feat(batch):
  156. msa_1hot = one_hot(batch['extra_msa'], 23)
  157. msa_feat = [
  158. msa_1hot,
  159. batch['extra_msa_has_deletion'].unsqueeze(-1),
  160. batch['extra_msa_deletion_value'].unsqueeze(-1),
  161. ]
  162. return torch.cat(msa_feat, dim=-1)