templates.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110
  1. # Copyright 2021 DeepMind Technologies Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Functions for getting templates and calculating template features."""
  15. import abc
  16. import dataclasses
  17. import datetime
  18. import functools
  19. import glob
  20. import os
  21. import re
  22. from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
  23. import numpy as np
  24. from absl import logging
  25. from modelscope.models.science.unifold.data import residue_constants
  26. from modelscope.models.science.unifold.msa import mmcif, parsers
  27. from modelscope.models.science.unifold.msa.tools import kalign
  28. class Error(Exception):
  29. """Base class for exceptions."""
  30. class NoChainsError(Error):
  31. """An error indicating that template mmCIF didn't have any chains."""
  32. class SequenceNotInTemplateError(Error):
  33. """An error indicating that template mmCIF didn't contain the sequence."""
  34. class NoAtomDataInTemplateError(Error):
  35. """An error indicating that template mmCIF didn't contain atom positions."""
  36. class TemplateAtomMaskAllZerosError(Error):
  37. """An error indicating that template mmCIF had all atom positions masked."""
  38. class QueryToTemplateAlignError(Error):
  39. """An error indicating that the query can't be aligned to the template."""
  40. class CaDistanceError(Error):
  41. """An error indicating that a CA atom distance exceeds a threshold."""
  42. class MultipleChainsError(Error):
  43. """An error indicating that multiple chains were found for a given ID."""
  44. # Prefilter exceptions.
  45. class PrefilterError(Exception):
  46. """A base class for template prefilter exceptions."""
  47. class DateError(PrefilterError):
  48. """An error indicating that the hit date was after the max allowed date."""
  49. class AlignRatioError(PrefilterError):
  50. """An error indicating that the hit align ratio to the query was too small."""
  51. class DuplicateError(PrefilterError):
  52. """An error indicating that the hit was an exact subsequence of the query."""
  53. class LengthError(PrefilterError):
  54. """An error indicating that the hit was too short."""
  55. TEMPLATE_FEATURES = {
  56. 'template_aatype': np.float32,
  57. 'template_all_atom_mask': np.float32,
  58. 'template_all_atom_positions': np.float32,
  59. 'template_domain_names': np.object_,
  60. 'template_sequence': np.object_,
  61. 'template_sum_probs': np.float32,
  62. }
  63. def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
  64. """Returns PDB id and chain id for an HHSearch Hit."""
  65. # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
  66. id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
  67. if not id_match:
  68. raise ValueError(
  69. f'hit.name did not start with PDBID_chain: {hit.name}')
  70. pdb_id, chain_id = id_match.group(0).split('_')
  71. return pdb_id.lower(), chain_id
  72. def _is_after_cutoff(
  73. pdb_id: str,
  74. release_dates: Mapping[str, datetime.datetime],
  75. release_date_cutoff: Optional[datetime.datetime],
  76. ) -> bool:
  77. """Checks if the template date is after the release date cutoff.
  78. Args:
  79. pdb_id: 4 letter pdb code.
  80. release_dates: Dictionary mapping PDB ids to their structure release dates.
  81. release_date_cutoff: Max release date that is valid for this query.
  82. Returns:
  83. True if the template release date is after the cutoff, False otherwise.
  84. """
  85. if release_date_cutoff is None:
  86. raise ValueError('The release_date_cutoff must not be None.')
  87. if pdb_id in release_dates:
  88. return release_dates[pdb_id] > release_date_cutoff
  89. else:
  90. # Since this is just a quick prefilter to reduce the number of mmCIF files
  91. # we need to parse, we don't have to worry about returning True here.
  92. return False
  93. def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, Optional[str]]:
  94. """Parses the data file from PDB that lists which pdb_ids are obsolete."""
  95. with open(obsolete_file_path) as f:
  96. result = {}
  97. for line in f:
  98. line = line.strip()
  99. # Format: Date From To
  100. # 'OBSLTE 06-NOV-19 6G9Y' - Removed, rare
  101. # 'OBSLTE 31-JUL-94 116L 216L' - Replaced, common
  102. # 'OBSLTE 26-SEP-06 2H33 2JM5 2OWI' - Replaced by multiple, rare
  103. if line.startswith('OBSLTE'):
  104. if len(line) > 30:
  105. # Replaced by at least one structure.
  106. from_id = line[20:24].lower()
  107. to_id = line[29:33].lower()
  108. result[from_id] = to_id
  109. elif len(line) == 24:
  110. # Removed.
  111. from_id = line[20:24].lower()
  112. result[from_id] = None
  113. return result
  114. def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
  115. """Parses release dates file, returns a mapping from PDBs to release dates."""
  116. if path.endswith('txt'):
  117. release_dates = {}
  118. with open(path, 'r', encoding='utf-8') as f:
  119. for line in f:
  120. pdb_id, date = line.split(':')
  121. date = date.strip()
  122. # Python 3.6 doesn't have datetime.date.fromisoformat() which is about
  123. # 90x faster than strptime. However, splitting the string manually is
  124. # about 10x faster than strptime.
  125. release_dates[pdb_id.strip()] = datetime.datetime(
  126. year=int(date[:4]),
  127. month=int(date[5:7]),
  128. day=int(date[8:10]))
  129. return release_dates
  130. else:
  131. raise ValueError('Invalid format of the release date file %s.' % path)
  132. def _assess_hhsearch_hit(
  133. hit: parsers.TemplateHit,
  134. hit_pdb_code: str,
  135. query_sequence: str,
  136. release_dates: Mapping[str, datetime.datetime],
  137. release_date_cutoff: datetime.datetime,
  138. max_subsequence_ratio: float = 0.95,
  139. min_align_ratio: float = 0.1,
  140. ) -> bool:
  141. """Determines if template is valid (without parsing the template mmcif file).
  142. Args:
  143. hit: HhrHit for the template.
  144. hit_pdb_code: The 4 letter pdb code of the template hit. This might be
  145. different from the value in the actual hit since the original pdb might
  146. have become obsolete.
  147. query_sequence: Amino acid sequence of the query.
  148. release_dates: Dictionary mapping pdb codes to their structure release
  149. dates.
  150. release_date_cutoff: Max release date that is valid for this query.
  151. max_subsequence_ratio: Exclude any exact matches with this much overlap.
  152. min_align_ratio: Minimum overlap between the template and query.
  153. Returns:
  154. True if the hit passed the prefilter. Raises an exception otherwise.
  155. Raises:
  156. DateError: If the hit date was after the max allowed date.
  157. AlignRatioError: If the hit align ratio to the query was too small.
  158. DuplicateError: If the hit was an exact subsequence of the query.
  159. LengthError: If the hit was too short.
  160. """
  161. aligned_cols = hit.aligned_cols
  162. align_ratio = aligned_cols / len(query_sequence)
  163. template_sequence = hit.hit_sequence.replace('-', '')
  164. length_ratio = float(len(template_sequence)) / len(query_sequence)
  165. # Check whether the template is a large subsequence or duplicate of original
  166. # query. This can happen due to duplicate entries in the PDB database.
  167. duplicate = (
  168. template_sequence in query_sequence
  169. and length_ratio > max_subsequence_ratio)
  170. if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
  171. raise DateError(
  172. f'Date ({release_dates[hit_pdb_code]}) > max template date '
  173. f'({release_date_cutoff}).')
  174. if align_ratio <= min_align_ratio:
  175. raise AlignRatioError(
  176. 'Proportion of residues aligned to query too small. '
  177. f'Align ratio: {align_ratio}.')
  178. if duplicate:
  179. raise DuplicateError(
  180. 'Template is an exact subsequence of query with large '
  181. f'coverage. Length ratio: {length_ratio}.')
  182. if len(template_sequence) < 10:
  183. raise LengthError(
  184. f'Template too short. Length: {len(template_sequence)}.')
  185. return True
  186. def _find_template_in_pdb(
  187. template_chain_id: str, template_sequence: str,
  188. mmcif_object: mmcif.MmcifObject) -> Tuple[str, str, int]:
  189. """Tries to find the template chain in the given pdb file.
  190. This method tries the three following things in order:
  191. 1. Tries if there is an exact match in both the chain ID and the sequence.
  192. If yes, the chain sequence is returned. Otherwise:
  193. 2. Tries if there is an exact match only in the sequence.
  194. If yes, the chain sequence is returned. Otherwise:
  195. 3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
  196. If yes, the chain sequence is returned.
  197. If none of these succeed, a SequenceNotInTemplateError is thrown.
  198. Args:
  199. template_chain_id: The template chain ID.
  200. template_sequence: The template chain sequence.
  201. mmcif_object: The PDB object to search for the template in.
  202. Returns:
  203. A tuple with:
  204. * The chain sequence that was found to match the template in the PDB object.
  205. * The ID of the chain that is being returned.
  206. * The offset where the template sequence starts in the chain sequence.
  207. Raises:
  208. SequenceNotInTemplateError: If no match is found after the steps described
  209. above.
  210. """
  211. # Try if there is an exact match in both the chain ID and the (sub)sequence.
  212. pdb_id = mmcif_object.file_id
  213. chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
  214. if chain_sequence and (template_sequence in chain_sequence):
  215. logging.info('Found an exact template match %s_%s.', pdb_id,
  216. template_chain_id)
  217. mapping_offset = chain_sequence.find(template_sequence)
  218. return chain_sequence, template_chain_id, mapping_offset
  219. # Try if there is an exact match in the (sub)sequence only.
  220. for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
  221. if chain_sequence and (template_sequence in chain_sequence):
  222. logging.info('Found a sequence-only match %s_%s.', pdb_id,
  223. chain_id)
  224. mapping_offset = chain_sequence.find(template_sequence)
  225. return chain_sequence, chain_id, mapping_offset
  226. # Return a chain sequence that fuzzy matches (X = wildcard) the template.
  227. # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
  228. regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence]
  229. regex = re.compile(''.join(regex))
  230. for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
  231. match = re.search(regex, chain_sequence)
  232. if match:
  233. logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id,
  234. chain_id)
  235. mapping_offset = match.start()
  236. return chain_sequence, chain_id, mapping_offset
  237. # No hits, raise an error.
  238. raise SequenceNotInTemplateError(
  239. 'Could not find the template sequence in %s_%s. Template sequence: %s, '
  240. 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence,
  241. mmcif_object.chain_to_seqres))
  242. def _realign_pdb_template_to_query(
  243. old_template_sequence: str,
  244. template_chain_id: str,
  245. mmcif_object: mmcif.MmcifObject,
  246. old_mapping: Mapping[int, int],
  247. kalign_binary_path: str,
  248. ) -> Tuple[str, Mapping[int, int]]:
  249. """Aligns template from the mmcif_object to the query.
  250. In case PDB70 contains a different version of the template sequence, we need
  251. to perform a realignment to the actual sequence that is in the mmCIF file.
  252. This method performs such realignment, but returns the new sequence and
  253. mapping only if the sequence in the mmCIF file is 90% identical to the old
  254. sequence.
  255. Note that the old_template_sequence comes from the hit, and contains only that
  256. part of the chain that matches with the query while the new_template_sequence
  257. is the full chain.
  258. Args:
  259. old_template_sequence: The template sequence that was returned by the PDB
  260. template search (typically done using HHSearch).
  261. template_chain_id: The template chain id was returned by the PDB template
  262. search (typically done using HHSearch). This is used to find the right
  263. chain in the mmcif_object chain_to_seqres mapping.
  264. mmcif_object: A mmcif_object which holds the actual template data.
  265. old_mapping: A mapping from the query sequence to the template sequence.
  266. This mapping will be used to compute the new mapping from the query
  267. sequence to the actual mmcif_object template sequence by aligning the
  268. old_template_sequence and the actual template sequence.
  269. kalign_binary_path: The path to a kalign executable.
  270. Returns:
  271. A tuple (new_template_sequence, new_query_to_template_mapping) where:
  272. * new_template_sequence is the actual template sequence that was found in
  273. the mmcif_object.
  274. * new_query_to_template_mapping is the new mapping from the query to the
  275. actual template found in the mmcif_object.
  276. Raises:
  277. QueryToTemplateAlignError:
  278. * If there was an error thrown by the alignment tool.
  279. * Or if the actual template sequence differs by more than 10% from the
  280. old_template_sequence.
  281. """
  282. aligner = kalign.Kalign(binary_path=kalign_binary_path)
  283. new_template_sequence = mmcif_object.chain_to_seqres.get(
  284. template_chain_id, '')
  285. # Sometimes the template chain id is unknown. But if there is only a single
  286. # sequence within the mmcif_object, it is safe to assume it is that one.
  287. if not new_template_sequence:
  288. if len(mmcif_object.chain_to_seqres) == 1:
  289. logging.info(
  290. 'Could not find %s in %s, but there is only 1 sequence, so '
  291. 'using that one.',
  292. template_chain_id,
  293. mmcif_object.file_id,
  294. )
  295. new_template_sequence = list(
  296. mmcif_object.chain_to_seqres.values())[0]
  297. else:
  298. raise QueryToTemplateAlignError(
  299. f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. '
  300. 'If there are no mmCIF parsing errors, it is possible it was not a '
  301. 'protein chain.')
  302. try:
  303. parsed_a3m = parsers.parse_a3m(
  304. aligner.align([old_template_sequence, new_template_sequence]))
  305. old_aligned_template, new_aligned_template = parsed_a3m.sequences
  306. except Exception as e:
  307. raise QueryToTemplateAlignError(
  308. 'Could not align old template %s to template %s (%s_%s). Error: %s'
  309. % (
  310. old_template_sequence,
  311. new_template_sequence,
  312. mmcif_object.file_id,
  313. template_chain_id,
  314. str(e),
  315. ))
  316. logging.info(
  317. 'Old aligned template: %s\nNew aligned template: %s',
  318. old_aligned_template,
  319. new_aligned_template,
  320. )
  321. old_to_new_template_mapping = {}
  322. old_template_index = -1
  323. new_template_index = -1
  324. num_same = 0
  325. for old_template_aa, new_template_aa in zip(old_aligned_template,
  326. new_aligned_template):
  327. if old_template_aa != '-':
  328. old_template_index += 1
  329. if new_template_aa != '-':
  330. new_template_index += 1
  331. if old_template_aa != '-' and new_template_aa != '-':
  332. old_to_new_template_mapping[
  333. old_template_index] = new_template_index
  334. if old_template_aa == new_template_aa:
  335. num_same += 1
  336. # Require at least 90 % sequence identity wrt to the shorter of the sequences.
  337. if (float(num_same)
  338. / min(len(old_template_sequence), len(new_template_sequence))
  339. < # noqa W504
  340. 0.9):
  341. raise QueryToTemplateAlignError(
  342. 'Insufficient similarity of the sequence in the database: %s to the '
  343. 'actual sequence in the mmCIF file %s_%s: %s. We require at least '
  344. '90 %% similarity wrt to the shorter of the sequences. This is not a '
  345. 'problem unless you think this is a template that should be included.'
  346. % (
  347. old_template_sequence,
  348. mmcif_object.file_id,
  349. template_chain_id,
  350. new_template_sequence,
  351. ))
  352. new_query_to_template_mapping = {}
  353. for query_index, old_template_index in old_mapping.items():
  354. new_query_to_template_mapping[
  355. query_index] = old_to_new_template_mapping.get(
  356. old_template_index, -1)
  357. new_template_sequence = new_template_sequence.replace('-', '')
  358. return new_template_sequence, new_query_to_template_mapping
  359. def _check_residue_distances(all_positions: np.ndarray,
  360. all_positions_mask: np.ndarray,
  361. max_ca_ca_distance: float):
  362. """Checks if the distance between unmasked neighbor residues is ok."""
  363. ca_position = residue_constants.atom_order['CA']
  364. prev_is_unmasked = False
  365. prev_calpha = None
  366. for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
  367. this_is_unmasked = bool(mask[ca_position])
  368. if this_is_unmasked:
  369. this_calpha = coords[ca_position]
  370. if prev_is_unmasked:
  371. distance = np.linalg.norm(this_calpha - prev_calpha)
  372. if distance > max_ca_ca_distance:
  373. raise CaDistanceError(
  374. 'The distance between residues %d and %d is %f > limit %f.'
  375. % (i, i + 1, distance, max_ca_ca_distance))
  376. prev_calpha = this_calpha
  377. prev_is_unmasked = this_is_unmasked
  378. def _get_atom_positions(
  379. mmcif_object: mmcif.MmcifObject, auth_chain_id: str,
  380. max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
  381. """Gets atom positions and mask from a list of Biopython Residues."""
  382. num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])
  383. relevant_chains = [
  384. c for c in mmcif_object.structure.get_chains() if c.id == auth_chain_id
  385. ]
  386. if len(relevant_chains) != 1:
  387. raise MultipleChainsError(
  388. f'Expected exactly one chain in structure with id {auth_chain_id}.'
  389. )
  390. chain = relevant_chains[0]
  391. all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
  392. all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
  393. dtype=np.int64)
  394. for res_index in range(num_res):
  395. pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
  396. mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
  397. res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][
  398. res_index]
  399. if not res_at_position.is_missing:
  400. res = chain[(
  401. res_at_position.hetflag,
  402. res_at_position.position.residue_number,
  403. res_at_position.position.insertion_code,
  404. )]
  405. for atom in res.get_atoms():
  406. atom_name = atom.get_name()
  407. x, y, z = atom.get_coord()
  408. if atom_name in residue_constants.atom_order.keys():
  409. pos[residue_constants.atom_order[atom_name]] = [x, y, z]
  410. mask[residue_constants.atom_order[atom_name]] = 1.0
  411. elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
  412. # Put the coordinates of the selenium atom in the sulphur column.
  413. pos[residue_constants.atom_order['SD']] = [x, y, z]
  414. mask[residue_constants.atom_order['SD']] = 1.0
  415. # Fix naming errors in arginine residues where NH2 is incorrectly
  416. # assigned to be closer to CD than NH1.
  417. cd = residue_constants.atom_order['CD']
  418. nh1 = residue_constants.atom_order['NH1']
  419. nh2 = residue_constants.atom_order['NH2']
  420. if (res.get_resname() == 'ARG'
  421. and all(mask[atom_index] for atom_index in (cd, nh1, nh2))
  422. and (np.linalg.norm(pos[nh1] - pos[cd]) > # noqa W504
  423. np.linalg.norm(pos[nh2] - pos[cd]))):
  424. pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
  425. mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
  426. all_positions[res_index] = pos
  427. all_positions_mask[res_index] = mask
  428. _check_residue_distances(all_positions, all_positions_mask,
  429. max_ca_ca_distance)
  430. return all_positions, all_positions_mask
  431. def _extract_template_features(
  432. mmcif_object: mmcif.MmcifObject,
  433. pdb_id: str,
  434. mapping: Mapping[int, int],
  435. template_sequence: str,
  436. query_sequence: str,
  437. template_chain_id: str,
  438. kalign_binary_path: str,
  439. ) -> Tuple[Dict[str, Any], Optional[str]]:
  440. """Parses atom positions in the target structure and aligns with the query.
  441. Atoms for each residue in the template structure are indexed to coincide
  442. with their corresponding residue in the query sequence, according to the
  443. alignment mapping provided.
  444. Args:
  445. mmcif_object: mmcif_parsing.MmcifObject representing the template.
  446. pdb_id: PDB code for the template.
  447. mapping: Dictionary mapping indices in the query sequence to indices in
  448. the template sequence.
  449. template_sequence: String describing the amino acid sequence for the
  450. template protein.
  451. query_sequence: String describing the amino acid sequence for the query
  452. protein.
  453. template_chain_id: String ID describing which chain in the structure proto
  454. should be used.
  455. kalign_binary_path: The path to a kalign executable used for template
  456. realignment.
  457. Returns:
  458. A tuple with:
  459. * A dictionary containing the extra features derived from the template
  460. protein structure.
  461. * A warning message if the hit was realigned to the actual mmCIF sequence.
  462. Otherwise None.
  463. Raises:
  464. NoChainsError: If the mmcif object doesn't contain any chains.
  465. SequenceNotInTemplateError: If the given chain id / sequence can't
  466. be found in the mmcif object.
  467. QueryToTemplateAlignError: If the actual template in the mmCIF file
  468. can't be aligned to the query.
  469. NoAtomDataInTemplateError: If the mmcif object doesn't contain
  470. atom positions.
  471. TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
  472. unmasked residues.
  473. """
  474. if mmcif_object is None or not mmcif_object.chain_to_seqres:
  475. raise NoChainsError('No chains in PDB: %s_%s' %
  476. (pdb_id, template_chain_id))
  477. warning = None
  478. try:
  479. seqres, chain_id, mapping_offset = _find_template_in_pdb(
  480. template_chain_id=template_chain_id,
  481. template_sequence=template_sequence,
  482. mmcif_object=mmcif_object,
  483. )
  484. except SequenceNotInTemplateError:
  485. # If PDB70 contains a different version of the template, we use the sequence
  486. # from the mmcif_object.
  487. chain_id = template_chain_id
  488. warning = (
  489. f'The exact sequence {template_sequence} was not found in '
  490. f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.'
  491. )
  492. logging.warning(warning)
  493. # This throws an exception if it fails to realign the hit.
  494. seqres, mapping = _realign_pdb_template_to_query(
  495. old_template_sequence=template_sequence,
  496. template_chain_id=template_chain_id,
  497. mmcif_object=mmcif_object,
  498. old_mapping=mapping,
  499. kalign_binary_path=kalign_binary_path,
  500. )
  501. logging.info(
  502. 'Sequence in %s_%s: %s successfully realigned to %s',
  503. pdb_id,
  504. chain_id,
  505. template_sequence,
  506. seqres,
  507. )
  508. # The template sequence changed.
  509. template_sequence = seqres
  510. # No mapping offset, the query is aligned to the actual sequence.
  511. mapping_offset = 0
  512. try:
  513. # Essentially set to infinity - we don't want to reject templates unless
  514. # they're really really bad.
  515. all_atom_positions, all_atom_mask = _get_atom_positions(
  516. mmcif_object, chain_id, max_ca_ca_distance=150.0)
  517. except (CaDistanceError, KeyError) as ex:
  518. raise NoAtomDataInTemplateError('Could not get atom data (%s_%s): %s' %
  519. (pdb_id, chain_id, str(ex))) from ex
  520. all_atom_positions = np.split(all_atom_positions,
  521. all_atom_positions.shape[0])
  522. all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
  523. output_templates_sequence = []
  524. templates_all_atom_positions = []
  525. templates_all_atom_masks = []
  526. for _ in query_sequence:
  527. # Residues in the query_sequence that are not in the template_sequence:
  528. templates_all_atom_positions.append(
  529. np.zeros((residue_constants.atom_type_num, 3)))
  530. templates_all_atom_masks.append(
  531. np.zeros(residue_constants.atom_type_num))
  532. output_templates_sequence.append('-')
  533. for k, v in mapping.items():
  534. template_index = v + mapping_offset
  535. templates_all_atom_positions[k] = all_atom_positions[template_index][0]
  536. templates_all_atom_masks[k] = all_atom_masks[template_index][0]
  537. output_templates_sequence[k] = template_sequence[v]
  538. # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
  539. if np.sum(templates_all_atom_masks) < 5:
  540. raise TemplateAtomMaskAllZerosError(
  541. 'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d'
  542. % (
  543. pdb_id,
  544. chain_id,
  545. min(mapping.values()) + mapping_offset,
  546. max(mapping.values()) + mapping_offset,
  547. ))
  548. output_templates_sequence = ''.join(output_templates_sequence)
  549. templates_aatype = residue_constants.sequence_to_onehot(
  550. output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID)
  551. return (
  552. {
  553. 'template_all_atom_positions':
  554. np.array(templates_all_atom_positions),
  555. 'template_all_atom_mask': np.array(templates_all_atom_masks),
  556. 'template_sequence': output_templates_sequence.encode(),
  557. 'template_aatype': np.array(templates_aatype),
  558. 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
  559. },
  560. warning,
  561. )
  562. def _build_query_to_hit_index_mapping(
  563. hit_query_sequence: str,
  564. hit_sequence: str,
  565. indices_hit: Sequence[int],
  566. indices_query: Sequence[int],
  567. original_query_sequence: str,
  568. ) -> Mapping[int, int]:
  569. """Gets mapping from indices in original query sequence to indices in the hit.
  570. hit_query_sequence and hit_sequence are two aligned sequences containing gap
  571. characters. hit_query_sequence contains only the part of the original query
  572. sequence that matched the hit. When interpreting the indices from the .hhr, we
  573. need to correct for this to recover a mapping from original query sequence to
  574. the hit sequence.
  575. Args:
  576. hit_query_sequence: The portion of the query sequence that is in the .hhr
  577. hit
  578. hit_sequence: The portion of the hit sequence that is in the .hhr
  579. indices_hit: The indices for each aminoacid relative to the hit sequence
  580. indices_query: The indices for each aminoacid relative to the original query
  581. sequence
  582. original_query_sequence: String describing the original query sequence.
  583. Returns:
  584. Dictionary with indices in the original query sequence as keys and indices
  585. in the hit sequence as values.
  586. """
  587. # If the hit is empty (no aligned residues), return empty mapping
  588. if not hit_query_sequence:
  589. return {}
  590. # Remove gaps and find the offset of hit.query relative to original query.
  591. hhsearch_query_sequence = hit_query_sequence.replace('-', '')
  592. hit_sequence = hit_sequence.replace('-', '')
  593. hhsearch_query_offset = original_query_sequence.find(
  594. hhsearch_query_sequence)
  595. # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
  596. min_idx = min(x for x in indices_hit if x > -1)
  597. fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
  598. min_idx = min(x for x in indices_query if x > -1)
  599. fixed_indices_query = [
  600. x - min_idx if x > -1 else -1 for x in indices_query
  601. ]
  602. # Zip the corrected indices, ignore case where both seqs have gap characters.
  603. mapping = {}
  604. for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
  605. if q_t != -1 and q_i != -1:
  606. if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
  607. original_query_sequence):
  608. continue
  609. mapping[q_i + hhsearch_query_offset] = q_t
  610. return mapping
  611. @dataclasses.dataclass(frozen=True)
  612. class SingleHitResult:
  613. features: Optional[Mapping[str, Any]]
  614. error: Optional[str]
  615. warning: Optional[str]
  616. @functools.lru_cache(16, typed=False)
  617. def _read_file(path):
  618. with open(path, 'r') as f:
  619. file_data = f.read()
  620. return file_data
  621. def _process_single_hit(
  622. query_sequence: str,
  623. hit: parsers.TemplateHit,
  624. mmcif_dir: str,
  625. max_template_date: datetime.datetime,
  626. release_dates: Mapping[str, datetime.datetime],
  627. obsolete_pdbs: Mapping[str, Optional[str]],
  628. kalign_binary_path: str,
  629. strict_error_check: bool = False,
  630. ) -> SingleHitResult:
  631. """Tries to extract template features from a single HHSearch hit."""
  632. # Fail hard if we can't get the PDB ID and chain name from the hit.
  633. hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
  634. # This hit has been removed (obsoleted) from PDB, skip it.
  635. if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None:
  636. return SingleHitResult(
  637. features=None,
  638. error=None,
  639. warning=f'Hit {hit_pdb_code} is obsolete.')
  640. if hit_pdb_code not in release_dates:
  641. if hit_pdb_code in obsolete_pdbs:
  642. hit_pdb_code = obsolete_pdbs[hit_pdb_code]
  643. # Pass hit_pdb_code since it might have changed due to the pdb being obsolete.
  644. try:
  645. _assess_hhsearch_hit(
  646. hit=hit,
  647. hit_pdb_code=hit_pdb_code,
  648. query_sequence=query_sequence,
  649. release_dates=release_dates,
  650. release_date_cutoff=max_template_date,
  651. )
  652. except PrefilterError as e:
  653. msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
  654. logging.info(msg)
  655. if strict_error_check and isinstance(e, (DateError, DuplicateError)):
  656. # In strict mode we treat some prefilter cases as errors.
  657. return SingleHitResult(features=None, error=msg, warning=None)
  658. return SingleHitResult(features=None, error=None, warning=None)
  659. mapping = _build_query_to_hit_index_mapping(hit.query, hit.hit_sequence,
  660. hit.indices_hit,
  661. hit.indices_query,
  662. query_sequence)
  663. # The mapping is from the query to the actual hit sequence, so we need to
  664. # remove gaps (which regardless have a missing confidence score).
  665. template_sequence = hit.hit_sequence.replace('-', '')
  666. cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
  667. logging.debug(
  668. 'Reading PDB entry from %s. Query: %s, template: %s',
  669. cif_path,
  670. query_sequence,
  671. template_sequence,
  672. )
  673. # Fail if we can't find the mmCIF file.
  674. cif_string = _read_file(cif_path)
  675. parsing_result = mmcif.parse(file_id=hit_pdb_code, mmcif_string=cif_string)
  676. if parsing_result.mmcif_object is not None:
  677. hit_release_date = datetime.datetime.strptime(
  678. parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d')
  679. if hit_release_date > max_template_date:
  680. error = 'Template %s date (%s) > max template date (%s).' % (
  681. hit_pdb_code,
  682. hit_release_date,
  683. max_template_date,
  684. )
  685. if strict_error_check:
  686. return SingleHitResult(
  687. features=None, error=error, warning=None)
  688. else:
  689. logging.debug(error)
  690. return SingleHitResult(features=None, error=None, warning=None)
  691. try:
  692. features, realign_warning = _extract_template_features(
  693. mmcif_object=parsing_result.mmcif_object,
  694. pdb_id=hit_pdb_code,
  695. mapping=mapping,
  696. template_sequence=template_sequence,
  697. query_sequence=query_sequence,
  698. template_chain_id=hit_chain_id,
  699. kalign_binary_path=kalign_binary_path,
  700. )
  701. if hit.sum_probs is None:
  702. features['template_sum_probs'] = [0]
  703. else:
  704. features['template_sum_probs'] = [hit.sum_probs]
  705. # It is possible there were some errors when parsing the other chains in the
  706. # mmCIF file, but the template features for the chain we want were still
  707. # computed. In such case the mmCIF parsing errors are not relevant.
  708. return SingleHitResult(
  709. features=features, error=None, warning=realign_warning)
  710. except (
  711. NoChainsError,
  712. NoAtomDataInTemplateError,
  713. TemplateAtomMaskAllZerosError,
  714. ) as e:
  715. # These 3 errors indicate missing mmCIF experimental data rather than a
  716. # problem with the template search, so turn them into warnings.
  717. warning = (
  718. '%s_%s (sum_probs: %s, rank: %s): feature extracting errors: '
  719. '%s, mmCIF parsing errors: %s' % (
  720. hit_pdb_code,
  721. hit_chain_id,
  722. hit.sum_probs,
  723. hit.index,
  724. str(e),
  725. parsing_result.errors,
  726. ))
  727. if strict_error_check:
  728. return SingleHitResult(features=None, error=warning, warning=None)
  729. else:
  730. return SingleHitResult(features=None, error=None, warning=warning)
  731. except Error as e:
  732. error = (
  733. '%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
  734. '%s, mmCIF parsing errors: %s' % (
  735. hit_pdb_code,
  736. hit_chain_id,
  737. hit.sum_probs,
  738. hit.index,
  739. str(e),
  740. parsing_result.errors,
  741. ))
  742. return SingleHitResult(features=None, error=error, warning=None)
  743. @dataclasses.dataclass(frozen=True)
  744. class TemplateSearchResult:
  745. features: Mapping[str, Any]
  746. errors: Sequence[str]
  747. warnings: Sequence[str]
  748. class TemplateHitFeaturizer(abc.ABC):
  749. """An abstract base class for turning template hits to template features."""
  750. def __init__(
  751. self,
  752. mmcif_dir: str,
  753. max_template_date: str,
  754. max_hits: int,
  755. kalign_binary_path: str,
  756. release_dates_path: Optional[str],
  757. obsolete_pdbs_path: Optional[str],
  758. strict_error_check: bool = False,
  759. ):
  760. """Initializes the Template Search.
  761. Args:
  762. mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
  763. is found by HHSearch, this directory is used to retrieve the template
  764. data.
  765. max_template_date: The maximum date permitted for template structures. No
  766. template with date higher than this date will be returned. In ISO8601
  767. date format, YYYY-MM-DD.
  768. max_hits: The maximum number of templates that will be returned.
  769. kalign_binary_path: The path to a kalign executable used for template
  770. realignment.
  771. release_dates_path: An optional path to a file with a mapping from PDB IDs
  772. to their release dates. Thanks to this we don't have to redundantly
  773. parse mmCIF files to get that information.
  774. obsolete_pdbs_path: An optional path to a file containing a mapping from
  775. obsolete PDB IDs to the PDB IDs of their replacements.
  776. strict_error_check: If True, then the following will be treated as errors:
  777. * If any template date is after the max_template_date.
  778. * If any template has identical PDB ID to the query.
  779. * If any template is a duplicate of the query.
  780. * Any feature computation errors.
  781. """
  782. self._mmcif_dir = mmcif_dir
  783. if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')):
  784. logging.error('Could not find CIFs in %s', self._mmcif_dir)
  785. raise ValueError(f'Could not find CIFs in {self._mmcif_dir}')
  786. try:
  787. self._max_template_date = datetime.datetime.strptime(
  788. max_template_date, '%Y-%m-%d')
  789. except ValueError:
  790. raise ValueError(
  791. 'max_template_date must be set and have format YYYY-MM-DD.')
  792. self._max_hits = max_hits
  793. self._kalign_binary_path = kalign_binary_path
  794. self._strict_error_check = strict_error_check
  795. if release_dates_path:
  796. logging.info('Using precomputed release dates %s.',
  797. release_dates_path)
  798. self._release_dates = _parse_release_dates(release_dates_path)
  799. else:
  800. self._release_dates = {}
  801. if obsolete_pdbs_path:
  802. logging.info('Using precomputed obsolete pdbs %s.',
  803. obsolete_pdbs_path)
  804. self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
  805. else:
  806. self._obsolete_pdbs = {}
  807. @abc.abstractmethod
  808. def get_templates(
  809. self, query_sequence: str,
  810. hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
  811. """Computes the templates for given query sequence."""
  812. class HhsearchHitFeaturizer(TemplateHitFeaturizer):
  813. """A class for turning a3m hits from hhsearch to template features."""
  814. def get_templates(
  815. self, query_sequence: str,
  816. hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
  817. """Computes the templates for given query sequence (more details above)."""
  818. logging.info('Searching for template for: %s', query_sequence)
  819. template_features = {}
  820. for template_feature_name in TEMPLATE_FEATURES:
  821. template_features[template_feature_name] = []
  822. num_hits = 0
  823. errors = []
  824. warnings = []
  825. for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True):
  826. # We got all the templates we wanted, stop processing hits.
  827. if num_hits >= self._max_hits:
  828. break
  829. result = _process_single_hit(
  830. query_sequence=query_sequence,
  831. hit=hit,
  832. mmcif_dir=self._mmcif_dir,
  833. max_template_date=self._max_template_date,
  834. release_dates=self._release_dates,
  835. obsolete_pdbs=self._obsolete_pdbs,
  836. strict_error_check=self._strict_error_check,
  837. kalign_binary_path=self._kalign_binary_path,
  838. )
  839. if result.error:
  840. errors.append(result.error)
  841. # There could be an error even if there are some results, e.g. thrown by
  842. # other unparsable chains in the same mmCIF file.
  843. if result.warning:
  844. warnings.append(result.warning)
  845. if result.features is None:
  846. logging.info(
  847. 'Skipped invalid hit %s, error: %s, warning: %s',
  848. hit.name,
  849. result.error,
  850. result.warning,
  851. )
  852. else:
  853. # Increment the hit counter, since we got features out of this hit.
  854. num_hits += 1
  855. for k in template_features:
  856. template_features[k].append(result.features[k])
  857. for name in template_features:
  858. if num_hits > 0:
  859. template_features[name] = np.stack(
  860. template_features[name],
  861. axis=0).astype(TEMPLATE_FEATURES[name])
  862. else:
  863. # Make sure the feature has correct dtype even if empty.
  864. template_features[name] = np.array(
  865. [], dtype=TEMPLATE_FEATURES[name])
  866. return TemplateSearchResult(
  867. features=template_features, errors=errors, warnings=warnings)
  868. class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
  869. """A class for turning a3m hits from hmmsearch to template features."""
  870. def get_templates(
  871. self, query_sequence: str,
  872. hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
  873. """Computes the templates for given query sequence (more details above)."""
  874. logging.info('Searching for template for: %s', query_sequence)
  875. template_features = {}
  876. for template_feature_name in TEMPLATE_FEATURES:
  877. template_features[template_feature_name] = []
  878. already_seen = set()
  879. errors = []
  880. warnings = []
  881. if not hits or hits[0].sum_probs is None:
  882. sorted_hits = hits
  883. else:
  884. sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True)
  885. for hit in sorted_hits:
  886. # We got all the templates we wanted, stop processing hits.
  887. if len(already_seen) >= self._max_hits:
  888. break
  889. result = _process_single_hit(
  890. query_sequence=query_sequence,
  891. hit=hit,
  892. mmcif_dir=self._mmcif_dir,
  893. max_template_date=self._max_template_date,
  894. release_dates=self._release_dates,
  895. obsolete_pdbs=self._obsolete_pdbs,
  896. strict_error_check=self._strict_error_check,
  897. kalign_binary_path=self._kalign_binary_path,
  898. )
  899. if result.error:
  900. errors.append(result.error)
  901. # There could be an error even if there are some results, e.g. thrown by
  902. # other unparsable chains in the same mmCIF file.
  903. if result.warning:
  904. warnings.append(result.warning)
  905. if result.features is None:
  906. logging.debug(
  907. 'Skipped invalid hit %s, error: %s, warning: %s',
  908. hit.name,
  909. result.error,
  910. result.warning,
  911. )
  912. else:
  913. already_seen_key = result.features['template_sequence']
  914. if already_seen_key in already_seen:
  915. continue
  916. # Increment the hit counter, since we got features out of this hit.
  917. already_seen.add(already_seen_key)
  918. for k in template_features:
  919. template_features[k].append(result.features[k])
  920. if already_seen:
  921. for name in template_features:
  922. template_features[name] = np.stack(
  923. template_features[name],
  924. axis=0).astype(TEMPLATE_FEATURES[name])
  925. else:
  926. num_res = len(query_sequence)
  927. # Construct a default template with all zeros.
  928. template_features = {
  929. 'template_aatype':
  930. np.zeros(
  931. (1, num_res, len(
  932. residue_constants.restypes_with_x_and_gap)),
  933. np.float32,
  934. ),
  935. 'template_all_atom_mask':
  936. np.zeros((1, num_res, residue_constants.atom_type_num),
  937. np.float32),
  938. 'template_all_atom_positions':
  939. np.zeros((1, num_res, residue_constants.atom_type_num, 3),
  940. np.float32),
  941. 'template_domain_names':
  942. np.array([''.encode()], dtype=np.object_),
  943. 'template_sequence':
  944. np.array([''.encode()], dtype=np.object_),
  945. 'template_sum_probs':
  946. np.array([0], dtype=np.float32),
  947. }
  948. return TemplateSearchResult(
  949. features=template_features, errors=errors, warnings=warnings)