pipeline.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright 2021 DeepMind Technologies Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Functions for building the input features for the unifold model."""
  15. import os
  16. from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
  17. import numpy as np
  18. from absl import logging
  19. from modelscope.models.science.unifold.data import residue_constants
  20. from modelscope.models.science.unifold.msa import (msa_identifiers, parsers,
  21. templates)
  22. from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch,
  23. hmmsearch, jackhmmer)
  24. FeatureDict = MutableMapping[str, np.ndarray]
  25. TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
  26. def make_sequence_features(sequence: str, description: str,
  27. num_res: int) -> FeatureDict:
  28. """Constructs a feature dict of sequence features."""
  29. features = {}
  30. features['aatype'] = residue_constants.sequence_to_onehot(
  31. sequence=sequence,
  32. mapping=residue_constants.restype_order_with_x,
  33. map_unknown_to_x=True,
  34. )
  35. features['between_segment_residues'] = np.zeros((num_res, ),
  36. dtype=np.int32)
  37. features['domain_name'] = np.array([description.encode('utf-8')],
  38. dtype=np.object_)
  39. features['residue_index'] = np.array(range(num_res), dtype=np.int32)
  40. features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
  41. features['sequence'] = np.array([sequence.encode('utf-8')],
  42. dtype=np.object_)
  43. return features
  44. def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
  45. """Constructs a feature dict of MSA features."""
  46. if not msas:
  47. raise ValueError('At least one MSA must be provided.')
  48. int_msa = []
  49. deletion_matrix = []
  50. species_ids = []
  51. seen_sequences = set()
  52. for msa_index, msa in enumerate(msas):
  53. if not msa:
  54. raise ValueError(
  55. f'MSA {msa_index} must contain at least one sequence.')
  56. for sequence_index, sequence in enumerate(msa.sequences):
  57. if sequence in seen_sequences:
  58. continue
  59. seen_sequences.add(sequence)
  60. int_msa.append(
  61. [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
  62. deletion_matrix.append(msa.deletion_matrix[sequence_index])
  63. identifiers = msa_identifiers.get_identifiers(
  64. msa.descriptions[sequence_index])
  65. species_ids.append(identifiers.species_id.encode('utf-8'))
  66. num_res = len(msas[0].sequences[0])
  67. num_alignments = len(int_msa)
  68. features = {}
  69. features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
  70. features['msa'] = np.array(int_msa, dtype=np.int32)
  71. features['num_alignments'] = np.array(
  72. [num_alignments] * num_res, dtype=np.int32)
  73. features['msa_species_identifiers'] = np.array(
  74. species_ids, dtype=np.object_)
  75. return features
  76. def run_msa_tool(
  77. msa_runner,
  78. input_fasta_path: str,
  79. msa_out_path: str,
  80. msa_format: str,
  81. use_precomputed_msas: bool,
  82. ) -> Mapping[str, Any]:
  83. """Runs an MSA tool, checking if output already exists first."""
  84. if not use_precomputed_msas or not os.path.exists(msa_out_path):
  85. result = msa_runner.query(input_fasta_path)[0]
  86. with open(msa_out_path, 'w') as f:
  87. f.write(result[msa_format])
  88. else:
  89. logging.warning('Reading MSA from file %s', msa_out_path)
  90. with open(msa_out_path, 'r', encoding='utf-8') as f:
  91. result = {msa_format: f.read()}
  92. return result
  93. class DataPipeline:
  94. """Runs the alignment tools and assembles the input features."""
  95. def __init__(
  96. self,
  97. jackhmmer_binary_path: str,
  98. hhblits_binary_path: str,
  99. uniref90_database_path: str,
  100. mgnify_database_path: str,
  101. bfd_database_path: Optional[str],
  102. uniclust30_database_path: Optional[str],
  103. small_bfd_database_path: Optional[str],
  104. uniprot_database_path: Optional[str],
  105. template_searcher: TemplateSearcher,
  106. template_featurizer: templates.TemplateHitFeaturizer,
  107. use_small_bfd: bool,
  108. mgnify_max_hits: int = 501,
  109. uniref_max_hits: int = 10000,
  110. use_precomputed_msas: bool = False,
  111. ):
  112. """Initializes the data pipeline."""
  113. self._use_small_bfd = use_small_bfd
  114. self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
  115. binary_path=jackhmmer_binary_path,
  116. database_path=uniref90_database_path)
  117. if use_small_bfd:
  118. self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
  119. binary_path=jackhmmer_binary_path,
  120. database_path=small_bfd_database_path)
  121. else:
  122. self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
  123. binary_path=hhblits_binary_path,
  124. databases=[bfd_database_path, uniclust30_database_path],
  125. )
  126. self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
  127. binary_path=jackhmmer_binary_path,
  128. database_path=mgnify_database_path)
  129. self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
  130. binary_path=jackhmmer_binary_path,
  131. database_path=uniprot_database_path)
  132. self.template_searcher = template_searcher
  133. self.template_featurizer = template_featurizer
  134. self.mgnify_max_hits = mgnify_max_hits
  135. self.uniref_max_hits = uniref_max_hits
  136. self.use_precomputed_msas = use_precomputed_msas
  137. def process(self, input_fasta_path: str,
  138. msa_output_dir: str) -> FeatureDict:
  139. """Runs alignment tools on the input sequence and creates features."""
  140. with open(input_fasta_path, encoding='utf-8') as f:
  141. input_fasta_str = f.read()
  142. input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
  143. if len(input_seqs) != 1:
  144. raise ValueError(
  145. f'More than one input sequence found in {input_fasta_path}.')
  146. input_sequence = input_seqs[0]
  147. input_description = input_descs[0]
  148. num_res = len(input_sequence)
  149. uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
  150. jackhmmer_uniref90_result = run_msa_tool(
  151. self.jackhmmer_uniref90_runner,
  152. input_fasta_path,
  153. uniref90_out_path,
  154. 'sto',
  155. self.use_precomputed_msas,
  156. )
  157. mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
  158. jackhmmer_mgnify_result = run_msa_tool(
  159. self.jackhmmer_mgnify_runner,
  160. input_fasta_path,
  161. mgnify_out_path,
  162. 'sto',
  163. self.use_precomputed_msas,
  164. )
  165. msa_for_templates = jackhmmer_uniref90_result['sto']
  166. msa_for_templates = parsers.truncate_stockholm_msa(
  167. msa_for_templates, max_sequences=self.uniref_max_hits)
  168. msa_for_templates = parsers.deduplicate_stockholm_msa(
  169. msa_for_templates)
  170. msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
  171. msa_for_templates)
  172. if self.template_searcher.input_format == 'sto':
  173. pdb_templates_result = self.template_searcher.query(
  174. msa_for_templates)
  175. elif self.template_searcher.input_format == 'a3m':
  176. uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
  177. msa_for_templates)
  178. pdb_templates_result = self.template_searcher.query(
  179. uniref90_msa_as_a3m)
  180. else:
  181. raise ValueError('Unrecognized template input format: '
  182. f'{self.template_searcher.input_format}')
  183. pdb_hits_out_path = os.path.join(
  184. msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
  185. with open(pdb_hits_out_path, 'w') as f:
  186. f.write(pdb_templates_result)
  187. uniref90_msa = parsers.parse_stockholm(
  188. jackhmmer_uniref90_result['sto'])
  189. uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
  190. mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
  191. mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
  192. pdb_template_hits = self.template_searcher.get_template_hits(
  193. output_string=pdb_templates_result, input_sequence=input_sequence)
  194. if self._use_small_bfd:
  195. bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
  196. jackhmmer_small_bfd_result = run_msa_tool(
  197. self.jackhmmer_small_bfd_runner,
  198. input_fasta_path,
  199. bfd_out_path,
  200. 'sto',
  201. self.use_precomputed_msas,
  202. )
  203. bfd_msa = parsers.parse_stockholm(
  204. jackhmmer_small_bfd_result['sto'])
  205. else:
  206. bfd_out_path = os.path.join(msa_output_dir,
  207. 'bfd_uniclust_hits.a3m')
  208. hhblits_bfd_uniclust_result = run_msa_tool(
  209. self.hhblits_bfd_uniclust_runner,
  210. input_fasta_path,
  211. bfd_out_path,
  212. 'a3m',
  213. self.use_precomputed_msas,
  214. )
  215. bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
  216. templates_result = self.template_featurizer.get_templates(
  217. query_sequence=input_sequence, hits=pdb_template_hits)
  218. sequence_features = make_sequence_features(
  219. sequence=input_sequence,
  220. description=input_description,
  221. num_res=num_res)
  222. msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))
  223. logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
  224. logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
  225. logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
  226. logging.info(
  227. 'Final (deduplicated) MSA size: %d sequences.',
  228. msa_features['num_alignments'][0],
  229. )
  230. logging.info(
  231. 'Total number of templates (NB: this can include bad '
  232. 'templates and is later filtered to top 4): %d.',
  233. templates_result.features['template_domain_names'].shape[0],
  234. )
  235. return {
  236. **sequence_features,
  237. **msa_features,
  238. **templates_result.features
  239. }
  240. def process_uniprot(self, input_fasta_path: str,
  241. msa_output_dir: str) -> FeatureDict:
  242. uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
  243. uniprot_result = run_msa_tool(
  244. self.jackhmmer_uniprot_runner,
  245. input_fasta_path,
  246. uniprot_path,
  247. 'sto',
  248. self.use_precomputed_msas,
  249. )
  250. msa = parsers.parse_stockholm(uniprot_result['sto'])
  251. msa = msa.truncate(max_seqs=50000)
  252. all_seq_dict = make_msa_features([msa])
  253. return all_seq_dict