backbone.py 498 B

123456789101112131415
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from transformers import GPT2Config
  3. from transformers import GPT2Model as GPT2ModelTransform
  4. from modelscope.metainfo import Models
  5. from modelscope.models.builder import BACKBONES
  6. from modelscope.utils.constant import Tasks
  7. @BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.gpt2)
  8. class GPT2Model(GPT2ModelTransform):
  9. def __init__(self, **kwargs):
  10. config = GPT2Config(**kwargs)
  11. super().__init__(config)