text_generation_trainer.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from modelscope.metainfo import Metrics, Trainers
  5. from modelscope.outputs.outputs import ModelOutputBase
  6. from modelscope.trainers import NlpEpochBasedTrainer
  7. from modelscope.trainers.builder import TRAINERS
  8. @TRAINERS.register_module(module_name=Trainers.text_generation_trainer)
  9. class TextGenerationTrainer(NlpEpochBasedTrainer):
  10. def _decode(self, tokens):
  11. return self.eval_preprocessor.decode(
  12. tokens.tolist(), skip_special_tokens=True)
  13. def evaluation_step(self, data):
  14. model = self.model.module if self._dist else self.model
  15. model.eval()
  16. output = dict()
  17. with torch.no_grad():
  18. if Metrics.text_gen_metric in self.metrics:
  19. output.update(self._eval_genarate(model, data))
  20. if Metrics.PPL in self.metrics or Metrics.loss_metric in self.metrics:
  21. output.update(model.forward(**data))
  22. return output
  23. def _eval_genarate(self, model, data) -> Dict[str, Any]:
  24. result = model.generate(data)
  25. if isinstance(result, ModelOutputBase):
  26. result = result.to_dict()
  27. result['preds'] = [self._decode(seq) for seq in result['sequences']]
  28. data['tgts'] = [self._decode(seq) for seq in data['labels']]
  29. assert len(result['preds']) == len(data['tgts'])
  30. return result