backbone.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from transformers import BloomConfig
  3. from transformers import BloomModel as BloomModelTransform
  4. from modelscope.metainfo import Models
  5. from modelscope.models import TorchModel
  6. from modelscope.models.builder import BACKBONES
  7. from modelscope.utils.constant import Tasks
  8. class MsModelMixin:
  9. @classmethod
  10. def _instantiate(cls, **kwargs):
  11. """Instantiate the model.
  12. Args:
  13. kwargs: Input args.
  14. model_dir: The model dir used to load the checkpoint and the label information.
  15. Returns:
  16. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  17. """
  18. model_dir = kwargs.pop('model_dir', None)
  19. kwargs.pop('device', None)
  20. if model_dir is None:
  21. config = BloomConfig(**kwargs)
  22. model = cls(config)
  23. else:
  24. model = super(MsModelMixin, cls).from_pretrained(
  25. pretrained_model_name_or_path=model_dir, **kwargs)
  26. model.model_dir = model_dir
  27. return model
  28. @BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom)
  29. class BloomModel(MsModelMixin, BloomModelTransform, TorchModel):
  30. pass