dataset.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. import paddle
  16. from ..features import MFCC, LogMelSpectrogram, MelSpectrogram, Spectrogram
  17. feat_funcs = {
  18. 'raw': None,
  19. 'melspectrogram': MelSpectrogram,
  20. 'mfcc': MFCC,
  21. 'logmelspectrogram': LogMelSpectrogram,
  22. 'spectrogram': Spectrogram,
  23. }
  24. class AudioClassificationDataset(paddle.io.Dataset):
  25. """
  26. Base class of audio classification dataset.
  27. """
  28. def __init__(
  29. self,
  30. files: List[str],
  31. labels: List[int],
  32. feat_type: str = 'raw',
  33. sample_rate: int = None,
  34. **kwargs,
  35. ):
  36. """
  37. Ags:
  38. files (:obj:`List[str]`): A list of absolute path of audio files.
  39. labels (:obj:`List[int]`): Labels of audio files.
  40. feat_type (:obj:`str`, `optional`, defaults to `raw`):
  41. It identifies the feature type that user wants to extract an audio file.
  42. """
  43. super().__init__()
  44. if feat_type not in feat_funcs.keys():
  45. raise RuntimeError(
  46. f"Unknown feat_type: {feat_type}, it must be one in {list(feat_funcs.keys())}"
  47. )
  48. self.files = files
  49. self.labels = labels
  50. self.feat_type = feat_type
  51. self.sample_rate = sample_rate
  52. self.feat_config = (
  53. kwargs # Pass keyword arguments to customize feature config
  54. )
  55. def _get_data(self, input_file: str):
  56. raise NotImplementedError
  57. def _convert_to_record(self, idx):
  58. file, label = self.files[idx], self.labels[idx]
  59. waveform, sample_rate = paddle.audio.load(file)
  60. self.sample_rate = sample_rate
  61. feat_func = feat_funcs[self.feat_type]
  62. record = {}
  63. if len(waveform.shape) == 2:
  64. waveform = waveform.squeeze(0) # 1D input
  65. waveform = paddle.to_tensor(waveform, dtype=paddle.float32)
  66. if feat_func is not None:
  67. waveform = waveform.unsqueeze(0) # (batch_size, T)
  68. if self.feat_type != 'spectrogram':
  69. feature_extractor = feat_func(
  70. sr=self.sample_rate, **self.feat_config
  71. )
  72. else:
  73. feature_extractor = feat_func(**self.feat_config)
  74. record['feat'] = feature_extractor(waveform).squeeze(0)
  75. else:
  76. record['feat'] = waveform
  77. record['label'] = label
  78. return record
  79. def __getitem__(self, idx):
  80. record = self._convert_to_record(idx)
  81. return record['feat'], record['label']
  82. def __len__(self):
  83. return len(self.files)