template.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import math
  4. from functools import partial
  5. from typing import List, Optional, Tuple
  6. import torch
  7. import torch.nn as nn
  8. from unicore.modules import LayerNorm
  9. from unicore.utils import (checkpoint_sequential, permute_final_dims,
  10. tensor_tree_map)
  11. from .attentions import (Attention, TriangleAttentionEnding,
  12. TriangleAttentionStarting, gen_attn_mask)
  13. from .common import (Linear, SimpleModuleList, Transition,
  14. bias_dropout_residual, chunk_layer, residual,
  15. tri_mul_residual)
  16. from .featurization import build_template_pair_feat_v2
  17. from .triangle_multiplication import (TriangleMultiplicationIncoming,
  18. TriangleMultiplicationOutgoing)
  19. class TemplatePointwiseAttention(nn.Module):
  20. def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs):
  21. super(TemplatePointwiseAttention, self).__init__()
  22. self.inf = inf
  23. self.mha = Attention(
  24. d_pair,
  25. d_template,
  26. d_template,
  27. d_hid,
  28. num_heads,
  29. gating=False,
  30. )
  31. def _chunk(
  32. self,
  33. z: torch.Tensor,
  34. t: torch.Tensor,
  35. mask: torch.Tensor,
  36. chunk_size: int,
  37. ) -> torch.Tensor:
  38. mha_inputs = {
  39. 'q': z,
  40. 'k': t,
  41. 'v': t,
  42. 'mask': mask,
  43. }
  44. return chunk_layer(
  45. self.mha,
  46. mha_inputs,
  47. chunk_size=chunk_size,
  48. num_batch_dims=len(z.shape[:-2]),
  49. )
  50. def forward(
  51. self,
  52. t: torch.Tensor,
  53. z: torch.Tensor,
  54. template_mask: Optional[torch.Tensor] = None,
  55. chunk_size: Optional[int] = None,
  56. ) -> torch.Tensor:
  57. if template_mask is None:
  58. template_mask = t.new_ones(t.shape[:-3])
  59. mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None,
  60. None, :]
  61. z = z.unsqueeze(-2)
  62. t = permute_final_dims(t, (1, 2, 0, 3))
  63. if chunk_size is not None:
  64. z = self._chunk(z, t, mask, chunk_size)
  65. else:
  66. z = self.mha(z, t, t, mask=mask)
  67. z = z.squeeze(-2)
  68. return z
  69. class TemplateProjection(nn.Module):
  70. def __init__(self, d_template, d_pair, **kwargs):
  71. super(TemplateProjection, self).__init__()
  72. self.d_pair = d_pair
  73. self.act = nn.ReLU()
  74. self.output_linear = Linear(d_template, d_pair, init='relu')
  75. def forward(self, t, z) -> torch.Tensor:
  76. if t is None:
  77. # handle for non-template case
  78. shape = z.shape
  79. shape[-1] = self.d_pair
  80. t = torch.zeros(shape, dtype=z.dtype, device=z.device)
  81. t = self.act(t)
  82. z_t = self.output_linear(t)
  83. return z_t
  84. class TemplatePairStackBlock(nn.Module):
  85. def __init__(
  86. self,
  87. d_template: int,
  88. d_hid_tri_att: int,
  89. d_hid_tri_mul: int,
  90. num_heads: int,
  91. pair_transition_n: int,
  92. dropout_rate: float,
  93. tri_attn_first: bool,
  94. inf: float,
  95. **kwargs,
  96. ):
  97. super(TemplatePairStackBlock, self).__init__()
  98. self.tri_att_start = TriangleAttentionStarting(
  99. d_template,
  100. d_hid_tri_att,
  101. num_heads,
  102. )
  103. self.tri_att_end = TriangleAttentionEnding(
  104. d_template,
  105. d_hid_tri_att,
  106. num_heads,
  107. )
  108. self.tri_mul_out = TriangleMultiplicationOutgoing(
  109. d_template,
  110. d_hid_tri_mul,
  111. )
  112. self.tri_mul_in = TriangleMultiplicationIncoming(
  113. d_template,
  114. d_hid_tri_mul,
  115. )
  116. self.pair_transition = Transition(
  117. d_template,
  118. pair_transition_n,
  119. )
  120. self.tri_attn_first = tri_attn_first
  121. self.dropout = dropout_rate
  122. self.row_dropout_share_dim = -3
  123. self.col_dropout_share_dim = -2
  124. def forward(
  125. self,
  126. s: torch.Tensor,
  127. mask: torch.Tensor,
  128. tri_start_attn_mask: torch.Tensor,
  129. tri_end_attn_mask: torch.Tensor,
  130. chunk_size: Optional[int] = None,
  131. block_size: Optional[int] = None,
  132. ):
  133. if self.tri_attn_first:
  134. s = bias_dropout_residual(
  135. self.tri_att_start,
  136. s,
  137. self.tri_att_start(
  138. s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
  139. self.row_dropout_share_dim,
  140. self.dropout,
  141. self.training,
  142. )
  143. s = bias_dropout_residual(
  144. self.tri_att_end,
  145. s,
  146. self.tri_att_end(
  147. s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
  148. self.col_dropout_share_dim,
  149. self.dropout,
  150. self.training,
  151. )
  152. s = tri_mul_residual(
  153. self.tri_mul_out,
  154. s,
  155. self.tri_mul_out(s, mask=mask, block_size=block_size),
  156. self.row_dropout_share_dim,
  157. self.dropout,
  158. self.training,
  159. block_size=block_size,
  160. )
  161. s = tri_mul_residual(
  162. self.tri_mul_in,
  163. s,
  164. self.tri_mul_in(s, mask=mask, block_size=block_size),
  165. self.row_dropout_share_dim,
  166. self.dropout,
  167. self.training,
  168. block_size=block_size,
  169. )
  170. else:
  171. s = tri_mul_residual(
  172. self.tri_mul_out,
  173. s,
  174. self.tri_mul_out(s, mask=mask, block_size=block_size),
  175. self.row_dropout_share_dim,
  176. self.dropout,
  177. self.training,
  178. block_size=block_size,
  179. )
  180. s = tri_mul_residual(
  181. self.tri_mul_in,
  182. s,
  183. self.tri_mul_in(s, mask=mask, block_size=block_size),
  184. self.row_dropout_share_dim,
  185. self.dropout,
  186. self.training,
  187. block_size=block_size,
  188. )
  189. s = bias_dropout_residual(
  190. self.tri_att_start,
  191. s,
  192. self.tri_att_start(
  193. s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
  194. self.row_dropout_share_dim,
  195. self.dropout,
  196. self.training,
  197. )
  198. s = bias_dropout_residual(
  199. self.tri_att_end,
  200. s,
  201. self.tri_att_end(
  202. s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
  203. self.col_dropout_share_dim,
  204. self.dropout,
  205. self.training,
  206. )
  207. s = residual(s, self.pair_transition(
  208. s,
  209. chunk_size=chunk_size,
  210. ), self.training)
  211. return s
  212. class TemplatePairStack(nn.Module):
  213. def __init__(
  214. self,
  215. d_template,
  216. d_hid_tri_att,
  217. d_hid_tri_mul,
  218. num_blocks,
  219. num_heads,
  220. pair_transition_n,
  221. dropout_rate,
  222. tri_attn_first,
  223. inf=1e9,
  224. **kwargs,
  225. ):
  226. super(TemplatePairStack, self).__init__()
  227. self.blocks = SimpleModuleList()
  228. for _ in range(num_blocks):
  229. self.blocks.append(
  230. TemplatePairStackBlock(
  231. d_template=d_template,
  232. d_hid_tri_att=d_hid_tri_att,
  233. d_hid_tri_mul=d_hid_tri_mul,
  234. num_heads=num_heads,
  235. pair_transition_n=pair_transition_n,
  236. dropout_rate=dropout_rate,
  237. inf=inf,
  238. tri_attn_first=tri_attn_first,
  239. ))
  240. self.layer_norm = LayerNorm(d_template)
  241. def forward(
  242. self,
  243. single_templates: Tuple[torch.Tensor],
  244. mask: torch.tensor,
  245. tri_start_attn_mask: torch.Tensor,
  246. tri_end_attn_mask: torch.Tensor,
  247. templ_dim: int,
  248. chunk_size: int,
  249. block_size: int,
  250. return_mean: bool,
  251. ):
  252. def one_template(i):
  253. (s, ) = checkpoint_sequential(
  254. functions=[
  255. partial(
  256. b,
  257. mask=mask,
  258. tri_start_attn_mask=tri_start_attn_mask,
  259. tri_end_attn_mask=tri_end_attn_mask,
  260. chunk_size=chunk_size,
  261. block_size=block_size,
  262. ) for b in self.blocks
  263. ],
  264. input=(single_templates[i], ),
  265. )
  266. return s
  267. n_templ = len(single_templates)
  268. if n_templ > 0:
  269. new_single_templates = [one_template(0)]
  270. if return_mean:
  271. t = self.layer_norm(new_single_templates[0])
  272. for i in range(1, n_templ):
  273. s = one_template(i)
  274. if return_mean:
  275. t = residual(t, self.layer_norm(s), self.training)
  276. else:
  277. new_single_templates.append(s)
  278. if return_mean:
  279. if n_templ > 0:
  280. t /= n_templ
  281. else:
  282. t = None
  283. else:
  284. t = torch.cat(
  285. [s.unsqueeze(templ_dim) for s in new_single_templates],
  286. dim=templ_dim)
  287. t = self.layer_norm(t)
  288. return t