| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from scipy.signal import get_window
- def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
- if win_type == 'None' or win_type is None:
- window = np.ones(win_len)
- else:
- window = get_window(win_type, win_len, fftbins=True)**0.5
- N = fft_len
- fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
- real_kernel = np.real(fourier_basis)
- imag_kernel = np.imag(fourier_basis)
- kernel = np.concatenate([real_kernel, imag_kernel], 1).T
- if invers:
- kernel = np.linalg.pinv(kernel).T
- kernel = kernel * window
- kernel = kernel[:, None, :]
- return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
- window[None, :, None].astype(np.float32))
- class ConvSTFT(nn.Module):
- def __init__(self,
- win_len,
- win_inc,
- fft_len=None,
- win_type='hamming',
- feature_type='real',
- fix=True):
- super(ConvSTFT, self).__init__()
- if fft_len is None:
- self.fft_len = int(2**np.ceil(np.log2(win_len)))
- else:
- self.fft_len = fft_len
- kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
- self.weight = nn.Parameter(kernel, requires_grad=(not fix))
- self.feature_type = feature_type
- self.stride = win_inc
- self.win_len = win_len
- self.dim = self.fft_len
- def forward(self, inputs):
- if inputs.dim() == 2:
- inputs = torch.unsqueeze(inputs, 1)
- outputs = F.conv1d(inputs, self.weight, stride=self.stride)
- if self.feature_type == 'complex':
- return outputs
- else:
- dim = self.dim // 2 + 1
- real = outputs[:, :dim, :]
- imag = outputs[:, dim:, :]
- mags = torch.sqrt(real**2 + imag**2)
- phase = torch.atan2(imag, real)
- return mags, phase
- class ConviSTFT(nn.Module):
- def __init__(self,
- win_len,
- win_inc,
- fft_len=None,
- win_type='hamming',
- feature_type='real',
- fix=True):
- super(ConviSTFT, self).__init__()
- if fft_len is None:
- self.fft_len = int(2**np.ceil(np.log2(win_len)))
- else:
- self.fft_len = fft_len
- kernel, window = init_kernels(
- win_len, win_inc, self.fft_len, win_type, invers=True)
- self.weight = nn.Parameter(kernel, requires_grad=(not fix))
- self.feature_type = feature_type
- self.win_type = win_type
- self.win_len = win_len
- self.win_inc = win_inc
- self.stride = win_inc
- self.dim = self.fft_len
- self.register_buffer('window', window)
- self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
- def forward(self, inputs, phase=None):
- """
- Args:
- inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
- phase: [B, N//2+1, T] (if not none)
- """
- if phase is not None:
- real = inputs * torch.cos(phase)
- imag = inputs * torch.sin(phase)
- inputs = torch.cat([real, imag], 1)
- outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
- # this is from torch-stft: https://github.com/pseeth/torch-stft
- t = self.window.repeat(1, 1, inputs.size(-1))**2
- coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
- outputs = outputs / (coff + 1e-8)
- return outputs
|