conv_stft.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from scipy.signal import get_window
  7. def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
  8. if win_type == 'None' or win_type is None:
  9. window = np.ones(win_len)
  10. else:
  11. window = get_window(win_type, win_len, fftbins=True)**0.5
  12. N = fft_len
  13. fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
  14. real_kernel = np.real(fourier_basis)
  15. imag_kernel = np.imag(fourier_basis)
  16. kernel = np.concatenate([real_kernel, imag_kernel], 1).T
  17. if invers:
  18. kernel = np.linalg.pinv(kernel).T
  19. kernel = kernel * window
  20. kernel = kernel[:, None, :]
  21. return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
  22. window[None, :, None].astype(np.float32))
  23. class ConvSTFT(nn.Module):
  24. def __init__(self,
  25. win_len,
  26. win_inc,
  27. fft_len=None,
  28. win_type='hamming',
  29. feature_type='real',
  30. fix=True):
  31. super(ConvSTFT, self).__init__()
  32. if fft_len is None:
  33. self.fft_len = int(2**np.ceil(np.log2(win_len)))
  34. else:
  35. self.fft_len = fft_len
  36. kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
  37. self.weight = nn.Parameter(kernel, requires_grad=(not fix))
  38. self.feature_type = feature_type
  39. self.stride = win_inc
  40. self.win_len = win_len
  41. self.dim = self.fft_len
  42. def forward(self, inputs):
  43. if inputs.dim() == 2:
  44. inputs = torch.unsqueeze(inputs, 1)
  45. outputs = F.conv1d(inputs, self.weight, stride=self.stride)
  46. if self.feature_type == 'complex':
  47. return outputs
  48. else:
  49. dim = self.dim // 2 + 1
  50. real = outputs[:, :dim, :]
  51. imag = outputs[:, dim:, :]
  52. mags = torch.sqrt(real**2 + imag**2)
  53. phase = torch.atan2(imag, real)
  54. return mags, phase
  55. class ConviSTFT(nn.Module):
  56. def __init__(self,
  57. win_len,
  58. win_inc,
  59. fft_len=None,
  60. win_type='hamming',
  61. feature_type='real',
  62. fix=True):
  63. super(ConviSTFT, self).__init__()
  64. if fft_len is None:
  65. self.fft_len = int(2**np.ceil(np.log2(win_len)))
  66. else:
  67. self.fft_len = fft_len
  68. kernel, window = init_kernels(
  69. win_len, win_inc, self.fft_len, win_type, invers=True)
  70. self.weight = nn.Parameter(kernel, requires_grad=(not fix))
  71. self.feature_type = feature_type
  72. self.win_type = win_type
  73. self.win_len = win_len
  74. self.win_inc = win_inc
  75. self.stride = win_inc
  76. self.dim = self.fft_len
  77. self.register_buffer('window', window)
  78. self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
  79. def forward(self, inputs, phase=None):
  80. """
  81. Args:
  82. inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
  83. phase: [B, N//2+1, T] (if not none)
  84. """
  85. if phase is not None:
  86. real = inputs * torch.cos(phase)
  87. imag = inputs * torch.sin(phase)
  88. inputs = torch.cat([real, imag], 1)
  89. outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
  90. # this is from torch-stft: https://github.com/pseeth/torch-stft
  91. t = self.window.repeat(1, 1, inputs.size(-1))**2
  92. coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
  93. outputs = outputs / (coff + 1e-8)
  94. return outputs