|
from tasks.tts.fs2 import FastSpeech2Task |
|
from modules.syntaspeech.multi_window_disc import Discriminator |
|
from utils.hparams import hparams |
|
from torch import nn |
|
import torch |
|
import torch.optim |
|
import torch.utils.data |
|
import utils |
|
|
|
|
|
class FastSpeech2AdvTask(FastSpeech2Task): |
|
def build_model(self): |
|
self.build_tts_model() |
|
if hparams['load_ckpt'] != '': |
|
self.load_ckpt(hparams['load_ckpt'], strict=False) |
|
utils.print_arch(self.model, 'Generator') |
|
self.build_disc_model() |
|
if not hasattr(self, 'gen_params'): |
|
self.gen_params = list(self.model.parameters()) |
|
return self.model |
|
|
|
def build_disc_model(self): |
|
disc_win_num = hparams['disc_win_num'] |
|
h = hparams['mel_disc_hidden_size'] |
|
self.mel_disc = Discriminator( |
|
time_lengths=[32, 64, 128][:disc_win_num], |
|
freq_length=80, hidden_size=h, kernel=(3, 3) |
|
) |
|
self.disc_params = list(self.mel_disc.parameters()) |
|
utils.print_arch(self.mel_disc, model_name='Mel Disc') |
|
|
|
def _training_step(self, sample, batch_idx, optimizer_idx): |
|
log_outputs = {} |
|
loss_weights = {} |
|
disc_start = hparams['mel_gan'] and self.global_step >= hparams["disc_start_steps"] and \ |
|
hparams['lambda_mel_adv'] > 0 |
|
if optimizer_idx == 0: |
|
|
|
|
|
|
|
log_outputs, model_out = self.run_model(self.model, sample, return_output=True) |
|
self.model_out = {k: v.detach() for k, v in model_out.items() if isinstance(v, torch.Tensor)} |
|
if disc_start: |
|
self.disc_cond = disc_cond = self.model_out['decoder_inp'].detach() \ |
|
if hparams['use_cond_disc'] else None |
|
if hparams['mel_loss_no_noise']: |
|
self.add_mel_loss(model_out['mel_out_nonoise'], sample['mels'], log_outputs) |
|
mel_p = model_out['mel_out'] |
|
if hasattr(self.model, 'out2mel'): |
|
mel_p = self.model.out2mel(mel_p) |
|
o_ = self.mel_disc(mel_p, disc_cond) |
|
p_, pc_ = o_['y'], o_['y_c'] |
|
|
|
if p_ is not None: |
|
log_outputs['a'] = self.mse_loss_fn(p_, p_.new_ones(p_.size())) |
|
loss_weights['a'] = hparams['lambda_mel_adv'] |
|
if pc_ is not None: |
|
log_outputs['ac'] = self.mse_loss_fn(pc_, pc_.new_ones(pc_.size())) |
|
loss_weights['ac'] = hparams['lambda_mel_adv'] |
|
else: |
|
|
|
|
|
|
|
if disc_start and self.global_step % hparams['disc_interval'] == 0: |
|
if hparams['rerun_gen']: |
|
with torch.no_grad(): |
|
_, model_out = self.run_model(self.model, sample, return_output=True) |
|
else: |
|
model_out = self.model_out |
|
mel_g = sample['mels'] |
|
mel_p = model_out['mel_out'] |
|
if hasattr(self.model, 'out2mel'): |
|
mel_p = self.model.out2mel(mel_p) |
|
|
|
o = self.mel_disc(mel_g, self.disc_cond) |
|
p, pc = o['y'], o['y_c'] |
|
o_ = self.mel_disc(mel_p, self.disc_cond) |
|
p_, pc_ = o_['y'], o_['y_c'] |
|
|
|
if p_ is not None: |
|
log_outputs["r"] = self.mse_loss_fn(p, p.new_ones(p.size())) |
|
log_outputs["f"] = self.mse_loss_fn(p_, p_.new_zeros(p_.size())) |
|
|
|
if pc_ is not None: |
|
log_outputs["rc"] = self.mse_loss_fn(pc, pc.new_ones(pc.size())) |
|
log_outputs["fc"] = self.mse_loss_fn(pc_, pc_.new_zeros(pc_.size())) |
|
|
|
if len(log_outputs) == 0: |
|
return None |
|
total_loss = sum([loss_weights.get(k, 1) * v for k, v in log_outputs.items()]) |
|
|
|
log_outputs['bs'] = sample['mels'].shape[0] |
|
return total_loss, log_outputs |
|
|
|
def configure_optimizers(self): |
|
if not hasattr(self, 'gen_params'): |
|
self.gen_params = list(self.model.parameters()) |
|
optimizer_gen = torch.optim.AdamW( |
|
self.gen_params, |
|
lr=hparams['lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), |
|
weight_decay=hparams['weight_decay']) |
|
optimizer_disc = torch.optim.AdamW( |
|
self.disc_params, |
|
lr=hparams['disc_lr'], |
|
betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), |
|
**hparams["discriminator_optimizer_params"]) if len(self.disc_params) > 0 else None |
|
self.scheduler = self.build_scheduler({'gen': optimizer_gen, 'disc': optimizer_disc}) |
|
return [optimizer_gen, optimizer_disc] |
|
|
|
def build_scheduler(self, optimizer): |
|
return { |
|
"gen": super().build_scheduler(optimizer['gen']), |
|
"disc": torch.optim.lr_scheduler.StepLR( |
|
optimizer=optimizer["disc"], |
|
**hparams["discriminator_scheduler_params"]) if optimizer["disc"] is not None else None, |
|
} |
|
|
|
def on_before_optimization(self, opt_idx): |
|
if opt_idx == 0: |
|
nn.utils.clip_grad_norm_(self.gen_params, hparams['generator_grad_norm']) |
|
else: |
|
nn.utils.clip_grad_norm_(self.disc_params, hparams["discriminator_grad_norm"]) |
|
|
|
def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): |
|
if optimizer_idx == 0: |
|
self.scheduler['gen'].step(self.global_step) |
|
else: |
|
self.scheduler['disc'].step(max(self.global_step - hparams["disc_start_steps"], 1)) |
|
|