| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
- # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
- from functools import partialmethod
- from typing import List, Optional
- import torch
- import torch.nn as nn
- from unicore.modules import LayerNorm, softmax_dropout
- from unicore.utils import permute_final_dims
- from .common import Linear, chunk_layer
- def gen_attn_mask(mask, neg_inf):
- assert neg_inf < -1e4
- attn_mask = torch.zeros_like(mask)
- attn_mask[mask == 0] = neg_inf
- return attn_mask
- class Attention(nn.Module):
- def __init__(
- self,
- q_dim: int,
- k_dim: int,
- v_dim: int,
- head_dim: int,
- num_heads: int,
- gating: bool = True,
- ):
- super(Attention, self).__init__()
- self.num_heads = num_heads
- total_dim = head_dim * self.num_heads
- self.gating = gating
- self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot')
- self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot')
- self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot')
- self.linear_o = Linear(total_dim, q_dim, init='final')
- self.linear_g = None
- if self.gating:
- self.linear_g = Linear(q_dim, total_dim, init='gating')
- # precompute the 1/sqrt(head_dim)
- self.norm = head_dim**-0.5
- def forward(
- self,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- mask: torch.Tensor = None,
- bias: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- g = None
- if self.linear_g is not None:
- # gating, use raw query input
- g = self.linear_g(q)
- q = self.linear_q(q)
- q *= self.norm
- k = self.linear_k(k)
- v = self.linear_v(v)
- q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(
- -2, -3).contiguous()
- k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(
- -2, -3).contiguous()
- v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)
- attn = torch.matmul(q, k.transpose(-1, -2))
- del q, k
- attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
- o = torch.matmul(attn, v)
- del attn, v
- o = o.transpose(-2, -3).contiguous()
- o = o.view(*o.shape[:-2], -1)
- if g is not None:
- o = torch.sigmoid(g) * o
- # merge heads
- o = nn.functional.linear(o, self.linear_o.weight)
- return o
- def get_output_bias(self):
- return self.linear_o.bias
- class GlobalAttention(nn.Module):
- def __init__(self, input_dim, head_dim, num_heads, inf, eps):
- super(GlobalAttention, self).__init__()
- self.num_heads = num_heads
- self.inf = inf
- self.eps = eps
- self.linear_q = Linear(
- input_dim, head_dim * num_heads, bias=False, init='glorot')
- self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot')
- self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot')
- self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating')
- self.linear_o = Linear(head_dim * num_heads, input_dim, init='final')
- self.sigmoid = nn.Sigmoid()
- # precompute the 1/sqrt(head_dim)
- self.norm = head_dim**-0.5
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
- # gating
- g = self.sigmoid(self.linear_g(x))
- k = self.linear_k(x)
- v = self.linear_v(x)
- q = torch.sum(
- x * mask.unsqueeze(-1), dim=-2) / (
- torch.sum(mask, dim=-1, keepdims=True) + self.eps)
- q = self.linear_q(q)
- q *= self.norm
- q = q.view(q.shape[:-1] + (self.num_heads, -1))
- attn = torch.matmul(q, k.transpose(-1, -2))
- del q, k
- attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
- attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)
- o = torch.matmul(
- attn,
- v,
- )
- del attn, v
- g = g.view(g.shape[:-1] + (self.num_heads, -1))
- o = o.unsqueeze(-3) * g
- del g
- # merge heads
- o = o.reshape(o.shape[:-2] + (-1, ))
- return self.linear_o(o)
- def gen_msa_attn_mask(mask, inf, gen_col_mask=True):
- row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
- if gen_col_mask:
- col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
- None, :]
- return row_mask, col_mask
- else:
- return row_mask
- class MSAAttention(nn.Module):
- def __init__(
- self,
- d_in,
- d_hid,
- num_heads,
- pair_bias=False,
- d_pair=None,
- ):
- super(MSAAttention, self).__init__()
- self.pair_bias = pair_bias
- self.layer_norm_m = LayerNorm(d_in)
- self.layer_norm_z = None
- self.linear_z = None
- if self.pair_bias:
- self.layer_norm_z = LayerNorm(d_pair)
- self.linear_z = Linear(
- d_pair, num_heads, bias=False, init='normal')
- self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
- @torch.jit.ignore
- def _chunk(
- self,
- m: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- chunk_size: int = None,
- ) -> torch.Tensor:
- return chunk_layer(
- self._attn_forward,
- {
- 'm': m,
- 'mask': mask,
- 'bias': bias
- },
- chunk_size=chunk_size,
- num_batch_dims=len(m.shape[:-2]),
- )
- @torch.jit.ignore
- def _attn_chunk_forward(
- self,
- m: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = 2560,
- ) -> torch.Tensor:
- m = self.layer_norm_m(m)
- num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size
- outputs = []
- for i in range(num_chunk):
- chunk_start = i * chunk_size
- chunk_end = min(m.shape[-3], chunk_start + chunk_size)
- cur_m = m[..., chunk_start:chunk_end, :, :]
- cur_mask = (
- mask[..., chunk_start:chunk_end, :, :, :]
- if mask is not None else None)
- outputs.append(
- self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
- return torch.cat(outputs, dim=-3)
- def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
- m = self.layer_norm_m(m)
- return self.mha(q=m, k=m, v=m, mask=mask, bias=bias)
- def forward(
- self,
- m: torch.Tensor,
- z: Optional[torch.Tensor] = None,
- attn_mask: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = None,
- ) -> torch.Tensor:
- bias = None
- if self.pair_bias:
- z = self.layer_norm_z(z)
- bias = (
- permute_final_dims(self.linear_z(z),
- (2, 0, 1)).unsqueeze(-4).contiguous())
- if chunk_size is not None:
- m = self._chunk(m, attn_mask, bias, chunk_size)
- else:
- attn_chunk_size = 2560
- if m.shape[-3] <= attn_chunk_size:
- m = self._attn_forward(m, attn_mask, bias)
- else:
- # reduce the peak memory cost in extra_msa_stack
- return self._attn_chunk_forward(
- m, attn_mask, bias, chunk_size=attn_chunk_size)
- return m
- def get_output_bias(self):
- return self.mha.get_output_bias()
- class MSARowAttentionWithPairBias(MSAAttention):
- def __init__(self, d_msa, d_pair, d_hid, num_heads):
- super(MSARowAttentionWithPairBias, self).__init__(
- d_msa,
- d_hid,
- num_heads,
- pair_bias=True,
- d_pair=d_pair,
- )
- class MSAColumnAttention(MSAAttention):
- def __init__(self, d_msa, d_hid, num_heads):
- super(MSAColumnAttention, self).__init__(
- d_in=d_msa,
- d_hid=d_hid,
- num_heads=num_heads,
- pair_bias=False,
- d_pair=None,
- )
- def forward(
- self,
- m: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = None,
- ) -> torch.Tensor:
- m = m.transpose(-2, -3)
- m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size)
- m = m.transpose(-2, -3)
- return m
- class MSAColumnGlobalAttention(nn.Module):
- def __init__(
- self,
- d_in,
- d_hid,
- num_heads,
- inf=1e9,
- eps=1e-10,
- ):
- super(MSAColumnGlobalAttention, self).__init__()
- self.layer_norm_m = LayerNorm(d_in)
- self.global_attention = GlobalAttention(
- d_in,
- d_hid,
- num_heads,
- inf=inf,
- eps=eps,
- )
- @torch.jit.ignore
- def _chunk(
- self,
- m: torch.Tensor,
- mask: torch.Tensor,
- chunk_size: int,
- ) -> torch.Tensor:
- return chunk_layer(
- self._attn_forward,
- {
- 'm': m,
- 'mask': mask
- },
- chunk_size=chunk_size,
- num_batch_dims=len(m.shape[:-2]),
- )
- def _attn_forward(self, m, mask):
- m = self.layer_norm_m(m)
- return self.global_attention(m, mask=mask)
- def forward(
- self,
- m: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = None,
- ) -> torch.Tensor:
- m = m.transpose(-2, -3)
- mask = mask.transpose(-1, -2)
- if chunk_size is not None:
- m = self._chunk(m, mask, chunk_size)
- else:
- m = self._attn_forward(m, mask=mask)
- m = m.transpose(-2, -3)
- return m
- def gen_tri_attn_mask(mask, inf):
- start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
- end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
- None, :]
- return start_mask, end_mask
- class TriangleAttention(nn.Module):
- def __init__(
- self,
- d_in,
- d_hid,
- num_heads,
- starting,
- ):
- super(TriangleAttention, self).__init__()
- self.starting = starting
- self.layer_norm = LayerNorm(d_in)
- self.linear = Linear(d_in, num_heads, bias=False, init='normal')
- self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)
- @torch.jit.ignore
- def _chunk(
- self,
- x: torch.Tensor,
- mask: Optional[torch.Tensor] = None,
- bias: Optional[torch.Tensor] = None,
- chunk_size: int = None,
- ) -> torch.Tensor:
- return chunk_layer(
- self.mha,
- {
- 'q': x,
- 'k': x,
- 'v': x,
- 'mask': mask,
- 'bias': bias
- },
- chunk_size=chunk_size,
- num_batch_dims=len(x.shape[:-2]),
- )
- def forward(
- self,
- x: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- chunk_size: Optional[int] = None,
- ) -> torch.Tensor:
- if not self.starting:
- x = x.transpose(-2, -3)
- x = self.layer_norm(x)
- triangle_bias = (
- permute_final_dims(self.linear(x),
- (2, 0, 1)).unsqueeze(-4).contiguous())
- if chunk_size is not None:
- x = self._chunk(x, attn_mask, triangle_bias, chunk_size)
- else:
- x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias)
- if not self.starting:
- x = x.transpose(-2, -3)
- return x
- def get_output_bias(self):
- return self.mha.get_output_bias()
- class TriangleAttentionStarting(TriangleAttention):
- __init__ = partialmethod(TriangleAttention.__init__, starting=True)
- class TriangleAttentionEnding(TriangleAttention):
- __init__ = partialmethod(TriangleAttention.__init__, starting=False)
|