| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520 |
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Utilities for logging and serialization"""
- import os
- import random
- import subprocess
- import time
- import json
- import numpy as np
- import torch
- from megatron_util import mpu, print_rank_0
- from megatron_util.fp16 import FP16_Optimizer
- SUMMARY_WRITER_DIR_NAME = 'runs'
- def get_log_dir(name, base):
- return os.path.join(base, SUMMARY_WRITER_DIR_NAME, name)
- def get_hostname():
- hostname_cmd = ['hostname -I']
- result = subprocess.check_output(hostname_cmd, shell=True)
- master_addr = result.decode('utf-8').split()[0]
- return master_addr
- def get_spare_port(args):
- if torch.distributed.get_rank() == 0:
- port = subprocess.check_output(['shuf -n 1 -i 10000-65535'],
- shell=True)
- port = int(port.strip())
- if port == args.master_port:
- port = subprocess.check_output(['shuf -n 1 -i 10000-65535'],
- shell=True)
- port = int(port.strip())
- port = torch.cuda.LongTensor([port])
- else:
- port = torch.cuda.LongTensor([0])
- torch.distributed.broadcast(port, 0)
- port = port.item()
- return port
- def print_and_save_args(args, verbose=True, log_dir=None):
- """Print arguments."""
- if verbose:
- print('arguments:', flush=True)
- for arg in vars(args):
- dots = '.' * (29 - len(arg))
- print(
- ' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True)
- if log_dir is not None:
- json_file = os.path.join(log_dir, 'config.json')
- with open(json_file, 'w') as output:
- json.dump(vars(args), output, sort_keys=True)
- if args.deepspeed and args.deepspeed_config is not None:
- with open(args.deepspeed_config, encoding='utf-8') as file:
- deepspeed_config = json.load(file)
- deepspeed_json_file = os.path.join(log_dir,
- 'config_gpt_large.json')
- with open(deepspeed_json_file, 'w') as output:
- json.dump(deepspeed_config, output)
- def print_params_min_max_norm(optimizer, iteration):
- """Print min, max, and norm of all parameters."""
- index = 0
- rank = torch.distributed.get_rank()
- string = 'iteration, rank, index, model-parallel,min, max, norm\n'
- optimizer_ = optimizer
- if isinstance(optimizer, FP16_Optimizer):
- optimizer_ = optimizer.optimizer
- for param_group in optimizer_.param_groups:
- for param in param_group['params']:
- index += 1
- min_ = param.data.min()
- max_ = param.data.max()
- norm = param.data.norm()
- string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
- iteration, rank, index, int(param.model_parallel))
- string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
- print(string, flush=True)
- class Timers:
- """Group of timers."""
- class Timer:
- """Timer."""
- def __init__(self, name):
- self.name_ = name
- self.elapsed_ = 0.0
- self.started_ = False
- self.start_time = time.time()
- def start(self):
- """Start the timer."""
- assert not self.started_, 'timer has already been started'
- torch.cuda.synchronize()
- self.start_time = time.time()
- self.started_ = True
- def stop(self):
- """Stop the timer."""
- assert self.started_, 'timer is not started'
- torch.cuda.synchronize()
- self.elapsed_ += (time.time() - self.start_time)
- self.started_ = False
- def reset(self):
- """Reset timer."""
- self.elapsed_ = 0.0
- self.started_ = False
- def elapsed(self, reset=True):
- """Calculate the elapsed time."""
- started_ = self.started_
- # If the timing in progress, end it first.
- if self.started_:
- self.stop()
- # Get the elapsed time.
- elapsed_ = self.elapsed_
- # Reset the elapsed time
- if reset:
- self.reset()
- # If timing was in progress, set it back.
- if started_:
- self.start()
- return elapsed_
- def __init__(self):
- self.timers = {}
- def __call__(self, name):
- if name not in self.timers:
- self.timers[name] = self.Timer(name)
- return self.timers[name]
- def log(self, names, normalizer=1.0, reset=True):
- """Log a group of timers."""
- assert normalizer > 0.0
- string = 'time (ms)'
- for name in names:
- elapsed_time = self.timers[name].elapsed(
- reset=reset) * 1000.0 / normalizer
- string += ' | {}: {:.2f}'.format(name, elapsed_time)
- print_rank_0(string)
- def report_memory(name):
- """Simple GPU memory report."""
- mega_bytes = 1024.0 * 1024.0
- string = name + ' memory (MB)'
- string += ' | allocated: {}'.format(torch.cuda.memory_allocated()
- / mega_bytes)
- string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated()
- / mega_bytes)
- string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
- string += ' | max cached: {}'.format(torch.cuda.memory_reserved()
- / mega_bytes)
- print_rank_0(string)
- def get_checkpoint_name(checkpoints_path,
- iteration,
- release=False,
- zero=False):
- if release:
- d = 'release'
- else:
- d = '{}'.format(iteration)
- if zero:
- dp_rank = mpu.get_data_parallel_rank()
- d += '_zero_dp_rank_{}'.format(dp_rank)
- return os.path.join(
- checkpoints_path, d,
- 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
- def ensure_directory_exists(filename):
- dirname = os.path.dirname(filename)
- if not os.path.exists(dirname):
- os.makedirs(dirname, exist_ok=True)
- def get_checkpoint_tracker_filename(checkpoints_path):
- return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
- def save_zero_checkpoint(args, iteration, optimizer):
- zero_sd = {
- 'iteration': iteration,
- 'optimizer_state_dict': optimizer.state_dict()
- }
- zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True)
- ensure_directory_exists(zero_checkpoint_name)
- torch.save(zero_sd, zero_checkpoint_name)
- print(' successfully saved {}'.format(zero_checkpoint_name))
- def save_checkpoint(iteration,
- model,
- optimizer,
- lr_scheduler,
- args,
- tag=None,
- barrier=True,
- only_changed_parameters=False,
- no_deepspeed=False,
- no_save_optim=False):
- """Save a model checkpoint."""
- if tag is None:
- tag = str(iteration)
- if args.deepspeed and not no_deepspeed:
- save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag)
- else:
- # Only rank zer0 of the data parallel writes to the disk.
- if mpu.get_data_parallel_rank() == 0:
- checkpoint_name = get_checkpoint_name(args.save, tag)
- print(
- 'global rank {} is saving checkpoint at iteration {:7d} to {}'.
- format(torch.distributed.get_rank(), iteration,
- checkpoint_name))
- sd = {'iteration': iteration}
- if args.deepspeed:
- model = model.module
- state_dict = model.state_dict()
- if only_changed_parameters:
- requires_grad_dict = {}
- for name, parameter in model.named_parameters():
- requires_grad_dict[name] = parameter.requires_grad
- state_dict = {
- key: value
- for key, value in state_dict.items()
- if requires_grad_dict[key]
- }
- sd['module'] = state_dict
- # Optimizer stuff.
- if not args.no_save_optim and not no_save_optim:
- if optimizer is not None:
- sd['optimizer'] = optimizer.state_dict()
- if lr_scheduler is not None:
- sd['lr_scheduler'] = lr_scheduler.state_dict()
- # rng states.
- if not args.no_save_rng:
- sd['random_rng_state'] = random.getstate()
- sd['np_rng_state'] = np.random.get_state()
- sd['torch_rng_state'] = torch.get_rng_state()
- sd['cuda_rng_state'] = torch.cuda.get_rng_state()
- sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker(
- ).get_states()
- ensure_directory_exists(checkpoint_name)
- torch.save(sd, checkpoint_name)
- print(' successfully saved {}'.format(checkpoint_name))
- # Wait so everyone is done (necessary)
- if barrier:
- torch.distributed.barrier()
- # And update the latest iteration
- if torch.distributed.get_rank() == 0:
- tracker_filename = get_checkpoint_tracker_filename(args.save)
- with open(tracker_filename, 'w') as f:
- f.write(tag)
- def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag):
- """Save a model checkpoint."""
- sd = {}
- sd['iteration'] = iteration
- if lr_scheduler is not None:
- sd['client_lr_scheduler'] = lr_scheduler.state_dict()
- # rng states.
- if not args.no_save_rng:
- sd['random_rng_state'] = random.getstate()
- sd['np_rng_state'] = np.random.get_state()
- sd['torch_rng_state'] = torch.get_rng_state()
- sd['cuda_rng_state'] = torch.cuda.get_rng_state()
- sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
- model.save_checkpoint(args.save, tag, client_state=sd)
- def get_checkpoint_iteration(load_path):
- # Read the tracker file and set the iteration.
- tracker_filename = get_checkpoint_tracker_filename(load_path)
- if not os.path.isfile(tracker_filename):
- print_rank_0('WARNING: could not find the metadata file {} '.format(
- tracker_filename))
- if os.path.isdir(load_path):
- path = os.path.normpath(load_path)
- load_dir, tag = os.path.split(path)
- print_rank_0(
- 'Try to directly load the checkpoint from the directory')
- return load_dir, tag, False, True
- print_rank_0(' will not load any checkpoints and will start from '
- 'random')
- return load_path, 0, False, False
- with open(tracker_filename, 'r', encoding='utf-8') as f:
- metastring = f.read().strip()
- release = metastring == 'release'
- # try:
- # iteration = int(metastring)
- # except ValueError:
- # release = metastring == 'release'
- # if not release:
- # print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
- # tracker_filename))
- # exit()
- # assert iteration > 0 or release, 'error parsing metadata file {}'.format(
- # tracker_filename)
- return load_path, metastring, release, True
- def load_checkpoint(model,
- optimizer,
- lr_scheduler,
- args,
- no_deepspeed=False,
- no_load_optim=False):
- """Load a model checkpoint."""
- load_dir, tag, release, success = get_checkpoint_iteration(args.load)
- if not success:
- return 0
- if args.deepspeed and not no_deepspeed:
- checkpoint_name, sd = model.load_checkpoint(
- load_dir,
- tag,
- load_optimizer_states=not args.no_load_optim and not no_load_optim,
- load_lr_scheduler_states=not args.no_load_lr_scheduler)
- if not args.no_load_lr_scheduler and 'client_lr_scheduler' in sd:
- lr_scheduler.load_state_dict(sd['client_lr_scheduler'])
- print_rank_0('Load lr scheduler state')
- if checkpoint_name is None:
- if mpu.get_data_parallel_rank() == 0:
- print('Unable to load checkpoint.')
- return tag
- else:
- # Checkpoint.
- checkpoint_name = get_checkpoint_name(load_dir, tag, release)
- if mpu.get_data_parallel_rank() == 0:
- print('global rank {} is loading checkpoint {}'.format(
- torch.distributed.get_rank(), checkpoint_name))
- # Load the checkpoint.
- sd = torch.load(checkpoint_name, map_location='cpu')
- # Model.
- if args.deepspeed:
- model = model.module
- missing_keys, unexpected_keys = model.load_state_dict(
- sd['module'], strict=False)
- if missing_keys or unexpected_keys:
- print_rank_0(
- f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}'
- )
- # Optimizer.
- if not release and not args.finetune and not args.no_load_optim and not no_load_optim:
- try:
- if optimizer is not None:
- optimizer.load_state_dict(sd['optimizer'])
- if lr_scheduler is not None:
- lr_scheduler.load_state_dict(sd['lr_scheduler'])
- except KeyError:
- print_rank_0(
- 'Unable to load optimizer from checkpoint {}, exiting. '
- 'Specify --no-load-optim or --finetune to prevent '
- 'attempting to load the optimizer '
- 'state.'.format(checkpoint_name))
- # Iterations.
- if args.finetune or release:
- iteration = 0
- else:
- try:
- iteration = sd['iteration']
- except KeyError:
- try: # Backward compatible with older checkpoints
- iteration = sd['total_iters']
- except KeyError:
- print_rank_0(
- 'A metadata file exists but Unable to load iteration '
- ' from checkpoint {}, starting from 0 iteration'.format(
- checkpoint_name))
- iteration = 0
- # rng states.
- if not release and not args.finetune and not args.no_load_rng:
- try:
- random.setstate(sd['random_rng_state'])
- np.random.set_state(sd['np_rng_state'])
- torch.set_rng_state(sd['torch_rng_state'])
- torch.cuda.set_rng_state(sd['cuda_rng_state'])
- mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
- except KeyError:
- print_rank_0(
- 'Unable to load random state from checkpoint {}, exiting. '
- 'Specify --no-load-rng or --finetune to prevent '
- 'attempting to load the random '
- 'state.'.format(checkpoint_name))
- if mpu.get_data_parallel_rank() == 0:
- print(' successfully loaded {}'.format(checkpoint_name))
- return iteration
- def load_weights(src, dst, dst2src=False):
- """
- Loads weights from src to dst via in place copy.
- src is a huggingface gpt2model, while dst is one of our models.
- dst2src=True loads parameters from our models into huggingface's.
- ^dst2src is still untested
- """
- conv_layer = 'Conv1D' in str(type(src))
- for n, p in src.named_parameters():
- if dst2src:
- data = dst._parameters[n].data
- load = p.data
- else:
- data = p.data
- load = dst._parameters[n].data
- if conv_layer and 'weight' in n:
- data = data.t().contiguous()
- load.copy_(data)
- # dst._parameters[n].data.copy_(data)
- def load_mlp(our, oai, dst2src=False):
- load_weights(oai.c_fc, our.dense_h_to_4h, dst2src)
- load_weights(oai.c_proj, our.dense_4h_to_h, dst2src)
- def load_attention(our, oai, dst2src=False):
- load_weights(oai.c_attn, our.query_key_value, dst2src)
- load_weights(oai.c_proj, our.dense, dst2src)
- def load_transformer_layer(our, oai, dst2src=False):
- load_weights(oai.ln_1, our.input_layernorm, dst2src)
- load_weights(oai.ln_2, our.post_attention_layernorm, dst2src)
- load_mlp(our.mlp, oai.mlp, dst2src)
- load_attention(our.attention, oai.attn, dst2src)
- def move_weights(our, oai, dst2src=False):
- """
- Loads weights from `oai` to `our` via in place copy.
- `oai` is a huggingface gpt2model, while `our` is one of our models.
- dst2src=True loads parameters from our models into huggingface's.
- ^dst2src=True is still untested
- """
- # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)):
- # our=our.module
- transformer_model = oai.transformer
- load_weights(transformer_model.ln_f, our.transformer.final_layernorm,
- dst2src)
- load_weights(transformer_model.wte, our.word_embeddings, dst2src)
- load_weights(transformer_model.wpe, our.position_embeddings, dst2src)
- for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h):
- load_transformer_layer(our_layer, oai_layer, dst2src)
- def debug_finetune_data(local_vars, batch_id, tokenizer):
- tokens, target_ids = local_vars['tokens'], local_vars['target_ids']
- attention_mask, logit_mask, position_ids = local_vars[
- 'attention_mask'], local_vars['logit_mask'], local_vars['position_ids']
- output_tokens = []
- sep = attention_mask[batch_id].item()
- for i, token in enumerate(tokens[batch_id][:sep].tolist()):
- token = tokenizer.IdToToken(token)
- if token == '[MASK]':
- token = f'[{position_ids[batch_id][0, i].item()}]'
- output_tokens.append(token)
- print(' '.join(output_tokens))
- target_positions = []
- for i in range(sep, tokens.size(-1)):
- if logit_mask[batch_id][i]:
- target_positions.append(i)
- print(target_positions)
- print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist()))
- if len(target_ids.shape) > 2:
- print(
- tokenizer.DecodeIds(
- target_ids[batch_id][target_positions].tolist()))
- else:
- print(tokenizer.DecodeIds(target_ids[batch_id].tolist()))
- print(position_ids[batch_id][:, target_positions])
|