audio_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import re
  4. import struct
  5. import sys
  6. import tempfile
  7. from typing import Union
  8. from urllib.parse import urlparse
  9. import numpy as np
  10. from modelscope.fileio.file import HTTPStorage
  11. from modelscope.utils.file_utils import get_model_cache_root
  12. from modelscope.utils.hub import snapshot_download
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. SEGMENT_LENGTH_TRAIN = 16000
  16. SUPPORT_AUDIO_TYPE_SETS = ('.flac', '.mp3', '.ogg', '.opus', '.wav', '.pcm')
  17. class TtsTrainType(object):
  18. TRAIN_TYPE_SAMBERT = 'train-type-sambert'
  19. TRAIN_TYPE_BERT = 'train-type-bert'
  20. TRAIN_TYPE_VOC = 'train-type-voc'
  21. class TtsCustomParams(object):
  22. VOICE_NAME = 'voice_name'
  23. AM_CKPT = 'am_ckpt'
  24. VOC_CKPT = 'voc_ckpt'
  25. AM_CONFIG = 'am_config'
  26. VOC_CONFIG = 'voc_config'
  27. AUIDO_CONFIG = 'audio_config'
  28. SE_FILE = 'se_file'
  29. SE_MODEL = 'se_model'
  30. MVN_FILE = 'mvn_file'
  31. def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN):
  32. """
  33. Dataset mapping function to split one audio into segments.
  34. It only works in batch mode.
  35. """
  36. noisy_arrays = []
  37. clean_arrays = []
  38. for x, y in zip(batch['noisy'], batch['clean']):
  39. length = min(len(x['array']), len(y['array']))
  40. noisy = x['array']
  41. clean = y['array']
  42. for offset in range(segment_length, length + 1, segment_length):
  43. noisy_arrays.append(noisy[offset - segment_length:offset])
  44. clean_arrays.append(clean[offset - segment_length:offset])
  45. return {'noisy': noisy_arrays, 'clean': clean_arrays}
  46. def audio_norm(x):
  47. rms = (x**2).mean()**0.5
  48. scalar = 10**(-25 / 20) / rms
  49. x = x * scalar
  50. pow_x = x**2
  51. avg_pow_x = pow_x.mean()
  52. rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5
  53. scalarx = 10**(-25 / 20) / rmsx
  54. x = x * scalarx
  55. return x
  56. def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
  57. def repl(matched):
  58. key = matched.group(1)
  59. if key in conf_item:
  60. value = conf_item[key]
  61. if not isinstance(value, str):
  62. value = str(value)
  63. return value
  64. else:
  65. return None
  66. with open(origin_config_file, encoding='utf-8') as f:
  67. lines = f.readlines()
  68. with open(new_config_file, 'w') as f:
  69. for line in lines:
  70. line = re.sub(r'\$\{(.*)\}', repl, line)
  71. f.write(line)
  72. def extract_pcm_from_wav(wav: bytes) -> bytes:
  73. data = wav
  74. sample_rate = None
  75. if len(data) > 44:
  76. frame_len = 44
  77. file_len = len(data)
  78. try:
  79. header_fields = {}
  80. header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
  81. header_fields['Format'] = str(data[8:12], 'UTF-8')
  82. header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
  83. if header_fields['ChunkID'] == 'RIFF' and header_fields[
  84. 'Format'] == 'WAVE' and header_fields[
  85. 'Subchunk1ID'] == 'fmt ':
  86. header_fields['SubChunk1Size'] = struct.unpack(
  87. '<I', data[16:20])[0]
  88. header_fields['SampleRate'] = struct.unpack('<I',
  89. data[24:28])[0]
  90. sample_rate = header_fields['SampleRate']
  91. if header_fields['SubChunk1Size'] == 16:
  92. frame_len = 44
  93. elif header_fields['SubChunk1Size'] == 18:
  94. frame_len = 46
  95. else:
  96. return data, sample_rate
  97. data = wav[frame_len:file_len]
  98. except Exception:
  99. # no treatment
  100. pass
  101. return data, sample_rate
  102. def expect_token_number(instr, token):
  103. first_token = re.match(r'^\s*' + token, instr)
  104. if first_token is None:
  105. return None
  106. instr = instr[first_token.end():]
  107. lr = re.match(r'^\s*(-?\d+\.?\d*e?-?\d*?)', instr)
  108. if lr is None:
  109. return None
  110. return instr[lr.end():], lr.groups()[0]
  111. def expect_kaldi_matrix(instr):
  112. pos2 = instr.find('[', 0)
  113. pos3 = instr.find(']', pos2)
  114. mat = []
  115. for stt in instr[pos2 + 1:pos3].split('\n'):
  116. tmp_mat = np.fromstring(stt, dtype=np.float32, sep=' ')
  117. if tmp_mat.size > 0:
  118. mat.append(tmp_mat)
  119. return instr[pos3 + 1:], np.array(mat)
  120. # This implementation is adopted from scipy.io.wavfile.write,
  121. # made publicly available under the BSD-3-Clause license at
  122. # https://github.com/scipy/scipy/blob/v1.9.3/scipy/io/wavfile.py
  123. def ndarray_pcm_to_wav(fs: int, data: np.ndarray) -> bytes:
  124. dkind = data.dtype.kind
  125. if not (dkind == 'i' or dkind == 'f' or # noqa W504
  126. (dkind == 'u' and data.dtype.itemsize == 1)):
  127. raise ValueError(f'Unsupported data type {data.dtype}')
  128. header_data = bytearray()
  129. header_data += b'RIFF'
  130. header_data += b'\x00\x00\x00\x00'
  131. header_data += b'WAVE'
  132. header_data += b'fmt '
  133. if dkind == 'f':
  134. format_tag = 0x0003
  135. else:
  136. format_tag = 0x0001
  137. if data.ndim == 1:
  138. channels = 1
  139. else:
  140. channels = data.shape[1]
  141. bit_depth = data.dtype.itemsize * 8
  142. bytes_per_second = fs * (bit_depth // 8) * channels
  143. block_align = channels * (bit_depth // 8)
  144. fmt_chunk_data = struct.pack('<HHIIHH', format_tag, channels, fs,
  145. bytes_per_second, block_align, bit_depth)
  146. if not (dkind == 'i' or dkind == 'u'):
  147. fmt_chunk_data += b'\x00\x00'
  148. header_data += struct.pack('<I', len(fmt_chunk_data))
  149. header_data += fmt_chunk_data
  150. if not (dkind == 'i' or dkind == 'u'):
  151. header_data += b'fact'
  152. header_data += struct.pack('<II', 4, data.shape[0])
  153. if ((len(header_data) - 8) + (8 + data.nbytes)) > 0xFFFFFFFF:
  154. raise ValueError('Data exceeds wave file size limit')
  155. header_data += b'data'
  156. header_data += struct.pack('<I', data.nbytes)
  157. if data.dtype.byteorder == '>' or (data.dtype.byteorder == '='
  158. and sys.byteorder == 'big'):
  159. data = data.byteswap()
  160. header_data += data.ravel().view('b').data
  161. size = len(header_data)
  162. header_data[4:8] = struct.pack('<I', size - 8)
  163. return bytes(header_data)
  164. def load_bytes_from_url(url: str) -> Union[bytes, str]:
  165. sample_rate = None
  166. result = urlparse(url)
  167. if result.scheme is not None and len(result.scheme) > 0:
  168. storage = HTTPStorage()
  169. data = storage.read(url)
  170. data, sample_rate = extract_pcm_from_wav(data)
  171. else:
  172. data = url
  173. return data, sample_rate
  174. def generate_scp_from_url(url: str, key: str = None):
  175. wav_scp_path = None
  176. raw_inputs = None
  177. # for local inputs
  178. if os.path.exists(url):
  179. wav_scp_path = url
  180. return wav_scp_path, raw_inputs
  181. # for wav url, download bytes data
  182. if url.startswith('http'):
  183. result = urlparse(url)
  184. if result.scheme is not None and len(result.scheme) > 0:
  185. storage = HTTPStorage()
  186. # bytes
  187. data = storage.read(url)
  188. work_dir = tempfile.TemporaryDirectory().name
  189. if not os.path.exists(work_dir):
  190. os.makedirs(work_dir)
  191. wav_path = os.path.join(work_dir, os.path.basename(url))
  192. with open(wav_path, 'wb') as fb:
  193. fb.write(data)
  194. return wav_path, raw_inputs
  195. return wav_scp_path, raw_inputs
  196. def generate_text_from_url(url: str):
  197. text_file_path = None
  198. raw_inputs = None
  199. # for text str input
  200. if not os.path.exists(url) and not url.startswith('http'):
  201. raw_inputs = url
  202. return text_file_path, raw_inputs
  203. # for local txt inputs
  204. if os.path.exists(url) and (url.lower().endswith('.txt')
  205. or url.lower().endswith('.scp')):
  206. text_file_path = url
  207. return text_file_path, raw_inputs
  208. # for url, download and generate txt
  209. result = urlparse(url)
  210. if result.scheme is not None and len(result.scheme) > 0:
  211. storage = HTTPStorage()
  212. data = storage.read(url)
  213. work_dir = tempfile.TemporaryDirectory().name
  214. if not os.path.exists(work_dir):
  215. os.makedirs(work_dir)
  216. text_file_path = os.path.join(work_dir, os.path.basename(url))
  217. with open(text_file_path, 'wb') as fp:
  218. fp.write(data)
  219. return text_file_path, raw_inputs
  220. return text_file_path, raw_inputs
  221. def generate_scp_for_sv(url: str, key: str = None):
  222. wav_scp_path = None
  223. wav_name = key if key is not None else os.path.basename(url)
  224. # for local wav.scp inputs
  225. if os.path.exists(url) and url.lower().endswith('.scp'):
  226. wav_scp_path = url
  227. return wav_scp_path
  228. # for local wav file inputs
  229. if os.path.exists(url) and (url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)):
  230. wav_path = url
  231. work_dir = tempfile.TemporaryDirectory().name
  232. if not os.path.exists(work_dir):
  233. os.makedirs(work_dir)
  234. wav_scp_path = os.path.join(work_dir, 'wav.scp')
  235. with open(wav_scp_path, 'w') as ft:
  236. scp_content = '\t'.join([wav_name, wav_path]) + '\n'
  237. ft.writelines(scp_content)
  238. return wav_scp_path
  239. # for wav url, download and generate wav.scp
  240. result = urlparse(url)
  241. if result.scheme is not None and len(result.scheme) > 0:
  242. storage = HTTPStorage()
  243. wav_scp_path = storage.read(url)
  244. return wav_scp_path
  245. return wav_scp_path
  246. def generate_sv_scp_from_url(urls: Union[tuple, list]):
  247. """
  248. generate audio_scp files from url input for speaker verification.
  249. """
  250. audio_scps = []
  251. for url in urls:
  252. audio_scp = generate_scp_for_sv(url, key='test1')
  253. audio_scps.append(audio_scp)
  254. return audio_scps
  255. def generate_sd_scp_from_url(urls: Union[tuple, list]):
  256. """
  257. generate audio_scp files from url input for speaker diarization.
  258. """
  259. audio_scps = []
  260. for url in urls:
  261. if os.path.exists(url) and (
  262. url.lower().endswith(SUPPORT_AUDIO_TYPE_SETS)):
  263. audio_scp = url
  264. else:
  265. result = urlparse(url)
  266. if result.scheme is not None and len(result.scheme) > 0:
  267. storage = HTTPStorage()
  268. wav_bytes = storage.read(url)
  269. audio_scp = wav_bytes
  270. else:
  271. raise ValueError("Can't download from {}.".format(url))
  272. audio_scps.append(audio_scp)
  273. return audio_scps
  274. def update_local_model(model_config, model_path, extra_args):
  275. if 'update_model' in extra_args and not extra_args['update_model']:
  276. return
  277. model_revision = None
  278. if 'update_model' in extra_args:
  279. if extra_args['update_model'] == 'latest':
  280. model_revision = None
  281. else:
  282. model_revision = extra_args['update_model']
  283. if model_config.__contains__('model'):
  284. model_name = model_config['model']
  285. dst_dir_root = get_model_cache_root()
  286. if isinstance(model_path, str) and os.path.exists(
  287. model_path) and not model_path.startswith(dst_dir_root):
  288. try:
  289. dst = os.path.join(dst_dir_root, '.cache/' + model_name)
  290. dst_dir = os.path.dirname(dst)
  291. os.makedirs(dst_dir, exist_ok=True)
  292. if not os.path.exists(dst):
  293. os.symlink(os.path.abspath(model_path), dst)
  294. snapshot_download(
  295. model_name,
  296. cache_dir=dst_dir_root,
  297. revision=model_revision)
  298. except Exception as e:
  299. logger.warning(str(e))
  300. else:
  301. logger.warning('Can not find model name in configuration')