initialize.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) 2022 Zhipu.AI
  2. import argparse
  3. import time
  4. import torch
  5. from SwissArmyTransformer import get_args, get_tokenizer
  6. from SwissArmyTransformer.arguments import initialize_distributed
  7. from SwissArmyTransformer.model import GLM130B
  8. from SwissArmyTransformer.mpu import (get_model_parallel_group,
  9. get_model_parallel_rank,
  10. get_model_parallel_world_size)
  11. from SwissArmyTransformer.training import load_checkpoint
  12. from .quantization import quantize
  13. def add_bminf_args(parser):
  14. """Arguments for BMInf"""
  15. group = parser.add_argument_group('BMInf')
  16. group.add_argument(
  17. '--bminf',
  18. action='store_true',
  19. help='Use BMInf to support low resource evaluation')
  20. group.add_argument(
  21. '--bminf-memory-limit',
  22. type=int,
  23. default=20,
  24. help='Max memory for model per GPU (in GB)')
  25. return parser
  26. def add_quantization_args(parser):
  27. group = parser.add_argument_group('Quantization')
  28. group.add_argument('--quantization-bit-width', type=int, default=4)
  29. group.add_argument(
  30. '--from-quantized-checkpoint',
  31. type=bool,
  32. default=True,
  33. help='Loading from a quantized checkpoint')
  34. def add_initialization_args(parser):
  35. group = parser.add_argument_group('Initialization')
  36. group.add_argument(
  37. '--sequential-initialization',
  38. action='store_true',
  39. help=
  40. 'Initialize sequentially in tensor parallel group (reduce CPU RAM for initialization)',
  41. )
  42. def set_up_model_args(args):
  43. args.model_parallel_size = 4
  44. args.num_layers = 70
  45. args.hidden_size = 12288
  46. args.inner_hidden_size = 32768
  47. args.vocab_size = 150528
  48. args.num_attention_heads = 96
  49. args.max_sequence_length = 2048
  50. args.tokenizer_type = 'icetk-glm-130B'
  51. args.layernorm_order = 'post'
  52. args.skip_init = True
  53. args.fp16 = True
  54. args.mode = 'inference'
  55. return args
  56. def initialize(extra_args_provider):
  57. parser = argparse.ArgumentParser(add_help=False)
  58. add_bminf_args(parser)
  59. add_quantization_args(parser)
  60. add_initialization_args(parser)
  61. GLM130B.add_model_specific_args(parser)
  62. extra_args_provider(parser)
  63. known, args_list = parser.parse_known_args()
  64. args_list += ['--model-parallel-size', '4', '--mode', 'inference']
  65. args = get_args(args_list)
  66. args = set_up_model_args(args)
  67. args = argparse.Namespace(**vars(args), **vars(known))
  68. args.do_train = False
  69. initialize_distributed(args)
  70. return args
  71. def initialize_model_and_tokenizer(args):
  72. tokenizer = get_tokenizer(args)
  73. torch.distributed.barrier()
  74. start = time.time()
  75. for i in range(get_model_parallel_world_size()):
  76. if get_model_parallel_rank() == i:
  77. # Initialize model
  78. model = GLM130B(args).half()
  79. if args.from_quantized_checkpoint:
  80. assert args.quantization_bit_width is not None
  81. # Quantize model before moving to GPU
  82. model = quantize(model, args.quantization_bit_width)
  83. # Load checkpoint
  84. load_checkpoint(model, args)
  85. if args.quantization_bit_width is not None and not args.from_quantized_checkpoint:
  86. # Quantize model before moving to GPU
  87. model = quantize(model, args.quantization_bit_width)
  88. if args.bminf:
  89. import bminf
  90. if torch.distributed.get_rank() == 0:
  91. print(
  92. f'> BMInf activated, memory limit: {args.bminf_memory_limit} GB'
  93. )
  94. with torch.cuda.device(args.device):
  95. model = bminf.wrapper(
  96. model,
  97. quantization=False,
  98. memory_limit=args.bminf_memory_limit << 30)
  99. else:
  100. model = model.to(args.device)
  101. if args.sequential_initialization:
  102. torch.distributed.barrier(group=get_model_parallel_group())
  103. torch.distributed.barrier()
  104. if torch.distributed.get_rank() == 0:
  105. print(f'> Model initialized in {time.time() - start:.1f}s')
  106. torch.cuda.empty_cache()
  107. model.eval()
  108. # generate rotary embedding cache
  109. original_parallel_output = model.transformer.parallel_output
  110. model.transformer.parallel_output = True
  111. with torch.no_grad():
  112. _, *_ = model(
  113. torch.ones(
  114. 1,
  115. args.max_sequence_length,
  116. device=torch.cuda.current_device(),
  117. dtype=torch.int64),
  118. torch.arange(
  119. args.max_sequence_length,
  120. device=torch.cuda.current_device(),
  121. dtype=torch.int64).view(1, -1),
  122. torch.randn(
  123. 1,
  124. 1,
  125. args.max_sequence_length,
  126. args.max_sequence_length,
  127. device=torch.cuda.current_device(),
  128. ) < 0.5,
  129. )
  130. model.transformer.parallel_output = original_parallel_output
  131. torch.distributed.barrier()
  132. return model, tokenizer