TDNN.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. class Conv1d_O(nn.Module):
  7. def __init__(
  8. self,
  9. out_channels,
  10. kernel_size,
  11. input_shape=None,
  12. in_channels=None,
  13. stride=1,
  14. dilation=1,
  15. padding='same',
  16. groups=1,
  17. bias=True,
  18. padding_mode='reflect',
  19. skip_transpose=False,
  20. ):
  21. super().__init__()
  22. self.kernel_size = kernel_size
  23. self.stride = stride
  24. self.dilation = dilation
  25. self.padding = padding
  26. self.padding_mode = padding_mode
  27. self.unsqueeze = False
  28. self.skip_transpose = skip_transpose
  29. if input_shape is None and in_channels is None:
  30. raise ValueError('Must provide one of input_shape or in_channels')
  31. if in_channels is None:
  32. in_channels = self._check_input_shape(input_shape)
  33. self.conv = nn.Conv1d(
  34. in_channels,
  35. out_channels,
  36. self.kernel_size,
  37. stride=self.stride,
  38. dilation=self.dilation,
  39. padding=0,
  40. groups=groups,
  41. bias=bias,
  42. )
  43. def forward(self, x):
  44. """Returns the output of the convolution.
  45. Arguments
  46. ---------
  47. x : torch.Tensor (batch, time, channel)
  48. input to convolve. 2d or 4d tensors are expected.
  49. """
  50. if not self.skip_transpose:
  51. x = x.transpose(1, -1)
  52. if self.unsqueeze:
  53. x = x.unsqueeze(1)
  54. if self.padding == 'same':
  55. x = self._manage_padding(x, self.kernel_size, self.dilation,
  56. self.stride)
  57. elif self.padding == 'causal':
  58. num_pad = (self.kernel_size - 1) * self.dilation
  59. x = F.pad(x, (num_pad, 0))
  60. elif self.padding == 'valid':
  61. pass
  62. else:
  63. raise ValueError(
  64. "Padding must be 'same', 'valid' or 'causal'. Got "
  65. + self.padding)
  66. wx = self.conv(x)
  67. if self.unsqueeze:
  68. wx = wx.squeeze(1)
  69. if not self.skip_transpose:
  70. wx = wx.transpose(1, -1)
  71. return wx
  72. def _manage_padding(
  73. self,
  74. x,
  75. kernel_size: int,
  76. dilation: int,
  77. stride: int,
  78. ):
  79. # Detecting input shape
  80. L_in = x.shape[-1]
  81. # Time padding
  82. padding = get_padding_elem(L_in, stride, kernel_size, dilation)
  83. # Applying padding
  84. x = F.pad(x, padding, mode=self.padding_mode)
  85. return x
  86. def _check_input_shape(self, shape):
  87. """Checks the input shape and returns the number of input channels.
  88. """
  89. if len(shape) == 2:
  90. self.unsqueeze = True
  91. in_channels = 1
  92. elif self.skip_transpose:
  93. in_channels = shape[1]
  94. elif len(shape) == 3:
  95. in_channels = shape[2]
  96. else:
  97. raise ValueError('conv1d expects 2d, 3d inputs. Got '
  98. + str(len(shape)))
  99. # Kernel size must be odd
  100. if self.kernel_size % 2 == 0:
  101. raise ValueError(
  102. 'The field kernel size must be an odd number. Got %s.' %
  103. (self.kernel_size))
  104. return in_channels
  105. # Skip transpose as much as possible for efficiency
  106. class Conv1d(Conv1d_O):
  107. def __init__(self, *args, **kwargs):
  108. super().__init__(skip_transpose=True, *args, **kwargs)
  109. def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
  110. """This function computes the number of elements to add for zero-padding.
  111. Arguments
  112. ---------
  113. L_in : int
  114. stride: int
  115. kernel_size : int
  116. dilation : int
  117. """
  118. if stride > 1:
  119. n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
  120. L_out = stride * (n_steps - 1) + kernel_size * dilation
  121. padding = [kernel_size // 2, kernel_size // 2]
  122. else:
  123. L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
  124. padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
  125. return padding
  126. class BatchNorm1d_O(nn.Module):
  127. def __init__(
  128. self,
  129. input_shape=None,
  130. input_size=None,
  131. eps=1e-05,
  132. momentum=0.1,
  133. affine=True,
  134. track_running_stats=True,
  135. combine_batch_time=False,
  136. skip_transpose=False,
  137. ):
  138. super().__init__()
  139. self.combine_batch_time = combine_batch_time
  140. self.skip_transpose = skip_transpose
  141. if input_size is None and skip_transpose:
  142. input_size = input_shape[1]
  143. elif input_size is None:
  144. input_size = input_shape[-1]
  145. self.norm = nn.BatchNorm1d(
  146. input_size,
  147. eps=eps,
  148. momentum=momentum,
  149. affine=affine,
  150. track_running_stats=track_running_stats,
  151. )
  152. def forward(self, x):
  153. """Returns the normalized input tensor.
  154. Arguments
  155. ---------
  156. x : torch.Tensor (batch, time, [channels])
  157. input to normalize. 2d or 3d tensors are expected in input
  158. 4d tensors can be used when combine_dims=True.
  159. """
  160. shape_or = x.shape
  161. if self.combine_batch_time:
  162. if x.ndim == 3:
  163. x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
  164. else:
  165. x = x.reshape(shape_or[0] * shape_or[1], shape_or[3],
  166. shape_or[2])
  167. elif not self.skip_transpose:
  168. x = x.transpose(-1, 1)
  169. x_n = self.norm(x)
  170. if self.combine_batch_time:
  171. x_n = x_n.reshape(shape_or)
  172. elif not self.skip_transpose:
  173. x_n = x_n.transpose(1, -1)
  174. return x_n
  175. class BatchNorm1d(BatchNorm1d_O):
  176. def __init__(self, *args, **kwargs):
  177. super().__init__(skip_transpose=True, *args, **kwargs)
  178. class Xvector(torch.nn.Module):
  179. """This model extracts X-vectors for speaker recognition and diarization.
  180. Arguments
  181. ---------
  182. device : str
  183. Device used e.g. "cpu" or "cuda".
  184. activation : torch class
  185. A class for constructing the activation layers.
  186. tdnn_blocks : int
  187. Number of time-delay neural (TDNN) layers.
  188. tdnn_channels : list of ints
  189. Output channels for TDNN layer.
  190. tdnn_kernel_sizes : list of ints
  191. List of kernel sizes for each TDNN layer.
  192. tdnn_dilations : list of ints
  193. List of dilations for kernels in each TDNN layer.
  194. lin_neurons : int
  195. Number of neurons in linear layers.
  196. Example
  197. -------
  198. >>> compute_xvect = Xvector('cpu')
  199. >>> input_feats = torch.rand([5, 10, 40])
  200. >>> outputs = compute_xvect(input_feats)
  201. >>> outputs.shape
  202. torch.Size([5, 1, 512])
  203. """
  204. def __init__(
  205. self,
  206. device='cpu',
  207. activation=torch.nn.LeakyReLU,
  208. tdnn_blocks=5,
  209. tdnn_channels=[512, 512, 512, 512, 1500],
  210. tdnn_kernel_sizes=[5, 3, 3, 1, 1],
  211. tdnn_dilations=[1, 2, 3, 1, 1],
  212. lin_neurons=512,
  213. in_channels=80,
  214. ):
  215. super().__init__()
  216. self.blocks = nn.ModuleList()
  217. # TDNN layers
  218. for block_index in range(tdnn_blocks):
  219. out_channels = tdnn_channels[block_index]
  220. self.blocks.extend([
  221. Conv1d(
  222. in_channels=in_channels,
  223. out_channels=out_channels,
  224. kernel_size=tdnn_kernel_sizes[block_index],
  225. dilation=tdnn_dilations[block_index],
  226. ),
  227. activation(),
  228. BatchNorm1d(input_size=out_channels),
  229. ])
  230. in_channels = tdnn_channels[block_index]
  231. def forward(self, x, lens=None):
  232. """Returns the x-vectors.
  233. Arguments
  234. ---------
  235. x : torch.Tensor
  236. """
  237. x = x.transpose(1, 2)
  238. for layer in self.blocks:
  239. try:
  240. x = layer(x, lengths=lens)
  241. except TypeError:
  242. x = layer(x)
  243. x = x.transpose(1, 2)
  244. return x