model.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
  2. # and is publicly available at https://github.com/dptech-corp/Uni-Fold.
  3. import argparse
  4. import os
  5. from typing import Any
  6. import torch
  7. from modelscope.metainfo import Models
  8. from modelscope.models import TorchModel
  9. from modelscope.models.builder import MODELS
  10. from modelscope.utils.constant import ModelFile, Tasks
  11. from .config import model_config
  12. from .modules.alphafold import AlphaFold
  13. __all__ = ['UnifoldForProteinStructrue']
  14. @MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold)
  15. class UnifoldForProteinStructrue(TorchModel):
  16. @staticmethod
  17. def add_args(parser):
  18. """Add model-specific arguments to the parser."""
  19. parser.add_argument(
  20. '--model-name',
  21. help='choose the model config',
  22. )
  23. def __init__(self, **kwargs):
  24. super().__init__()
  25. parser = argparse.ArgumentParser()
  26. parse_comm = []
  27. for key in kwargs:
  28. parser.add_argument(f'--{key}')
  29. parse_comm.append(f'--{key}')
  30. parse_comm.append(kwargs[key])
  31. args = parser.parse_args(parse_comm)
  32. base_architecture(args)
  33. self.args = args
  34. config = model_config(
  35. self.args.model_name,
  36. train=True,
  37. )
  38. self.model = AlphaFold(config)
  39. self.config = config
  40. # load model state dict
  41. param_path = os.path.join(kwargs['model_dir'],
  42. ModelFile.TORCH_MODEL_BIN_FILE)
  43. state_dict = torch.load(param_path)['ema']['params']
  44. state_dict = {
  45. '.'.join(k.split('.')[1:]): v
  46. for k, v in state_dict.items()
  47. }
  48. self.model.load_state_dict(state_dict)
  49. def half(self):
  50. self.model = self.model.half()
  51. return self
  52. def bfloat16(self):
  53. self.model = self.model.bfloat16()
  54. return self
  55. @classmethod
  56. def build_model(cls, args, task):
  57. """Build a new model instance."""
  58. return cls(args)
  59. def forward(self, batch, **kwargs):
  60. outputs = self.model.forward(batch)
  61. return outputs, self.config.loss
  62. def base_architecture(args):
  63. args.model_name = getattr(args, 'model_name', 'model_2')