breadlicker45 commited on
Commit
c34b897
1 Parent(s): e09f850

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. encoder/__init__.py +12 -0
  2. encoder/__pycache__/__init__.cpython-310.pyc +0 -0
  3. encoder/__pycache__/__init__.cpython-38.pyc +0 -0
  4. encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  5. encoder/__pycache__/distrib.cpython-310.pyc +0 -0
  6. encoder/__pycache__/distrib.cpython-38.pyc +0 -0
  7. encoder/__pycache__/distrib.cpython-39.pyc +0 -0
  8. encoder/__pycache__/model.cpython-310.pyc +0 -0
  9. encoder/__pycache__/model.cpython-38.pyc +0 -0
  10. encoder/__pycache__/model.cpython-39.pyc +0 -0
  11. encoder/__pycache__/utils.cpython-310.pyc +0 -0
  12. encoder/__pycache__/utils.cpython-38.pyc +0 -0
  13. encoder/__pycache__/utils.cpython-39.pyc +0 -0
  14. encoder/distrib.py +124 -0
  15. encoder/model.py +324 -0
  16. encoder/modules/__init__.py +22 -0
  17. encoder/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  18. encoder/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  19. encoder/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  20. encoder/modules/__pycache__/conv.cpython-310.pyc +0 -0
  21. encoder/modules/__pycache__/conv.cpython-38.pyc +0 -0
  22. encoder/modules/__pycache__/conv.cpython-39.pyc +0 -0
  23. encoder/modules/__pycache__/lstm.cpython-310.pyc +0 -0
  24. encoder/modules/__pycache__/lstm.cpython-38.pyc +0 -0
  25. encoder/modules/__pycache__/lstm.cpython-39.pyc +0 -0
  26. encoder/modules/__pycache__/norm.cpython-310.pyc +0 -0
  27. encoder/modules/__pycache__/norm.cpython-38.pyc +0 -0
  28. encoder/modules/__pycache__/norm.cpython-39.pyc +0 -0
  29. encoder/modules/__pycache__/seanet.cpython-310.pyc +0 -0
  30. encoder/modules/__pycache__/seanet.cpython-38.pyc +0 -0
  31. encoder/modules/__pycache__/seanet.cpython-39.pyc +0 -0
  32. encoder/modules/__pycache__/transformer.cpython-310.pyc +0 -0
  33. encoder/modules/__pycache__/transformer.cpython-38.pyc +0 -0
  34. encoder/modules/__pycache__/transformer.cpython-39.pyc +0 -0
  35. encoder/modules/conv.py +253 -0
  36. encoder/modules/lstm.py +39 -0
  37. encoder/modules/norm.py +28 -0
  38. encoder/modules/seanet.py +253 -0
  39. encoder/modules/transformer.py +119 -0
  40. encoder/msstftd.py +147 -0
  41. encoder/quantization/__init__.py +8 -0
  42. encoder/quantization/__pycache__/__init__.cpython-310.pyc +0 -0
  43. encoder/quantization/__pycache__/__init__.cpython-38.pyc +0 -0
  44. encoder/quantization/__pycache__/__init__.cpython-39.pyc +0 -0
  45. encoder/quantization/__pycache__/core_vq.cpython-310.pyc +0 -0
  46. encoder/quantization/__pycache__/core_vq.cpython-38.pyc +0 -0
  47. encoder/quantization/__pycache__/core_vq.cpython-39.pyc +0 -0
  48. encoder/quantization/__pycache__/vq.cpython-310.pyc +0 -0
  49. encoder/quantization/__pycache__/vq.cpython-38.pyc +0 -0
  50. encoder/quantization/__pycache__/vq.cpython-39.pyc +0 -0
encoder/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # flake8: noqa
7
+
8
+ """EnCodec neural audio codec."""
9
+
10
+ __version__ = "0.1.2a3"
11
+
12
+ from .model import EncodecModel
encoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (268 Bytes). View file
 
encoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (278 Bytes). View file
 
encoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (266 Bytes). View file
 
encoder/__pycache__/distrib.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
encoder/__pycache__/distrib.cpython-38.pyc ADDED
Binary file (3.79 kB). View file
 
encoder/__pycache__/distrib.cpython-39.pyc ADDED
Binary file (3.76 kB). View file
 
encoder/__pycache__/model.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
encoder/__pycache__/model.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
encoder/__pycache__/model.cpython-39.pyc ADDED
Binary file (11.6 kB). View file
 
encoder/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.68 kB). View file
 
encoder/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.65 kB). View file
 
encoder/__pycache__/utils.cpython-39.pyc ADDED
Binary file (2.66 kB). View file
 
encoder/distrib.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch distributed utilities."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+
13
+
14
+ def rank():
15
+ if torch.distributed.is_initialized():
16
+ return torch.distributed.get_rank()
17
+ else:
18
+ return 0
19
+
20
+
21
+ def world_size():
22
+ if torch.distributed.is_initialized():
23
+ return torch.distributed.get_world_size()
24
+ else:
25
+ return 1
26
+
27
+
28
+ def is_distributed():
29
+ return world_size() > 1
30
+
31
+
32
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
+ if is_distributed():
34
+ return torch.distributed.all_reduce(tensor, op)
35
+
36
+
37
+ def _is_complex_or_float(tensor):
38
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
+
40
+
41
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
42
+ # utility function to check that the number of params in all workers is the same,
43
+ # and thus avoid a deadlock with distributed all reduce.
44
+ if not is_distributed() or not params:
45
+ return
46
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
47
+ all_reduce(tensor)
48
+ if tensor.item() != len(params) * world_size():
49
+ # If not all the workers have the same number, for at least one of them,
50
+ # this inequality will be verified.
51
+ raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, "
52
+ "at least one worker has a different one.")
53
+
54
+
55
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
56
+ """Broadcast the tensors from the given parameters to all workers.
57
+ This can be used to ensure that all workers have the same model to start with.
58
+ """
59
+ if not is_distributed():
60
+ return
61
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
62
+ _check_number_of_params(tensors)
63
+ handles = []
64
+ for tensor in tensors:
65
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
66
+ handles.append(handle)
67
+ for handle in handles:
68
+ handle.wait()
69
+
70
+
71
+ def sync_buffer(buffers, average=True):
72
+ """
73
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
74
+ """
75
+ if not is_distributed():
76
+ return
77
+ handles = []
78
+ for buffer in buffers:
79
+ if torch.is_floating_point(buffer.data):
80
+ if average:
81
+ handle = torch.distributed.all_reduce(
82
+ buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
83
+ else:
84
+ handle = torch.distributed.broadcast(
85
+ buffer.data, src=0, async_op=True)
86
+ handles.append((buffer, handle))
87
+ for buffer, handle in handles:
88
+ handle.wait()
89
+ if average:
90
+ buffer.data /= world_size
91
+
92
+
93
+ def sync_grad(params):
94
+ """
95
+ Simpler alternative to DistributedDataParallel, that doesn't rely
96
+ on any black magic. For simple models it can also be as fast.
97
+ Just call this on your model parameters after the call to backward!
98
+ """
99
+ if not is_distributed():
100
+ return
101
+ handles = []
102
+ for p in params:
103
+ if p.grad is not None:
104
+ handle = torch.distributed.all_reduce(
105
+ p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
106
+ handles.append((p, handle))
107
+ for p, handle in handles:
108
+ handle.wait()
109
+ p.grad.data /= world_size()
110
+
111
+
112
+ def average_metrics(metrics: tp.Dict[str, float], count=1.):
113
+ """Average a dictionary of metrics across all workers, using the optional
114
+ `count` as unnormalized weight.
115
+ """
116
+ if not is_distributed():
117
+ return metrics
118
+ keys, values = zip(*metrics.items())
119
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
120
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
121
+ tensor *= count
122
+ all_reduce(tensor)
123
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
124
+ return dict(zip(keys, averaged))
encoder/model.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """EnCodec model implementation."""
8
+
9
+ import math
10
+ from pathlib import Path
11
+ import typing as tp
12
+
13
+ import numpy as np
14
+ import torch
15
+ from torch import nn
16
+
17
+ from . import quantization as qt
18
+ from . import modules as m
19
+ from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
20
+
21
+
22
+ ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
23
+
24
+ EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
25
+
26
+
27
+ class LMModel(nn.Module):
28
+ """Language Model to estimate probabilities of each codebook entry.
29
+ We predict all codebooks in parallel for a given time step.
30
+
31
+ Args:
32
+ n_q (int): number of codebooks.
33
+ card (int): codebook cardinality.
34
+ dim (int): transformer dimension.
35
+ **kwargs: passed to `encoder.modules.transformer.StreamingTransformerEncoder`.
36
+ """
37
+ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
38
+ super().__init__()
39
+ self.card = card
40
+ self.n_q = n_q
41
+ self.dim = dim
42
+ self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
43
+ self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
44
+ self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
45
+
46
+ def forward(self, indices: torch.Tensor,
47
+ states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
48
+ """
49
+ Args:
50
+ indices (torch.Tensor): indices from the previous time step. Indices
51
+ should be 1 + actual index in the codebook. The value 0 is reserved for
52
+ when the index is missing (i.e. first time step). Shape should be
53
+ `[B, n_q, T]`.
54
+ states: state for the streaming decoding.
55
+ offset: offset of the current time step.
56
+
57
+ Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
58
+ with a shape `[B, card, n_q, T]`.
59
+
60
+ """
61
+ B, K, T = indices.shape
62
+ input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
63
+ out, states, offset = self.transformer(input_, states, offset)
64
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
65
+ return torch.softmax(logits, dim=1), states, offset
66
+
67
+
68
+ class EncodecModel(nn.Module):
69
+ """EnCodec model operating on the raw waveform.
70
+ Args:
71
+ target_bandwidths (list of float): Target bandwidths.
72
+ encoder (nn.Module): Encoder network.
73
+ decoder (nn.Module): Decoder network.
74
+ sample_rate (int): Audio sample rate.
75
+ channels (int): Number of audio channels.
76
+ normalize (bool): Whether to apply audio normalization.
77
+ segment (float or None): segment duration in sec. when doing overlap-add.
78
+ overlap (float): overlap between segment, given as a fraction of the segment duration.
79
+ name (str): name of the model, used as metadata when compressing audio.
80
+ """
81
+ def __init__(self,
82
+ encoder: m.SEANetEncoder,
83
+ decoder: m.SEANetDecoder,
84
+ quantizer: qt.ResidualVectorQuantizer,
85
+ target_bandwidths: tp.List[float],
86
+ sample_rate: int,
87
+ channels: int,
88
+ normalize: bool = False,
89
+ segment: tp.Optional[float] = None,
90
+ overlap: float = 0.01,
91
+ name: str = 'unset'):
92
+ super().__init__()
93
+ self.bandwidth: tp.Optional[float] = None
94
+ self.target_bandwidths = target_bandwidths
95
+ self.encoder = encoder
96
+ self.quantizer = quantizer
97
+ self.decoder = decoder
98
+ self.sample_rate = sample_rate
99
+ self.channels = channels
100
+ self.normalize = normalize
101
+ self.segment = segment
102
+ self.overlap = overlap
103
+ self.frame_rate = math.ceil(self.sample_rate / np.prod(self.encoder.ratios))
104
+ self.name = name
105
+ self.bits_per_codebook = int(math.log2(self.quantizer.bins))
106
+ assert 2 ** self.bits_per_codebook == self.quantizer.bins, \
107
+ "quantizer bins must be a power of 2."
108
+
109
+ @property
110
+ def segment_length(self) -> tp.Optional[int]:
111
+ if self.segment is None:
112
+ return None
113
+ return int(self.segment * self.sample_rate)
114
+
115
+ @property
116
+ def segment_stride(self) -> tp.Optional[int]:
117
+ segment_length = self.segment_length
118
+ if segment_length is None:
119
+ return None
120
+ return max(1, int((1 - self.overlap) * segment_length))
121
+
122
+ def encode(self, x: torch.Tensor) -> tp.List[EncodedFrame]:
123
+ """Given a tensor `x`, returns a list of frames containing
124
+ the discrete encoded codes for `x`, along with rescaling factors
125
+ for each segment, when `self.normalize` is True.
126
+
127
+ Each frames is a tuple `(codebook, scale)`, with `codebook` of
128
+ shape `[B, K, T]`, with `K` the number of codebooks.
129
+ """
130
+ assert x.dim() == 3
131
+ _, channels, length = x.shape
132
+ assert channels > 0 and channels <= 2
133
+ segment_length = self.segment_length
134
+ if segment_length is None:
135
+ segment_length = length
136
+ stride = length
137
+ else:
138
+ stride = self.segment_stride # type: ignore
139
+ assert stride is not None
140
+
141
+ encoded_frames: tp.List[EncodedFrame] = []
142
+ for offset in range(0, length, stride):
143
+ frame = x[:, :, offset: offset + segment_length]
144
+ encoded_frames.append(self._encode_frame(frame))
145
+ return encoded_frames
146
+
147
+ def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
148
+ length = x.shape[-1]
149
+ duration = length / self.sample_rate
150
+ assert self.segment is None or duration <= 1e-5 + self.segment
151
+
152
+ if self.normalize:
153
+ mono = x.mean(dim=1, keepdim=True)
154
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
155
+ scale = 1e-8 + volume
156
+ x = x / scale
157
+ scale = scale.view(-1, 1)
158
+ else:
159
+ scale = None
160
+
161
+ emb = self.encoder(x)
162
+ codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
163
+ codes = codes.transpose(0, 1)
164
+ # codes is [B, K, T], with T frames, K nb of codebooks.
165
+ return codes, scale
166
+
167
+ def decode(self, encoded_frames: tp.List[EncodedFrame]) -> torch.Tensor:
168
+ """Decode the given frames into a waveform.
169
+ Note that the output might be a bit bigger than the input. In that case,
170
+ any extra steps at the end can be trimmed.
171
+ """
172
+ segment_length = self.segment_length
173
+ if segment_length is None:
174
+ assert len(encoded_frames) == 1
175
+ return self._decode_frame(encoded_frames[0])
176
+
177
+ frames = [self._decode_frame(frame) for frame in encoded_frames]
178
+ return _linear_overlap_add(frames, self.segment_stride or 1)
179
+
180
+ def _decode_frame(self, encoded_frame: EncodedFrame) -> torch.Tensor:
181
+ codes, scale = encoded_frame
182
+ codes = codes.transpose(0, 1)
183
+ emb = self.quantizer.decode(codes)
184
+ out = self.decoder(emb)
185
+ if scale is not None:
186
+ out = out * scale.view(-1, 1, 1)
187
+ return out
188
+
189
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
190
+ frames = self.encode(x)
191
+ return self.decode(frames)[:, :, :x.shape[-1]]
192
+
193
+ def set_target_bandwidth(self, bandwidth: float):
194
+ if bandwidth not in self.target_bandwidths:
195
+ raise ValueError(f"This model doesn't support the bandwidth {bandwidth}. "
196
+ f"Select one of {self.target_bandwidths}.")
197
+ self.bandwidth = bandwidth
198
+
199
+ def get_lm_model(self) -> LMModel:
200
+ """Return the associated LM model to improve the compression rate.
201
+ """
202
+ device = next(self.parameters()).device
203
+ lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
204
+ past_context=int(3.5 * self.frame_rate)).to(device)
205
+ checkpoints = {
206
+ 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
207
+ 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
208
+ }
209
+ try:
210
+ checkpoint_name = checkpoints[self.name]
211
+ except KeyError:
212
+ raise RuntimeError("No LM pre-trained for the current Encodec model.")
213
+ url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
214
+ state = torch.hub.load_state_dict_from_url(
215
+ url, map_location='cpu', check_hash=True) # type: ignore
216
+ lm.load_state_dict(state)
217
+ lm.eval()
218
+ return lm
219
+
220
+ @staticmethod
221
+ def _get_model(target_bandwidths: tp.List[float],
222
+ sample_rate: int = 24_000,
223
+ channels: int = 1,
224
+ causal: bool = True,
225
+ model_norm: str = 'weight_norm',
226
+ audio_normalize: bool = False,
227
+ segment: tp.Optional[float] = None,
228
+ name: str = 'unset'):
229
+ encoder = m.SEANetEncoder(channels=channels, norm=model_norm, causal=causal)
230
+ decoder = m.SEANetDecoder(channels=channels, norm=model_norm, causal=causal)
231
+ n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / encoder.hop_length) * 10))
232
+ quantizer = qt.ResidualVectorQuantizer(
233
+ dimension=encoder.dimension,
234
+ n_q=n_q,
235
+ bins=1024,
236
+ )
237
+ model = EncodecModel(
238
+ encoder,
239
+ decoder,
240
+ quantizer,
241
+ target_bandwidths,
242
+ sample_rate,
243
+ channels,
244
+ normalize=audio_normalize,
245
+ segment=segment,
246
+ name=name,
247
+ )
248
+ return model
249
+
250
+ @staticmethod
251
+ def _get_pretrained(checkpoint_name: str, repository: tp.Optional[Path] = None):
252
+ if repository is not None:
253
+ if not repository.is_dir():
254
+ raise ValueError(f"{repository} must exist and be a directory.")
255
+ file = repository / checkpoint_name
256
+ checksum = file.stem.split('-')[1]
257
+ _check_checksum(file, checksum)
258
+ return torch.load(file)
259
+ else:
260
+ url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
261
+ return torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type:ignore
262
+
263
+ @staticmethod
264
+ def encodec_model_24khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
265
+ """Return the pretrained causal 24khz model.
266
+ """
267
+ if repository:
268
+ assert pretrained
269
+ target_bandwidths = [1.5, 3., 6, 12., 24.]
270
+ checkpoint_name = 'encodec_24khz-d7cc33bc.th'
271
+ sample_rate = 24_000
272
+ channels = 1
273
+ model = EncodecModel._get_model(
274
+ target_bandwidths, sample_rate, channels,
275
+ causal=True, model_norm='weight_norm', audio_normalize=False,
276
+ name='encodec_24khz' if pretrained else 'unset')
277
+ if pretrained:
278
+ state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
279
+ model.load_state_dict(state_dict)
280
+ model.eval()
281
+ return model
282
+
283
+ @staticmethod
284
+ def encodec_model_48khz(pretrained: bool = True, repository: tp.Optional[Path] = None):
285
+ """Return the pretrained 48khz model.
286
+ """
287
+ if repository:
288
+ assert pretrained
289
+ target_bandwidths = [3., 6., 12., 24.]
290
+ checkpoint_name = 'encodec_48khz-7e698e3e.th'
291
+ sample_rate = 48_000
292
+ channels = 2
293
+ model = EncodecModel._get_model(
294
+ target_bandwidths, sample_rate, channels,
295
+ causal=False, model_norm='time_group_norm', audio_normalize=True,
296
+ segment=1., name='encodec_48khz' if pretrained else 'unset')
297
+ if pretrained:
298
+ state_dict = EncodecModel._get_pretrained(checkpoint_name, repository)
299
+ model.load_state_dict(state_dict)
300
+ model.eval()
301
+ return model
302
+
303
+
304
+ def test():
305
+ from itertools import product
306
+ import torchaudio
307
+ bandwidths = [3, 6, 12, 24]
308
+ models = {
309
+ 'encodec_24khz': EncodecModel.encodec_model_24khz,
310
+ 'encodec_48khz': EncodecModel.encodec_model_48khz
311
+ }
312
+ for model_name, bw in product(models.keys(), bandwidths):
313
+ model = models[model_name]()
314
+ model.set_target_bandwidth(bw)
315
+ audio_suffix = model_name.split('_')[1][:3]
316
+ wav, sr = torchaudio.load(f"test_{audio_suffix}.wav")
317
+ wav = wav[:, :model.sample_rate * 2]
318
+ wav_in = wav.unsqueeze(0)
319
+ wav_dec = model(wav_in)[0]
320
+ assert wav.shape == wav_dec.shape, (wav.shape, wav_dec.shape)
321
+
322
+
323
+ if __name__ == '__main__':
324
+ test()
encoder/modules/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch modules."""
8
+
9
+ # flake8: noqa
10
+ from .conv import (
11
+ pad1d,
12
+ unpad1d,
13
+ NormConv1d,
14
+ NormConvTranspose1d,
15
+ NormConv2d,
16
+ NormConvTranspose2d,
17
+ SConv1d,
18
+ SConvTranspose1d,
19
+ )
20
+ from .lstm import SLSTM
21
+ from .seanet import SEANetEncoder, SEANetDecoder
22
+ from .transformer import StreamingTransformerEncoder
encoder/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (557 Bytes). View file
 
encoder/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (567 Bytes). View file
 
encoder/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (555 Bytes). View file
 
encoder/modules/__pycache__/conv.cpython-310.pyc ADDED
Binary file (9.2 kB). View file
 
encoder/modules/__pycache__/conv.cpython-38.pyc ADDED
Binary file (9.48 kB). View file
 
encoder/modules/__pycache__/conv.cpython-39.pyc ADDED
Binary file (9.43 kB). View file
 
encoder/modules/__pycache__/lstm.cpython-310.pyc ADDED
Binary file (1.05 kB). View file
 
encoder/modules/__pycache__/lstm.cpython-38.pyc ADDED
Binary file (1.06 kB). View file
 
encoder/modules/__pycache__/lstm.cpython-39.pyc ADDED
Binary file (1.05 kB). View file
 
encoder/modules/__pycache__/norm.cpython-310.pyc ADDED
Binary file (1.15 kB). View file
 
encoder/modules/__pycache__/norm.cpython-38.pyc ADDED
Binary file (1.15 kB). View file
 
encoder/modules/__pycache__/norm.cpython-39.pyc ADDED
Binary file (1.14 kB). View file
 
encoder/modules/__pycache__/seanet.cpython-310.pyc ADDED
Binary file (9.72 kB). View file
 
encoder/modules/__pycache__/seanet.cpython-38.pyc ADDED
Binary file (9.65 kB). View file
 
encoder/modules/__pycache__/seanet.cpython-39.pyc ADDED
Binary file (9.48 kB). View file
 
encoder/modules/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (4.54 kB). View file
 
encoder/modules/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (4.53 kB). View file
 
encoder/modules/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (4.47 kB). View file
 
encoder/modules/conv.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ from .norm import ConvLayerNorm
19
+
20
+
21
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
22
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
23
+
24
+
25
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
26
+ assert norm in CONV_NORMALIZATIONS
27
+ if norm == 'weight_norm':
28
+ return weight_norm(module)
29
+ elif norm == 'spectral_norm':
30
+ return spectral_norm(module)
31
+ else:
32
+ # We already check was in CONV_NORMALIZATION, so any other choice
33
+ # doesn't need reparametrization.
34
+ return module
35
+
36
+
37
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
38
+ """Return the proper normalization module. If causal is True, this will ensure the returned
39
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
40
+ """
41
+ assert norm in CONV_NORMALIZATIONS
42
+ if norm == 'layer_norm':
43
+ assert isinstance(module, nn.modules.conv._ConvNd)
44
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
45
+ elif norm == 'time_group_norm':
46
+ if causal:
47
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
48
+ assert isinstance(module, nn.modules.conv._ConvNd)
49
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
50
+ else:
51
+ return nn.Identity()
52
+
53
+
54
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
55
+ padding_total: int = 0) -> int:
56
+ """See `pad_for_conv1d`.
57
+ """
58
+ length = x.shape[-1]
59
+ n_frames = (length - kernel_size + padding_total) / stride + 1
60
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
61
+ return ideal_length - length
62
+
63
+
64
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
65
+ """Pad for a convolution to make sure that the last window is full.
66
+ Extra padding is added at the end. This is required to ensure that we can rebuild
67
+ an output of the same length, as otherwise, even with padding, some time steps
68
+ might get removed.
69
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
70
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
71
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
72
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
73
+ 1 2 3 4 # once you removed padding, we are missing one time step !
74
+ """
75
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
76
+ return F.pad(x, (0, extra_padding))
77
+
78
+
79
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
80
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
81
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
82
+ """
83
+ length = x.shape[-1]
84
+ padding_left, padding_right = paddings
85
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
86
+ if mode == 'reflect':
87
+ max_pad = max(padding_left, padding_right)
88
+ extra_pad = 0
89
+ if length <= max_pad:
90
+ extra_pad = max_pad - length + 1
91
+ x = F.pad(x, (0, extra_pad))
92
+ padded = F.pad(x, paddings, mode, value)
93
+ end = padded.shape[-1] - extra_pad
94
+ return padded[..., :end]
95
+ else:
96
+ return F.pad(x, paddings, mode, value)
97
+
98
+
99
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
100
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ assert (padding_left + padding_right) <= x.shape[-1]
104
+ end = x.shape[-1] - padding_right
105
+ return x[..., padding_left: end]
106
+
107
+
108
+ class NormConv1d(nn.Module):
109
+ """Wrapper around Conv1d and normalization applied to this conv
110
+ to provide a uniform interface across normalization approaches.
111
+ """
112
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
113
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
114
+ super().__init__()
115
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
116
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
117
+ self.norm_type = norm
118
+
119
+ def forward(self, x):
120
+ x = self.conv(x)
121
+ x = self.norm(x)
122
+ return x
123
+
124
+
125
+ class NormConv2d(nn.Module):
126
+ """Wrapper around Conv2d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConvTranspose1d(nn.Module):
143
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.convtr(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose2d(nn.Module):
160
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
168
+
169
+ def forward(self, x):
170
+ x = self.convtr(x)
171
+ x = self.norm(x)
172
+ return x
173
+
174
+
175
+ class SConv1d(nn.Module):
176
+ """Conv1d with some builtin handling of asymmetric or causal padding
177
+ and normalization.
178
+ """
179
+ def __init__(self, in_channels: int, out_channels: int,
180
+ kernel_size: int, stride: int = 1, dilation: int = 1,
181
+ groups: int = 1, bias: bool = True, causal: bool = False,
182
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
183
+ pad_mode: str = 'reflect'):
184
+ super().__init__()
185
+ # warn user on unusual setup between dilation and stride
186
+ if stride > 1 and dilation > 1:
187
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
188
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
189
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
190
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
191
+ norm=norm, norm_kwargs=norm_kwargs)
192
+ self.causal = causal
193
+ self.pad_mode = pad_mode
194
+
195
+ def forward(self, x):
196
+ B, C, T = x.shape
197
+ kernel_size = self.conv.conv.kernel_size[0]
198
+ stride = self.conv.conv.stride[0]
199
+ dilation = self.conv.conv.dilation[0]
200
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
201
+ padding_total = kernel_size - stride
202
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
203
+ if self.causal:
204
+ # Left padding for causal
205
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
206
+ else:
207
+ # Asymmetric padding required for odd strides
208
+ padding_right = padding_total // 2
209
+ padding_left = padding_total - padding_right
210
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
211
+ return self.conv(x)
212
+
213
+
214
+ class SConvTranspose1d(nn.Module):
215
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
216
+ and normalization.
217
+ """
218
+ def __init__(self, in_channels: int, out_channels: int,
219
+ kernel_size: int, stride: int = 1, causal: bool = False,
220
+ norm: str = 'none', trim_right_ratio: float = 1.,
221
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
222
+ super().__init__()
223
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
224
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
225
+ self.causal = causal
226
+ self.trim_right_ratio = trim_right_ratio
227
+ assert self.causal or self.trim_right_ratio == 1., \
228
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
229
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
230
+
231
+ def forward(self, x):
232
+ kernel_size = self.convtr.convtr.kernel_size[0]
233
+ stride = self.convtr.convtr.stride[0]
234
+ padding_total = kernel_size - stride
235
+
236
+ y = self.convtr(x)
237
+
238
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
239
+ # removed at the very end, when keeping only the right length for the output,
240
+ # as removing it here would require also passing the length at the matching layer
241
+ # in the encoder.
242
+ if self.causal:
243
+ # Trim the padding on the right according to the specified ratio
244
+ # if trim_right_ratio = 1.0, trim everything from right
245
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
246
+ padding_left = padding_total - padding_right
247
+ y = unpad1d(y, (padding_left, padding_right))
248
+ else:
249
+ # Asymmetric padding required for odd strides
250
+ padding_right = padding_total // 2
251
+ padding_left = padding_total - padding_right
252
+ y = unpad1d(y, (padding_left, padding_right))
253
+ return y
encoder/modules/lstm.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """LSTM layers module."""
8
+
9
+ from torch import nn
10
+
11
+
12
+ class SLSTM(nn.Module):
13
+ """
14
+ LSTM without worrying about the hidden state, nor the layout of the data.
15
+ Expects input as convolutional layout.
16
+ """
17
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
18
+ super().__init__()
19
+ self.skip = skip
20
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
21
+
22
+ # def forward(self, x):
23
+ # x = x.permute(2, 0, 1)
24
+ # y, _ = self.lstm(x)
25
+ # if self.skip:
26
+ # y = y + x
27
+ # y = y.permute(1, 2, 0)
28
+ # return y
29
+
30
+ # 修改transpose顺序
31
+ def forward(self, x):
32
+ # # 插入reshape
33
+ # x = x.reshape(x.shape)
34
+ x1 = x.permute(2, 0, 1)
35
+ y, _ = self.lstm(x1)
36
+ y = y.permute(1, 2, 0)
37
+ if self.skip:
38
+ y = y + x
39
+ return y
encoder/modules/norm.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Normalization modules."""
8
+
9
+ import typing as tp
10
+
11
+ import einops
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class ConvLayerNorm(nn.LayerNorm):
17
+ """
18
+ Convolution-friendly LayerNorm that moves channels to last dimensions
19
+ before running the normalization and moves them back to original position right after.
20
+ """
21
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
22
+ super().__init__(normalized_shape, **kwargs)
23
+
24
+ def forward(self, x):
25
+ x = einops.rearrange(x, 'b ... t -> b t ...')
26
+ x = super().forward(x)
27
+ x = einops.rearrange(x, 'b t ... -> b ... t')
28
+ return
encoder/modules/seanet.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Encodec SEANet-based encoder and decoder implementation."""
8
+
9
+ import typing as tp
10
+
11
+ import numpy as np
12
+ import torch.nn as nn
13
+
14
+ from . import (
15
+ SConv1d,
16
+ SConvTranspose1d,
17
+ SLSTM
18
+ )
19
+
20
+
21
+ class SEANetResnetBlock(nn.Module):
22
+ """Residual block from SEANet model.
23
+ Args:
24
+ dim (int): Dimension of the input/output
25
+ kernel_sizes (list): List of kernel sizes for the convolutions.
26
+ dilations (list): List of dilations for the convolutions.
27
+ activation (str): Activation function.
28
+ activation_params (dict): Parameters to provide to the activation function
29
+ norm (str): Normalization method.
30
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
31
+ causal (bool): Whether to use fully causal convolution.
32
+ pad_mode (str): Padding mode for the convolutions.
33
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3)
34
+ true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection.
35
+ """
36
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
37
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
38
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
39
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
40
+ super().__init__()
41
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
42
+ act = getattr(nn, activation)
43
+ hidden = dim // compress
44
+ block = []
45
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
46
+ in_chs = dim if i == 0 else hidden
47
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
48
+ block += [
49
+ act(**activation_params),
50
+ SConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
51
+ norm=norm, norm_kwargs=norm_params,
52
+ causal=causal, pad_mode=pad_mode),
53
+ ]
54
+ self.block = nn.Sequential(*block)
55
+ self.shortcut: nn.Module
56
+ if true_skip:
57
+ self.shortcut = nn.Identity()
58
+ else:
59
+ self.shortcut = SConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
60
+ causal=causal, pad_mode=pad_mode)
61
+
62
+ def forward(self, x):
63
+ return self.shortcut(x) + self.block(x)
64
+
65
+
66
+ class SEANetEncoder(nn.Module):
67
+ """SEANet encoder.
68
+ Args:
69
+ channels (int): Audio channels.
70
+ dimension (int): Intermediate representation dimension.
71
+ n_filters (int): Base width for the model.
72
+ n_residual_layers (int): nb of residual layers.
73
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
74
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
75
+ that must match the decoder order
76
+ activation (str): Activation function.
77
+ activation_params (dict): Parameters to provide to the activation function
78
+ norm (str): Normalization method.
79
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
80
+ kernel_size (int): Kernel size for the initial convolution.
81
+ last_kernel_size (int): Kernel size for the initial convolution.
82
+ residual_kernel_size (int): Kernel size for the residual layers.
83
+ dilation_base (int): How much to increase the dilation with each layer.
84
+ causal (bool): Whether to use fully causal convolution.
85
+ pad_mode (str): Padding mode for the convolutions.
86
+ true_skip (bool): Whether to use true skip connection or a simple
87
+ (streamable) convolution as the skip connection in the residual network blocks.
88
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
89
+ lstm (int): Number of LSTM layers at the end of the encoder.
90
+ """
91
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
92
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
93
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
94
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
95
+ pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.dimension = dimension
99
+ self.n_filters = n_filters
100
+ self.ratios = list(reversed(ratios))
101
+ del ratios
102
+ self.n_residual_layers = n_residual_layers
103
+ self.hop_length = np.prod(self.ratios)
104
+
105
+ act = getattr(nn, activation)
106
+ mult = 1
107
+ model: tp.List[nn.Module] = [
108
+ SConv1d(channels, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
109
+ causal=causal, pad_mode=pad_mode)
110
+ ]
111
+ # Downsample to raw audio scale
112
+ for i, ratio in enumerate(self.ratios):
113
+ # Add residual layers
114
+ for j in range(n_residual_layers):
115
+ model += [
116
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
117
+ dilations=[dilation_base ** j, 1],
118
+ norm=norm, norm_params=norm_params,
119
+ activation=activation, activation_params=activation_params,
120
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
121
+
122
+ # Add downsampling layers
123
+ model += [
124
+ act(**activation_params),
125
+ SConv1d(mult * n_filters, mult * n_filters * 2,
126
+ kernel_size=ratio * 2, stride=ratio,
127
+ norm=norm, norm_kwargs=norm_params,
128
+ causal=causal, pad_mode=pad_mode),
129
+ ]
130
+ mult *= 2
131
+
132
+ if lstm:
133
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
134
+
135
+ model += [
136
+ act(**activation_params),
137
+ SConv1d(mult * n_filters, dimension, last_kernel_size, norm=norm, norm_kwargs=norm_params,
138
+ causal=causal, pad_mode=pad_mode)
139
+ ]
140
+
141
+ self.model = nn.Sequential(*model)
142
+
143
+ def forward(self, x):
144
+ return self.model(x)
145
+
146
+
147
+ class SEANetDecoder(nn.Module):
148
+ """SEANet decoder.
149
+ Args:
150
+ channels (int): Audio channels.
151
+ dimension (int): Intermediate representation dimension.
152
+ n_filters (int): Base width for the model.
153
+ n_residual_layers (int): nb of residual layers.
154
+ ratios (Sequence[int]): kernel size and stride ratios
155
+ activation (str): Activation function.
156
+ activation_params (dict): Parameters to provide to the activation function
157
+ final_activation (str): Final activation function after all convolutions.
158
+ final_activation_params (dict): Parameters to provide to the activation function
159
+ norm (str): Normalization method.
160
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
161
+ kernel_size (int): Kernel size for the initial convolution.
162
+ last_kernel_size (int): Kernel size for the initial convolution.
163
+ residual_kernel_size (int): Kernel size for the residual layers.
164
+ dilation_base (int): How much to increase the dilation with each layer.
165
+ causal (bool): Whether to use fully causal convolution.
166
+ pad_mode (str): Padding mode for the convolutions.
167
+ true_skip (bool): Whether to use true skip connection or a simple
168
+ (streamable) convolution as the skip connection in the residual network blocks.
169
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
170
+ lstm (int): Number of LSTM layers at the end of the encoder.
171
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
172
+ If equal to 1.0, it means that all the trimming is done at the right.
173
+ """
174
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 1,
175
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
176
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
177
+ norm: str = 'weight_norm', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
178
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
179
+ pad_mode: str = 'reflect', true_skip: bool = False, compress: int = 2, lstm: int = 2,
180
+ trim_right_ratio: float = 1.0):
181
+ super().__init__()
182
+ self.dimension = dimension
183
+ self.channels = channels
184
+ self.n_filters = n_filters
185
+ self.ratios = ratios
186
+ del ratios
187
+ self.n_residual_layers = n_residual_layers
188
+ self.hop_length = np.prod(self.ratios)
189
+
190
+ act = getattr(nn, activation)
191
+ mult = int(2 ** len(self.ratios))
192
+ model: tp.List[nn.Module] = [
193
+ SConv1d(dimension, mult * n_filters, kernel_size, norm=norm, norm_kwargs=norm_params,
194
+ causal=causal, pad_mode=pad_mode)
195
+ ]
196
+
197
+ if lstm:
198
+ model += [SLSTM(mult * n_filters, num_layers=lstm)]
199
+
200
+ # Upsample to raw audio scale
201
+ for i, ratio in enumerate(self.ratios):
202
+ # Add upsampling layers
203
+ model += [
204
+ act(**activation_params),
205
+ SConvTranspose1d(mult * n_filters, mult * n_filters // 2,
206
+ kernel_size=ratio * 2, stride=ratio,
207
+ norm=norm, norm_kwargs=norm_params,
208
+ causal=causal, trim_right_ratio=trim_right_ratio),
209
+ ]
210
+ # Add residual layers
211
+ for j in range(n_residual_layers):
212
+ model += [
213
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
214
+ dilations=[dilation_base ** j, 1],
215
+ activation=activation, activation_params=activation_params,
216
+ norm=norm, norm_params=norm_params, causal=causal,
217
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
218
+
219
+ mult //= 2
220
+
221
+ # Add final layers
222
+ model += [
223
+ act(**activation_params),
224
+ SConv1d(n_filters, channels, last_kernel_size, norm=norm, norm_kwargs=norm_params,
225
+ causal=causal, pad_mode=pad_mode)
226
+ ]
227
+ # Add optional final activation to decoder (eg. tanh)
228
+ if final_activation is not None:
229
+ final_act = getattr(nn, final_activation)
230
+ final_activation_params = final_activation_params or {}
231
+ model += [
232
+ final_act(**final_activation_params)
233
+ ]
234
+ self.model = nn.Sequential(*model)
235
+
236
+ def forward(self, z):
237
+ y = self.model(z)
238
+ return y
239
+
240
+
241
+ def test():
242
+ import torch
243
+ encoder = SEANetEncoder()
244
+ decoder = SEANetDecoder()
245
+ x = torch.randn(1, 1, 24000)
246
+ z = encoder(x)
247
+ assert list(z.shape) == [1, 128, 75], z.shape
248
+ y = decoder(z)
249
+ assert y.shape == x.shape, (x.shape, y.shape)
250
+
251
+
252
+ if __name__ == '__main__':
253
+ test()
encoder/modules/transformer.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """A streamable transformer."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
17
+ """Create time embedding for the given positions, target dimension `dim`.
18
+ """
19
+ # We aim for BTC format
20
+ assert dim % 2 == 0
21
+ half_dim = dim // 2
22
+ adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
23
+ phase = positions / (max_period ** (adim / (half_dim - 1)))
24
+ return torch.cat([
25
+ torch.cos(phase),
26
+ torch.sin(phase),
27
+ ], dim=-1)
28
+
29
+
30
+ class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
31
+ def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
32
+ if self.norm_first:
33
+ sa_input = self.norm1(x)
34
+ x = x + self._sa_block(sa_input, x_past, past_context)
35
+ x = x + self._ff_block(self.norm2(x))
36
+ else:
37
+ sa_input = x
38
+ x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
39
+ x = self.norm2(x + self._ff_block(x))
40
+
41
+ return x, sa_input
42
+
43
+ # self-attention block
44
+ def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore
45
+ _, T, _ = x.shape
46
+ _, H, _ = x_past.shape
47
+
48
+ queries = x
49
+ keys = torch.cat([x_past, x], dim=1)
50
+ values = keys
51
+
52
+ queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
53
+ keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
54
+ delta = queries_pos - keys_pos
55
+ valid_access = (delta >= 0) & (delta <= past_context)
56
+ x = self.self_attn(queries, keys, values,
57
+ attn_mask=~valid_access,
58
+ need_weights=False)[0]
59
+ return self.dropout1(x)
60
+
61
+
62
+ class StreamingTransformerEncoder(nn.Module):
63
+ """TransformerEncoder with streaming support.
64
+
65
+ Args:
66
+ dim (int): dimension of the data.
67
+ hidden_scale (int): intermediate dimension of FF module is this times the dimension.
68
+ num_heads (int): number of heads.
69
+ num_layers (int): number of layers.
70
+ max_period (float): maxium period of cosines in the positional embedding.
71
+ past_context (int or None): receptive field for the causal mask, infinite if None.
72
+ gelu (bool): if true uses GeLUs, otherwise use ReLUs.
73
+ norm_in (bool): normalize the input.
74
+ dropout (float): dropout probability.
75
+ **kwargs: See `nn.TransformerEncoderLayer`.
76
+ """
77
+ def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5,
78
+ max_period: float = 10000, past_context: int = 1000, gelu: bool = True,
79
+ norm_in: bool = True, dropout: float = 0., **kwargs):
80
+ super().__init__()
81
+ assert dim % num_heads == 0
82
+ hidden_dim = int(dim * hidden_scale)
83
+
84
+ self.max_period = max_period
85
+ self.past_context = past_context
86
+ activation: tp.Any = F.gelu if gelu else F.relu
87
+
88
+ self.norm_in: nn.Module
89
+ if norm_in:
90
+ self.norm_in = nn.LayerNorm(dim)
91
+ else:
92
+ self.norm_in = nn.Identity()
93
+
94
+ self.layers = nn.ModuleList()
95
+ for idx in range(num_layers):
96
+ self.layers.append(
97
+ StreamingTransformerEncoderLayer(
98
+ dim, num_heads, hidden_dim,
99
+ activation=activation, batch_first=True, dropout=dropout, **kwargs))
100
+
101
+ def forward(self, x: torch.Tensor,
102
+ states: tp.Optional[tp.List[torch.Tensor]] = None,
103
+ offset: tp.Union[int, torch.Tensor] = 0):
104
+ B, T, C = x.shape
105
+ if states is None:
106
+ states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
107
+
108
+ positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
109
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
110
+
111
+ new_state: tp.List[torch.Tensor] = []
112
+ x = self.norm_in(x)
113
+ x = x + pos_emb
114
+
115
+ for layer_state, layer in zip(states, self.layers):
116
+ x, new_layer_state = layer(x, layer_state, self.past_context)
117
+ new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
118
+ new_state.append(new_layer_state[:, -self.past_context:, :])
119
+ return x, new_state, offset + T
encoder/msstftd.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """MS-STFT discriminator, provided here for reference."""
8
+
9
+ import typing as tp
10
+
11
+ import torchaudio
12
+ import torch
13
+ from torch import nn
14
+ from einops import rearrange
15
+
16
+ from .modules import NormConv2d
17
+
18
+
19
+ FeatureMapType = tp.List[torch.Tensor]
20
+ LogitsType = torch.Tensor
21
+ DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
22
+
23
+
24
+ def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
25
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
26
+
27
+
28
+ class DiscriminatorSTFT(nn.Module):
29
+ """STFT sub-discriminator.
30
+ Args:
31
+ filters (int): Number of filters in convolutions
32
+ in_channels (int): Number of input channels. Default: 1
33
+ out_channels (int): Number of output channels. Default: 1
34
+ n_fft (int): Size of FFT for each scale. Default: 1024
35
+ hop_length (int): Length of hop between STFT windows for each scale. Default: 256
36
+ kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
37
+ stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
38
+ dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
39
+ win_length (int): Window size for each scale. Default: 1024
40
+ normalized (bool): Whether to normalize by magnitude after stft. Default: True
41
+ norm (str): Normalization method. Default: `'weight_norm'`
42
+ activation (str): Activation function. Default: `'LeakyReLU'`
43
+ activation_params (dict): Parameters to provide to the activation function.
44
+ growth (int): Growth factor for the filters. Default: 1
45
+ """
46
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
47
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
48
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
49
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
50
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
51
+ super().__init__()
52
+ assert len(kernel_size) == 2
53
+ assert len(stride) == 2
54
+ self.filters = filters
55
+ self.in_channels = in_channels
56
+ self.out_channels = out_channels
57
+ self.n_fft = n_fft
58
+ self.hop_length = hop_length
59
+ self.win_length = win_length
60
+ self.normalized = normalized
61
+ self.activation = getattr(torch.nn, activation)(**activation_params)
62
+ self.spec_transform = torchaudio.transforms.Spectrogram(
63
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
64
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
65
+ spec_channels = 2 * self.in_channels
66
+ self.convs = nn.ModuleList()
67
+ self.convs.append(
68
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
69
+ )
70
+ in_chs = min(filters_scale * self.filters, max_filters)
71
+ for i, dilation in enumerate(dilations):
72
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
73
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
74
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
75
+ norm=norm))
76
+ in_chs = out_chs
77
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
78
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
79
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
80
+ norm=norm))
81
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
82
+ kernel_size=(kernel_size[0], kernel_size[0]),
83
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
84
+ norm=norm)
85
+
86
+ def forward(self, x: torch.Tensor):
87
+ fmap = []
88
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
89
+ z = torch.cat([z.real, z.imag], dim=1)
90
+ z = rearrange(z, 'b c w t -> b c t w')
91
+ for i, layer in enumerate(self.convs):
92
+ z = layer(z)
93
+ z = self.activation(z)
94
+ fmap.append(z)
95
+ z = self.conv_post(z)
96
+ return z, fmap
97
+
98
+
99
+ class MultiScaleSTFTDiscriminator(nn.Module):
100
+ """Multi-Scale STFT (MS-STFT) discriminator.
101
+ Args:
102
+ filters (int): Number of filters in convolutions
103
+ in_channels (int): Number of input channels. Default: 1
104
+ out_channels (int): Number of output channels. Default: 1
105
+ n_ffts (Sequence[int]): Size of FFT for each scale
106
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
107
+ win_lengths (Sequence[int]): Window size for each scale
108
+ **kwargs: additional args for STFTDiscriminator
109
+ """
110
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
111
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
112
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
113
+ super().__init__()
114
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
115
+ self.discriminators = nn.ModuleList([
116
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
117
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
118
+ for i in range(len(n_ffts))
119
+ ])
120
+ self.num_discriminators = len(self.discriminators)
121
+
122
+ def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
123
+ logits = []
124
+ fmaps = []
125
+ for disc in self.discriminators:
126
+ logit, fmap = disc(x)
127
+ logits.append(logit)
128
+ fmaps.append(fmap)
129
+ return logits, fmaps
130
+
131
+
132
+ def test():
133
+ disc = MultiScaleSTFTDiscriminator(filters=32)
134
+ y = torch.randn(1, 1, 24000)
135
+ y_hat = torch.randn(1, 1, 24000)
136
+
137
+ y_disc_r, fmap_r = disc(y)
138
+ y_disc_gen, fmap_gen = disc(y_hat)
139
+ assert len(y_disc_r) == len(y_disc_gen) == len(fmap_r) == len(fmap_gen) == disc.num_discriminators
140
+
141
+ assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
142
+ assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
143
+ assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
144
+
145
+
146
+ if __name__ == '__main__':
147
+ test()
encoder/quantization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import QuantizedResult, ResidualVectorQuantizer
encoder/quantization/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (243 Bytes). View file
 
encoder/quantization/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (253 Bytes). View file
 
encoder/quantization/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (241 Bytes). View file
 
encoder/quantization/__pycache__/core_vq.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
encoder/quantization/__pycache__/core_vq.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
encoder/quantization/__pycache__/core_vq.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
encoder/quantization/__pycache__/vq.cpython-310.pyc ADDED
Binary file (5.14 kB). View file
 
encoder/quantization/__pycache__/vq.cpython-38.pyc ADDED
Binary file (4.8 kB). View file
 
encoder/quantization/__pycache__/vq.cpython-39.pyc ADDED
Binary file (5.12 kB). View file