backbone.py 518 B

12345678910111213141516
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from transformers import GPTNeoConfig
  3. from transformers import GPTNeoModel as GPTNeoModelTransform
  4. from modelscope.metainfo import Models
  5. from modelscope.models.builder import BACKBONES
  6. from modelscope.utils.constant import Tasks
  7. @BACKBONES.register_module(
  8. group_key=Tasks.backbone, module_name=Models.gpt_neo)
  9. class GPTNeoModel(GPTNeoModelTransform):
  10. def __init__(self, **kwargs):
  11. config = GPTNeoConfig(**kwargs)
  12. super().__init__(config)