audio_noise_metric.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. from modelscope.metainfo import Metrics
  4. from modelscope.metrics.base import Metric
  5. from modelscope.metrics.builder import METRICS, MetricKeys
  6. from modelscope.utils.registry import default_group
  7. @METRICS.register_module(
  8. group_key=default_group, module_name=Metrics.audio_noise_metric)
  9. class AudioNoiseMetric(Metric):
  10. """
  11. The metric computation class for acoustic noise suppression task.
  12. """
  13. def __init__(self):
  14. self.loss = []
  15. self.amp_loss = []
  16. self.phase_loss = []
  17. self.sisnr = []
  18. def add(self, outputs: Dict, inputs: Dict):
  19. self.loss.append(outputs['loss'].data.cpu())
  20. self.amp_loss.append(outputs['amp_loss'].data.cpu())
  21. self.phase_loss.append(outputs['phase_loss'].data.cpu())
  22. self.sisnr.append(outputs['sisnr'].data.cpu())
  23. def evaluate(self):
  24. avg_loss = sum(self.loss) / len(self.loss)
  25. avg_sisnr = sum(self.sisnr) / len(self.sisnr)
  26. avg_amp = sum(self.amp_loss) / len(self.amp_loss)
  27. avg_phase = sum(self.phase_loss) / len(self.phase_loss)
  28. total_loss = avg_loss + avg_amp + avg_phase + avg_sisnr
  29. return {
  30. 'total_loss': total_loss.item(),
  31. # model use opposite number of sisnr as a calculation shortcut.
  32. # revert it in evaluation result
  33. 'avg_sisnr': -avg_sisnr.item(),
  34. MetricKeys.AVERAGE_LOSS: avg_loss.item()
  35. }
  36. def merge(self, other: 'AudioNoiseMetric'):
  37. self.loss.extend(other.loss)
  38. self.amp_loss.extend(other.amp_loss)
  39. self.phase_loss.extend(other.phase_loss)
  40. self.sisnr.extend(other.sisnr)
  41. def __getstate__(self):
  42. return self.loss, self.amp_loss, self.phase_loss, self.sisnr
  43. def __setstate__(self, state):
  44. self.loss, self.amp_loss, self.phase_loss, self.sisnr = state