|
from typing import Any, Callable, Dict |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
class LitSourceSeparation(pl.LightningModule): |
|
def __init__( |
|
self, |
|
batch_data_preprocessor, |
|
model: nn.Module, |
|
loss_function: Callable, |
|
optimizer_type: str, |
|
learning_rate: float, |
|
lr_lambda: Callable, |
|
): |
|
r"""Pytorch Lightning wrapper of PyTorch model, including forward, |
|
optimization of model, etc. |
|
|
|
Args: |
|
batch_data_preprocessor: object, used for preparing inputs and |
|
targets for training. E.g., BasicBatchDataPreprocessor is used |
|
for preparing data in dictionary into tensor. |
|
model: nn.Module |
|
loss_function: function |
|
learning_rate: float |
|
lr_lambda: function |
|
""" |
|
super().__init__() |
|
|
|
self.batch_data_preprocessor = batch_data_preprocessor |
|
self.model = model |
|
self.optimizer_type = optimizer_type |
|
self.loss_function = loss_function |
|
self.learning_rate = learning_rate |
|
self.lr_lambda = lr_lambda |
|
|
|
def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float: |
|
r"""Forward a mini-batch data to model, calculate loss function, and |
|
train for one step. A mini-batch data is evenly distributed to multiple |
|
devices (if there are) for parallel training. |
|
|
|
Args: |
|
batch_data_dict: e.g. { |
|
'vocals': (batch_size, channels_num, segment_samples), |
|
'accompaniment': (batch_size, channels_num, segment_samples), |
|
'mixture': (batch_size, channels_num, segment_samples) |
|
} |
|
batch_idx: int |
|
|
|
Returns: |
|
loss: float, loss function of this mini-batch |
|
""" |
|
input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model.train() |
|
|
|
output_dict = self.model(input_dict) |
|
|
|
|
|
|
|
|
|
outputs = output_dict['waveform'] |
|
|
|
|
|
|
|
loss = self.loss_function( |
|
output=outputs, |
|
target=target_dict['waveform'], |
|
mixture=input_dict['waveform'], |
|
) |
|
|
|
return loss |
|
|
|
def configure_optimizers(self) -> Any: |
|
r"""Configure optimizer.""" |
|
|
|
if self.optimizer_type == "Adam": |
|
optimizer = optim.Adam( |
|
self.model.parameters(), |
|
lr=self.learning_rate, |
|
betas=(0.9, 0.999), |
|
eps=1e-08, |
|
weight_decay=0.0, |
|
amsgrad=True, |
|
) |
|
|
|
elif self.optimizer_type == "AdamW": |
|
optimizer = optim.AdamW( |
|
self.model.parameters(), |
|
lr=self.learning_rate, |
|
betas=(0.9, 0.999), |
|
eps=1e-08, |
|
weight_decay=0.0, |
|
amsgrad=True, |
|
) |
|
|
|
else: |
|
raise NotImplementedError |
|
|
|
scheduler = { |
|
'scheduler': LambdaLR(optimizer, self.lr_lambda), |
|
'interval': 'step', |
|
'frequency': 1, |
|
} |
|
|
|
return [optimizer], [scheduler] |
|
|
|
|
|
def get_model_class(model_type): |
|
r"""Get model. |
|
|
|
Args: |
|
model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN' |
|
|
|
Returns: |
|
nn.Module |
|
""" |
|
if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021': |
|
from bytesep.models.resunet_ismir2021 import ( |
|
ResUNet143_DecouplePlusInplaceABN_ISMIR2021, |
|
) |
|
|
|
return ResUNet143_DecouplePlusInplaceABN_ISMIR2021 |
|
|
|
elif model_type == 'UNet': |
|
from bytesep.models.unet import UNet |
|
|
|
return UNet |
|
|
|
elif model_type == 'UNetSubbandTime': |
|
from bytesep.models.unet_subbandtime import UNetSubbandTime |
|
|
|
return UNetSubbandTime |
|
|
|
elif model_type == 'ResUNet143_Subbandtime': |
|
from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime |
|
|
|
return ResUNet143_Subbandtime |
|
|
|
elif model_type == 'ResUNet143_DecouplePlus': |
|
from bytesep.models.resunet import ResUNet143_DecouplePlus |
|
|
|
return ResUNet143_DecouplePlus |
|
|
|
elif model_type == 'ConditionalUNet': |
|
from bytesep.models.conditional_unet import ConditionalUNet |
|
|
|
return ConditionalUNet |
|
|
|
elif model_type == 'LevelRNN': |
|
from bytesep.models.levelrnn import LevelRNN |
|
|
|
return LevelRNN |
|
|
|
elif model_type == 'WavUNet': |
|
from bytesep.models.wavunet import WavUNet |
|
|
|
return WavUNet |
|
|
|
elif model_type == 'WavUNetLevelRNN': |
|
from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN |
|
|
|
return WavUNetLevelRNN |
|
|
|
elif model_type == 'TTnet': |
|
from bytesep.models.ttnet import TTnet |
|
|
|
return TTnet |
|
|
|
elif model_type == 'TTnetNoTransformer': |
|
from bytesep.models.ttnet_no_transformer import TTnetNoTransformer |
|
|
|
return TTnetNoTransformer |
|
|
|
else: |
|
raise NotImplementedError |
|
|