| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import csv
- import os
- from typing import Dict, Optional, Union
- import numpy as np
- import speechbrain as sb
- import speechbrain.nnet.schedulers as schedulers
- import torch
- import torch.nn.functional as F
- import torchaudio
- from torch.cuda.amp import autocast
- from torch.utils.data import Dataset
- from tqdm import tqdm
- from modelscope.metainfo import Trainers
- from modelscope.models import Model, TorchModel
- from modelscope.msdatasets import MsDataset
- from modelscope.trainers.base import BaseTrainer
- from modelscope.trainers.builder import TRAINERS
- from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
- from modelscope.utils.device import create_device
- from modelscope.utils.logger import get_logger
- from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
- init_dist)
- EVAL_KEY = 'si-snr'
- logger = get_logger()
- @TRAINERS.register_module(module_name=Trainers.speech_separation)
- class SeparationTrainer(BaseTrainer):
- """A trainer is used for speech separation.
- Args:
- model: id or local path of the model
- work_dir: local path to store all training outputs
- cfg_file: config file of the model
- train_dataset: dataset for training
- eval_dataset: dataset for evaluation
- model_revision: the git version of model on modelhub
- """
- def __init__(self,
- model: str,
- work_dir: str,
- cfg_file: Optional[str] = None,
- train_dataset: Optional[Union[MsDataset, Dataset]] = None,
- eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
- model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
- **kwargs):
- if isinstance(model, str):
- self.model_dir = self.get_or_download_model_dir(
- model, model_revision)
- if cfg_file is None:
- cfg_file = os.path.join(self.model_dir,
- ModelFile.CONFIGURATION)
- else:
- assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
- self.model_dir = os.path.dirname(cfg_file)
- BaseTrainer.__init__(self, cfg_file)
- self.model = self.build_model()
- self.work_dir = work_dir
- if kwargs.get('launcher', None) is not None:
- init_dist(kwargs['launcher'])
- _, world_size = get_dist_info()
- self._dist = world_size > 1
- device_name = kwargs.get('device', 'gpu')
- if self._dist:
- local_rank = get_local_rank()
- device_name = f'cuda:{local_rank}'
- self.device = create_device(device_name)
- if 'max_epochs' not in kwargs:
- assert hasattr(
- self.cfg.train, 'max_epochs'
- ), 'max_epochs is missing from the configuration file'
- self._max_epochs = self.cfg.train.max_epochs
- else:
- self._max_epochs = kwargs['max_epochs']
- self.train_dataset = train_dataset
- self.eval_dataset = eval_dataset
- hparams_file = os.path.join(self.model_dir, 'hparams.yaml')
- overrides = {
- 'output_folder':
- self.work_dir,
- 'seed':
- self.cfg.train.seed,
- 'lr':
- self.cfg.train.optimizer.lr,
- 'weight_decay':
- self.cfg.train.optimizer.weight_decay,
- 'clip_grad_norm':
- self.cfg.train.optimizer.clip_grad_norm,
- 'factor':
- self.cfg.train.lr_scheduler.factor,
- 'patience':
- self.cfg.train.lr_scheduler.patience,
- 'dont_halve_until_epoch':
- self.cfg.train.lr_scheduler.dont_halve_until_epoch,
- }
- # load hyper params
- from hyperpyyaml import load_hyperpyyaml
- with open(hparams_file) as fin:
- self.hparams = load_hyperpyyaml(fin, overrides=overrides)
- # Create experiment directory
- sb.create_experiment_directory(
- experiment_directory=self.work_dir,
- hyperparams_to_save=hparams_file,
- overrides=overrides,
- )
- run_opts = {
- 'debug': False,
- 'device': 'cpu',
- 'data_parallel_backend': False,
- 'distributed_launch': False,
- 'distributed_backend': 'nccl',
- 'find_unused_parameters': False
- }
- if self.device.type == 'cuda':
- run_opts['device'] = f'{self.device.type}:{self.device.index}'
- self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs)
- self.hparams['epoch_counter'] = self.epoch_counter
- self.hparams['checkpointer'].add_recoverables(
- {'counter': self.epoch_counter})
- modules = self.model.as_dict()
- self.hparams['checkpointer'].add_recoverables(modules)
- # Brain class initialization
- self.separator = Separation(
- modules=modules,
- opt_class=self.hparams['optimizer'],
- hparams=self.hparams,
- run_opts=run_opts,
- checkpointer=self.hparams['checkpointer'],
- )
- def build_model(self) -> torch.nn.Module:
- """ Instantiate a pytorch model and return.
- """
- model = Model.from_pretrained(
- self.model_dir, cfg_dict=self.cfg, training=True)
- if isinstance(model, TorchModel) and hasattr(model, 'model'):
- return model.model
- elif isinstance(model, torch.nn.Module):
- return model
- def train(self, *args, **kwargs):
- self.separator.fit(
- self.epoch_counter,
- self.train_dataset,
- self.eval_dataset,
- train_loader_kwargs=self.hparams['dataloader_opts'],
- valid_loader_kwargs=self.hparams['dataloader_opts'],
- )
- def evaluate(self, checkpoint_path: str, *args,
- **kwargs) -> Dict[str, float]:
- if checkpoint_path:
- self.hparams.checkpointer.checkpoints_dir = checkpoint_path
- else:
- self.model.load_check_point(device=self.device)
- value = self.separator.evaluate(
- self.eval_dataset,
- test_loader_kwargs=self.hparams['dataloader_opts'],
- min_key=EVAL_KEY)
- return {EVAL_KEY: value}
- class Separation(sb.Brain):
- """A subclass of speechbrain.Brain implements training steps."""
- def compute_forward(self, mix, targets, stage, noise=None):
- """Forward computations from the mixture to the separated signals."""
- # Unpack lists and put tensors in the right device
- mix, mix_lens = mix
- mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
- # Convert targets to tensor
- targets = torch.cat(
- [
- targets[i][0].unsqueeze(-1)
- for i in range(self.hparams.num_spks)
- ],
- dim=-1,
- ).to(self.device)
- # Add speech distortions
- if stage == sb.Stage.TRAIN:
- with torch.no_grad():
- if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
- mix, targets = self.add_speed_perturb(targets, mix_lens)
- mix = targets.sum(-1)
- if self.hparams.use_wavedrop:
- mix = self.hparams.wavedrop(mix, mix_lens)
- if self.hparams.limit_training_signal_len:
- mix, targets = self.cut_signals(mix, targets)
- # Separation
- mix_w = self.modules['encoder'](mix)
- est_mask = self.modules['masknet'](mix_w)
- mix_w = torch.stack([mix_w] * self.hparams.num_spks)
- sep_h = mix_w * est_mask
- # Decoding
- est_source = torch.cat(
- [
- self.modules['decoder'](sep_h[i]).unsqueeze(-1)
- for i in range(self.hparams.num_spks)
- ],
- dim=-1,
- )
- # T changed after conv1d in encoder, fix it here
- T_origin = mix.size(1)
- T_est = est_source.size(1)
- if T_origin > T_est:
- est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
- else:
- est_source = est_source[:, :T_origin, :]
- return est_source, targets
- def compute_objectives(self, predictions, targets):
- """Computes the sinr loss"""
- return self.hparams.loss(targets, predictions)
- # yapf: disable
- def fit_batch(self, batch):
- """Trains one batch"""
- # Unpacking batch list
- mixture = batch.mix_sig
- targets = [batch.s1_sig, batch.s2_sig]
- if self.hparams.num_spks == 3:
- targets.append(batch.s3_sig)
- if self.auto_mix_prec:
- with autocast():
- predictions, targets = self.compute_forward(
- mixture, targets, sb.Stage.TRAIN)
- loss = self.compute_objectives(predictions, targets)
- # hard threshold the easy dataitems
- if self.hparams.threshold_byloss:
- th = self.hparams.threshold
- loss_to_keep = loss[loss > th]
- if loss_to_keep.nelement() > 0:
- loss = loss_to_keep.mean()
- else:
- print('loss has zero elements!!')
- else:
- loss = loss.mean()
- # the fix for computational problems
- if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
- self.scaler.scale(loss).backward()
- if self.hparams.clip_grad_norm >= 0:
- self.scaler.unscale_(self.optimizer)
- torch.nn.utils.clip_grad_norm_(
- self.modules.parameters(),
- self.hparams.clip_grad_norm,
- )
- self.scaler.step(self.optimizer)
- self.scaler.update()
- else:
- self.nonfinite_count += 1
- logger.info(
- 'infinite loss or empty loss! it happened {} times so far - skipping this batch'
- .format(self.nonfinite_count))
- loss.data = torch.tensor(0).to(self.device)
- else:
- predictions, targets = self.compute_forward(
- mixture, targets, sb.Stage.TRAIN)
- loss = self.compute_objectives(predictions, targets)
- if self.hparams.threshold_byloss:
- th = self.hparams.threshold
- loss_to_keep = loss[loss > th]
- if loss_to_keep.nelement() > 0:
- loss = loss_to_keep.mean()
- else:
- loss = loss.mean()
- # the fix for computational problems
- if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
- loss.backward()
- if self.hparams.clip_grad_norm >= 0:
- torch.nn.utils.clip_grad_norm_(self.modules.parameters(),
- self.hparams.clip_grad_norm)
- self.optimizer.step()
- else:
- self.nonfinite_count += 1
- logger.info(
- 'infinite loss or empty loss! it happened {} times so far - skipping this batch'
- .format(self.nonfinite_count))
- loss.data = torch.tensor(0).to(self.device)
- self.optimizer.zero_grad()
- return loss.detach().cpu()
- # yapf: enable
- def evaluate_batch(self, batch, stage):
- """Computations needed for validation/test batches"""
- snt_id = batch.id
- mixture = batch.mix_sig
- targets = [batch.s1_sig, batch.s2_sig]
- if self.hparams.num_spks == 3:
- targets.append(batch.s3_sig)
- with torch.no_grad():
- predictions, targets = self.compute_forward(
- mixture, targets, stage)
- loss = self.compute_objectives(predictions, targets)
- # Manage audio file saving
- if stage == sb.Stage.TEST and self.hparams.save_audio:
- if hasattr(self.hparams, 'n_audio_to_save'):
- if self.hparams.n_audio_to_save > 0:
- self.save_audio(snt_id[0], mixture, targets, predictions)
- self.hparams.n_audio_to_save += -1
- else:
- self.save_audio(snt_id[0], mixture, targets, predictions)
- return loss.mean().detach()
- def on_stage_end(self, stage, stage_loss, epoch):
- """Gets called at the end of a epoch."""
- # Compute/store important stats
- stage_stats = {'si-snr': stage_loss}
- if stage == sb.Stage.TRAIN:
- self.train_stats = stage_stats
- # Perform end-of-iteration things, like annealing, logging, etc.
- if stage == sb.Stage.VALID:
- # Learning rate annealing
- if isinstance(self.hparams.lr_scheduler,
- schedulers.ReduceLROnPlateau):
- current_lr, next_lr = self.hparams.lr_scheduler(
- [self.optimizer], epoch, stage_loss)
- schedulers.update_learning_rate(self.optimizer, next_lr)
- else:
- # if we do not use the reducelronplateau, we do not change the lr
- current_lr = self.hparams.optimizer.optim.param_groups[0]['lr']
- self.hparams.train_logger.log_stats(
- stats_meta={
- 'epoch': epoch,
- 'lr': current_lr
- },
- train_stats=self.train_stats,
- valid_stats=stage_stats,
- )
- self.checkpointer.save_and_keep_only(
- meta={'si-snr': stage_stats['si-snr']},
- min_keys=['si-snr'],
- )
- def add_speed_perturb(self, targets, targ_lens):
- """Adds speed perturbation and random_shift to the input signals"""
- min_len = -1
- recombine = False
- if self.hparams.use_speedperturb:
- # Performing speed change (independently on each source)
- new_targets = []
- recombine = True
- for i in range(targets.shape[-1]):
- new_target = self.hparams.speedperturb(targets[:, :, i],
- targ_lens)
- new_targets.append(new_target)
- if i == 0:
- min_len = new_target.shape[-1]
- else:
- if new_target.shape[-1] < min_len:
- min_len = new_target.shape[-1]
- if self.hparams.use_rand_shift:
- # Performing random_shift (independently on each source)
- recombine = True
- for i in range(targets.shape[-1]):
- rand_shift = torch.randint(self.hparams.min_shift,
- self.hparams.max_shift, (1, ))
- new_targets[i] = new_targets[i].to(self.device)
- new_targets[i] = torch.roll(
- new_targets[i], shifts=(rand_shift[0], ), dims=1)
- # Re-combination
- if recombine:
- if self.hparams.use_speedperturb:
- targets = torch.zeros(
- targets.shape[0],
- min_len,
- targets.shape[-1],
- device=targets.device,
- dtype=torch.float,
- )
- for i, new_target in enumerate(new_targets):
- targets[:, :, i] = new_targets[i][:, 0:min_len]
- mix = targets.sum(-1)
- return mix, targets
- def cut_signals(self, mixture, targets):
- """This function selects a random segment of a given length within the mixture.
- The corresponding targets are selected accordingly"""
- randstart = torch.randint(
- 0,
- 1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
- (1, ),
- ).item()
- targets = targets[:, randstart:randstart
- + self.hparams.training_signal_len, :]
- mixture = mixture[:, randstart:randstart
- + self.hparams.training_signal_len]
- return mixture, targets
- def reset_layer_recursively(self, layer):
- """Reinitializes the parameters of the neural networks"""
- if hasattr(layer, 'reset_parameters'):
- layer.reset_parameters()
- for child_layer in layer.modules():
- if layer != child_layer:
- self.reset_layer_recursively(child_layer)
- def save_results(self, test_data):
- """This script computes the SDR and SI-SNR metrics and saves
- them into a csv file"""
- # This package is required for SDR computation
- from mir_eval.separation import bss_eval_sources
- # Create folders where to store audio
- save_file = os.path.join(self.hparams.output_folder,
- 'test_results.csv')
- # Variable init
- all_sdrs = []
- all_sdrs_i = []
- all_sisnrs = []
- all_sisnrs_i = []
- csv_columns = ['snt_id', 'sdr', 'sdr_i', 'si-snr', 'si-snr_i']
- test_loader = sb.dataio.dataloader.make_dataloader(
- test_data, **self.hparams.dataloader_opts)
- with open(save_file, 'w') as results_csv:
- writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
- writer.writeheader()
- # Loop over all test sentence
- with tqdm(test_loader, dynamic_ncols=True) as t:
- for i, batch in enumerate(t):
- # Apply Separation
- mixture, mix_len = batch.mix_sig
- snt_id = batch.id
- targets = [batch.s1_sig, batch.s2_sig]
- if self.hparams.num_spks == 3:
- targets.append(batch.s3_sig)
- with torch.no_grad():
- predictions, targets = self.compute_forward(
- batch.mix_sig, targets, sb.Stage.TEST)
- # Compute SI-SNR
- sisnr = self.compute_objectives(predictions, targets)
- # Compute SI-SNR improvement
- mixture_signal = torch.stack(
- [mixture] * self.hparams.num_spks, dim=-1)
- mixture_signal = mixture_signal.to(targets.device)
- sisnr_baseline = self.compute_objectives(
- mixture_signal, targets)
- sisnr_i = sisnr.mean() - sisnr_baseline.mean()
- # Compute SDR
- sdr, _, _, _ = bss_eval_sources(
- targets[0].t().cpu().numpy(),
- predictions[0].t().detach().cpu().numpy(),
- )
- sdr_baseline, _, _, _ = bss_eval_sources(
- targets[0].t().cpu().numpy(),
- mixture_signal[0].t().detach().cpu().numpy(),
- )
- sdr_i = sdr.mean() - sdr_baseline.mean()
- # Saving on a csv file
- row = {
- 'snt_id': snt_id[0],
- 'sdr': sdr.mean(),
- 'sdr_i': sdr_i,
- 'si-snr': -sisnr.item(),
- 'si-snr_i': -sisnr_i.item(),
- }
- writer.writerow(row)
- # Metric Accumulation
- all_sdrs.append(sdr.mean())
- all_sdrs_i.append(sdr_i.mean())
- all_sisnrs.append(-sisnr.item())
- all_sisnrs_i.append(-sisnr_i.item())
- row = {
- 'snt_id': 'avg',
- 'sdr': np.array(all_sdrs).mean(),
- 'sdr_i': np.array(all_sdrs_i).mean(),
- 'si-snr': np.array(all_sisnrs).mean(),
- 'si-snr_i': np.array(all_sisnrs_i).mean(),
- }
- writer.writerow(row)
- logger.info('Mean SISNR is {}'.format(np.array(all_sisnrs).mean()))
- logger.info('Mean SISNRi is {}'.format(np.array(all_sisnrs_i).mean()))
- logger.info('Mean SDR is {}'.format(np.array(all_sdrs).mean()))
- logger.info('Mean SDRi is {}'.format(np.array(all_sdrs_i).mean()))
- def save_audio(self, snt_id, mixture, targets, predictions):
- 'saves the test audio (mixture, targets, and estimated sources) on disk'
- # Create output folder
- save_path = os.path.join(self.hparams.save_folder, 'audio_results')
- if not os.path.exists(save_path):
- os.mkdir(save_path)
- for ns in range(self.hparams.num_spks):
- # Estimated source
- signal = predictions[0, :, ns]
- signal = signal / signal.abs().max() * 0.5
- save_file = os.path.join(
- save_path, 'item{}_source{}hat.wav'.format(snt_id, ns + 1))
- torchaudio.save(save_file,
- signal.unsqueeze(0).cpu(),
- self.hparams.sample_rate)
- # Original source
- signal = targets[0, :, ns]
- signal = signal / signal.abs().max() * 0.5
- save_file = os.path.join(
- save_path, 'item{}_source{}.wav'.format(snt_id, ns + 1))
- torchaudio.save(save_file,
- signal.unsqueeze(0).cpu(),
- self.hparams.sample_rate)
- # Mixture
- signal = mixture[0][0, :]
- signal = signal / signal.abs().max() * 0.5
- save_file = os.path.join(save_path, 'item{}_mix.wav'.format(snt_id))
- torchaudio.save(save_file,
- signal.unsqueeze(0).cpu(), self.hparams.sample_rate)
|