config.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import copy
  4. from typing import Any
  5. import ml_collections as mlc
  6. N_RES = 'number of residues'
  7. N_MSA = 'number of MSA sequences'
  8. N_EXTRA_MSA = 'number of extra MSA sequences'
  9. N_TPL = 'number of templates'
  10. d_pair = mlc.FieldReference(128, field_type=int)
  11. d_msa = mlc.FieldReference(256, field_type=int)
  12. d_template = mlc.FieldReference(64, field_type=int)
  13. d_extra_msa = mlc.FieldReference(64, field_type=int)
  14. d_single = mlc.FieldReference(384, field_type=int)
  15. max_recycling_iters = mlc.FieldReference(3, field_type=int)
  16. chunk_size = mlc.FieldReference(4, field_type=int)
  17. aux_distogram_bins = mlc.FieldReference(64, field_type=int)
  18. eps = mlc.FieldReference(1e-8, field_type=float)
  19. inf = mlc.FieldReference(3e4, field_type=float)
  20. use_templates = mlc.FieldReference(True, field_type=bool)
  21. is_multimer = mlc.FieldReference(False, field_type=bool)
  22. def base_config():
  23. return mlc.ConfigDict({
  24. 'data': {
  25. 'common': {
  26. 'features': {
  27. 'aatype': [N_RES],
  28. 'all_atom_mask': [N_RES, None],
  29. 'all_atom_positions': [N_RES, None, None],
  30. 'alt_chi_angles': [N_RES, None],
  31. 'atom14_alt_gt_exists': [N_RES, None],
  32. 'atom14_alt_gt_positions': [N_RES, None, None],
  33. 'atom14_atom_exists': [N_RES, None],
  34. 'atom14_atom_is_ambiguous': [N_RES, None],
  35. 'atom14_gt_exists': [N_RES, None],
  36. 'atom14_gt_positions': [N_RES, None, None],
  37. 'atom37_atom_exists': [N_RES, None],
  38. 'frame_mask': [N_RES],
  39. 'true_frame_tensor': [N_RES, None, None],
  40. 'bert_mask': [N_MSA, N_RES],
  41. 'chi_angles_sin_cos': [N_RES, None, None],
  42. 'chi_mask': [N_RES, None],
  43. 'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES],
  44. 'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES],
  45. 'extra_msa': [N_EXTRA_MSA, N_RES],
  46. 'extra_msa_mask': [N_EXTRA_MSA, N_RES],
  47. 'extra_msa_row_mask': [N_EXTRA_MSA],
  48. 'is_distillation': [],
  49. 'msa_feat': [N_MSA, N_RES, None],
  50. 'msa_mask': [N_MSA, N_RES],
  51. 'msa_chains': [N_MSA, None],
  52. 'msa_row_mask': [N_MSA],
  53. 'num_recycling_iters': [],
  54. 'pseudo_beta': [N_RES, None],
  55. 'pseudo_beta_mask': [N_RES],
  56. 'residue_index': [N_RES],
  57. 'residx_atom14_to_atom37': [N_RES, None],
  58. 'residx_atom37_to_atom14': [N_RES, None],
  59. 'resolution': [],
  60. 'rigidgroups_alt_gt_frames': [N_RES, None, None, None],
  61. 'rigidgroups_group_exists': [N_RES, None],
  62. 'rigidgroups_group_is_ambiguous': [N_RES, None],
  63. 'rigidgroups_gt_exists': [N_RES, None],
  64. 'rigidgroups_gt_frames': [N_RES, None, None, None],
  65. 'seq_length': [],
  66. 'seq_mask': [N_RES],
  67. 'target_feat': [N_RES, None],
  68. 'template_aatype': [N_TPL, N_RES],
  69. 'template_all_atom_mask': [N_TPL, N_RES, None],
  70. 'template_all_atom_positions': [N_TPL, N_RES, None, None],
  71. 'template_alt_torsion_angles_sin_cos': [
  72. N_TPL,
  73. N_RES,
  74. None,
  75. None,
  76. ],
  77. 'template_frame_mask': [N_TPL, N_RES],
  78. 'template_frame_tensor': [N_TPL, N_RES, None, None],
  79. 'template_mask': [N_TPL],
  80. 'template_pseudo_beta': [N_TPL, N_RES, None],
  81. 'template_pseudo_beta_mask': [N_TPL, N_RES],
  82. 'template_sum_probs': [N_TPL, None],
  83. 'template_torsion_angles_mask': [N_TPL, N_RES, None],
  84. 'template_torsion_angles_sin_cos':
  85. [N_TPL, N_RES, None, None],
  86. 'true_msa': [N_MSA, N_RES],
  87. 'use_clamped_fape': [],
  88. 'assembly_num_chains': [1],
  89. 'asym_id': [N_RES],
  90. 'sym_id': [N_RES],
  91. 'entity_id': [N_RES],
  92. 'num_sym': [N_RES],
  93. 'asym_len': [None],
  94. 'cluster_bias_mask': [N_MSA],
  95. },
  96. 'masked_msa': {
  97. 'profile_prob': 0.1,
  98. 'same_prob': 0.1,
  99. 'uniform_prob': 0.1,
  100. },
  101. 'block_delete_msa': {
  102. 'msa_fraction_per_block': 0.3,
  103. 'randomize_num_blocks': False,
  104. 'num_blocks': 5,
  105. 'min_num_msa': 16,
  106. },
  107. 'random_delete_msa': {
  108. 'max_msa_entry': 1 << 25, # := 33554432
  109. },
  110. 'v2_feature':
  111. False,
  112. 'gumbel_sample':
  113. False,
  114. 'max_extra_msa':
  115. 1024,
  116. 'msa_cluster_features':
  117. True,
  118. 'reduce_msa_clusters_by_max_templates':
  119. True,
  120. 'resample_msa_in_recycling':
  121. True,
  122. 'template_features': [
  123. 'template_all_atom_positions',
  124. 'template_sum_probs',
  125. 'template_aatype',
  126. 'template_all_atom_mask',
  127. ],
  128. 'unsupervised_features': [
  129. 'aatype',
  130. 'residue_index',
  131. 'msa',
  132. 'msa_chains',
  133. 'num_alignments',
  134. 'seq_length',
  135. 'between_segment_residues',
  136. 'deletion_matrix',
  137. 'num_recycling_iters',
  138. 'crop_and_fix_size_seed',
  139. ],
  140. 'recycling_features': [
  141. 'msa_chains',
  142. 'msa_mask',
  143. 'msa_row_mask',
  144. 'bert_mask',
  145. 'true_msa',
  146. 'msa_feat',
  147. 'extra_msa_deletion_value',
  148. 'extra_msa_has_deletion',
  149. 'extra_msa',
  150. 'extra_msa_mask',
  151. 'extra_msa_row_mask',
  152. 'is_distillation',
  153. ],
  154. 'multimer_features': [
  155. 'assembly_num_chains',
  156. 'asym_id',
  157. 'sym_id',
  158. 'num_sym',
  159. 'entity_id',
  160. 'asym_len',
  161. 'cluster_bias_mask',
  162. ],
  163. 'use_templates':
  164. use_templates,
  165. 'is_multimer':
  166. is_multimer,
  167. 'use_template_torsion_angles':
  168. use_templates,
  169. 'max_recycling_iters':
  170. max_recycling_iters,
  171. },
  172. 'supervised': {
  173. 'use_clamped_fape_prob':
  174. 1.0,
  175. 'supervised_features': [
  176. 'all_atom_mask',
  177. 'all_atom_positions',
  178. 'resolution',
  179. 'use_clamped_fape',
  180. 'is_distillation',
  181. ],
  182. },
  183. 'predict': {
  184. 'fixed_size': True,
  185. 'subsample_templates': False,
  186. 'block_delete_msa': False,
  187. 'random_delete_msa': True,
  188. 'masked_msa_replace_fraction': 0.15,
  189. 'max_msa_clusters': 128,
  190. 'max_templates': 4,
  191. 'num_ensembles': 2,
  192. 'crop': False,
  193. 'crop_size': None,
  194. 'supervised': False,
  195. 'biased_msa_by_chain': False,
  196. 'share_mask': False,
  197. },
  198. 'eval': {
  199. 'fixed_size': True,
  200. 'subsample_templates': False,
  201. 'block_delete_msa': False,
  202. 'random_delete_msa': True,
  203. 'masked_msa_replace_fraction': 0.15,
  204. 'max_msa_clusters': 128,
  205. 'max_templates': 4,
  206. 'num_ensembles': 1,
  207. 'crop': False,
  208. 'crop_size': None,
  209. 'spatial_crop_prob': 0.5,
  210. 'ca_ca_threshold': 10.0,
  211. 'supervised': True,
  212. 'biased_msa_by_chain': False,
  213. 'share_mask': False,
  214. },
  215. 'train': {
  216. 'fixed_size': True,
  217. 'subsample_templates': True,
  218. 'block_delete_msa': True,
  219. 'random_delete_msa': True,
  220. 'masked_msa_replace_fraction': 0.15,
  221. 'max_msa_clusters': 128,
  222. 'max_templates': 4,
  223. 'num_ensembles': 1,
  224. 'crop': True,
  225. 'crop_size': 256,
  226. 'spatial_crop_prob': 0.5,
  227. 'ca_ca_threshold': 10.0,
  228. 'supervised': True,
  229. 'use_clamped_fape_prob': 1.0,
  230. 'max_distillation_msa_clusters': 1000,
  231. 'biased_msa_by_chain': True,
  232. 'share_mask': True,
  233. },
  234. },
  235. 'globals': {
  236. 'chunk_size': chunk_size,
  237. 'block_size': None,
  238. 'd_pair': d_pair,
  239. 'd_msa': d_msa,
  240. 'd_template': d_template,
  241. 'd_extra_msa': d_extra_msa,
  242. 'd_single': d_single,
  243. 'eps': eps,
  244. 'inf': inf,
  245. 'max_recycling_iters': max_recycling_iters,
  246. 'alphafold_original_mode': False,
  247. },
  248. 'model': {
  249. 'is_multimer': is_multimer,
  250. 'input_embedder': {
  251. 'tf_dim': 22,
  252. 'msa_dim': 49,
  253. 'd_pair': d_pair,
  254. 'd_msa': d_msa,
  255. 'relpos_k': 32,
  256. 'max_relative_chain': 2,
  257. },
  258. 'recycling_embedder': {
  259. 'd_pair': d_pair,
  260. 'd_msa': d_msa,
  261. 'min_bin': 3.25,
  262. 'max_bin': 20.75,
  263. 'num_bins': 15,
  264. 'inf': 1e8,
  265. },
  266. 'template': {
  267. 'distogram': {
  268. 'min_bin': 3.25,
  269. 'max_bin': 50.75,
  270. 'num_bins': 39,
  271. },
  272. 'template_angle_embedder': {
  273. 'd_in': 57,
  274. 'd_out': d_msa,
  275. },
  276. 'template_pair_embedder': {
  277. 'd_in': 88,
  278. 'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1],
  279. 'd_pair': d_pair,
  280. 'd_out': d_template,
  281. 'v2_feature': False,
  282. },
  283. 'template_pair_stack': {
  284. 'd_template': d_template,
  285. 'd_hid_tri_att': 16,
  286. 'd_hid_tri_mul': 64,
  287. 'num_blocks': 2,
  288. 'num_heads': 4,
  289. 'pair_transition_n': 2,
  290. 'dropout_rate': 0.25,
  291. 'inf': 1e9,
  292. 'tri_attn_first': True,
  293. },
  294. 'template_pointwise_attention': {
  295. 'enabled': True,
  296. 'd_template': d_template,
  297. 'd_pair': d_pair,
  298. 'd_hid': 16,
  299. 'num_heads': 4,
  300. 'inf': 1e5,
  301. },
  302. 'inf': 1e5,
  303. 'eps': 1e-6,
  304. 'enabled': use_templates,
  305. 'embed_angles': use_templates,
  306. },
  307. 'extra_msa': {
  308. 'extra_msa_embedder': {
  309. 'd_in': 25,
  310. 'd_out': d_extra_msa,
  311. },
  312. 'extra_msa_stack': {
  313. 'd_msa': d_extra_msa,
  314. 'd_pair': d_pair,
  315. 'd_hid_msa_att': 8,
  316. 'd_hid_opm': 32,
  317. 'd_hid_mul': 128,
  318. 'd_hid_pair_att': 32,
  319. 'num_heads_msa': 8,
  320. 'num_heads_pair': 4,
  321. 'num_blocks': 4,
  322. 'transition_n': 4,
  323. 'msa_dropout': 0.15,
  324. 'pair_dropout': 0.25,
  325. 'inf': 1e9,
  326. 'eps': 1e-10,
  327. 'outer_product_mean_first': False,
  328. },
  329. 'enabled': True,
  330. },
  331. 'evoformer_stack': {
  332. 'd_msa': d_msa,
  333. 'd_pair': d_pair,
  334. 'd_hid_msa_att': 32,
  335. 'd_hid_opm': 32,
  336. 'd_hid_mul': 128,
  337. 'd_hid_pair_att': 32,
  338. 'd_single': d_single,
  339. 'num_heads_msa': 8,
  340. 'num_heads_pair': 4,
  341. 'num_blocks': 48,
  342. 'transition_n': 4,
  343. 'msa_dropout': 0.15,
  344. 'pair_dropout': 0.25,
  345. 'inf': 1e9,
  346. 'eps': 1e-10,
  347. 'outer_product_mean_first': False,
  348. },
  349. 'structure_module': {
  350. 'd_single': d_single,
  351. 'd_pair': d_pair,
  352. 'd_ipa': 16,
  353. 'd_angle': 128,
  354. 'num_heads_ipa': 12,
  355. 'num_qk_points': 4,
  356. 'num_v_points': 8,
  357. 'dropout_rate': 0.1,
  358. 'num_blocks': 8,
  359. 'no_transition_layers': 1,
  360. 'num_resnet_blocks': 2,
  361. 'num_angles': 7,
  362. 'trans_scale_factor': 10,
  363. 'epsilon': 1e-12,
  364. 'inf': 1e5,
  365. 'separate_kv': False,
  366. 'ipa_bias': True,
  367. },
  368. 'heads': {
  369. 'plddt': {
  370. 'num_bins': 50,
  371. 'd_in': d_single,
  372. 'd_hid': 128,
  373. },
  374. 'distogram': {
  375. 'd_pair': d_pair,
  376. 'num_bins': aux_distogram_bins,
  377. 'disable_enhance_head': False,
  378. },
  379. 'pae': {
  380. 'd_pair': d_pair,
  381. 'num_bins': aux_distogram_bins,
  382. 'enabled': False,
  383. 'iptm_weight': 0.8,
  384. 'disable_enhance_head': False,
  385. },
  386. 'masked_msa': {
  387. 'd_msa': d_msa,
  388. 'd_out': 23,
  389. 'disable_enhance_head': False,
  390. },
  391. 'experimentally_resolved': {
  392. 'd_single': d_single,
  393. 'd_out': 37,
  394. 'enabled': False,
  395. 'disable_enhance_head': False,
  396. },
  397. },
  398. },
  399. 'loss': {
  400. 'distogram': {
  401. 'min_bin': 2.3125,
  402. 'max_bin': 21.6875,
  403. 'num_bins': 64,
  404. 'eps': 1e-6,
  405. 'weight': 0.3,
  406. },
  407. 'experimentally_resolved': {
  408. 'eps': 1e-8,
  409. 'min_resolution': 0.1,
  410. 'max_resolution': 3.0,
  411. 'weight': 0.0,
  412. },
  413. 'fape': {
  414. 'backbone': {
  415. 'clamp_distance': 10.0,
  416. 'clamp_distance_between_chains': 30.0,
  417. 'loss_unit_distance': 10.0,
  418. 'loss_unit_distance_between_chains': 20.0,
  419. 'weight': 0.5,
  420. 'eps': 1e-4,
  421. },
  422. 'sidechain': {
  423. 'clamp_distance': 10.0,
  424. 'length_scale': 10.0,
  425. 'weight': 0.5,
  426. 'eps': 1e-4,
  427. },
  428. 'weight': 1.0,
  429. },
  430. 'plddt': {
  431. 'min_resolution': 0.1,
  432. 'max_resolution': 3.0,
  433. 'cutoff': 15.0,
  434. 'num_bins': 50,
  435. 'eps': 1e-10,
  436. 'weight': 0.01,
  437. },
  438. 'masked_msa': {
  439. 'eps': 1e-8,
  440. 'weight': 2.0,
  441. },
  442. 'supervised_chi': {
  443. 'chi_weight': 0.5,
  444. 'angle_norm_weight': 0.01,
  445. 'eps': 1e-6,
  446. 'weight': 1.0,
  447. },
  448. 'violation': {
  449. 'violation_tolerance_factor': 12.0,
  450. 'clash_overlap_tolerance': 1.5,
  451. 'bond_angle_loss_weight': 0.3,
  452. 'eps': 1e-6,
  453. 'weight': 0.0,
  454. },
  455. 'pae': {
  456. 'max_bin': 31,
  457. 'num_bins': 64,
  458. 'min_resolution': 0.1,
  459. 'max_resolution': 3.0,
  460. 'eps': 1e-8,
  461. 'weight': 0.0,
  462. },
  463. 'repr_norm': {
  464. 'weight': 0.01,
  465. 'tolerance': 1.0,
  466. },
  467. 'chain_centre_mass': {
  468. 'weight': 0.0,
  469. 'eps': 1e-8,
  470. },
  471. },
  472. })
  473. def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None):
  474. with c.unlocked():
  475. for k, v in c.items():
  476. if ignore is not None and k == ignore:
  477. continue
  478. if isinstance(v, mlc.ConfigDict):
  479. recursive_set(v, key, value)
  480. elif k == key:
  481. c[k] = value
  482. def model_config(name, train=False):
  483. c = copy.deepcopy(base_config())
  484. def model_2_v2(c):
  485. recursive_set(c, 'v2_feature', True)
  486. recursive_set(c, 'gumbel_sample', True)
  487. c.model.heads.masked_msa.d_out = 22
  488. c.model.structure_module.separate_kv = True
  489. c.model.structure_module.ipa_bias = False
  490. c.model.template.template_angle_embedder.d_in = 34
  491. return c
  492. def multimer(c):
  493. recursive_set(c, 'is_multimer', True)
  494. recursive_set(c, 'max_extra_msa', 1152)
  495. recursive_set(c, 'max_msa_clusters', 128)
  496. recursive_set(c, 'v2_feature', True)
  497. recursive_set(c, 'gumbel_sample', True)
  498. c.model.template.template_angle_embedder.d_in = 34
  499. c.model.template.template_pair_stack.tri_attn_first = False
  500. c.model.template.template_pointwise_attention.enabled = False
  501. c.model.heads.pae.enabled = True
  502. # we forget to enable it in our training, so disable it here
  503. c.model.heads.pae.disable_enhance_head = True
  504. c.model.heads.masked_msa.d_out = 22
  505. c.model.structure_module.separate_kv = True
  506. c.model.structure_module.ipa_bias = False
  507. c.model.structure_module.trans_scale_factor = 20
  508. c.loss.pae.weight = 0.1
  509. c.model.input_embedder.tf_dim = 21
  510. c.data.train.crop_size = 384
  511. c.loss.violation.weight = 0.02
  512. c.loss.chain_centre_mass.weight = 1.0
  513. return c
  514. if name == 'model_1':
  515. pass
  516. elif name == 'model_1_ft':
  517. recursive_set(c, 'max_extra_msa', 5120)
  518. recursive_set(c, 'max_msa_clusters', 512)
  519. c.data.train.crop_size = 384
  520. c.loss.violation.weight = 0.02
  521. elif name == 'model_1_af2':
  522. recursive_set(c, 'max_extra_msa', 5120)
  523. recursive_set(c, 'max_msa_clusters', 512)
  524. c.data.train.crop_size = 384
  525. c.loss.violation.weight = 0.02
  526. c.loss.repr_norm.weight = 0
  527. c.model.heads.experimentally_resolved.enabled = True
  528. c.loss.experimentally_resolved.weight = 0.01
  529. c.globals.alphafold_original_mode = True
  530. elif name == 'model_2':
  531. pass
  532. elif name == 'model_init':
  533. pass
  534. elif name == 'model_init_af2':
  535. c.globals.alphafold_original_mode = True
  536. pass
  537. elif name == 'model_2_ft':
  538. recursive_set(c, 'max_extra_msa', 1024)
  539. recursive_set(c, 'max_msa_clusters', 512)
  540. c.data.train.crop_size = 384
  541. c.loss.violation.weight = 0.02
  542. elif name == 'model_2_af2':
  543. recursive_set(c, 'max_extra_msa', 1024)
  544. recursive_set(c, 'max_msa_clusters', 512)
  545. c.data.train.crop_size = 384
  546. c.loss.violation.weight = 0.02
  547. c.loss.repr_norm.weight = 0
  548. c.model.heads.experimentally_resolved.enabled = True
  549. c.loss.experimentally_resolved.weight = 0.01
  550. c.globals.alphafold_original_mode = True
  551. elif name == 'model_2_v2':
  552. c = model_2_v2(c)
  553. elif name == 'model_2_v2_ft':
  554. c = model_2_v2(c)
  555. recursive_set(c, 'max_extra_msa', 1024)
  556. recursive_set(c, 'max_msa_clusters', 512)
  557. c.data.train.crop_size = 384
  558. c.loss.violation.weight = 0.02
  559. elif name == 'model_3_af2' or name == 'model_4_af2':
  560. recursive_set(c, 'max_extra_msa', 5120)
  561. recursive_set(c, 'max_msa_clusters', 512)
  562. c.data.train.crop_size = 384
  563. c.loss.violation.weight = 0.02
  564. c.loss.repr_norm.weight = 0
  565. c.model.heads.experimentally_resolved.enabled = True
  566. c.loss.experimentally_resolved.weight = 0.01
  567. c.globals.alphafold_original_mode = True
  568. c.model.template.enabled = False
  569. c.model.template.embed_angles = False
  570. recursive_set(c, 'use_templates', False)
  571. recursive_set(c, 'use_template_torsion_angles', False)
  572. elif name == 'model_5_af2':
  573. recursive_set(c, 'max_extra_msa', 1024)
  574. recursive_set(c, 'max_msa_clusters', 512)
  575. c.data.train.crop_size = 384
  576. c.loss.violation.weight = 0.02
  577. c.loss.repr_norm.weight = 0
  578. c.model.heads.experimentally_resolved.enabled = True
  579. c.loss.experimentally_resolved.weight = 0.01
  580. c.globals.alphafold_original_mode = True
  581. c.model.template.enabled = False
  582. c.model.template.embed_angles = False
  583. recursive_set(c, 'use_templates', False)
  584. recursive_set(c, 'use_template_torsion_angles', False)
  585. elif name == 'multimer':
  586. c = multimer(c)
  587. elif name == 'multimer_ft':
  588. c = multimer(c)
  589. recursive_set(c, 'max_extra_msa', 1152)
  590. recursive_set(c, 'max_msa_clusters', 256)
  591. c.data.train.crop_size = 384
  592. c.loss.violation.weight = 0.5
  593. elif name == 'multimer_af2':
  594. recursive_set(c, 'max_extra_msa', 1152)
  595. recursive_set(c, 'max_msa_clusters', 256)
  596. recursive_set(c, 'is_multimer', True)
  597. recursive_set(c, 'v2_feature', True)
  598. recursive_set(c, 'gumbel_sample', True)
  599. c.model.template.template_angle_embedder.d_in = 34
  600. c.model.template.template_pair_stack.tri_attn_first = False
  601. c.model.template.template_pointwise_attention.enabled = False
  602. c.model.heads.pae.enabled = True
  603. c.model.heads.experimentally_resolved.enabled = True
  604. c.model.heads.masked_msa.d_out = 22
  605. c.model.structure_module.separate_kv = True
  606. c.model.structure_module.ipa_bias = False
  607. c.model.structure_module.trans_scale_factor = 20
  608. c.loss.pae.weight = 0.1
  609. c.loss.violation.weight = 0.5
  610. c.loss.experimentally_resolved.weight = 0.01
  611. c.model.input_embedder.tf_dim = 21
  612. c.globals.alphafold_original_mode = True
  613. c.data.train.crop_size = 384
  614. c.loss.repr_norm.weight = 0
  615. c.loss.chain_centre_mass.weight = 1.0
  616. recursive_set(c, 'outer_product_mean_first', True)
  617. else:
  618. raise ValueError(f'invalid --model-name: {name}.')
  619. if train:
  620. c.globals.chunk_size = None
  621. recursive_set(c, 'inf', 3e4)
  622. recursive_set(c, 'eps', 1e-5, 'loss')
  623. return c