dataset.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import copy
  4. import logging
  5. import os
  6. # from typing import *
  7. from typing import Dict, Iterable, List, Optional, Tuple, Union
  8. import json
  9. import ml_collections as mlc
  10. import numpy as np
  11. import torch
  12. from unicore.data import UnicoreDataset, data_utils
  13. from unicore.distributed import utils as distributed_utils
  14. from .data import utils
  15. from .data.data_ops import NumpyDict, TorchDict
  16. from .data.process import process_features, process_labels
  17. from .data.process_multimer import (add_assembly_features,
  18. convert_monomer_features, merge_msas,
  19. pair_and_merge, post_process)
  20. Rotation = Iterable[Iterable]
  21. Translation = Iterable
  22. Operation = Union[str, Tuple[Rotation, Translation]]
  23. NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]]
  24. TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]]
  25. logger = logging.getLogger(__name__) # pylint: disable=invalid-name
  26. def make_data_config(
  27. config: mlc.ConfigDict,
  28. mode: str,
  29. num_res: int,
  30. ) -> Tuple[mlc.ConfigDict, List[str]]:
  31. cfg = copy.deepcopy(config)
  32. mode_cfg = cfg[mode]
  33. with cfg.unlocked():
  34. if mode_cfg.crop_size is None:
  35. mode_cfg.crop_size = num_res
  36. feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features
  37. if cfg.common.use_templates:
  38. feature_names += cfg.common.template_features
  39. if cfg.common.is_multimer:
  40. feature_names += cfg.common.multimer_features
  41. if cfg[mode].supervised:
  42. feature_names += cfg.supervised.supervised_features
  43. return cfg, feature_names
  44. def process_label(all_atom_positions: np.ndarray,
  45. operation: Operation) -> np.ndarray:
  46. if operation == 'I':
  47. return all_atom_positions
  48. rot, trans = operation
  49. rot = np.array(rot).reshape(3, 3)
  50. trans = np.array(trans).reshape(3)
  51. return all_atom_positions @ rot.T + trans
  52. @utils.lru_cache(maxsize=8, copy=True)
  53. def load_single_feature(
  54. sequence_id: str,
  55. monomer_feature_dir: str,
  56. uniprot_msa_dir: Optional[str] = None,
  57. is_monomer: bool = False,
  58. ) -> NumpyDict:
  59. monomer_feature = utils.load_pickle(
  60. os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz'))
  61. monomer_feature = convert_monomer_features(monomer_feature)
  62. chain_feature = {**monomer_feature}
  63. if uniprot_msa_dir is not None:
  64. all_seq_feature = utils.load_pickle(
  65. os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz'))
  66. if is_monomer:
  67. chain_feature['msa'], chain_feature[
  68. 'deletion_matrix'] = merge_msas(
  69. chain_feature['msa'],
  70. chain_feature['deletion_matrix'],
  71. all_seq_feature['msa'],
  72. all_seq_feature['deletion_matrix'],
  73. ) # noqa
  74. else:
  75. all_seq_feature = utils.convert_all_seq_feature(all_seq_feature)
  76. for key in [
  77. 'msa_all_seq',
  78. 'msa_species_identifiers_all_seq',
  79. 'deletion_matrix_all_seq',
  80. ]:
  81. chain_feature[key] = all_seq_feature[key]
  82. return chain_feature
  83. def load_single_label(
  84. label_id: str,
  85. label_dir: str,
  86. symmetry_operation: Optional[Operation] = None,
  87. ) -> NumpyDict:
  88. label = utils.load_pickle(
  89. os.path.join(label_dir, f'{label_id}.label.pkl.gz'))
  90. if symmetry_operation is not None:
  91. label['all_atom_positions'] = process_label(
  92. label['all_atom_positions'], symmetry_operation)
  93. label = {
  94. k: v
  95. for k, v in label.items() if k in
  96. ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
  97. }
  98. return label
  99. def load(
  100. sequence_ids: List[str],
  101. monomer_feature_dir: str,
  102. uniprot_msa_dir: Optional[str] = None,
  103. label_ids: Optional[List[str]] = None,
  104. label_dir: Optional[str] = None,
  105. symmetry_operations: Optional[List[Operation]] = None,
  106. is_monomer: bool = False,
  107. ) -> NumpyExample:
  108. all_chain_features = [
  109. load_single_feature(s, monomer_feature_dir, uniprot_msa_dir,
  110. is_monomer) for s in sequence_ids
  111. ]
  112. if label_ids is not None:
  113. # load labels
  114. assert len(label_ids) == len(sequence_ids)
  115. assert label_dir is not None
  116. if symmetry_operations is None:
  117. symmetry_operations = ['I' for _ in label_ids]
  118. all_chain_labels = [
  119. load_single_label(ll, label_dir, o)
  120. for ll, o in zip(label_ids, symmetry_operations)
  121. ]
  122. # update labels into features to calculate spatial cropping etc.
  123. [f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)]
  124. all_chain_features = add_assembly_features(all_chain_features)
  125. # get labels back from features, as add_assembly_features may alter the order of inputs.
  126. if label_ids is not None:
  127. all_chain_labels = [{
  128. k: f[k]
  129. for k in
  130. ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
  131. } for f in all_chain_features]
  132. else:
  133. all_chain_labels = None
  134. asym_len = np.array([c['seq_length'] for c in all_chain_features],
  135. dtype=np.int64)
  136. if is_monomer:
  137. all_chain_features = all_chain_features[0]
  138. else:
  139. all_chain_features = pair_and_merge(all_chain_features)
  140. all_chain_features = post_process(all_chain_features)
  141. all_chain_features['asym_len'] = asym_len
  142. return all_chain_features, all_chain_labels
  143. def process(
  144. config: mlc.ConfigDict,
  145. mode: str,
  146. features: NumpyDict,
  147. labels: Optional[List[NumpyDict]] = None,
  148. seed: int = 0,
  149. batch_idx: Optional[int] = None,
  150. data_idx: Optional[int] = None,
  151. is_distillation: bool = False,
  152. ) -> TorchExample:
  153. if mode == 'train':
  154. assert batch_idx is not None
  155. with data_utils.numpy_seed(seed, batch_idx, key='recycling'):
  156. num_iters = np.random.randint(
  157. 0, config.common.max_recycling_iters + 1)
  158. use_clamped_fape = np.random.rand(
  159. ) < config[mode].use_clamped_fape_prob
  160. else:
  161. num_iters = config.common.max_recycling_iters
  162. use_clamped_fape = 1
  163. features['num_recycling_iters'] = int(num_iters)
  164. features['use_clamped_fape'] = int(use_clamped_fape)
  165. features['is_distillation'] = int(is_distillation)
  166. if is_distillation and 'msa_chains' in features:
  167. features.pop('msa_chains')
  168. num_res = int(features['seq_length'])
  169. cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
  170. if labels is not None:
  171. features['resolution'] = labels[0]['resolution'].reshape(-1)
  172. with data_utils.numpy_seed(seed, data_idx, key='protein_feature'):
  173. features['crop_and_fix_size_seed'] = np.random.randint(0, 63355)
  174. features = utils.filter(features, desired_keys=feature_names)
  175. features = {k: torch.tensor(v) for k, v in features.items()}
  176. with torch.no_grad():
  177. features = process_features(features, cfg.common, cfg[mode])
  178. if labels is not None:
  179. labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels]
  180. with torch.no_grad():
  181. labels = process_labels(labels)
  182. return features, labels
  183. def load_and_process(
  184. config: mlc.ConfigDict,
  185. mode: str,
  186. seed: int = 0,
  187. batch_idx: Optional[int] = None,
  188. data_idx: Optional[int] = None,
  189. is_distillation: bool = False,
  190. **load_kwargs,
  191. ):
  192. is_monomer = (
  193. is_distillation
  194. if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer'))
  195. features, labels = load(**load_kwargs, is_monomer=is_monomer)
  196. features, labels = process(config, mode, features, labels, seed, batch_idx,
  197. data_idx, is_distillation)
  198. return features, labels
  199. class UnifoldDataset(UnicoreDataset):
  200. def __init__(
  201. self,
  202. args,
  203. seed,
  204. config,
  205. data_path,
  206. mode='train',
  207. max_step=None,
  208. disable_sd=False,
  209. json_prefix='',
  210. ):
  211. self.path = data_path
  212. def load_json(filename):
  213. return json.load(open(filename, 'r', encoding='utf-8'))
  214. sample_weight = load_json(
  215. os.path.join(self.path,
  216. json_prefix + mode + '_sample_weight.json'))
  217. self.multi_label = load_json(
  218. os.path.join(self.path, json_prefix + mode + '_multi_label.json'))
  219. self.inverse_multi_label = self._inverse_map(self.multi_label)
  220. self.sample_weight = {}
  221. for chain in self.inverse_multi_label:
  222. entity = self.inverse_multi_label[chain]
  223. self.sample_weight[chain] = sample_weight[entity]
  224. self.seq_sample_weight = sample_weight
  225. logger.info('load {} chains (unique {} sequences)'.format(
  226. len(self.sample_weight), len(self.seq_sample_weight)))
  227. self.feature_path = os.path.join(self.path, 'pdb_features')
  228. self.label_path = os.path.join(self.path, 'pdb_labels')
  229. sd_sample_weight_path = os.path.join(
  230. self.path, json_prefix + 'sd_train_sample_weight.json')
  231. if mode == 'train' and os.path.isfile(
  232. sd_sample_weight_path) and not disable_sd:
  233. self.sd_sample_weight = load_json(sd_sample_weight_path)
  234. logger.info('load {} self-distillation samples.'.format(
  235. len(self.sd_sample_weight)))
  236. self.sd_feature_path = os.path.join(self.path, 'sd_features')
  237. self.sd_label_path = os.path.join(self.path, 'sd_labels')
  238. else:
  239. self.sd_sample_weight = None
  240. self.batch_size = (
  241. args.batch_size * distributed_utils.get_data_parallel_world_size()
  242. * args.update_freq[0])
  243. self.data_len = (
  244. max_step * self.batch_size
  245. if max_step is not None else len(self.sample_weight))
  246. self.mode = mode
  247. self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight(
  248. self.seq_sample_weight)
  249. self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
  250. self.sample_weight)
  251. if self.sd_sample_weight is not None:
  252. (
  253. self.sd_num_chain,
  254. self.sd_chain_keys,
  255. self.sd_sample_prob,
  256. ) = self.cal_sample_weight(self.sd_sample_weight)
  257. self.config = config.data
  258. self.seed = seed
  259. self.sd_prob = args.sd_prob
  260. def cal_sample_weight(self, sample_weight):
  261. prot_keys = list(sample_weight.keys())
  262. sum_weight = sum(sample_weight.values())
  263. sample_prob = [sample_weight[k] / sum_weight for k in prot_keys]
  264. num_prot = len(prot_keys)
  265. return num_prot, prot_keys, sample_prob
  266. def sample_chain(self, idx, sample_by_seq=False):
  267. is_distillation = False
  268. if self.mode == 'train':
  269. with data_utils.numpy_seed(self.seed, idx, key='data_sample'):
  270. is_distillation = ((np.random.rand(1)[0] < self.sd_prob)
  271. if self.sd_sample_weight is not None else
  272. False)
  273. if is_distillation:
  274. prot_idx = np.random.choice(
  275. self.sd_num_chain, p=self.sd_sample_prob)
  276. label_name = self.sd_chain_keys[prot_idx]
  277. seq_name = label_name
  278. else:
  279. if not sample_by_seq:
  280. prot_idx = np.random.choice(
  281. self.num_chain, p=self.sample_prob)
  282. label_name = self.chain_keys[prot_idx]
  283. seq_name = self.inverse_multi_label[label_name]
  284. else:
  285. seq_idx = np.random.choice(
  286. self.num_seq, p=self.seq_sample_prob)
  287. seq_name = self.seq_keys[seq_idx]
  288. label_name = np.random.choice(
  289. self.multi_label[seq_name])
  290. else:
  291. label_name = self.chain_keys[idx]
  292. seq_name = self.inverse_multi_label[label_name]
  293. return seq_name, label_name, is_distillation
  294. def __getitem__(self, idx):
  295. sequence_id, label_id, is_distillation = self.sample_chain(
  296. idx, sample_by_seq=True)
  297. feature_dir, label_dir = ((self.feature_path,
  298. self.label_path) if not is_distillation else
  299. (self.sd_feature_path, self.sd_label_path))
  300. features, _ = load_and_process(
  301. self.config,
  302. self.mode,
  303. self.seed,
  304. batch_idx=(idx // self.batch_size),
  305. data_idx=idx,
  306. is_distillation=is_distillation,
  307. sequence_ids=[sequence_id],
  308. monomer_feature_dir=feature_dir,
  309. uniprot_msa_dir=None,
  310. label_ids=[label_id],
  311. label_dir=label_dir,
  312. symmetry_operations=None,
  313. is_monomer=True,
  314. )
  315. return features
  316. def __len__(self):
  317. return self.data_len
  318. @staticmethod
  319. def collater(samples):
  320. # first dim is recyling. bsz is at the 2nd dim
  321. return data_utils.collate_dict(samples, dim=1)
  322. @staticmethod
  323. def _inverse_map(mapping: Dict[str, List[str]]):
  324. inverse_mapping = {}
  325. for ent, refs in mapping.items():
  326. for ref in refs:
  327. if ref in inverse_mapping: # duplicated ent for this ref.
  328. ent_2 = inverse_mapping[ref]
  329. assert (
  330. ent == ent_2
  331. ), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.'
  332. inverse_mapping[ref] = ent
  333. return inverse_mapping
  334. class UnifoldMultimerDataset(UnifoldDataset):
  335. def __init__(
  336. self,
  337. args: mlc.ConfigDict,
  338. seed: int,
  339. config: mlc.ConfigDict,
  340. data_path: str,
  341. mode: str = 'train',
  342. max_step: Optional[int] = None,
  343. disable_sd: bool = False,
  344. json_prefix: str = '',
  345. **kwargs,
  346. ):
  347. super().__init__(args, seed, config, data_path, mode, max_step,
  348. disable_sd, json_prefix)
  349. self.data_path = data_path
  350. self.pdb_assembly = json.load(
  351. open(
  352. os.path.join(self.data_path,
  353. json_prefix + 'pdb_assembly.json'),
  354. encoding='utf-8'))
  355. self.pdb_chains = self.get_chains(self.inverse_multi_label)
  356. self.monomer_feature_path = os.path.join(self.data_path,
  357. 'pdb_features')
  358. self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots')
  359. self.label_path = os.path.join(self.data_path, 'pdb_labels')
  360. self.max_chains = args.max_chains
  361. if self.mode == 'train':
  362. self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains(
  363. self.pdb_chains, self.pdb_assembly, self.sample_weight,
  364. self.max_chains)
  365. self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
  366. self.sample_weight)
  367. def __getitem__(self, idx):
  368. seq_id, label_id, is_distillation = self.sample_chain(idx)
  369. if is_distillation:
  370. label_ids = [label_id]
  371. sequence_ids = [seq_id]
  372. monomer_feature_path, uniprot_msa_path, label_path = (
  373. self.sd_feature_path,
  374. None,
  375. self.sd_label_path,
  376. )
  377. symmetry_operations = None
  378. else:
  379. pdb_id = self.get_pdb_name(label_id)
  380. if pdb_id in self.pdb_assembly and self.mode == 'train':
  381. label_ids = [
  382. pdb_id + '_' + id
  383. for id in self.pdb_assembly[pdb_id]['chains']
  384. ]
  385. symmetry_operations = [
  386. t for t in self.pdb_assembly[pdb_id]['opers']
  387. ]
  388. else:
  389. label_ids = self.pdb_chains[pdb_id]
  390. symmetry_operations = None
  391. sequence_ids = [
  392. self.inverse_multi_label[chain_id] for chain_id in label_ids
  393. ]
  394. monomer_feature_path, uniprot_msa_path, label_path = (
  395. self.monomer_feature_path,
  396. self.uniprot_msa_path,
  397. self.label_path,
  398. )
  399. return load_and_process(
  400. self.config,
  401. self.mode,
  402. self.seed,
  403. batch_idx=(idx // self.batch_size),
  404. data_idx=idx,
  405. is_distillation=is_distillation,
  406. sequence_ids=sequence_ids,
  407. monomer_feature_dir=monomer_feature_path,
  408. uniprot_msa_dir=uniprot_msa_path,
  409. label_ids=label_ids,
  410. label_dir=label_path,
  411. symmetry_operations=symmetry_operations,
  412. is_monomer=False,
  413. )
  414. @staticmethod
  415. def collater(samples):
  416. # first dim is recyling. bsz is at the 2nd dim
  417. if len(samples) <= 0: # tackle empty batch
  418. return None
  419. feats = [s[0] for s in samples]
  420. labs = [s[1] for s in samples if s[1] is not None]
  421. try:
  422. feats = data_utils.collate_dict(feats, dim=1)
  423. except BaseException:
  424. raise ValueError('cannot collate features', feats)
  425. if not labs:
  426. labs = None
  427. return feats, labs
  428. @staticmethod
  429. def get_pdb_name(chain):
  430. return chain.split('_')[0]
  431. @staticmethod
  432. def get_chains(canon_chain_map):
  433. pdb_chains = {}
  434. for chain in canon_chain_map:
  435. pdb = UnifoldMultimerDataset.get_pdb_name(chain)
  436. if pdb not in pdb_chains:
  437. pdb_chains[pdb] = []
  438. pdb_chains[pdb].append(chain)
  439. return pdb_chains
  440. @staticmethod
  441. def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight,
  442. max_chains):
  443. new_pdb_chains = {}
  444. for chain in pdb_chains:
  445. if chain in pdb_assembly:
  446. size = len(pdb_assembly[chain]['chains'])
  447. if size <= max_chains:
  448. new_pdb_chains[chain] = pdb_chains[chain]
  449. else:
  450. size = len(pdb_chains[chain])
  451. if size == 1:
  452. new_pdb_chains[chain] = pdb_chains[chain]
  453. new_sample_weight = {
  454. k: sample_weight[k]
  455. for k in sample_weight
  456. if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains
  457. }
  458. logger.info(
  459. f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs '
  460. f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) '
  461. f'by max_chains {max_chains}')
  462. return new_pdb_chains, new_sample_weight