gpt3_trainer.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from copy import deepcopy
  4. from typing import Any, Dict, List, Union
  5. import torch
  6. from torch import nn
  7. from modelscope.metainfo import Trainers
  8. from modelscope.models.base import Model, TorchModel
  9. from modelscope.models.nlp import GPT3ForTextGeneration
  10. from modelscope.trainers.builder import TRAINERS
  11. from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer
  12. from modelscope.trainers.parallel.builder import build_parallel
  13. from modelscope.utils.config import Config
  14. from modelscope.utils.megatron_utils import is_megatron_initialized
  15. @TRAINERS.register_module(module_name=Trainers.gpt3_trainer)
  16. class GPT3Trainer(NlpEpochBasedTrainer):
  17. def rebuild_config(self, cfg: Config):
  18. cfg = super().rebuild_config(cfg)
  19. cfg.model.rank = int(os.environ.get('RANK', 0))
  20. return cfg
  21. def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
  22. # config format to reserve custom ddp
  23. if self.cfg.get('parallel', None) is not None:
  24. dp_cfg = deepcopy(self.cfg['parallel'])
  25. dp_cfg.update(
  26. dict(module=model, device_ids=[torch.cuda.current_device()]))
  27. return build_parallel(dp_cfg)
  28. dp_cfg = dict(
  29. type='DistributedDataParallel',
  30. module=model,
  31. find_unused_parameters=True,
  32. device_ids=[torch.cuda.current_device()])
  33. if is_megatron_initialized():
  34. from megatron_util import mpu
  35. dp_cfg.update({
  36. 'output_device': torch.cuda.current_device(),
  37. 'process_group': mpu.get_data_parallel_group()
  38. })
  39. return build_parallel(dp_cfg)
  40. def _decode(self, tokens):
  41. tokenizer = self.eval_preprocessor.tokenizer
  42. return tokenizer.detokenize(tokens.tolist())
  43. def evaluation_step(self, data):
  44. model = self.model.module if self._dist else self.model
  45. model.eval()
  46. if 'inputs_len' in data:
  47. return self._generate_eval(model, data)
  48. else:
  49. return self._forward_eval(model, data)
  50. def _generate_eval(self, model: GPT3ForTextGeneration,
  51. data: Dict[str, Any]) -> Dict[str, Any]:
  52. # Force greedy decoding in non-open tasks
  53. data.update(top_k=1, top_p=0.)
  54. result = model.generate(data)
  55. prompts_len: List[int] = data['prompts_len']
  56. result['preds'] = [
  57. self._decode(seq[skip_len:])
  58. for seq, skip_len in zip(result['sequences'], prompts_len)
  59. ]
  60. data['tgts'] = [
  61. self._decode(seq[skip_len - 1:])
  62. for seq, skip_len in zip(data['labels'], prompts_len)
  63. ]
  64. return result
  65. def _forward_eval(self, model: GPT3ForTextGeneration,
  66. data: Dict[str, Any]) -> Dict[str, Any]:
  67. return model.forward(data)
  68. def build_model(self) -> TorchModel:
  69. return Model.from_pretrained(
  70. self.model_dir, cfg_dict=self.cfg, megatron_cfg=self.cfg.megatron)