EnglishToucan / Preprocessing /AudioPreprocessor.py
Flux9665's picture
initial commit
6faeba1
raw
history blame
6.45 kB
import numpy
import pyloudnorm as pyln
import torch
from torchaudio.transforms import MelSpectrogram
from torchaudio.transforms import Resample
class AudioPreprocessor:
def __init__(self, input_sr, output_sr=None, cut_silence=False, do_loudnorm=False, device="cpu"):
"""
The parameters are by default set up to do well
on a 16kHz signal. A different sampling rate may
require different hop_length and n_fft (e.g.
doubling frequency --> doubling hop_length and
doubling n_fft)
"""
self.cut_silence = cut_silence
self.do_loudnorm = do_loudnorm
self.device = device
self.input_sr = input_sr
self.output_sr = output_sr
self.meter = pyln.Meter(input_sr)
self.final_sr = input_sr
self.wave_to_spectrogram = LogMelSpec(output_sr if output_sr is not None else input_sr).to(device)
if cut_silence:
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(self.get_speech_timestamps,
self.save_audio,
self.read_audio,
self.VADIterator,
self.collect_chunks) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference mode or no_grad
self.silero_model = self.silero_model.to(self.device)
if output_sr is not None and output_sr != input_sr:
self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
self.final_sr = output_sr
else:
self.resample = lambda x: x
def cut_leading_and_trailing_silence(self, audio):
"""
https://github.com/snakers4/silero-vad
"""
with torch.inference_mode():
speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr)
try:
result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
return result
except IndexError:
print("Audio might be too short to cut silences from front and back.")
return audio
def normalize_loudness(self, audio):
"""
normalize the amplitudes according to
their decibels, so this should turn any
signal with different magnitudes into
the same magnitude by analysing loudness
"""
try:
loudness = self.meter.integrated_loudness(audio)
except ValueError:
# if the audio is too short, a value error will arise
return audio
loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
peak = numpy.amax(numpy.abs(loud_normed))
peak_normed = numpy.divide(loud_normed, peak)
return peak_normed
def normalize_audio(self, audio):
"""
one function to apply them all in an
order that makes sense.
"""
if self.do_loudnorm:
audio = self.normalize_loudness(audio)
audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
audio = self.resample(audio)
if self.cut_silence:
audio = self.cut_leading_and_trailing_silence(audio)
return audio
def audio_to_mel_spec_tensor(self, audio, normalize=False, explicit_sampling_rate=None):
"""
explicit_sampling_rate is for when
normalization has already been applied
and that included resampling. No way
to detect the current input_sr of the incoming
audio
"""
if type(audio) != torch.tensor and type(audio) != torch.Tensor:
audio = torch.tensor(audio, device=self.device)
if explicit_sampling_rate is None or explicit_sampling_rate == self.output_sr:
return self.wave_to_spectrogram(audio.float())
else:
if explicit_sampling_rate != self.input_sr:
print("WARNING: different sampling rate used, this will be very slow if it happens often. Consider creating a dedicated audio processor.")
self.resample = Resample(orig_freq=explicit_sampling_rate, new_freq=self.output_sr).to(self.device)
self.input_sr = explicit_sampling_rate
audio = self.resample(audio.float())
return self.wave_to_spectrogram(audio)
class LogMelSpec(torch.nn.Module):
def __init__(self, sr, *args, **kwargs):
super().__init__(*args, **kwargs)
self.spec = MelSpectrogram(sample_rate=sr,
n_fft=1024,
win_length=1024,
hop_length=256,
f_min=40.0,
f_max=sr // 2,
pad=0,
n_mels=128,
power=2.0,
normalized=False,
center=True,
pad_mode='reflect',
mel_scale='htk')
def forward(self, audio):
melspec = self.spec(audio.float())
zero_mask = melspec == 0
melspec[zero_mask] = 1e-8
logmelspec = torch.log10(melspec)
return logmelspec
if __name__ == '__main__':
import soundfile
wav, sr = soundfile.read("../audios/ad00_0004.wav")
ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 6))
import librosa.display as lbd
lbd.specshow(ap.audio_to_mel_spec_tensor(wav).cpu().numpy(),
ax=ax,
sr=16000,
cmap='GnBu',
y_axis='features',
x_axis=None,
hop_length=256)
plt.show()