cati
commited on
Commit
•
69d94dc
1
Parent(s):
8d8a9b2
..
Browse files- app.py +3 -1
- ctcalign.py +187 -0
app.py
CHANGED
@@ -29,7 +29,9 @@ with bl:
|
|
29 |
# Long and short Icelandic vowels
|
30 |
Choose a word, speaker group, and aligner type. Available speaker groups are native speakers, second-language speakers, or all. Aligner options are Montreal Forced Aligner (MFA) and CTC decoding with Wav2vec-2.0.
|
31 |
|
32 |
-
The general expectation is that syllables with long stressed vowels followed by short consonants have a higher vowel:consonant duration ratio, while syllables with short stressed vowels followed by long consonants have a lower vowel:consonant ratio. However, a great many other factors affect the relative duration in any one recorded token. See Pind 1999, 'Speech segment durations and quantity in Icelandic' (J. Acoustical Society of America, 106(2)) for a review of the acoustics of Icelandic vowel duration.
|
|
|
|
|
33 |
"""
|
34 |
)
|
35 |
|
|
|
29 |
# Long and short Icelandic vowels
|
30 |
Choose a word, speaker group, and aligner type. Available speaker groups are native speakers, second-language speakers, or all. Aligner options are Montreal Forced Aligner (MFA) and CTC decoding with Wav2vec-2.0.
|
31 |
|
32 |
+
The general expectation is that syllables with long stressed vowels followed by short consonants have a higher vowel:consonant duration ratio, while syllables with short stressed vowels followed by long consonants have a lower vowel:consonant ratio. However, a great many other factors affect the relative duration in any one recorded token. See Pind 1999, 'Speech segment durations and quantity in Icelandic' (J. Acoustical Society of America, 106(2)) for a review of the acoustics of Icelandic vowel duration.
|
33 |
+
|
34 |
+
All phoneme durations are measured automatically with no human correction. The purpose of this demo is to evaluate the role of such tools in large-scale phonetic research. Therefore, no measurements shown in this demo should be taken as conclusive without some independent verification.
|
35 |
"""
|
36 |
)
|
37 |
|
ctcalign.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
2 |
+
import torch, torchaudio
|
3 |
+
import soundfile as sf
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
#------------------------------------------
|
7 |
+
# setup wav2vec2
|
8 |
+
#------------------------------------------
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
torch.random.manual_seed(0)
|
12 |
+
|
13 |
+
# info: https://huggingface.co/carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h/blob/main/vocab.json
|
14 |
+
MODEL_PATH="/work/caitlinr/w2vrec/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
|
15 |
+
model_blank_token = '[PAD]' # important to know for CTC decoding
|
16 |
+
model_word_separator = '|'
|
17 |
+
labels_dict = {"f": 0, "a": 1, "é": 2, "t": 3, "o": 4, "n": 5, "e": 6, "y": 8, "k": 9, "j": 10, "u": 11, "d": 12, "w": 13, "l": 14, "ú": 15, "q": 16, "g": 17, "í": 18, "s": 19, "r": 20, "ý": 21, "i": 22, "z": 23, "m": 24, "h": 25, "ó": 26, "þ": 27, "æ": 28, "c": 29, "á": 30, "v": 31, "b": 32, "ð": 33, "x": 34, "ö": 35, "p": 36, "|": 7, "[UNK]": 37, "[PAD]": 38}
|
18 |
+
|
19 |
+
model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).to(device)
|
20 |
+
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
|
21 |
+
inverse_dict = {v:k for k,v in labels_dict.items()}
|
22 |
+
all_labels = tuple(labels_dict.keys())
|
23 |
+
blank_id = labels_dict[model_blank_token]
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
#------------------------------------------
|
29 |
+
# forced alignment with ctc decoder
|
30 |
+
# originally based on implementation of
|
31 |
+
# https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html
|
32 |
+
#------------------------------------------
|
33 |
+
|
34 |
+
# return the label class probability of each audio frame
|
35 |
+
def get_frame_probs(wav_path):
|
36 |
+
wav = readwav(wav_path)
|
37 |
+
with torch.inference_mode(): # similar to with torch.no_grad():
|
38 |
+
input_values = processor(wav,sampling_rate=16000).input_values[0]
|
39 |
+
input_values = torch.tensor(input_values, device=device).unsqueeze(0)
|
40 |
+
emits = model(input_values).logits
|
41 |
+
emits = torch.log_softmax(emits, dim=-1)
|
42 |
+
emit = emits[0].cpu().detach()
|
43 |
+
return emit
|
44 |
+
|
45 |
+
|
46 |
+
def get_trellis(emission, tokens, blank_id):
|
47 |
+
num_frame = emission.size(0)
|
48 |
+
num_tokens = len(tokens)
|
49 |
+
# Trellis has extra diemsions for both time axis and tokens.
|
50 |
+
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
51 |
+
# The extra dim for time axis is for simplification of the code.
|
52 |
+
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
53 |
+
trellis[0, 0] = 0
|
54 |
+
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames)
|
55 |
+
trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens
|
56 |
+
trellis[-num_tokens:, 0] = float("inf")
|
57 |
+
for t in range(num_frame):
|
58 |
+
trellis[t + 1, 1:] = torch.max(
|
59 |
+
# Score for staying at the same token
|
60 |
+
trellis[t, 1:] + emission[t, blank_id],
|
61 |
+
# Score for changing to the next token
|
62 |
+
trellis[t, :-1] + emission[t, tokens],
|
63 |
+
)
|
64 |
+
return trellis
|
65 |
+
|
66 |
+
|
67 |
+
def backtrack(trellis, emission, tokens, blank_id):
|
68 |
+
# Note:
|
69 |
+
# j and t are indices for trellis, which has extra dimensions
|
70 |
+
# for time and tokens at the beginning.
|
71 |
+
# When referring to time frame index `T` in trellis,
|
72 |
+
# the corresponding index in emission is `T-1`.
|
73 |
+
# Similarly, when referring to token index `J` in trellis,
|
74 |
+
# the corresponding index in transcript is `J-1`.
|
75 |
+
j = trellis.size(1) - 1
|
76 |
+
t_start = torch.argmax(trellis[:, j]).item()
|
77 |
+
|
78 |
+
path = []
|
79 |
+
for t in range(t_start, 0, -1):
|
80 |
+
# 1. Figure out if the current position was stay or change
|
81 |
+
# Note (again):
|
82 |
+
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
83 |
+
# Score for token staying the same from time frame J-1 to T.
|
84 |
+
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
85 |
+
# Score for token changing from C-1 at T-1 to J at T.
|
86 |
+
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
87 |
+
|
88 |
+
# 2. Store the path with frame-wise probability.
|
89 |
+
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
90 |
+
# Return token index and time index in non-trellis coordinate.
|
91 |
+
path.append((j - 1, t - 1, prob))
|
92 |
+
|
93 |
+
# 3. Update the token
|
94 |
+
if changed > stayed:
|
95 |
+
j -= 1
|
96 |
+
if j == 0:
|
97 |
+
break
|
98 |
+
else:
|
99 |
+
raise ValueError("Failed to align")
|
100 |
+
return path[::-1]
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def merge_repeats(path,transcript):
|
105 |
+
i1, i2 = 0, 0
|
106 |
+
segments = []
|
107 |
+
while i1 < len(path):
|
108 |
+
while i2 < len(path) and path[i1][0] == path[i2][0]: # while both path steps point to the same token index
|
109 |
+
i2 += 1
|
110 |
+
segments.append( # when i2 finally switches to a different token,
|
111 |
+
#Segment(
|
112 |
+
(transcript[path[i1][0]], # to the list of segments, append the token from i1
|
113 |
+
path[i1][1], # time of the first path-point of that token
|
114 |
+
path[i2 - 1][1] + 1, # time of the final path-point for that token.
|
115 |
+
)
|
116 |
+
)
|
117 |
+
i1 = i2
|
118 |
+
return segments
|
119 |
+
|
120 |
+
def merge_words(segments, separator):
|
121 |
+
words = []
|
122 |
+
i1, i2 = 0, 0
|
123 |
+
while i1 < len(segments):
|
124 |
+
if i2 >= len(segments) or segments[i2][0] == separator:
|
125 |
+
if i1 != i2:
|
126 |
+
segs = segments[i1:i2]
|
127 |
+
word = "".join([seg[0] for seg in segs])
|
128 |
+
words.append((word, segments[i1][1], segments[i2 - 1][2]))
|
129 |
+
i1 = i2 + 1
|
130 |
+
i2 = i1
|
131 |
+
else:
|
132 |
+
i2 += 1
|
133 |
+
return words
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
#------------------------------------------
|
139 |
+
# handle in/out/etc.
|
140 |
+
#------------------------------------------
|
141 |
+
|
142 |
+
def readwav(wav_path):
|
143 |
+
wav, sr = sf.read(wav_path, dtype=np.float32)
|
144 |
+
if len(wav.shape) == 2:
|
145 |
+
wav = wav.mean(1)
|
146 |
+
if sr != 16000:
|
147 |
+
wlen = int(wav.shape[0] / sr * 16000)
|
148 |
+
wav = signal.resample(wav, wlen)
|
149 |
+
return wav
|
150 |
+
|
151 |
+
|
152 |
+
#convert frame-numbers to timestamps in seconds
|
153 |
+
# w2v2 step size is about 20ms, or 50 frames per second
|
154 |
+
def f2s(fr):
|
155 |
+
return fr/50
|
156 |
+
|
157 |
+
def fmt(frame_aligns):
|
158 |
+
return [(label,f2s(start),f2s(end)) for label,start,end in frame_aligns]
|
159 |
+
|
160 |
+
|
161 |
+
# prepare the input transcript text string
|
162 |
+
# TODO:
|
163 |
+
# handle input strings that still have punctuation,
|
164 |
+
# or that have characters not present in labels_dict
|
165 |
+
def prep_transcript(xcp):
|
166 |
+
xcp = xcp.lower()
|
167 |
+
while ' ' in xcp:
|
168 |
+
xcp = xcp.replace(' ', ' ')
|
169 |
+
xcp = xcp.replace(' ',model_word_separator)
|
170 |
+
label_ids = [labels_dict[c] for c in xcp]
|
171 |
+
return xcp, label_ids
|
172 |
+
|
173 |
+
|
174 |
+
def ctcalign(wav_path,transcript_string):
|
175 |
+
norm_txt, rec_label_ids = prep_transcript(transcript_string)
|
176 |
+
emit = get_frame_probs(wav_path)
|
177 |
+
trellis = get_trellis(emit, rec_label_ids, blank_id)
|
178 |
+
path = backtrack(trellis, emit, rec_label_ids, blank_id)
|
179 |
+
segments = merge_repeats(path,norm_txt)
|
180 |
+
words = merge_words(segments, model_word_separator)
|
181 |
+
|
182 |
+
#segments = [s for s in segments if s[0] != model_word_separator]
|
183 |
+
return fmt(segments), fmt(words)
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
|