text_generation.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from collections import OrderedDict
  3. from typing import Dict, Generator
  4. import torch
  5. from transformers import BertTokenizer
  6. from modelscope.metainfo import Models
  7. from modelscope.models.base import Tensor, TorchModel
  8. from modelscope.models.builder import MODELS
  9. from modelscope.models.nlp.gpt3 import GPT3Model
  10. from modelscope.utils.constant import Tasks
  11. from modelscope.utils.hub import read_config
  12. from modelscope.utils.streaming_output import StreamingOutputMixin
  13. __all__ = ['GPT3ForTextGeneration']
  14. @MODELS.register_module(Tasks.text_generation, module_name=Models.gpt3)
  15. class GPT3ForTextGeneration(TorchModel, StreamingOutputMixin):
  16. def __init__(self, model_dir: str, *args, **kwargs):
  17. """initialize the text generation model from the `model_dir` path.
  18. Args:
  19. model_dir (str): the model path.
  20. """
  21. super().__init__(model_dir, *args, **kwargs)
  22. # Temporarily compatible with DistributedGPT3 and GPT3Model,
  23. # the base/large model based on GPT3Model will be replaced in the future,
  24. # and GPT3Model will be deprecated
  25. if 'megatron' in read_config(model_dir):
  26. from modelscope.models.nlp import DistributedGPT3
  27. self.model = DistributedGPT3(model_dir, **kwargs)
  28. else:
  29. self.model = GPT3Model.from_pretrained(model_dir)
  30. self.tokenizer = BertTokenizer.from_pretrained(model_dir)
  31. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  32. """return the result by the model
  33. Args:
  34. input (Dict[str, Tensor]): the preprocessed data
  35. Returns:
  36. Dict[str, Tensor]: results
  37. Example:
  38. >>> {
  39. >>> 'logits': Tensor([[0.54, 0.32...])]), # logits
  40. >>> }
  41. """
  42. return self.model(**input)
  43. def generate(self, inputs: Dict[str, Tensor],
  44. **kwargs) -> Dict[str, Tensor]:
  45. if not isinstance(self.model, GPT3Model):
  46. return self.model.generate(**inputs, **kwargs)
  47. tokens = inputs['input_ids']
  48. lengths = self._get_length(inputs['attention_mask'])
  49. return self.model.generate(tokens, prompt_length=lengths, **kwargs)
  50. @staticmethod
  51. def _get_length(attention_mask: torch.Tensor) -> Tensor:
  52. return attention_mask.sum(-1) - 1
  53. def save_pretrained(self, *args, **kwargs):
  54. if not isinstance(self.model, GPT3Model):
  55. return self.model.save_pretrained(*args, **kwargs)
  56. return super().save_pretrained(*args, **kwargs)
  57. def state_dict(self, destination=None, prefix='', keep_vars=False):
  58. return self.model.state_dict(destination, prefix, keep_vars)
  59. def load_state_dict(self,
  60. state_dict: 'OrderedDict[str, Tensor]',
  61. strict: bool = True):
  62. return self.model.load_state_dict(state_dict, strict)
  63. def stream_generate(self, inputs, **kwargs) -> Generator:
  64. tokens = inputs['input_ids']
  65. lengths = self._get_length(inputs['attention_mask'])
  66. return self.model.streaming_generate(
  67. tokens, prompt_length=lengths, **kwargs)