text_generation.py 501 B

123456789101112131415
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from transformers import BloomForCausalLM as BloomForCausalLMTransform
  3. from modelscope.metainfo import Models
  4. from modelscope.models import MODELS
  5. from modelscope.utils.constant import Tasks
  6. from .backbone import MsModelMixin, TorchModel
  7. @MODELS.register_module(
  8. group_key=Tasks.text_generation, module_name=Models.bloom)
  9. class BloomForTextGeneration(MsModelMixin, BloomForCausalLMTransform,
  10. TorchModel):
  11. pass