|
import pickle |
|
from typing import Dict, List, NoReturn |
|
|
|
import numpy as np |
|
import torch.distributed as dist |
|
|
|
|
|
class SegmentSampler: |
|
def __init__( |
|
self, |
|
indexes_path: str, |
|
segment_samples: int, |
|
mixaudio_dict: Dict, |
|
batch_size: int, |
|
steps_per_epoch: int, |
|
random_seed=1234, |
|
): |
|
r"""Sample training indexes of sources. |
|
|
|
Args: |
|
indexes_path: str, path of indexes dict |
|
segment_samplers: int |
|
mixaudio_dict, dict, including hyper-parameters for mix-audio data |
|
augmentation, e.g., {'voclas': 2, 'accompaniment': 2} |
|
batch_size: int |
|
steps_per_epoch: int, #steps_per_epoch is called an `epoch` |
|
random_seed: int |
|
""" |
|
self.segment_samples = segment_samples |
|
self.mixaudio_dict = mixaudio_dict |
|
self.batch_size = batch_size |
|
self.steps_per_epoch = steps_per_epoch |
|
|
|
self.meta_dict = pickle.load(open(indexes_path, "rb")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.source_types = self.meta_dict.keys() |
|
|
|
|
|
self.pointers_dict = {source_type: 0 for source_type in self.source_types} |
|
|
|
|
|
self.indexes_dict = { |
|
source_type: np.arange(len(self.meta_dict[source_type])) |
|
for source_type in self.source_types |
|
} |
|
|
|
|
|
|
|
|
|
|
|
self.random_state = np.random.RandomState(random_seed) |
|
|
|
|
|
for source_type in self.source_types: |
|
self.random_state.shuffle(self.indexes_dict[source_type]) |
|
print("{}: {}".format(source_type, len(self.indexes_dict[source_type]))) |
|
|
|
def __iter__(self) -> List[Dict]: |
|
r"""Yield a batch of meta info. |
|
|
|
Returns: |
|
batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [ |
|
{'vocals': [ |
|
{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700}, |
|
{'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}] |
|
'accompaniment': [ |
|
{'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760}, |
|
{'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}] |
|
} |
|
... |
|
] |
|
""" |
|
batch_size = self.batch_size |
|
|
|
while True: |
|
batch_meta_dict = {source_type: [] for source_type in self.source_types} |
|
|
|
for source_type in self.source_types: |
|
|
|
|
|
|
|
while len(batch_meta_dict[source_type]) != batch_size: |
|
|
|
largest_index = ( |
|
len(self.indexes_dict[source_type]) |
|
- self.mixaudio_dict[source_type] |
|
) |
|
|
|
|
|
if self.pointers_dict[source_type] > largest_index: |
|
|
|
|
|
self.pointers_dict[source_type] = 0 |
|
self.random_state.shuffle(self.indexes_dict[source_type]) |
|
|
|
source_metas = [] |
|
mix_audios_num = self.mixaudio_dict[source_type] |
|
|
|
for _ in range(mix_audios_num): |
|
|
|
pointer = self.pointers_dict[source_type] |
|
|
|
|
|
index = self.indexes_dict[source_type][pointer] |
|
|
|
|
|
self.pointers_dict[source_type] += 1 |
|
|
|
source_meta = self.meta_dict[source_type][index] |
|
|
|
|
|
|
|
source_metas.append(source_meta) |
|
|
|
batch_meta_dict[source_type].append(source_metas) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_meta_list = [ |
|
{ |
|
source_type: batch_meta_dict[source_type][i] |
|
for source_type in self.source_types |
|
} |
|
for i in range(batch_size) |
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yield batch_meta_list |
|
|
|
def __len__(self) -> int: |
|
return self.steps_per_epoch |
|
|
|
def state_dict(self) -> Dict: |
|
state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict} |
|
return state |
|
|
|
def load_state_dict(self, state) -> NoReturn: |
|
self.pointers_dict = state['pointers_dict'] |
|
self.indexes_dict = state['indexes_dict'] |
|
|
|
|
|
class DistributedSamplerWrapper: |
|
def __init__(self, sampler): |
|
r"""Distributed wrapper of sampler.""" |
|
self.sampler = sampler |
|
|
|
def __iter__(self): |
|
num_replicas = dist.get_world_size() |
|
rank = dist.get_rank() |
|
|
|
for indices in self.sampler: |
|
yield indices[rank::num_replicas] |
|
|
|
def __len__(self) -> int: |
|
return len(self.sampler) |
|
|