| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- # Copyright 2021 DeepMind Technologies Limited
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Functions for building the input features for the unifold model."""
- import os
- from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union
- import numpy as np
- from absl import logging
- from modelscope.models.science.unifold.data import residue_constants
- from modelscope.models.science.unifold.msa import (msa_identifiers, parsers,
- templates)
- from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch,
- hmmsearch, jackhmmer)
- FeatureDict = MutableMapping[str, np.ndarray]
- TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
- def make_sequence_features(sequence: str, description: str,
- num_res: int) -> FeatureDict:
- """Constructs a feature dict of sequence features."""
- features = {}
- features['aatype'] = residue_constants.sequence_to_onehot(
- sequence=sequence,
- mapping=residue_constants.restype_order_with_x,
- map_unknown_to_x=True,
- )
- features['between_segment_residues'] = np.zeros((num_res, ),
- dtype=np.int32)
- features['domain_name'] = np.array([description.encode('utf-8')],
- dtype=np.object_)
- features['residue_index'] = np.array(range(num_res), dtype=np.int32)
- features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
- features['sequence'] = np.array([sequence.encode('utf-8')],
- dtype=np.object_)
- return features
- def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
- """Constructs a feature dict of MSA features."""
- if not msas:
- raise ValueError('At least one MSA must be provided.')
- int_msa = []
- deletion_matrix = []
- species_ids = []
- seen_sequences = set()
- for msa_index, msa in enumerate(msas):
- if not msa:
- raise ValueError(
- f'MSA {msa_index} must contain at least one sequence.')
- for sequence_index, sequence in enumerate(msa.sequences):
- if sequence in seen_sequences:
- continue
- seen_sequences.add(sequence)
- int_msa.append(
- [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
- deletion_matrix.append(msa.deletion_matrix[sequence_index])
- identifiers = msa_identifiers.get_identifiers(
- msa.descriptions[sequence_index])
- species_ids.append(identifiers.species_id.encode('utf-8'))
- num_res = len(msas[0].sequences[0])
- num_alignments = len(int_msa)
- features = {}
- features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
- features['msa'] = np.array(int_msa, dtype=np.int32)
- features['num_alignments'] = np.array(
- [num_alignments] * num_res, dtype=np.int32)
- features['msa_species_identifiers'] = np.array(
- species_ids, dtype=np.object_)
- return features
- def run_msa_tool(
- msa_runner,
- input_fasta_path: str,
- msa_out_path: str,
- msa_format: str,
- use_precomputed_msas: bool,
- ) -> Mapping[str, Any]:
- """Runs an MSA tool, checking if output already exists first."""
- if not use_precomputed_msas or not os.path.exists(msa_out_path):
- result = msa_runner.query(input_fasta_path)[0]
- with open(msa_out_path, 'w') as f:
- f.write(result[msa_format])
- else:
- logging.warning('Reading MSA from file %s', msa_out_path)
- with open(msa_out_path, 'r', encoding='utf-8') as f:
- result = {msa_format: f.read()}
- return result
- class DataPipeline:
- """Runs the alignment tools and assembles the input features."""
- def __init__(
- self,
- jackhmmer_binary_path: str,
- hhblits_binary_path: str,
- uniref90_database_path: str,
- mgnify_database_path: str,
- bfd_database_path: Optional[str],
- uniclust30_database_path: Optional[str],
- small_bfd_database_path: Optional[str],
- uniprot_database_path: Optional[str],
- template_searcher: TemplateSearcher,
- template_featurizer: templates.TemplateHitFeaturizer,
- use_small_bfd: bool,
- mgnify_max_hits: int = 501,
- uniref_max_hits: int = 10000,
- use_precomputed_msas: bool = False,
- ):
- """Initializes the data pipeline."""
- self._use_small_bfd = use_small_bfd
- self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
- binary_path=jackhmmer_binary_path,
- database_path=uniref90_database_path)
- if use_small_bfd:
- self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
- binary_path=jackhmmer_binary_path,
- database_path=small_bfd_database_path)
- else:
- self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
- binary_path=hhblits_binary_path,
- databases=[bfd_database_path, uniclust30_database_path],
- )
- self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
- binary_path=jackhmmer_binary_path,
- database_path=mgnify_database_path)
- self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
- binary_path=jackhmmer_binary_path,
- database_path=uniprot_database_path)
- self.template_searcher = template_searcher
- self.template_featurizer = template_featurizer
- self.mgnify_max_hits = mgnify_max_hits
- self.uniref_max_hits = uniref_max_hits
- self.use_precomputed_msas = use_precomputed_msas
- def process(self, input_fasta_path: str,
- msa_output_dir: str) -> FeatureDict:
- """Runs alignment tools on the input sequence and creates features."""
- with open(input_fasta_path, encoding='utf-8') as f:
- input_fasta_str = f.read()
- input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
- if len(input_seqs) != 1:
- raise ValueError(
- f'More than one input sequence found in {input_fasta_path}.')
- input_sequence = input_seqs[0]
- input_description = input_descs[0]
- num_res = len(input_sequence)
- uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
- jackhmmer_uniref90_result = run_msa_tool(
- self.jackhmmer_uniref90_runner,
- input_fasta_path,
- uniref90_out_path,
- 'sto',
- self.use_precomputed_msas,
- )
- mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
- jackhmmer_mgnify_result = run_msa_tool(
- self.jackhmmer_mgnify_runner,
- input_fasta_path,
- mgnify_out_path,
- 'sto',
- self.use_precomputed_msas,
- )
- msa_for_templates = jackhmmer_uniref90_result['sto']
- msa_for_templates = parsers.truncate_stockholm_msa(
- msa_for_templates, max_sequences=self.uniref_max_hits)
- msa_for_templates = parsers.deduplicate_stockholm_msa(
- msa_for_templates)
- msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
- msa_for_templates)
- if self.template_searcher.input_format == 'sto':
- pdb_templates_result = self.template_searcher.query(
- msa_for_templates)
- elif self.template_searcher.input_format == 'a3m':
- uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
- msa_for_templates)
- pdb_templates_result = self.template_searcher.query(
- uniref90_msa_as_a3m)
- else:
- raise ValueError('Unrecognized template input format: '
- f'{self.template_searcher.input_format}')
- pdb_hits_out_path = os.path.join(
- msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
- with open(pdb_hits_out_path, 'w') as f:
- f.write(pdb_templates_result)
- uniref90_msa = parsers.parse_stockholm(
- jackhmmer_uniref90_result['sto'])
- uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
- mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
- mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
- pdb_template_hits = self.template_searcher.get_template_hits(
- output_string=pdb_templates_result, input_sequence=input_sequence)
- if self._use_small_bfd:
- bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
- jackhmmer_small_bfd_result = run_msa_tool(
- self.jackhmmer_small_bfd_runner,
- input_fasta_path,
- bfd_out_path,
- 'sto',
- self.use_precomputed_msas,
- )
- bfd_msa = parsers.parse_stockholm(
- jackhmmer_small_bfd_result['sto'])
- else:
- bfd_out_path = os.path.join(msa_output_dir,
- 'bfd_uniclust_hits.a3m')
- hhblits_bfd_uniclust_result = run_msa_tool(
- self.hhblits_bfd_uniclust_runner,
- input_fasta_path,
- bfd_out_path,
- 'a3m',
- self.use_precomputed_msas,
- )
- bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
- templates_result = self.template_featurizer.get_templates(
- query_sequence=input_sequence, hits=pdb_template_hits)
- sequence_features = make_sequence_features(
- sequence=input_sequence,
- description=input_description,
- num_res=num_res)
- msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))
- logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
- logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
- logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
- logging.info(
- 'Final (deduplicated) MSA size: %d sequences.',
- msa_features['num_alignments'][0],
- )
- logging.info(
- 'Total number of templates (NB: this can include bad '
- 'templates and is later filtered to top 4): %d.',
- templates_result.features['template_domain_names'].shape[0],
- )
- return {
- **sequence_features,
- **msa_features,
- **templates_result.features
- }
- def process_uniprot(self, input_fasta_path: str,
- msa_output_dir: str) -> FeatureDict:
- uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
- uniprot_result = run_msa_tool(
- self.jackhmmer_uniprot_runner,
- input_fasta_path,
- uniprot_path,
- 'sto',
- self.use_precomputed_msas,
- )
- msa = parsers.parse_stockholm(uniprot_result['sto'])
- msa = msa.truncate(max_seqs=50000)
- all_seq_dict = make_msa_features([msa])
- return all_seq_dict
|