utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import copy as copy_lib
  4. import functools
  5. import gzip
  6. import pickle
  7. from typing import Any, Dict
  8. import json
  9. import numpy as np
  10. from scipy import sparse as sp
  11. from . import residue_constants as rc
  12. from .data_ops import NumpyDict
  13. # from typing import *
  14. def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False):
  15. if deepcopy:
  16. def decorator(f):
  17. cached_func = functools.lru_cache(maxsize, typed)(f)
  18. @functools.wraps(f)
  19. def wrapper(*args, **kwargs):
  20. return copy_lib.deepcopy(cached_func(*args, **kwargs))
  21. return wrapper
  22. elif copy:
  23. def decorator(f):
  24. cached_func = functools.lru_cache(maxsize, typed)(f)
  25. @functools.wraps(f)
  26. def wrapper(*args, **kwargs):
  27. return copy_lib.copy(cached_func(*args, **kwargs))
  28. return wrapper
  29. else:
  30. decorator = functools.lru_cache(maxsize, typed)
  31. return decorator
  32. @lru_cache(maxsize=8, deepcopy=True)
  33. def load_pickle_safe(path: str) -> Dict[str, Any]:
  34. def load(path):
  35. assert path.endswith('.pkl') or path.endswith(
  36. '.pkl.gz'), f'bad suffix in {path} as pickle file.'
  37. open_fn = gzip.open if path.endswith('.gz') else open
  38. with open_fn(path, 'rb') as f:
  39. return pickle.load(f)
  40. ret = load(path)
  41. ret = uncompress_features(ret)
  42. return ret
  43. @lru_cache(maxsize=8, copy=True)
  44. def load_pickle(path: str) -> Dict[str, Any]:
  45. def load(path):
  46. assert path.endswith('.pkl') or path.endswith(
  47. '.pkl.gz'), f'bad suffix in {path} as pickle file.'
  48. open_fn = gzip.open if path.endswith('.gz') else open
  49. with open_fn(path, 'rb') as f:
  50. return pickle.load(f)
  51. ret = load(path)
  52. ret = uncompress_features(ret)
  53. return ret
  54. def correct_template_restypes(feature):
  55. """Correct template restype to have the same order as residue_constants."""
  56. feature = np.argmax(feature, axis=-1).astype(np.int32)
  57. new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
  58. feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
  59. return feature
  60. def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict:
  61. feature['msa'] = feature['msa'].astype(np.uint8)
  62. if 'num_alignments' in feature:
  63. feature.pop('num_alignments')
  64. # make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k
  65. def make_all_seq_key(k):
  66. if not k.endswith('_all_seq'):
  67. return f'{k}_all_seq'
  68. return k
  69. return {make_all_seq_key(k): v for k, v in feature.items()}
  70. def to_dense_matrix(spmat_dict: NumpyDict):
  71. spmat = sp.coo_matrix(
  72. (spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])),
  73. shape=spmat_dict['shape'],
  74. dtype=np.float32,
  75. )
  76. return spmat.toarray()
  77. FEATS_DTYPE = {'msa': np.int32}
  78. def uncompress_features(feats: NumpyDict) -> NumpyDict:
  79. if 'sparse_deletion_matrix_int' in feats:
  80. v = feats.pop('sparse_deletion_matrix_int')
  81. v = to_dense_matrix(v)
  82. feats['deletion_matrix'] = v
  83. return feats
  84. def filter(feature: NumpyDict, **kwargs) -> NumpyDict:
  85. assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}'
  86. if 'desired_keys' in kwargs:
  87. feature = {
  88. k: v
  89. for k, v in feature.items() if k in kwargs['desired_keys']
  90. }
  91. elif 'required_keys' in kwargs:
  92. for k in kwargs['required_keys']:
  93. assert k in feature, f'cannot find required key {k}.'
  94. elif 'ignored_keys' in kwargs:
  95. feature = {
  96. k: v
  97. for k, v in feature.items() if k not in kwargs['ignored_keys']
  98. }
  99. else:
  100. raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}')
  101. return feature
  102. def compress_features(features: NumpyDict):
  103. change_dtype = {
  104. 'msa': np.uint8,
  105. }
  106. sparse_keys = ['deletion_matrix_int']
  107. compressed_features = {}
  108. for k, v in features.items():
  109. if k in change_dtype:
  110. v = v.astype(change_dtype[k])
  111. if k in sparse_keys:
  112. v = sp.coo_matrix(v, dtype=v.dtype)
  113. sp_v = {
  114. 'shape': v.shape,
  115. 'row': v.row,
  116. 'col': v.col,
  117. 'data': v.data
  118. }
  119. k = f'sparse_{k}'
  120. v = sp_v
  121. compressed_features[k] = v
  122. return compressed_features