import torch from torch import nn from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz from decoder.spectral_ops import IMDCT, ISTFT from decoder.modules import symexp class FourierHead(nn.Module): """Base class for inverse fourier modules.""" def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ raise NotImplementedError("Subclasses must implement the forward method.") class ISTFTHead(FourierHead): """ ISTFT Head module for predicting STFT complex coefficients. Args: dim (int): Hidden dimension of the model. n_fft (int): Size of Fourier transform. hop_length (int): The distance between neighboring sliding window frames, which should align with the resolution of the input features. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". """ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): super().__init__() out_dim = n_fft + 2 self.out = torch.nn.Linear(dim, out_dim) self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the ISTFTHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x).transpose(1, 2) mag, p = x.chunk(2, dim=1) mag = torch.exp(mag) mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes # wrapping happens here. These two lines produce real and imaginary value x = torch.cos(p) y = torch.sin(p) # recalculating phase here does not produce anything new # only costs time # phase = torch.atan2(y, x) # S = mag * torch.exp(phase * 1j) # better directly produce the complex value S = mag * (x + 1j * y) audio = self.istft(S) return audio class IMDCTSymExpHead(FourierHead): """ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function Args: dim (int): Hidden dimension of the model. mdct_frame_len (int): Length of the MDCT frame. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized based on perceptual scaling. Defaults to None. clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. """ def __init__( self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False, ): super().__init__() out_dim = mdct_frame_len // 2 self.out = nn.Linear(dim, out_dim) self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) self.clip_audio = clip_audio if sample_rate is not None: # optionally init the last layer following mel-scale m_max = _hz_to_mel(sample_rate // 2) m_pts = torch.linspace(0, m_max, out_dim) f_pts = _mel_to_hz(m_pts) scale = 1 - (f_pts / f_pts.max()) with torch.no_grad(): self.out.weight.mul_(scale.view(-1, 1)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the IMDCTSymExpHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x) x = symexp(x) x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes audio = self.imdct(x) if self.clip_audio: audio = torch.clip(x, min=-1.0, max=1.0) return audio class IMDCTCosHead(FourierHead): """ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) ยท cos(p) Args: dim (int): Hidden dimension of the model. mdct_frame_len (int): Length of the MDCT frame. padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False. """ def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False): super().__init__() self.clip_audio = clip_audio self.out = nn.Linear(dim, mdct_frame_len) self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the IMDCTCosHead module. Args: x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, L is the sequence length, and H denotes the model dimension. Returns: Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. """ x = self.out(x) m, p = x.chunk(2, dim=2) m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes audio = self.imdct(m * torch.cos(p)) if self.clip_audio: audio = torch.clip(x, min=-1.0, max=1.0) return audio