confidence.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from typing import Dict, Optional, Tuple
  4. import torch
  5. def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor:
  6. """Computes per-residue pLDDT from logits.
  7. Args:
  8. logits: [num_res, num_bins] output from the PredictedLDDTHead.
  9. Returns:
  10. plddt: [num_res] per-residue pLDDT.
  11. """
  12. num_bins = plddt_logits.shape[-1]
  13. bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1)
  14. bin_width = 1.0 / num_bins
  15. bounds = torch.arange(
  16. start=0.5 * bin_width,
  17. end=1.0,
  18. step=bin_width,
  19. device=plddt_logits.device)
  20. plddt = torch.sum(
  21. bin_probs
  22. * bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape),
  23. dim=-1,
  24. )
  25. return plddt
  26. def compute_bin_values(breaks: torch.Tensor):
  27. """Gets the bin centers from the bin edges.
  28. Args:
  29. breaks: [num_bins - 1] the error bin edges.
  30. Returns:
  31. bin_centers: [num_bins] the error bin centers.
  32. """
  33. step = breaks[1] - breaks[0]
  34. bin_values = breaks + step / 2
  35. bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)],
  36. dim=0)
  37. return bin_values
  38. def compute_predicted_aligned_error(
  39. bin_edges: torch.Tensor,
  40. bin_probs: torch.Tensor,
  41. ) -> Tuple[torch.Tensor, torch.Tensor]:
  42. """Calculates expected aligned distance errors for every pair of residues.
  43. Args:
  44. alignment_confidence_breaks: [num_bins - 1] the error bin edges.
  45. aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
  46. probs for each error bin, for each pair of residues.
  47. Returns:
  48. predicted_aligned_error: [num_res, num_res] the expected aligned distance
  49. error for each pair of residues.
  50. max_predicted_aligned_error: The maximum predicted error possible.
  51. """
  52. bin_values = compute_bin_values(bin_edges)
  53. return torch.sum(bin_probs * bin_values, dim=-1)
  54. def predicted_aligned_error(
  55. pae_logits: torch.Tensor,
  56. max_bin: int = 31,
  57. num_bins: int = 64,
  58. **kwargs,
  59. ) -> Dict[str, torch.Tensor]:
  60. """Computes aligned confidence metrics from logits.
  61. Args:
  62. logits: [num_res, num_res, num_bins] the logits output from
  63. PredictedAlignedErrorHead.
  64. breaks: [num_bins - 1] the error bin edges.
  65. Returns:
  66. aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
  67. aligned error probabilities over bins for each residue pair.
  68. predicted_aligned_error: [num_res, num_res] the expected aligned distance
  69. error for each pair of residues.
  70. max_predicted_aligned_error: The maximum predicted error possible.
  71. """
  72. bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1)
  73. bin_edges = torch.linspace(
  74. 0, max_bin, steps=(num_bins - 1), device=pae_logits.device)
  75. predicted_aligned_error = compute_predicted_aligned_error(
  76. bin_edges=bin_edges,
  77. bin_probs=bin_probs,
  78. )
  79. return {
  80. 'aligned_error_probs_per_bin': bin_probs,
  81. 'predicted_aligned_error': predicted_aligned_error,
  82. }
  83. def predicted_tm_score(
  84. pae_logits: torch.Tensor,
  85. residue_weights: Optional[torch.Tensor] = None,
  86. max_bin: int = 31,
  87. num_bins: int = 64,
  88. eps: float = 1e-8,
  89. asym_id: Optional[torch.Tensor] = None,
  90. interface: bool = False,
  91. **kwargs,
  92. ) -> torch.Tensor:
  93. """Computes predicted TM alignment or predicted interface TM alignment score.
  94. Args:
  95. logits: [num_res, num_res, num_bins] the logits output from
  96. PredictedAlignedErrorHead.
  97. breaks: [num_bins] the error bins.
  98. residue_weights: [num_res] the per residue weights to use for the
  99. expectation.
  100. asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
  101. ipTM calculation, i.e. when interface=True.
  102. interface: If True, interface predicted TM score is computed.
  103. Returns:
  104. ptm_score: The predicted TM alignment or the predicted iTM score.
  105. """
  106. pae_logits = pae_logits.float()
  107. if residue_weights is None:
  108. residue_weights = pae_logits.new_ones(pae_logits.shape[:-2])
  109. breaks = torch.linspace(
  110. 0, max_bin, steps=(num_bins - 1), device=pae_logits.device)
  111. def tm_kernal(nres):
  112. clipped_n = max(nres, 19)
  113. d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8
  114. return lambda x: 1.0 / (1.0 + (x / d0)**2)
  115. def rmsd_kernal(eps): # leave for compute pRMS
  116. return lambda x: 1. / (x + eps)
  117. bin_centers = compute_bin_values(breaks)
  118. probs = torch.nn.functional.softmax(pae_logits, dim=-1)
  119. tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers)
  120. # tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
  121. # rmsd_per_bin = rmsd_kernal()(bin_centers)
  122. predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
  123. pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape)
  124. if interface:
  125. assert asym_id is not None, 'must provide asym_id for iptm calculation.'
  126. pair_mask *= asym_id[..., :, None] != asym_id[..., None, :]
  127. predicted_tm_term *= pair_mask
  128. pair_residue_weights = pair_mask * (
  129. residue_weights[None, :] * residue_weights[:, None])
  130. normed_residue_mask = pair_residue_weights / (
  131. eps + pair_residue_weights.sum(dim=-1, keepdim=True))
  132. per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
  133. weighted = per_alignment * residue_weights
  134. ret = per_alignment.gather(
  135. dim=-1, index=weighted.max(dim=-1,
  136. keepdim=True).indices).squeeze(dim=-1)
  137. return ret