from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import torch, torchaudio import soundfile as sf import numpy as np import scipy.signal #------------------------------------------ # setup wav2vec2 #------------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.random.manual_seed(0) # info: https://huggingface.co/carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h/blob/main/vocab.json is_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h" is_model_blank_token = '[PAD]' # important to know for CTC decoding is_model_word_separator = '|' is_labels_dict = {"f": 0, "a": 1, "é": 2, "t": 3, "o": 4, "n": 5, "e": 6, "y": 8, "k": 9, "j": 10, "u": 11, "d": 12, "w": 13, "l": 14, "ú": 15, "q": 16, "g": 17, "í": 18, "s": 19, "r": 20, "ý": 21, "i": 22, "z": 23, "m": 24, "h": 25, "ó": 26, "þ": 27, "æ": 28, "c": 29, "á": 30, "v": 31, "b": 32, "ð": 33, "x": 34, "ö": 35, "p": 36, "|": 7, "[UNK]": 37, "[PAD]": 38} is_model = Wav2Vec2ForCTC.from_pretrained(is_MODEL_PATH).to(device) is_processor = Wav2Vec2Processor.from_pretrained(is_MODEL_PATH) is_inverse_dict = {v:k for k,v in is_labels_dict.items()} is_all_labels = tuple(is_labels_dict.keys()) is_blank_id = is_labels_dict[is_model_blank_token] fo_MODEL_PATH="carlosdanielhernandezmena/wav2vec2-large-xlsr-53-faroese-100h" fo_model_blank_token = '[PAD]' # important to know for CTC decoding fo_model_word_separator = '|' fo_labels_dict = {"w": 0, "i": 1, "6": 2, "s": 3, "_": 4, "k": 5, "l": 6, "ú": 7, "2": 8, "4": 9, "d": 10, "z": 11, "3": 12, "ð": 13, "t": 15, "ø": 16, "x": 17, "p": 18, "o": 19, "æ": 20, "n": 21, "f": 22, "á": 23, "5": 24, "g": 25, "ý": 26, "r": 27, "é": 28, "u": 29, "ü": 30, "y": 31, "í": 32, "h": 33, "q": 34, "b": 35, "e": 36, "v": 37, "-": 38, "c": 39, "j": 40, ".": 41, "ó": 42, "'": 43, "m": 44, "a": 45, "|": 14, "[UNK]": 46, "[PAD]": 47} fo_model = Wav2Vec2ForCTC.from_pretrained(fo_MODEL_PATH).to(device) fo_processor = Wav2Vec2Processor.from_pretrained(fo_MODEL_PATH) fo_inverse_dict = {v:k for k,v in fo_labels_dict.items()} fo_all_labels = tuple(fo_labels_dict.keys()) fo_blank_id = fo_labels_dict[fo_model_blank_token] no_MODEL_PATH="NbAiLab/nb-wav2vec2-1b-bokmaal" no_model_blank_token = '[PAD]' # important to know for CTC decoding no_model_word_separator = '|' no_labels_dict = {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "å": 27, "æ": 28, "ø": 29, "|": 0, "[UNK]": 30, "[PAD]": 31} no_model = Wav2Vec2ForCTC.from_pretrained(no_MODEL_PATH).to(device) no_processor = Wav2Vec2Processor.from_pretrained(no_MODEL_PATH) no_inverse_dict = {v:k for k,v in no_labels_dict.items()} no_all_labels = tuple(no_labels_dict.keys()) no_blank_id = no_labels_dict[no_model_blank_token] d = {"Icelandic": {'model': is_model, 'processor': is_processor, 'inverse_dict': is_inverse_dict, 'labels_dict': is_labels_dict, 'all_labels': is_all_labels, 'blank_id': is_blank_id, 'model_blank_token': is_model_blank_token, 'model_word_separator': is_model_word_separator}, "Faroese": {'model': fo_model, 'processor': fo_processor, 'inverse_dict': fo_inverse_dict, 'labels_dict': fo_labels_dict, 'all_labels': fo_all_labels, 'blank_id': fo_blank_id, 'model_blank_token': fo_model_blank_token, 'model_word_separator': fo_model_word_separator}, "Norwegian": {'model': no_model, 'processor': no_processor, 'inverse_dict': no_inverse_dict, 'labels_dict': no_labels_dict, 'all_labels': no_all_labels, 'blank_id': no_blank_id, 'model_blank_token': no_model_blank_token, 'model_word_separator': no_model_word_separator} } #------------------------------------------ # forced alignment with ctc decoder # originally based on implementation of # https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html #------------------------------------------ # return the label class probability of each audio frame def get_frame_probs(wav_path,lang): wav = readwav(wav_path) with torch.inference_mode(): # similar to with torch.no_grad(): input_values = d[lang]['processor'](wav,sampling_rate=16000).input_values[0] input_values = torch.tensor(input_values, device=device).unsqueeze(0) emits = d[lang]['model'](input_values).logits emits = torch.log_softmax(emits, dim=-1) emit = emits[0].cpu().detach() return emit def get_trellis(emission, tokens, blank_id): num_frame = emission.size(0) num_tokens = len(tokens) # Trellis has extra diemsions for both time axis and tokens. # The extra dim for tokens represents (start-of-sentence) # The extra dim for time axis is for simplification of the code. trellis = torch.empty((num_frame + 1, num_tokens + 1)) trellis[0, 0] = 0 trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames) trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens trellis[-num_tokens:, 0] = float("inf") for t in range(num_frame): trellis[t + 1, 1:] = torch.max( # Score for staying at the same token trellis[t, 1:] + emission[t, blank_id], # Score for changing to the next token trellis[t, :-1] + emission[t, tokens], ) return trellis def backtrack(trellis, emission, tokens, blank_id): # Note: # j and t are indices for trellis, which has extra dimensions # for time and tokens at the beginning. # When referring to time frame index `T` in trellis, # the corresponding index in emission is `T-1`. # Similarly, when referring to token index `J` in trellis, # the corresponding index in transcript is `J-1`. j = trellis.size(1) - 1 t_start = torch.argmax(trellis[:, j]).item() path = [] for t in range(t_start, 0, -1): # 1. Figure out if the current position was stay or change # Note (again): # `emission[J-1]` is the emission at time frame `J` of trellis dimension. # Score for token staying the same from time frame J-1 to T. stayed = trellis[t - 1, j] + emission[t - 1, blank_id] # Score for token changing from C-1 at T-1 to J at T. changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] # 2. Store the path with frame-wise probability. prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() # Return token index and time index in non-trellis coordinate. path.append((j - 1, t - 1, prob)) # 3. Update the token if changed > stayed: j -= 1 if j == 0: break else: raise ValueError("Failed to align") return path[::-1] def merge_repeats(path,transcript): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1][0] == path[i2][0]: # while both path steps point to the same token index i2 += 1 segments.append( # when i2 finally switches to a different token, #Segment( (transcript[path[i1][0]], # to the list of segments, append the token from i1 path[i1][1], # time of the first path-point of that token path[i2 - 1][1] + 1, # time of the final path-point for that token. ) ) i1 = i2 return segments def merge_words(segments, separator): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2][0] == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg[0] for seg in segs]) words.append((word, segments[i1][1], segments[i2 - 1][2])) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words #------------------------------------------ # handle in/out/etc. #------------------------------------------ def readwav(wav_path): wav, sr = sf.read(wav_path, dtype=np.float32) if len(wav.shape) == 2: wav = wav.mean(1) if sr != 16000: wlen = int(wav.shape[0] / sr * 16000) wav = signal.resample(wav, wlen) return wav #convert frame-numbers to timestamps in seconds # w2v2 step size is about 20ms, or 50 frames per second def f2s(fr): return fr/50 def fmt(frame_aligns): return [(label,f2s(start),f2s(end)) for label,start,end in frame_aligns] # generate mfa format for character (phone) and word alignments def mfalike(chars,wds): hed = ['Begin,End,Label,Type,Speaker\n'] wlines = [f'{s},{e},{w},words,000\n' for w,s,e in wds] slines = [f'{s},{e},{sg},phones,000\n' for sg,s,e in chars] return (''.join(hed+wlines+slines)) # prepare the input transcript text string # TODO: # handle input strings that still have punctuation, # or that have characters not present in labels_dict def prep_transcript(xcp,lang): xcp = xcp.lower() while ' ' in xcp: xcp = xcp.replace(' ', ' ') xcp = xcp.replace(' ',d[lang]['model_word_separator']) label_ids = [d[lang]['labels_dict'][c] for c in xcp] return xcp, label_ids def langsalign(wav_path,transcript_string,lang): norm_txt, rec_label_ids = prep_transcript(transcript_string, lang) emit = get_frame_probs(wav_path, lang) trellis = get_trellis(emit, rec_label_ids, d[lang]['blank_id']) path = backtrack(trellis, emit, rec_label_ids, d[lang]['blank_id']) segments = merge_repeats(path,norm_txt) words = merge_words(segments, d[lang]['model_word_separator']) #segments = [s for s in segments if s[0] != model_word_separator] return mfalike(fmt(segments), fmt(words))