|
import argparse |
|
import logging |
|
import os |
|
import pathlib |
|
from functools import partial |
|
from typing import List, NoReturn |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.plugins import DDPPlugin |
|
|
|
from bytesep.callbacks import get_callbacks |
|
from bytesep.data.augmentors import Augmentor |
|
from bytesep.data.batch_data_preprocessors import ( |
|
get_batch_data_preprocessor_class, |
|
) |
|
from bytesep.data.data_modules import DataModule, Dataset |
|
from bytesep.data.samplers import SegmentSampler |
|
from bytesep.losses import get_loss_function |
|
from bytesep.models.lightning_modules import ( |
|
LitSourceSeparation, |
|
get_model_class, |
|
) |
|
from bytesep.optimizers.lr_schedulers import get_lr_lambda |
|
from bytesep.utils import ( |
|
create_logging, |
|
get_pitch_shift_factor, |
|
read_yaml, |
|
check_configs_gramma, |
|
) |
|
|
|
|
|
def get_dirs( |
|
workspace: str, task_name: str, filename: str, config_yaml: str, gpus: int |
|
) -> List[str]: |
|
r"""Get directories. |
|
|
|
Args: |
|
workspace: str |
|
task_name, str, e.g., 'musdb18' |
|
filenmae: str |
|
config_yaml: str |
|
gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards |
|
|
|
Returns: |
|
checkpoints_dir: str |
|
logs_dir: str |
|
logger: pl.loggers.TensorBoardLogger |
|
statistics_path: str |
|
""" |
|
|
|
|
|
checkpoints_dir = os.path.join( |
|
workspace, |
|
"checkpoints", |
|
task_name, |
|
filename, |
|
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), |
|
) |
|
os.makedirs(checkpoints_dir, exist_ok=True) |
|
|
|
|
|
logs_dir = os.path.join( |
|
workspace, |
|
"logs", |
|
task_name, |
|
filename, |
|
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), |
|
) |
|
os.makedirs(logs_dir, exist_ok=True) |
|
|
|
|
|
create_logging(logs_dir, filemode='w') |
|
logging.info(args) |
|
|
|
|
|
tb_logs_dir = os.path.join(workspace, "tensorboard_logs") |
|
os.makedirs(tb_logs_dir, exist_ok=True) |
|
|
|
experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem) |
|
logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name) |
|
|
|
|
|
statistics_path = os.path.join( |
|
workspace, |
|
"statistics", |
|
task_name, |
|
filename, |
|
"config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus), |
|
"statistics.pkl", |
|
) |
|
os.makedirs(os.path.dirname(statistics_path), exist_ok=True) |
|
|
|
return checkpoints_dir, logs_dir, logger, statistics_path |
|
|
|
|
|
def _get_data_module( |
|
workspace: str, config_yaml: str, num_workers: int, distributed: bool |
|
) -> DataModule: |
|
r"""Create data_module. Mini-batch data can be obtained by: |
|
|
|
code-block:: python |
|
|
|
data_module.setup() |
|
for batch_data_dict in data_module.train_dataloader(): |
|
print(batch_data_dict.keys()) |
|
break |
|
|
|
Args: |
|
workspace: str |
|
config_yaml: str |
|
num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores |
|
for preparing data in parallel |
|
distributed: bool |
|
|
|
Returns: |
|
data_module: DataModule |
|
""" |
|
|
|
configs = read_yaml(config_yaml) |
|
input_source_types = configs['train']['input_source_types'] |
|
indexes_path = os.path.join(workspace, configs['train']['indexes_dict']) |
|
sample_rate = configs['train']['sample_rate'] |
|
segment_seconds = configs['train']['segment_seconds'] |
|
mixaudio_dict = configs['train']['augmentations']['mixaudio'] |
|
augmentations = configs['train']['augmentations'] |
|
max_pitch_shift = max( |
|
[ |
|
augmentations['pitch_shift'][source_type] |
|
for source_type in input_source_types |
|
] |
|
) |
|
batch_size = configs['train']['batch_size'] |
|
steps_per_epoch = configs['train']['steps_per_epoch'] |
|
|
|
segment_samples = int(segment_seconds * sample_rate) |
|
ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift)) |
|
|
|
|
|
train_sampler = SegmentSampler( |
|
indexes_path=indexes_path, |
|
segment_samples=ex_segment_samples, |
|
mixaudio_dict=mixaudio_dict, |
|
batch_size=batch_size, |
|
steps_per_epoch=steps_per_epoch, |
|
) |
|
|
|
|
|
augmentor = Augmentor(augmentations=augmentations) |
|
|
|
|
|
train_dataset = Dataset(augmentor, segment_samples) |
|
|
|
|
|
data_module = DataModule( |
|
train_sampler=train_sampler, |
|
train_dataset=train_dataset, |
|
num_workers=num_workers, |
|
distributed=distributed, |
|
) |
|
|
|
return data_module |
|
|
|
|
|
def train(args) -> NoReturn: |
|
r"""Train & evaluate and save checkpoints. |
|
|
|
Args: |
|
workspace: str, directory of workspace |
|
gpus: int |
|
config_yaml: str, path of config file for training |
|
""" |
|
|
|
|
|
workspace = args.workspace |
|
gpus = args.gpus |
|
config_yaml = args.config_yaml |
|
filename = args.filename |
|
|
|
num_workers = 8 |
|
distributed = True if gpus > 1 else False |
|
evaluate_device = "cuda" if gpus > 0 else "cpu" |
|
|
|
|
|
configs = read_yaml(config_yaml) |
|
check_configs_gramma(configs) |
|
task_name = configs['task_name'] |
|
target_source_types = configs['train']['target_source_types'] |
|
target_sources_num = len(target_source_types) |
|
channels = configs['train']['channels'] |
|
batch_data_preprocessor_type = configs['train']['batch_data_preprocessor'] |
|
model_type = configs['train']['model_type'] |
|
loss_type = configs['train']['loss_type'] |
|
optimizer_type = configs['train']['optimizer_type'] |
|
learning_rate = float(configs['train']['learning_rate']) |
|
precision = configs['train']['precision'] |
|
early_stop_steps = configs['train']['early_stop_steps'] |
|
warm_up_steps = configs['train']['warm_up_steps'] |
|
reduce_lr_steps = configs['train']['reduce_lr_steps'] |
|
|
|
|
|
checkpoints_dir, logs_dir, logger, statistics_path = get_dirs( |
|
workspace, task_name, filename, config_yaml, gpus |
|
) |
|
|
|
|
|
data_module = _get_data_module( |
|
workspace=workspace, |
|
config_yaml=config_yaml, |
|
num_workers=num_workers, |
|
distributed=distributed, |
|
) |
|
|
|
|
|
BatchDataPreprocessor = get_batch_data_preprocessor_class( |
|
batch_data_preprocessor_type=batch_data_preprocessor_type |
|
) |
|
|
|
batch_data_preprocessor = BatchDataPreprocessor( |
|
target_source_types=target_source_types |
|
) |
|
|
|
|
|
Model = get_model_class(model_type=model_type) |
|
model = Model(input_channels=channels, target_sources_num=target_sources_num) |
|
|
|
|
|
loss_function = get_loss_function(loss_type=loss_type) |
|
|
|
|
|
callbacks = get_callbacks( |
|
task_name=task_name, |
|
config_yaml=config_yaml, |
|
workspace=workspace, |
|
checkpoints_dir=checkpoints_dir, |
|
statistics_path=statistics_path, |
|
logger=logger, |
|
model=model, |
|
evaluate_device=evaluate_device, |
|
) |
|
|
|
|
|
|
|
lr_lambda = partial( |
|
get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps |
|
) |
|
|
|
|
|
pl_model = LitSourceSeparation( |
|
batch_data_preprocessor=batch_data_preprocessor, |
|
model=model, |
|
optimizer_type=optimizer_type, |
|
loss_function=loss_function, |
|
learning_rate=learning_rate, |
|
lr_lambda=lr_lambda, |
|
) |
|
|
|
|
|
trainer = pl.Trainer( |
|
checkpoint_callback=False, |
|
gpus=gpus, |
|
callbacks=callbacks, |
|
max_steps=early_stop_steps, |
|
accelerator="ddp", |
|
sync_batchnorm=True, |
|
precision=precision, |
|
replace_sampler_ddp=False, |
|
plugins=[DDPPlugin(find_unused_parameters=True)], |
|
profiler='simple', |
|
) |
|
|
|
|
|
trainer.fit(pl_model, data_module) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description="") |
|
subparsers = parser.add_subparsers(dest="mode") |
|
|
|
parser_train = subparsers.add_parser("train") |
|
parser_train.add_argument( |
|
"--workspace", type=str, required=True, help="Directory of workspace." |
|
) |
|
parser_train.add_argument("--gpus", type=int, required=True) |
|
parser_train.add_argument( |
|
"--config_yaml", |
|
type=str, |
|
required=True, |
|
help="Path of config file for training.", |
|
) |
|
|
|
args = parser.parse_args() |
|
args.filename = pathlib.Path(__file__).stem |
|
|
|
if args.mode == "train": |
|
train(args) |
|
|
|
else: |
|
raise Exception("Error argument!") |
|
|