| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- # Copyright 2021 DeepMind Technologies Limited
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Protein data type."""
- import dataclasses
- import io
- from typing import Any, Mapping, Optional
- import numpy as np
- from Bio.PDB import PDBParser
- from modelscope.models.science.unifold.data import residue_constants
- FeatureDict = Mapping[str, np.ndarray]
- ModelOutput = Mapping[str, Any] # Is a nested dict.
- # Complete sequence of chain IDs supported by the PDB format.
- PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
- PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
- @dataclasses.dataclass(frozen=True)
- class Protein:
- """Protein structure representation."""
- # Cartesian coordinates of atoms in angstroms. The atom types correspond to
- # residue_constants.atom_types, i.e. the first three are N, CA, CB.
- atom_positions: np.ndarray # [num_res, num_atom_type, 3]
- # Amino-acid type for each residue represented as an integer between 0 and
- # 20, where 20 is 'X'.
- aatype: np.ndarray # [num_res]
- # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
- # is present and 0.0 if not. This should be used for loss masking.
- atom_mask: np.ndarray # [num_res, num_atom_type]
- # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
- residue_index: np.ndarray # [num_res]
- # 0-indexed number corresponding to the chain in the protein that this residue
- # belongs to.
- chain_index: np.ndarray # [num_res]
- # B-factors, or temperature factors, of each residue (in sq. angstroms units),
- # representing the displacement of the residue from its ground truth mean
- # value.
- b_factors: np.ndarray # [num_res, num_atom_type]
- def __post_init__(self):
- if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
- raise ValueError(
- f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
- 'because these cannot be written to PDB format.')
- def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
- """Takes a PDB string and constructs a Protein object.
- WARNING: All non-standard residue types will be converted into UNK. All
- non-standard atoms will be ignored.
- Args:
- pdb_str: The contents of the pdb file
- chain_id: If chain_id is specified (e.g. A), then only that chain
- is parsed. Otherwise all chains are parsed.
- Returns:
- A new `Protein` parsed from the pdb contents.
- """
- pdb_fh = io.StringIO(pdb_str)
- parser = PDBParser(QUIET=True)
- structure = parser.get_structure('none', pdb_fh)
- models = list(structure.get_models())
- if len(models) != 1:
- raise ValueError(
- f'Only single model PDBs are supported. Found {len(models)} models.'
- )
- model = models[0]
- atom_positions = []
- aatype = []
- atom_mask = []
- residue_index = []
- chain_ids = []
- b_factors = []
- for chain in model:
- if chain_id is not None and chain.id != chain_id:
- continue
- for res in chain:
- if res.id[2] != ' ':
- raise ValueError(
- f'PDB contains an insertion code at chain {chain.id} and residue '
- f'index {res.id[1]}. These are not supported.')
- res_shortname = residue_constants.restype_3to1.get(
- res.resname, 'X')
- restype_idx = residue_constants.restype_order.get(
- res_shortname, residue_constants.restype_num)
- pos = np.zeros((residue_constants.atom_type_num, 3))
- mask = np.zeros((residue_constants.atom_type_num, ))
- res_b_factors = np.zeros((residue_constants.atom_type_num, ))
- for atom in res:
- if atom.name not in residue_constants.atom_types:
- continue
- pos[residue_constants.atom_order[atom.name]] = atom.coord
- mask[residue_constants.atom_order[atom.name]] = 1.0
- res_b_factors[residue_constants.atom_order[
- atom.name]] = atom.bfactor
- if np.sum(mask) < 0.5:
- # If no known atom positions are reported for the residue then skip it.
- continue
- aatype.append(restype_idx)
- atom_positions.append(pos)
- atom_mask.append(mask)
- residue_index.append(res.id[1])
- chain_ids.append(chain.id)
- b_factors.append(res_b_factors)
- # Chain IDs are usually characters so map these to ints.
- unique_chain_ids = np.unique(chain_ids)
- chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
- chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
- return Protein(
- atom_positions=np.array(atom_positions),
- atom_mask=np.array(atom_mask),
- aatype=np.array(aatype),
- residue_index=np.array(residue_index),
- chain_index=chain_index,
- b_factors=np.array(b_factors),
- )
- def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
- chain_end = 'TER'
- return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
- f'{chain_name:>1}{residue_index:>4}')
- def to_pdb(prot: Protein) -> str:
- """Converts a `Protein` instance to a PDB string.
- Args:
- prot: The protein to convert to PDB.
- Returns:
- PDB string.
- """
- restypes = residue_constants.restypes + ['X']
- # res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
- def res_1to3(r):
- return residue_constants.restype_1to3.get(restypes[r], 'UNK')
- atom_types = residue_constants.atom_types
- pdb_lines = []
- atom_mask = prot.atom_mask
- aatype = prot.aatype
- atom_positions = prot.atom_positions
- residue_index = prot.residue_index.astype(np.int32)
- chain_index = prot.chain_index.astype(np.int32)
- b_factors = prot.b_factors
- if np.any(aatype > residue_constants.restype_num):
- raise ValueError('Invalid aatypes.')
- # Construct a mapping from chain integer indices to chain ID strings.
- chain_ids = {}
- for i in np.unique(chain_index): # np.unique gives sorted output.
- if i >= PDB_MAX_CHAINS:
- raise ValueError(
- f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
- chain_ids[i] = PDB_CHAIN_IDS[i]
- pdb_lines.append('MODEL 1')
- atom_index = 1
- last_chain_index = chain_index[0]
- # Add all atom sites.
- for i in range(aatype.shape[0]):
- # Close the previous chain if in a multichain PDB.
- if last_chain_index != chain_index[i]:
- pdb_lines.append(
- _chain_end(
- atom_index,
- res_1to3(aatype[i - 1]),
- chain_ids[chain_index[i - 1]],
- residue_index[i - 1],
- ))
- last_chain_index = chain_index[i]
- atom_index += 1 # Atom index increases at the TER symbol.
- res_name_3 = res_1to3(aatype[i])
- for atom_name, pos, mask, b_factor in zip(atom_types,
- atom_positions[i],
- atom_mask[i], b_factors[i]):
- if mask < 0.5:
- continue
- record_type = 'ATOM'
- name = atom_name if len(atom_name) == 4 else f' {atom_name}'
- alt_loc = ''
- insertion_code = ''
- occupancy = 1.00
- element = atom_name[
- 0] # Protein supports only C, N, O, S, this works.
- charge = ''
- # PDB is a columnar format, every space matters here!
- atom_line = (
- f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
- f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
- f'{residue_index[i]:>4}{insertion_code:>1} '
- f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
- f'{occupancy:>6.2f}{b_factor:>6.2f} '
- f'{element:>2}{charge:>2}')
- pdb_lines.append(atom_line)
- atom_index += 1
- # Close the final chain.
- pdb_lines.append(
- _chain_end(
- atom_index,
- res_1to3(aatype[-1]),
- chain_ids[chain_index[-1]],
- residue_index[-1],
- ))
- pdb_lines.append('ENDMDL')
- pdb_lines.append('END')
- # Pad all lines to 80 characters.
- pdb_lines = [line.ljust(80) for line in pdb_lines]
- return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
- def ideal_atom_mask(prot: Protein) -> np.ndarray:
- """Computes an ideal atom mask.
- `Protein.atom_mask` typically is defined according to the atoms that are
- reported in the PDB. This function computes a mask according to heavy atoms
- that should be present in the given sequence of amino acids.
- Args:
- prot: `Protein` whose fields are `numpy.ndarray` objects.
- Returns:
- An ideal atom mask.
- """
- return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
- def from_prediction(features: FeatureDict,
- result: ModelOutput,
- b_factors: Optional[np.ndarray] = None) -> Protein:
- """Assembles a protein from a prediction.
- Args:
- features: Dictionary holding model inputs.
- fold_output: Dictionary holding model outputs.
- b_factors: (Optional) B-factors to use for the protein.
- Returns:
- A protein instance.
- """
- if 'asym_id' in features:
- chain_index = features['asym_id'] - 1
- else:
- chain_index = np.zeros_like((features['aatype']))
- if b_factors is None:
- b_factors = np.zeros_like(result['final_atom_mask'])
- return Protein(
- aatype=features['aatype'],
- atom_positions=result['final_atom_positions'],
- atom_mask=result['final_atom_mask'],
- residue_index=features['residue_index'] + 1,
- chain_index=chain_index,
- b_factors=b_factors,
- )
- def from_feature(features: FeatureDict,
- b_factors: Optional[np.ndarray] = None) -> Protein:
- """Assembles a standard pdb from input atom positions & mask.
- Args:
- features: Dictionary holding model inputs.
- b_factors: (Optional) B-factors to use for the protein.
- Returns:
- A protein instance.
- """
- if 'asym_id' in features:
- chain_index = features['asym_id'] - 1
- else:
- chain_index = np.zeros_like((features['aatype']))
- if b_factors is None:
- b_factors = np.zeros_like(features['all_atom_mask'])
- return Protein(
- aatype=features['aatype'],
- atom_positions=features['all_atom_positions'],
- atom_mask=features['all_atom_mask'],
- residue_index=features['residue_index'] + 1,
- chain_index=chain_index,
- b_factors=b_factors,
- )
|