|
import logging |
|
import os |
|
import time |
|
from typing import List, NoReturn |
|
|
|
import librosa |
|
import numpy as np |
|
import pysepm |
|
import pytorch_lightning as pl |
|
import torch.nn as nn |
|
from pesq import pesq |
|
from pytorch_lightning.utilities import rank_zero_only |
|
|
|
from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback |
|
from bytesep.inference import Separator |
|
from bytesep.utils import StatisticsContainer, read_yaml |
|
|
|
|
|
def get_voicebank_demand_callbacks( |
|
config_yaml: str, |
|
workspace: str, |
|
checkpoints_dir: str, |
|
statistics_path: str, |
|
logger: pl.loggers.TensorBoardLogger, |
|
model: nn.Module, |
|
evaluate_device: str, |
|
) -> List[pl.Callback]: |
|
"""Get Voicebank-Demand callbacks of a config yaml. |
|
|
|
Args: |
|
config_yaml: str |
|
workspace: str |
|
checkpoints_dir: str, directory to save checkpoints |
|
statistics_dir: str, directory to save statistics |
|
logger: pl.loggers.TensorBoardLogger |
|
model: nn.Module |
|
evaluate_device: str |
|
|
|
Return: |
|
callbacks: List[pl.Callback] |
|
""" |
|
configs = read_yaml(config_yaml) |
|
task_name = configs['task_name'] |
|
target_source_types = configs['train']['target_source_types'] |
|
input_channels = configs['train']['channels'] |
|
evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name) |
|
sample_rate = configs['train']['sample_rate'] |
|
evaluate_step_frequency = configs['train']['evaluate_step_frequency'] |
|
save_step_frequency = configs['train']['save_step_frequency'] |
|
test_batch_size = configs['evaluate']['batch_size'] |
|
test_segment_seconds = configs['evaluate']['segment_seconds'] |
|
|
|
test_segment_samples = int(test_segment_seconds * sample_rate) |
|
assert len(target_source_types) == 1 |
|
target_source_type = target_source_types[0] |
|
assert target_source_type == 'speech' |
|
|
|
|
|
save_checkpoints_callback = SaveCheckpointsCallback( |
|
model=model, |
|
checkpoints_dir=checkpoints_dir, |
|
save_step_frequency=save_step_frequency, |
|
) |
|
|
|
|
|
statistics_container = StatisticsContainer(statistics_path) |
|
|
|
|
|
evaluate_test_callback = EvaluationCallback( |
|
model=model, |
|
input_channels=input_channels, |
|
sample_rate=sample_rate, |
|
evaluation_audios_dir=evaluation_audios_dir, |
|
segment_samples=test_segment_samples, |
|
batch_size=test_batch_size, |
|
device=evaluate_device, |
|
evaluate_step_frequency=evaluate_step_frequency, |
|
logger=logger, |
|
statistics_container=statistics_container, |
|
) |
|
|
|
callbacks = [save_checkpoints_callback, evaluate_test_callback] |
|
|
|
return callbacks |
|
|
|
|
|
class EvaluationCallback(pl.Callback): |
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
input_channels: int, |
|
evaluation_audios_dir, |
|
sample_rate: int, |
|
segment_samples: int, |
|
batch_size: int, |
|
device: str, |
|
evaluate_step_frequency: int, |
|
logger: pl.loggers.TensorBoardLogger, |
|
statistics_container: StatisticsContainer, |
|
): |
|
r"""Callback to evaluate every #save_step_frequency steps. |
|
|
|
Args: |
|
model: nn.Module |
|
input_channels: int |
|
evaluation_audios_dir: str, directory containing audios for evaluation |
|
sample_rate: int |
|
segment_samples: int, length of segments to be input to a model, e.g., 44100*30 |
|
batch_size, int, e.g., 12 |
|
device: str, e.g., 'cuda' |
|
evaluate_step_frequency: int, evaluate every #save_step_frequency steps |
|
logger: pl.loggers.TensorBoardLogger |
|
statistics_container: StatisticsContainer |
|
""" |
|
self.model = model |
|
self.mono = True |
|
self.sample_rate = sample_rate |
|
self.segment_samples = segment_samples |
|
self.evaluate_step_frequency = evaluate_step_frequency |
|
self.logger = logger |
|
self.statistics_container = statistics_container |
|
|
|
self.clean_dir = os.path.join(evaluation_audios_dir, "clean_testset_wav") |
|
self.noisy_dir = os.path.join(evaluation_audios_dir, "noisy_testset_wav") |
|
|
|
self.EVALUATION_SAMPLE_RATE = 16000 |
|
|
|
|
|
|
|
self.separator = Separator(model, self.segment_samples, batch_size, device) |
|
|
|
@rank_zero_only |
|
def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: |
|
r"""Evaluate losses on a few mini-batches. Losses are only used for |
|
observing training, and are not final F1 metrics. |
|
""" |
|
|
|
global_step = trainer.global_step |
|
|
|
if global_step % self.evaluate_step_frequency == 0: |
|
|
|
audio_names = sorted( |
|
[ |
|
audio_name |
|
for audio_name in sorted(os.listdir(self.clean_dir)) |
|
if audio_name.endswith('.wav') |
|
] |
|
) |
|
|
|
error_str = "Directory {} does not contain audios for evaluation!".format( |
|
self.clean_dir |
|
) |
|
assert len(audio_names) > 0, error_str |
|
|
|
pesqs, csigs, cbaks, covls, ssnrs = [], [], [], [], [] |
|
|
|
logging.info("--- Step {} ---".format(global_step)) |
|
logging.info("Total {} pieces for evaluation:".format(len(audio_names))) |
|
|
|
eval_time = time.time() |
|
|
|
for n, audio_name in enumerate(audio_names): |
|
|
|
|
|
clean_path = os.path.join(self.clean_dir, audio_name) |
|
mixture_path = os.path.join(self.noisy_dir, audio_name) |
|
|
|
mixture, _ = librosa.core.load( |
|
mixture_path, sr=self.sample_rate, mono=self.mono |
|
) |
|
|
|
if mixture.ndim == 1: |
|
mixture = mixture[None, :] |
|
|
|
|
|
|
|
input_dict = {'waveform': mixture} |
|
|
|
sep_wav = self.separator.separate(input_dict) |
|
|
|
|
|
|
|
clean, _ = librosa.core.load( |
|
clean_path, sr=self.EVALUATION_SAMPLE_RATE, mono=self.mono |
|
) |
|
|
|
|
|
sep_wav = np.squeeze(sep_wav) |
|
|
|
|
|
sep_wav = librosa.resample( |
|
sep_wav, |
|
orig_sr=self.sample_rate, |
|
target_sr=self.EVALUATION_SAMPLE_RATE, |
|
) |
|
|
|
sep_wav = librosa.util.fix_length(sep_wav, size=len(clean), axis=0) |
|
|
|
|
|
|
|
pesq_ = pesq(self.EVALUATION_SAMPLE_RATE, clean, sep_wav, 'wb') |
|
|
|
(csig, cbak, covl) = pysepm.composite( |
|
clean, sep_wav, self.EVALUATION_SAMPLE_RATE |
|
) |
|
|
|
ssnr = pysepm.SNRseg(clean, sep_wav, self.EVALUATION_SAMPLE_RATE) |
|
|
|
pesqs.append(pesq_) |
|
csigs.append(csig) |
|
cbaks.append(cbak) |
|
covls.append(covl) |
|
ssnrs.append(ssnr) |
|
print( |
|
'{}, {}, PESQ: {:.3f}, CSIG: {:.3f}, CBAK: {:.3f}, COVL: {:.3f}, SSNR: {:.3f}'.format( |
|
n, audio_name, pesq_, csig, cbak, covl, ssnr |
|
) |
|
) |
|
|
|
logging.info("-----------------------------") |
|
logging.info('Avg PESQ: {:.3f}'.format(np.mean(pesqs))) |
|
logging.info('Avg CSIG: {:.3f}'.format(np.mean(csigs))) |
|
logging.info('Avg CBAK: {:.3f}'.format(np.mean(cbaks))) |
|
logging.info('Avg COVL: {:.3f}'.format(np.mean(covls))) |
|
logging.info('Avg SSNR: {:.3f}'.format(np.mean(ssnrs))) |
|
|
|
logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) |
|
|
|
statistics = {"pesq": np.mean(pesqs)} |
|
self.statistics_container.append(global_step, statistics, 'test') |
|
self.statistics_container.dump() |
|
|