| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # Copyright (c) 2022 Zhipu.AI
- import argparse
- import time
- import torch
- from SwissArmyTransformer import get_args, get_tokenizer
- from SwissArmyTransformer.arguments import initialize_distributed
- from SwissArmyTransformer.model import GLM130B
- from SwissArmyTransformer.mpu import (get_model_parallel_group,
- get_model_parallel_rank,
- get_model_parallel_world_size)
- from SwissArmyTransformer.training import load_checkpoint
- from .quantization import quantize
- def add_bminf_args(parser):
- """Arguments for BMInf"""
- group = parser.add_argument_group('BMInf')
- group.add_argument(
- '--bminf',
- action='store_true',
- help='Use BMInf to support low resource evaluation')
- group.add_argument(
- '--bminf-memory-limit',
- type=int,
- default=20,
- help='Max memory for model per GPU (in GB)')
- return parser
- def add_quantization_args(parser):
- group = parser.add_argument_group('Quantization')
- group.add_argument('--quantization-bit-width', type=int, default=4)
- group.add_argument(
- '--from-quantized-checkpoint',
- type=bool,
- default=True,
- help='Loading from a quantized checkpoint')
- def add_initialization_args(parser):
- group = parser.add_argument_group('Initialization')
- group.add_argument(
- '--sequential-initialization',
- action='store_true',
- help=
- 'Initialize sequentially in tensor parallel group (reduce CPU RAM for initialization)',
- )
- def set_up_model_args(args):
- args.model_parallel_size = 4
- args.num_layers = 70
- args.hidden_size = 12288
- args.inner_hidden_size = 32768
- args.vocab_size = 150528
- args.num_attention_heads = 96
- args.max_sequence_length = 2048
- args.tokenizer_type = 'icetk-glm-130B'
- args.layernorm_order = 'post'
- args.skip_init = True
- args.fp16 = True
- args.mode = 'inference'
- return args
- def initialize(extra_args_provider):
- parser = argparse.ArgumentParser(add_help=False)
- add_bminf_args(parser)
- add_quantization_args(parser)
- add_initialization_args(parser)
- GLM130B.add_model_specific_args(parser)
- extra_args_provider(parser)
- known, args_list = parser.parse_known_args()
- args_list += ['--model-parallel-size', '4', '--mode', 'inference']
- args = get_args(args_list)
- args = set_up_model_args(args)
- args = argparse.Namespace(**vars(args), **vars(known))
- args.do_train = False
- initialize_distributed(args)
- return args
- def initialize_model_and_tokenizer(args):
- tokenizer = get_tokenizer(args)
- torch.distributed.barrier()
- start = time.time()
- for i in range(get_model_parallel_world_size()):
- if get_model_parallel_rank() == i:
- # Initialize model
- model = GLM130B(args).half()
- if args.from_quantized_checkpoint:
- assert args.quantization_bit_width is not None
- # Quantize model before moving to GPU
- model = quantize(model, args.quantization_bit_width)
- # Load checkpoint
- load_checkpoint(model, args)
- if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
- # Quantize model before moving to GPU
- model = quantize(model, args.quantization_bit_width)
- if args.bminf:
- import bminf
- if torch.distributed.get_rank() == 0:
- print(
- f'> BMInf activated, memory limit: {args.bminf_memory_limit} GB'
- )
- with torch.cuda.device(args.device):
- model = bminf.wrapper(
- model,
- quantization=False,
- memory_limit=args.bminf_memory_limit << 30)
- else:
- model = model.to(args.device)
- if args.sequential_initialization:
- torch.distributed.barrier(group=get_model_parallel_group())
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print(f'> Model initialized in {time.time() - start:.1f}s')
- torch.cuda.empty_cache()
- model.eval()
- # generate rotary embedding cache
- original_parallel_output = model.transformer.parallel_output
- model.transformer.parallel_output = True
- with torch.no_grad():
- _, *_ = model(
- torch.ones(
- 1,
- args.max_sequence_length,
- device=torch.cuda.current_device(),
- dtype=torch.int64),
- torch.arange(
- args.max_sequence_length,
- device=torch.cuda.current_device(),
- dtype=torch.int64).view(1, -1),
- torch.randn(
- 1,
- 1,
- args.max_sequence_length,
- args.max_sequence_length,
- device=torch.cuda.current_device(),
- ) < 0.5,
- )
- model.transformer.parallel_output = original_parallel_output
- torch.distributed.barrier()
- return model, tokenizer
|