from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch, torchaudio
import soundfile as sf
import numpy as np
# setup wav2vec2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# info:
is_model_blank_token = '[PAD]' # important to know for CTC decoding
is_model_word_separator = '|'
is_labels_dict = {"f": 0, "a": 1, "é": 2, "t": 3, "o": 4, "n": 5, "e": 6, "y": 8, "k": 9, "j": 10, "u": 11, "d": 12, "w": 13, "l": 14, "ú": 15, "q": 16, "g": 17, "í": 18, "s": 19, "r": 20, "ý": 21, "i": 22, "z": 23, "m": 24, "h": 25, "ó": 26, "þ": 27, "æ": 28, "c": 29, "á": 30, "v": 31, "b": 32, "ð": 33, "x": 34, "ö": 35, "p": 36, "|": 7, "[UNK]": 37, "[PAD]": 38}
is_model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device)
is_processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH)
is_inverse_dict = {v:k for k,v in is_labels_dict.items()}
is_all_labels = tuple(is_labels_dict.keys())
is_blank_id = is_labels_dict[is_model_blank_token]
fo_model_blank_token = '[PAD]' # important to know for CTC decoding
fo_model_word_separator = '|'
fo_labels_dict = {"w": 0, "i": 1, "6": 2, "s": 3, "_": 4, "k": 5, "l": 6, "ú": 7, "2": 8, "4": 9, "d": 10, "z": 11, "3": 12, "ð": 13, "t": 15, "ø": 16, "x": 17, "p": 18, "o": 19, "æ": 20, "n": 21, "f": 22, "á": 23, "5": 24, "g": 25, "ý": 26, "r": 27, "é": 28, "u": 29, "ü": 30, "y": 31, "í": 32, "h": 33, "q": 34, "b": 35, "e": 36, "v": 37, "-": 38, "c": 39, "j": 40, ".": 41, "ó": 42, "'": 43, "m": 44, "a": 45, "|": 14, "[UNK]": 46, "[PAD]": 47}
fo_model = Wav2Vec2ForCTC.from_pretrained(fo_MODEL_PATH).to(device)
fo_processor = Wav2Vec2Processor.from_pretrained(fo_MODEL_PATH)
fo_inverse_dict = {v:k for k,v in fo_labels_dict.items()}
fo_all_labels = tuple(fo_labels_dict.keys())
fo_blank_id = fo_labels_dict[fo_model_blank_token]
no_model_blank_token = '[PAD]' # important to know for CTC decoding
no_model_word_separator = '|'
no_labels_dict = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "å": 27, "æ": 28, "ø": 29, "|": 0, "[UNK]": 30, "[PAD]": 31}
no_model = Wav2Vec2ForCTC.from_pretrained(no_MODEL_PATH).to(device)
no_processor = Wav2Vec2Processor.from_pretrained(no_MODEL_PATH)
no_inverse_dict = {v:k for k,v in no_labels_dict.items()}
no_all_labels = tuple(no_labels_dict.keys())
no_blank_id = no_labels_dict[no_model_blank_token]
# forced alignment with ctc decoder
# originally based on implementation of
# return the label class probability of each audio frame
def get_frame_probs(wav_path):
wav = readwav(wav_path)
with torch.inference_mode(): # similar to with torch.no_grad():
input_values = processor(wav,sampling_rate=16000).input_values[0]
input_values = torch.tensor(input_values, device=device).unsqueeze(0)
emits = model(input_values).logits
emits = torch.log_softmax(emits, dim=-1)
emit = emits[0].cpu().detach()
return emit
def get_trellis(emission, tokens, blank_id):
num_frame = emission.size(0)
num_tokens = len(tokens)
# Trellis has extra diemsions for both time axis and tokens.
# The extra dim for tokens represents <SoS> (start-of-sentence)
# The extra dim for time axis is for simplification of the code.
trellis = torch.empty((num_frame + 1, num_tokens + 1))
trellis[0, 0] = 0
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames)
trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens
trellis[-num_tokens:, 0] = float("inf")
for t in range(num_frame):
trellis[t + 1, 1:] = torch.max(
# Score for staying at the same token
trellis[t, 1:] + emission[t, blank_id],
# Score for changing to the next token
trellis[t, :-1] + emission[t, tokens],
return trellis
def backtrack(trellis, emission, tokens, blank_id):
# Note:
# j and t are indices for trellis, which has extra dimensions
# for time and tokens at the beginning.
# When referring to time frame index `T` in trellis,
# the corresponding index in emission is `T-1`.
# Similarly, when referring to token index `J` in trellis,
# the corresponding index in transcript is `J-1`.
j = trellis.size(1) - 1
t_start = torch.argmax(trellis[:, j]).item()
path = []
for t in range(t_start, 0, -1):
# 1. Figure out if the current position was stay or change
# Note (again):
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
# Score for token staying the same from time frame J-1 to T.
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
# Score for token changing from C-1 at T-1 to J at T.
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
# 2. Store the path with frame-wise probability.
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
# Return token index and time index in non-trellis coordinate.
path.append((j - 1, t - 1, prob))
# 3. Update the token
if changed > stayed:
j -= 1
if j == 0:
raise ValueError("Failed to align")
return path[::-1]
def merge_repeats(path,transcript):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1][0] == path[i2][0]: # while both path steps point to the same token index
i2 += 1
segments.append( # when i2 finally switches to a different token,
(transcript[path[i1][0]], # to the list of segments, append the token from i1
path[i1][1], # time of the first path-point of that token
path[i2 - 1][1] + 1, # time of the final path-point for that token.
i1 = i2
return segments
def merge_words(segments, separator):
words = []
i1, i2 = 0, 0
while i1 < len(segments):
if i2 >= len(segments) or segments[i2][0] == separator:
if i1 != i2:
segs = segments[i1:i2]
word = "".join([seg[0] for seg in segs])
words.append((word, segments[i1][1], segments[i2 - 1][2]))
i1 = i2 + 1
i2 = i1
i2 += 1
return words
# handle in/out/etc.
def readwav(wav_path):
wav, sr =, dtype=np.float32)
if len(wav.shape) == 2:
wav = wav.mean(1)
if sr != 16000:
wlen = int(wav.shape[0] / sr * 16000)
wav = signal.resample(wav, wlen)
return wav
#convert frame-numbers to timestamps in seconds
# w2v2 step size is about 20ms, or 50 frames per second
def f2s(fr):
return fr/50
def fmt(frame_aligns):
return [(label,f2s(start),f2s(end)) for label,start,end in frame_aligns]
# prepare the input transcript text string
# handle input strings that still have punctuation,
# or that have characters not present in labels_dict
def prep_transcript(xcp):
xcp = xcp.lower()
while ' ' in xcp:
xcp = xcp.replace(' ', ' ')
xcp = xcp.replace(' ',model_word_separator)
label_ids = [labels_dict[c] for c in xcp]
return xcp, label_ids
def ctcalign(wav_path,transcript_string):
norm_txt, rec_label_ids = prep_transcript(transcript_string)
emit = get_frame_probs(wav_path)
trellis = get_trellis(emit, rec_label_ids, blank_id)
path = backtrack(trellis, emit, rec_label_ids, blank_id)
segments = merge_repeats(path,norm_txt)
words = merge_words(segments, model_word_separator)
#segments = [s for s in segments if s[0] != model_word_separator]
return fmt(segments), fmt(words)