backbone.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """ PyTorch LLaMA model."""
  21. from transformers.models.llama import LlamaConfig
  22. from transformers.models.llama import LlamaModel as LlamaModelHF
  23. from transformers.models.llama import \
  24. LlamaPreTrainedModel as LlamaPreTrainedModelHF
  25. from modelscope.metainfo import Models
  26. from modelscope.models import Model, TorchModel
  27. from modelscope.models.builder import MODELS
  28. from modelscope.utils.constant import Tasks
  29. from modelscope.utils.logger import get_logger
  30. logger = get_logger()
  31. class MsModelMixin:
  32. @classmethod
  33. def _instantiate(cls, **kwargs):
  34. """Instantiate the model.
  35. Args:
  36. kwargs: Input args.
  37. model_dir: The model dir used to load the checkpoint and the label information.
  38. num_labels: An optional arg to tell the model how many classes to initialize.
  39. Method will call utils.parse_label_mapping if num_labels not supplied.
  40. If num_labels is not found, the model will use the default setting (2 classes).
  41. Returns:
  42. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  43. """
  44. model_dir = kwargs.pop('model_dir', None)
  45. device = kwargs.pop('device', None)
  46. if model_dir is None:
  47. config = LlamaConfig(**kwargs)
  48. model = cls(config)
  49. else:
  50. model = super(MsModelMixin, cls).from_pretrained(
  51. pretrained_model_name_or_path=model_dir, **kwargs)
  52. model.model_dir = model_dir
  53. return model if 'device_map' in kwargs \
  54. or device is None else model.to(device)
  55. class LlamaPreTrainedModel(MsModelMixin, LlamaPreTrainedModelHF, TorchModel):
  56. pass
  57. @MODELS.register_module(Tasks.backbone, module_name=Models.llama2)
  58. @MODELS.register_module(Tasks.backbone, module_name=Models.llama)
  59. class LlamaModel(MsModelMixin, LlamaModelHF, TorchModel):
  60. pass