gpt_moe_trainer.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from collections.abc import Mapping
  4. from typing import List
  5. import torch
  6. from megatron_util import mpu
  7. from modelscope.metainfo import Trainers
  8. from modelscope.models import TorchModel
  9. from modelscope.trainers.builder import TRAINERS
  10. from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
  11. from modelscope.utils.config import Config
  12. from modelscope.utils.file_utils import func_receive_dict_inputs
  13. @TRAINERS.register_module(module_name=Trainers.gpt_moe_trainer)
  14. class GPTMoETrainer(NlpEpochBasedTrainer):
  15. def rebuild_config(self, cfg: Config):
  16. super().rebuild_config(cfg)
  17. cfg.model.rank = int(os.environ.get('LOCAL_RANK', -1))
  18. cfg.model.master_ip = os.environ.get('MASTER_ADDR', '127.0.0.1')
  19. cfg.model.master_port = os.environ.get('MASTER_PORT', '29500')
  20. return cfg
  21. def train_step(self, model: TorchModel, inputs: Mapping):
  22. keys = list(inputs.keys())
  23. datatype = torch.int64
  24. inputs = mpu.broadcast_data(keys, inputs, datatype)
  25. return super().train_step(model, inputs)
  26. def _decode(self, tokens):
  27. tokenizer = self.eval_preprocessor.tokenizer
  28. return tokenizer.detokenize(tokens.tolist())
  29. def evaluation_step(self, data):
  30. model = self.model.module if self._dist else self.model
  31. model.eval()
  32. with torch.no_grad():
  33. if isinstance(
  34. data,
  35. Mapping) and not func_receive_dict_inputs(model.generate):
  36. result = model.generate(**data)
  37. else:
  38. result = model.generate(data)
  39. prompt_length: List[int] = data['prompt_length']
  40. result['preds'] = [
  41. self._decode(seq[skip_len:])
  42. for seq, skip_len in zip(result['sequences'], prompt_length)
  43. ]
  44. data['tgts'] = [
  45. self._decode(seq[skip_len - 1:])
  46. for seq, skip_len in zip(data['labels'], prompt_length)
  47. ]
  48. assert len(result['preds']) == len(data['tgts'])
  49. return result