process.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. from typing import Optional
  4. import numpy as np
  5. import torch
  6. from modelscope.models.science.unifold.data import data_ops
  7. def nonensembled_fns(common_cfg, mode_cfg):
  8. """Input pipeline data transformers that are not ensembled."""
  9. v2_feature = common_cfg.v2_feature
  10. operators = []
  11. if mode_cfg.random_delete_msa:
  12. operators.append(
  13. data_ops.random_delete_msa(common_cfg.random_delete_msa))
  14. operators.extend([
  15. data_ops.cast_to_64bit_ints,
  16. data_ops.correct_msa_restypes,
  17. data_ops.squeeze_features,
  18. data_ops.randomly_replace_msa_with_unknown(0.0),
  19. data_ops.make_seq_mask,
  20. data_ops.make_msa_mask,
  21. ])
  22. operators.append(data_ops.make_hhblits_profile_v2
  23. if v2_feature else data_ops.make_hhblits_profile)
  24. if common_cfg.use_templates:
  25. operators.extend([
  26. data_ops.make_template_mask,
  27. data_ops.make_pseudo_beta('template_'),
  28. ])
  29. operators.append(
  30. data_ops.crop_templates(
  31. max_templates=mode_cfg.max_templates,
  32. subsample_templates=mode_cfg.subsample_templates,
  33. ))
  34. if common_cfg.use_template_torsion_angles:
  35. operators.extend([
  36. data_ops.atom37_to_torsion_angles('template_'),
  37. ])
  38. operators.append(data_ops.make_atom14_masks)
  39. operators.append(data_ops.make_target_feat)
  40. return operators
  41. def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed):
  42. operators = []
  43. if common_cfg.reduce_msa_clusters_by_max_templates:
  44. pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
  45. else:
  46. pad_msa_clusters = mode_cfg.max_msa_clusters
  47. crop_feats = dict(common_cfg.features)
  48. if mode_cfg.fixed_size:
  49. if mode_cfg.crop:
  50. if common_cfg.is_multimer:
  51. crop_fn = data_ops.crop_to_size_multimer(
  52. crop_size=mode_cfg.crop_size,
  53. shape_schema=crop_feats,
  54. seed=crop_and_fix_size_seed,
  55. spatial_crop_prob=mode_cfg.spatial_crop_prob,
  56. ca_ca_threshold=mode_cfg.ca_ca_threshold,
  57. )
  58. else:
  59. crop_fn = data_ops.crop_to_size_single(
  60. crop_size=mode_cfg.crop_size,
  61. shape_schema=crop_feats,
  62. seed=crop_and_fix_size_seed,
  63. )
  64. operators.append(crop_fn)
  65. operators.append(data_ops.select_feat(crop_feats))
  66. operators.append(
  67. data_ops.make_fixed_size(
  68. crop_feats,
  69. pad_msa_clusters,
  70. common_cfg.max_extra_msa,
  71. mode_cfg.crop_size,
  72. mode_cfg.max_templates,
  73. ))
  74. return operators
  75. def ensembled_fns(common_cfg, mode_cfg):
  76. """Input pipeline data transformers that can be ensembled and averaged."""
  77. operators = []
  78. multimer_mode = common_cfg.is_multimer
  79. v2_feature = common_cfg.v2_feature
  80. # multimer don't use block delete msa
  81. if mode_cfg.block_delete_msa and not multimer_mode:
  82. operators.append(
  83. data_ops.block_delete_msa(common_cfg.block_delete_msa))
  84. if 'max_distillation_msa_clusters' in mode_cfg:
  85. operators.append(
  86. data_ops.sample_msa_distillation(
  87. mode_cfg.max_distillation_msa_clusters))
  88. if common_cfg.reduce_msa_clusters_by_max_templates:
  89. pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
  90. else:
  91. pad_msa_clusters = mode_cfg.max_msa_clusters
  92. max_msa_clusters = pad_msa_clusters
  93. max_extra_msa = common_cfg.max_extra_msa
  94. assert common_cfg.resample_msa_in_recycling
  95. gumbel_sample = common_cfg.gumbel_sample
  96. operators.append(
  97. data_ops.sample_msa(
  98. max_msa_clusters,
  99. keep_extra=True,
  100. gumbel_sample=gumbel_sample,
  101. biased_msa_by_chain=mode_cfg.biased_msa_by_chain,
  102. ))
  103. if 'masked_msa' in common_cfg:
  104. # Masked MSA should come *before* MSA clustering so that
  105. # the clustering and full MSA profile do not leak information about
  106. # the masked locations and secret corrupted locations.
  107. operators.append(
  108. data_ops.make_masked_msa(
  109. common_cfg.masked_msa,
  110. mode_cfg.masked_msa_replace_fraction,
  111. gumbel_sample=gumbel_sample,
  112. share_mask=mode_cfg.share_mask,
  113. ))
  114. if common_cfg.msa_cluster_features:
  115. if v2_feature:
  116. operators.append(data_ops.nearest_neighbor_clusters_v2())
  117. else:
  118. operators.append(data_ops.nearest_neighbor_clusters())
  119. operators.append(data_ops.summarize_clusters)
  120. if v2_feature:
  121. operators.append(data_ops.make_msa_feat_v2)
  122. else:
  123. operators.append(data_ops.make_msa_feat)
  124. # Crop after creating the cluster profiles.
  125. if max_extra_msa:
  126. if v2_feature:
  127. operators.append(data_ops.make_extra_msa_feat(max_extra_msa))
  128. else:
  129. operators.append(data_ops.crop_extra_msa(max_extra_msa))
  130. else:
  131. operators.append(data_ops.delete_extra_msa)
  132. # operators.append(data_operators.select_feat(common_cfg.recycling_features))
  133. return operators
  134. def process_features(tensors, common_cfg, mode_cfg):
  135. """Based on the config, apply filters and transformations to the data."""
  136. is_distillation = bool(tensors.get('is_distillation', 0))
  137. multimer_mode = common_cfg.is_multimer
  138. crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed'])
  139. crop_fn = crop_and_fix_size_fns(
  140. common_cfg,
  141. mode_cfg,
  142. crop_and_fix_size_seed,
  143. )
  144. def wrap_ensemble_fn(data, i):
  145. """Function to be mapped over the ensemble dimension."""
  146. d = data.copy()
  147. fns = ensembled_fns(
  148. common_cfg,
  149. mode_cfg,
  150. )
  151. new_d = compose(fns)(d)
  152. if not multimer_mode or is_distillation:
  153. new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d)
  154. return compose(crop_fn)(new_d)
  155. else: # select after crop for spatial cropping
  156. d = compose(crop_fn)(d)
  157. d = data_ops.select_feat(common_cfg.recycling_features)(d)
  158. return d
  159. nonensembled = nonensembled_fns(common_cfg, mode_cfg)
  160. if mode_cfg.supervised and (not multimer_mode or is_distillation):
  161. nonensembled.extend(label_transform_fn())
  162. tensors = compose(nonensembled)(tensors)
  163. num_recycling = int(tensors['num_recycling_iters']) + 1
  164. num_ensembles = mode_cfg.num_ensembles
  165. ensemble_tensors = map_fn(
  166. lambda x: wrap_ensemble_fn(tensors, x),
  167. torch.arange(num_recycling * num_ensembles),
  168. )
  169. tensors = compose(crop_fn)(tensors)
  170. # add a dummy dim to align with recycling features
  171. tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors}
  172. tensors.update(ensemble_tensors)
  173. return tensors
  174. @data_ops.curry1
  175. def compose(x, fs):
  176. for f in fs:
  177. x = f(x)
  178. return x
  179. def pad_then_stack(values, ):
  180. if len(values[0].shape) >= 1:
  181. size = max(v.shape[0] for v in values)
  182. new_values = []
  183. for v in values:
  184. if v.shape[0] < size:
  185. res = values[0].new_zeros(size, *v.shape[1:])
  186. res[:v.shape[0], ...] = v
  187. else:
  188. res = v
  189. new_values.append(res)
  190. else:
  191. new_values = values
  192. return torch.stack(new_values, dim=0)
  193. def map_fn(fun, x):
  194. ensembles = [fun(elem) for elem in x]
  195. features = ensembles[0].keys()
  196. ensembled_dict = {}
  197. for feat in features:
  198. ensembled_dict[feat] = pad_then_stack(
  199. [dict_i[feat] for dict_i in ensembles])
  200. return ensembled_dict
  201. def process_single_label(label: dict,
  202. num_ensemble: Optional[int] = None) -> dict:
  203. assert 'aatype' in label
  204. assert 'all_atom_positions' in label
  205. assert 'all_atom_mask' in label
  206. label = compose(label_transform_fn())(label)
  207. if num_ensemble is not None:
  208. label = {
  209. k: torch.stack([v for _ in range(num_ensemble)])
  210. for k, v in label.items()
  211. }
  212. return label
  213. def process_labels(labels_list, num_ensemble: Optional[int] = None):
  214. return [process_single_label(ll, num_ensemble) for ll in labels_list]
  215. def label_transform_fn():
  216. return [
  217. data_ops.make_atom14_masks,
  218. data_ops.make_atom14_positions,
  219. data_ops.atom37_to_frames,
  220. data_ops.atom37_to_torsion_angles(''),
  221. data_ops.make_pseudo_beta(''),
  222. data_ops.get_backbone_frames,
  223. data_ops.get_chi_angles,
  224. ]