|
from typing import Dict, List, NoReturn, Optional |
|
|
|
import h5py |
|
import librosa |
|
import numpy as np |
|
import torch |
|
from pytorch_lightning.core.datamodule import LightningDataModule |
|
|
|
from bytesep.data.samplers import DistributedSamplerWrapper |
|
from bytesep.utils import int16_to_float32 |
|
|
|
|
|
class DataModule(LightningDataModule): |
|
def __init__( |
|
self, |
|
train_sampler: object, |
|
train_dataset: object, |
|
num_workers: int, |
|
distributed: bool, |
|
): |
|
r"""Data module. |
|
|
|
Args: |
|
train_sampler: Sampler object |
|
train_dataset: Dataset object |
|
num_workers: int |
|
distributed: bool |
|
""" |
|
super().__init__() |
|
self._train_sampler = train_sampler |
|
self.train_dataset = train_dataset |
|
self.num_workers = num_workers |
|
self.distributed = distributed |
|
|
|
def setup(self, stage: Optional[str] = None) -> NoReturn: |
|
r"""called on every device.""" |
|
|
|
|
|
|
|
|
|
if self.distributed: |
|
self.train_sampler = DistributedSamplerWrapper(self._train_sampler) |
|
|
|
else: |
|
self.train_sampler = self._train_sampler |
|
|
|
def train_dataloader(self) -> torch.utils.data.DataLoader: |
|
r"""Get train loader.""" |
|
train_loader = torch.utils.data.DataLoader( |
|
dataset=self.train_dataset, |
|
batch_sampler=self.train_sampler, |
|
collate_fn=collate_fn, |
|
num_workers=self.num_workers, |
|
pin_memory=True, |
|
) |
|
|
|
return train_loader |
|
|
|
|
|
class Dataset: |
|
def __init__(self, augmentor: object, segment_samples: int): |
|
r"""Used for getting data according to a meta. |
|
|
|
Args: |
|
augmentor: Augmentor class |
|
segment_samples: int |
|
""" |
|
self.augmentor = augmentor |
|
self.segment_samples = segment_samples |
|
|
|
def __getitem__(self, meta: Dict) -> Dict: |
|
r"""Return data according to a meta. E.g., an input meta looks like: { |
|
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], |
|
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}. |
|
} |
|
|
|
Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation). |
|
Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation). |
|
Finally, mixture is created by summing vocals and accompaniment. |
|
|
|
Args: |
|
meta: dict, e.g., { |
|
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], |
|
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]} |
|
} |
|
|
|
Returns: |
|
data_dict: dict, e.g., { |
|
'vocals': (channels, segments_num), |
|
'accompaniment': (channels, segments_num), |
|
'mixture': (channels, segments_num), |
|
} |
|
""" |
|
source_types = meta.keys() |
|
data_dict = {} |
|
|
|
for source_type in source_types: |
|
|
|
|
|
waveforms = [] |
|
|
|
for m in meta[source_type]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hdf5_path = m['hdf5_path'] |
|
key_in_hdf5 = m['key_in_hdf5'] |
|
bgn_sample = m['begin_sample'] |
|
end_sample = m['end_sample'] |
|
|
|
with h5py.File(hdf5_path, 'r') as hf: |
|
|
|
if source_type == 'audioset': |
|
index_in_hdf5 = m['index_in_hdf5'] |
|
waveform = int16_to_float32( |
|
hf['waveform'][index_in_hdf5][bgn_sample:end_sample] |
|
) |
|
waveform = waveform[None, :] |
|
else: |
|
waveform = int16_to_float32( |
|
hf[key_in_hdf5][:, bgn_sample:end_sample] |
|
) |
|
|
|
if self.augmentor: |
|
waveform = self.augmentor(waveform, source_type) |
|
|
|
waveform = librosa.util.fix_length( |
|
waveform, size=self.segment_samples, axis=1 |
|
) |
|
|
|
|
|
waveforms.append(waveform) |
|
|
|
|
|
|
|
data_dict[source_type] = np.sum(waveforms, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixture = np.sum( |
|
[data_dict[source_type] for source_type in source_types], axis=0 |
|
) |
|
data_dict['mixture'] = mixture |
|
|
|
|
|
return data_dict |
|
|
|
|
|
def collate_fn(list_data_dict: List[Dict]) -> Dict: |
|
r"""Collate mini-batch data to inputs and targets for training. |
|
|
|
Args: |
|
list_data_dict: e.g., [ |
|
{'vocals': (channels_num, segment_samples), |
|
'accompaniment': (channels_num, segment_samples), |
|
'mixture': (channels_num, segment_samples) |
|
}, |
|
{'vocals': (channels_num, segment_samples), |
|
'accompaniment': (channels_num, segment_samples), |
|
'mixture': (channels_num, segment_samples) |
|
}, |
|
...] |
|
|
|
Returns: |
|
data_dict: e.g. { |
|
'vocals': (batch_size, channels_num, segment_samples), |
|
'accompaniment': (batch_size, channels_num, segment_samples), |
|
'mixture': (batch_size, channels_num, segment_samples) |
|
} |
|
""" |
|
data_dict = {} |
|
|
|
for key in list_data_dict[0].keys(): |
|
data_dict[key] = torch.Tensor( |
|
np.array([data_dict[key] for data_dict in list_data_dict]) |
|
) |
|
|
|
return data_dict |
|
|