| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450 |
- # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
- # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
- import torch
- import torch.nn as nn
- from unicore.utils import tensor_tree_map
- from ..data import residue_constants
- from .attentions import gen_msa_attn_mask, gen_tri_attn_mask
- from .auxillary_heads import AuxiliaryHeads
- from .common import residual
- from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder,
- TemplateAngleEmbedder, TemplatePairEmbedder)
- from .evoformer import EvoformerStack, ExtraMSAStack
- from .featurization import (atom14_to_atom37, build_extra_msa_feat,
- build_template_angle_feat,
- build_template_pair_feat,
- build_template_pair_feat_v2, pseudo_beta_fn)
- from .structure_module import StructureModule
- from .template import (TemplatePairStack, TemplatePointwiseAttention,
- TemplateProjection)
- class AlphaFold(nn.Module):
- def __init__(self, config):
- super(AlphaFold, self).__init__()
- self.globals = config.globals
- config = config.model
- template_config = config.template
- extra_msa_config = config.extra_msa
- self.input_embedder = InputEmbedder(
- **config['input_embedder'],
- use_chain_relative=config.is_multimer,
- )
- self.recycling_embedder = RecyclingEmbedder(
- **config['recycling_embedder'], )
- if config.template.enabled:
- self.template_angle_embedder = TemplateAngleEmbedder(
- **template_config['template_angle_embedder'], )
- self.template_pair_embedder = TemplatePairEmbedder(
- **template_config['template_pair_embedder'], )
- self.template_pair_stack = TemplatePairStack(
- **template_config['template_pair_stack'], )
- else:
- self.template_pair_stack = None
- self.enable_template_pointwise_attention = template_config[
- 'template_pointwise_attention'].enabled
- if self.enable_template_pointwise_attention:
- self.template_pointwise_att = TemplatePointwiseAttention(
- **template_config['template_pointwise_attention'], )
- else:
- self.template_proj = TemplateProjection(
- **template_config['template_pointwise_attention'], )
- self.extra_msa_embedder = ExtraMSAEmbedder(
- **extra_msa_config['extra_msa_embedder'], )
- self.extra_msa_stack = ExtraMSAStack(
- **extra_msa_config['extra_msa_stack'], )
- self.evoformer = EvoformerStack(**config['evoformer_stack'], )
- self.structure_module = StructureModule(**config['structure_module'], )
- self.aux_heads = AuxiliaryHeads(config['heads'], )
- self.config = config
- self.dtype = torch.float
- self.inf = self.globals.inf
- if self.globals.alphafold_original_mode:
- self.alphafold_original_mode()
- def __make_input_float__(self):
- self.input_embedder = self.input_embedder.float()
- self.recycling_embedder = self.recycling_embedder.float()
- def half(self):
- super().half()
- if (not getattr(self, 'inference', False)):
- self.__make_input_float__()
- self.dtype = torch.half
- return self
- def bfloat16(self):
- super().bfloat16()
- if (not getattr(self, 'inference', False)):
- self.__make_input_float__()
- self.dtype = torch.bfloat16
- return self
- def alphafold_original_mode(self):
- def set_alphafold_original_mode(module):
- if hasattr(module, 'apply_alphafold_original_mode'):
- module.apply_alphafold_original_mode()
- if hasattr(module, 'act'):
- module.act = nn.ReLU()
- self.apply(set_alphafold_original_mode)
- def inference_mode(self):
- def set_inference_mode(module):
- setattr(module, 'inference', True)
- self.apply(set_inference_mode)
- def __convert_input_dtype__(self, batch):
- for key in batch:
- # only convert features with mask
- if batch[key].dtype != self.dtype and 'mask' in key:
- batch[key] = batch[key].type(self.dtype)
- return batch
- def embed_templates_pair_core(self, batch, z, pair_mask,
- tri_start_attn_mask, tri_end_attn_mask,
- templ_dim, multichain_mask_2d):
- if self.config.template.template_pair_embedder.v2_feature:
- t = build_template_pair_feat_v2(
- batch,
- inf=self.config.template.inf,
- eps=self.config.template.eps,
- multichain_mask_2d=multichain_mask_2d,
- **self.config.template.distogram,
- )
- num_template = t[0].shape[-4]
- single_templates = [
- self.template_pair_embedder([x[..., ti, :, :, :]
- for x in t], z)
- for ti in range(num_template)
- ]
- else:
- t = build_template_pair_feat(
- batch,
- inf=self.config.template.inf,
- eps=self.config.template.eps,
- **self.config.template.distogram,
- )
- single_templates = [
- self.template_pair_embedder(x, z)
- for x in torch.unbind(t, dim=templ_dim)
- ]
- t = self.template_pair_stack(
- single_templates,
- pair_mask,
- tri_start_attn_mask=tri_start_attn_mask,
- tri_end_attn_mask=tri_end_attn_mask,
- templ_dim=templ_dim,
- chunk_size=self.globals.chunk_size,
- block_size=self.globals.block_size,
- return_mean=not self.enable_template_pointwise_attention,
- )
- return t
- def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask,
- tri_end_attn_mask, templ_dim):
- if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch:
- multichain_mask_2d = (
- batch['asym_id'][..., :, None] == batch['asym_id'][...,
- None, :])
- multichain_mask_2d = multichain_mask_2d.unsqueeze(0)
- else:
- multichain_mask_2d = None
- if self.training or self.enable_template_pointwise_attention:
- t = self.embed_templates_pair_core(batch, z, pair_mask,
- tri_start_attn_mask,
- tri_end_attn_mask, templ_dim,
- multichain_mask_2d)
- if self.enable_template_pointwise_attention:
- t = self.template_pointwise_att(
- t,
- z,
- template_mask=batch['template_mask'],
- chunk_size=self.globals.chunk_size,
- )
- t_mask = torch.sum(
- batch['template_mask'], dim=-1, keepdims=True) > 0
- t_mask = t_mask[..., None, None].type(t.dtype)
- t *= t_mask
- else:
- t = self.template_proj(t, z)
- else:
- template_aatype_shape = batch['template_aatype'].shape
- # template_aatype is either [n_template, n_res] or [1, n_template_, n_res]
- batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0
- n_templ = batch['template_aatype'].shape[batch_templ_dim]
- if n_templ <= 0:
- t = None
- else:
- template_batch = {
- k: v
- for k, v in batch.items() if k.startswith('template_')
- }
- def embed_one_template(i):
- def slice_template_tensor(t):
- s = [slice(None) for _ in t.shape]
- s[batch_templ_dim] = slice(i, i + 1)
- return t[s]
- template_feats = tensor_tree_map(
- slice_template_tensor,
- template_batch,
- )
- t = self.embed_templates_pair_core(
- template_feats, z, pair_mask, tri_start_attn_mask,
- tri_end_attn_mask, templ_dim, multichain_mask_2d)
- return t
- t = embed_one_template(0)
- # iterate templates one by one
- for i in range(1, n_templ):
- t += embed_one_template(i)
- t /= n_templ
- t = self.template_proj(t, z)
- return t
- def embed_templates_angle(self, batch):
- template_angle_feat, template_angle_mask = build_template_angle_feat(
- batch,
- v2_feature=self.config.template.template_pair_embedder.v2_feature)
- t = self.template_angle_embedder(template_angle_feat)
- return t, template_angle_mask
- def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
- batch_dims = feats['target_feat'].shape[:-2]
- n = feats['target_feat'].shape[-2]
- seq_mask = feats['seq_mask']
- pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
- msa_mask = feats['msa_mask']
- m, z = self.input_embedder(
- feats['target_feat'],
- feats['msa_feat'],
- )
- if m_1_prev is None:
- m_1_prev = m.new_zeros(
- (*batch_dims, n, self.config.input_embedder.d_msa),
- requires_grad=False,
- )
- if z_prev is None:
- z_prev = z.new_zeros(
- (*batch_dims, n, n, self.config.input_embedder.d_pair),
- requires_grad=False,
- )
- if x_prev is None:
- x_prev = z.new_zeros(
- (*batch_dims, n, residue_constants.atom_type_num, 3),
- requires_grad=False,
- )
- x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None)
- z += self.recycling_embedder.recyle_pos(x_prev)
- m_1_prev_emb, z_prev_emb = self.recycling_embedder(
- m_1_prev,
- z_prev,
- )
- m[..., 0, :, :] += m_1_prev_emb
- z += z_prev_emb
- z += self.input_embedder.relpos_emb(
- feats['residue_index'].long(),
- feats.get('sym_id', None),
- feats.get('asym_id', None),
- feats.get('entity_id', None),
- feats.get('num_sym', None),
- )
- m = m.type(self.dtype)
- z = z.type(self.dtype)
- tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(
- pair_mask, self.inf)
- if self.config.template.enabled:
- template_mask = feats['template_mask']
- if torch.any(template_mask):
- z = residual(
- z,
- self.embed_templates_pair(
- feats,
- z,
- pair_mask,
- tri_start_attn_mask,
- tri_end_attn_mask,
- templ_dim=-4,
- ),
- self.training,
- )
- if self.config.extra_msa.enabled:
- a = self.extra_msa_embedder(build_extra_msa_feat(feats))
- extra_msa_row_mask = gen_msa_attn_mask(
- feats['extra_msa_mask'],
- inf=self.inf,
- gen_col_mask=False,
- )
- z = self.extra_msa_stack(
- a,
- z,
- msa_mask=feats['extra_msa_mask'],
- chunk_size=self.globals.chunk_size,
- block_size=self.globals.block_size,
- pair_mask=pair_mask,
- msa_row_attn_mask=extra_msa_row_mask,
- msa_col_attn_mask=None,
- tri_start_attn_mask=tri_start_attn_mask,
- tri_end_attn_mask=tri_end_attn_mask,
- )
- if self.config.template.embed_angles:
- template_1d_feat, template_1d_mask = self.embed_templates_angle(
- feats)
- m = torch.cat([m, template_1d_feat], dim=-3)
- msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2)
- msa_row_mask, msa_col_mask = gen_msa_attn_mask(
- msa_mask,
- inf=self.inf,
- )
- m, z, s = self.evoformer(
- m,
- z,
- msa_mask=msa_mask,
- pair_mask=pair_mask,
- msa_row_attn_mask=msa_row_mask,
- msa_col_attn_mask=msa_col_mask,
- tri_start_attn_mask=tri_start_attn_mask,
- tri_end_attn_mask=tri_end_attn_mask,
- chunk_size=self.globals.chunk_size,
- block_size=self.globals.block_size,
- )
- return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb
- def iteration_evoformer_structure_module(self,
- batch,
- m_1_prev,
- z_prev,
- x_prev,
- cycle_no,
- num_recycling,
- num_ensembles=1):
- z, s = 0, 0
- n_seq = batch['msa_feat'].shape[-3]
- assert num_ensembles >= 1
- for ensemble_no in range(num_ensembles):
- idx = cycle_no * num_ensembles + ensemble_no
- # fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...]
- def fetch_cur_batch(t):
- return t[min(t.shape[0] - 1, idx), ...]
- feats = tensor_tree_map(fetch_cur_batch, batch)
- m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer(
- feats, m_1_prev, z_prev, x_prev)
- z += z0
- s += s0
- del z0, s0
- if num_ensembles > 1:
- z /= float(num_ensembles)
- s /= float(num_ensembles)
- outputs = {}
- outputs['msa'] = m[..., :n_seq, :, :]
- outputs['pair'] = z
- outputs['single'] = s
- # norm loss
- if (not getattr(self, 'inference',
- False)) and num_recycling == (cycle_no + 1):
- delta_msa = m
- delta_msa[...,
- 0, :, :] = delta_msa[...,
- 0, :, :] - m_1_prev_emb.detach()
- delta_pair = z - z_prev_emb.detach()
- outputs['delta_msa'] = delta_msa
- outputs['delta_pair'] = delta_pair
- outputs['msa_norm_mask'] = msa_mask
- outputs['sm'] = self.structure_module(
- s,
- z,
- feats['aatype'],
- mask=feats['seq_mask'],
- )
- outputs['final_atom_positions'] = atom14_to_atom37(
- outputs['sm']['positions'], feats)
- outputs['final_atom_mask'] = feats['atom37_atom_exists']
- outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1]
- # use float32 for numerical stability
- if (not getattr(self, 'inference', False)):
- m_1_prev = m[..., 0, :, :].float()
- z_prev = z.float()
- x_prev = outputs['final_atom_positions'].float()
- else:
- m_1_prev = m[..., 0, :, :]
- z_prev = z
- x_prev = outputs['final_atom_positions']
- return outputs, m_1_prev, z_prev, x_prev
- def forward(self, batch):
- m_1_prev = batch.get('m_1_prev', None)
- z_prev = batch.get('z_prev', None)
- x_prev = batch.get('x_prev', None)
- is_grad_enabled = torch.is_grad_enabled()
- num_iters = int(batch['num_recycling_iters']) + 1
- num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters
- if self.training:
- # don't use ensemble during training
- assert num_ensembles == 1
- # convert dtypes in batch
- batch = self.__convert_input_dtype__(batch)
- for cycle_no in range(num_iters):
- is_final_iter = cycle_no == (num_iters - 1)
- with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
- (
- outputs,
- m_1_prev,
- z_prev,
- x_prev,
- ) = self.iteration_evoformer_structure_module(
- batch,
- m_1_prev,
- z_prev,
- x_prev,
- cycle_no=cycle_no,
- num_recycling=num_iters,
- num_ensembles=num_ensembles,
- )
- if not is_final_iter:
- del outputs
- if 'asym_id' in batch:
- outputs['asym_id'] = batch['asym_id'][0, ...]
- outputs.update(self.aux_heads(outputs))
- return outputs
|