train_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Copyright (c) 2022 Zhipu.AI
  2. import deepspeed
  3. import torch
  4. from apex.optimizers import FusedAdam as Adam
  5. from megatron_util import mpu
  6. from megatron_util.fp16 import DynamicLossScaler, FP16_Module, FP16_Optimizer
  7. from torch import distributed as dist
  8. from .model import DistributedDataParallel as LocalDDP
  9. from .model import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast,
  10. GLMForSequenceClassification, GLMForSingleTokenCloze,
  11. GLMModel)
  12. from .model import PyTorchDistributedDataParallel as TorchDDP
  13. from .model import glm_get_params_for_weight_decay_optimization
  14. from .utils import get_checkpoint_iteration, get_checkpoint_name, print_rank_0
  15. def load_pretrained(model, checkpoint_path, args, task_tokens=None):
  16. load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path)
  17. checkpoint_name = get_checkpoint_name(load_dir, tag, release)
  18. if mpu.get_data_parallel_rank() == 0:
  19. print('global rank {} is loading pretrained model {}'.format(
  20. torch.distributed.get_rank(), checkpoint_name))
  21. # Load the checkpoint.
  22. sd = torch.load(checkpoint_name, map_location='cpu')
  23. if args.deepspeed:
  24. model = model.module
  25. if isinstance(model, TorchDDP):
  26. model = model.module
  27. if isinstance(model, FP16_Module):
  28. model = model.module
  29. if hasattr(model, 'model'):
  30. model = model.model
  31. # Model.
  32. def extend_embedding_weights(state_weights, model_weights):
  33. original_length = state_weights.shape[0]
  34. assert original_length <= args.max_position_embeddings + 1
  35. new_weights = model_weights.clone()
  36. new_weights[:original_length] = state_weights
  37. return new_weights
  38. if args.block_lm:
  39. if 'transformer.block_position_embeddings.weight' in sd['module']:
  40. position_weights = sd['module'][
  41. 'transformer.position_embeddings.weight']
  42. if args.max_position_embeddings + 1 > position_weights.shape[0]:
  43. sd['module'][
  44. 'transformer.position_embeddings.weight'] = extend_embedding_weights(
  45. position_weights,
  46. model.state_dict()
  47. ['transformer.position_embeddings.weight'].data)
  48. print_rank_0(
  49. f'Extend position embedding to {args.max_position_embeddings + 1}'
  50. )
  51. if 'transformer.block_position_embeddings.weight' in sd['module']:
  52. block_position_weights = sd['module'][
  53. 'transformer.block_position_embeddings.weight']
  54. if args.max_position_embeddings + 1 > block_position_weights.shape[
  55. 0]:
  56. sd['module'][
  57. 'transformer.block_position_embeddings.weight'] = extend_embedding_weights(
  58. block_position_weights,
  59. model.state_dict()
  60. ['transformer.block_position_embeddings.weight'].data)
  61. print_rank_0(
  62. f'Extend block position embedding to {args.max_position_embeddings + 1}'
  63. )
  64. for key in list(model.state_dict().keys()):
  65. print(key)
  66. model.state_dict()[key.replace(
  67. 'mixins.block_position_embedding.block_position_embeddings.weight',
  68. 'transformer.block_position_embeddings.weight').replace(
  69. 'transformer.word_embeddings.weight',
  70. 'word_embeddings.weight')] = model.state_dict().pop(key)
  71. missing_keys, unexpected_keys = model.load_state_dict(
  72. sd['module'], strict=False)
  73. if missing_keys or unexpected_keys:
  74. print_rank_0(
  75. f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}')
  76. if args.continuous_prompt and args.prompt_init:
  77. model.prompt_spell.init_embedding(model.word_embeddings.weight.data,
  78. task_tokens)
  79. def get_model(args,
  80. model_type=None,
  81. multi_token=True,
  82. num_labels=None,
  83. spell_length=None):
  84. """Build the model."""
  85. print_rank_0('building GPT2 model ...')
  86. if args.pretrained_bert:
  87. if model_type == 'multiple_choice':
  88. model = BertForMultipleChoice.from_pretrained(
  89. args.tokenizer_model_type,
  90. cache_dir=args.cache_dir,
  91. fp32_layernorm=args.fp32_layernorm,
  92. fp32_embedding=args.fp32_embedding,
  93. layernorm_epsilon=args.layernorm_epsilon)
  94. elif model_type == 'classification':
  95. model = BertForSequenceClassification.from_pretrained(
  96. args.tokenizer_model_type,
  97. cache_dir=args.cache_dir,
  98. fp32_layernorm=args.fp32_layernorm,
  99. fp32_embedding=args.fp32_embedding,
  100. layernorm_epsilon=args.layernorm_epsilon,
  101. num_labels=num_labels)
  102. else:
  103. raise NotImplementedError
  104. else:
  105. output_predict, paralle_output = True, True
  106. if (model_type == 'multiple_choice'
  107. or model_type == 'classification') and not args.cloze_eval:
  108. output_predict = False
  109. if model_type is not None:
  110. paralle_output = False
  111. if spell_length is not None:
  112. print_rank_0(f'Continuous spell length {spell_length}')
  113. model = GLMModel(
  114. num_layers=args.num_layers,
  115. vocab_size=args.vocab_size,
  116. hidden_size=args.hidden_size,
  117. num_attention_heads=args.num_attention_heads,
  118. embedding_dropout_prob=args.hidden_dropout,
  119. attention_dropout_prob=args.attention_dropout,
  120. output_dropout_prob=args.hidden_dropout,
  121. max_sequence_length=args.max_position_embeddings,
  122. max_memory_length=args.mem_length,
  123. checkpoint_activations=args.checkpoint_activations,
  124. checkpoint_num_layers=args.checkpoint_num_layers,
  125. parallel_output=paralle_output,
  126. relative_encoding=args.transformer_xl,
  127. block_position_encoding=args.block_lm and not args.masked_lm,
  128. output_predict=output_predict,
  129. spell_length=spell_length,
  130. spell_func=args.prompt_func,
  131. attention_scale=args.attention_scale)
  132. if args.freeze_transformer:
  133. model.freeze_transformer(
  134. tune_prefix_layers=args.tune_prefix_layers)
  135. if model_type is not None:
  136. if model_type == 'multiple_choice':
  137. if args.cloze_eval:
  138. if multi_token:
  139. if args.fast_decode:
  140. model = GLMForMultiTokenClozeFast(
  141. model, length_penalty=args.length_penalty)
  142. else:
  143. model = GLMForMultiTokenCloze(
  144. model, length_penalty=args.length_penalty)
  145. else:
  146. model = GLMForSingleTokenCloze(
  147. model, take_softmax=args.adapet)
  148. else:
  149. model = GLMForSequenceClassification(
  150. model,
  151. args.hidden_size,
  152. args.output_dropout,
  153. args.pool_token,
  154. num_class=num_labels)
  155. elif model_type == 'classification':
  156. model = GLMForSequenceClassification(
  157. model,
  158. args.hidden_size,
  159. args.output_dropout,
  160. args.pool_token,
  161. num_class=num_labels)
  162. elif model_type == 'generation':
  163. pass
  164. else:
  165. raise NotImplementedError(model_type)
  166. if mpu.get_data_parallel_rank() == 0:
  167. print(
  168. ' > number of parameters on model parallel rank {}: {}'.format(
  169. mpu.get_model_parallel_rank(),
  170. sum([p.nelement() for p in model.parameters()])),
  171. flush=True)
  172. # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
  173. if args.fp16:
  174. model.half()
  175. # GPU allocation.
  176. model.cuda(torch.cuda.current_device())
  177. # Fp16 conversion.
  178. if args.fp16:
  179. model = FP16_Module(model)
  180. # Wrap model for distributed training.
  181. if not args.deepspeed and (args.train_iters or args.epochs):
  182. if args.DDP_impl == 'torch':
  183. i = torch.cuda.current_device()
  184. model = TorchDDP(
  185. model,
  186. device_ids=[i],
  187. output_device=i,
  188. process_group=mpu.get_data_parallel_group())
  189. elif args.DDP_impl == 'local':
  190. model = LocalDDP(model)
  191. else:
  192. print_rank_0('Skip DDP model')
  193. return model
  194. def get_optimizer_param_groups(model):
  195. # Build parameter groups (weight decay and non-decay).
  196. while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)):
  197. model = model.module
  198. param_groups = glm_get_params_for_weight_decay_optimization(model)
  199. # Add model parallel attribute if it is not set.
  200. for param_group in param_groups:
  201. # print('## param_group', len(param_group['params']))
  202. for param in param_group['params']:
  203. if not hasattr(param, 'model_parallel'):
  204. param.model_parallel = False
  205. return param_groups
  206. def get_optimizer(param_groups, args):
  207. """Set up the optimizer."""
  208. if args.cpu_optimizer:
  209. # Apex FusedAdam uses decoupled weight decay so use the same here
  210. if args.cpu_torch_adam:
  211. cpu_adam_optimizer = torch.optim.AdamW
  212. else:
  213. from deepspeed.ops.adam import DeepSpeedCPUAdam
  214. cpu_adam_optimizer = DeepSpeedCPUAdam
  215. optimizer = cpu_adam_optimizer(
  216. param_groups, lr=args.lr, weight_decay=args.weight_decay)
  217. else:
  218. # Use FusedAdam.
  219. if args.optimizer == 'adam':
  220. optimizer = Adam(
  221. param_groups,
  222. lr=args.lr,
  223. weight_decay=args.weight_decay,
  224. betas=(args.adam_beta1, args.adam_beta2),
  225. eps=args.adam_eps)
  226. elif args.optimizer == 'adafactor':
  227. from transformers import Adafactor
  228. optimizer = Adafactor(
  229. param_groups,
  230. lr=args.lr,
  231. relative_step=False,
  232. warmup_init=False)
  233. else:
  234. raise NotImplementedError
  235. print(f'Optimizer = {optimizer.__class__.__name__}')
  236. if hasattr(args, 'deepspeed') and args.deepspeed:
  237. raise NotImplementedError
  238. # fp16 wrapper is not required for DeepSpeed.
  239. # return optimizer
  240. # Wrap into fp16 optimizer.
  241. if args.fp16:
  242. optimizer = FP16_Optimizer(
  243. optimizer,
  244. static_loss_scale=args.loss_scale,
  245. dynamic_loss_scale=args.dynamic_loss_scale,
  246. dynamic_loss_args={
  247. 'scale_window': args.loss_scale_window,
  248. 'min_scale': args.min_scale,
  249. 'delayed_shift': args.hysteresis
  250. })
  251. return optimizer
  252. def get_learning_rate_scheduler(optimizer, args):
  253. """Build the learning rate scheduler."""
  254. # Add linear learning rate scheduler.
  255. if args.lr_decay_iters is not None:
  256. num_iters = args.lr_decay_iters
  257. else:
  258. num_iters = args.train_iters
  259. if args.finetune:
  260. num_iters = num_iters // args.gradient_accumulation_steps
  261. num_iters = max(1, num_iters)
  262. init_step = -1
  263. warmup_iter = args.warmup * num_iters
  264. lr_scheduler = AnnealingLR(
  265. optimizer,
  266. start_lr=args.lr,
  267. warmup_iter=warmup_iter,
  268. num_iters=num_iters - warmup_iter,
  269. decay_style=args.lr_decay_style,
  270. last_iter=init_step,
  271. decay_ratio=args.lr_decay_ratio)
  272. return lr_scheduler
  273. def setup_model_and_optimizer(args,
  274. model_type=None,
  275. multi_token=True,
  276. num_labels=None,
  277. spell_length=None):
  278. """Setup model and optimizer."""
  279. model = get_model(
  280. args,
  281. model_type=model_type,
  282. multi_token=multi_token,
  283. num_labels=num_labels,
  284. spell_length=spell_length)
  285. param_groups = get_optimizer_param_groups(model)
  286. if args.train_data is not None or args.data_dir is not None and (
  287. args.epochs > 0 or args.train_iters > 0):
  288. if args.deepspeed:
  289. print_rank_0('DeepSpeed is enabled.')
  290. model, optimizer, _, _ = deepspeed.initialize(
  291. model=model,
  292. model_parameters=param_groups,
  293. args=args,
  294. mpu=mpu,
  295. dist_init_required=False)
  296. else:
  297. optimizer = get_optimizer(param_groups, args)
  298. lr_scheduler = get_learning_rate_scheduler(optimizer, args)
  299. else:
  300. optimizer, lr_scheduler = None, None
  301. return model, optimizer, lr_scheduler
  302. def backward_step(optimizer, model, lm_loss, args, timers):
  303. """Backward step."""
  304. # Total loss.
  305. loss = lm_loss
  306. # Backward pass.
  307. if args.deepspeed:
  308. model.backward(loss)
  309. else:
  310. # optimizer.zero_grad()
  311. if args.fp16:
  312. optimizer.backward(loss, update_master_grads=False)
  313. else:
  314. loss.backward()
  315. if args.deepspeed or args.DDP_impl == 'torch':
  316. # DeepSpeed backward propagation already addressed all reduce communication.
  317. # Reset the timer to avoid breaking timer logs below.
  318. timers('allreduce').reset()
  319. else:
  320. timers('allreduce').start()
  321. model.allreduce_params(
  322. reduce_after=False, fp32_allreduce=args.fp32_allreduce)
  323. timers('allreduce').stop()
  324. # Update master gradients.
  325. if not args.deepspeed:
  326. if args.fp16:
  327. optimizer.update_master_grads()
  328. # Clipping gradients helps prevent the exploding gradient.
  329. if args.clip_grad > 0:
  330. if not args.fp16:
  331. mpu.clip_grad_norm(model.parameters(), args.clip_grad)
  332. else:
  333. optimizer.clip_master_grads(args.clip_grad)
  334. return lm_loss
  335. def see_memory_usage(message, force=False):
  336. if not force:
  337. return
  338. dist.barrier()
  339. if dist.get_rank() == 0:
  340. print(message)
  341. print('Memory Allocated ',
  342. torch.cuda.memory_allocated() / (1024 * 1024 * 1024),
  343. 'GigaBytes')
  344. print('Max Memory Allocated ',
  345. torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),
  346. 'GigaBytes')
  347. print('Cache Allocated ',
  348. torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes')
  349. print('Max cache Allocated ',
  350. torch.cuda.max_memory_cached() / (1024 * 1024 * 1024),
  351. 'GigaBytes')
  352. print(' ')
  353. # input("Press Any Key To Continue ..")
  354. def train_step(data_iterator,
  355. model,
  356. optimizer,
  357. lr_scheduler,
  358. args,
  359. timers,
  360. forward_step_func,
  361. mems=None,
  362. single_step=False):
  363. """Single training step."""
  364. lm_loss_total, count = 0.0, 0
  365. mems = [] if mems is None else mems
  366. if not args.deepspeed:
  367. optimizer.zero_grad()
  368. while True:
  369. skipped_iter, complete = 0, False
  370. # Forward model for one step.
  371. timers('forward').start()
  372. lm_loss, mems, _ = forward_step_func(data_iterator, model, args,
  373. timers, mems)
  374. timers('forward').stop()
  375. # print_rank_0("Forward step")
  376. if not args.deepspeed:
  377. lm_loss /= args.gradient_accumulation_steps
  378. reduced_loss = lm_loss.detach().clone().view(1)
  379. torch.distributed.all_reduce(
  380. reduced_loss.data, group=mpu.get_data_parallel_group())
  381. reduced_loss.data = reduced_loss.data / (
  382. args.world_size / args.model_parallel_size)
  383. if not DynamicLossScaler._has_inf_or_nan(reduced_loss):
  384. lm_loss_total += reduced_loss
  385. count += 1
  386. # Calculate gradients, reduce across processes, and clip.
  387. timers('backward').start()
  388. backward_step(optimizer, model, lm_loss, args, timers)
  389. timers('backward').stop()
  390. # print_rank_0("Backward step")
  391. # Update parameters.
  392. timers('optimizer').start()
  393. if args.deepspeed:
  394. if model.is_gradient_accumulation_boundary():
  395. model.step()
  396. complete = True
  397. if not (args.fp16 and optimizer.overflow):
  398. lr_scheduler.step()
  399. else:
  400. skipped_iter = 1
  401. else:
  402. model.step()
  403. else:
  404. if count == args.gradient_accumulation_steps:
  405. optimizer.step()
  406. complete = True
  407. # Update learning rate.
  408. if not (args.fp16 and optimizer.overflow):
  409. lr_scheduler.step()
  410. else:
  411. skipped_iter = 1
  412. # print_rank_0("Optimizer step")
  413. timers('optimizer').stop()
  414. if complete:
  415. break
  416. else:
  417. print_rank_0('Found NaN loss, skip backward')
  418. del lm_loss, reduced_loss
  419. mems = []
  420. if single_step:
  421. break
  422. if args.deepspeed:
  423. lm_loss_total = lm_loss_total / count
  424. return lm_loss_total, skipped_iter, mems