plug_trainer.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import os
  2. from typing import Union
  3. import torch
  4. from deepspeed import DeepSpeedEngine
  5. from megatron_util import mpu
  6. from torch import nn
  7. from modelscope.metainfo import Trainers
  8. from modelscope.models.base import TorchModel
  9. from modelscope.models.nlp.plug import DistributedPlug
  10. from modelscope.models.nlp.plug.backbone import BertLayerNorm
  11. from modelscope.models.nlp.plug.generator import TextGenerator
  12. from modelscope.utils.constant import ModeKeys
  13. from ..base import TRAINERS
  14. from ..nlp_trainer import NlpEpochBasedTrainer
  15. @TRAINERS.register_module(module_name=Trainers.nlp_plug_trainer)
  16. class PlugTrainer(NlpEpochBasedTrainer):
  17. def build_model(self) -> Union[nn.Module, TorchModel]:
  18. rank = int(os.environ.get('LOCAL_RANK', -1))
  19. master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
  20. master_port = os.environ.get('MASTER_PORT', '29500')
  21. model = DistributedPlug(
  22. self.model_dir,
  23. rank,
  24. master_ip=master_ip,
  25. master_port=master_port,
  26. **self.cfg.model)
  27. self.unwrap_module(model.model).model_dir = self.model_dir
  28. return model.model
  29. def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
  30. from modelscope.utils.nlp.distributed import DistributedDataParallel as DDP
  31. return DDP(model)
  32. def _get_params_for_weight_decay_optimization(self, module):
  33. weight_decay_params = {'params': []}
  34. no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
  35. for module_ in module.modules():
  36. if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)):
  37. no_weight_decay_params['params'].extend([
  38. p for p in list(module_._parameters.values())
  39. if p is not None
  40. ])
  41. else:
  42. weight_decay_params['params'].extend([
  43. p for n, p in list(module_._parameters.items())
  44. if p is not None and 'mask_score' not in n
  45. and 'mask' not in n and n != 'bias'
  46. ])
  47. no_weight_decay_params['params'].extend([
  48. p for n, p in list(module_._parameters.items())
  49. if p is not None and n == 'bias'
  50. ])
  51. return weight_decay_params, no_weight_decay_params
  52. def create_optimizer_and_scheduler(self):
  53. optimizer, lr_scheduler = self.optimizers
  54. optimizer_cfg = self.cfg.train.get('optimizer', None)
  55. # optim_options = {}
  56. if optimizer_cfg is not None:
  57. optim_options = optimizer_cfg.pop('options', {})
  58. from deepspeed.ops.adam import DeepSpeedCPUAdam
  59. model = self.model
  60. embeddings = model.module.model.bert.embeddings
  61. layers = model.module.model.bert.encoder.layer
  62. dec_layers = model.module.model.decoder.decoder
  63. param_groups = []
  64. param_groups += list(
  65. self._get_params_for_weight_decay_optimization(layers))
  66. param_groups += list(
  67. self._get_params_for_weight_decay_optimization(embeddings))
  68. param_groups += list(
  69. self._get_params_for_weight_decay_optimization(dec_layers))
  70. for param_group in param_groups:
  71. for param in param_group['params']:
  72. if not hasattr(param, 'model_parallel'):
  73. param.model_parallel = False
  74. optimizer = DeepSpeedCPUAdam(
  75. param_groups,
  76. lr=optimizer_cfg.lr,
  77. weight_decay=optimizer_cfg.weight_decay)
  78. lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None)
  79. if lr_scheduler_cfg is not None:
  80. assert optimizer is not None
  81. lr_options = lr_scheduler_cfg.pop('options', {})
  82. from modelscope.models.nlp.plug.AnnealingLR import AnnealingLR
  83. num_iters = self.max_iters
  84. lr_scheduler = AnnealingLR(
  85. optimizer,
  86. start_lr=optimizer_cfg.lr,
  87. warmup_iter=lr_scheduler_cfg.warmup * num_iters,
  88. num_iters=num_iters,
  89. decay_style=lr_scheduler_cfg.decay_style,
  90. last_iter=-1)
  91. self.optimizer = optimizer
  92. self.lr_scheduler = lr_scheduler
  93. return self.optimizer, self.lr_scheduler, optim_options, lr_options
  94. def _get_masks_and_position_ids(self, data, eod_token):
  95. # Extract batch size and sequence length.
  96. batch_size, seq_length = data.size()
  97. # Attention mask (lower triangular).
  98. att_mask_batch = 1
  99. attention_mask = torch.tril(
  100. torch.ones((att_mask_batch, seq_length, seq_length),
  101. device=data.device)).view(att_mask_batch, 1, seq_length,
  102. seq_length)
  103. # Loss mask.
  104. loss_mask = torch.ones(
  105. data.size(), dtype=torch.float, device=data.device)
  106. loss_mask[data == eod_token] = 0.0
  107. # Position ids.
  108. position_ids = torch.arange(
  109. seq_length, dtype=torch.long, device=data.device)
  110. position_ids = position_ids.unsqueeze(0).expand_as(data)
  111. return attention_mask, loss_mask, position_ids
  112. def train_step(self, model, inputs):
  113. self._mode = ModeKeys.TRAIN
  114. # format inputs
  115. checkpoint_activations = getattr(self.cfg.train,
  116. 'checkpoint_activations', True)
  117. tgt_tokens = inputs['labels'][:, :-1].contiguous()
  118. tgt_labels = inputs['labels'][:, 1:].contiguous()
  119. tgt_attention_mask, dec_loss_mask, position_ids = self._get_masks_and_position_ids(
  120. tgt_tokens, 0)
  121. if getattr(self.cfg.train, 'fp16', None):
  122. tgt_attention_mask = tgt_attention_mask.half()
  123. # forward step
  124. _, output = model(
  125. inputs['input_ids'],
  126. None,
  127. inputs['attention_mask'],
  128. tgt_tokens,
  129. position_ids,
  130. tgt_attention_mask,
  131. checkpoint_activations=checkpoint_activations)
  132. losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
  133. tgt_labels)
  134. dec_loss_mask = dec_loss_mask.view(-1)
  135. loss = torch.sum(losses.view(-1) * dec_loss_mask) / dec_loss_mask.sum()
  136. # add model output info to log
  137. self.train_outputs = {'loss': loss}
  138. self.log_buffer.update(self.train_outputs)
  139. def evaluation_step(self, data):
  140. # wrapper 1: DeepspeedEngine, wrapper 2: DDP
  141. # model = self.model.module
  142. if isinstance(self.model, DeepSpeedEngine):
  143. model = self.model.module
  144. else:
  145. model = self.model
  146. model.eval()
  147. # model: fp16 wapper; model.module : distributedPlug
  148. vocab_size = self.unwrap_module(self.model).config.original_vocab_size
  149. batch_size = data['input_ids'].shape[0]
  150. beam_generator = TextGenerator(model,
  151. self.eval_preprocessor.nlp_tokenizer,
  152. None)
  153. with torch.no_grad():
  154. tokens = data['input_ids'].long()
  155. padding_mask = data['attention_mask'].byte()
  156. target_ids = data['labels'].long()
  157. target_labels = target_ids[:, 1:].contiguous()
  158. encoder_inputs = [tokens, None, padding_mask]
  159. result = beam_generator.translate_batch(encoder_inputs)
  160. pred_list = result['predictions']
  161. target_list = target_labels.cpu().numpy().tolist()
  162. result['preds'] = []
  163. data['tgts'] = []
  164. for i in range(batch_size):
  165. pred_ids = pred_list[i][0]
  166. pred_ids[pred_ids > vocab_size - 1] = 100
  167. pred_ids = pred_ids.cpu().numpy().tolist()
  168. gold_string = self.eval_preprocessor.decode(
  169. target_list[i], skip_special_tokens=True)
  170. pred_string = self.eval_preprocessor.decode(
  171. pred_ids, skip_special_tokens=True)
  172. result['preds'].append(pred_string)
  173. data['tgts'].append(gold_string)
  174. return result