Spaces:
Sleeping
Sleeping
breadlicker45
commited on
Commit
•
c34b897
1
Parent(s):
e09f850
Upload 54 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- encoder/__init__.py +12 -0
- encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- encoder/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- encoder/__pycache__/distrib.cpython-310.pyc +0 -0
- encoder/__pycache__/distrib.cpython-38.pyc +0 -0
- encoder/__pycache__/distrib.cpython-39.pyc +0 -0
- encoder/__pycache__/model.cpython-310.pyc +0 -0
- encoder/__pycache__/model.cpython-38.pyc +0 -0
- encoder/__pycache__/model.cpython-39.pyc +0 -0
- encoder/__pycache__/utils.cpython-310.pyc +0 -0
- encoder/__pycache__/utils.cpython-38.pyc +0 -0
- encoder/__pycache__/utils.cpython-39.pyc +0 -0
- encoder/distrib.py +124 -0
- encoder/model.py +324 -0
- encoder/modules/__init__.py +22 -0
- encoder/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/conv.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/conv.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/conv.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/lstm.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/lstm.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/lstm.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/norm.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/norm.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/norm.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/seanet.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/seanet.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/seanet.cpython-39.pyc +0 -0
- encoder/modules/__pycache__/transformer.cpython-310.pyc +0 -0
- encoder/modules/__pycache__/transformer.cpython-38.pyc +0 -0
- encoder/modules/__pycache__/transformer.cpython-39.pyc +0 -0
- encoder/modules/conv.py +253 -0
- encoder/modules/lstm.py +39 -0
- encoder/modules/norm.py +28 -0
- encoder/modules/seanet.py +253 -0
- encoder/modules/transformer.py +119 -0
- encoder/msstftd.py +147 -0
- encoder/quantization/__init__.py +8 -0
- encoder/quantization/__pycache__/__init__.cpython-310.pyc +0 -0
- encoder/quantization/__pycache__/__init__.cpython-38.pyc +0 -0
- encoder/quantization/__pycache__/__init__.cpython-39.pyc +0 -0
- encoder/quantization/__pycache__/core_vq.cpython-310.pyc +0 -0
- encoder/quantization/__pycache__/core_vq.cpython-38.pyc +0 -0
- encoder/quantization/__pycache__/core_vq.cpython-39.pyc +0 -0
- encoder/quantization/__pycache__/vq.cpython-310.pyc +0 -0
- encoder/quantization/__pycache__/vq.cpython-38.pyc +0 -0
- encoder/quantization/__pycache__/vq.cpython-39.pyc +0 -0
encoder/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
# flake8: noqa
|
7 |
+
|
8 |
+
"""EnCodec neural audio codec."""
|
9 |
+
|
10 |
+
__version__ = "0.1.2a3"
|
11 |
+
|
12 |
+
from .model import EncodecModel
|
encoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (268 Bytes). View file
|
|
encoder/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (278 Bytes). View file
|
|
encoder/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (266 Bytes). View file
|
|
encoder/__pycache__/distrib.cpython-310.pyc
ADDED
Binary file (3.74 kB). View file
|
|
encoder/__pycache__/distrib.cpython-38.pyc
ADDED
Binary file (3.79 kB). View file
|
|
encoder/__pycache__/distrib.cpython-39.pyc
ADDED
Binary file (3.76 kB). View file
|
|
encoder/__pycache__/model.cpython-310.pyc
ADDED
Binary file (11.7 kB). View file
|
|
encoder/__pycache__/model.cpython-38.pyc
ADDED
Binary file (11.7 kB). View file
|
|
encoder/__pycache__/model.cpython-39.pyc
ADDED
Binary file (11.6 kB). View file
|
|
encoder/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.68 kB). View file
|
|
encoder/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (2.65 kB). View file
|
|
encoder/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (2.66 kB). View file
|
|
encoder/distrib.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Torch distributed utilities."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def rank():
|
15 |
+
if torch.distributed.is_initialized():
|
16 |
+
return torch.distributed.get_rank()
|
17 |
+
else:
|
18 |
+
return 0
|
19 |
+
|
20 |
+
|
21 |
+
def world_size():
|
22 |
+
if torch.distributed.is_initialized():
|
23 |
+
return torch.distributed.get_world_size()
|
24 |
+
else:
|
25 |
+
return 1
|
26 |
+
|
27 |
+
|
28 |
+
def is_distributed():
|
29 |
+
return world_size() > 1
|
30 |
+
|
31 |
+
|
32 |
+
def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
|
33 |
+
if is_distributed():
|
34 |
+
return torch.distributed.all_reduce(tensor, op)
|
35 |
+
|
36 |
+
|
37 |
+
def _is_complex_or_float(tensor):
|
38 |
+
return torch.is_floating_point(tensor) or torch.is_complex(tensor)
|
39 |
+
|
40 |
+
|
41 |
+
def _check_number_of_params(params: tp.List[torch.Tensor]):
|
42 |
+
# utility function to check that the number of params in all workers is the same,
|
43 |
+
# and thus avoid a deadlock with distributed all reduce.
|
44 |
+
if not is_distributed() or not params:
|
45 |
+
return
|
46 |
+
tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
|
47 |
+
all_reduce(tensor)
|
48 |
+
if tensor.item() != len(params) * world_size():
|
49 |
+
# If not all the workers have the same number, for at least one of them,
|
50 |
+
# this inequality will be verified.
|
51 |
+
raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
|
52 |
+
"at least one worker has a different one.")
|
53 |
+
|
54 |
+
|
55 |
+
def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
|
56 |
+
"""Broadcast the tensors from the given parameters to all workers.
|
57 |
+
This can be used to ensure that all workers have the same model to start with.
|
58 |
+
"""
|
59 |
+
if not is_distributed():
|
60 |
+
return
|
61 |
+
tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
|
62 |
+
_check_number_of_params(tensors)
|
63 |
+
handles = []
|
64 |
+
for tensor in tensors:
|
65 |
+
handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
|
66 |
+
handles.append(handle)
|
67 |
+
for handle in handles:
|
68 |
+
handle.wait()
|
69 |
+
|
70 |
+
|
71 |
+
def sync_buffer(buffers, average=True):
|
72 |
+
"""
|
73 |
+
Sync grad for buffers. If average is False, broadcast instead of averaging.
|
74 |
+
"""
|
75 |
+
if not is_distributed():
|
76 |
+
return
|
77 |
+
handles = []
|
78 |
+
for buffer in buffers:
|
79 |
+
if torch.is_floating_point(buffer.data):
|
80 |
+
if average:
|
81 |
+
handle = torch.distributed.all_reduce(
|
82 |
+
buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
83 |
+
else:
|
84 |
+
handle = torch.distributed.broadcast(
|
85 |
+
buffer.data, src=0, async_op=True)
|
86 |
+
handles.append((buffer, handle))
|
87 |
+
for buffer, handle in handles:
|
88 |
+
handle.wait()
|
89 |
+
if average:
|
90 |
+
buffer.data /= world_size
|
91 |
+
|
92 |
+
|
93 |
+
def sync_grad(params):
|
94 |
+
"""
|
95 |
+
Simpler alternative to DistributedDataParallel, that doesn't rely
|
96 |
+
on any black magic. For simple models it can also be as fast.
|
97 |
+
Just call this on your model parameters after the call to backward!
|
98 |
+
"""
|
99 |
+
if not is_distributed():
|
100 |
+
return
|
101 |
+
handles = []
|
102 |
+
for p in params:
|
103 |
+
if p.grad is not None:
|
104 |
+
handle = torch.distributed.all_reduce(
|
105 |
+
p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
|
106 |
+
handles.append((p, handle))
|
107 |
+
for p, handle in handles:
|
108 |
+
handle.wait()
|
109 |
+
p.grad.data /= world_size()
|
110 |
+
|
111 |
+
|
112 |
+
def average_metrics(metrics: tp.Dict[str, float], count=1.):
|
113 |
+
"""Average a dictionary of metrics across all workers, using the optional
|
114 |
+
`count` as unnormalized weight.
|
115 |
+
"""
|
116 |
+
if not is_distributed():
|
117 |
+
return metrics
|
118 |
+
keys, values = zip(*metrics.items())
|
119 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
120 |
+
tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
|
121 |
+
tensor *= count
|
122 |
+
all_reduce(tensor)
|
123 |
+
averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
|
124 |
+
return dict(zip(keys, averaged))
|
encoder/model.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""EnCodec model implementation."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
from pathlib import Path
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from . import quantization as qt
|
18 |
+
from . import modules as m
|
19 |
+
from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
|
20 |
+
|
21 |
+
|
22 |
+
ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
|
23 |
+
|
24 |
+
EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
|
25 |
+
|
26 |
+
|
27 |
+
class LMModel(nn.Module):
|
28 |
+
"""Language Model to estimate probabilities of each codebook entry.
|
29 |
+
We predict all codebooks in parallel for a given time step.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
n_q (int): number of codebooks.
|
33 |
+
card (int): codebook cardinality.
|
34 |
+
dim (int): transformer dimension.
|
35 |
+
**kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
|
36 |
+
"""
|
37 |
+
def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
|
38 |
+
super().__init__()
|
39 |
+
self.card = card
|
40 |
+
self.n_q = n_q
|
41 |
+
self.dim = dim
|
42 |
+
self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
|
43 |
+
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
|
44 |
+
self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
|
45 |
+
|
46 |
+
def forward(self, indices: torch.Tensor,
|
47 |
+
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
indices (torch.Tensor): indices from the previous time step. Indices
|
51 |
+
should be 1 + actual index in the codebook. The value 0 is reserved for
|
52 |
+
when the index is missing (i.e. first time step). Shape should be
|
53 |
+
`[B, n_q, T]`.
|
54 |
+
states: state for the streaming decoding.
|
55 |
+
offset: offset of the current time step.
|
56 |
+
|
57 |
+
Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
|
58 |
+
with a shape `[B, card, n_q, T]`.
|
59 |
+
|
60 |
+
"""
|
61 |
+
B, K, T = indices.shape
|
62 |
+
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
|
63 |
+
out, states, offset = self.transformer(input_, states, offset)
|
64 |
+
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
|
65 |
+
return torch.softmax(logits, dim=1), states, offset
|
66 |
+
|
67 |
+
|
68 |
+
class EncodecModel(nn.Module):
|
69 |
+
"""EnCodec model operating on the raw waveform.
|
70 |
+
Args:
|
71 |
+
target_bandwidths (list of float): Target bandwidths.
|
72 |
+
encoder (nn.Module): Encoder network.
|
73 |
+
decoder (nn.Module): Decoder network.
|
74 |
+
sample_rate (int): Audio sample rate.
|
75 |
+
channels (int): Number of audio channels.
|
76 |
+
normalize (bool): Whether to apply audio normalization.
|
77 |
+
segment (float or None): segment duration in sec. when doing overlap-add.
|
78 |
+
overlap (float): overlap between segment, given as a fraction of the segment duration.
|
79 |
+
name (str): name of the model, used as metadata when compressing audio.
|
80 |
+
"""
|
81 |
+
def __init__(self,
|
82 |
+
encoder: m.SEANetEncoder,
|
83 |
+
decoder: m.SEANetDecoder,
|
84 |
+
quantizer: qt.ResidualVectorQuantizer,
|
85 |
+
target_bandwidths: tp.List[float],
|
86 |
+
sample_rate: int,
|
87 |
+
channels: int,
|
88 |
+
normalize: bool = False,
|
89 |
+
segment: tp.Optional[float] = None,
|
90 |
+
overlap: float = 0.01,
|
91 |
+
name: str = 'unset'):
|
92 |
+
super().__init__()
|
93 |
+
self.bandwidth: tp.Optional[float] = None
|
94 |
+
self.target_bandwidths = target_bandwidths
|
95 |
+
self.encoder = encoder
|
96 |
+
self.quantizer = quantizer
|
97 |
+
self.decoder = decoder
|
98 |
+
self.sample_rate = sample_rate
|
99 |
+
self.channels = channels
|
100 |
+
self.normalize = normalize
|
101 |
+
self.segment = segment
|
102 |
+
self.overlap = overlap
|
103 |
+
self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios))
|
104 |
+
self.name = name
|
105 |
+
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
|
106 |
+
assert 2 ** self.bits_per_codebook == self.quantizer.bins, \
|
107 |
+
"quantizer bins must be a power of 2."
|
108 |
+
|
109 |
+
@property
|
110 |
+
def segment_length(self) -> tp.Optional[int]:
|
111 |
+
if self.segment is None:
|
112 |
+
return None
|
113 |
+
return int(self.segment * self.sample_rate)
|
114 |
+
|
115 |
+
@property
|
116 |
+
def segment_stride(self) -> tp.Optional[int]:
|
117 |
+
segment_length = self.segment_length
|
118 |
+
if segment_length is None:
|
119 |
+
return None
|
120 |
+
return max(1, int((1 - self.overlap) * segment_length))
|
121 |
+
|
122 |
+
def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
|
123 |
+
"""Given a tensor `x`, returns a list of frames containing
|
124 |
+
the discrete encoded codes for `x`, along with rescaling factors
|
125 |
+
for each segment, when `self.normalize` is True.
|
126 |
+
|
127 |
+
Each frames is a tuple `(codebook, scale)`, with `codebook` of
|
128 |
+
shape `[B, K, T]`, with `K` the number of codebooks.
|
129 |
+
"""
|
130 |
+
assert x.dim() == 3
|
131 |
+
_, channels, length = x.shape
|
132 |
+
assert channels > 0 and channels <= 2
|
133 |
+
segment_length = self.segment_length
|
134 |
+
if segment_length is None:
|
135 |
+
segment_length = length
|
136 |
+
stride = length
|
137 |
+
else:
|
138 |
+
stride = self.segment_stride # type: ignore
|
139 |
+
assert stride is not None
|
140 |
+
|
141 |
+
encoded_frames: tp.List[EncodedFrame] = []
|
142 |
+
for offset in range(0, length, stride):
|
143 |
+
frame = x[:, :, offset: offset + segment_length]
|
144 |
+
encoded_frames.append(self._encode_frame(frame))
|
145 |
+
return encoded_frames
|
146 |
+
|
147 |
+
def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
|
148 |
+
length = x.shape[-1]
|
149 |
+
duration = length / self.sample_rate
|
150 |
+
assert self.segment is None or duration <= 1e-5 + self.segment
|
151 |
+
|
152 |
+
if self.normalize:
|
153 |
+
mono = x.mean(dim=1, keepdim=True)
|
154 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
155 |
+
scale = 1e-8 + volume
|
156 |
+
x = x / scale
|
157 |
+
scale = scale.view(-1, 1)
|
158 |
+
else:
|
159 |
+
scale = None
|
160 |
+
|
161 |
+
emb = self.encoder(x)
|
162 |
+
codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
|
163 |
+
codes = codes.transpose(0, 1)
|
164 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
165 |
+
return codes, scale
|
166 |
+
|
167 |
+
def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
|
168 |
+
"""Decode the given frames into a waveform.
|
169 |
+
Note that the output might be a bit bigger than the input. In that case,
|
170 |
+
any extra steps at the end can be trimmed.
|
171 |
+
"""
|
172 |
+
segment_length = self.segment_length
|
173 |
+
if segment_length is None:
|
174 |
+
assert len(encoded_frames) == 1
|
175 |
+
return self._decode_frame(encoded_frames[0])
|
176 |
+
|
177 |
+
frames = [self._decode_frame(frame) for frame in encoded_frames]
|
178 |
+
return _linear_overlap_add(frames, self.segment_stride or 1)
|
179 |
+
|
180 |
+
def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor:
|
181 |
+
codes, scale = encoded_frame
|
182 |
+
codes = codes.transpose(0, 1)
|
183 |
+
emb = self.quantizer.decode(codes)
|
184 |
+
out = self.decoder(emb)
|
185 |
+
if scale is not None:
|
186 |
+
out = out * scale.view(-1, 1, 1)
|
187 |
+
return out
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
190 |
+
frames = self.encode(x)
|
191 |
+
return self.decode(frames)[:, :, :x.shape[-1]]
|
192 |
+
|
193 |
+
def set_target_bandwidth(self, bandwidth: float):
|
194 |
+
if bandwidth not in self.target_bandwidths:
|
195 |
+
raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. "
|
196 |
+
f"Select one of {self.target_bandwidths}.")
|
197 |
+
self.bandwidth = bandwidth
|
198 |
+
|
199 |
+
def get_lm_model(self) -> LMModel:
|
200 |
+
"""Return the associated LM model to improve the compression rate.
|
201 |
+
"""
|
202 |
+
device = next(self.parameters()).device
|
203 |
+
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
|
204 |
+
past_context=int(3.5 * self.frame_rate)).to(device)
|
205 |
+
checkpoints = {
|
206 |
+
'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
|
207 |
+
'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
|
208 |
+
}
|
209 |
+
try:
|
210 |
+
checkpoint_name = checkpoints[self.name]
|
211 |
+
except KeyError:
|
212 |
+
raise RuntimeError("No LM pre-trained for the current Encodec model.")
|
213 |
+
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
|
214 |
+
state = torch.hub.load_state_dict_from_url(
|
215 |
+
url, map_location='cpu', check_hash=True) # type: ignore
|
216 |
+
lm.load_state_dict(state)
|
217 |
+
lm.eval()
|
218 |
+
return lm
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def _get_model(target_bandwidths: tp.List[float],
|
222 |
+
sample_rate: int = 24_000,
|
223 |
+
channels: int = 1,
|
224 |
+
causal: bool = True,
|
225 |
+
model_norm: str = 'weight_norm',
|
226 |
+
audio_normalize: bool = False,
|
227 |
+
segment: tp.Optional[float] = None,
|
228 |
+
name: str = 'unset'):
|
229 |
+
encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal)
|
230 |
+
decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal)
|
231 |
+
n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10))
|
232 |
+
quantizer = qt.ResidualVectorQuantizer(
|
233 |
+
dimension=encoder.dimension,
|
234 |
+
n_q=n_q,
|
235 |
+
bins=1024,
|
236 |
+
)
|
237 |
+
model = EncodecModel(
|
238 |
+
encoder,
|
239 |
+
decoder,
|
240 |
+
quantizer,
|
241 |
+
target_bandwidths,
|
242 |
+
sample_rate,
|
243 |
+
channels,
|
244 |
+
normalize=audio_normalize,
|
245 |
+
segment=segment,
|
246 |
+
name=name,
|
247 |
+
)
|
248 |
+
return model
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None):
|
252 |
+
if repository is not None:
|
253 |
+
if not repository.is_dir():
|
254 |
+
raise ValueError(f"{repository} must exist and be a directory.")
|
255 |
+
file = repository / checkpoint_name
|
256 |
+
checksum = file.stem.split('-')[1]
|
257 |
+
_check_checksum(file, checksum)
|
258 |
+
return torch.load(file)
|
259 |
+
else:
|
260 |
+
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
|
261 |
+
return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore
|
262 |
+
|
263 |
+
@staticmethod
|
264 |
+
def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
|
265 |
+
"""Return the pretrained causal 24khz model.
|
266 |
+
"""
|
267 |
+
if repository:
|
268 |
+
assert pretrained
|
269 |
+
target_bandwidths = [1.5, 3., 6, 12., 24.]
|
270 |
+
checkpoint_name = 'encodec_24khz-d7cc33bc.th'
|
271 |
+
sample_rate = 24_000
|
272 |
+
channels = 1
|
273 |
+
model = EncodecModel._get_model(
|
274 |
+
target_bandwidths, sample_rate, channels,
|
275 |
+
causal=True, model_norm='weight_norm', audio_normalize=False,
|
276 |
+
name='encodec_24khz' if pretrained else 'unset')
|
277 |
+
if pretrained:
|
278 |
+
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
|
279 |
+
model.load_state_dict(state_dict)
|
280 |
+
model.eval()
|
281 |
+
return model
|
282 |
+
|
283 |
+
@staticmethod
|
284 |
+
def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
|
285 |
+
"""Return the pretrained 48khz model.
|
286 |
+
"""
|
287 |
+
if repository:
|
288 |
+
assert pretrained
|
289 |
+
target_bandwidths = [3., 6., 12., 24.]
|
290 |
+
checkpoint_name = 'encodec_48khz-7e698e3e.th'
|
291 |
+
sample_rate = 48_000
|
292 |
+
channels = 2
|
293 |
+
model = EncodecModel._get_model(
|
294 |
+
target_bandwidths, sample_rate, channels,
|
295 |
+
causal=False, model_norm='time_group_norm', audio_normalize=True,
|
296 |
+
segment=1., name='encodec_48khz' if pretrained else 'unset')
|
297 |
+
if pretrained:
|
298 |
+
state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
|
299 |
+
model.load_state_dict(state_dict)
|
300 |
+
model.eval()
|
301 |
+
return model
|
302 |
+
|
303 |
+
|
304 |
+
def test():
|
305 |
+
from itertools import product
|
306 |
+
import torchaudio
|
307 |
+
bandwidths = [3, 6, 12, 24]
|
308 |
+
models = {
|
309 |
+
'encodec_24khz': EncodecModel.encodec_model_24khz,
|
310 |
+
'encodec_48khz': EncodecModel.encodec_model_48khz
|
311 |
+
}
|
312 |
+
for model_name, bw in product(models.keys(), bandwidths):
|
313 |
+
model = models[model_name]()
|
314 |
+
model.set_target_bandwidth(bw)
|
315 |
+
audio_suffix = model_name.split('_')[1][:3]
|
316 |
+
wav, sr = torchaudio.load(f"test_{audio_suffix}.wav")
|
317 |
+
wav = wav[:, :model.sample_rate * 2]
|
318 |
+
wav_in = wav.unsqueeze(0)
|
319 |
+
wav_dec = model(wav_in)[0]
|
320 |
+
assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape)
|
321 |
+
|
322 |
+
|
323 |
+
if __name__ == '__main__':
|
324 |
+
test()
|
encoder/modules/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Torch modules."""
|
8 |
+
|
9 |
+
# flake8: noqa
|
10 |
+
from .conv import (
|
11 |
+
pad1d,
|
12 |
+
unpad1d,
|
13 |
+
NormConv1d,
|
14 |
+
NormConvTranspose1d,
|
15 |
+
NormConv2d,
|
16 |
+
NormConvTranspose2d,
|
17 |
+
SConv1d,
|
18 |
+
SConvTranspose1d,
|
19 |
+
)
|
20 |
+
from .lstm import SLSTM
|
21 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
22 |
+
from .transformer import StreamingTransformerEncoder
|
encoder/modules/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (557 Bytes). View file
|
|
encoder/modules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (567 Bytes). View file
|
|
encoder/modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (555 Bytes). View file
|
|
encoder/modules/__pycache__/conv.cpython-310.pyc
ADDED
Binary file (9.2 kB). View file
|
|
encoder/modules/__pycache__/conv.cpython-38.pyc
ADDED
Binary file (9.48 kB). View file
|
|
encoder/modules/__pycache__/conv.cpython-39.pyc
ADDED
Binary file (9.43 kB). View file
|
|
encoder/modules/__pycache__/lstm.cpython-310.pyc
ADDED
Binary file (1.05 kB). View file
|
|
encoder/modules/__pycache__/lstm.cpython-38.pyc
ADDED
Binary file (1.06 kB). View file
|
|
encoder/modules/__pycache__/lstm.cpython-39.pyc
ADDED
Binary file (1.05 kB). View file
|
|
encoder/modules/__pycache__/norm.cpython-310.pyc
ADDED
Binary file (1.15 kB). View file
|
|
encoder/modules/__pycache__/norm.cpython-38.pyc
ADDED
Binary file (1.15 kB). View file
|
|
encoder/modules/__pycache__/norm.cpython-39.pyc
ADDED
Binary file (1.14 kB). View file
|
|
encoder/modules/__pycache__/seanet.cpython-310.pyc
ADDED
Binary file (9.72 kB). View file
|
|
encoder/modules/__pycache__/seanet.cpython-38.pyc
ADDED
Binary file (9.65 kB). View file
|
|
encoder/modules/__pycache__/seanet.cpython-39.pyc
ADDED
Binary file (9.48 kB). View file
|
|
encoder/modules/__pycache__/transformer.cpython-310.pyc
ADDED
Binary file (4.54 kB). View file
|
|
encoder/modules/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (4.53 kB). View file
|
|
encoder/modules/__pycache__/transformer.cpython-39.pyc
ADDED
Binary file (4.47 kB). View file
|
|
encoder/modules/conv.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Convolutional layers wrappers and utilities."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
import typing as tp
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
17 |
+
|
18 |
+
from .norm import ConvLayerNorm
|
19 |
+
|
20 |
+
|
21 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
22 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
23 |
+
|
24 |
+
|
25 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
26 |
+
assert norm in CONV_NORMALIZATIONS
|
27 |
+
if norm == 'weight_norm':
|
28 |
+
return weight_norm(module)
|
29 |
+
elif norm == 'spectral_norm':
|
30 |
+
return spectral_norm(module)
|
31 |
+
else:
|
32 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
33 |
+
# doesn't need reparametrization.
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
38 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
39 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
40 |
+
"""
|
41 |
+
assert norm in CONV_NORMALIZATIONS
|
42 |
+
if norm == 'layer_norm':
|
43 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
44 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
45 |
+
elif norm == 'time_group_norm':
|
46 |
+
if causal:
|
47 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
48 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
49 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
50 |
+
else:
|
51 |
+
return nn.Identity()
|
52 |
+
|
53 |
+
|
54 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
55 |
+
padding_total: int = 0) -> int:
|
56 |
+
"""See `pad_for_conv1d`.
|
57 |
+
"""
|
58 |
+
length = x.shape[-1]
|
59 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
60 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
61 |
+
return ideal_length - length
|
62 |
+
|
63 |
+
|
64 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
65 |
+
"""Pad for a convolution to make sure that the last window is full.
|
66 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
67 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
68 |
+
might get removed.
|
69 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
70 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
71 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
72 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
73 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
74 |
+
"""
|
75 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
76 |
+
return F.pad(x, (0, extra_padding))
|
77 |
+
|
78 |
+
|
79 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
80 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
81 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
82 |
+
"""
|
83 |
+
length = x.shape[-1]
|
84 |
+
padding_left, padding_right = paddings
|
85 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
86 |
+
if mode == 'reflect':
|
87 |
+
max_pad = max(padding_left, padding_right)
|
88 |
+
extra_pad = 0
|
89 |
+
if length <= max_pad:
|
90 |
+
extra_pad = max_pad - length + 1
|
91 |
+
x = F.pad(x, (0, extra_pad))
|
92 |
+
padded = F.pad(x, paddings, mode, value)
|
93 |
+
end = padded.shape[-1] - extra_pad
|
94 |
+
return padded[..., :end]
|
95 |
+
else:
|
96 |
+
return F.pad(x, paddings, mode, value)
|
97 |
+
|
98 |
+
|
99 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
100 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
101 |
+
padding_left, padding_right = paddings
|
102 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
103 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
104 |
+
end = x.shape[-1] - padding_right
|
105 |
+
return x[..., padding_left: end]
|
106 |
+
|
107 |
+
|
108 |
+
class NormConv1d(nn.Module):
|
109 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
110 |
+
to provide a uniform interface across normalization approaches.
|
111 |
+
"""
|
112 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
113 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
114 |
+
super().__init__()
|
115 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
116 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
117 |
+
self.norm_type = norm
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.conv(x)
|
121 |
+
x = self.norm(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class NormConv2d(nn.Module):
|
126 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
127 |
+
to provide a uniform interface across normalization approaches.
|
128 |
+
"""
|
129 |
+
def __init__(self, *args, norm: str = 'none',
|
130 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
131 |
+
super().__init__()
|
132 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
133 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
134 |
+
self.norm_type = norm
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.conv(x)
|
138 |
+
x = self.norm(x)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class NormConvTranspose1d(nn.Module):
|
143 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
144 |
+
to provide a uniform interface across normalization approaches.
|
145 |
+
"""
|
146 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
147 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
148 |
+
super().__init__()
|
149 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
150 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
151 |
+
self.norm_type = norm
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
x = self.convtr(x)
|
155 |
+
x = self.norm(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class NormConvTranspose2d(nn.Module):
|
160 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
161 |
+
to provide a uniform interface across normalization approaches.
|
162 |
+
"""
|
163 |
+
def __init__(self, *args, norm: str = 'none',
|
164 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
165 |
+
super().__init__()
|
166 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
167 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
x = self.convtr(x)
|
171 |
+
x = self.norm(x)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class SConv1d(nn.Module):
|
176 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
177 |
+
and normalization.
|
178 |
+
"""
|
179 |
+
def __init__(self, in_channels: int, out_channels: int,
|
180 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
181 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
182 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
183 |
+
pad_mode: str = 'reflect'):
|
184 |
+
super().__init__()
|
185 |
+
# warn user on unusual setup between dilation and stride
|
186 |
+
if stride > 1 and dilation > 1:
|
187 |
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
188 |
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
189 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
190 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
191 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
192 |
+
self.causal = causal
|
193 |
+
self.pad_mode = pad_mode
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
B, C, T = x.shape
|
197 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
198 |
+
stride = self.conv.conv.stride[0]
|
199 |
+
dilation = self.conv.conv.dilation[0]
|
200 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
201 |
+
padding_total = kernel_size - stride
|
202 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
203 |
+
if self.causal:
|
204 |
+
# Left padding for causal
|
205 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
206 |
+
else:
|
207 |
+
# Asymmetric padding required for odd strides
|
208 |
+
padding_right = padding_total // 2
|
209 |
+
padding_left = padding_total - padding_right
|
210 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
211 |
+
return self.conv(x)
|
212 |
+
|
213 |
+
|
214 |
+
class SConvTranspose1d(nn.Module):
|
215 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
216 |
+
and normalization.
|
217 |
+
"""
|
218 |
+
def __init__(self, in_channels: int, out_channels: int,
|
219 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
220 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
221 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
222 |
+
super().__init__()
|
223 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
224 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
225 |
+
self.causal = causal
|
226 |
+
self.trim_right_ratio = trim_right_ratio
|
227 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
228 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
229 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
230 |
+
|
231 |
+
def forward(self, x):
|
232 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
233 |
+
stride = self.convtr.convtr.stride[0]
|
234 |
+
padding_total = kernel_size - stride
|
235 |
+
|
236 |
+
y = self.convtr(x)
|
237 |
+
|
238 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
239 |
+
# removed at the very end, when keeping only the right length for the output,
|
240 |
+
# as removing it here would require also passing the length at the matching layer
|
241 |
+
# in the encoder.
|
242 |
+
if self.causal:
|
243 |
+
# Trim the padding on the right according to the specified ratio
|
244 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
245 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
246 |
+
padding_left = padding_total - padding_right
|
247 |
+
y = unpad1d(y, (padding_left, padding_right))
|
248 |
+
else:
|
249 |
+
# Asymmetric padding required for odd strides
|
250 |
+
padding_right = padding_total // 2
|
251 |
+
padding_left = padding_total - padding_right
|
252 |
+
y = unpad1d(y, (padding_left, padding_right))
|
253 |
+
return y
|
encoder/modules/lstm.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""LSTM layers module."""
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
|
12 |
+
class SLSTM(nn.Module):
|
13 |
+
"""
|
14 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
15 |
+
Expects input as convolutional layout.
|
16 |
+
"""
|
17 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
18 |
+
super().__init__()
|
19 |
+
self.skip = skip
|
20 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
21 |
+
|
22 |
+
# def forward(self, x):
|
23 |
+
# x = x.permute(2, 0, 1)
|
24 |
+
# y, _ = self.lstm(x)
|
25 |
+
# if self.skip:
|
26 |
+
# y = y + x
|
27 |
+
# y = y.permute(1, 2, 0)
|
28 |
+
# return y
|
29 |
+
|
30 |
+
# 修改transpose顺序
|
31 |
+
def forward(self, x):
|
32 |
+
# # 插入reshape
|
33 |
+
# x = x.reshape(x.shape)
|
34 |
+
x1 = x.permute(2, 0, 1)
|
35 |
+
y, _ = self.lstm(x1)
|
36 |
+
y = y.permute(1, 2, 0)
|
37 |
+
if self.skip:
|
38 |
+
y = y + x
|
39 |
+
return y
|
encoder/modules/norm.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Normalization modules."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import einops
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class ConvLayerNorm(nn.LayerNorm):
|
17 |
+
"""
|
18 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
19 |
+
before running the normalization and moves them back to original position right after.
|
20 |
+
"""
|
21 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
22 |
+
super().__init__(normalized_shape, **kwargs)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
26 |
+
x = super().forward(x)
|
27 |
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
28 |
+
return
|
encoder/modules/seanet.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Encodec SEANet-based encoder and decoder implementation."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from . import (
|
15 |
+
SConv1d,
|
16 |
+
SConvTranspose1d,
|
17 |
+
SLSTM
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class SEANetResnetBlock(nn.Module):
|
22 |
+
"""Residual block from SEANet model.
|
23 |
+
Args:
|
24 |
+
dim (int): Dimension of the input/output
|
25 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
26 |
+
dilations (list): List of dilations for the convolutions.
|
27 |
+
activation (str): Activation function.
|
28 |
+
activation_params (dict): Parameters to provide to the activation function
|
29 |
+
norm (str): Normalization method.
|
30 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
31 |
+
causal (bool): Whether to use fully causal convolution.
|
32 |
+
pad_mode (str): Padding mode for the convolutions.
|
33 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3)
|
34 |
+
true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
|
35 |
+
"""
|
36 |
+
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
37 |
+
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
38 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
39 |
+
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
40 |
+
super().__init__()
|
41 |
+
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
42 |
+
act = getattr(nn, activation)
|
43 |
+
hidden = dim // compress
|
44 |
+
block = []
|
45 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
46 |
+
in_chs = dim if i == 0 else hidden
|
47 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
48 |
+
block += [
|
49 |
+
act(**activation_params),
|
50 |
+
SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
51 |
+
norm=norm, norm_kwargs=norm_params,
|
52 |
+
causal=causal, pad_mode=pad_mode),
|
53 |
+
]
|
54 |
+
self.block = nn.Sequential(*block)
|
55 |
+
self.shortcut: nn.Module
|
56 |
+
if true_skip:
|
57 |
+
self.shortcut = nn.Identity()
|
58 |
+
else:
|
59 |
+
self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
60 |
+
causal=causal, pad_mode=pad_mode)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return self.shortcut(x) + self.block(x)
|
64 |
+
|
65 |
+
|
66 |
+
class SEANetEncoder(nn.Module):
|
67 |
+
"""SEANet encoder.
|
68 |
+
Args:
|
69 |
+
channels (int): Audio channels.
|
70 |
+
dimension (int): Intermediate representation dimension.
|
71 |
+
n_filters (int): Base width for the model.
|
72 |
+
n_residual_layers (int): nb of residual layers.
|
73 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
74 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
75 |
+
that must match the decoder order
|
76 |
+
activation (str): Activation function.
|
77 |
+
activation_params (dict): Parameters to provide to the activation function
|
78 |
+
norm (str): Normalization method.
|
79 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
80 |
+
kernel_size (int): Kernel size for the initial convolution.
|
81 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
82 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
83 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
84 |
+
causal (bool): Whether to use fully causal convolution.
|
85 |
+
pad_mode (str): Padding mode for the convolutions.
|
86 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
87 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
88 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
89 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
90 |
+
"""
|
91 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
|
92 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
93 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
94 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
95 |
+
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
|
96 |
+
super().__init__()
|
97 |
+
self.channels = channels
|
98 |
+
self.dimension = dimension
|
99 |
+
self.n_filters = n_filters
|
100 |
+
self.ratios = list(reversed(ratios))
|
101 |
+
del ratios
|
102 |
+
self.n_residual_layers = n_residual_layers
|
103 |
+
self.hop_length = np.prod(self.ratios)
|
104 |
+
|
105 |
+
act = getattr(nn, activation)
|
106 |
+
mult = 1
|
107 |
+
model: tp.List[nn.Module] = [
|
108 |
+
SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
|
109 |
+
causal=causal, pad_mode=pad_mode)
|
110 |
+
]
|
111 |
+
# Downsample to raw audio scale
|
112 |
+
for i, ratio in enumerate(self.ratios):
|
113 |
+
# Add residual layers
|
114 |
+
for j in range(n_residual_layers):
|
115 |
+
model += [
|
116 |
+
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
117 |
+
dilations=[dilation_base ** j, 1],
|
118 |
+
norm=norm, norm_params=norm_params,
|
119 |
+
activation=activation, activation_params=activation_params,
|
120 |
+
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
121 |
+
|
122 |
+
# Add downsampling layers
|
123 |
+
model += [
|
124 |
+
act(**activation_params),
|
125 |
+
SConv1d(mult * n_filters, mult * n_filters * 2,
|
126 |
+
kernel_size=ratio * 2, stride=ratio,
|
127 |
+
norm=norm, norm_kwargs=norm_params,
|
128 |
+
causal=causal, pad_mode=pad_mode),
|
129 |
+
]
|
130 |
+
mult *= 2
|
131 |
+
|
132 |
+
if lstm:
|
133 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
134 |
+
|
135 |
+
model += [
|
136 |
+
act(**activation_params),
|
137 |
+
SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
|
138 |
+
causal=causal, pad_mode=pad_mode)
|
139 |
+
]
|
140 |
+
|
141 |
+
self.model = nn.Sequential(*model)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
return self.model(x)
|
145 |
+
|
146 |
+
|
147 |
+
class SEANetDecoder(nn.Module):
|
148 |
+
"""SEANet decoder.
|
149 |
+
Args:
|
150 |
+
channels (int): Audio channels.
|
151 |
+
dimension (int): Intermediate representation dimension.
|
152 |
+
n_filters (int): Base width for the model.
|
153 |
+
n_residual_layers (int): nb of residual layers.
|
154 |
+
ratios (Sequence[int]): kernel size and stride ratios
|
155 |
+
activation (str): Activation function.
|
156 |
+
activation_params (dict): Parameters to provide to the activation function
|
157 |
+
final_activation (str): Final activation function after all convolutions.
|
158 |
+
final_activation_params (dict): Parameters to provide to the activation function
|
159 |
+
norm (str): Normalization method.
|
160 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
161 |
+
kernel_size (int): Kernel size for the initial convolution.
|
162 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
163 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
164 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
165 |
+
causal (bool): Whether to use fully causal convolution.
|
166 |
+
pad_mode (str): Padding mode for the convolutions.
|
167 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
168 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
169 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
170 |
+
lstm (int): Number of LSTM layers at the end of the encoder.
|
171 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
172 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
173 |
+
"""
|
174 |
+
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
|
175 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
176 |
+
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
177 |
+
norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
178 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
179 |
+
pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2,
|
180 |
+
trim_right_ratio: float = 1.0):
|
181 |
+
super().__init__()
|
182 |
+
self.dimension = dimension
|
183 |
+
self.channels = channels
|
184 |
+
self.n_filters = n_filters
|
185 |
+
self.ratios = ratios
|
186 |
+
del ratios
|
187 |
+
self.n_residual_layers = n_residual_layers
|
188 |
+
self.hop_length = np.prod(self.ratios)
|
189 |
+
|
190 |
+
act = getattr(nn, activation)
|
191 |
+
mult = int(2 ** len(self.ratios))
|
192 |
+
model: tp.List[nn.Module] = [
|
193 |
+
SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
|
194 |
+
causal=causal, pad_mode=pad_mode)
|
195 |
+
]
|
196 |
+
|
197 |
+
if lstm:
|
198 |
+
model += [SLSTM(mult * n_filters, num_layers=lstm)]
|
199 |
+
|
200 |
+
# Upsample to raw audio scale
|
201 |
+
for i, ratio in enumerate(self.ratios):
|
202 |
+
# Add upsampling layers
|
203 |
+
model += [
|
204 |
+
act(**activation_params),
|
205 |
+
SConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
206 |
+
kernel_size=ratio * 2, stride=ratio,
|
207 |
+
norm=norm, norm_kwargs=norm_params,
|
208 |
+
causal=causal, trim_right_ratio=trim_right_ratio),
|
209 |
+
]
|
210 |
+
# Add residual layers
|
211 |
+
for j in range(n_residual_layers):
|
212 |
+
model += [
|
213 |
+
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
214 |
+
dilations=[dilation_base ** j, 1],
|
215 |
+
activation=activation, activation_params=activation_params,
|
216 |
+
norm=norm, norm_params=norm_params, causal=causal,
|
217 |
+
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
218 |
+
|
219 |
+
mult //= 2
|
220 |
+
|
221 |
+
# Add final layers
|
222 |
+
model += [
|
223 |
+
act(**activation_params),
|
224 |
+
SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params,
|
225 |
+
causal=causal, pad_mode=pad_mode)
|
226 |
+
]
|
227 |
+
# Add optional final activation to decoder (eg. tanh)
|
228 |
+
if final_activation is not None:
|
229 |
+
final_act = getattr(nn, final_activation)
|
230 |
+
final_activation_params = final_activation_params or {}
|
231 |
+
model += [
|
232 |
+
final_act(**final_activation_params)
|
233 |
+
]
|
234 |
+
self.model = nn.Sequential(*model)
|
235 |
+
|
236 |
+
def forward(self, z):
|
237 |
+
y = self.model(z)
|
238 |
+
return y
|
239 |
+
|
240 |
+
|
241 |
+
def test():
|
242 |
+
import torch
|
243 |
+
encoder = SEANetEncoder()
|
244 |
+
decoder = SEANetDecoder()
|
245 |
+
x = torch.randn(1, 1, 24000)
|
246 |
+
z = encoder(x)
|
247 |
+
assert list(z.shape) == [1, 128, 75], z.shape
|
248 |
+
y = decoder(z)
|
249 |
+
assert y.shape == x.shape, (x.shape, y.shape)
|
250 |
+
|
251 |
+
|
252 |
+
if __name__ == '__main__':
|
253 |
+
test()
|
encoder/modules/transformer.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""A streamable transformer."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
|
17 |
+
"""Create time embedding for the given positions, target dimension `dim`.
|
18 |
+
"""
|
19 |
+
# We aim for BTC format
|
20 |
+
assert dim % 2 == 0
|
21 |
+
half_dim = dim // 2
|
22 |
+
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
|
23 |
+
phase = positions / (max_period ** (adim / (half_dim - 1)))
|
24 |
+
return torch.cat([
|
25 |
+
torch.cos(phase),
|
26 |
+
torch.sin(phase),
|
27 |
+
], dim=-1)
|
28 |
+
|
29 |
+
|
30 |
+
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
31 |
+
def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
|
32 |
+
if self.norm_first:
|
33 |
+
sa_input = self.norm1(x)
|
34 |
+
x = x + self._sa_block(sa_input, x_past, past_context)
|
35 |
+
x = x + self._ff_block(self.norm2(x))
|
36 |
+
else:
|
37 |
+
sa_input = x
|
38 |
+
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
|
39 |
+
x = self.norm2(x + self._ff_block(x))
|
40 |
+
|
41 |
+
return x, sa_input
|
42 |
+
|
43 |
+
# self-attention block
|
44 |
+
def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
|
45 |
+
_, T, _ = x.shape
|
46 |
+
_, H, _ = x_past.shape
|
47 |
+
|
48 |
+
queries = x
|
49 |
+
keys = torch.cat([x_past, x], dim=1)
|
50 |
+
values = keys
|
51 |
+
|
52 |
+
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
|
53 |
+
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
|
54 |
+
delta = queries_pos - keys_pos
|
55 |
+
valid_access = (delta >= 0) & (delta <= past_context)
|
56 |
+
x = self.self_attn(queries, keys, values,
|
57 |
+
attn_mask=~valid_access,
|
58 |
+
need_weights=False)[0]
|
59 |
+
return self.dropout1(x)
|
60 |
+
|
61 |
+
|
62 |
+
class StreamingTransformerEncoder(nn.Module):
|
63 |
+
"""TransformerEncoder with streaming support.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
dim (int): dimension of the data.
|
67 |
+
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
|
68 |
+
num_heads (int): number of heads.
|
69 |
+
num_layers (int): number of layers.
|
70 |
+
max_period (float): maxium period of cosines in the positional embedding.
|
71 |
+
past_context (int or None): receptive field for the causal mask, infinite if None.
|
72 |
+
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
|
73 |
+
norm_in (bool): normalize the input.
|
74 |
+
dropout (float): dropout probability.
|
75 |
+
**kwargs: See `nn.TransformerEncoderLayer`.
|
76 |
+
"""
|
77 |
+
def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
|
78 |
+
max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
|
79 |
+
norm_in: bool = True, dropout: float = 0., **kwargs):
|
80 |
+
super().__init__()
|
81 |
+
assert dim % num_heads == 0
|
82 |
+
hidden_dim = int(dim * hidden_scale)
|
83 |
+
|
84 |
+
self.max_period = max_period
|
85 |
+
self.past_context = past_context
|
86 |
+
activation: tp.Any = F.gelu if gelu else F.relu
|
87 |
+
|
88 |
+
self.norm_in: nn.Module
|
89 |
+
if norm_in:
|
90 |
+
self.norm_in = nn.LayerNorm(dim)
|
91 |
+
else:
|
92 |
+
self.norm_in = nn.Identity()
|
93 |
+
|
94 |
+
self.layers = nn.ModuleList()
|
95 |
+
for idx in range(num_layers):
|
96 |
+
self.layers.append(
|
97 |
+
StreamingTransformerEncoderLayer(
|
98 |
+
dim, num_heads, hidden_dim,
|
99 |
+
activation=activation, batch_first=True, dropout=dropout, **kwargs))
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor,
|
102 |
+
states: tp.Optional[tp.List[torch.Tensor]] = None,
|
103 |
+
offset: tp.Union[int, torch.Tensor] = 0):
|
104 |
+
B, T, C = x.shape
|
105 |
+
if states is None:
|
106 |
+
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
|
107 |
+
|
108 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
|
109 |
+
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
|
110 |
+
|
111 |
+
new_state: tp.List[torch.Tensor] = []
|
112 |
+
x = self.norm_in(x)
|
113 |
+
x = x + pos_emb
|
114 |
+
|
115 |
+
for layer_state, layer in zip(states, self.layers):
|
116 |
+
x, new_layer_state = layer(x, layer_state, self.past_context)
|
117 |
+
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
|
118 |
+
new_state.append(new_layer_state[:, -self.past_context:, :])
|
119 |
+
return x, new_state, offset + T
|
encoder/msstftd.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""MS-STFT discriminator, provided here for reference."""
|
8 |
+
|
9 |
+
import typing as tp
|
10 |
+
|
11 |
+
import torchaudio
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
from .modules import NormConv2d
|
17 |
+
|
18 |
+
|
19 |
+
FeatureMapType = tp.List[torch.Tensor]
|
20 |
+
LogitsType = torch.Tensor
|
21 |
+
DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
|
22 |
+
|
23 |
+
|
24 |
+
def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
|
25 |
+
return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
|
26 |
+
|
27 |
+
|
28 |
+
class DiscriminatorSTFT(nn.Module):
|
29 |
+
"""STFT sub-discriminator.
|
30 |
+
Args:
|
31 |
+
filters (int): Number of filters in convolutions
|
32 |
+
in_channels (int): Number of input channels. Default: 1
|
33 |
+
out_channels (int): Number of output channels. Default: 1
|
34 |
+
n_fft (int): Size of FFT for each scale. Default: 1024
|
35 |
+
hop_length (int): Length of hop between STFT windows for each scale. Default: 256
|
36 |
+
kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
|
37 |
+
stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
|
38 |
+
dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
|
39 |
+
win_length (int): Window size for each scale. Default: 1024
|
40 |
+
normalized (bool): Whether to normalize by magnitude after stft. Default: True
|
41 |
+
norm (str): Normalization method. Default: `'weight_norm'`
|
42 |
+
activation (str): Activation function. Default: `'LeakyReLU'`
|
43 |
+
activation_params (dict): Parameters to provide to the activation function.
|
44 |
+
growth (int): Growth factor for the filters. Default: 1
|
45 |
+
"""
|
46 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
47 |
+
n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
|
48 |
+
filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
|
49 |
+
stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
|
50 |
+
activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
|
51 |
+
super().__init__()
|
52 |
+
assert len(kernel_size) == 2
|
53 |
+
assert len(stride) == 2
|
54 |
+
self.filters = filters
|
55 |
+
self.in_channels = in_channels
|
56 |
+
self.out_channels = out_channels
|
57 |
+
self.n_fft = n_fft
|
58 |
+
self.hop_length = hop_length
|
59 |
+
self.win_length = win_length
|
60 |
+
self.normalized = normalized
|
61 |
+
self.activation = getattr(torch.nn, activation)(**activation_params)
|
62 |
+
self.spec_transform = torchaudio.transforms.Spectrogram(
|
63 |
+
n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
|
64 |
+
normalized=self.normalized, center=False, pad_mode=None, power=None)
|
65 |
+
spec_channels = 2 * self.in_channels
|
66 |
+
self.convs = nn.ModuleList()
|
67 |
+
self.convs.append(
|
68 |
+
NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
|
69 |
+
)
|
70 |
+
in_chs = min(filters_scale * self.filters, max_filters)
|
71 |
+
for i, dilation in enumerate(dilations):
|
72 |
+
out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
|
73 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
|
74 |
+
dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
|
75 |
+
norm=norm))
|
76 |
+
in_chs = out_chs
|
77 |
+
out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
|
78 |
+
self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
|
79 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
80 |
+
norm=norm))
|
81 |
+
self.conv_post = NormConv2d(out_chs, self.out_channels,
|
82 |
+
kernel_size=(kernel_size[0], kernel_size[0]),
|
83 |
+
padding=get_2d_padding((kernel_size[0], kernel_size[0])),
|
84 |
+
norm=norm)
|
85 |
+
|
86 |
+
def forward(self, x: torch.Tensor):
|
87 |
+
fmap = []
|
88 |
+
z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
|
89 |
+
z = torch.cat([z.real, z.imag], dim=1)
|
90 |
+
z = rearrange(z, 'b c w t -> b c t w')
|
91 |
+
for i, layer in enumerate(self.convs):
|
92 |
+
z = layer(z)
|
93 |
+
z = self.activation(z)
|
94 |
+
fmap.append(z)
|
95 |
+
z = self.conv_post(z)
|
96 |
+
return z, fmap
|
97 |
+
|
98 |
+
|
99 |
+
class MultiScaleSTFTDiscriminator(nn.Module):
|
100 |
+
"""Multi-Scale STFT (MS-STFT) discriminator.
|
101 |
+
Args:
|
102 |
+
filters (int): Number of filters in convolutions
|
103 |
+
in_channels (int): Number of input channels. Default: 1
|
104 |
+
out_channels (int): Number of output channels. Default: 1
|
105 |
+
n_ffts (Sequence[int]): Size of FFT for each scale
|
106 |
+
hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
|
107 |
+
win_lengths (Sequence[int]): Window size for each scale
|
108 |
+
**kwargs: additional args for STFTDiscriminator
|
109 |
+
"""
|
110 |
+
def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
|
111 |
+
n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
|
112 |
+
win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
|
113 |
+
super().__init__()
|
114 |
+
assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
|
115 |
+
self.discriminators = nn.ModuleList([
|
116 |
+
DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
|
117 |
+
n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
|
118 |
+
for i in range(len(n_ffts))
|
119 |
+
])
|
120 |
+
self.num_discriminators = len(self.discriminators)
|
121 |
+
|
122 |
+
def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
|
123 |
+
logits = []
|
124 |
+
fmaps = []
|
125 |
+
for disc in self.discriminators:
|
126 |
+
logit, fmap = disc(x)
|
127 |
+
logits.append(logit)
|
128 |
+
fmaps.append(fmap)
|
129 |
+
return logits, fmaps
|
130 |
+
|
131 |
+
|
132 |
+
def test():
|
133 |
+
disc = MultiScaleSTFTDiscriminator(filters=32)
|
134 |
+
y = torch.randn(1, 1, 24000)
|
135 |
+
y_hat = torch.randn(1, 1, 24000)
|
136 |
+
|
137 |
+
y_disc_r, fmap_r = disc(y)
|
138 |
+
y_disc_gen, fmap_gen = disc(y_hat)
|
139 |
+
assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators
|
140 |
+
|
141 |
+
assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
|
142 |
+
assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
|
143 |
+
assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
test()
|
encoder/quantization/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# flake8: noqa
|
8 |
+
from .vq import QuantizedResult, ResidualVectorQuantizer
|
encoder/quantization/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (243 Bytes). View file
|
|
encoder/quantization/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (253 Bytes). View file
|
|
encoder/quantization/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (241 Bytes). View file
|
|
encoder/quantization/__pycache__/core_vq.cpython-310.pyc
ADDED
Binary file (12.3 kB). View file
|
|
encoder/quantization/__pycache__/core_vq.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|
encoder/quantization/__pycache__/core_vq.cpython-39.pyc
ADDED
Binary file (12.6 kB). View file
|
|
encoder/quantization/__pycache__/vq.cpython-310.pyc
ADDED
Binary file (5.14 kB). View file
|
|
encoder/quantization/__pycache__/vq.cpython-38.pyc
ADDED
Binary file (4.8 kB). View file
|
|
encoder/quantization/__pycache__/vq.cpython-39.pyc
ADDED
Binary file (5.12 kB). View file
|
|