|
from typing import Dict, List |
|
|
|
import torch |
|
|
|
|
|
class BasicBatchDataPreprocessor: |
|
def __init__(self, target_source_types: List[str]): |
|
r"""Batch data preprocessor. Used for preparing mixtures and targets for |
|
training. If there are multiple target source types, the waveforms of |
|
those sources will be stacked along the channel dimension. |
|
|
|
Args: |
|
target_source_types: List[str], e.g., ['vocals', 'bass', ...] |
|
""" |
|
self.target_source_types = target_source_types |
|
|
|
def __call__(self, batch_data_dict: Dict) -> List[Dict]: |
|
r"""Format waveforms and targets for training. |
|
|
|
Args: |
|
batch_data_dict: dict, e.g., { |
|
'mixture': (batch_size, channels_num, segment_samples), |
|
'vocals': (batch_size, channels_num, segment_samples), |
|
'bass': (batch_size, channels_num, segment_samples), |
|
..., |
|
} |
|
|
|
Returns: |
|
input_dict: dict, e.g., { |
|
'waveform': (batch_size, channels_num, segment_samples), |
|
} |
|
output_dict: dict, e.g., { |
|
'target': (batch_size, target_sources_num * channels_num, segment_samples) |
|
} |
|
""" |
|
mixtures = batch_data_dict['mixture'] |
|
|
|
|
|
|
|
targets = torch.cat( |
|
[batch_data_dict[source_type] for source_type in self.target_source_types], |
|
dim=1, |
|
) |
|
|
|
|
|
input_dict = {'waveform': mixtures} |
|
target_dict = {'waveform': targets} |
|
|
|
return input_dict, target_dict |
|
|
|
|
|
class ConditionalSisoBatchDataPreprocessor: |
|
def __init__(self, target_source_types: List[str]): |
|
r"""Conditional single input single output (SISO) batch data |
|
preprocessor. Select one target source from several target sources as |
|
training target and prepare the corresponding conditional vector. |
|
|
|
Args: |
|
target_source_types: List[str], e.g., ['vocals', 'bass', ...] |
|
""" |
|
self.target_source_types = target_source_types |
|
|
|
def __call__(self, batch_data_dict: Dict) -> List[Dict]: |
|
r"""Format waveforms and targets for training. |
|
|
|
Args: |
|
batch_data_dict: dict, e.g., { |
|
'mixture': (batch_size, channels_num, segment_samples), |
|
'vocals': (batch_size, channels_num, segment_samples), |
|
'bass': (batch_size, channels_num, segment_samples), |
|
..., |
|
} |
|
|
|
Returns: |
|
input_dict: dict, e.g., { |
|
'waveform': (batch_size, channels_num, segment_samples), |
|
'condition': (batch_size, target_sources_num), |
|
} |
|
output_dict: dict, e.g., { |
|
'target': (batch_size, channels_num, segment_samples) |
|
} |
|
""" |
|
|
|
batch_size = len(batch_data_dict['mixture']) |
|
target_sources_num = len(self.target_source_types) |
|
|
|
assert ( |
|
batch_size % target_sources_num == 0 |
|
), "Batch size should be \ |
|
evenly divided by target sources number." |
|
|
|
mixtures = batch_data_dict['mixture'] |
|
|
|
|
|
conditions = torch.zeros(batch_size, target_sources_num).to(mixtures.device) |
|
|
|
|
|
targets = [] |
|
|
|
for n in range(batch_size): |
|
|
|
k = n % target_sources_num |
|
source_type = self.target_source_types[k] |
|
|
|
targets.append(batch_data_dict[source_type][n]) |
|
|
|
conditions[n, k] = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
targets = torch.stack(targets, dim=0) |
|
|
|
|
|
input_dict = { |
|
'waveform': mixtures, |
|
'condition': conditions, |
|
} |
|
|
|
target_dict = {'waveform': targets} |
|
|
|
return input_dict, target_dict |
|
|
|
|
|
def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> object: |
|
r"""Get batch data preprocessor class.""" |
|
if batch_data_preprocessor_type == 'BasicBatchDataPreprocessor': |
|
return BasicBatchDataPreprocessor |
|
|
|
elif batch_data_preprocessor_type == 'ConditionalSisoBatchDataPreprocessor': |
|
return ConditionalSisoBatchDataPreprocessor |
|
|
|
else: |
|
raise NotImplementedError |
|
|