segmentation_clustering_pipeline.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import ast
  3. import io
  4. from typing import Any, Dict, List, Union
  5. import numpy as np
  6. import soundfile as sf
  7. import torch
  8. import torchaudio
  9. from modelscope.fileio import File
  10. from modelscope.metainfo import Pipelines
  11. from modelscope.outputs import OutputKeys
  12. from modelscope.pipelines import pipeline
  13. from modelscope.pipelines.base import InputModel, Pipeline
  14. from modelscope.pipelines.builder import PIPELINES
  15. from modelscope.utils.constant import Tasks
  16. from modelscope.utils.logger import get_logger
  17. logger = get_logger()
  18. __all__ = ['SegmentationClusteringPipeline']
  19. @PIPELINES.register_module(
  20. Tasks.speaker_diarization, module_name=Pipelines.segmentation_clustering)
  21. class SegmentationClusteringPipeline(Pipeline):
  22. """Segmentation and Clustering Pipeline
  23. use `model` to create a Segmentation and Clustering Pipeline.
  24. Args:
  25. model (SegmentationClusteringPipeline): A model instance, or a model local dir, or a model id in the model hub.
  26. kwargs (dict, `optional`):
  27. Extra kwargs passed into the pipeline's constructor.
  28. Example:
  29. >>> from modelscope.pipelines import pipeline
  30. >>> from modelscope.utils.constant import Tasks
  31. >>> p = pipeline(
  32. >>> task=Tasks.speaker_diarization, model='damo/speech_campplus_speaker-diarization_common')
  33. >>> print(p(audio))
  34. """
  35. def __init__(self, model: InputModel, **kwargs):
  36. """use `model` to create a speaker diarization pipeline for prediction
  37. Args:
  38. model (str): a valid official model id
  39. """
  40. super().__init__(model=model, **kwargs)
  41. self.config = self.model.other_config
  42. config = {
  43. 'seg_dur': 1.5,
  44. 'seg_shift': 0.75,
  45. }
  46. self.config.update(config)
  47. self.fs = self.config['sample_rate']
  48. self.sv_pipeline = pipeline(
  49. task='speaker-verification', model=self.config['speaker_model'])
  50. def __call__(self, audio: Union[str, np.ndarray, list],
  51. **params) -> Dict[str, Any]:
  52. """ extract the speaker embeddings of input audio and do cluster
  53. Args:
  54. audio (str, np.ndarray, list): If it is represented as a str or a np.ndarray, it
  55. should be a complete speech signal and requires VAD preprocessing. If the audio
  56. is represented as a list, it should contain only the effective speech segments
  57. obtained through VAD preprocessing. The list should be formatted as [[0(s),3.2,
  58. np.ndarray], [5.3,9.1, np.ndarray], ...]. Each element is a sublist that contains
  59. the start time, end time, and the numpy array of the speech segment respectively.
  60. """
  61. self.config.update(params)
  62. # vad
  63. logger.info('Doing VAD...')
  64. vad_segments = self.preprocess(audio)
  65. # check input data
  66. self.check_audio_list(vad_segments)
  67. # segmentation
  68. logger.info('Doing segmentation...')
  69. segments = self.chunk(vad_segments)
  70. # embedding
  71. logger.info('Extracting embeddings...')
  72. embeddings = self.forward(segments)
  73. # clustering
  74. logger.info('Clustering...')
  75. labels = self.clustering(embeddings)
  76. # post processing
  77. logger.info('Post processing...')
  78. output = self.postprocess(segments, vad_segments, labels, embeddings)
  79. return {OutputKeys.TEXT: output}
  80. def forward(self, input: list) -> np.ndarray:
  81. embeddings = []
  82. for s in input:
  83. save_dict = self.sv_pipeline([s[2]], output_emb=True)
  84. if save_dict['embs'].shape == (1, 192):
  85. embeddings.append(save_dict['embs'])
  86. embeddings = np.concatenate(embeddings)
  87. return embeddings
  88. def clustering(self, embeddings: np.ndarray) -> np.ndarray:
  89. labels = self.model(embeddings, **self.config)
  90. return labels
  91. def postprocess(self, segments: list, vad_segments: list,
  92. labels: np.ndarray, embeddings: np.ndarray) -> list:
  93. assert len(segments) == len(labels)
  94. labels = self.correct_labels(labels)
  95. distribute_res = []
  96. for i in range(len(segments)):
  97. distribute_res.append([segments[i][0], segments[i][1], labels[i]])
  98. # merge the same speakers chronologically
  99. distribute_res = self.merge_seque(distribute_res)
  100. # acquire speaker center
  101. spk_embs = []
  102. for i in range(labels.max() + 1):
  103. spk_emb = embeddings[labels == i].mean(0)
  104. spk_embs.append(spk_emb)
  105. spk_embs = np.stack(spk_embs)
  106. def is_overlapped(t1, t2):
  107. if t1 > t2 + 1e-4:
  108. return True
  109. return False
  110. # distribute the overlap region
  111. for i in range(1, len(distribute_res)):
  112. if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
  113. p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
  114. if 'change_locator' in self.config:
  115. if not hasattr(self, 'change_locator_pipeline'):
  116. self.change_locator_pipeline = pipeline(
  117. task=Tasks.speaker_diarization,
  118. model=self.config['change_locator'])
  119. short_utt_st = max(p - 1.5, distribute_res[i - 1][0])
  120. short_utt_ed = min(p + 1.5, distribute_res[i][1])
  121. if short_utt_ed - short_utt_st > 1:
  122. audio_data = self.cut_audio(short_utt_st, short_utt_ed,
  123. vad_segments)
  124. spk1 = distribute_res[i - 1][2]
  125. spk2 = distribute_res[i][2]
  126. _, ct = self.change_locator_pipeline(
  127. audio_data, [spk_embs[spk1], spk_embs[spk2]],
  128. output_res=True)
  129. if ct is not None:
  130. p = short_utt_st + ct
  131. distribute_res[i][0] = p
  132. distribute_res[i - 1][1] = p
  133. # smooth the result
  134. distribute_res = self.smooth(distribute_res)
  135. return distribute_res
  136. def preprocess(self, audio: Union[str, np.ndarray, list]) -> list:
  137. if isinstance(audio, list):
  138. audio.sort(key=lambda x: x[0])
  139. return audio
  140. elif isinstance(audio, str):
  141. file_bytes = File.read(audio)
  142. audio, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
  143. if len(audio.shape) == 2:
  144. audio = audio[:, 0]
  145. if fs != self.fs:
  146. logger.info(
  147. f'[WARNING]: The sample rate of audio is not {self.fs}, resample it.'
  148. )
  149. audio, fs = torchaudio.sox_effects.apply_effects_tensor(
  150. torch.from_numpy(audio).unsqueeze(0),
  151. fs,
  152. effects=[['rate', str(self.fs)]])
  153. audio = audio.squeeze(0).numpy()
  154. assert len(audio.shape) == 1, 'modelscope error: Wrong audio format.'
  155. if audio.dtype in ['int16', 'int32', 'int64']:
  156. audio = (audio / (1 << 15)).astype('float32')
  157. else:
  158. audio = audio.astype('float32')
  159. if not hasattr(self, 'vad_pipeline'):
  160. self.vad_pipeline = pipeline(
  161. task=Tasks.voice_activity_detection,
  162. model=self.config['vad_model'],
  163. model_revision='v2.0.2')
  164. vad_time = self.vad_pipeline(
  165. audio, fs=self.fs, is_final=True)[0]['value']
  166. vad_segments = []
  167. if isinstance(vad_time, str):
  168. vad_time_list = ast.literal_eval(vad_time)
  169. elif isinstance(vad_time, list):
  170. vad_time_list = vad_time
  171. else:
  172. raise ValueError('Incorrect vad result. Get %s' % (type(vad_time)))
  173. for t in vad_time_list:
  174. st = int(t[0]) / 1000
  175. ed = int(t[1]) / 1000
  176. vad_segments.append(
  177. [st, ed, audio[int(st * self.fs):int(ed * self.fs)]])
  178. return vad_segments
  179. def check_audio_list(self, audio: list):
  180. audio_dur = 0
  181. for i in range(len(audio)):
  182. seg = audio[i]
  183. assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
  184. assert isinstance(seg[2],
  185. np.ndarray), 'modelscope error: Wrong data type.'
  186. assert int(seg[1] * self.fs) - int(
  187. seg[0] * self.fs
  188. ) == seg[2].shape[
  189. 0], 'modelscope error: audio data in list is inconsistent with time length.'
  190. if i > 0:
  191. assert seg[0] >= audio[
  192. i - 1][1], 'modelscope error: Wrong time stamps.'
  193. audio_dur += seg[1] - seg[0]
  194. assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
  195. def chunk(self, vad_segments: list) -> list:
  196. def seg_chunk(seg_data):
  197. seg_st = seg_data[0]
  198. data = seg_data[2]
  199. chunk_len = int(self.config['seg_dur'] * self.fs)
  200. chunk_shift = int(self.config['seg_shift'] * self.fs)
  201. last_chunk_ed = 0
  202. seg_res = []
  203. for chunk_st in range(0, data.shape[0], chunk_shift):
  204. chunk_ed = min(chunk_st + chunk_len, data.shape[0])
  205. if chunk_ed <= last_chunk_ed:
  206. break
  207. last_chunk_ed = chunk_ed
  208. chunk_st = max(0, chunk_ed - chunk_len)
  209. chunk_data = data[chunk_st:chunk_ed]
  210. if chunk_data.shape[0] < chunk_len:
  211. chunk_data = np.pad(chunk_data,
  212. (0, chunk_len - chunk_data.shape[0]),
  213. 'constant')
  214. seg_res.append([
  215. chunk_st / self.fs + seg_st, chunk_ed / self.fs + seg_st,
  216. chunk_data
  217. ])
  218. return seg_res
  219. segs = []
  220. for i, s in enumerate(vad_segments):
  221. segs.extend(seg_chunk(s))
  222. return segs
  223. def cut_audio(self, cut_st: float, cut_ed: float,
  224. audio: Union[np.ndarray, list]) -> np.ndarray:
  225. # collect audio data given the start and end time.
  226. if isinstance(audio, np.ndarray):
  227. return audio[int(cut_st * self.fs):int(cut_ed * self.fs)]
  228. elif isinstance(audio, list):
  229. for i in range(len(audio)):
  230. if i == 0:
  231. if cut_st < audio[i][1]:
  232. st_i = i
  233. else:
  234. if cut_st >= audio[i - 1][1] and cut_st < audio[i][1]:
  235. st_i = i
  236. if i == len(audio) - 1:
  237. if cut_ed > audio[i][0]:
  238. ed_i = i
  239. else:
  240. if cut_ed > audio[i][0] and cut_ed <= audio[i + 1][0]:
  241. ed_i = i
  242. audio_segs = audio[st_i:ed_i + 1]
  243. cut_data = []
  244. for i in range(len(audio_segs)):
  245. s_st, s_ed, data = audio_segs[i]
  246. cut_data.append(
  247. data[int((max(cut_st, s_st) - s_st)
  248. * self.fs):int((min(cut_ed, s_ed) - s_st)
  249. * self.fs)])
  250. cut_data = np.concatenate(cut_data)
  251. return cut_data
  252. else:
  253. raise ValueError('modelscope error: Wrong audio format.')
  254. def correct_labels(self, labels):
  255. labels_id = 0
  256. id2id = {}
  257. new_labels = []
  258. for i in labels:
  259. if i not in id2id:
  260. id2id[i] = labels_id
  261. labels_id += 1
  262. new_labels.append(id2id[i])
  263. return np.array(new_labels)
  264. def merge_seque(self, distribute_res):
  265. res = [distribute_res[0]]
  266. for i in range(1, len(distribute_res)):
  267. if distribute_res[i][2] != res[-1][2] or distribute_res[i][
  268. 0] > res[-1][1]:
  269. res.append(distribute_res[i])
  270. else:
  271. res[-1][1] = distribute_res[i][1]
  272. return res
  273. def smooth(self, res, mindur=1):
  274. # short segments are assigned to nearest speakers.
  275. for i in range(len(res)):
  276. res[i][0] = round(res[i][0], 2)
  277. res[i][1] = round(res[i][1], 2)
  278. if res[i][1] - res[i][0] < mindur:
  279. if i == 0:
  280. res[i][2] = res[i + 1][2]
  281. elif i == len(res) - 1:
  282. res[i][2] = res[i - 1][2]
  283. elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
  284. res[i][2] = res[i - 1][2]
  285. else:
  286. res[i][2] = res[i + 1][2]
  287. # merge the speakers
  288. res = self.merge_seque(res)
  289. return res