| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636 |
- # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
- # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
- import copy
- from typing import Any
- import ml_collections as mlc
- N_RES = 'number of residues'
- N_MSA = 'number of MSA sequences'
- N_EXTRA_MSA = 'number of extra MSA sequences'
- N_TPL = 'number of templates'
- d_pair = mlc.FieldReference(128, field_type=int)
- d_msa = mlc.FieldReference(256, field_type=int)
- d_template = mlc.FieldReference(64, field_type=int)
- d_extra_msa = mlc.FieldReference(64, field_type=int)
- d_single = mlc.FieldReference(384, field_type=int)
- max_recycling_iters = mlc.FieldReference(3, field_type=int)
- chunk_size = mlc.FieldReference(4, field_type=int)
- aux_distogram_bins = mlc.FieldReference(64, field_type=int)
- eps = mlc.FieldReference(1e-8, field_type=float)
- inf = mlc.FieldReference(3e4, field_type=float)
- use_templates = mlc.FieldReference(True, field_type=bool)
- is_multimer = mlc.FieldReference(False, field_type=bool)
- def base_config():
- return mlc.ConfigDict({
- 'data': {
- 'common': {
- 'features': {
- 'aatype': [N_RES],
- 'all_atom_mask': [N_RES, None],
- 'all_atom_positions': [N_RES, None, None],
- 'alt_chi_angles': [N_RES, None],
- 'atom14_alt_gt_exists': [N_RES, None],
- 'atom14_alt_gt_positions': [N_RES, None, None],
- 'atom14_atom_exists': [N_RES, None],
- 'atom14_atom_is_ambiguous': [N_RES, None],
- 'atom14_gt_exists': [N_RES, None],
- 'atom14_gt_positions': [N_RES, None, None],
- 'atom37_atom_exists': [N_RES, None],
- 'frame_mask': [N_RES],
- 'true_frame_tensor': [N_RES, None, None],
- 'bert_mask': [N_MSA, N_RES],
- 'chi_angles_sin_cos': [N_RES, None, None],
- 'chi_mask': [N_RES, None],
- 'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES],
- 'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES],
- 'extra_msa': [N_EXTRA_MSA, N_RES],
- 'extra_msa_mask': [N_EXTRA_MSA, N_RES],
- 'extra_msa_row_mask': [N_EXTRA_MSA],
- 'is_distillation': [],
- 'msa_feat': [N_MSA, N_RES, None],
- 'msa_mask': [N_MSA, N_RES],
- 'msa_chains': [N_MSA, None],
- 'msa_row_mask': [N_MSA],
- 'num_recycling_iters': [],
- 'pseudo_beta': [N_RES, None],
- 'pseudo_beta_mask': [N_RES],
- 'residue_index': [N_RES],
- 'residx_atom14_to_atom37': [N_RES, None],
- 'residx_atom37_to_atom14': [N_RES, None],
- 'resolution': [],
- 'rigidgroups_alt_gt_frames': [N_RES, None, None, None],
- 'rigidgroups_group_exists': [N_RES, None],
- 'rigidgroups_group_is_ambiguous': [N_RES, None],
- 'rigidgroups_gt_exists': [N_RES, None],
- 'rigidgroups_gt_frames': [N_RES, None, None, None],
- 'seq_length': [],
- 'seq_mask': [N_RES],
- 'target_feat': [N_RES, None],
- 'template_aatype': [N_TPL, N_RES],
- 'template_all_atom_mask': [N_TPL, N_RES, None],
- 'template_all_atom_positions': [N_TPL, N_RES, None, None],
- 'template_alt_torsion_angles_sin_cos': [
- N_TPL,
- N_RES,
- None,
- None,
- ],
- 'template_frame_mask': [N_TPL, N_RES],
- 'template_frame_tensor': [N_TPL, N_RES, None, None],
- 'template_mask': [N_TPL],
- 'template_pseudo_beta': [N_TPL, N_RES, None],
- 'template_pseudo_beta_mask': [N_TPL, N_RES],
- 'template_sum_probs': [N_TPL, None],
- 'template_torsion_angles_mask': [N_TPL, N_RES, None],
- 'template_torsion_angles_sin_cos':
- [N_TPL, N_RES, None, None],
- 'true_msa': [N_MSA, N_RES],
- 'use_clamped_fape': [],
- 'assembly_num_chains': [1],
- 'asym_id': [N_RES],
- 'sym_id': [N_RES],
- 'entity_id': [N_RES],
- 'num_sym': [N_RES],
- 'asym_len': [None],
- 'cluster_bias_mask': [N_MSA],
- },
- 'masked_msa': {
- 'profile_prob': 0.1,
- 'same_prob': 0.1,
- 'uniform_prob': 0.1,
- },
- 'block_delete_msa': {
- 'msa_fraction_per_block': 0.3,
- 'randomize_num_blocks': False,
- 'num_blocks': 5,
- 'min_num_msa': 16,
- },
- 'random_delete_msa': {
- 'max_msa_entry': 1 << 25, # := 33554432
- },
- 'v2_feature':
- False,
- 'gumbel_sample':
- False,
- 'max_extra_msa':
- 1024,
- 'msa_cluster_features':
- True,
- 'reduce_msa_clusters_by_max_templates':
- True,
- 'resample_msa_in_recycling':
- True,
- 'template_features': [
- 'template_all_atom_positions',
- 'template_sum_probs',
- 'template_aatype',
- 'template_all_atom_mask',
- ],
- 'unsupervised_features': [
- 'aatype',
- 'residue_index',
- 'msa',
- 'msa_chains',
- 'num_alignments',
- 'seq_length',
- 'between_segment_residues',
- 'deletion_matrix',
- 'num_recycling_iters',
- 'crop_and_fix_size_seed',
- ],
- 'recycling_features': [
- 'msa_chains',
- 'msa_mask',
- 'msa_row_mask',
- 'bert_mask',
- 'true_msa',
- 'msa_feat',
- 'extra_msa_deletion_value',
- 'extra_msa_has_deletion',
- 'extra_msa',
- 'extra_msa_mask',
- 'extra_msa_row_mask',
- 'is_distillation',
- ],
- 'multimer_features': [
- 'assembly_num_chains',
- 'asym_id',
- 'sym_id',
- 'num_sym',
- 'entity_id',
- 'asym_len',
- 'cluster_bias_mask',
- ],
- 'use_templates':
- use_templates,
- 'is_multimer':
- is_multimer,
- 'use_template_torsion_angles':
- use_templates,
- 'max_recycling_iters':
- max_recycling_iters,
- },
- 'supervised': {
- 'use_clamped_fape_prob':
- 1.0,
- 'supervised_features': [
- 'all_atom_mask',
- 'all_atom_positions',
- 'resolution',
- 'use_clamped_fape',
- 'is_distillation',
- ],
- },
- 'predict': {
- 'fixed_size': True,
- 'subsample_templates': False,
- 'block_delete_msa': False,
- 'random_delete_msa': True,
- 'masked_msa_replace_fraction': 0.15,
- 'max_msa_clusters': 128,
- 'max_templates': 4,
- 'num_ensembles': 2,
- 'crop': False,
- 'crop_size': None,
- 'supervised': False,
- 'biased_msa_by_chain': False,
- 'share_mask': False,
- },
- 'eval': {
- 'fixed_size': True,
- 'subsample_templates': False,
- 'block_delete_msa': False,
- 'random_delete_msa': True,
- 'masked_msa_replace_fraction': 0.15,
- 'max_msa_clusters': 128,
- 'max_templates': 4,
- 'num_ensembles': 1,
- 'crop': False,
- 'crop_size': None,
- 'spatial_crop_prob': 0.5,
- 'ca_ca_threshold': 10.0,
- 'supervised': True,
- 'biased_msa_by_chain': False,
- 'share_mask': False,
- },
- 'train': {
- 'fixed_size': True,
- 'subsample_templates': True,
- 'block_delete_msa': True,
- 'random_delete_msa': True,
- 'masked_msa_replace_fraction': 0.15,
- 'max_msa_clusters': 128,
- 'max_templates': 4,
- 'num_ensembles': 1,
- 'crop': True,
- 'crop_size': 256,
- 'spatial_crop_prob': 0.5,
- 'ca_ca_threshold': 10.0,
- 'supervised': True,
- 'use_clamped_fape_prob': 1.0,
- 'max_distillation_msa_clusters': 1000,
- 'biased_msa_by_chain': True,
- 'share_mask': True,
- },
- },
- 'globals': {
- 'chunk_size': chunk_size,
- 'block_size': None,
- 'd_pair': d_pair,
- 'd_msa': d_msa,
- 'd_template': d_template,
- 'd_extra_msa': d_extra_msa,
- 'd_single': d_single,
- 'eps': eps,
- 'inf': inf,
- 'max_recycling_iters': max_recycling_iters,
- 'alphafold_original_mode': False,
- },
- 'model': {
- 'is_multimer': is_multimer,
- 'input_embedder': {
- 'tf_dim': 22,
- 'msa_dim': 49,
- 'd_pair': d_pair,
- 'd_msa': d_msa,
- 'relpos_k': 32,
- 'max_relative_chain': 2,
- },
- 'recycling_embedder': {
- 'd_pair': d_pair,
- 'd_msa': d_msa,
- 'min_bin': 3.25,
- 'max_bin': 20.75,
- 'num_bins': 15,
- 'inf': 1e8,
- },
- 'template': {
- 'distogram': {
- 'min_bin': 3.25,
- 'max_bin': 50.75,
- 'num_bins': 39,
- },
- 'template_angle_embedder': {
- 'd_in': 57,
- 'd_out': d_msa,
- },
- 'template_pair_embedder': {
- 'd_in': 88,
- 'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1],
- 'd_pair': d_pair,
- 'd_out': d_template,
- 'v2_feature': False,
- },
- 'template_pair_stack': {
- 'd_template': d_template,
- 'd_hid_tri_att': 16,
- 'd_hid_tri_mul': 64,
- 'num_blocks': 2,
- 'num_heads': 4,
- 'pair_transition_n': 2,
- 'dropout_rate': 0.25,
- 'inf': 1e9,
- 'tri_attn_first': True,
- },
- 'template_pointwise_attention': {
- 'enabled': True,
- 'd_template': d_template,
- 'd_pair': d_pair,
- 'd_hid': 16,
- 'num_heads': 4,
- 'inf': 1e5,
- },
- 'inf': 1e5,
- 'eps': 1e-6,
- 'enabled': use_templates,
- 'embed_angles': use_templates,
- },
- 'extra_msa': {
- 'extra_msa_embedder': {
- 'd_in': 25,
- 'd_out': d_extra_msa,
- },
- 'extra_msa_stack': {
- 'd_msa': d_extra_msa,
- 'd_pair': d_pair,
- 'd_hid_msa_att': 8,
- 'd_hid_opm': 32,
- 'd_hid_mul': 128,
- 'd_hid_pair_att': 32,
- 'num_heads_msa': 8,
- 'num_heads_pair': 4,
- 'num_blocks': 4,
- 'transition_n': 4,
- 'msa_dropout': 0.15,
- 'pair_dropout': 0.25,
- 'inf': 1e9,
- 'eps': 1e-10,
- 'outer_product_mean_first': False,
- },
- 'enabled': True,
- },
- 'evoformer_stack': {
- 'd_msa': d_msa,
- 'd_pair': d_pair,
- 'd_hid_msa_att': 32,
- 'd_hid_opm': 32,
- 'd_hid_mul': 128,
- 'd_hid_pair_att': 32,
- 'd_single': d_single,
- 'num_heads_msa': 8,
- 'num_heads_pair': 4,
- 'num_blocks': 48,
- 'transition_n': 4,
- 'msa_dropout': 0.15,
- 'pair_dropout': 0.25,
- 'inf': 1e9,
- 'eps': 1e-10,
- 'outer_product_mean_first': False,
- },
- 'structure_module': {
- 'd_single': d_single,
- 'd_pair': d_pair,
- 'd_ipa': 16,
- 'd_angle': 128,
- 'num_heads_ipa': 12,
- 'num_qk_points': 4,
- 'num_v_points': 8,
- 'dropout_rate': 0.1,
- 'num_blocks': 8,
- 'no_transition_layers': 1,
- 'num_resnet_blocks': 2,
- 'num_angles': 7,
- 'trans_scale_factor': 10,
- 'epsilon': 1e-12,
- 'inf': 1e5,
- 'separate_kv': False,
- 'ipa_bias': True,
- },
- 'heads': {
- 'plddt': {
- 'num_bins': 50,
- 'd_in': d_single,
- 'd_hid': 128,
- },
- 'distogram': {
- 'd_pair': d_pair,
- 'num_bins': aux_distogram_bins,
- 'disable_enhance_head': False,
- },
- 'pae': {
- 'd_pair': d_pair,
- 'num_bins': aux_distogram_bins,
- 'enabled': False,
- 'iptm_weight': 0.8,
- 'disable_enhance_head': False,
- },
- 'masked_msa': {
- 'd_msa': d_msa,
- 'd_out': 23,
- 'disable_enhance_head': False,
- },
- 'experimentally_resolved': {
- 'd_single': d_single,
- 'd_out': 37,
- 'enabled': False,
- 'disable_enhance_head': False,
- },
- },
- },
- 'loss': {
- 'distogram': {
- 'min_bin': 2.3125,
- 'max_bin': 21.6875,
- 'num_bins': 64,
- 'eps': 1e-6,
- 'weight': 0.3,
- },
- 'experimentally_resolved': {
- 'eps': 1e-8,
- 'min_resolution': 0.1,
- 'max_resolution': 3.0,
- 'weight': 0.0,
- },
- 'fape': {
- 'backbone': {
- 'clamp_distance': 10.0,
- 'clamp_distance_between_chains': 30.0,
- 'loss_unit_distance': 10.0,
- 'loss_unit_distance_between_chains': 20.0,
- 'weight': 0.5,
- 'eps': 1e-4,
- },
- 'sidechain': {
- 'clamp_distance': 10.0,
- 'length_scale': 10.0,
- 'weight': 0.5,
- 'eps': 1e-4,
- },
- 'weight': 1.0,
- },
- 'plddt': {
- 'min_resolution': 0.1,
- 'max_resolution': 3.0,
- 'cutoff': 15.0,
- 'num_bins': 50,
- 'eps': 1e-10,
- 'weight': 0.01,
- },
- 'masked_msa': {
- 'eps': 1e-8,
- 'weight': 2.0,
- },
- 'supervised_chi': {
- 'chi_weight': 0.5,
- 'angle_norm_weight': 0.01,
- 'eps': 1e-6,
- 'weight': 1.0,
- },
- 'violation': {
- 'violation_tolerance_factor': 12.0,
- 'clash_overlap_tolerance': 1.5,
- 'bond_angle_loss_weight': 0.3,
- 'eps': 1e-6,
- 'weight': 0.0,
- },
- 'pae': {
- 'max_bin': 31,
- 'num_bins': 64,
- 'min_resolution': 0.1,
- 'max_resolution': 3.0,
- 'eps': 1e-8,
- 'weight': 0.0,
- },
- 'repr_norm': {
- 'weight': 0.01,
- 'tolerance': 1.0,
- },
- 'chain_centre_mass': {
- 'weight': 0.0,
- 'eps': 1e-8,
- },
- },
- })
- def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None):
- with c.unlocked():
- for k, v in c.items():
- if ignore is not None and k == ignore:
- continue
- if isinstance(v, mlc.ConfigDict):
- recursive_set(v, key, value)
- elif k == key:
- c[k] = value
- def model_config(name, train=False):
- c = copy.deepcopy(base_config())
- def model_2_v2(c):
- recursive_set(c, 'v2_feature', True)
- recursive_set(c, 'gumbel_sample', True)
- c.model.heads.masked_msa.d_out = 22
- c.model.structure_module.separate_kv = True
- c.model.structure_module.ipa_bias = False
- c.model.template.template_angle_embedder.d_in = 34
- return c
- def multimer(c):
- recursive_set(c, 'is_multimer', True)
- recursive_set(c, 'max_extra_msa', 1152)
- recursive_set(c, 'max_msa_clusters', 128)
- recursive_set(c, 'v2_feature', True)
- recursive_set(c, 'gumbel_sample', True)
- c.model.template.template_angle_embedder.d_in = 34
- c.model.template.template_pair_stack.tri_attn_first = False
- c.model.template.template_pointwise_attention.enabled = False
- c.model.heads.pae.enabled = True
- # we forget to enable it in our training, so disable it here
- c.model.heads.pae.disable_enhance_head = True
- c.model.heads.masked_msa.d_out = 22
- c.model.structure_module.separate_kv = True
- c.model.structure_module.ipa_bias = False
- c.model.structure_module.trans_scale_factor = 20
- c.loss.pae.weight = 0.1
- c.model.input_embedder.tf_dim = 21
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- c.loss.chain_centre_mass.weight = 1.0
- return c
- if name == 'model_1':
- pass
- elif name == 'model_1_ft':
- recursive_set(c, 'max_extra_msa', 5120)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- elif name == 'model_1_af2':
- recursive_set(c, 'max_extra_msa', 5120)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- c.loss.repr_norm.weight = 0
- c.model.heads.experimentally_resolved.enabled = True
- c.loss.experimentally_resolved.weight = 0.01
- c.globals.alphafold_original_mode = True
- elif name == 'model_2':
- pass
- elif name == 'model_init':
- pass
- elif name == 'model_init_af2':
- c.globals.alphafold_original_mode = True
- pass
- elif name == 'model_2_ft':
- recursive_set(c, 'max_extra_msa', 1024)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- elif name == 'model_2_af2':
- recursive_set(c, 'max_extra_msa', 1024)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- c.loss.repr_norm.weight = 0
- c.model.heads.experimentally_resolved.enabled = True
- c.loss.experimentally_resolved.weight = 0.01
- c.globals.alphafold_original_mode = True
- elif name == 'model_2_v2':
- c = model_2_v2(c)
- elif name == 'model_2_v2_ft':
- c = model_2_v2(c)
- recursive_set(c, 'max_extra_msa', 1024)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- elif name == 'model_3_af2' or name == 'model_4_af2':
- recursive_set(c, 'max_extra_msa', 5120)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- c.loss.repr_norm.weight = 0
- c.model.heads.experimentally_resolved.enabled = True
- c.loss.experimentally_resolved.weight = 0.01
- c.globals.alphafold_original_mode = True
- c.model.template.enabled = False
- c.model.template.embed_angles = False
- recursive_set(c, 'use_templates', False)
- recursive_set(c, 'use_template_torsion_angles', False)
- elif name == 'model_5_af2':
- recursive_set(c, 'max_extra_msa', 1024)
- recursive_set(c, 'max_msa_clusters', 512)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.02
- c.loss.repr_norm.weight = 0
- c.model.heads.experimentally_resolved.enabled = True
- c.loss.experimentally_resolved.weight = 0.01
- c.globals.alphafold_original_mode = True
- c.model.template.enabled = False
- c.model.template.embed_angles = False
- recursive_set(c, 'use_templates', False)
- recursive_set(c, 'use_template_torsion_angles', False)
- elif name == 'multimer':
- c = multimer(c)
- elif name == 'multimer_ft':
- c = multimer(c)
- recursive_set(c, 'max_extra_msa', 1152)
- recursive_set(c, 'max_msa_clusters', 256)
- c.data.train.crop_size = 384
- c.loss.violation.weight = 0.5
- elif name == 'multimer_af2':
- recursive_set(c, 'max_extra_msa', 1152)
- recursive_set(c, 'max_msa_clusters', 256)
- recursive_set(c, 'is_multimer', True)
- recursive_set(c, 'v2_feature', True)
- recursive_set(c, 'gumbel_sample', True)
- c.model.template.template_angle_embedder.d_in = 34
- c.model.template.template_pair_stack.tri_attn_first = False
- c.model.template.template_pointwise_attention.enabled = False
- c.model.heads.pae.enabled = True
- c.model.heads.experimentally_resolved.enabled = True
- c.model.heads.masked_msa.d_out = 22
- c.model.structure_module.separate_kv = True
- c.model.structure_module.ipa_bias = False
- c.model.structure_module.trans_scale_factor = 20
- c.loss.pae.weight = 0.1
- c.loss.violation.weight = 0.5
- c.loss.experimentally_resolved.weight = 0.01
- c.model.input_embedder.tf_dim = 21
- c.globals.alphafold_original_mode = True
- c.data.train.crop_size = 384
- c.loss.repr_norm.weight = 0
- c.loss.chain_centre_mass.weight = 1.0
- recursive_set(c, 'outer_product_mean_first', True)
- else:
- raise ValueError(f'invalid --model-name: {name}.')
- if train:
- c.globals.chunk_size = None
- recursive_set(c, 'inf', 3e4)
- recursive_set(c, 'eps', 1e-5, 'loss')
- return c
|