text_generation.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from collections import OrderedDict
  3. from typing import Dict, Generator
  4. import torch
  5. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  6. from modelscope.metainfo import Models
  7. from modelscope.models.base import Tensor, TorchModel
  8. from modelscope.models.builder import MODELS
  9. from modelscope.utils.constant import Tasks
  10. from modelscope.utils.logger import get_logger
  11. from modelscope.utils.streaming_output import StreamingOutputMixin
  12. logger = get_logger()
  13. __all__ = ['PolyLMForTextGeneration']
  14. @MODELS.register_module(Tasks.text_generation, module_name=Models.polylm)
  15. class PolyLMForTextGeneration(TorchModel, StreamingOutputMixin):
  16. def __init__(self, model_dir: str, *args, **kwargs):
  17. """initialize the text generation model from the `model_dir` path.
  18. Args:
  19. model_dir (str): the model path.
  20. """
  21. super().__init__(model_dir, *args, **kwargs)
  22. self.tokenizer = AutoTokenizer.from_pretrained(
  23. model_dir, legacy=False, use_fast=False)
  24. self.check_trust_remote_code(
  25. info_str=
  26. f'Use trust_remote_code=True. Will invoke codes from {model_dir}. Please make sure '
  27. 'that you can trust the external codes.',
  28. model_dir=model_dir)
  29. self.model = AutoModelForCausalLM.from_pretrained(
  30. model_dir,
  31. device_map='auto',
  32. trust_remote_code=self.trust_remote_code)
  33. self.model.eval()
  34. def forward(self, input: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
  35. """return the result by the model
  36. Args:
  37. input (Dict[str, Tensor]): the preprocessed data
  38. Returns:
  39. Dict[str, Tensor]: results
  40. """
  41. res = self.generate(input, **kwargs)
  42. return res
  43. def generate(self, input: Dict[str, Tensor],
  44. **kwargs) -> Dict[str, Tensor]:
  45. device = self.model.device
  46. inputs = self.tokenizer(input, return_tensors='pt')
  47. outputs = self.model.generate(
  48. inputs.input_ids.to(device),
  49. attention_mask=inputs.attention_mask.to(device),
  50. **kwargs)
  51. pred = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
  52. return pred