protein_structure_pipeline.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import time
  4. from typing import Any, Dict, List, Optional, Union
  5. import json
  6. import numpy as np
  7. import torch
  8. from unicore.utils import tensor_tree_map
  9. from modelscope.metainfo import Pipelines
  10. from modelscope.models.base import Model
  11. from modelscope.models.science.unifold.config import model_config
  12. from modelscope.models.science.unifold.data import protein, residue_constants
  13. from modelscope.models.science.unifold.dataset import (UnifoldDataset,
  14. load_and_process)
  15. from modelscope.outputs import OutputKeys
  16. from modelscope.pipelines.base import Pipeline, Tensor
  17. from modelscope.pipelines.builder import PIPELINES
  18. from modelscope.preprocessors import Preprocessor, build_preprocessor
  19. from modelscope.utils.constant import Fields, Frameworks, Tasks
  20. from modelscope.utils.device import device_placement
  21. from modelscope.utils.hub import read_config
  22. from modelscope.utils.logger import get_logger
  23. logger = get_logger()
  24. __all__ = ['ProteinStructurePipeline']
  25. def automatic_chunk_size(seq_len):
  26. if seq_len < 512:
  27. chunk_size = 256
  28. elif seq_len < 1024:
  29. chunk_size = 128
  30. elif seq_len < 2048:
  31. chunk_size = 32
  32. elif seq_len < 3072:
  33. chunk_size = 16
  34. else:
  35. chunk_size = 1
  36. return chunk_size
  37. def load_feature_for_one_target(
  38. config,
  39. data_folder,
  40. seed=0,
  41. is_multimer=False,
  42. use_uniprot=False,
  43. symmetry_group=None,
  44. ):
  45. if not is_multimer:
  46. uniprot_msa_dir = None
  47. sequence_ids = ['A']
  48. if use_uniprot:
  49. uniprot_msa_dir = data_folder
  50. else:
  51. uniprot_msa_dir = data_folder
  52. sequence_ids = open(
  53. os.path.join(data_folder, 'chains.txt'),
  54. encoding='utf-8').readline().split()
  55. if symmetry_group is None:
  56. batch, _ = load_and_process(
  57. config=config.data,
  58. mode='predict',
  59. seed=seed,
  60. batch_idx=None,
  61. data_idx=0,
  62. is_distillation=False,
  63. sequence_ids=sequence_ids,
  64. monomer_feature_dir=data_folder,
  65. uniprot_msa_dir=uniprot_msa_dir,
  66. )
  67. else:
  68. # Not for unifold-symmetry
  69. # only for unifold-multimer
  70. batch, _ = load_and_process(
  71. config=config.data,
  72. mode='predict',
  73. seed=seed,
  74. batch_idx=None,
  75. data_idx=0,
  76. is_distillation=False,
  77. sequence_ids=sequence_ids,
  78. monomer_feature_dir=data_folder,
  79. uniprot_msa_dir=uniprot_msa_dir,
  80. )
  81. batch = UnifoldDataset.collater([batch])
  82. return batch
  83. @PIPELINES.register_module(
  84. Tasks.protein_structure, module_name=Pipelines.protein_structure)
  85. class ProteinStructurePipeline(Pipeline):
  86. def __init__(self,
  87. model: Union[Model, str],
  88. preprocessor: Optional[Preprocessor] = None,
  89. **kwargs):
  90. """Use `model` and `preprocessor` to create a protein structure pipeline for prediction.
  91. Args:
  92. model (str or Model): Supply either a local model dir which supported the protein structure task,
  93. or a model id from the model hub, or a torch model instance.
  94. preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
  95. the model if supplied.
  96. Examples:
  97. >>> from modelscope.pipelines import pipeline
  98. >>> pipeline_ins = pipeline(task='protein-structure',
  99. >>> model='DPTech/uni-fold-monomer')
  100. >>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC'
  101. >>> print(pipeline_ins(protein))
  102. """
  103. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  104. self.cfg = read_config(self.model.model_dir)
  105. self.config = model_config(
  106. self.cfg['pipeline']['model_name']) # alphafold config
  107. self.postprocessor = self.cfg.pop('postprocessor', None)
  108. if preprocessor is None:
  109. preprocessor_cfg = self.cfg.preprocessor
  110. self.preprocessor = build_preprocessor(preprocessor_cfg,
  111. Fields.science)
  112. self.model.eval()
  113. def _sanitize_parameters(self, **pipeline_parameters):
  114. return pipeline_parameters, pipeline_parameters, pipeline_parameters
  115. def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]:
  116. preprocess_params = kwargs.get('preprocess_params', {})
  117. forward_params = kwargs.get('forward_params', {})
  118. postprocess_params = kwargs.get('postprocess_params', {})
  119. out = self.preprocess(input, **preprocess_params)
  120. with device_placement(self.framework, self.device_name):
  121. with torch.no_grad():
  122. out = self.forward(out, **forward_params)
  123. out = self.postprocess(out, **postprocess_params)
  124. return out
  125. def forward(self, inputs: Dict[str, Any],
  126. **forward_params) -> Dict[str, Any]:
  127. plddts = {}
  128. ptms = {}
  129. output_dir = os.path.join(self.preprocessor.output_dir_base,
  130. inputs['target_id'])
  131. pdbs = []
  132. for seed in range(self.cfg['pipeline']['times']):
  133. cur_seed = hash((42, seed)) % 100000
  134. batch = load_feature_for_one_target(
  135. self.config,
  136. output_dir,
  137. cur_seed,
  138. is_multimer=inputs['is_multimer'],
  139. use_uniprot=inputs['is_multimer'],
  140. symmetry_group=self.preprocessor.symmetry_group,
  141. )
  142. seq_len = batch['aatype'].shape[-1]
  143. self.model.model.globals.chunk_size = automatic_chunk_size(seq_len)
  144. with torch.no_grad():
  145. batch = {
  146. k: torch.as_tensor(v, device='cuda:0')
  147. for k, v in batch.items()
  148. }
  149. out = self.model(batch)
  150. def to_float(x):
  151. if x.dtype == torch.bfloat16 or x.dtype == torch.half:
  152. return x.float()
  153. else:
  154. return x
  155. # Toss out the recycling dimensions --- we don't need them anymore
  156. batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)
  157. batch = tensor_tree_map(to_float, batch)
  158. out = tensor_tree_map(lambda t: t[0, ...], out[0])
  159. out = tensor_tree_map(to_float, out)
  160. batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)
  161. out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
  162. plddt = out['plddt']
  163. mean_plddt = np.mean(plddt)
  164. plddt_b_factors = np.repeat(
  165. plddt[..., None], residue_constants.atom_type_num, axis=-1)
  166. # TODO: , may need to reorder chains, based on entity_ids
  167. cur_protein = protein.from_prediction(
  168. features=batch, result=out, b_factors=plddt_b_factors)
  169. cur_save_name = (f'{cur_seed}')
  170. plddts[cur_save_name] = str(mean_plddt)
  171. if inputs[
  172. 'is_multimer'] and self.preprocessor.symmetry_group is None:
  173. ptms[cur_save_name] = str(np.mean(out['iptm+ptm']))
  174. with open(os.path.join(output_dir, cur_save_name + '.pdb'),
  175. 'w') as f:
  176. f.write(protein.to_pdb(cur_protein))
  177. pdbs.append(protein.to_pdb(cur_protein))
  178. logger.info('plddts:' + str(plddts))
  179. model_name = self.cfg['pipeline']['model_name']
  180. score_name = f'{model_name}'
  181. plddt_fname = score_name + '_plddt.json'
  182. with open(os.path.join(output_dir, plddt_fname), 'w') as f:
  183. json.dump(plddts, f, indent=4)
  184. if ptms:
  185. logger.info('ptms' + str(ptms))
  186. ptm_fname = score_name + '_ptm.json'
  187. with open(os.path.join(output_dir, ptm_fname), 'w') as f:
  188. json.dump(ptms, f, indent=4)
  189. return pdbs
  190. def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params):
  191. return inputs