cluster_backend.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Union
  3. import hdbscan
  4. import numpy as np
  5. import scipy
  6. import sklearn
  7. import umap
  8. from sklearn.cluster._kmeans import k_means
  9. from modelscope.metainfo import Models
  10. from modelscope.models import MODELS, TorchModel
  11. from modelscope.utils.constant import Tasks
  12. class SpectralCluster:
  13. r"""A spectral clustering method using unnormalized Laplacian of affinity matrix.
  14. This implementation is adapted from https://github.com/speechbrain/speechbrain.
  15. """
  16. def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
  17. self.min_num_spks = min_num_spks
  18. self.max_num_spks = max_num_spks
  19. self.pval = pval
  20. def __call__(self, X, oracle_num=None):
  21. # Similarity matrix computation
  22. sim_mat = self.get_sim_mat(X)
  23. # Refining similarity matrix with pval
  24. prunned_sim_mat = self.p_pruning(sim_mat)
  25. # Symmetrization
  26. sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
  27. # Laplacian calculation
  28. laplacian = self.get_laplacian(sym_prund_sim_mat)
  29. # Get Spectral Embeddings
  30. emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
  31. # Perform clustering
  32. labels = self.cluster_embs(emb, num_of_spk)
  33. return labels
  34. def get_sim_mat(self, X):
  35. # Cosine similarities
  36. M = sklearn.metrics.pairwise.cosine_similarity(X, X)
  37. return M
  38. def p_pruning(self, A):
  39. if A.shape[0] * self.pval < 6:
  40. pval = 6. / A.shape[0]
  41. else:
  42. pval = self.pval
  43. n_elems = int((1 - pval) * A.shape[0])
  44. # For each row in a affinity matrix
  45. for i in range(A.shape[0]):
  46. low_indexes = np.argsort(A[i, :])
  47. low_indexes = low_indexes[0:n_elems]
  48. # Replace smaller similarity values by 0s
  49. A[i, low_indexes] = 0
  50. return A
  51. def get_laplacian(self, M):
  52. M[np.diag_indices(M.shape[0])] = 0
  53. D = np.sum(np.abs(M), axis=1)
  54. D = np.diag(D)
  55. L = D - M
  56. return L
  57. def get_spec_embs(self, L, k_oracle=None):
  58. lambdas, eig_vecs = scipy.linalg.eigh(L)
  59. if k_oracle is not None:
  60. num_of_spk = k_oracle
  61. else:
  62. lambda_gap_list = self.getEigenGaps(
  63. lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
  64. num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
  65. emb = eig_vecs[:, :num_of_spk]
  66. return emb, num_of_spk
  67. def cluster_embs(self, emb, k):
  68. _, labels, _ = k_means(emb, k)
  69. return labels
  70. def getEigenGaps(self, eig_vals):
  71. eig_vals_gap_list = []
  72. for i in range(len(eig_vals) - 1):
  73. gap = float(eig_vals[i + 1]) - float(eig_vals[i])
  74. eig_vals_gap_list.append(gap)
  75. return eig_vals_gap_list
  76. class UmapHdbscan:
  77. r"""
  78. Reference:
  79. - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
  80. Emphasis On Topological Structure. ICASSP2022
  81. """
  82. def __init__(self,
  83. n_neighbors=20,
  84. n_components=60,
  85. min_samples=10,
  86. min_cluster_size=10,
  87. metric='cosine'):
  88. self.n_neighbors = n_neighbors
  89. self.n_components = n_components
  90. self.min_samples = min_samples
  91. self.min_cluster_size = min_cluster_size
  92. self.metric = metric
  93. def __call__(self, X):
  94. umap_X = umap.UMAP(
  95. n_neighbors=self.n_neighbors,
  96. min_dist=0.0,
  97. n_components=min(self.n_components, X.shape[0] - 2),
  98. metric=self.metric,
  99. ).fit_transform(X)
  100. labels = hdbscan.HDBSCAN(
  101. min_samples=self.min_samples,
  102. min_cluster_size=self.min_cluster_size,
  103. allow_single_cluster=True).fit_predict(umap_X)
  104. return labels
  105. @MODELS.register_module(
  106. Tasks.speaker_diarization, module_name=Models.cluster_backend)
  107. class ClusterBackend(TorchModel):
  108. r"""Perform clustering for input embeddings and output the labels.
  109. Args:
  110. model_dir: A model dir.
  111. model_config: The model config.
  112. """
  113. def __init__(self, model_dir, model_config: Dict[str, Any], *args,
  114. **kwargs):
  115. super().__init__(model_dir, model_config, *args, **kwargs)
  116. self.model_config = model_config
  117. self.other_config = kwargs
  118. self.spectral_cluster = SpectralCluster()
  119. self.umap_hdbscan_cluster = UmapHdbscan()
  120. def forward(self, X, **params):
  121. # clustering and return the labels
  122. k = params['oracle_num'] if 'oracle_num' in params else None
  123. assert len(
  124. X.shape
  125. ) == 2, 'modelscope error: the shape of input should be [N, C]'
  126. if X.shape[0] < 20:
  127. return np.zeros(X.shape[0], dtype='int')
  128. if X.shape[0] < 2048 or k is not None:
  129. labels = self.spectral_cluster(X, k)
  130. else:
  131. labels = self.umap_hdbscan_cluster(X)
  132. if k is None and 'merge_thr' in self.model_config:
  133. labels = self.merge_by_cos(labels, X,
  134. self.model_config['merge_thr'])
  135. return labels
  136. def merge_by_cos(self, labels, embs, cos_thr):
  137. # merge the similar speakers by cosine similarity
  138. assert cos_thr > 0 and cos_thr <= 1
  139. while True:
  140. spk_num = labels.max() + 1
  141. if spk_num == 1:
  142. break
  143. spk_center = []
  144. for i in range(spk_num):
  145. spk_emb = embs[labels == i].mean(0)
  146. spk_center.append(spk_emb)
  147. assert len(spk_center) > 0
  148. spk_center = np.stack(spk_center, axis=0)
  149. norm_spk_center = spk_center / np.linalg.norm(
  150. spk_center, axis=1, keepdims=True)
  151. affinity = np.matmul(norm_spk_center, norm_spk_center.T)
  152. affinity = np.triu(affinity, 1)
  153. spks = np.unravel_index(np.argmax(affinity), affinity.shape)
  154. if affinity[spks] < cos_thr:
  155. break
  156. for i in range(len(labels)):
  157. if labels[i] == spks[1]:
  158. labels[i] = spks[0]
  159. elif labels[i] > spks[1]:
  160. labels[i] -= 1
  161. return labels