ans_trainer.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import Trainers
  3. from modelscope.trainers import EpochBasedTrainer
  4. from modelscope.trainers.builder import TRAINERS
  5. from modelscope.utils.constant import TrainerStages
  6. from modelscope.utils.data_utils import to_device
  7. from modelscope.utils.logger import get_logger
  8. logger = get_logger()
  9. @TRAINERS.register_module(module_name=Trainers.speech_frcrn_ans_cirm_16k)
  10. class ANSTrainer(EpochBasedTrainer):
  11. """
  12. A trainer is used for acoustic noise suppression.
  13. Override train_loop() to use dataset just one time.
  14. """
  15. def __init__(self, *args, **kwargs):
  16. super().__init__(*args, **kwargs)
  17. def train_loop(self, data_loader):
  18. """
  19. Update epoch by step number, based on super method.
  20. """
  21. self.invoke_hook(TrainerStages.before_run)
  22. self._epoch = 0
  23. kwargs = {}
  24. self.model.train()
  25. enumerated = enumerate(data_loader)
  26. for _ in range(self._epoch, self._max_epochs):
  27. self.invoke_hook(TrainerStages.before_train_epoch)
  28. self._inner_iter = 0
  29. for i, data_batch in enumerated:
  30. data_batch = to_device(data_batch, self.device)
  31. self.data_batch = data_batch
  32. self._inner_iter += 1
  33. self.invoke_hook(TrainerStages.before_train_iter)
  34. self.train_step(self.model, data_batch, **kwargs)
  35. self.invoke_hook(TrainerStages.after_train_iter)
  36. del self.data_batch
  37. self._iter += 1
  38. if self._inner_iter >= self.iters_per_epoch:
  39. break
  40. self.invoke_hook(TrainerStages.after_train_epoch)
  41. self._epoch += 1
  42. self.invoke_hook(TrainerStages.after_run)
  43. def prediction_step(self, model, inputs):
  44. pass