| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from modelscope.metainfo import Trainers
- from modelscope.trainers import EpochBasedTrainer
- from modelscope.trainers.builder import TRAINERS
- from modelscope.utils.constant import TrainerStages
- from modelscope.utils.data_utils import to_device
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- @TRAINERS.register_module(module_name=Trainers.speech_frcrn_ans_cirm_16k)
- class ANSTrainer(EpochBasedTrainer):
- """
- A trainer is used for acoustic noise suppression.
- Override train_loop() to use dataset just one time.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- def train_loop(self, data_loader):
- """
- Update epoch by step number, based on super method.
- """
- self.invoke_hook(TrainerStages.before_run)
- self._epoch = 0
- kwargs = {}
- self.model.train()
- enumerated = enumerate(data_loader)
- for _ in range(self._epoch, self._max_epochs):
- self.invoke_hook(TrainerStages.before_train_epoch)
- self._inner_iter = 0
- for i, data_batch in enumerated:
- data_batch = to_device(data_batch, self.device)
- self.data_batch = data_batch
- self._inner_iter += 1
- self.invoke_hook(TrainerStages.before_train_iter)
- self.train_step(self.model, data_batch, **kwargs)
- self.invoke_hook(TrainerStages.after_train_iter)
- del self.data_batch
- self._iter += 1
- if self._inner_iter >= self.iters_per_epoch:
- break
- self.invoke_hook(TrainerStages.after_train_epoch)
- self._epoch += 1
- self.invoke_hook(TrainerStages.after_run)
- def prediction_step(self, model, inputs):
- pass
|