embedders.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from typing import Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from unicore.modules import LayerNorm
  7. from unicore.utils import one_hot
  8. from .common import Linear, SimpleModuleList, residual
  9. class InputEmbedder(nn.Module):
  10. def __init__(
  11. self,
  12. tf_dim: int,
  13. msa_dim: int,
  14. d_pair: int,
  15. d_msa: int,
  16. relpos_k: int,
  17. use_chain_relative: bool = False,
  18. max_relative_chain: Optional[int] = None,
  19. **kwargs,
  20. ):
  21. super(InputEmbedder, self).__init__()
  22. self.tf_dim = tf_dim
  23. self.msa_dim = msa_dim
  24. self.d_pair = d_pair
  25. self.d_msa = d_msa
  26. self.linear_tf_z_i = Linear(tf_dim, d_pair)
  27. self.linear_tf_z_j = Linear(tf_dim, d_pair)
  28. self.linear_tf_m = Linear(tf_dim, d_msa)
  29. self.linear_msa_m = Linear(msa_dim, d_msa)
  30. # RPE stuff
  31. self.relpos_k = relpos_k
  32. self.use_chain_relative = use_chain_relative
  33. self.max_relative_chain = max_relative_chain
  34. if not self.use_chain_relative:
  35. self.num_bins = 2 * self.relpos_k + 1
  36. else:
  37. self.num_bins = 2 * self.relpos_k + 2
  38. self.num_bins += 1 # entity id
  39. self.num_bins += 2 * max_relative_chain + 2
  40. self.linear_relpos = Linear(self.num_bins, d_pair)
  41. def _relpos_indices(
  42. self,
  43. res_id: torch.Tensor,
  44. sym_id: Optional[torch.Tensor] = None,
  45. asym_id: Optional[torch.Tensor] = None,
  46. entity_id: Optional[torch.Tensor] = None,
  47. ):
  48. max_rel_res = self.relpos_k
  49. rp = res_id[..., None] - res_id[..., None, :]
  50. rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res
  51. if not self.use_chain_relative:
  52. return rp
  53. else:
  54. asym_id_same = asym_id[..., :, None] == asym_id[..., None, :]
  55. rp[~asym_id_same] = 2 * max_rel_res + 1
  56. entity_id_same = entity_id[..., :, None] == entity_id[..., None, :]
  57. rp_entity_id = entity_id_same.type(rp.dtype)[..., None]
  58. rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :]
  59. max_rel_chain = self.max_relative_chain
  60. clipped_rel_chain = torch.clamp(
  61. rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain)
  62. clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1
  63. return rp, rp_entity_id, clipped_rel_chain
  64. def relpos_emb(
  65. self,
  66. res_id: torch.Tensor,
  67. sym_id: Optional[torch.Tensor] = None,
  68. asym_id: Optional[torch.Tensor] = None,
  69. entity_id: Optional[torch.Tensor] = None,
  70. num_sym: Optional[torch.Tensor] = None,
  71. ):
  72. dtype = self.linear_relpos.weight.dtype
  73. if not self.use_chain_relative:
  74. rp = self._relpos_indices(res_id=res_id)
  75. return self.linear_relpos(
  76. one_hot(rp, num_classes=self.num_bins, dtype=dtype))
  77. else:
  78. rp, rp_entity_id, rp_rel_chain = self._relpos_indices(
  79. res_id=res_id,
  80. sym_id=sym_id,
  81. asym_id=asym_id,
  82. entity_id=entity_id)
  83. rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype)
  84. rp_entity_id = rp_entity_id.type(dtype)
  85. rp_rel_chain = one_hot(
  86. rp_rel_chain,
  87. num_classes=(2 * self.max_relative_chain + 2),
  88. dtype=dtype)
  89. return self.linear_relpos(
  90. torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1))
  91. def forward(
  92. self,
  93. tf: torch.Tensor,
  94. msa: torch.Tensor,
  95. ) -> Tuple[torch.Tensor, torch.Tensor]:
  96. # [*, N_res, d_pair]
  97. if self.tf_dim == 21:
  98. # multimer use 21 target dim
  99. tf = tf[..., 1:]
  100. # convert type if necessary
  101. tf = tf.type(self.linear_tf_z_i.weight.dtype)
  102. msa = msa.type(self.linear_tf_z_i.weight.dtype)
  103. n_clust = msa.shape[-3]
  104. msa_emb = self.linear_msa_m(msa)
  105. # target_feat (aatype) into msa representation
  106. tf_m = (
  107. self.linear_tf_m(tf).unsqueeze(-3).expand(
  108. ((-1, ) * len(tf.shape[:-2]) + # noqa W504
  109. (n_clust, -1, -1))))
  110. msa_emb += tf_m
  111. tf_emb_i = self.linear_tf_z_i(tf)
  112. tf_emb_j = self.linear_tf_z_j(tf)
  113. pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
  114. return msa_emb, pair_emb
  115. class RecyclingEmbedder(nn.Module):
  116. def __init__(
  117. self,
  118. d_msa: int,
  119. d_pair: int,
  120. min_bin: float,
  121. max_bin: float,
  122. num_bins: int,
  123. inf: float = 1e8,
  124. **kwargs,
  125. ):
  126. super(RecyclingEmbedder, self).__init__()
  127. self.d_msa = d_msa
  128. self.d_pair = d_pair
  129. self.min_bin = min_bin
  130. self.max_bin = max_bin
  131. self.num_bins = num_bins
  132. self.inf = inf
  133. self.squared_bins = None
  134. self.linear = Linear(self.num_bins, self.d_pair)
  135. self.layer_norm_m = LayerNorm(self.d_msa)
  136. self.layer_norm_z = LayerNorm(self.d_pair)
  137. def forward(
  138. self,
  139. m: torch.Tensor,
  140. z: torch.Tensor,
  141. ) -> Tuple[torch.Tensor, torch.Tensor]:
  142. m_update = self.layer_norm_m(m)
  143. z_update = self.layer_norm_z(z)
  144. return m_update, z_update
  145. def recyle_pos(
  146. self,
  147. x: torch.Tensor,
  148. ) -> Tuple[torch.Tensor, torch.Tensor]:
  149. if self.squared_bins is None:
  150. bins = torch.linspace(
  151. self.min_bin,
  152. self.max_bin,
  153. self.num_bins,
  154. dtype=torch.float if self.training else x.dtype,
  155. device=x.device,
  156. requires_grad=False,
  157. )
  158. self.squared_bins = bins**2
  159. upper = torch.cat(
  160. [self.squared_bins[1:],
  161. self.squared_bins.new_tensor([self.inf])],
  162. dim=-1)
  163. if self.training:
  164. x = x.float()
  165. d = torch.sum(
  166. (x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True)
  167. d = ((d > self.squared_bins) * # noqa W504
  168. (d < upper)).type(self.linear.weight.dtype)
  169. d = self.linear(d)
  170. return d
  171. class TemplateAngleEmbedder(nn.Module):
  172. def __init__(
  173. self,
  174. d_in: int,
  175. d_out: int,
  176. **kwargs,
  177. ):
  178. super(TemplateAngleEmbedder, self).__init__()
  179. self.d_out = d_out
  180. self.d_in = d_in
  181. self.linear_1 = Linear(self.d_in, self.d_out, init='relu')
  182. self.act = nn.GELU()
  183. self.linear_2 = Linear(self.d_out, self.d_out, init='relu')
  184. def forward(self, x: torch.Tensor) -> torch.Tensor:
  185. x = self.linear_1(x.type(self.linear_1.weight.dtype))
  186. x = self.act(x)
  187. x = self.linear_2(x)
  188. return x
  189. class TemplatePairEmbedder(nn.Module):
  190. def __init__(
  191. self,
  192. d_in: int,
  193. v2_d_in: list,
  194. d_out: int,
  195. d_pair: int,
  196. v2_feature: bool = False,
  197. **kwargs,
  198. ):
  199. super(TemplatePairEmbedder, self).__init__()
  200. self.d_out = d_out
  201. self.v2_feature = v2_feature
  202. if self.v2_feature:
  203. self.d_in = v2_d_in
  204. self.linear = SimpleModuleList()
  205. for d_in in self.d_in:
  206. self.linear.append(Linear(d_in, self.d_out, init='relu'))
  207. self.z_layer_norm = LayerNorm(d_pair)
  208. self.z_linear = Linear(d_pair, self.d_out, init='relu')
  209. else:
  210. self.d_in = d_in
  211. self.linear = Linear(self.d_in, self.d_out, init='relu')
  212. def forward(
  213. self,
  214. x,
  215. z,
  216. ) -> torch.Tensor:
  217. if not self.v2_feature:
  218. x = self.linear(x.type(self.linear.weight.dtype))
  219. return x
  220. else:
  221. dtype = self.z_linear.weight.dtype
  222. t = self.linear[0](x[0].type(dtype))
  223. for i, s in enumerate(x[1:]):
  224. t = residual(t, self.linear[i + 1](s.type(dtype)),
  225. self.training)
  226. t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training)
  227. return t
  228. class ExtraMSAEmbedder(nn.Module):
  229. def __init__(
  230. self,
  231. d_in: int,
  232. d_out: int,
  233. **kwargs,
  234. ):
  235. super(ExtraMSAEmbedder, self).__init__()
  236. self.d_in = d_in
  237. self.d_out = d_out
  238. self.linear = Linear(self.d_in, self.d_out)
  239. def forward(self, x: torch.Tensor) -> torch.Tensor:
  240. return self.linear(x.type(self.linear.weight.dtype))