uni_fold.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import gzip
  4. import hashlib
  5. import logging
  6. import os
  7. import pickle
  8. import random
  9. import re
  10. import tarfile
  11. import time
  12. from pathlib import Path
  13. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  14. from unittest import result
  15. import json
  16. import numpy as np
  17. import requests
  18. import torch
  19. from tqdm import tqdm
  20. from modelscope.metainfo import Preprocessors
  21. from modelscope.models.science.unifold.data import protein, residue_constants
  22. from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS
  23. from modelscope.models.science.unifold.data.utils import compress_features
  24. from modelscope.models.science.unifold.msa import parsers, pipeline, templates
  25. from modelscope.models.science.unifold.msa.tools import hhsearch
  26. from modelscope.models.science.unifold.msa.utils import divide_multi_chains
  27. from modelscope.preprocessors.base import Preprocessor
  28. from modelscope.preprocessors.builder import PREPROCESSORS
  29. from modelscope.utils.constant import Fields
  30. __all__ = [
  31. 'UniFoldPreprocessor',
  32. ]
  33. TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
  34. DEFAULT_API_SERVER = 'https://api.colabfold.com'
  35. def run_mmseqs2(
  36. x,
  37. prefix,
  38. use_env=True,
  39. use_templates=False,
  40. use_pairing=False,
  41. host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]:
  42. submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa'
  43. def submit(seqs, mode, N=101):
  44. n, query = N, ''
  45. for seq in seqs:
  46. query += f'>{n}\n{seq}\n'
  47. n += 1
  48. res = requests.post(
  49. f'{host_url}/{submission_endpoint}',
  50. data={
  51. 'q': query,
  52. 'mode': mode
  53. })
  54. try:
  55. out = res.json()
  56. except ValueError:
  57. out = {'status': 'ERROR'}
  58. return out
  59. def status(ID):
  60. res = requests.get(f'{host_url}/ticket/{ID}')
  61. try:
  62. out = res.json()
  63. except ValueError:
  64. out = {'status': 'ERROR'}
  65. return out
  66. def download(ID, path):
  67. res = requests.get(f'{host_url}/result/download/{ID}')
  68. with open(path, 'wb') as out:
  69. out.write(res.content)
  70. # process input x
  71. seqs = [x] if isinstance(x, str) else x
  72. mode = 'env'
  73. if use_pairing:
  74. mode = ''
  75. use_templates = False
  76. use_env = False
  77. # define path
  78. path = f'{prefix}'
  79. if not os.path.isdir(path):
  80. os.mkdir(path)
  81. # call mmseqs2 api
  82. tar_gz_file = f'{path}/out_{mode}.tar.gz'
  83. N, REDO = 101, True
  84. # deduplicate and keep track of order
  85. seqs_unique = []
  86. # TODO this might be slow for large sets
  87. [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
  88. Ms = [N + seqs_unique.index(seq) for seq in seqs]
  89. # lets do it!
  90. if not os.path.isfile(tar_gz_file):
  91. TIME_ESTIMATE = 150 * len(seqs_unique)
  92. with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
  93. while REDO:
  94. pbar.set_description('SUBMIT')
  95. # Resubmit job until it goes through
  96. out = submit(seqs_unique, mode, N)
  97. while out['status'] in ['UNKNOWN', 'RATELIMIT']:
  98. sleep_time = 5 + random.randint(0, 5)
  99. # logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
  100. # resubmit
  101. time.sleep(sleep_time)
  102. out = submit(seqs_unique, mode, N)
  103. if out['status'] == 'ERROR':
  104. error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.'
  105. error = error + 'If error persists, please try again an hour later.'
  106. raise Exception(error)
  107. if out['status'] == 'MAINTENANCE':
  108. raise Exception(
  109. 'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.'
  110. )
  111. # wait for job to finish
  112. ID, TIME = out['id'], 0
  113. pbar.set_description(out['status'])
  114. while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']:
  115. t = 5 + random.randint(0, 5)
  116. # logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
  117. time.sleep(t)
  118. out = status(ID)
  119. pbar.set_description(out['status'])
  120. if out['status'] == 'RUNNING':
  121. TIME += t
  122. pbar.update(n=t)
  123. if out['status'] == 'COMPLETE':
  124. if TIME < TIME_ESTIMATE:
  125. pbar.update(n=(TIME_ESTIMATE - TIME))
  126. REDO = False
  127. if out['status'] == 'ERROR':
  128. REDO = False
  129. error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.'
  130. error = error + 'If error persists, please try again an hour later.'
  131. raise Exception(error)
  132. # Download results
  133. download(ID, tar_gz_file)
  134. # prep list of a3m files
  135. if use_pairing:
  136. a3m_files = [f'{path}/pair.a3m']
  137. else:
  138. a3m_files = [f'{path}/uniref.a3m']
  139. if use_env:
  140. a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m')
  141. # extract a3m files
  142. if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
  143. with tarfile.open(tar_gz_file) as tar_gz:
  144. tar_gz.extractall(path)
  145. # templates
  146. if use_templates:
  147. templates = {}
  148. with open(f'{path}/pdb70.m8', 'r') as f:
  149. lines = f.readlines()
  150. for line in lines:
  151. p = line.rstrip().split()
  152. M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value
  153. M = int(M)
  154. if M not in templates:
  155. templates[M] = []
  156. templates[M].append(pdb)
  157. template_paths = {}
  158. for k, TMPL in templates.items():
  159. TMPL_PATH = f'{prefix}/templates_{k}'
  160. if not os.path.isdir(TMPL_PATH):
  161. os.mkdir(TMPL_PATH)
  162. TMPL_LINE = ','.join(TMPL[:20])
  163. os.system(
  164. f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/'
  165. )
  166. os.system(
  167. f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex'
  168. )
  169. os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata')
  170. template_paths[k] = TMPL_PATH
  171. # gather a3m lines
  172. a3m_lines = {}
  173. for a3m_file in a3m_files:
  174. update_M, M = True, None
  175. with open(a3m_file, 'r', encoding='utf-8') as f:
  176. lines = f.readlines()
  177. for line in lines:
  178. if len(line) > 0:
  179. if '\x00' in line:
  180. line = line.replace('\x00', '')
  181. update_M = True
  182. if line.startswith('>') and update_M:
  183. M = int(line[1:].rstrip())
  184. update_M = False
  185. if M not in a3m_lines:
  186. a3m_lines[M] = []
  187. a3m_lines[M].append(line)
  188. # return results
  189. a3m_lines = [''.join(a3m_lines[n]) for n in Ms]
  190. if use_templates:
  191. template_paths_ = []
  192. for n in Ms:
  193. if n not in template_paths:
  194. template_paths_.append(None)
  195. # print(f"{n-N}\tno_templates_found")
  196. else:
  197. template_paths_.append(template_paths[n])
  198. template_paths = template_paths_
  199. return (a3m_lines, template_paths) if use_templates else a3m_lines
  200. def get_null_template(query_sequence: Union[List[str], str],
  201. num_temp: int = 1) -> Dict[str, Any]:
  202. ln = (
  203. len(query_sequence) if isinstance(query_sequence, str) else sum(
  204. len(s) for s in query_sequence))
  205. output_templates_sequence = 'A' * ln
  206. # output_confidence_scores = np.full(ln, 1.0)
  207. templates_all_atom_positions = np.zeros(
  208. (ln, templates.residue_constants.atom_type_num, 3))
  209. templates_all_atom_masks = np.zeros(
  210. (ln, templates.residue_constants.atom_type_num))
  211. templates_aatype = templates.residue_constants.sequence_to_onehot(
  212. output_templates_sequence,
  213. templates.residue_constants.HHBLITS_AA_TO_ID)
  214. template_features = {
  215. 'template_all_atom_positions':
  216. np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]),
  217. 'template_all_atom_masks':
  218. np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]),
  219. 'template_sequence': ['none'.encode()] * num_temp,
  220. 'template_aatype':
  221. np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),
  222. 'template_domain_names': ['none'.encode()] * num_temp,
  223. 'template_sum_probs':
  224. np.zeros([num_temp], dtype=np.float32),
  225. }
  226. return template_features
  227. def get_template(a3m_lines: str, template_path: str,
  228. query_sequence: str) -> Dict[str, Any]:
  229. template_featurizer = templates.HhsearchHitFeaturizer(
  230. mmcif_dir=template_path,
  231. max_template_date='2100-01-01',
  232. max_hits=20,
  233. kalign_binary_path='kalign',
  234. release_dates_path=None,
  235. obsolete_pdbs_path=None,
  236. )
  237. hhsearch_pdb70_runner = hhsearch.HHSearch(
  238. binary_path='hhsearch', databases=[f'{template_path}/pdb70'])
  239. hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines)
  240. hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result)
  241. templates_result = template_featurizer.get_templates(
  242. query_sequence=query_sequence, hits=hhsearch_hits)
  243. return dict(templates_result.features)
  244. @PREPROCESSORS.register_module(
  245. Fields.science, module_name=Preprocessors.unifold_preprocessor)
  246. class UniFoldPreprocessor(Preprocessor):
  247. def __init__(self, **cfg):
  248. self.symmetry_group = cfg['symmetry_group'] # "C1"
  249. if not self.symmetry_group:
  250. self.symmetry_group = None
  251. self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg
  252. self.MAX_SINGLE_SEQUENCE_LENGTH = 1000
  253. self.MAX_MULTIMER_LENGTH = 1000
  254. self.jobname = 'unifold'
  255. self.output_dir_base = './unifold-predictions'
  256. os.makedirs(self.output_dir_base, exist_ok=True)
  257. def clean_and_validate_sequence(self, input_sequence: str, min_length: int,
  258. max_length: int) -> str:
  259. clean_sequence = input_sequence.translate(
  260. str.maketrans('', '', ' \n\t')).upper()
  261. aatypes = set(residue_constants.restypes) # 20 standard aatypes.
  262. if not set(clean_sequence).issubset(aatypes):
  263. raise ValueError(
  264. f'Input sequence contains non-amino acid letters: '
  265. f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '
  266. 'amino acids as inputs.')
  267. if len(clean_sequence) < min_length:
  268. raise ValueError(
  269. f'Input sequence is too short: {len(clean_sequence)} amino acids, '
  270. f'while the minimum is {min_length}')
  271. if len(clean_sequence) > max_length:
  272. raise ValueError(
  273. f'Input sequence is too long: {len(clean_sequence)} amino acids, while '
  274. f'the maximum is {max_length}. You may be able to run it with the full '
  275. f'Uni-Fold system depending on your resources (system memory, '
  276. f'GPU memory).')
  277. return clean_sequence
  278. def validate_input(self, input_sequences: Sequence[str],
  279. symmetry_group: str, min_length: int, max_length: int,
  280. max_multimer_length: int) -> Tuple[Sequence[str], bool]:
  281. """Validates and cleans input sequences and determines which model to use."""
  282. sequences = []
  283. for input_sequence in input_sequences:
  284. if input_sequence.strip():
  285. input_sequence = self.clean_and_validate_sequence(
  286. input_sequence=input_sequence,
  287. min_length=min_length,
  288. max_length=max_length)
  289. sequences.append(input_sequence)
  290. if symmetry_group is not None and symmetry_group != 'C1':
  291. if symmetry_group.startswith(
  292. 'C') and symmetry_group[1:].isnumeric():
  293. print(
  294. f'Using UF-Symmetry with group {symmetry_group}. If you do not '
  295. f'want to use UF-Symmetry, please use `C1` and copy the AU '
  296. f'sequences to the count in the assembly.')
  297. is_multimer = (len(sequences) > 1)
  298. return sequences, is_multimer, symmetry_group
  299. else:
  300. raise ValueError(
  301. f'UF-Symmetry does not support symmetry group '
  302. f'{symmetry_group} currently. Cyclic groups (Cx) are '
  303. f'supported only.')
  304. elif len(sequences) == 1:
  305. print('Using the single-chain model.')
  306. return sequences, False, None
  307. elif len(sequences) > 1:
  308. total_multimer_length = sum([len(seq) for seq in sequences])
  309. if total_multimer_length > max_multimer_length:
  310. raise ValueError(
  311. f'The total length of multimer sequences is too long: '
  312. f'{total_multimer_length}, while the maximum is '
  313. f'{max_multimer_length}. Please use the full AlphaFold '
  314. f'system for long multimers.')
  315. print(f'Using the multimer model with {len(sequences)} sequences.')
  316. return sequences, True, None
  317. else:
  318. raise ValueError(
  319. 'No input amino acid sequence provided, please provide at '
  320. 'least one sequence.')
  321. def add_hash(self, x, y):
  322. return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5]
  323. def get_msa_and_templates(
  324. self,
  325. jobname: str,
  326. query_seqs_unique: Union[str, List[str]],
  327. result_dir: Path,
  328. msa_mode: str,
  329. use_templates: bool,
  330. homooligomers_num: int = 1,
  331. host_url: str = DEFAULT_API_SERVER,
  332. ) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int],
  333. List[Dict[str, Any]]]:
  334. use_env = msa_mode == 'MMseqs2'
  335. template_features = []
  336. if use_templates:
  337. a3m_lines_mmseqs2, template_paths = run_mmseqs2(
  338. query_seqs_unique,
  339. str(result_dir.joinpath(jobname)),
  340. use_env,
  341. use_templates=True,
  342. host_url=host_url,
  343. )
  344. if template_paths is None:
  345. for index in range(0, len(query_seqs_unique)):
  346. template_feature = get_null_template(
  347. query_seqs_unique[index])
  348. template_features.append(template_feature)
  349. else:
  350. for index in range(0, len(query_seqs_unique)):
  351. if template_paths[index] is not None:
  352. template_feature = get_template(
  353. a3m_lines_mmseqs2[index],
  354. template_paths[index],
  355. query_seqs_unique[index],
  356. )
  357. if len(template_feature['template_domain_names']) == 0:
  358. template_feature = get_null_template(
  359. query_seqs_unique[index])
  360. else:
  361. template_feature = get_null_template(
  362. query_seqs_unique[index])
  363. template_features.append(template_feature)
  364. else:
  365. for index in range(0, len(query_seqs_unique)):
  366. template_feature = get_null_template(query_seqs_unique[index])
  367. template_features.append(template_feature)
  368. if msa_mode == 'single_sequence':
  369. a3m_lines = []
  370. num = 101
  371. for i, seq in enumerate(query_seqs_unique):
  372. a3m_lines.append('>' + str(num + i) + '\n' + seq)
  373. else:
  374. # find normal a3ms
  375. a3m_lines = run_mmseqs2(
  376. query_seqs_unique,
  377. str(result_dir.joinpath(jobname)),
  378. use_env,
  379. use_pairing=False,
  380. host_url=host_url,
  381. )
  382. if len(query_seqs_unique) > 1:
  383. # find paired a3m if not a homooligomers
  384. paired_a3m_lines = run_mmseqs2(
  385. query_seqs_unique,
  386. str(result_dir.joinpath(jobname)),
  387. use_env,
  388. use_pairing=True,
  389. host_url=host_url,
  390. )
  391. else:
  392. num = 101
  393. paired_a3m_lines = []
  394. for i in range(0, homooligomers_num):
  395. paired_a3m_lines.append('>' + str(num + i) + '\n'
  396. + query_seqs_unique[0] + '\n')
  397. return (
  398. a3m_lines,
  399. paired_a3m_lines,
  400. template_features,
  401. )
  402. def __call__(self, data: Union[str, Tuple]):
  403. if isinstance(data, str):
  404. data = data.strip().split()
  405. if len(data) < 4:
  406. data = data + [''] * (4 - len(data))
  407. basejobname = ''.join(data)
  408. basejobname = re.sub(r'\W+', '', basejobname)
  409. target_id = self.add_hash(self.jobname, basejobname)
  410. sequences, is_multimer, _ = self.validate_input(
  411. input_sequences=data,
  412. symmetry_group=self.symmetry_group,
  413. min_length=self.MIN_SINGLE_SEQUENCE_LENGTH,
  414. max_length=self.MAX_SINGLE_SEQUENCE_LENGTH,
  415. max_multimer_length=self.MAX_MULTIMER_LENGTH)
  416. descriptions = [
  417. '> ' + target_id + ' seq' + str(ii)
  418. for ii in range(len(sequences))
  419. ]
  420. if is_multimer:
  421. divide_multi_chains(target_id, self.output_dir_base, sequences,
  422. descriptions)
  423. s = []
  424. for des, seq in zip(descriptions, sequences):
  425. s += [des, seq]
  426. unique_sequences = []
  427. [
  428. unique_sequences.append(x) for x in sequences
  429. if x not in unique_sequences
  430. ]
  431. if len(unique_sequences) == 1:
  432. homooligomers_num = len(sequences)
  433. else:
  434. homooligomers_num = 1
  435. with open(f'{self.jobname}.fasta', 'w') as f:
  436. f.write('\n'.join(s))
  437. result_dir = Path(self.output_dir_base)
  438. output_dir = os.path.join(self.output_dir_base, target_id)
  439. # msa_mode = 'single_sequence'
  440. msa_mode = 'MMseqs2'
  441. use_templates = True
  442. unpaired_msa, paired_msa, template_results = self.get_msa_and_templates(
  443. target_id,
  444. unique_sequences,
  445. result_dir=result_dir,
  446. msa_mode=msa_mode,
  447. use_templates=use_templates,
  448. homooligomers_num=homooligomers_num)
  449. features = []
  450. pair_features_list = []
  451. for idx, seq in enumerate(unique_sequences):
  452. chain_id = PDB_CHAIN_IDS[idx]
  453. sequence_features = pipeline.make_sequence_features(
  454. sequence=seq,
  455. description=f'> {self.jobname} seq {chain_id}',
  456. num_res=len(seq))
  457. monomer_msa = parsers.parse_a3m(unpaired_msa[idx])
  458. msa_features = pipeline.make_msa_features([monomer_msa])
  459. template_features = template_results[idx]
  460. feature_dict = {
  461. **sequence_features,
  462. **msa_features,
  463. **template_features
  464. }
  465. feature_dict = compress_features(feature_dict)
  466. features_output_path = os.path.join(
  467. output_dir, '{}.feature.pkl.gz'.format(chain_id))
  468. pickle.dump(
  469. feature_dict,
  470. gzip.GzipFile(features_output_path, 'wb'),
  471. protocol=4)
  472. features.append(feature_dict)
  473. if is_multimer:
  474. multimer_msa = parsers.parse_a3m(paired_msa[idx])
  475. pair_features = pipeline.make_msa_features([multimer_msa])
  476. pair_feature_dict = compress_features(pair_features)
  477. uniprot_output_path = os.path.join(
  478. output_dir, '{}.uniprot.pkl.gz'.format(chain_id))
  479. pickle.dump(
  480. pair_feature_dict,
  481. gzip.GzipFile(uniprot_output_path, 'wb'),
  482. protocol=4,
  483. )
  484. pair_features_list.append(pair_feature_dict)
  485. # return features, pair_features, target_id
  486. return {
  487. 'features': features,
  488. 'pair_features': pair_features_list,
  489. 'target_id': target_id,
  490. 'is_multimer': is_multimer,
  491. }
  492. if __name__ == '__main__':
  493. proc = UniFoldPreprocessor()
  494. protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \
  495. 'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'
  496. features, pair_features = proc.__call__(protein_example)
  497. import ipdb
  498. ipdb.set_trace()