| 1234567891011121314151617181920212223242526272829303132333435363738 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from transformers import BloomConfig
- from transformers import BloomModel as BloomModelTransform
- from modelscope.metainfo import Models
- from modelscope.models import TorchModel
- from modelscope.models.builder import BACKBONES
- from modelscope.utils.constant import Tasks
- class MsModelMixin:
- @classmethod
- def _instantiate(cls, **kwargs):
- """Instantiate the model.
- Args:
- kwargs: Input args.
- model_dir: The model dir used to load the checkpoint and the label information.
- Returns:
- The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
- """
- model_dir = kwargs.pop('model_dir', None)
- kwargs.pop('device', None)
- if model_dir is None:
- config = BloomConfig(**kwargs)
- model = cls(config)
- else:
- model = super(MsModelMixin, cls).from_pretrained(
- pretrained_model_name_or_path=model_dir, **kwargs)
- model.model_dir = model_dir
- return model
- @BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom)
- class BloomModel(MsModelMixin, BloomModelTransform, TorchModel):
- pass
|