|
from typing import List |
|
|
|
import pytorch_lightning as pl |
|
import torch.nn as nn |
|
|
|
|
|
def get_callbacks( |
|
task_name: str, |
|
config_yaml: str, |
|
workspace: str, |
|
checkpoints_dir: str, |
|
statistics_path: str, |
|
logger: pl.loggers.TensorBoardLogger, |
|
model: nn.Module, |
|
evaluate_device: str, |
|
) -> List[pl.Callback]: |
|
r"""Get callbacks of a task and config yaml file. |
|
|
|
Args: |
|
task_name: str |
|
config_yaml: str |
|
dataset_dir: str |
|
workspace: str, containing useful files such as audios for evaluation |
|
checkpoints_dir: str, directory to save checkpoints |
|
statistics_dir: str, directory to save statistics |
|
logger: pl.loggers.TensorBoardLogger |
|
model: nn.Module |
|
evaluate_device: str |
|
|
|
Return: |
|
callbacks: List[pl.Callback] |
|
""" |
|
if task_name == 'musdb18': |
|
|
|
from bytesep.callbacks.musdb18 import get_musdb18_callbacks |
|
|
|
return get_musdb18_callbacks( |
|
config_yaml=config_yaml, |
|
workspace=workspace, |
|
checkpoints_dir=checkpoints_dir, |
|
statistics_path=statistics_path, |
|
logger=logger, |
|
model=model, |
|
evaluate_device=evaluate_device, |
|
) |
|
|
|
elif task_name == 'voicebank-demand': |
|
|
|
from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks |
|
|
|
return get_voicebank_demand_callbacks( |
|
config_yaml=config_yaml, |
|
workspace=workspace, |
|
checkpoints_dir=checkpoints_dir, |
|
statistics_path=statistics_path, |
|
logger=logger, |
|
model=model, |
|
evaluate_device=evaluate_device, |
|
) |
|
|
|
elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']: |
|
|
|
from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks |
|
|
|
return get_instruments_callbacks( |
|
config_yaml=config_yaml, |
|
workspace=workspace, |
|
checkpoints_dir=checkpoints_dir, |
|
statistics_path=statistics_path, |
|
logger=logger, |
|
model=model, |
|
evaluate_device=evaluate_device, |
|
) |
|
|
|
else: |
|
raise NotImplementedError |
|
|