maskgct / models /svc /base /svc_dataset.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
22.1 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from torch.nn.utils.rnn import pad_sequence
import json
import os
import numpy as np
import librosa
from utils.data_utils import *
from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
from processors.content_extractor import (
ContentvecExtractor,
WhisperExtractor,
WenetExtractor,
)
from models.base.base_dataset import (
BaseOfflineDataset,
BaseOfflineCollator,
BaseOnlineDataset,
BaseOnlineCollator,
)
from models.base.new_dataset import BaseTestDataset
EPS = 1.0e-12
class SVCOfflineDataset(BaseOfflineDataset):
def __init__(self, cfg, dataset, is_valid=False):
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
cfg = self.cfg
if cfg.model.condition_encoder.use_whisper:
self.whisper_aligner = WhisperExtractor(self.cfg)
self.utt2whisper_path = load_content_feature_path(
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
)
if cfg.model.condition_encoder.use_contentvec:
self.contentvec_aligner = ContentvecExtractor(self.cfg)
self.utt2contentVec_path = load_content_feature_path(
self.metadata,
cfg.preprocess.processed_dir,
cfg.preprocess.contentvec_dir,
)
if cfg.model.condition_encoder.use_mert:
self.utt2mert_path = load_content_feature_path(
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
)
if cfg.model.condition_encoder.use_wenet:
self.wenet_aligner = WenetExtractor(self.cfg)
self.utt2wenet_path = load_content_feature_path(
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
)
def __getitem__(self, index):
single_feature = BaseOfflineDataset.__getitem__(self, index)
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
if self.cfg.model.condition_encoder.use_whisper:
assert "target_len" in single_feature.keys()
aligned_whisper_feat = (
self.whisper_aligner.offline_resolution_transformation(
np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
)
)
single_feature["whisper_feat"] = aligned_whisper_feat
if self.cfg.model.condition_encoder.use_contentvec:
assert "target_len" in single_feature.keys()
aligned_contentvec = (
self.contentvec_aligner.offline_resolution_transformation(
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
)
)
single_feature["contentvec_feat"] = aligned_contentvec
if self.cfg.model.condition_encoder.use_mert:
assert "target_len" in single_feature.keys()
aligned_mert_feat = align_content_feature_length(
np.load(self.utt2mert_path[utt]),
single_feature["target_len"],
source_hop=self.cfg.preprocess.mert_hop_size,
)
single_feature["mert_feat"] = aligned_mert_feat
if self.cfg.model.condition_encoder.use_wenet:
assert "target_len" in single_feature.keys()
aligned_wenet_feat = self.wenet_aligner.offline_resolution_transformation(
np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
)
single_feature["wenet_feat"] = aligned_wenet_feat
# print(single_feature.keys())
# for k, v in single_feature.items():
# if type(v) in [torch.Tensor, np.ndarray]:
# print(k, v.shape)
# else:
# print(k, v)
# exit()
return self.clip_if_too_long(single_feature)
def __len__(self):
return len(self.metadata)
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
"""
ending_ts: to avoid invalid whisper features for over 30s audios
2812 = 30 * 24000 // 256
"""
ts = max(feature_seq_len - max_seq_len, 0)
ts = min(ts, ending_ts - max_seq_len)
start = random.randint(0, ts)
end = start + max_seq_len
return start, end
def clip_if_too_long(self, sample, max_seq_len=512):
"""
sample :
{
'spk_id': (1,),
'target_len': int
'mel': (seq_len, dim),
'frame_pitch': (seq_len,)
'frame_energy': (seq_len,)
'content_vector_feat': (seq_len, dim)
}
"""
if sample["target_len"] <= max_seq_len:
return sample
start, end = self.random_select(sample["target_len"], max_seq_len)
sample["target_len"] = end - start
for k in sample.keys():
if k == "audio":
# audio should be clipped in hop_size scale
sample[k] = sample[k][
start
* self.cfg.preprocess.hop_size : end
* self.cfg.preprocess.hop_size
]
elif k == "audio_len":
sample[k] = (end - start) * self.cfg.preprocess.hop_size
elif k not in ["spk_id", "target_len"]:
sample[k] = sample[k][start:end]
return sample
class SVCOnlineDataset(BaseOnlineDataset):
def __init__(self, cfg, dataset, is_valid=False):
super().__init__(cfg, dataset, is_valid=is_valid)
# Audio pretrained models' sample rates
self.all_sample_rates = {self.sample_rate}
if self.cfg.model.condition_encoder.use_whisper:
self.all_sample_rates.add(self.cfg.preprocess.whisper_sample_rate)
if self.cfg.model.condition_encoder.use_contentvec:
self.all_sample_rates.add(self.cfg.preprocess.contentvec_sample_rate)
if self.cfg.model.condition_encoder.use_wenet:
self.all_sample_rates.add(self.cfg.preprocess.wenet_sample_rate)
self.highest_sample_rate = max(list(self.all_sample_rates))
# The maximum duration (seconds) for one training sample
self.max_duration = 6.0
self.max_n_frames = int(self.max_duration * self.highest_sample_rate)
def random_select(self, wav, duration, wav_path):
"""
wav: (T,)
"""
if duration <= self.max_duration:
return wav
ts_frame = int((duration - self.max_duration) * self.highest_sample_rate)
start = random.randint(0, ts_frame)
end = start + self.max_n_frames
if (wav[start:end] == 0).all():
print("*" * 20)
print("Warning! The wav file {} has a lot of silience.".format(wav_path))
# There should be at least some frames that are not silience. Then we select them.
assert (wav != 0).any()
start = np.where(wav != 0)[0][0]
end = start + self.max_n_frames
return wav[start:end]
def __getitem__(self, index):
"""
single_feature: dict,
wav: (T,)
wav_len: int
target_len: int
mask: (n_frames, 1)
spk_id
wav_{sr}: (T,)
wav_{sr}_len: int
"""
single_feature = dict()
utt_item = self.metadata[index]
wav_path = utt_item["Path"]
### Use the highest sampling rate to load and randomly select ###
highest_sr_wav, _ = librosa.load(wav_path, sr=self.highest_sample_rate)
highest_sr_wav = self.random_select(
highest_sr_wav, utt_item["Duration"], wav_path
)
### Waveforms under all the sample rates ###
for sr in self.all_sample_rates:
# Resample to the required sample rate
if sr != self.highest_sample_rate:
wav_sr = librosa.resample(
highest_sr_wav, orig_sr=self.highest_sample_rate, target_sr=sr
)
else:
wav_sr = highest_sr_wav
wav_sr = torch.as_tensor(wav_sr, dtype=torch.float32)
single_feature["wav_{}".format(sr)] = wav_sr
single_feature["wav_{}_len".format(sr)] = len(wav_sr)
# For target sample rate
if sr == self.sample_rate:
wav_len = len(wav_sr)
frame_len = wav_len // self.hop_size
single_feature["wav"] = wav_sr
single_feature["wav_len"] = wav_len
single_feature["target_len"] = frame_len
single_feature["mask"] = torch.ones(frame_len, 1, dtype=torch.long)
### Speaker ID ###
if self.cfg.preprocess.use_spkid:
utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
single_feature["spk_id"] = torch.tensor(
[self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
)
return single_feature
def __len__(self):
return len(self.metadata)
class SVCOfflineCollator(BaseOfflineCollator):
def __init__(self, cfg):
super().__init__(cfg)
def __call__(self, batch):
parsed_batch_features = super().__call__(batch)
return parsed_batch_features
class SVCOnlineCollator(BaseOnlineCollator):
def __init__(self, cfg):
super().__init__(cfg)
def __call__(self, batch):
"""
SVCOnlineDataset.__getitem__:
wav: (T,)
wav_len: int
target_len: int
mask: (n_frames, 1)
spk_id: (1)
wav_{sr}: (T,)
wav_{sr}_len: int
Returns:
wav: (B, T), torch.float32
wav_len: (B), torch.long
target_len: (B), torch.long
mask: (B, n_frames, 1), torch.long
spk_id: (B, 1), torch.int32
wav_{sr}: (B, T)
wav_{sr}_len: (B), torch.long
"""
packed_batch_features = dict()
for key in batch[0].keys():
if "_len" in key:
packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
else:
packed_batch_features[key] = pad_sequence(
[b[key] for b in batch], batch_first=True, padding_value=0
)
return packed_batch_features
class SVCTestDataset(BaseTestDataset):
def __init__(self, args, cfg, infer_type):
BaseTestDataset.__init__(self, args, cfg, infer_type)
self.metadata = self.get_metadata()
target_singer = args.target_singer
self.cfg = cfg
self.trans_key = args.trans_key
assert type(target_singer) == str
self.target_singer = target_singer.split("_")[-1]
self.target_dataset = target_singer.replace(
"_{}".format(self.target_singer), ""
)
if cfg.preprocess.mel_min_max_norm:
if self.cfg.preprocess.features_extraction_mode == "online":
# TODO: Change the hard code
# Using an empirical mel extrema to normalize
self.target_mel_extrema = load_mel_extrema(cfg.preprocess, "vctk")
else:
self.target_mel_extrema = load_mel_extrema(
cfg.preprocess, self.target_dataset
)
self.target_mel_extrema = torch.as_tensor(
self.target_mel_extrema[0]
), torch.as_tensor(self.target_mel_extrema[1])
######### Load source acoustic features #########
if cfg.preprocess.use_spkid:
spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
# utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
with open(spk2id_path, "r", encoding="utf-8") as f:
self.spk2id = json.load(f)
# print("self.spk2id", self.spk2id)
if cfg.preprocess.use_uv:
self.utt2uv_path = {
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
cfg.preprocess.processed_dir,
utt_info["Dataset"],
cfg.preprocess.uv_dir,
utt_info["Uid"] + ".npy",
)
for utt_info in self.metadata
}
if cfg.preprocess.use_frame_pitch:
self.utt2frame_pitch_path = {
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
cfg.preprocess.processed_dir,
utt_info["Dataset"],
cfg.preprocess.pitch_dir,
utt_info["Uid"] + ".npy",
)
for utt_info in self.metadata
}
# Target F0 median
target_f0_statistics_path = os.path.join(
cfg.preprocess.processed_dir,
self.target_dataset,
cfg.preprocess.pitch_dir,
"statistics.json",
)
self.target_pitch_median = json.load(
open(target_f0_statistics_path, "r", encoding="utf-8")
)[f"{self.target_dataset}_{self.target_singer}"]["voiced_positions"][
"median"
]
# Source F0 median (if infer from file)
if infer_type == "from_file":
source_audio_name = cfg.inference.source_audio_name
source_f0_statistics_path = os.path.join(
cfg.preprocess.processed_dir,
source_audio_name,
cfg.preprocess.pitch_dir,
"statistics.json",
)
self.source_pitch_median = json.load(
open(source_f0_statistics_path, "r", encoding="utf-8")
)[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
"median"
]
else:
self.source_pitch_median = None
if cfg.preprocess.use_frame_energy:
self.utt2frame_energy_path = {
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
cfg.preprocess.processed_dir,
utt_info["Dataset"],
cfg.preprocess.energy_dir,
utt_info["Uid"] + ".npy",
)
for utt_info in self.metadata
}
if cfg.preprocess.use_mel:
self.utt2mel_path = {
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
cfg.preprocess.processed_dir,
utt_info["Dataset"],
cfg.preprocess.mel_dir,
utt_info["Uid"] + ".npy",
)
for utt_info in self.metadata
}
######### Load source content features' path #########
if cfg.model.condition_encoder.use_whisper:
self.whisper_aligner = WhisperExtractor(cfg)
self.utt2whisper_path = load_content_feature_path(
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
)
if cfg.model.condition_encoder.use_contentvec:
self.contentvec_aligner = ContentvecExtractor(cfg)
self.utt2contentVec_path = load_content_feature_path(
self.metadata,
cfg.preprocess.processed_dir,
cfg.preprocess.contentvec_dir,
)
if cfg.model.condition_encoder.use_mert:
self.utt2mert_path = load_content_feature_path(
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
)
if cfg.model.condition_encoder.use_wenet:
self.wenet_aligner = WenetExtractor(cfg)
self.utt2wenet_path = load_content_feature_path(
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
)
def __getitem__(self, index):
single_feature = {}
utt_info = self.metadata[index]
dataset = utt_info["Dataset"]
uid = utt_info["Uid"]
utt = "{}_{}".format(dataset, uid)
source_dataset = self.metadata[index]["Dataset"]
if self.cfg.preprocess.use_spkid:
single_feature["spk_id"] = np.array(
[self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
dtype=np.int32,
)
######### Get Acoustic Features Item #########
if self.cfg.preprocess.use_mel:
mel = np.load(self.utt2mel_path[utt])
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
if self.cfg.preprocess.use_min_max_norm_mel:
# mel norm
mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = mel.shape[1]
single_feature["mel"] = mel.T # [T, n_mels]
if self.cfg.preprocess.use_frame_pitch:
frame_pitch_path = self.utt2frame_pitch_path[utt]
frame_pitch = np.load(frame_pitch_path)
if self.trans_key:
try:
self.trans_key = int(self.trans_key)
except:
pass
if type(self.trans_key) == int:
frame_pitch = transpose_key(frame_pitch, self.trans_key)
elif self.trans_key:
assert self.target_singer
frame_pitch = pitch_shift_to_target(
frame_pitch, self.target_pitch_median, self.source_pitch_median
)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_pitch)
aligned_frame_pitch = align_length(
frame_pitch, single_feature["target_len"]
)
single_feature["frame_pitch"] = aligned_frame_pitch
if self.cfg.preprocess.use_uv:
frame_uv_path = self.utt2uv_path[utt]
frame_uv = np.load(frame_uv_path)
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
aligned_frame_uv = [
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
]
aligned_frame_uv = np.array(aligned_frame_uv)
single_feature["frame_uv"] = aligned_frame_uv
if self.cfg.preprocess.use_frame_energy:
frame_energy_path = self.utt2frame_energy_path[utt]
frame_energy = np.load(frame_energy_path)
if "target_len" not in single_feature.keys():
single_feature["target_len"] = len(frame_energy)
aligned_frame_energy = align_length(
frame_energy, single_feature["target_len"]
)
single_feature["frame_energy"] = aligned_frame_energy
######### Get Content Features Item #########
if self.cfg.model.condition_encoder.use_whisper:
assert "target_len" in single_feature.keys()
aligned_whisper_feat = (
self.whisper_aligner.offline_resolution_transformation(
np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
)
)
single_feature["whisper_feat"] = aligned_whisper_feat
if self.cfg.model.condition_encoder.use_contentvec:
assert "target_len" in single_feature.keys()
aligned_contentvec = (
self.contentvec_aligner.offline_resolution_transformation(
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
)
)
single_feature["contentvec_feat"] = aligned_contentvec
if self.cfg.model.condition_encoder.use_mert:
assert "target_len" in single_feature.keys()
aligned_mert_feat = align_content_feature_length(
np.load(self.utt2mert_path[utt]),
single_feature["target_len"],
source_hop=self.cfg.preprocess.mert_hop_size,
)
single_feature["mert_feat"] = aligned_mert_feat
if self.cfg.model.condition_encoder.use_wenet:
assert "target_len" in single_feature.keys()
aligned_wenet_feat = self.wenet_aligner.offline_resolution_transformation(
np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
)
single_feature["wenet_feat"] = aligned_wenet_feat
return single_feature
def __len__(self):
return len(self.metadata)
class SVCTestCollator:
"""Zero-pads model inputs and targets based on number of frames per step"""
def __init__(self, cfg):
self.cfg = cfg
def __call__(self, batch):
packed_batch_features = dict()
# mel: [b, T, n_mels]
# frame_pitch, frame_energy: [1, T]
# target_len: [1]
# spk_id: [b, 1]
# mask: [b, T, 1]
for key in batch[0].keys():
if key == "target_len":
packed_batch_features["target_len"] = torch.LongTensor(
[b["target_len"] for b in batch]
)
masks = [
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
]
packed_batch_features["mask"] = pad_sequence(
masks, batch_first=True, padding_value=0
)
else:
values = [torch.from_numpy(b[key]) for b in batch]
packed_batch_features[key] = pad_sequence(
values, batch_first=True, padding_value=0
)
return packed_batch_features