attentions.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from functools import partialmethod
  4. from typing import List, Optional
  5. import torch
  6. import torch.nn as nn
  7. from unicore.modules import LayerNorm, softmax_dropout
  8. from unicore.utils import permute_final_dims
  9. from .common import Linear, chunk_layer
  10. def gen_attn_mask(mask, neg_inf):
  11. assert neg_inf < -1e4
  12. attn_mask = torch.zeros_like(mask)
  13. attn_mask[mask == 0] = neg_inf
  14. return attn_mask
  15. class Attention(nn.Module):
  16. def __init__(
  17. self,
  18. q_dim: int,
  19. k_dim: int,
  20. v_dim: int,
  21. head_dim: int,
  22. num_heads: int,
  23. gating: bool = True,
  24. ):
  25. super(Attention, self).__init__()
  26. self.num_heads = num_heads
  27. total_dim = head_dim * self.num_heads
  28. self.gating = gating
  29. self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot')
  30. self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot')
  31. self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot')
  32. self.linear_o = Linear(total_dim, q_dim, init='final')
  33. self.linear_g = None
  34. if self.gating:
  35. self.linear_g = Linear(q_dim, total_dim, init='gating')
  36. # precompute the 1/sqrt(head_dim)
  37. self.norm = head_dim**-0.5
  38. def forward(
  39. self,
  40. q: torch.Tensor,
  41. k: torch.Tensor,
  42. v: torch.Tensor,
  43. mask: torch.Tensor = None,
  44. bias: Optional[torch.Tensor] = None,
  45. ) -> torch.Tensor:
  46. g = None
  47. if self.linear_g is not None:
  48. # gating, use raw query input
  49. g = self.linear_g(q)
  50. q = self.linear_q(q)
  51. q *= self.norm
  52. k = self.linear_k(k)
  53. v = self.linear_v(v)
  54. q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(
  55. -2, -3).contiguous()
  56. k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(
  57. -2, -3).contiguous()
  58. v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)
  59. attn = torch.matmul(q, k.transpose(-1, -2))
  60. del q, k
  61. attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
  62. o = torch.matmul(attn, v)
  63. del attn, v
  64. o = o.transpose(-2, -3).contiguous()
  65. o = o.view(*o.shape[:-2], -1)
  66. if g is not None:
  67. o = torch.sigmoid(g) * o
  68. # merge heads
  69. o = nn.functional.linear(o, self.linear_o.weight)
  70. return o
  71. def get_output_bias(self):
  72. return self.linear_o.bias
  73. class GlobalAttention(nn.Module):
  74. def __init__(self, input_dim, head_dim, num_heads, inf, eps):
  75. super(GlobalAttention, self).__init__()
  76. self.num_heads = num_heads
  77. self.inf = inf
  78. self.eps = eps
  79. self.linear_q = Linear(
  80. input_dim, head_dim * num_heads, bias=False, init='glorot')
  81. self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot')
  82. self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot')
  83. self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating')
  84. self.linear_o = Linear(head_dim * num_heads, input_dim, init='final')
  85. self.sigmoid = nn.Sigmoid()
  86. # precompute the 1/sqrt(head_dim)
  87. self.norm = head_dim**-0.5
  88. def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
  89. # gating
  90. g = self.sigmoid(self.linear_g(x))
  91. k = self.linear_k(x)
  92. v = self.linear_v(x)
  93. q = torch.sum(
  94. x * mask.unsqueeze(-1), dim=-2) / (
  95. torch.sum(mask, dim=-1, keepdims=True) + self.eps)
  96. q = self.linear_q(q)
  97. q *= self.norm
  98. q = q.view(q.shape[:-1] + (self.num_heads, -1))
  99. attn = torch.matmul(q, k.transpose(-1, -2))
  100. del q, k
  101. attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
  102. attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)
  103. o = torch.matmul(
  104. attn,
  105. v,
  106. )
  107. del attn, v
  108. g = g.view(g.shape[:-1] + (self.num_heads, -1))
  109. o = o.unsqueeze(-3) * g
  110. del g
  111. # merge heads
  112. o = o.reshape(o.shape[:-2] + (-1, ))
  113. return self.linear_o(o)
  114. def gen_msa_attn_mask(mask, inf, gen_col_mask=True):
  115. row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
  116. if gen_col_mask:
  117. col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
  118. None, :]
  119. return row_mask, col_mask
  120. else:
  121. return row_mask
  122. class MSAAttention(nn.Module):
  123. def __init__(
  124. self,
  125. d_in,
  126. d_hid,
  127. num_heads,
  128. pair_bias=False,
  129. d_pair=None,
  130. ):
  131. super(MSAAttention, self).__init__()
  132. self.pair_bias = pair_bias
  133. self.layer_norm_m = LayerNorm(d_in)
  134. self.layer_norm_z = None
  135. self.linear_z = None
  136. if self.pair_bias:
  137. self.layer_norm_z = LayerNorm(d_pair)
  138. self.linear_z = Linear(
  139. d_pair, num_heads, bias=False, init='normal')
  140. self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
  141. @torch.jit.ignore
  142. def _chunk(
  143. self,
  144. m: torch.Tensor,
  145. mask: Optional[torch.Tensor] = None,
  146. bias: Optional[torch.Tensor] = None,
  147. chunk_size: int = None,
  148. ) -> torch.Tensor:
  149. return chunk_layer(
  150. self._attn_forward,
  151. {
  152. 'm': m,
  153. 'mask': mask,
  154. 'bias': bias
  155. },
  156. chunk_size=chunk_size,
  157. num_batch_dims=len(m.shape[:-2]),
  158. )
  159. @torch.jit.ignore
  160. def _attn_chunk_forward(
  161. self,
  162. m: torch.Tensor,
  163. mask: Optional[torch.Tensor] = None,
  164. bias: Optional[torch.Tensor] = None,
  165. chunk_size: Optional[int] = 2560,
  166. ) -> torch.Tensor:
  167. m = self.layer_norm_m(m)
  168. num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size
  169. outputs = []
  170. for i in range(num_chunk):
  171. chunk_start = i * chunk_size
  172. chunk_end = min(m.shape[-3], chunk_start + chunk_size)
  173. cur_m = m[..., chunk_start:chunk_end, :, :]
  174. cur_mask = (
  175. mask[..., chunk_start:chunk_end, :, :, :]
  176. if mask is not None else None)
  177. outputs.append(
  178. self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
  179. return torch.cat(outputs, dim=-3)
  180. def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
  181. m = self.layer_norm_m(m)
  182. return self.mha(q=m, k=m, v=m, mask=mask, bias=bias)
  183. def forward(
  184. self,
  185. m: torch.Tensor,
  186. z: Optional[torch.Tensor] = None,
  187. attn_mask: Optional[torch.Tensor] = None,
  188. chunk_size: Optional[int] = None,
  189. ) -> torch.Tensor:
  190. bias = None
  191. if self.pair_bias:
  192. z = self.layer_norm_z(z)
  193. bias = (
  194. permute_final_dims(self.linear_z(z),
  195. (2, 0, 1)).unsqueeze(-4).contiguous())
  196. if chunk_size is not None:
  197. m = self._chunk(m, attn_mask, bias, chunk_size)
  198. else:
  199. attn_chunk_size = 2560
  200. if m.shape[-3] <= attn_chunk_size:
  201. m = self._attn_forward(m, attn_mask, bias)
  202. else:
  203. # reduce the peak memory cost in extra_msa_stack
  204. return self._attn_chunk_forward(
  205. m, attn_mask, bias, chunk_size=attn_chunk_size)
  206. return m
  207. def get_output_bias(self):
  208. return self.mha.get_output_bias()
  209. class MSARowAttentionWithPairBias(MSAAttention):
  210. def __init__(self, d_msa, d_pair, d_hid, num_heads):
  211. super(MSARowAttentionWithPairBias, self).__init__(
  212. d_msa,
  213. d_hid,
  214. num_heads,
  215. pair_bias=True,
  216. d_pair=d_pair,
  217. )
  218. class MSAColumnAttention(MSAAttention):
  219. def __init__(self, d_msa, d_hid, num_heads):
  220. super(MSAColumnAttention, self).__init__(
  221. d_in=d_msa,
  222. d_hid=d_hid,
  223. num_heads=num_heads,
  224. pair_bias=False,
  225. d_pair=None,
  226. )
  227. def forward(
  228. self,
  229. m: torch.Tensor,
  230. attn_mask: Optional[torch.Tensor] = None,
  231. chunk_size: Optional[int] = None,
  232. ) -> torch.Tensor:
  233. m = m.transpose(-2, -3)
  234. m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size)
  235. m = m.transpose(-2, -3)
  236. return m
  237. class MSAColumnGlobalAttention(nn.Module):
  238. def __init__(
  239. self,
  240. d_in,
  241. d_hid,
  242. num_heads,
  243. inf=1e9,
  244. eps=1e-10,
  245. ):
  246. super(MSAColumnGlobalAttention, self).__init__()
  247. self.layer_norm_m = LayerNorm(d_in)
  248. self.global_attention = GlobalAttention(
  249. d_in,
  250. d_hid,
  251. num_heads,
  252. inf=inf,
  253. eps=eps,
  254. )
  255. @torch.jit.ignore
  256. def _chunk(
  257. self,
  258. m: torch.Tensor,
  259. mask: torch.Tensor,
  260. chunk_size: int,
  261. ) -> torch.Tensor:
  262. return chunk_layer(
  263. self._attn_forward,
  264. {
  265. 'm': m,
  266. 'mask': mask
  267. },
  268. chunk_size=chunk_size,
  269. num_batch_dims=len(m.shape[:-2]),
  270. )
  271. def _attn_forward(self, m, mask):
  272. m = self.layer_norm_m(m)
  273. return self.global_attention(m, mask=mask)
  274. def forward(
  275. self,
  276. m: torch.Tensor,
  277. mask: Optional[torch.Tensor] = None,
  278. chunk_size: Optional[int] = None,
  279. ) -> torch.Tensor:
  280. m = m.transpose(-2, -3)
  281. mask = mask.transpose(-1, -2)
  282. if chunk_size is not None:
  283. m = self._chunk(m, mask, chunk_size)
  284. else:
  285. m = self._attn_forward(m, mask=mask)
  286. m = m.transpose(-2, -3)
  287. return m
  288. def gen_tri_attn_mask(mask, inf):
  289. start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
  290. end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
  291. None, :]
  292. return start_mask, end_mask
  293. class TriangleAttention(nn.Module):
  294. def __init__(
  295. self,
  296. d_in,
  297. d_hid,
  298. num_heads,
  299. starting,
  300. ):
  301. super(TriangleAttention, self).__init__()
  302. self.starting = starting
  303. self.layer_norm = LayerNorm(d_in)
  304. self.linear = Linear(d_in, num_heads, bias=False, init='normal')
  305. self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
  306. @torch.jit.ignore
  307. def _chunk(
  308. self,
  309. x: torch.Tensor,
  310. mask: Optional[torch.Tensor] = None,
  311. bias: Optional[torch.Tensor] = None,
  312. chunk_size: int = None,
  313. ) -> torch.Tensor:
  314. return chunk_layer(
  315. self.mha,
  316. {
  317. 'q': x,
  318. 'k': x,
  319. 'v': x,
  320. 'mask': mask,
  321. 'bias': bias
  322. },
  323. chunk_size=chunk_size,
  324. num_batch_dims=len(x.shape[:-2]),
  325. )
  326. def forward(
  327. self,
  328. x: torch.Tensor,
  329. attn_mask: Optional[torch.Tensor] = None,
  330. chunk_size: Optional[int] = None,
  331. ) -> torch.Tensor:
  332. if not self.starting:
  333. x = x.transpose(-2, -3)
  334. x = self.layer_norm(x)
  335. triangle_bias = (
  336. permute_final_dims(self.linear(x),
  337. (2, 0, 1)).unsqueeze(-4).contiguous())
  338. if chunk_size is not None:
  339. x = self._chunk(x, attn_mask, triangle_bias, chunk_size)
  340. else:
  341. x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias)
  342. if not self.starting:
  343. x = x.transpose(-2, -3)
  344. return x
  345. def get_output_bias(self):
  346. return self.mha.get_output_bias()
  347. class TriangleAttentionStarting(TriangleAttention):
  348. __init__ = partialmethod(TriangleAttention.__init__, starting=True)
  349. class TriangleAttentionEnding(TriangleAttention):
  350. __init__ = partialmethod(TriangleAttention.__init__, starting=False)