# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os from collections import OrderedDict from typing import Callable, Dict, List, Optional, Union import torch from megatron_util import get_args, mpu from megatron_util.global_vars import get_global_memory_buffer from megatron_util.model import (AttnMaskType, Float16Module, LayerNorm, bias_gelu_impl) from megatron_util.model.fused_softmax import FusedScaleMaskSoftmax from torch import nn from torch.nn import functional as F from transformers.modeling_utils import PreTrainedModel from modelscope.models import TorchModel from modelscope.models.nlp.gpt3 import GPT3Config from modelscope.outputs import TextGenerationModelOutput, TokenGeneratorOutput from modelscope.utils.megatron_utils import init_megatron_util from modelscope.utils.nlp.load_checkpoint import pre_load from modelscope.utils.streaming_output import StreamingOutputMixin class GPT3ParallelMLP(nn.Module): """MLP. MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the state back into h hidden dimension. """ def __init__(self, config, init_method, output_layer_init_method): super().__init__() # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear( config.hidden_size, config.ffn_hidden_size, gather_output=False, init_method=init_method, skip_bias_add=True) self.bias_gelu_fusion = config.bias_gelu_fusion self.activation_func = F.gelu # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( config.ffn_hidden_size, config.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h( hidden_states) if self.bias_gelu_fusion: intermediate_parallel = \ bias_gelu_impl(intermediate_parallel, bias_parallel) else: intermediate_parallel = \ self.activation_func(intermediate_parallel + bias_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias class GPT3Embedding(nn.Module): """Language model embeddings. Arguments: hidden_size: hidden size vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, config, init_method): super().__init__() self.hidden_size = config.hidden_size self.init_method = init_method # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding( config.vocab_size, self.hidden_size, init_method=self.init_method) # Position embedding (serial). self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.hidden_size) # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) self.fp32_residual_connection = config.fp32_residual_connection self.sequence_parallel = config.sequence_parallel # Embeddings dropout self.embedding_dropout = nn.Dropout(config.hidden_dropout) def zero_parameters(self): """Zero out all parameters in embedding.""" self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True def forward(self, input_ids, position_ids): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings # Data format change to avoid explicit transposes : [b s h] --> [s b h]. embeddings = embeddings.transpose(0, 1).contiguous() # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: embeddings = embeddings.float() # Dropout. if self.sequence_parallel: embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) with mpu.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: embeddings = self.embedding_dropout(embeddings) return embeddings class NoopTransformerLayer(nn.Module): def __init__(self, layer_number): super().__init__() self.layer_number = layer_number def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): return hidden_states.clone() def attention_mask_func(attention_scores, attention_mask): attention_scores.masked_fill_(attention_mask, -10000.0) return attention_scores class GPT3CoreAttention(nn.Module): def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding): super().__init__() self.fp16 = config.fp16 self.bf16 = config.bf16 self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attn_mask_type = attn_mask_type self.sequence_parallel = config.sequence_parallel projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = mpu.divide(projection_size, world_size) self.hidden_size_per_attention_head = mpu.divide( projection_size, config.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( config.num_attention_heads, world_size) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, config.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = get_global_memory_buffer().get_tensor( (output_size[0] * output_size[1], output_size[2], output_size[3]), query_layer.dtype, 'mpu') # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. if not self.sequence_parallel: with mpu.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) else: attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view( value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class GPT3ParallelAttention(nn.Module): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, config, init_method, output_layer_init_method, layer_number): super().__init__() self.layer_number = max(1, layer_number) self.params_dtype = config.params_dtype projection_size = config.kv_channels * config.num_attention_heads # Per attention head and per partition values. world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = mpu.divide( projection_size, config.num_attention_heads) self.num_attention_heads_per_partition = mpu.divide( config.num_attention_heads, world_size) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( config.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) self.core_attention = GPT3CoreAttention(config, self.layer_number) # Output. self.dense = mpu.RowParallelLinear( projection_size, config.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, inference_params=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory) else: inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) # ================================== # Adjust key and value for inference # ================================== if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] # ================================== # core attention computation # ================================== context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) return output, bias class nullcontext: def __init__(self, enter_result=None): self.enter_result = enter_result def __enter__(self): return self.enter_result def __exit__(self, *excinfo): pass def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = F.dropout(x + bias, p=prob, training=training) out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) class GPT3ParallelTransformerLayer(nn.Module): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__(self, config, init_method, output_layer_init_method, layer_number): super().__init__() self.layer_number = layer_number self.apply_residual_connection_post_layernorm \ = config.apply_residual_connection_post_layernorm self.bf16 = config.bf16 self.fp32_residual_connection = config.fp32_residual_connection # Layernorm on the input data. self.input_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=config.no_persist_layer_norm, sequence_parallel=config.sequence_parallel) # Self attention. self.self_attention = GPT3ParallelAttention(config, init_method, output_layer_init_method, layer_number) self.hidden_dropout = config.hidden_dropout self.bias_dropout_fusion = config.bias_dropout_fusion # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=config.no_persist_layer_norm, sequence_parallel=config.sequence_parallel) # MLP self.mlp = GPT3ParallelMLP(config, init_method, output_layer_init_method) # Set bias+dropout+add fusion grad_enable execution handler. TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) self.bias_dropout_add_exec_handler = \ nullcontext if use_nvfuser else torch.enable_grad def forward(self, hidden_states, attention_mask, inference_params=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) with self.bias_dropout_add_exec_handler(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input with self.bias_dropout_add_exec_handler(): output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, # which rejects view tensors. While making a viewless tensor here # won't result in memory savings (like the data loader, or # p2p_communication), it serves to document the origin of this # 'view' tensor. output = mpu.make_viewless_tensor( inp=output, requires_grad=output.requires_grad, keep_graph=True) return output class GPT3ParallelTransformer(nn.Module): """Transformer class.""" def __init__(self, config, init_method, output_layer_init_method, post_layer_norm=True, pre_process=True, post_process=True): super().__init__() self.bf16 = config.bf16 self.fp32_residual_connection = config.fp32_residual_connection self.post_layer_norm = post_layer_norm self.pre_process = pre_process self.post_process = post_process self.input_tensor = None self.sequence_parallel = config.sequence_parallel # Number of layers. self.num_layers = config.num_hidden_layers # Transformer layers. def build_layer(layer_number): return GPT3ParallelTransformerLayer(config, init_method, output_layer_init_method, layer_number) if self.num_layers == 0: self.num_layers = 1 self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) else: self.layers = torch.nn.ModuleList( [build_layer(i + 1) for i in range(self.num_layers)]) if self.post_process and self.post_layer_norm: # Final layer norm before output. self.final_layernorm = LayerNorm( config.hidden_size, eps=config.layernorm_epsilon, no_persist_layer_norm=config.no_persist_layer_norm, sequence_parallel=config.sequence_parallel) def _get_layer(self, layer_number): return self.layers[layer_number] def forward(self, hidden_states, attention_mask, inference_params=None): # hidden_states: [s, b, h] if not self.pre_process: # See set_input_tensor() hidden_states = self.input_tensor # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' # above creates a view tensor, and '.contiguous()' is a pass-through. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # the need to make it viewless. # # However, we don't explicitly check mbs == 1 here because # make_viewless_tensor() has negligible overhead when its input # is already viewless. # # - For the 'else' case above, calling make_viewless_tensor() here is # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. hidden_states = mpu.make_viewless_tensor( hidden_states, requires_grad=True, keep_graph=True, ) if self.sequence_parallel: rng_context = mpu.get_cuda_rng_tracker().fork() else: rng_context = nullcontext() with rng_context: # Forward pass. for index in range(self.num_layers): layer = self._get_layer(index) hidden_states = layer( hidden_states, attention_mask, inference_params=inference_params) # Final layer norm. if self.post_process and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states class GPT3TransformerLanguageModel(nn.Module): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, config, init_method, output_layer_init_method): super().__init__() self.hidden_size = config.hidden_size self.init_method = init_method self.encoder_hidden_state = None # Embeddings. self.embedding = GPT3Embedding(config, self.init_method) # Transformer. self.encoder = GPT3ParallelTransformer( config, self.init_method, output_layer_init_method, ) def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, inference_params=None, enc_hidden_states=None): # Encoder embedding. encoder_input = self.embedding(enc_input_ids, enc_position_ids) # Run encoder. if enc_hidden_states is None: if self.encoder is not None: encoder_output = self.encoder( encoder_input, enc_attn_mask, inference_params=inference_params) else: encoder_output = self.encoder_hidden_state else: encoder_output = enc_hidden_states.to(encoder_input.dtype) return encoder_output def init_method_normal(sigma): """Init method based on N(0, sigma).""" def init_(tensor): return nn.init.normal_(tensor, mean=0.0, std=sigma) return init_ def scaled_init_method_normal(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" std = sigma / math.sqrt(2.0 * num_layers) def init_(tensor): return nn.init.normal_(tensor, mean=0.0, std=std) return init_ class GPT3Model(PreTrainedModel): config_class = GPT3Config def __init__(self, config): super().__init__(config) self.language_model = GPT3TransformerLanguageModel( config, init_method_normal(config.init_method_std), scaled_init_method_normal(config.init_method_std, config.num_hidden_layers)) def word_embeddings_weight(self): return self.language_model.embedding.word_embeddings.weight @staticmethod def build_attention_mask_and_position_ids(tokens): seq_length = tokens.size(1) attention_mask = torch.tril( torch.ones((1, 1, seq_length, seq_length), device=tokens.device)) attention_mask = (attention_mask < 0.5) position_ids = torch.arange( seq_length, dtype=torch.long, device=tokens.device) position_ids = position_ids.unsqueeze(0).expand_as(tokens) return attention_mask, position_ids def forward(self, input_ids, attention_mask=None, position_ids=None, inference_params=None, labels=None, **kwargs): if attention_mask is None and position_ids is None: attention_mask, position_ids = \ self.build_attention_mask_and_position_ids(input_ids) lm_output = self.language_model( input_ids, position_ids, attention_mask, inference_params=inference_params) logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( lm_output, self.word_embeddings_weight(), None, False, True, self.config.sequence_parallel) losses = None if labels is not None: # [b s] => [s b] labels = labels.transpose(0, 1).contiguous() losses = mpu.vocab_parallel_cross_entropy( logits_parallel.clone().float(), labels) # [s b] => [b s] losses = losses.transpose(0, 1).contiguous() # Gather if needed. logits = mpu.gather_from_tensor_model_parallel_region(logits_parallel) # [s b h] => [b s h] logits = logits.transpose(0, 1).contiguous() return logits, losses def modify_logits_for_top_k_filtering(logits, top_k): """Set the logits for none top-k values to -inf.""" filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] logits.masked_fill_(filter_, float('-Inf')) def modify_logits_for_top_p_filtering(logits, top_p): """Set the logits for none top-p values to -inf.""" # First sort and calculate cumulative sum of probabilities. sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Filteration based on the cumulative sum. filter_ = cumulative_probs > top_p # This shift by 1 is weird and I cannot justify it. This existed # in the original implementation: # https://github.com/ari-holtzman/degen/blob/master/gen.py # and I guess it is needed so keeping it for now. filter_[:, 1:] = filter_[:, :-1].clone() # Make sure we at least have one token to select from. filter_[..., 0] = 0 # Fill in the filtered part filter_ = filter_.scatter(1, sorted_indices, filter_) logits.masked_fill_(filter_, float('-Inf')) def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): """ Sample and generate a token. Note: logits has the dimension [b, v] where b is the batch size and v is the vocabulary size. If vocab_size is provided, we will make sure the sample that is generated is in [0, vocab-size). This will avoid out of vocabulary generations due to padding. """ # Check logits for consistency. assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' # Greedy is just simple argmax. if top_k == 1: assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' samples = torch.argmax(logits, dim=-1) # Top-k or top-p sampling. else: # Clone so we do not modify the inputs, logits = logits.clone() # Apply temperature in place. if temperature != 1.0: logits.div_(temperature) if top_k > 1: assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' assert top_k <= logits.size(1), 'top-k is larger than logit size.' if vocab_size: assert top_k < vocab_size, 'top-k is larger than vocab size.' modify_logits_for_top_k_filtering(logits, top_k) elif top_p > 0.0: assert top_p <= 1.0, 'top-p should be in (0, 1].' modify_logits_for_top_p_filtering(logits, top_p) # After filtering, we need to recalculate the distribution. probs = logits.softmax(dim=-1) samples = torch.multinomial(probs, num_samples=1).view(-1) # If vocab size is provided, make sure the samples are in # in the range [0, vocab-size). if vocab_size: samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) return samples class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" def __init__(self, max_batch_size, max_sequence_len): """Note that offsets are set to zero and we always set the flag to allocate memory. After the first call, make sure to set this flag to False.""" self.max_sequence_len = max_sequence_len self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.key_value_memory_dict = {} def swap_key_value_dict(self, batch_idx): 'swap between batches' if len(self.key_value_memory_dict) == 0: raise ValueError('should not swap when dict in empty') for layer_number in self.key_value_memory_dict.keys(): inference_key_memory, inference_value_memory = self.key_value_memory_dict[ layer_number] assert len(batch_idx) == inference_key_memory.shape[ 1] # make sure batch size is the same new_inference_key_memory = inference_key_memory[:, batch_idx] new_inference_value_memory = inference_value_memory[:, batch_idx] self.key_value_memory_dict[layer_number] = ( new_inference_key_memory, new_inference_value_memory) def split_into_partitions(tensor, num_partitions, partition_dim, stride): per_partition_size = mpu.utils.divide( tensor.size(partition_dim), num_partitions) per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) partitions_list = torch.split( tensor, per_partition_per_stride_size, dim=partition_dim) partitions = [] for i in range(num_partitions): partition = torch.cat( partitions_list[i::num_partitions], dim=partition_dim) partitions.append(partition) return partitions def split_state_dict(state_dict: Dict[str, torch.Tensor], model: GPT3Model, partitions: int) -> Dict[str, torch.Tensor]: if partitions == 1: return state_dict rank: int = mpu.get_tensor_model_parallel_rank() for name, parameters in model.named_parameters(): if parameters.shape == state_dict[name].shape: continue dim = max(parameters.partition_dim, 0) stride = parameters.partition_stride state_dict[name] = split_into_partitions(state_dict[name], partitions, dim, stride)[rank] return state_dict class DistributedGPT3(TorchModel, StreamingOutputMixin): def __init__(self, model_dir, rank, path_load_tag='model', *args, megatron_cfg=None, **kwargs): super().__init__(model_dir, *args, **kwargs) init_megatron_util(megatron_cfg, model_dir, rank=rank) self.config = GPT3Config.from_pretrained(model_dir) # Build model. model = GPT3Model(self.config) for param in model.parameters(): mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if self.config.fp16 or self.config.bf16: model = Float16Module(model, self.config) self.dist_model = model tensor_ws = mpu.get_tensor_model_parallel_world_size() ckpt_ws = get_args().get('checkpoint_tensor_model_parallel_size', None) ckpt_ws = tensor_ws if ckpt_ws is None else ckpt_ws ckpt_rank = mpu.get_tensor_model_parallel_rank() * ckpt_ws // tensor_ws load_model = pre_load(ckpt_rank, model_dir, tag=path_load_tag) load_model = split_state_dict(load_model, model, tensor_ws // ckpt_ws) self.dist_model.load_state_dict( load_model, strict=kwargs.get('strict', True)) self.inference_params = None def train(self, mode: bool = True): if mode: self.inference_params = None return super().train(mode) def forward(self, tokens, attention_mask=None, position_ids=None, labels=None, prompts_len=None, inputs_len=None): logits, losses = self.dist_model( tokens, attention_mask, position_ids, inference_params=self.inference_params, labels=labels) loss = None if labels is None: self.inference_params.sequence_len_offset += tokens.size(1) else: loss_mask = torch.ones( labels.size(), dtype=torch.float, device=tokens.device) if inputs_len is None: for i, l in enumerate(prompts_len): loss_mask[i, l:] = 0 else: for i, l in enumerate(inputs_len): loss_mask[i, l - 1:] = 0 for i, l in enumerate(prompts_len): loss_mask[i, :l - 1] = 0 losses = losses.float() loss_mask = loss_mask.view(-1).float() mask_sum = loss_mask.sum() if mask_sum == 0: loss = torch.sum(losses.view(-1)).zero_() else: loss = torch.sum(losses.view(-1) * loss_mask) / mask_sum return TextGenerationModelOutput(logits=logits, loss=loss) def sample(self, tokens, prompts_len=None, use_eod_token_for_early_termination=True, stop_on_double_eol=False, stop_on_eol=False, **kwargs): top_k = kwargs.pop('top_k', self.config.top_k) top_p = kwargs.pop('top_p', self.config.top_p) temperature = kwargs.pop('temperature', self.config.temperature) max_length = kwargs.pop( 'max_length', tokens.size(1) + self.config.tokens_to_generate) batch_size = tokens.size(0) lengths = prompts_len if lengths is None: lengths = torch.tensor([tokens.size(1)], device=tokens.device) min_prompt_length = lengths.min().item() max_sequence_length = min(max_length, self.config.max_position_embeddings) # If the context is too big, this happens if min_prompt_length >= max_sequence_length: raise ValueError('context length + tokens_to_generate too large') pad_length = max_sequence_length - tokens.size(1) if pad_length > 0: pads = torch.zeros( batch_size, pad_length, device=tokens.device).long() tokens = torch.cat((tokens, pads), dim=-1) # Initialize inference parameters. self.inference_params = InferenceParams(batch_size, max_sequence_length) # Added termination_id to support the case that we want to terminate the # generation once that id is generated. termination_id = self.config.eod_id # Whether we have reached a termination id. is_generation_done = torch.zeros( batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) # ============= # Run infernece # ============= attention_mask, position_ids = \ GPT3Model.build_attention_mask_and_position_ids(tokens) prev_context_length = 0 for context_length in range(min_prompt_length, max_sequence_length): # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] attention_mask2use = attention_mask[ ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = self(tokens2use, attention_mask2use, positions2use).logits # Sample. last_token_logits = logits[:, -1, :] new_sample = sample( last_token_logits, top_k=top_k, top_p=top_p, temperature=temperature, vocab_size=self.config.vocab_size) # If a prompt length is smaller or equal th current context # length, it means we have started generating tokens started = lengths <= context_length # Update the tokens. tokens[started, context_length] = new_sample[started] # streaming output yield TokenGeneratorOutput(sequences=tokens[:, :(context_length + 1)]) # Update the context length for the next token generation. prev_context_length = context_length # instead tokenization should be in the inference loop so stop sequences can be used if stop_on_double_eol: hit_double_eol = (new_sample == 628).byte() & started.byte() hit_two_eols = (new_sample == 198).byte() & ( tokens[:, context_length - 1] == 198).byte() & started.byte() done_token = hit_double_eol | hit_two_eols elif stop_on_eol: hit_double_eol = (new_sample == 628).byte() & started.byte() hit_eol = (new_sample == 198).byte() & started.byte() done_token = hit_double_eol | hit_eol else: done_token = (new_sample == termination_id).byte() & \ started.byte() is_generation_done = is_generation_done | done_token done = torch.all(is_generation_done) if use_eod_token_for_early_termination and done: break def beam_search(self, tokens, beam_size=5, num_return_gen=1, **kwargs): batch_size = tokens.size(0) assert (batch_size == 1) prompt_length = kwargs.pop( 'prompt_length', torch.tensor([tokens.size(1)], device=tokens.device)).item() stop_token = self.config.eod_id pads = torch.ones( 1, self.config.tokens_to_generate, device=tokens.device).long() * stop_token tokens = torch.cat((tokens, pads), dim=-1) final_sequence_length = tokens.size(1) final_sequence_length = min(final_sequence_length, self.config.max_position_embeddings) # If the context is too big, this happens if prompt_length >= final_sequence_length: raise ValueError('context length + tokens_to_generate too large') # Initialize inference parameters. self.inference_params = InferenceParams(beam_size, final_sequence_length) beam_hyp = BeamHypotheses(beam_size) done = False scores = torch.zeros( beam_size, dtype=torch.float32, device=torch.cuda.current_device()).unsqueeze(1) # ============= # Run infernece # ============= tokens = tokens.repeat(beam_size, 1) attention_mask, position_ids = \ GPT3Model.build_attention_mask_and_position_ids(tokens) prev_context_length = 0 for context_length in range(prompt_length, final_sequence_length): # Pick the slice that we need to pass through the network. tokens2use = tokens[:, prev_context_length:context_length] positions2use = position_ids[:, prev_context_length:context_length] attention_mask2use = attention_mask[ ..., prev_context_length:context_length, :context_length] # logits will be meanigful only in the last pipeline stage. logits = self(tokens2use, attention_mask2use, positions2use).logits vocab_size = logits.size(2) log_probs = F.log_softmax(logits, dim=2) new_scores = log_probs[:, -1, :] + scores if context_length == prompt_length: # if this is the first one sorted_scores, indices = torch.sort( new_scores[0, :], descending=True) else: sorted_scores, indices = torch.sort( new_scores.view(-1), descending=True) best_beam_ids = torch.div(indices[:2 * beam_size], vocab_size).trunc().long() best_words = indices[:2 * beam_size] % vocab_size best_scores = sorted_scores[:2 * beam_size] next_beams = [] for beam_token_rank, (token_id, beam_score, beam_id) in enumerate( zip(best_words, best_scores, best_beam_ids)): if token_id.item() == stop_token: # if beam_token does not belong to top num_beams tokens, it should not be added is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size if is_beam_token_worse_than_top_num_beams: continue beam_hyp.add(tokens[beam_id].clone(), beam_score, context_length + 1 - prompt_length) else: # add next predicted token since it is not eos_token next_beams.append((token_id, beam_score, beam_id)) if len(next_beams) == beam_size: break if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): done = True break best_batches = tokens.new([item[2] for item in next_beams]) tokens = tokens[best_batches, :] tokens[:, context_length] = tokens.new( [item[0] for item in next_beams]) scores = scores.new([item[1] for item in next_beams]).unsqueeze(1) # set inference key values to make it consistent with best beam index self.inference_params.swap_key_value_dict(best_batches) # Update the context length for the next token generation. prev_context_length = context_length # if cannot find stop token, add open beams to hyps if not done: for beam_id in range(beam_size): beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) # rank based on scores sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True) num_return_gen = min(num_return_gen, len(sorted_hyps)) scores = [sorted_hyps[i][0] for i in range(num_return_gen)] tokens = [sorted_hyps[i][1] for i in range(num_return_gen)] scores = torch.stack(scores, dim=0) tokens = torch.stack(tokens, dim=0) return TokenGeneratorOutput(sequences=tokens, scores=scores) @torch.no_grad() def generate(self, tokens, do_sample=True, *args, **kwargs): if do_sample: last_output = None for output in self.sample(tokens, *args, **kwargs): last_output = output return last_output else: return self.beam_search(tokens, *args, **kwargs) @torch.no_grad() def stream_generate(self, tokens, *args, **kwargs): return self.sample(tokens, *args, **kwargs) def state_dict(self, destination=None, prefix='', keep_vars=False): return self.dist_model.state_dict(destination, prefix, keep_vars) def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): return self.dist_model.load_state_dict(state_dict, strict) def save_pretrained(self, target_folder: Union[str, os.PathLike], save_checkpoint_names: Union[str, List[str]] = None, save_function: Callable = None, config: Optional[dict] = None, **kwargs): # DistributedPipeline type is different from task name config['pipeline']['type'] = 'gpt3-generation' config['model'].pop('rank', None) config['model'].pop('megatron_cfg', None) config['megatron'].pop('rank', None) config['megatron'].pop('checkpoint_tensor_model_parallel_size', None) tp_size = get_args().tensor_model_parallel_size pp_size = get_args().pipeline_model_parallel_size config['megatron']['world_size'] = tp_size * pp_size return super().save_pretrained(target_folder, save_checkpoint_names, save_function, config, **kwargs) class BeamHypotheses: def __init__(self, num_beams: int, length_penalty: float = 1.0, early_stopping: bool = False): """ Initialize n-best list of hypotheses. """ self.length_penalty = length_penalty self.early_stopping = early_stopping self.num_beams = num_beams self.beams = [] self.worst_score = 1e9 def __len__(self): """ Number of hypotheses in the list. """ return len(self.beams) def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None): """ Add a new hypothesis to the list. """ score = sum_logprobs / (hyp.shape[-1]**self.length_penalty) if len(self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp, beam_indices)) if len(self) > self.num_beams: sorted_next_scores = sorted([ (s, idx) for idx, (s, _, _) in enumerate(self.beams) ]) del self.beams[sorted_next_scores[0][1]] self.worst_score = sorted_next_scores[1][0] else: self.worst_score = min(score, self.worst_score) def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. """ if len(self) < self.num_beams: return False elif self.early_stopping: return True else: cur_score = best_sum_logprobs / cur_len**self.length_penalty ret = self.worst_score >= cur_score return ret