| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from collections import OrderedDict
- from typing import Dict, Generator
- import torch
- from transformers import BertTokenizer
- from modelscope.metainfo import Models
- from modelscope.models.base import Tensor, TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.models.nlp.gpt3 import GPT3Model
- from modelscope.utils.constant import Tasks
- from modelscope.utils.hub import read_config
- from modelscope.utils.streaming_output import StreamingOutputMixin
- __all__ = ['GPT3ForTextGeneration']
- @MODELS.register_module(Tasks.text_generation, module_name=Models.gpt3)
- class GPT3ForTextGeneration(TorchModel, StreamingOutputMixin):
- def __init__(self, model_dir: str, *args, **kwargs):
- """initialize the text generation model from the `model_dir` path.
- Args:
- model_dir (str): the model path.
- """
- super().__init__(model_dir, *args, **kwargs)
- # Temporarily compatible with DistributedGPT3 and GPT3Model,
- # the base/large model based on GPT3Model will be replaced in the future,
- # and GPT3Model will be deprecated
- if 'megatron' in read_config(model_dir):
- from modelscope.models.nlp import DistributedGPT3
- self.model = DistributedGPT3(model_dir, **kwargs)
- else:
- self.model = GPT3Model.from_pretrained(model_dir)
- self.tokenizer = BertTokenizer.from_pretrained(model_dir)
- def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
- """return the result by the model
- Args:
- input (Dict[str, Tensor]): the preprocessed data
- Returns:
- Dict[str, Tensor]: results
- Example:
- >>> {
- >>> 'logits': Tensor([[0.54, 0.32...])]), # logits
- >>> }
- """
- return self.model(**input)
- def generate(self, inputs: Dict[str, Tensor],
- **kwargs) -> Dict[str, Tensor]:
- if not isinstance(self.model, GPT3Model):
- return self.model.generate(**inputs, **kwargs)
- tokens = inputs['input_ids']
- lengths = self._get_length(inputs['attention_mask'])
- return self.model.generate(tokens, prompt_length=lengths, **kwargs)
- @staticmethod
- def _get_length(attention_mask: torch.Tensor) -> Tensor:
- return attention_mask.sum(-1) - 1
- def save_pretrained(self, *args, **kwargs):
- if not isinstance(self.model, GPT3Model):
- return self.model.save_pretrained(*args, **kwargs)
- return super().save_pretrained(*args, **kwargs)
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- return self.model.state_dict(destination, prefix, keep_vars)
- def load_state_dict(self,
- state_dict: 'OrderedDict[str, Tensor]',
- strict: bool = True):
- return self.model.load_state_dict(state_dict, strict)
- def stream_generate(self, inputs, **kwargs) -> Generator:
- tokens = inputs['input_ids']
- lengths = self._get_length(inputs['attention_mask'])
- return self.model.streaming_generate(
- tokens, prompt_length=lengths, **kwargs)
|