|
import math |
|
from typing import Callable |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torchlibrosa.stft import STFT |
|
|
|
from bytesep.models.pytorch_modules import Base |
|
|
|
|
|
def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: |
|
r"""L1 loss. |
|
|
|
Args: |
|
output: torch.Tensor |
|
target: torch.Tensor |
|
|
|
Returns: |
|
loss: torch.float |
|
""" |
|
return torch.mean(torch.abs(output - target)) |
|
|
|
|
|
def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor: |
|
r"""L1 loss in the time-domain. |
|
|
|
Args: |
|
output: torch.Tensor |
|
target: torch.Tensor |
|
|
|
Returns: |
|
loss: torch.float |
|
""" |
|
return l1(output, target) |
|
|
|
|
|
class L1_Wav_L1_Sp(nn.Module, Base): |
|
def __init__(self): |
|
r"""L1 loss in the time-domain and L1 loss on the spectrogram.""" |
|
super(L1_Wav_L1_Sp, self).__init__() |
|
|
|
self.window_size = 2048 |
|
hop_size = 441 |
|
center = True |
|
pad_mode = "reflect" |
|
window = "hann" |
|
|
|
self.stft = STFT( |
|
n_fft=self.window_size, |
|
hop_length=hop_size, |
|
win_length=self.window_size, |
|
window=window, |
|
center=center, |
|
pad_mode=pad_mode, |
|
freeze_parameters=True, |
|
) |
|
|
|
def __call__( |
|
self, output: torch.Tensor, target: torch.Tensor, **kwargs |
|
) -> torch.Tensor: |
|
r"""L1 loss in the time-domain and on the spectrogram. |
|
|
|
Args: |
|
output: torch.Tensor |
|
target: torch.Tensor |
|
|
|
Returns: |
|
loss: torch.float |
|
""" |
|
|
|
|
|
wav_loss = l1_wav(output, target) |
|
|
|
|
|
sp_loss = l1( |
|
self.wav_to_spectrogram(output, eps=1e-8), |
|
self.wav_to_spectrogram(target, eps=1e-8), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return wav_loss + sp_loss |
|
|
|
return sp_loss |
|
|
|
|
|
def get_loss_function(loss_type: str) -> Callable: |
|
r"""Get loss function. |
|
|
|
Args: |
|
loss_type: str |
|
|
|
Returns: |
|
loss function: Callable |
|
""" |
|
|
|
if loss_type == "l1_wav": |
|
return l1_wav |
|
|
|
elif loss_type == "l1_wav_l1_sp": |
|
return L1_Wav_L1_Sp() |
|
|
|
else: |
|
raise NotImplementedError |
|
|