Flux9665's picture
update to current version
6a79837
import torch
from Modules.EmbeddingModel.GST import GSTStyleEncoder
from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder
class StyleEmbedding(torch.nn.Module):
"""
The style embedding should provide information of the speaker and their speaking style
The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop.
The train loop does however supply supervision in the form of a barlow twins objective.
See the git history for some other approaches for style embedding, like the SWIN transformer
and a simple LSTM baseline. GST turned out to be the best.
"""
def __init__(self, embedding_dim=16, style_tts_encoder=False):
super().__init__()
self.embedding_dim = embedding_dim
self.use_gst = not style_tts_encoder
if style_tts_encoder:
self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim)
else:
self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim)
def forward(self,
batch_of_feature_sequences,
batch_of_feature_sequence_lengths):
"""
Args:
batch_of_feature_sequences: b is the batch axis, 128 features per timestep
and l time-steps, which may include padding
for most elements in the batch (b, l, 128)
batch_of_feature_sequence_lengths: indicate for every element in the batch,
what the true length is, since they are
all padded to the length of the longest
element in the batch (b, 1)
Returns:
batch of n dimensional embeddings (b,n)
"""
minimum_sequence_length = 512
specs = list()
for index, spec_length in enumerate(batch_of_feature_sequence_lengths):
spec = batch_of_feature_sequences[index][:spec_length]
# double the length at least once, then check
spec = spec.repeat((2, 1))
current_spec_length = len(spec)
while current_spec_length < minimum_sequence_length:
# make it longer
spec = spec.repeat((2, 1))
current_spec_length = len(spec)
specs.append(spec[:minimum_sequence_length])
spec_batch = torch.stack(specs, dim=0)
return self.style_encoder(speech=spec_batch)
if __name__ == '__main__':
style_emb = StyleEmbedding(style_tts_encoder=False)
print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")
seq_length = 398
print(style_emb(torch.randn(5, seq_length, 512),
torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)
style_emb = StyleEmbedding(style_tts_encoder=True)
print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")
seq_length = 398
print(style_emb(torch.randn(5, seq_length, 512),
torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)