audio.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import os
  4. from typing import Any, Dict, Tuple, Union
  5. import numpy as np
  6. import scipy.io.wavfile as wav
  7. import torch
  8. from modelscope.fileio import File
  9. from modelscope.preprocessors import Preprocessor
  10. from modelscope.preprocessors.builder import PREPROCESSORS
  11. from modelscope.utils.constant import Fields, ModeKeys
  12. class AudioBrainPreprocessor(Preprocessor):
  13. """A preprocessor takes audio file path and reads it into tensor
  14. Args:
  15. takes: the audio file field name
  16. provides: the tensor field name
  17. mode: process mode, default 'inference'
  18. """
  19. def __init__(self,
  20. takes: str,
  21. provides: str,
  22. mode=ModeKeys.INFERENCE,
  23. *args,
  24. **kwargs):
  25. super(AudioBrainPreprocessor, self).__init__(mode, *args, **kwargs)
  26. self.takes = takes
  27. self.provides = provides
  28. import speechbrain as sb
  29. self.read_audio = sb.dataio.dataio.read_audio
  30. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  31. result = self.read_audio(data[self.takes])
  32. data[self.provides] = result
  33. return data
  34. def load_kaldi_feature_transform(filename):
  35. fp = open(filename, 'r', encoding='utf-8')
  36. all_str = fp.read()
  37. pos1 = all_str.find('AddShift')
  38. pos2 = all_str.find('[', pos1)
  39. pos3 = all_str.find(']', pos2)
  40. mean = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ')
  41. pos1 = all_str.find('Rescale')
  42. pos2 = all_str.find('[', pos1)
  43. pos3 = all_str.find(']', pos2)
  44. scale = np.fromstring(all_str[pos2 + 1:pos3], dtype=np.float32, sep=' ')
  45. fp.close()
  46. return mean, scale
  47. class Feature:
  48. r"""Extract feat from one utterance.
  49. """
  50. def __init__(self,
  51. fbank_config,
  52. feat_type='spec',
  53. mvn_file=None,
  54. cuda=False):
  55. r"""
  56. Args:
  57. fbank_config (dict):
  58. feat_type (str):
  59. raw: do nothing
  60. fbank: use kaldi.fbank
  61. spec: Real/Imag
  62. logpow: log(1+|x|^2)
  63. mvn_file (str): the path of data file for mean variance normalization
  64. cuda:
  65. """
  66. self.fbank_config = fbank_config
  67. self.feat_type = feat_type
  68. self.n_fft = fbank_config['frame_length'] * fbank_config[
  69. 'sample_frequency'] // 1000
  70. self.hop_length = fbank_config['frame_shift'] * fbank_config[
  71. 'sample_frequency'] // 1000
  72. self.window = torch.hamming_window(self.n_fft, periodic=False)
  73. self.mvn = False
  74. if mvn_file is not None and os.path.exists(mvn_file):
  75. print(f'loading mvn file: {mvn_file}')
  76. shift, scale = load_kaldi_feature_transform(mvn_file)
  77. self.shift = torch.from_numpy(shift)
  78. self.scale = torch.from_numpy(scale)
  79. self.mvn = True
  80. if cuda:
  81. self.window = self.window.cuda()
  82. if self.mvn:
  83. self.shift = self.shift.cuda()
  84. self.scale = self.scale.cuda()
  85. def compute(self, utt):
  86. r"""
  87. Args:
  88. utt: in [-32768, 32767] range
  89. Returns:
  90. [..., T, F]
  91. """
  92. if self.feat_type == 'raw':
  93. return utt
  94. elif self.feat_type == 'fbank':
  95. # have to use local import before modelscope framework support lazy loading
  96. import torchaudio.compliance.kaldi as kaldi
  97. if len(utt.shape) == 1:
  98. utt = utt.unsqueeze(0)
  99. feat = kaldi.fbank(utt, **self.fbank_config)
  100. elif self.feat_type == 'spec':
  101. spec = torch.stft(
  102. utt / 32768,
  103. self.n_fft,
  104. self.hop_length,
  105. self.n_fft,
  106. self.window,
  107. center=False,
  108. return_complex=True)
  109. feat = torch.cat([spec.real, spec.imag], dim=-2).permute(-1, -2)
  110. elif self.feat_type == 'logpow':
  111. spec = torch.stft(
  112. utt,
  113. self.n_fft,
  114. self.hop_length,
  115. self.n_fft,
  116. self.window,
  117. center=False,
  118. return_complex=True)
  119. abspow = torch.abs(spec)**2
  120. feat = torch.log(1 + abspow).permute(-1, -2)
  121. return feat
  122. def normalize(self, feat):
  123. if self.mvn:
  124. feat = feat + self.shift
  125. feat = feat * self.scale
  126. return feat
  127. @PREPROCESSORS.register_module(Fields.audio)
  128. class LinearAECAndFbank(Preprocessor):
  129. SAMPLE_RATE = 16000
  130. def __init__(self, io_config):
  131. import MinDAEC
  132. self.trunc_length = 7200 * self.SAMPLE_RATE
  133. self.linear_aec_delay = io_config['linear_aec_delay']
  134. self.feature = Feature(io_config['fbank_config'],
  135. io_config['feat_type'], io_config['mvn'])
  136. self.mitaec = MinDAEC.load()
  137. self.mask_on_mic = io_config['mask_on'] == 'nearend_mic'
  138. def __call__(self, data: Union[Tuple, Dict[str, Any]]) -> Dict[str, Any]:
  139. """ Linear filtering the near end mic and far end audio, then extract the feature.
  140. Args:
  141. data: Dict with two keys and correspond audios: "nearend_mic" and "farend_speech".
  142. Returns:
  143. Dict with two keys and Tensor values: "base" linear filtered audio,and "feature"
  144. """
  145. if isinstance(data, tuple):
  146. nearend_mic, fs = self.load_wav(data[0])
  147. farend_speech, fs = self.load_wav(data[1])
  148. nearend_speech = np.zeros_like(nearend_mic)
  149. else:
  150. # read files
  151. nearend_mic, fs = self.load_wav(data['nearend_mic'])
  152. farend_speech, fs = self.load_wav(data['farend_speech'])
  153. if 'nearend_speech' in data:
  154. nearend_speech, fs = self.load_wav(data['nearend_speech'])
  155. else:
  156. nearend_speech = np.zeros_like(nearend_mic)
  157. out_mic, out_ref, out_linear, out_echo = self.mitaec.do_linear_aec(
  158. nearend_mic, farend_speech)
  159. # fix 20ms linear aec delay by delaying the target speech
  160. extra_zeros = np.zeros([int(self.linear_aec_delay * fs)])
  161. nearend_speech = np.concatenate([extra_zeros, nearend_speech])
  162. # truncate files to the same length
  163. flen = min(
  164. len(out_mic), len(out_ref), len(out_linear), len(out_echo),
  165. len(nearend_speech))
  166. fstart = 0
  167. flen = min(flen, self.trunc_length)
  168. nearend_mic, out_ref, out_linear, out_echo, nearend_speech = (
  169. out_mic[fstart:flen], out_ref[fstart:flen],
  170. out_linear[fstart:flen], out_echo[fstart:flen],
  171. nearend_speech[fstart:flen])
  172. # extract features (frames, [mic, linear, ref, aes?])
  173. feat = torch.FloatTensor()
  174. nearend_mic = torch.from_numpy(np.float32(nearend_mic))
  175. fbank_nearend_mic = self.feature.compute(nearend_mic)
  176. feat = torch.cat([feat, fbank_nearend_mic], dim=1)
  177. out_linear = torch.from_numpy(np.float32(out_linear))
  178. fbank_out_linear = self.feature.compute(out_linear)
  179. feat = torch.cat([feat, fbank_out_linear], dim=1)
  180. out_echo = torch.from_numpy(np.float32(out_echo))
  181. fbank_out_echo = self.feature.compute(out_echo)
  182. feat = torch.cat([feat, fbank_out_echo], dim=1)
  183. # feature transform
  184. feat = self.feature.normalize(feat)
  185. # prepare target
  186. if nearend_speech is not None:
  187. nearend_speech = torch.from_numpy(np.float32(nearend_speech))
  188. if self.mask_on_mic:
  189. base = nearend_mic
  190. else:
  191. base = out_linear
  192. out_data = {'base': base, 'target': nearend_speech, 'feature': feat}
  193. return out_data
  194. @staticmethod
  195. def load_wav(inputs):
  196. import librosa
  197. if isinstance(inputs, bytes):
  198. inputs = io.BytesIO(inputs)
  199. elif isinstance(inputs, str):
  200. file_bytes = File.read(inputs)
  201. inputs = io.BytesIO(file_bytes)
  202. else:
  203. raise TypeError(f'Unsupported input type: {type(inputs)}.')
  204. sample_rate, data = wav.read(inputs)
  205. if len(data.shape) > 1:
  206. raise ValueError('modelscope error:The audio must be mono.')
  207. if sample_rate != LinearAECAndFbank.SAMPLE_RATE:
  208. data = librosa.resample(data, sample_rate,
  209. LinearAECAndFbank.SAMPLE_RATE)
  210. return data.astype(np.float32), LinearAECAndFbank.SAMPLE_RATE