|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import torch, torchaudio |
|
import soundfile as sf |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
torch.random.manual_seed(0) |
|
|
|
|
|
is_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h" |
|
is_model_blank_token = '[PAD]' |
|
is_model_word_separator = '|' |
|
is_labels_dict = {"f": 0, "a": 1, "é": 2, "t": 3, "o": 4, "n": 5, "e": 6, "y": 8, "k": 9, "j": 10, "u": 11, "d": 12, "w": 13, "l": 14, "ú": 15, "q": 16, "g": 17, "í": 18, "s": 19, "r": 20, "ý": 21, "i": 22, "z": 23, "m": 24, "h": 25, "ó": 26, "þ": 27, "æ": 28, "c": 29, "á": 30, "v": 31, "b": 32, "ð": 33, "x": 34, "ö": 35, "p": 36, "|": 7, "[UNK]": 37, "[PAD]": 38} |
|
|
|
is_model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device) |
|
is_processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH) |
|
is_inverse_dict = {v:k for k,v in is_labels_dict.items()} |
|
is_all_labels = tuple(is_labels_dict.keys()) |
|
is_blank_id = is_labels_dict[is_model_blank_token] |
|
|
|
|
|
|
|
fo_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-faroese-100h" |
|
fo_model_blank_token = '[PAD]' |
|
fo_model_word_separator = '|' |
|
fo_labels_dict = {"w": 0, "i": 1, "6": 2, "s": 3, "_": 4, "k": 5, "l": 6, "ú": 7, "2": 8, "4": 9, "d": 10, "z": 11, "3": 12, "ð": 13, "t": 15, "ø": 16, "x": 17, "p": 18, "o": 19, "æ": 20, "n": 21, "f": 22, "á": 23, "5": 24, "g": 25, "ý": 26, "r": 27, "é": 28, "u": 29, "ü": 30, "y": 31, "í": 32, "h": 33, "q": 34, "b": 35, "e": 36, "v": 37, "-": 38, "c": 39, "j": 40, ".": 41, "ó": 42, "'": 43, "m": 44, "a": 45, "|": 14, "[UNK]": 46, "[PAD]": 47} |
|
|
|
fo_model = Wav2Vec2ForCTC.from_pretrained(fo_MODEL_PATH).to(device) |
|
fo_processor = Wav2Vec2Processor.from_pretrained(fo_MODEL_PATH) |
|
fo_inverse_dict = {v:k for k,v in fo_labels_dict.items()} |
|
fo_all_labels = tuple(fo_labels_dict.keys()) |
|
fo_blank_id = fo_labels_dict[fo_model_blank_token] |
|
|
|
|
|
|
|
no_MODEL_PATH="NbAiLab/nb-wav2vec2-1b-bokmaal" |
|
no_model_blank_token = '[PAD]' |
|
no_model_word_separator = '|' |
|
no_labels_dict = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "å": 27, "æ": 28, "ø": 29, "|": 0, "[UNK]": 30, "[PAD]": 31} |
|
|
|
no_model = Wav2Vec2ForCTC.from_pretrained(no_MODEL_PATH).to(device) |
|
no_processor = Wav2Vec2Processor.from_pretrained(no_MODEL_PATH) |
|
no_inverse_dict = {v:k for k,v in no_labels_dict.items()} |
|
no_all_labels = tuple(no_labels_dict.keys()) |
|
no_blank_id = no_labels_dict[no_model_blank_token] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_frame_probs(wav_path): |
|
wav = readwav(wav_path) |
|
with torch.inference_mode(): |
|
input_values = processor(wav,sampling_rate=16000).input_values[0] |
|
input_values = torch.tensor(input_values, device=device).unsqueeze(0) |
|
emits = model(input_values).logits |
|
emits = torch.log_softmax(emits, dim=-1) |
|
emit = emits[0].cpu().detach() |
|
return emit |
|
|
|
|
|
def get_trellis(emission, tokens, blank_id): |
|
num_frame = emission.size(0) |
|
num_tokens = len(tokens) |
|
|
|
|
|
|
|
trellis = torch.empty((num_frame + 1, num_tokens + 1)) |
|
trellis[0, 0] = 0 |
|
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) |
|
trellis[0, -num_tokens:] = -float("inf") |
|
trellis[-num_tokens:, 0] = float("inf") |
|
for t in range(num_frame): |
|
trellis[t + 1, 1:] = torch.max( |
|
|
|
trellis[t, 1:] + emission[t, blank_id], |
|
|
|
trellis[t, :-1] + emission[t, tokens], |
|
) |
|
return trellis |
|
|
|
|
|
def backtrack(trellis, emission, tokens, blank_id): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
j = trellis.size(1) - 1 |
|
t_start = torch.argmax(trellis[:, j]).item() |
|
|
|
path = [] |
|
for t in range(t_start, 0, -1): |
|
|
|
|
|
|
|
|
|
stayed = trellis[t - 1, j] + emission[t - 1, blank_id] |
|
|
|
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] |
|
|
|
|
|
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() |
|
|
|
path.append((j - 1, t - 1, prob)) |
|
|
|
|
|
if changed > stayed: |
|
j -= 1 |
|
if j == 0: |
|
break |
|
else: |
|
raise ValueError("Failed to align") |
|
return path[::-1] |
|
|
|
|
|
|
|
def merge_repeats(path,transcript): |
|
i1, i2 = 0, 0 |
|
segments = [] |
|
while i1 < len(path): |
|
while i2 < len(path) and path[i1][0] == path[i2][0]: |
|
i2 += 1 |
|
segments.append( |
|
|
|
(transcript[path[i1][0]], |
|
path[i1][1], |
|
path[i2 - 1][1] + 1, |
|
) |
|
) |
|
i1 = i2 |
|
return segments |
|
|
|
def merge_words(segments, separator): |
|
words = [] |
|
i1, i2 = 0, 0 |
|
while i1 < len(segments): |
|
if i2 >= len(segments) or segments[i2][0] == separator: |
|
if i1 != i2: |
|
segs = segments[i1:i2] |
|
word = "".join([seg[0] for seg in segs]) |
|
words.append((word, segments[i1][1], segments[i2 - 1][2])) |
|
i1 = i2 + 1 |
|
i2 = i1 |
|
else: |
|
i2 += 1 |
|
return words |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def readwav(wav_path): |
|
wav, sr = sf.read(wav_path, dtype=np.float32) |
|
if len(wav.shape) == 2: |
|
wav = wav.mean(1) |
|
if sr != 16000: |
|
wlen = int(wav.shape[0] / sr * 16000) |
|
wav = signal.resample(wav, wlen) |
|
return wav |
|
|
|
|
|
|
|
|
|
def f2s(fr): |
|
return fr/50 |
|
|
|
def fmt(frame_aligns): |
|
return [(label,f2s(start),f2s(end)) for label,start,end in frame_aligns] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def prep_transcript(xcp): |
|
xcp = xcp.lower() |
|
while ' ' in xcp: |
|
xcp = xcp.replace(' ', ' ') |
|
xcp = xcp.replace(' ',model_word_separator) |
|
label_ids = [labels_dict[c] for c in xcp] |
|
return xcp, label_ids |
|
|
|
|
|
def ctcalign(wav_path,transcript_string): |
|
norm_txt, rec_label_ids = prep_transcript(transcript_string) |
|
emit = get_frame_probs(wav_path) |
|
trellis = get_trellis(emit, rec_label_ids, blank_id) |
|
path = backtrack(trellis, emit, rec_label_ids, blank_id) |
|
segments = merge_repeats(path,norm_txt) |
|
words = merge_words(segments, model_word_separator) |
|
|
|
|
|
return fmt(segments), fmt(words) |
|
|
|
|
|
|
|
|
|
|