| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518 |
- # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
- # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
- import copy
- import logging
- import os
- # from typing import *
- from typing import Dict, Iterable, List, Optional, Tuple, Union
- import json
- import ml_collections as mlc
- import numpy as np
- import torch
- from unicore.data import UnicoreDataset, data_utils
- from unicore.distributed import utils as distributed_utils
- from .data import utils
- from .data.data_ops import NumpyDict, TorchDict
- from .data.process import process_features, process_labels
- from .data.process_multimer import (add_assembly_features,
- convert_monomer_features, merge_msas,
- pair_and_merge, post_process)
- Rotation = Iterable[Iterable]
- Translation = Iterable
- Operation = Union[str, Tuple[Rotation, Translation]]
- NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]]
- TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]]
- logger = logging.getLogger(__name__) # pylint: disable=invalid-name
- def make_data_config(
- config: mlc.ConfigDict,
- mode: str,
- num_res: int,
- ) -> Tuple[mlc.ConfigDict, List[str]]:
- cfg = copy.deepcopy(config)
- mode_cfg = cfg[mode]
- with cfg.unlocked():
- if mode_cfg.crop_size is None:
- mode_cfg.crop_size = num_res
- feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features
- if cfg.common.use_templates:
- feature_names += cfg.common.template_features
- if cfg.common.is_multimer:
- feature_names += cfg.common.multimer_features
- if cfg[mode].supervised:
- feature_names += cfg.supervised.supervised_features
- return cfg, feature_names
- def process_label(all_atom_positions: np.ndarray,
- operation: Operation) -> np.ndarray:
- if operation == 'I':
- return all_atom_positions
- rot, trans = operation
- rot = np.array(rot).reshape(3, 3)
- trans = np.array(trans).reshape(3)
- return all_atom_positions @ rot.T + trans
- @utils.lru_cache(maxsize=8, copy=True)
- def load_single_feature(
- sequence_id: str,
- monomer_feature_dir: str,
- uniprot_msa_dir: Optional[str] = None,
- is_monomer: bool = False,
- ) -> NumpyDict:
- monomer_feature = utils.load_pickle(
- os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz'))
- monomer_feature = convert_monomer_features(monomer_feature)
- chain_feature = {**monomer_feature}
- if uniprot_msa_dir is not None:
- all_seq_feature = utils.load_pickle(
- os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz'))
- if is_monomer:
- chain_feature['msa'], chain_feature[
- 'deletion_matrix'] = merge_msas(
- chain_feature['msa'],
- chain_feature['deletion_matrix'],
- all_seq_feature['msa'],
- all_seq_feature['deletion_matrix'],
- ) # noqa
- else:
- all_seq_feature = utils.convert_all_seq_feature(all_seq_feature)
- for key in [
- 'msa_all_seq',
- 'msa_species_identifiers_all_seq',
- 'deletion_matrix_all_seq',
- ]:
- chain_feature[key] = all_seq_feature[key]
- return chain_feature
- def load_single_label(
- label_id: str,
- label_dir: str,
- symmetry_operation: Optional[Operation] = None,
- ) -> NumpyDict:
- label = utils.load_pickle(
- os.path.join(label_dir, f'{label_id}.label.pkl.gz'))
- if symmetry_operation is not None:
- label['all_atom_positions'] = process_label(
- label['all_atom_positions'], symmetry_operation)
- label = {
- k: v
- for k, v in label.items() if k in
- ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
- }
- return label
- def load(
- sequence_ids: List[str],
- monomer_feature_dir: str,
- uniprot_msa_dir: Optional[str] = None,
- label_ids: Optional[List[str]] = None,
- label_dir: Optional[str] = None,
- symmetry_operations: Optional[List[Operation]] = None,
- is_monomer: bool = False,
- ) -> NumpyExample:
- all_chain_features = [
- load_single_feature(s, monomer_feature_dir, uniprot_msa_dir,
- is_monomer) for s in sequence_ids
- ]
- if label_ids is not None:
- # load labels
- assert len(label_ids) == len(sequence_ids)
- assert label_dir is not None
- if symmetry_operations is None:
- symmetry_operations = ['I' for _ in label_ids]
- all_chain_labels = [
- load_single_label(ll, label_dir, o)
- for ll, o in zip(label_ids, symmetry_operations)
- ]
- # update labels into features to calculate spatial cropping etc.
- [f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)]
- all_chain_features = add_assembly_features(all_chain_features)
- # get labels back from features, as add_assembly_features may alter the order of inputs.
- if label_ids is not None:
- all_chain_labels = [{
- k: f[k]
- for k in
- ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
- } for f in all_chain_features]
- else:
- all_chain_labels = None
- asym_len = np.array([c['seq_length'] for c in all_chain_features],
- dtype=np.int64)
- if is_monomer:
- all_chain_features = all_chain_features[0]
- else:
- all_chain_features = pair_and_merge(all_chain_features)
- all_chain_features = post_process(all_chain_features)
- all_chain_features['asym_len'] = asym_len
- return all_chain_features, all_chain_labels
- def process(
- config: mlc.ConfigDict,
- mode: str,
- features: NumpyDict,
- labels: Optional[List[NumpyDict]] = None,
- seed: int = 0,
- batch_idx: Optional[int] = None,
- data_idx: Optional[int] = None,
- is_distillation: bool = False,
- ) -> TorchExample:
- if mode == 'train':
- assert batch_idx is not None
- with data_utils.numpy_seed(seed, batch_idx, key='recycling'):
- num_iters = np.random.randint(
- 0, config.common.max_recycling_iters + 1)
- use_clamped_fape = np.random.rand(
- ) < config[mode].use_clamped_fape_prob
- else:
- num_iters = config.common.max_recycling_iters
- use_clamped_fape = 1
- features['num_recycling_iters'] = int(num_iters)
- features['use_clamped_fape'] = int(use_clamped_fape)
- features['is_distillation'] = int(is_distillation)
- if is_distillation and 'msa_chains' in features:
- features.pop('msa_chains')
- num_res = int(features['seq_length'])
- cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
- if labels is not None:
- features['resolution'] = labels[0]['resolution'].reshape(-1)
- with data_utils.numpy_seed(seed, data_idx, key='protein_feature'):
- features['crop_and_fix_size_seed'] = np.random.randint(0, 63355)
- features = utils.filter(features, desired_keys=feature_names)
- features = {k: torch.tensor(v) for k, v in features.items()}
- with torch.no_grad():
- features = process_features(features, cfg.common, cfg[mode])
- if labels is not None:
- labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels]
- with torch.no_grad():
- labels = process_labels(labels)
- return features, labels
- def load_and_process(
- config: mlc.ConfigDict,
- mode: str,
- seed: int = 0,
- batch_idx: Optional[int] = None,
- data_idx: Optional[int] = None,
- is_distillation: bool = False,
- **load_kwargs,
- ):
- is_monomer = (
- is_distillation
- if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer'))
- features, labels = load(**load_kwargs, is_monomer=is_monomer)
- features, labels = process(config, mode, features, labels, seed, batch_idx,
- data_idx, is_distillation)
- return features, labels
- class UnifoldDataset(UnicoreDataset):
- def __init__(
- self,
- args,
- seed,
- config,
- data_path,
- mode='train',
- max_step=None,
- disable_sd=False,
- json_prefix='',
- ):
- self.path = data_path
- def load_json(filename):
- return json.load(open(filename, 'r', encoding='utf-8'))
- sample_weight = load_json(
- os.path.join(self.path,
- json_prefix + mode + '_sample_weight.json'))
- self.multi_label = load_json(
- os.path.join(self.path, json_prefix + mode + '_multi_label.json'))
- self.inverse_multi_label = self._inverse_map(self.multi_label)
- self.sample_weight = {}
- for chain in self.inverse_multi_label:
- entity = self.inverse_multi_label[chain]
- self.sample_weight[chain] = sample_weight[entity]
- self.seq_sample_weight = sample_weight
- logger.info('load {} chains (unique {} sequences)'.format(
- len(self.sample_weight), len(self.seq_sample_weight)))
- self.feature_path = os.path.join(self.path, 'pdb_features')
- self.label_path = os.path.join(self.path, 'pdb_labels')
- sd_sample_weight_path = os.path.join(
- self.path, json_prefix + 'sd_train_sample_weight.json')
- if mode == 'train' and os.path.isfile(
- sd_sample_weight_path) and not disable_sd:
- self.sd_sample_weight = load_json(sd_sample_weight_path)
- logger.info('load {} self-distillation samples.'.format(
- len(self.sd_sample_weight)))
- self.sd_feature_path = os.path.join(self.path, 'sd_features')
- self.sd_label_path = os.path.join(self.path, 'sd_labels')
- else:
- self.sd_sample_weight = None
- self.batch_size = (
- args.batch_size * distributed_utils.get_data_parallel_world_size()
- * args.update_freq[0])
- self.data_len = (
- max_step * self.batch_size
- if max_step is not None else len(self.sample_weight))
- self.mode = mode
- self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight(
- self.seq_sample_weight)
- self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
- self.sample_weight)
- if self.sd_sample_weight is not None:
- (
- self.sd_num_chain,
- self.sd_chain_keys,
- self.sd_sample_prob,
- ) = self.cal_sample_weight(self.sd_sample_weight)
- self.config = config.data
- self.seed = seed
- self.sd_prob = args.sd_prob
- def cal_sample_weight(self, sample_weight):
- prot_keys = list(sample_weight.keys())
- sum_weight = sum(sample_weight.values())
- sample_prob = [sample_weight[k] / sum_weight for k in prot_keys]
- num_prot = len(prot_keys)
- return num_prot, prot_keys, sample_prob
- def sample_chain(self, idx, sample_by_seq=False):
- is_distillation = False
- if self.mode == 'train':
- with data_utils.numpy_seed(self.seed, idx, key='data_sample'):
- is_distillation = ((np.random.rand(1)[0] < self.sd_prob)
- if self.sd_sample_weight is not None else
- False)
- if is_distillation:
- prot_idx = np.random.choice(
- self.sd_num_chain, p=self.sd_sample_prob)
- label_name = self.sd_chain_keys[prot_idx]
- seq_name = label_name
- else:
- if not sample_by_seq:
- prot_idx = np.random.choice(
- self.num_chain, p=self.sample_prob)
- label_name = self.chain_keys[prot_idx]
- seq_name = self.inverse_multi_label[label_name]
- else:
- seq_idx = np.random.choice(
- self.num_seq, p=self.seq_sample_prob)
- seq_name = self.seq_keys[seq_idx]
- label_name = np.random.choice(
- self.multi_label[seq_name])
- else:
- label_name = self.chain_keys[idx]
- seq_name = self.inverse_multi_label[label_name]
- return seq_name, label_name, is_distillation
- def __getitem__(self, idx):
- sequence_id, label_id, is_distillation = self.sample_chain(
- idx, sample_by_seq=True)
- feature_dir, label_dir = ((self.feature_path,
- self.label_path) if not is_distillation else
- (self.sd_feature_path, self.sd_label_path))
- features, _ = load_and_process(
- self.config,
- self.mode,
- self.seed,
- batch_idx=(idx // self.batch_size),
- data_idx=idx,
- is_distillation=is_distillation,
- sequence_ids=[sequence_id],
- monomer_feature_dir=feature_dir,
- uniprot_msa_dir=None,
- label_ids=[label_id],
- label_dir=label_dir,
- symmetry_operations=None,
- is_monomer=True,
- )
- return features
- def __len__(self):
- return self.data_len
- @staticmethod
- def collater(samples):
- # first dim is recyling. bsz is at the 2nd dim
- return data_utils.collate_dict(samples, dim=1)
- @staticmethod
- def _inverse_map(mapping: Dict[str, List[str]]):
- inverse_mapping = {}
- for ent, refs in mapping.items():
- for ref in refs:
- if ref in inverse_mapping: # duplicated ent for this ref.
- ent_2 = inverse_mapping[ref]
- assert (
- ent == ent_2
- ), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.'
- inverse_mapping[ref] = ent
- return inverse_mapping
- class UnifoldMultimerDataset(UnifoldDataset):
- def __init__(
- self,
- args: mlc.ConfigDict,
- seed: int,
- config: mlc.ConfigDict,
- data_path: str,
- mode: str = 'train',
- max_step: Optional[int] = None,
- disable_sd: bool = False,
- json_prefix: str = '',
- **kwargs,
- ):
- super().__init__(args, seed, config, data_path, mode, max_step,
- disable_sd, json_prefix)
- self.data_path = data_path
- self.pdb_assembly = json.load(
- open(
- os.path.join(self.data_path,
- json_prefix + 'pdb_assembly.json'),
- encoding='utf-8'))
- self.pdb_chains = self.get_chains(self.inverse_multi_label)
- self.monomer_feature_path = os.path.join(self.data_path,
- 'pdb_features')
- self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots')
- self.label_path = os.path.join(self.data_path, 'pdb_labels')
- self.max_chains = args.max_chains
- if self.mode == 'train':
- self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains(
- self.pdb_chains, self.pdb_assembly, self.sample_weight,
- self.max_chains)
- self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
- self.sample_weight)
- def __getitem__(self, idx):
- seq_id, label_id, is_distillation = self.sample_chain(idx)
- if is_distillation:
- label_ids = [label_id]
- sequence_ids = [seq_id]
- monomer_feature_path, uniprot_msa_path, label_path = (
- self.sd_feature_path,
- None,
- self.sd_label_path,
- )
- symmetry_operations = None
- else:
- pdb_id = self.get_pdb_name(label_id)
- if pdb_id in self.pdb_assembly and self.mode == 'train':
- label_ids = [
- pdb_id + '_' + id
- for id in self.pdb_assembly[pdb_id]['chains']
- ]
- symmetry_operations = [
- t for t in self.pdb_assembly[pdb_id]['opers']
- ]
- else:
- label_ids = self.pdb_chains[pdb_id]
- symmetry_operations = None
- sequence_ids = [
- self.inverse_multi_label[chain_id] for chain_id in label_ids
- ]
- monomer_feature_path, uniprot_msa_path, label_path = (
- self.monomer_feature_path,
- self.uniprot_msa_path,
- self.label_path,
- )
- return load_and_process(
- self.config,
- self.mode,
- self.seed,
- batch_idx=(idx // self.batch_size),
- data_idx=idx,
- is_distillation=is_distillation,
- sequence_ids=sequence_ids,
- monomer_feature_dir=monomer_feature_path,
- uniprot_msa_dir=uniprot_msa_path,
- label_ids=label_ids,
- label_dir=label_path,
- symmetry_operations=symmetry_operations,
- is_monomer=False,
- )
- @staticmethod
- def collater(samples):
- # first dim is recyling. bsz is at the 2nd dim
- if len(samples) <= 0: # tackle empty batch
- return None
- feats = [s[0] for s in samples]
- labs = [s[1] for s in samples if s[1] is not None]
- try:
- feats = data_utils.collate_dict(feats, dim=1)
- except BaseException:
- raise ValueError('cannot collate features', feats)
- if not labs:
- labs = None
- return feats, labs
- @staticmethod
- def get_pdb_name(chain):
- return chain.split('_')[0]
- @staticmethod
- def get_chains(canon_chain_map):
- pdb_chains = {}
- for chain in canon_chain_map:
- pdb = UnifoldMultimerDataset.get_pdb_name(chain)
- if pdb not in pdb_chains:
- pdb_chains[pdb] = []
- pdb_chains[pdb].append(chain)
- return pdb_chains
- @staticmethod
- def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight,
- max_chains):
- new_pdb_chains = {}
- for chain in pdb_chains:
- if chain in pdb_assembly:
- size = len(pdb_assembly[chain]['chains'])
- if size <= max_chains:
- new_pdb_chains[chain] = pdb_chains[chain]
- else:
- size = len(pdb_chains[chain])
- if size == 1:
- new_pdb_chains[chain] = pdb_chains[chain]
- new_sample_weight = {
- k: sample_weight[k]
- for k in sample_weight
- if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains
- }
- logger.info(
- f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs '
- f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) '
- f'by max_chains {max_chains}')
- return new_pdb_chains, new_sample_weight
|