|
defaults: |
|
- base |
|
- _self_ |
|
|
|
project: vq-gan-pretrain |
|
|
|
|
|
trainer: |
|
accelerator: gpu |
|
devices: auto |
|
precision: bf16-mixed |
|
max_steps: 1_000_000 |
|
val_check_interval: 5000 |
|
strategy: ddp_find_unused_parameters_true |
|
|
|
sample_rate: 44100 |
|
hop_length: 512 |
|
num_mels: 128 |
|
n_fft: 2048 |
|
win_length: 2048 |
|
|
|
|
|
train_dataset: |
|
_target_: torch.utils.data.ConcatDataset |
|
datasets: |
|
- _target_: fish_speech.datasets.vqgan.VQGANDataset |
|
filelist: data/gigaspeech/vq_train_filelist.txt |
|
sample_rate: ${sample_rate} |
|
hop_length: ${hop_length} |
|
slice_frames: 512 |
|
- _target_: fish_speech.datasets.vqgan.VQGANDataset |
|
filelist: data/sft/vq_train_filelist.txt |
|
sample_rate: ${sample_rate} |
|
hop_length: ${hop_length} |
|
slice_frames: 512 |
|
|
|
val_dataset: |
|
_target_: fish_speech.datasets.vqgan.VQGANDataset |
|
filelist: data/sft/vq_val_filelist.txt |
|
sample_rate: ${sample_rate} |
|
hop_length: ${hop_length} |
|
|
|
data: |
|
_target_: fish_speech.datasets.vqgan.VQGANDataModule |
|
train_dataset: ${train_dataset} |
|
val_dataset: ${val_dataset} |
|
num_workers: 4 |
|
batch_size: 32 |
|
val_batch_size: 32 |
|
|
|
|
|
model: |
|
_target_: fish_speech.models.vqgan.VQGAN |
|
|
|
sampling_rate: ${sample_rate} |
|
weight_adv: 0.2 |
|
weight_vq: 1.0 |
|
weight_mel: 1.0 |
|
freeze_encoder: false |
|
|
|
encoder: |
|
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet |
|
input_channels: ${num_mels} |
|
residual_channels: 768 |
|
residual_layers: 20 |
|
dilation_cycle: 4 |
|
|
|
quantizer: |
|
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize |
|
input_dim: 768 |
|
n_codebooks: 1 |
|
n_groups: 2 |
|
levels: [8, 5, 5, 5] |
|
|
|
decoder: |
|
_target_: fish_speech.models.vqgan.modules.wavenet.WaveNet |
|
output_channels: ${num_mels} |
|
residual_channels: 768 |
|
residual_layers: 20 |
|
dilation_cycle: 4 |
|
condition_channels: 768 |
|
|
|
discriminator: |
|
_target_: fish_speech.models.vqgan.modules.discriminator.Discriminator |
|
|
|
vocoder: |
|
_target_: fish_speech.models.vqgan.modules.firefly.FireflyBase |
|
ckpt_path: null |
|
|
|
encode_mel_transform: |
|
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram |
|
sample_rate: ${sample_rate} |
|
n_fft: ${n_fft} |
|
hop_length: ${hop_length} |
|
win_length: ${win_length} |
|
n_mels: ${num_mels} |
|
f_min: 0.0 |
|
f_max: 8000.0 |
|
|
|
gt_mel_transform: |
|
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram |
|
sample_rate: ${sample_rate} |
|
n_fft: ${n_fft} |
|
hop_length: ${hop_length} |
|
win_length: ${win_length} |
|
n_mels: ${num_mels} |
|
|
|
optimizer: |
|
_target_: torch.optim.AdamW |
|
_partial_: true |
|
lr: 1e-4 |
|
betas: [0.8, 0.99] |
|
eps: 1e-5 |
|
weight_decay: 0.01 |
|
|
|
lr_scheduler: |
|
_target_: torch.optim.lr_scheduler.LambdaLR |
|
_partial_: true |
|
lr_lambda: |
|
_target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda |
|
_partial_: true |
|
num_warmup_steps: 100 |
|
num_training_steps: ${trainer.max_steps} |
|
final_lr_ratio: 0 |
|
|
|
callbacks: |
|
model_summary: |
|
_target_: lightning.pytorch.callbacks.ModelSummary |
|
max_depth: 1 |
|
|
|
model_checkpoint: |
|
every_n_train_steps: ${trainer.val_check_interval} |
|
|
|
grad_norm_monitor: |
|
sub_module: |
|
- encoder |
|
- decoder |
|
- quantizer |
|
- discriminator |
|
|