# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Dict import torch from megatron_util import mpu, print_rank_0 from megatron_util.fp16 import FP16_Module from torch.nn import functional as F from modelscope.models import TorchModel from modelscope.models.base import Tensor from modelscope.utils.logger import get_logger from modelscope.utils.megatron_utils import init_megatron_util from modelscope.utils.nlp.load_checkpoint import pre_load from . import PlugModel from .configuration import PlugNLGConfig logger = get_logger() class DistributedPlug(TorchModel): """ The wrapper class of PLUG Model to initialize parallel environment, load model weights, generate sentences. Parameters: model_dir (`str`, *required*): Path to model damo/nlp_plug_text-generation_27B. The model structure in model_dir should be like this: model_dir |_ config.json |_ configuration.json |_ ds_zero-offload_10B_config.json |_ vocab.txt |_ model <-- an empty directory Model binaries shall be downloaded separately to populate the model directory, so that the model directory would contain the following binaries: |_ model |_ mp_rank_00_model_states.pt |_ mp_rank_01_model_states.pt |_ mp_rank_02_model_states.pt |_ mp_rank_03_model_states.pt |_ mp_rank_04_model_states.pt |_ mp_rank_05_model_states.pt |_ mp_rank_06_model_states.pt |_ mp_rank_07_model_states.pt rank (`int`, *required*): Used to identify different GPUs in a tensor parallel environment. eg. The rank of GPU #0 is 0, and the model file `mp_rank_00_model_states.pt` will be loaded on this GPU. world_size (`int`, *required*, defaults to 8): The parallel size in total. model_parallel_size (`int`, *required*, defaults to 8): The parallel size of model(tensor parallel). master_ip (`str`, *required*): The master IP, can usually be set to `"127.0.0.1"`, used as part of [`~torch.distributed.init_process_group`] method parameter `init_method`. `init_method` = `"tcp://{master_ip}:{master_port}"` master_port (`str`, *required*): The master port, can usually be set to `"29500"`, used as part of [`~torch.distributed.init_process_group`] method parameter `init_method`. `init_method` = `"tcp://{master_ip}:{master_port}"` seed (`int`, *optional*, defaults to 42): Random seed to control sampling. """ def __init__(self, model_dir, rank, **kwargs): super().__init__(model_dir, **kwargs) self.rank = rank self.model_cfg = kwargs self.config = PlugNLGConfig.from_pretrained(model_dir) init_megatron_util(model_dir=model_dir, rank=rank) self.iteration = 0 self.model = self.initialize_model(path_load_tag='model') def initialize_model(self, path_load_tag='model'): """Build the model.""" print_rank_0('Building Plug model. It will take a few minutes ...') model = PlugModel(self.config) if mpu.get_data_parallel_rank() == 0: logger.info( ' > number of parameters on model parallel rank {}: {}'.format( mpu.get_tensor_model_parallel_rank(), sum([p.nelement() for p in model.parameters()]))) if self.config.deepspeed and self.config.fp16: model.half() # GPU allocation. model.cuda(torch.cuda.current_device()) # Fp16 conversion. if self.config.fp16: model = FP16_Module(model) if self.config.fp32_embedding: model.module.model.bert.embeddings.word_embeddings.float() model.module.model.bert.embeddings.position_embeddings.float() model.module.model.bert.embeddings.token_type_embeddings.float( ) if self.config.fp32_tokentypes: model.module.model.bert.embeddings.token_type_embeddings.float( ) if self.config.fp32_layernorm: for name, _module in model.named_modules(): if 'LayerNorm' in name: _module.float() load_model = pre_load( mpu.get_tensor_model_parallel_rank(), self.model_dir, tag=path_load_tag) model_dict = model.module.model.state_dict() for key in load_model: if key not in model_dict.keys(): print_rank_0('Skip key: ' + key) else: print_rank_0('Loading key: ' + key) model.module.model.load_state_dict(load_model, strict=False) return model @staticmethod def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): # This function has been mostly taken from huggingface conversational ai code at # https://medium.com/huggingface/how-to-build-a-state-of-the-art- # conversational-ai-with-transfer-learning-2d818ac26313 if top_k > 0: # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p > 0.0: # convert to 1D logits = logits.view(logits.size()[1]).contiguous() sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum( F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[indices_to_remove] = filter_value # going back to 2D logits = logits.view(1, -1).contiguous() return logits def forward(self, input_tokens, token_type_ids=None, attention_mask=None, target_tokens=None, position_ids=None, decode_attention_mask=None, checkpoint_activations=False, is_infer=False, sequence_output=None, parallel_output=True): return self.model( input_tokens, token_type_ids, attention_mask, target_tokens, position_ids, decode_attention_mask, checkpoint_activations=checkpoint_activations, is_infer=is_infer, sequence_output=sequence_output, parallel_output=parallel_output) def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): device = torch.cuda.current_device() batch_size = input['input_ids'].shape[0] tokens = input['input_ids'].view(1, -1).contiguous().to(device) dec_input_ids = input['dec_input_ids'].to(device) attention_mask = input['attention_mask'].to(device) self.model.eval() with torch.no_grad(): # Only supports batch_size=1 all_generate_tokens = [] generate_tokens = [] counter = 0 sequence_output = None vocab_size = self.config.original_vocab_size sep_token_idx = 102 # index of [SEP] token in BertTokenizer while counter < out_length: if counter % 128 == 0 and counter != 0: # Sliding window generate_tokens.append(sep_token_idx) start = (tokens == sep_token_idx).nonzero( as_tuple=True)[-1] if start + len(generate_tokens) >= 512: tokens = torch.cat([ tokens[:start], torch.cuda.LongTensor(generate_tokens) ], -1)[-512:] else: tokens[0][start:start + len(generate_tokens )] = torch.cuda.LongTensor( generate_tokens) attention_mask = (tokens != 0) dec_input_ids = input['dec_input_ids'].to(device) generate_tokens = [] sequence_output = None position_ids = torch.full([batch_size, 1], len(generate_tokens), dtype=torch.long, device=device) _, logits, sequence_output = self.model( tokens, None, attention_mask, dec_input_ids, attention_mask, position_ids, is_infer=True, sequence_output=sequence_output, parallel_output=False) logits = logits[:, -1, :] logits = logits / self.model_cfg['temperature'] logits = self.top_k_logits( logits, top_k=self.model_cfg['top_k'], top_p=self.model_cfg['top_p']) log_probs = F.softmax(logits, dim=-1) prev = torch.multinomial(log_probs, num_samples=1) prev_token = prev[0].item() if prev_token >= vocab_size: prev_token = 100 prev[0] = 100 if prev_token == 102 and len(all_generate_tokens) > int( max(1, out_length) * 0.8): break if prev_token == 102: counter += 1 continue dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) generate_tokens.append(prev_token) all_generate_tokens.append(prev_token) counter += 1 generate_context = [] for token in all_generate_tokens: if generate_context and generate_context[ -1] == 100 and token == 100: continue else: generate_context.append(token) return {'generate_context': generate_context}