|
import os |
|
import sys |
|
import numpy as np |
|
import argparse |
|
import h5py |
|
import math |
|
import time |
|
import logging |
|
import pickle |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def load_sdrs(workspace, task_name, filename, config, gpus): |
|
|
|
stat_path = os.path.join( |
|
workspace, |
|
"statistics", |
|
task_name, |
|
filename, |
|
"config={},gpus={}".format(config, gpus), |
|
"statistics.pkl", |
|
) |
|
|
|
stat_dict = pickle.load(open(stat_path, 'rb')) |
|
|
|
median_sdrs = [e['sdr'] for e in stat_dict['test']] |
|
|
|
return median_sdrs |
|
|
|
|
|
def plot_statistics(args): |
|
|
|
|
|
workspace = args.workspace |
|
select = args.select |
|
task_name = "vctk-musdb18" |
|
filename = "train" |
|
|
|
|
|
fig_path = os.path.join('results', task_name, "sdr_{}.pdf".format(select)) |
|
os.makedirs(os.path.dirname(fig_path), exist_ok=True) |
|
|
|
linewidth = 1 |
|
lines = [] |
|
fig, ax = plt.subplots(1, 1, figsize=(8, 6)) |
|
ylim = 30 |
|
expand = 1 |
|
|
|
if select == '1a': |
|
sdrs = load_sdrs(workspace, task_name, filename, config='unet', gpus=1) |
|
(line,) = ax.plot(sdrs, label='UNet,l1_wav', linewidth=linewidth) |
|
lines.append(line) |
|
|
|
else: |
|
raise Exception('Error!') |
|
|
|
eval_every_iterations = 10000 |
|
total_ticks = 50 |
|
ticks_freq = 10 |
|
|
|
ax.set_ylim(0, ylim) |
|
ax.set_xlim(0, total_ticks) |
|
ax.xaxis.set_ticks(np.arange(0, total_ticks + 1, ticks_freq)) |
|
ax.xaxis.set_ticklabels( |
|
np.arange( |
|
0, |
|
total_ticks * eval_every_iterations + 1, |
|
ticks_freq * eval_every_iterations, |
|
) |
|
) |
|
ax.yaxis.set_ticks(np.arange(ylim + 1)) |
|
ax.yaxis.set_ticklabels(np.arange(ylim + 1)) |
|
ax.grid(color='b', linestyle='solid', linewidth=0.3) |
|
plt.legend(handles=lines, loc=4) |
|
|
|
plt.savefig(fig_path) |
|
print('Save figure to {}'.format(fig_path)) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--workspace', type=str, required=True) |
|
parser.add_argument('--select', type=str, required=True) |
|
|
|
args = parser.parse_args() |
|
|
|
plot_statistics(args) |
|
|