tess.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import os
  16. from typing import List, Tuple
  17. from paddle.dataset.common import DATA_HOME
  18. from paddle.utils import download
  19. from .dataset import AudioClassificationDataset
  20. __all__ = []
  21. class TESS(AudioClassificationDataset):
  22. """
  23. TESS is a set of 200 target words were spoken in the carrier phrase
  24. "Say the word _____' by two actresses (aged 26 and 64 years) and
  25. recordings were made of the set portraying each of seven emotions(anger,
  26. disgust, fear, happiness, pleasant surprise, sadness, and neutral).
  27. There are 2800 stimuli in total.
  28. Reference:
  29. Toronto emotional speech set (TESS) https://tspace.library.utoronto.ca/handle/1807/24487
  30. https://doi.org/10.5683/SP2/E8H2MF
  31. Args:
  32. mode (str, optional): It identifies the dataset mode (train or dev). Defaults to train.
  33. n_folds (int, optional): Split the dataset into n folds. 1 fold for dev dataset and n-1 for train dataset. Defaults to 5.
  34. split (int, optional): It specify the fold of dev dataset. Defaults to 1.
  35. feat_type (str, optional): It identifies the feature type that user wants to extract of an audio file. Defaults to raw.
  36. archive(dict): it tells where to download the audio archive. Defaults to None.
  37. Returns:
  38. :ref:`api_paddle_io_Dataset`. An instance of TESS dataset.
  39. Examples:
  40. .. code-block:: python
  41. >>> import paddle
  42. >>> mode = 'dev'
  43. >>> tess_dataset = paddle.audio.datasets.TESS(mode=mode,
  44. ... feat_type='raw')
  45. >>> for idx in range(5):
  46. ... audio, label = tess_dataset[idx]
  47. ... # do something with audio, label
  48. ... print(audio.shape, label)
  49. ... # [audio_data_length] , label_id
  50. >>> tess_dataset = paddle.audio.datasets.TESS(mode=mode,
  51. ... feat_type='mfcc',
  52. ... n_mfcc=40)
  53. >>> for idx in range(5):
  54. ... audio, label = tess_dataset[idx]
  55. ... # do something with mfcc feature, label
  56. ... print(audio.shape, label)
  57. ... # [feature_dim, num_frames] , label_id
  58. """
  59. archive = {
  60. 'url': 'https://bj.bcebos.com/paddleaudio/datasets/TESS_Toronto_emotional_speech_set.zip',
  61. 'md5': '1465311b24d1de704c4c63e4ccc470c7',
  62. }
  63. label_list = [
  64. 'angry',
  65. 'disgust',
  66. 'fear',
  67. 'happy',
  68. 'neutral',
  69. 'ps', # pleasant surprise
  70. 'sad',
  71. ]
  72. meta_info = collections.namedtuple(
  73. 'META_INFO', ('speaker', 'word', 'emotion')
  74. )
  75. audio_path = 'TESS_Toronto_emotional_speech_set'
  76. def __init__(
  77. self,
  78. mode: str = 'train',
  79. n_folds: int = 5,
  80. split: int = 1,
  81. feat_type: str = 'raw',
  82. archive=None,
  83. **kwargs,
  84. ):
  85. assert isinstance(n_folds, int) and (
  86. n_folds >= 1
  87. ), f'the n_folds should be integer and n_folds >= 1, but got {n_folds}'
  88. assert split in range(
  89. 1, n_folds + 1
  90. ), f'The selected split should be integer and should be 1 <= split <= {n_folds}, but got {split}'
  91. if archive is not None:
  92. self.archive = archive
  93. files, labels = self._get_data(mode, n_folds, split)
  94. super().__init__(
  95. files=files, labels=labels, feat_type=feat_type, **kwargs
  96. )
  97. def _get_meta_info(self, files) -> List[collections.namedtuple]:
  98. ret = []
  99. for file in files:
  100. basename_without_extend = os.path.basename(file)[:-4]
  101. ret.append(self.meta_info(*basename_without_extend.split('_')))
  102. return ret
  103. def _get_data(
  104. self, mode: str, n_folds: int, split: int
  105. ) -> Tuple[List[str], List[int]]:
  106. if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)):
  107. download.get_path_from_url(
  108. self.archive['url'],
  109. DATA_HOME,
  110. self.archive['md5'],
  111. decompress=True,
  112. )
  113. wav_files = []
  114. for root, _, files in os.walk(os.path.join(DATA_HOME, self.audio_path)):
  115. for file in files:
  116. if file.endswith('.wav'):
  117. wav_files.append(os.path.join(root, file))
  118. meta_info = self._get_meta_info(wav_files)
  119. files = []
  120. labels = []
  121. for idx, sample in enumerate(meta_info):
  122. _, _, emotion = sample
  123. target = self.label_list.index(emotion)
  124. fold = idx % n_folds + 1
  125. if mode == 'train' and int(fold) != split:
  126. files.append(wav_files[idx])
  127. labels.append(target)
  128. if mode != 'train' and int(fold) == split:
  129. files.append(wav_files[idx])
  130. labels.append(target)
  131. return files, labels