| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429 |
- # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import os
- from typing import Optional, Union
- import addict
- import torch
- from torch import nn
- from torch.nn import functional as F
- from transformers.modeling_utils import PreTrainedModel
- from modelscope.outputs import TokenGeneratorOutput
- from modelscope.utils.constant import ModelFile
- from .configuration import GPT3Config
- from .distributed_gpt3 import sample
- class GPT3SelfAttention(nn.Module):
- """Parallel self-attention layer abstract class.
- Self-attention layer takes input with size [s, b, h]
- and returns output of the same size.
- """
- def __init__(self, config):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.num_attention_heads = config.num_attention_heads
- # Per attention head
- self.hidden_size_per_attention_head = \
- self.hidden_size // self.num_attention_heads
- self.query_key_value = nn.Linear(self.hidden_size,
- 3 * self.hidden_size)
- self.softmax = nn.Softmax(dim=-1)
- self.attention_dropout = nn.Dropout(
- config.attention_probs_dropout_prob)
- # Output.
- self.dense = nn.Linear(self.hidden_size, self.hidden_size)
- self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
- def _transpose_for_scores(self, tensor):
- """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
- size [b, np, s, hn].
- """
- new_tensor_shape = tensor.size()[:-1] + (
- self.num_attention_heads, self.hidden_size_per_attention_head)
- tensor = tensor.view(*new_tensor_shape)
- return tensor.permute(0, 2, 1, 3)
- def _split_tensor_along_last_dim(self,
- tensor,
- num_partitions,
- contiguous_split_chunks=False):
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = tensor.size()[last_dim] // num_partitions
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
- def forward(self, hidden_states, ltor_mask, is_infer=False):
- # hidden_states: [b, s, h]
- # ltor_mask: [1, 1, s, s]
- # Attention heads. [b, s, hp]
- tgt_len = hidden_states.size(1)
- ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len])
- mixed_x_layer = self.query_key_value(hidden_states)
- (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \
- self._split_tensor_along_last_dim(mixed_x_layer, 3)
- # Reshape and transpose [b, np, s, hn]
- query_layer = self._transpose_for_scores(mixed_query_layer)
- key_layer = self._transpose_for_scores(mixed_key_layer)
- value_layer = self._transpose_for_scores(mixed_value_layer)
- previous_type = value_layer.type()
- # Raw attention scores. [b, np, s, s]
- attention_scores = torch.matmul(query_layer,
- key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(
- self.hidden_size_per_attention_head)
- # Apply the left to right attention mask.
- if is_infer:
- src_len = key_layer.size(2)
- ltor_mask = torch.tril(
- torch.ones((1, tgt_len, src_len),
- device=hidden_states.device)).view(
- 1, 1, tgt_len, src_len).type(previous_type)
- converted_mask = 10000.0 * (1.0 - ltor_mask)
- attention_scores = (torch.mul(attention_scores, ltor_mask)
- - converted_mask).type(previous_type)
- # Attention probabilities. [b, np, s, s]
- attention_probs = self.softmax(attention_scores)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.attention_dropout(attention_probs)
- # Context layer.
- # [b, np, s, hn]
- context_layer = torch.matmul(attention_probs, value_layer)
- # [b, s, np, hn]
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (
- self.hidden_size, )
- # [b, s, hp]
- context_layer = context_layer.view(*new_context_layer_shape)
- # Output. [b, s, h]
- output = self.dense(context_layer)
- output = self.output_dropout(output)
- return output
- class GPT3MLP(nn.Module):
- """MLP.
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension.
- """
- def __init__(self, config):
- super().__init__()
- hidden_size = config.hidden_size
- # Project to 4h.
- self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
- self.activation_func = F.gelu
- # Project back to h.
- self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states):
- # [s, b, 4hp]
- intermediate_parallel = self.dense_h_to_4h(hidden_states)
- intermediate_parallel = self.activation_func(intermediate_parallel)
- # [s, b, h]
- output = self.dense_4h_to_h(intermediate_parallel)
- output = self.dropout(output)
- return output
- class GPT3TransformerLayer(nn.Module):
- """A single transformer layer.
- Transformer layer takes input with size [s, b, h] and returns an
- output of the same size.
- """
- def __init__(self, config):
- super().__init__()
- # Layernorm on the input data.
- self.input_layernorm = nn.LayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- # Self attention.
- self.attention = GPT3SelfAttention(config)
- # Layernorm on the attention output
- self.post_attention_layernorm = nn.LayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- # MLP
- self.mlp = GPT3MLP(config)
- def forward(self, hidden_states, ltor_mask):
- # hidden_states: [b, s, h]
- # ltor_mask: [1, 1, s, s]
- # Layer norm at the beginning of the transformer layer.
- layernorm_output = self.input_layernorm(hidden_states)
- # Self attention.
- attention_output = self.attention(layernorm_output, ltor_mask)
- # Residual connection.
- layernorm_input = hidden_states + attention_output
- # Layer norm post the self attention.
- layernorm_output = self.post_attention_layernorm(layernorm_input)
- # MLP.
- mlp_output = self.mlp(layernorm_output)
- # Second residual connection.
- output = layernorm_input + mlp_output
- return output
- class GPT3Transformer(nn.Module):
- """Transformer class."""
- def __init__(self, config):
- super().__init__()
- self.input_tensor = None
- # Number of layers.
- self.num_layers = config.num_hidden_layers
- self.layers = torch.nn.ModuleList(
- [GPT3TransformerLayer(config) for _ in range(self.num_layers)])
- # Final layer norm before output.
- self.final_layernorm = nn.LayerNorm(
- config.hidden_size, eps=config.layernorm_epsilon)
- def _get_layer(self, layer_number):
- return self.layers[layer_number]
- def forward(self, hidden_states, attention_mask):
- # hidden_states: [s, b, h]
- for index in range(self.num_layers):
- layer = self._get_layer(index)
- hidden_states = layer(hidden_states, attention_mask)
- # Final layer norm.
- hidden_states = self.final_layernorm(hidden_states)
- return hidden_states
- class GPT3TransformerLanguageModel(nn.Module):
- """Transformer language model.
- Arguments:
- transformer_hparams: transformer hyperparameters
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- embedding_dropout_prob: dropout probability for embeddings
- num_tokentypes: size of the token-type embeddings. 0 value
- will ignore this embedding
- """
- def __init__(self, config):
- super().__init__()
- # Embeddings.
- self.word_embeddings = nn.Embedding(config.vocab_size,
- config.hidden_size)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings,
- config.hidden_size)
- self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
- # Transformer.
- self.transformer = GPT3Transformer(config)
- def forward(self, input_ids, attention_mask, position_ids):
- words_embeddings = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = words_embeddings + position_embeddings
- transformer_input = self.embedding_dropout(embeddings)
- transformer_output = self.transformer(transformer_input,
- attention_mask)
- logits = F.linear(transformer_output, self.word_embeddings.weight)
- return logits
- class GPT3Model(PreTrainedModel):
- config_class = GPT3Config
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, nn.Linear):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(
- mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(
- mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- def __init__(self, config):
- super().__init__(config)
- self.language_model = GPT3TransformerLanguageModel(config)
- def forward(self,
- input_ids,
- attention_mask=None,
- position_ids=None,
- labels=None,
- **kwargs):
- seq_length = input_ids.size(1)
- attention_mask = torch.tril(
- torch.ones((1, 1, seq_length, seq_length),
- dtype=torch.long,
- device=input_ids.device))
- if position_ids is None:
- position_ids = torch.arange(
- seq_length, dtype=torch.long, device=input_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
- logits = self.language_model(input_ids, attention_mask, position_ids)
- loss = None
- if labels is not None:
- loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(
- logits.view(-1, self.config.vocab_size), labels.view(-1))
- return addict.Dict(loss=loss, logits=logits)
- @classmethod
- def from_pretrained(
- cls, pretrained_model_name_or_path: Optional[Union[str,
- os.PathLike]]):
- config = cls.config_class.from_pretrained(
- pretrained_model_name_or_path)
- model = cls(config)
- state_dict_file = os.path.join(pretrained_model_name_or_path,
- ModelFile.TORCH_MODEL_BIN_FILE)
- state_dict = torch.load(state_dict_file)
- if 'state_dict' in state_dict:
- state_dict = state_dict['state_dict']
- state_dict = {
- k.replace('model.language_model', 'language_model'): v
- for k, v in state_dict.items()
- }
- model.load_state_dict(state_dict)
- return model
- def streaming_generate(self, tokens, temperature=1.0, **kwargs):
- top_k = kwargs.pop('top_k', self.config.top_k)
- top_p = kwargs.pop('top_p', self.config.top_p)
- max_length = kwargs.pop('max_length', tokens.size(1) + 100)
- batch_size = tokens.size(0)
- lengths = kwargs.pop(
- 'prompt_length',
- torch.tensor([tokens.size(1)], device=tokens.device))
- min_prompt_length = lengths.min().item()
- max_sequence_length = min(max_length,
- self.config.max_position_embeddings)
- # If the context is too big, this happens
- if min_prompt_length >= max_sequence_length:
- raise ValueError('context length too large')
- pad_length = max_sequence_length - tokens.size(1)
- if pad_length > 0:
- pads = torch.zeros(
- batch_size, pad_length, device=tokens.device).long()
- tokens = torch.cat((tokens, pads), dim=-1)
- # Added termination_id to support the case that we want to terminate the
- # generation once that id is generated.
- termination_id = self.config.eod_id
- # Whether we have reached a termination id.
- is_generation_done = torch.zeros(
- batch_size, dtype=torch.uint8, device=tokens.device)
- with torch.no_grad():
- for context_length in range(min_prompt_length,
- max_sequence_length):
- # Pick the slice that we need to pass through the network.
- tokens2use = tokens[:, :context_length]
- # logits will be meanigful only in the last pipeline stage.
- logits = self(tokens2use).logits
- # Sample.
- last_token_logits = logits[:, -1, :]
- new_sample = sample(
- last_token_logits,
- top_k=top_k,
- top_p=top_p,
- temperature=temperature,
- vocab_size=self.config.vocab_size)
- # If a prompt length is smaller or equal th current context
- # length, it means we have started generating tokens
- started = lengths <= context_length
- # Update the tokens.
- tokens[started, context_length] = new_sample[started]
- yield TokenGeneratorOutput(sequences=tokens[:, :(context_length
- + 1)])
- done_token = (new_sample == termination_id).byte() & \
- started.byte()
- is_generation_done = is_generation_done | done_token
- done = torch.all(is_generation_done)
- if done:
- break
- def generate(self, tokens, temperature=1.0, **kwargs):
- last_output = None
- for output in self.streaming_generate(tokens, temperature, **kwargs):
- last_output = output
- return last_output
|