|
from torchlibrosa.stft import STFT, ISTFT, magphase |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from tools.pytorch.modules.pqmf import PQMF |
|
|
|
|
|
class FDomainHelper(nn.Module): |
|
def __init__( |
|
self, |
|
window_size=2048, |
|
hop_size=441, |
|
center=True, |
|
pad_mode='reflect', |
|
window='hann', |
|
freeze_parameters=True, |
|
subband=None, |
|
root="/Users/admin/Documents/projects/", |
|
): |
|
super(FDomainHelper, self).__init__() |
|
self.subband = subband |
|
if self.subband is None: |
|
self.stft = STFT( |
|
n_fft=window_size, |
|
hop_length=hop_size, |
|
win_length=window_size, |
|
window=window, |
|
center=center, |
|
pad_mode=pad_mode, |
|
freeze_parameters=freeze_parameters, |
|
) |
|
|
|
self.istft = ISTFT( |
|
n_fft=window_size, |
|
hop_length=hop_size, |
|
win_length=window_size, |
|
window=window, |
|
center=center, |
|
pad_mode=pad_mode, |
|
freeze_parameters=freeze_parameters, |
|
) |
|
else: |
|
self.stft = STFT( |
|
n_fft=window_size // self.subband, |
|
hop_length=hop_size // self.subband, |
|
win_length=window_size // self.subband, |
|
window=window, |
|
center=center, |
|
pad_mode=pad_mode, |
|
freeze_parameters=freeze_parameters, |
|
) |
|
|
|
self.istft = ISTFT( |
|
n_fft=window_size // self.subband, |
|
hop_length=hop_size // self.subband, |
|
win_length=window_size // self.subband, |
|
window=window, |
|
center=center, |
|
pad_mode=pad_mode, |
|
freeze_parameters=freeze_parameters, |
|
) |
|
|
|
if subband is not None and root is not None: |
|
self.qmf = PQMF(subband, 64, root) |
|
|
|
def complex_spectrogram(self, input, eps=0.0): |
|
|
|
|
|
real, imag = self.stft(input) |
|
return torch.cat([real, imag], dim=1) |
|
|
|
def reverse_complex_spectrogram(self, input, eps=0.0, length=None): |
|
|
|
wav = self.istft(input[:, 0:1, ...], input[:, 1:2, ...], length=length) |
|
return wav |
|
|
|
def spectrogram(self, input, eps=0.0): |
|
(real, imag) = self.stft(input.float()) |
|
return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 |
|
|
|
def spectrogram_phase(self, input, eps=0.0): |
|
(real, imag) = self.stft(input.float()) |
|
mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5 |
|
cos = real / mag |
|
sin = imag / mag |
|
return mag, cos, sin |
|
|
|
def wav_to_spectrogram_phase(self, input, eps=1e-8): |
|
"""Waveform to spectrogram. |
|
|
|
Args: |
|
input: (batch_size, channels_num, segment_samples) |
|
|
|
Outputs: |
|
output: (batch_size, channels_num, time_steps, freq_bins) |
|
""" |
|
sp_list = [] |
|
cos_list = [] |
|
sin_list = [] |
|
channels_num = input.shape[1] |
|
for channel in range(channels_num): |
|
mag, cos, sin = self.spectrogram_phase(input[:, channel, :], eps=eps) |
|
sp_list.append(mag) |
|
cos_list.append(cos) |
|
sin_list.append(sin) |
|
|
|
sps = torch.cat(sp_list, dim=1) |
|
coss = torch.cat(cos_list, dim=1) |
|
sins = torch.cat(sin_list, dim=1) |
|
return sps, coss, sins |
|
|
|
def spectrogram_phase_to_wav(self, sps, coss, sins, length): |
|
channels_num = sps.size()[1] |
|
res = [] |
|
for i in range(channels_num): |
|
res.append( |
|
self.istft( |
|
sps[:, i : i + 1, ...] * coss[:, i : i + 1, ...], |
|
sps[:, i : i + 1, ...] * sins[:, i : i + 1, ...], |
|
length, |
|
) |
|
) |
|
res[-1] = res[-1].unsqueeze(1) |
|
return torch.cat(res, dim=1) |
|
|
|
def wav_to_spectrogram(self, input, eps=1e-8): |
|
"""Waveform to spectrogram. |
|
|
|
Args: |
|
input: (batch_size,channels_num, segment_samples) |
|
|
|
Outputs: |
|
output: (batch_size, channels_num, time_steps, freq_bins) |
|
""" |
|
sp_list = [] |
|
channels_num = input.shape[1] |
|
for channel in range(channels_num): |
|
sp_list.append(self.spectrogram(input[:, channel, :], eps=eps)) |
|
output = torch.cat(sp_list, dim=1) |
|
return output |
|
|
|
def spectrogram_to_wav(self, input, spectrogram, length=None): |
|
"""Spectrogram to waveform. |
|
Args: |
|
input: (batch_size, segment_samples, channels_num) |
|
spectrogram: (batch_size, channels_num, time_steps, freq_bins) |
|
|
|
Outputs: |
|
output: (batch_size, segment_samples, channels_num) |
|
""" |
|
channels_num = input.shape[1] |
|
wav_list = [] |
|
for channel in range(channels_num): |
|
(real, imag) = self.stft(input[:, channel, :]) |
|
(_, cos, sin) = magphase(real, imag) |
|
wav_list.append( |
|
self.istft( |
|
spectrogram[:, channel : channel + 1, :, :] * cos, |
|
spectrogram[:, channel : channel + 1, :, :] * sin, |
|
length, |
|
) |
|
) |
|
|
|
output = torch.stack(wav_list, dim=1) |
|
return output |
|
|
|
|
|
def wav_to_complex_spectrogram(self, input, eps=0.0): |
|
|
|
|
|
res = [] |
|
channels_num = input.shape[1] |
|
for channel in range(channels_num): |
|
res.append(self.complex_spectrogram(input[:, channel, :], eps=eps)) |
|
return torch.cat(res, dim=1) |
|
|
|
def complex_spectrogram_to_wav(self, input, eps=0.0, length=None): |
|
|
|
|
|
channels = input.size()[1] // 2 |
|
wavs = [] |
|
for i in range(channels): |
|
wavs.append( |
|
self.reverse_complex_spectrogram( |
|
input[:, 2 * i : 2 * i + 2, ...], eps=eps, length=length |
|
) |
|
) |
|
wavs[-1] = wavs[-1].unsqueeze(1) |
|
return torch.cat(wavs, dim=1) |
|
|
|
def wav_to_complex_subband_spectrogram(self, input, eps=0.0): |
|
|
|
|
|
subwav = self.qmf.analysis(input) |
|
subspec = self.wav_to_complex_spectrogram(subwav) |
|
return subspec |
|
|
|
def complex_subband_spectrogram_to_wav(self, input, eps=0.0): |
|
|
|
|
|
subwav = self.complex_spectrogram_to_wav(input) |
|
data = self.qmf.synthesis(subwav) |
|
return data |
|
|
|
def wav_to_mag_phase_subband_spectrogram(self, input, eps=1e-8): |
|
""" |
|
:param input: |
|
:param eps: |
|
:return: |
|
loss = torch.nn.L1Loss() |
|
model = FDomainHelper(subband=4) |
|
data = torch.randn((3,1, 44100*3)) |
|
|
|
sps, coss, sins = model.wav_to_mag_phase_subband_spectrogram(data) |
|
wav = model.mag_phase_subband_spectrogram_to_wav(sps,coss,sins,44100*3//4) |
|
|
|
print(loss(data,wav)) |
|
print(torch.max(torch.abs(data-wav))) |
|
|
|
""" |
|
|
|
|
|
subwav = self.qmf.analysis(input) |
|
sps, coss, sins = self.wav_to_spectrogram_phase(subwav, eps=eps) |
|
return sps, coss, sins |
|
|
|
def mag_phase_subband_spectrogram_to_wav(self, sps, coss, sins, length, eps=0.0): |
|
|
|
|
|
subwav = self.spectrogram_phase_to_wav(sps, coss, sins, length) |
|
data = self.qmf.synthesis(subwav) |
|
return data |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from tools.file.wav import * |
|
|
|
loss = torch.nn.L1Loss() |
|
model = FDomainHelper() |
|
data = torch.randn((3, 1, 44100 * 5)) |
|
|
|
sps = model.wav_to_complex_spectrogram(data) |
|
print(sps.size()) |
|
wav = model.complex_spectrogram_to_wav(sps, 44100 * 5) |
|
|
|
print(loss(data, wav)) |
|
print(torch.max(torch.abs(data - wav))) |
|
|