separation_trainer.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import csv
  3. import os
  4. from typing import Dict, Optional, Union
  5. import numpy as np
  6. import speechbrain as sb
  7. import speechbrain.nnet.schedulers as schedulers
  8. import torch
  9. import torch.nn.functional as F
  10. import torchaudio
  11. from torch.cuda.amp import autocast
  12. from torch.utils.data import Dataset
  13. from tqdm import tqdm
  14. from modelscope.metainfo import Trainers
  15. from modelscope.models import Model, TorchModel
  16. from modelscope.msdatasets import MsDataset
  17. from modelscope.trainers.base import BaseTrainer
  18. from modelscope.trainers.builder import TRAINERS
  19. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
  20. from modelscope.utils.device import create_device
  21. from modelscope.utils.logger import get_logger
  22. from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
  23. init_dist)
  24. EVAL_KEY = 'si-snr'
  25. logger = get_logger()
  26. @TRAINERS.register_module(module_name=Trainers.speech_separation)
  27. class SeparationTrainer(BaseTrainer):
  28. """A trainer is used for speech separation.
  29. Args:
  30. model: id or local path of the model
  31. work_dir: local path to store all training outputs
  32. cfg_file: config file of the model
  33. train_dataset: dataset for training
  34. eval_dataset: dataset for evaluation
  35. model_revision: the git version of model on modelhub
  36. """
  37. def __init__(self,
  38. model: str,
  39. work_dir: str,
  40. cfg_file: Optional[str] = None,
  41. train_dataset: Optional[Union[MsDataset, Dataset]] = None,
  42. eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
  43. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  44. **kwargs):
  45. if isinstance(model, str):
  46. self.model_dir = self.get_or_download_model_dir(
  47. model, model_revision)
  48. if cfg_file is None:
  49. cfg_file = os.path.join(self.model_dir,
  50. ModelFile.CONFIGURATION)
  51. else:
  52. assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
  53. self.model_dir = os.path.dirname(cfg_file)
  54. BaseTrainer.__init__(self, cfg_file)
  55. self.model = self.build_model()
  56. self.work_dir = work_dir
  57. if kwargs.get('launcher', None) is not None:
  58. init_dist(kwargs['launcher'])
  59. _, world_size = get_dist_info()
  60. self._dist = world_size > 1
  61. device_name = kwargs.get('device', 'gpu')
  62. if self._dist:
  63. local_rank = get_local_rank()
  64. device_name = f'cuda:{local_rank}'
  65. self.device = create_device(device_name)
  66. if 'max_epochs' not in kwargs:
  67. assert hasattr(
  68. self.cfg.train, 'max_epochs'
  69. ), 'max_epochs is missing from the configuration file'
  70. self._max_epochs = self.cfg.train.max_epochs
  71. else:
  72. self._max_epochs = kwargs['max_epochs']
  73. self.train_dataset = train_dataset
  74. self.eval_dataset = eval_dataset
  75. hparams_file = os.path.join(self.model_dir, 'hparams.yaml')
  76. overrides = {
  77. 'output_folder':
  78. self.work_dir,
  79. 'seed':
  80. self.cfg.train.seed,
  81. 'lr':
  82. self.cfg.train.optimizer.lr,
  83. 'weight_decay':
  84. self.cfg.train.optimizer.weight_decay,
  85. 'clip_grad_norm':
  86. self.cfg.train.optimizer.clip_grad_norm,
  87. 'factor':
  88. self.cfg.train.lr_scheduler.factor,
  89. 'patience':
  90. self.cfg.train.lr_scheduler.patience,
  91. 'dont_halve_until_epoch':
  92. self.cfg.train.lr_scheduler.dont_halve_until_epoch,
  93. }
  94. # load hyper params
  95. from hyperpyyaml import load_hyperpyyaml
  96. with open(hparams_file) as fin:
  97. self.hparams = load_hyperpyyaml(fin, overrides=overrides)
  98. # Create experiment directory
  99. sb.create_experiment_directory(
  100. experiment_directory=self.work_dir,
  101. hyperparams_to_save=hparams_file,
  102. overrides=overrides,
  103. )
  104. run_opts = {
  105. 'debug': False,
  106. 'device': 'cpu',
  107. 'data_parallel_backend': False,
  108. 'distributed_launch': False,
  109. 'distributed_backend': 'nccl',
  110. 'find_unused_parameters': False
  111. }
  112. if self.device.type == 'cuda':
  113. run_opts['device'] = f'{self.device.type}:{self.device.index}'
  114. self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs)
  115. self.hparams['epoch_counter'] = self.epoch_counter
  116. self.hparams['checkpointer'].add_recoverables(
  117. {'counter': self.epoch_counter})
  118. modules = self.model.as_dict()
  119. self.hparams['checkpointer'].add_recoverables(modules)
  120. # Brain class initialization
  121. self.separator = Separation(
  122. modules=modules,
  123. opt_class=self.hparams['optimizer'],
  124. hparams=self.hparams,
  125. run_opts=run_opts,
  126. checkpointer=self.hparams['checkpointer'],
  127. )
  128. def build_model(self) -> torch.nn.Module:
  129. """ Instantiate a pytorch model and return.
  130. """
  131. model = Model.from_pretrained(
  132. self.model_dir, cfg_dict=self.cfg, training=True)
  133. if isinstance(model, TorchModel) and hasattr(model, 'model'):
  134. return model.model
  135. elif isinstance(model, torch.nn.Module):
  136. return model
  137. def train(self, *args, **kwargs):
  138. self.separator.fit(
  139. self.epoch_counter,
  140. self.train_dataset,
  141. self.eval_dataset,
  142. train_loader_kwargs=self.hparams['dataloader_opts'],
  143. valid_loader_kwargs=self.hparams['dataloader_opts'],
  144. )
  145. def evaluate(self, checkpoint_path: str, *args,
  146. **kwargs) -> Dict[str, float]:
  147. if checkpoint_path:
  148. self.hparams.checkpointer.checkpoints_dir = checkpoint_path
  149. else:
  150. self.model.load_check_point(device=self.device)
  151. value = self.separator.evaluate(
  152. self.eval_dataset,
  153. test_loader_kwargs=self.hparams['dataloader_opts'],
  154. min_key=EVAL_KEY)
  155. return {EVAL_KEY: value}
  156. class Separation(sb.Brain):
  157. """A subclass of speechbrain.Brain implements training steps."""
  158. def compute_forward(self, mix, targets, stage, noise=None):
  159. """Forward computations from the mixture to the separated signals."""
  160. # Unpack lists and put tensors in the right device
  161. mix, mix_lens = mix
  162. mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
  163. # Convert targets to tensor
  164. targets = torch.cat(
  165. [
  166. targets[i][0].unsqueeze(-1)
  167. for i in range(self.hparams.num_spks)
  168. ],
  169. dim=-1,
  170. ).to(self.device)
  171. # Add speech distortions
  172. if stage == sb.Stage.TRAIN:
  173. with torch.no_grad():
  174. if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
  175. mix, targets = self.add_speed_perturb(targets, mix_lens)
  176. mix = targets.sum(-1)
  177. if self.hparams.use_wavedrop:
  178. mix = self.hparams.wavedrop(mix, mix_lens)
  179. if self.hparams.limit_training_signal_len:
  180. mix, targets = self.cut_signals(mix, targets)
  181. # Separation
  182. mix_w = self.modules['encoder'](mix)
  183. est_mask = self.modules['masknet'](mix_w)
  184. mix_w = torch.stack([mix_w] * self.hparams.num_spks)
  185. sep_h = mix_w * est_mask
  186. # Decoding
  187. est_source = torch.cat(
  188. [
  189. self.modules['decoder'](sep_h[i]).unsqueeze(-1)
  190. for i in range(self.hparams.num_spks)
  191. ],
  192. dim=-1,
  193. )
  194. # T changed after conv1d in encoder, fix it here
  195. T_origin = mix.size(1)
  196. T_est = est_source.size(1)
  197. if T_origin > T_est:
  198. est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
  199. else:
  200. est_source = est_source[:, :T_origin, :]
  201. return est_source, targets
  202. def compute_objectives(self, predictions, targets):
  203. """Computes the sinr loss"""
  204. return self.hparams.loss(targets, predictions)
  205. # yapf: disable
  206. def fit_batch(self, batch):
  207. """Trains one batch"""
  208. # Unpacking batch list
  209. mixture = batch.mix_sig
  210. targets = [batch.s1_sig, batch.s2_sig]
  211. if self.hparams.num_spks == 3:
  212. targets.append(batch.s3_sig)
  213. if self.auto_mix_prec:
  214. with autocast():
  215. predictions, targets = self.compute_forward(
  216. mixture, targets, sb.Stage.TRAIN)
  217. loss = self.compute_objectives(predictions, targets)
  218. # hard threshold the easy dataitems
  219. if self.hparams.threshold_byloss:
  220. th = self.hparams.threshold
  221. loss_to_keep = loss[loss > th]
  222. if loss_to_keep.nelement() > 0:
  223. loss = loss_to_keep.mean()
  224. else:
  225. print('loss has zero elements!!')
  226. else:
  227. loss = loss.mean()
  228. # the fix for computational problems
  229. if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
  230. self.scaler.scale(loss).backward()
  231. if self.hparams.clip_grad_norm >= 0:
  232. self.scaler.unscale_(self.optimizer)
  233. torch.nn.utils.clip_grad_norm_(
  234. self.modules.parameters(),
  235. self.hparams.clip_grad_norm,
  236. )
  237. self.scaler.step(self.optimizer)
  238. self.scaler.update()
  239. else:
  240. self.nonfinite_count += 1
  241. logger.info(
  242. 'infinite loss or empty loss! it happened {} times so far - skipping this batch'
  243. .format(self.nonfinite_count))
  244. loss.data = torch.tensor(0).to(self.device)
  245. else:
  246. predictions, targets = self.compute_forward(
  247. mixture, targets, sb.Stage.TRAIN)
  248. loss = self.compute_objectives(predictions, targets)
  249. if self.hparams.threshold_byloss:
  250. th = self.hparams.threshold
  251. loss_to_keep = loss[loss > th]
  252. if loss_to_keep.nelement() > 0:
  253. loss = loss_to_keep.mean()
  254. else:
  255. loss = loss.mean()
  256. # the fix for computational problems
  257. if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
  258. loss.backward()
  259. if self.hparams.clip_grad_norm >= 0:
  260. torch.nn.utils.clip_grad_norm_(self.modules.parameters(),
  261. self.hparams.clip_grad_norm)
  262. self.optimizer.step()
  263. else:
  264. self.nonfinite_count += 1
  265. logger.info(
  266. 'infinite loss or empty loss! it happened {} times so far - skipping this batch'
  267. .format(self.nonfinite_count))
  268. loss.data = torch.tensor(0).to(self.device)
  269. self.optimizer.zero_grad()
  270. return loss.detach().cpu()
  271. # yapf: enable
  272. def evaluate_batch(self, batch, stage):
  273. """Computations needed for validation/test batches"""
  274. snt_id = batch.id
  275. mixture = batch.mix_sig
  276. targets = [batch.s1_sig, batch.s2_sig]
  277. if self.hparams.num_spks == 3:
  278. targets.append(batch.s3_sig)
  279. with torch.no_grad():
  280. predictions, targets = self.compute_forward(
  281. mixture, targets, stage)
  282. loss = self.compute_objectives(predictions, targets)
  283. # Manage audio file saving
  284. if stage == sb.Stage.TEST and self.hparams.save_audio:
  285. if hasattr(self.hparams, 'n_audio_to_save'):
  286. if self.hparams.n_audio_to_save > 0:
  287. self.save_audio(snt_id[0], mixture, targets, predictions)
  288. self.hparams.n_audio_to_save += -1
  289. else:
  290. self.save_audio(snt_id[0], mixture, targets, predictions)
  291. return loss.mean().detach()
  292. def on_stage_end(self, stage, stage_loss, epoch):
  293. """Gets called at the end of a epoch."""
  294. # Compute/store important stats
  295. stage_stats = {'si-snr': stage_loss}
  296. if stage == sb.Stage.TRAIN:
  297. self.train_stats = stage_stats
  298. # Perform end-of-iteration things, like annealing, logging, etc.
  299. if stage == sb.Stage.VALID:
  300. # Learning rate annealing
  301. if isinstance(self.hparams.lr_scheduler,
  302. schedulers.ReduceLROnPlateau):
  303. current_lr, next_lr = self.hparams.lr_scheduler(
  304. [self.optimizer], epoch, stage_loss)
  305. schedulers.update_learning_rate(self.optimizer, next_lr)
  306. else:
  307. # if we do not use the reducelronplateau, we do not change the lr
  308. current_lr = self.hparams.optimizer.optim.param_groups[0]['lr']
  309. self.hparams.train_logger.log_stats(
  310. stats_meta={
  311. 'epoch': epoch,
  312. 'lr': current_lr
  313. },
  314. train_stats=self.train_stats,
  315. valid_stats=stage_stats,
  316. )
  317. self.checkpointer.save_and_keep_only(
  318. meta={'si-snr': stage_stats['si-snr']},
  319. min_keys=['si-snr'],
  320. )
  321. def add_speed_perturb(self, targets, targ_lens):
  322. """Adds speed perturbation and random_shift to the input signals"""
  323. min_len = -1
  324. recombine = False
  325. if self.hparams.use_speedperturb:
  326. # Performing speed change (independently on each source)
  327. new_targets = []
  328. recombine = True
  329. for i in range(targets.shape[-1]):
  330. new_target = self.hparams.speedperturb(targets[:, :, i],
  331. targ_lens)
  332. new_targets.append(new_target)
  333. if i == 0:
  334. min_len = new_target.shape[-1]
  335. else:
  336. if new_target.shape[-1] < min_len:
  337. min_len = new_target.shape[-1]
  338. if self.hparams.use_rand_shift:
  339. # Performing random_shift (independently on each source)
  340. recombine = True
  341. for i in range(targets.shape[-1]):
  342. rand_shift = torch.randint(self.hparams.min_shift,
  343. self.hparams.max_shift, (1, ))
  344. new_targets[i] = new_targets[i].to(self.device)
  345. new_targets[i] = torch.roll(
  346. new_targets[i], shifts=(rand_shift[0], ), dims=1)
  347. # Re-combination
  348. if recombine:
  349. if self.hparams.use_speedperturb:
  350. targets = torch.zeros(
  351. targets.shape[0],
  352. min_len,
  353. targets.shape[-1],
  354. device=targets.device,
  355. dtype=torch.float,
  356. )
  357. for i, new_target in enumerate(new_targets):
  358. targets[:, :, i] = new_targets[i][:, 0:min_len]
  359. mix = targets.sum(-1)
  360. return mix, targets
  361. def cut_signals(self, mixture, targets):
  362. """This function selects a random segment of a given length within the mixture.
  363. The corresponding targets are selected accordingly"""
  364. randstart = torch.randint(
  365. 0,
  366. 1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
  367. (1, ),
  368. ).item()
  369. targets = targets[:, randstart:randstart
  370. + self.hparams.training_signal_len, :]
  371. mixture = mixture[:, randstart:randstart
  372. + self.hparams.training_signal_len]
  373. return mixture, targets
  374. def reset_layer_recursively(self, layer):
  375. """Reinitializes the parameters of the neural networks"""
  376. if hasattr(layer, 'reset_parameters'):
  377. layer.reset_parameters()
  378. for child_layer in layer.modules():
  379. if layer != child_layer:
  380. self.reset_layer_recursively(child_layer)
  381. def save_results(self, test_data):
  382. """This script computes the SDR and SI-SNR metrics and saves
  383. them into a csv file"""
  384. # This package is required for SDR computation
  385. from mir_eval.separation import bss_eval_sources
  386. # Create folders where to store audio
  387. save_file = os.path.join(self.hparams.output_folder,
  388. 'test_results.csv')
  389. # Variable init
  390. all_sdrs = []
  391. all_sdrs_i = []
  392. all_sisnrs = []
  393. all_sisnrs_i = []
  394. csv_columns = ['snt_id', 'sdr', 'sdr_i', 'si-snr', 'si-snr_i']
  395. test_loader = sb.dataio.dataloader.make_dataloader(
  396. test_data, **self.hparams.dataloader_opts)
  397. with open(save_file, 'w') as results_csv:
  398. writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
  399. writer.writeheader()
  400. # Loop over all test sentence
  401. with tqdm(test_loader, dynamic_ncols=True) as t:
  402. for i, batch in enumerate(t):
  403. # Apply Separation
  404. mixture, mix_len = batch.mix_sig
  405. snt_id = batch.id
  406. targets = [batch.s1_sig, batch.s2_sig]
  407. if self.hparams.num_spks == 3:
  408. targets.append(batch.s3_sig)
  409. with torch.no_grad():
  410. predictions, targets = self.compute_forward(
  411. batch.mix_sig, targets, sb.Stage.TEST)
  412. # Compute SI-SNR
  413. sisnr = self.compute_objectives(predictions, targets)
  414. # Compute SI-SNR improvement
  415. mixture_signal = torch.stack(
  416. [mixture] * self.hparams.num_spks, dim=-1)
  417. mixture_signal = mixture_signal.to(targets.device)
  418. sisnr_baseline = self.compute_objectives(
  419. mixture_signal, targets)
  420. sisnr_i = sisnr.mean() - sisnr_baseline.mean()
  421. # Compute SDR
  422. sdr, _, _, _ = bss_eval_sources(
  423. targets[0].t().cpu().numpy(),
  424. predictions[0].t().detach().cpu().numpy(),
  425. )
  426. sdr_baseline, _, _, _ = bss_eval_sources(
  427. targets[0].t().cpu().numpy(),
  428. mixture_signal[0].t().detach().cpu().numpy(),
  429. )
  430. sdr_i = sdr.mean() - sdr_baseline.mean()
  431. # Saving on a csv file
  432. row = {
  433. 'snt_id': snt_id[0],
  434. 'sdr': sdr.mean(),
  435. 'sdr_i': sdr_i,
  436. 'si-snr': -sisnr.item(),
  437. 'si-snr_i': -sisnr_i.item(),
  438. }
  439. writer.writerow(row)
  440. # Metric Accumulation
  441. all_sdrs.append(sdr.mean())
  442. all_sdrs_i.append(sdr_i.mean())
  443. all_sisnrs.append(-sisnr.item())
  444. all_sisnrs_i.append(-sisnr_i.item())
  445. row = {
  446. 'snt_id': 'avg',
  447. 'sdr': np.array(all_sdrs).mean(),
  448. 'sdr_i': np.array(all_sdrs_i).mean(),
  449. 'si-snr': np.array(all_sisnrs).mean(),
  450. 'si-snr_i': np.array(all_sisnrs_i).mean(),
  451. }
  452. writer.writerow(row)
  453. logger.info('Mean SISNR is {}'.format(np.array(all_sisnrs).mean()))
  454. logger.info('Mean SISNRi is {}'.format(np.array(all_sisnrs_i).mean()))
  455. logger.info('Mean SDR is {}'.format(np.array(all_sdrs).mean()))
  456. logger.info('Mean SDRi is {}'.format(np.array(all_sdrs_i).mean()))
  457. def save_audio(self, snt_id, mixture, targets, predictions):
  458. 'saves the test audio (mixture, targets, and estimated sources) on disk'
  459. # Create output folder
  460. save_path = os.path.join(self.hparams.save_folder, 'audio_results')
  461. if not os.path.exists(save_path):
  462. os.mkdir(save_path)
  463. for ns in range(self.hparams.num_spks):
  464. # Estimated source
  465. signal = predictions[0, :, ns]
  466. signal = signal / signal.abs().max() * 0.5
  467. save_file = os.path.join(
  468. save_path, 'item{}_source{}hat.wav'.format(snt_id, ns + 1))
  469. torchaudio.save(save_file,
  470. signal.unsqueeze(0).cpu(),
  471. self.hparams.sample_rate)
  472. # Original source
  473. signal = targets[0, :, ns]
  474. signal = signal / signal.abs().max() * 0.5
  475. save_file = os.path.join(
  476. save_path, 'item{}_source{}.wav'.format(snt_id, ns + 1))
  477. torchaudio.save(save_file,
  478. signal.unsqueeze(0).cpu(),
  479. self.hparams.sample_rate)
  480. # Mixture
  481. signal = mixture[0][0, :]
  482. signal = signal / signal.abs().max() * 0.5
  483. save_file = os.path.join(save_path, 'item{}_mix.wav'.format(snt_id))
  484. torchaudio.save(save_file,
  485. signal.unsqueeze(0).cpu(), self.hparams.sample_rate)