| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from collections import OrderedDict
- from typing import Dict, Generator
- import torch
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
- from modelscope.metainfo import Models
- from modelscope.models.base import Tensor, TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import Tasks
- from modelscope.utils.logger import get_logger
- from modelscope.utils.streaming_output import StreamingOutputMixin
- logger = get_logger()
- __all__ = ['PolyLMForTextGeneration']
- @MODELS.register_module(Tasks.text_generation, module_name=Models.polylm)
- class PolyLMForTextGeneration(TorchModel, StreamingOutputMixin):
- def __init__(self, model_dir: str, *args, **kwargs):
- """initialize the text generation model from the `model_dir` path.
- Args:
- model_dir (str): the model path.
- """
- super().__init__(model_dir, *args, **kwargs)
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_dir, legacy=False, use_fast=False)
- self.check_trust_remote_code(
- info_str=
- f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
- 'that you can trust the external codes.',
- model_dir=model_dir)
- self.model = AutoModelForCausalLM.from_pretrained(
- model_dir,
- device_map='auto',
- trust_remote_code=self.trust_remote_code)
- self.model.eval()
- def forward(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
- """return the result by the model
- Args:
- input (Dict[str, Tensor]): the preprocessed data
- Returns:
- Dict[str, Tensor]: results
- """
- res = self.generate(input, **kwargs)
- return res
- def generate(self, input: Dict[str, Tensor],
- **kwargs) -> Dict[str, Tensor]:
- device = self.model.device
- inputs = self.tokenizer(input, return_tensors='pt')
- outputs = self.model.generate(
- inputs.input_ids.to(device),
- attention_mask=inputs.attention_mask.to(device),
- **kwargs)
- pred = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
- return pred
|