cati commited on
Commit
69d94dc
1 Parent(s): 8d8a9b2
Files changed (2) hide show
  1. app.py +3 -1
  2. ctcalign.py +187 -0
app.py CHANGED
@@ -29,7 +29,9 @@ with bl:
29
  # Long and short Icelandic vowels
30
  Choose a word, speaker group, and aligner type. Available speaker groups are native speakers, second-language speakers, or all. Aligner options are Montreal Forced Aligner (MFA) and CTC decoding with Wav2vec-2.0.
31
 
32
- The general expectation is that syllables with long stressed vowels followed by short consonants have a higher vowel:consonant duration ratio, while syllables with short stressed vowels followed by long consonants have a lower vowel:consonant ratio. However, a great many other factors affect the relative duration in any one recorded token. See Pind 1999, 'Speech segment durations and quantity in Icelandic' (J. Acoustical Society of America, 106(2)) for a review of the acoustics of Icelandic vowel duration. All phoneme durations are measured automatically with no human correction. The purpose of this demo is to evaluate the role of such tools in large-scale phonetic research. Therefore, no measurements shown in this demo should be taken as conclusive without some independent verification.
 
 
33
  """
34
  )
35
 
 
29
  # Long and short Icelandic vowels
30
  Choose a word, speaker group, and aligner type. Available speaker groups are native speakers, second-language speakers, or all. Aligner options are Montreal Forced Aligner (MFA) and CTC decoding with Wav2vec-2.0.
31
 
32
+ The general expectation is that syllables with long stressed vowels followed by short consonants have a higher vowel:consonant duration ratio, while syllables with short stressed vowels followed by long consonants have a lower vowel:consonant ratio. However, a great many other factors affect the relative duration in any one recorded token. See Pind 1999, 'Speech segment durations and quantity in Icelandic' (J. Acoustical Society of America, 106(2)) for a review of the acoustics of Icelandic vowel duration.
33
+
34
+ All phoneme durations are measured automatically with no human correction. The purpose of this demo is to evaluate the role of such tools in large-scale phonetic research. Therefore, no measurements shown in this demo should be taken as conclusive without some independent verification.
35
  """
36
  )
37
 
ctcalign.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
2
+ import torch, torchaudio
3
+ import soundfile as sf
4
+ import numpy as np
5
+
6
+ #------------------------------------------
7
+ # setup wav2vec2
8
+ #------------------------------------------
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ torch.random.manual_seed(0)
12
+
13
+ # info: https://huggingface.co/carlosdanielhernandezmena/wav2vec2-large-xlsr-53-icelandic-ep10-1000h/blob/main/vocab.json
14
+ MODEL_PATH="/work/caitlinr/w2vrec/wav2vec2-large-xlsr-53-icelandic-ep10-1000h"
15
+ model_blank_token = '[PAD]' # important to know for CTC decoding
16
+ model_word_separator = '|'
17
+ labels_dict = {"f": 0, "a": 1, "é": 2, "t": 3, "o": 4, "n": 5, "e": 6, "y": 8, "k": 9, "j": 10, "u": 11, "d": 12, "w": 13, "l": 14, "ú": 15, "q": 16, "g": 17, "í": 18, "s": 19, "r": 20, "ý": 21, "i": 22, "z": 23, "m": 24, "h": 25, "ó": 26, "þ": 27, "æ": 28, "c": 29, "á": 30, "v": 31, "b": 32, "ð": 33, "x": 34, "ö": 35, "p": 36, "|": 7, "[UNK]": 37, "[PAD]": 38}
18
+
19
+ model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).to(device)
20
+ processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
21
+ inverse_dict = {v:k for k,v in labels_dict.items()}
22
+ all_labels = tuple(labels_dict.keys())
23
+ blank_id = labels_dict[model_blank_token]
24
+
25
+
26
+
27
+
28
+ #------------------------------------------
29
+ # forced alignment with ctc decoder
30
+ # originally based on implementation of
31
+ # https://pytorch.org/audio/main/tutorials/forced_alignment_tutorial.html
32
+ #------------------------------------------
33
+
34
+ # return the label class probability of each audio frame
35
+ def get_frame_probs(wav_path):
36
+ wav = readwav(wav_path)
37
+ with torch.inference_mode(): # similar to with torch.no_grad():
38
+ input_values = processor(wav,sampling_rate=16000).input_values[0]
39
+ input_values = torch.tensor(input_values, device=device).unsqueeze(0)
40
+ emits = model(input_values).logits
41
+ emits = torch.log_softmax(emits, dim=-1)
42
+ emit = emits[0].cpu().detach()
43
+ return emit
44
+
45
+
46
+ def get_trellis(emission, tokens, blank_id):
47
+ num_frame = emission.size(0)
48
+ num_tokens = len(tokens)
49
+ # Trellis has extra diemsions for both time axis and tokens.
50
+ # The extra dim for tokens represents <SoS> (start-of-sentence)
51
+ # The extra dim for time axis is for simplification of the code.
52
+ trellis = torch.empty((num_frame + 1, num_tokens + 1))
53
+ trellis[0, 0] = 0
54
+ trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) # len of this slice of trellis is len of audio frames)
55
+ trellis[0, -num_tokens:] = -float("inf") # len of this slice of trellis is len of transcript tokens
56
+ trellis[-num_tokens:, 0] = float("inf")
57
+ for t in range(num_frame):
58
+ trellis[t + 1, 1:] = torch.max(
59
+ # Score for staying at the same token
60
+ trellis[t, 1:] + emission[t, blank_id],
61
+ # Score for changing to the next token
62
+ trellis[t, :-1] + emission[t, tokens],
63
+ )
64
+ return trellis
65
+
66
+
67
+ def backtrack(trellis, emission, tokens, blank_id):
68
+ # Note:
69
+ # j and t are indices for trellis, which has extra dimensions
70
+ # for time and tokens at the beginning.
71
+ # When referring to time frame index `T` in trellis,
72
+ # the corresponding index in emission is `T-1`.
73
+ # Similarly, when referring to token index `J` in trellis,
74
+ # the corresponding index in transcript is `J-1`.
75
+ j = trellis.size(1) - 1
76
+ t_start = torch.argmax(trellis[:, j]).item()
77
+
78
+ path = []
79
+ for t in range(t_start, 0, -1):
80
+ # 1. Figure out if the current position was stay or change
81
+ # Note (again):
82
+ # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
83
+ # Score for token staying the same from time frame J-1 to T.
84
+ stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
85
+ # Score for token changing from C-1 at T-1 to J at T.
86
+ changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
87
+
88
+ # 2. Store the path with frame-wise probability.
89
+ prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
90
+ # Return token index and time index in non-trellis coordinate.
91
+ path.append((j - 1, t - 1, prob))
92
+
93
+ # 3. Update the token
94
+ if changed > stayed:
95
+ j -= 1
96
+ if j == 0:
97
+ break
98
+ else:
99
+ raise ValueError("Failed to align")
100
+ return path[::-1]
101
+
102
+
103
+
104
+ def merge_repeats(path,transcript):
105
+ i1, i2 = 0, 0
106
+ segments = []
107
+ while i1 < len(path):
108
+ while i2 < len(path) and path[i1][0] == path[i2][0]: # while both path steps point to the same token index
109
+ i2 += 1
110
+ segments.append( # when i2 finally switches to a different token,
111
+ #Segment(
112
+ (transcript[path[i1][0]], # to the list of segments, append the token from i1
113
+ path[i1][1], # time of the first path-point of that token
114
+ path[i2 - 1][1] + 1, # time of the final path-point for that token.
115
+ )
116
+ )
117
+ i1 = i2
118
+ return segments
119
+
120
+ def merge_words(segments, separator):
121
+ words = []
122
+ i1, i2 = 0, 0
123
+ while i1 < len(segments):
124
+ if i2 >= len(segments) or segments[i2][0] == separator:
125
+ if i1 != i2:
126
+ segs = segments[i1:i2]
127
+ word = "".join([seg[0] for seg in segs])
128
+ words.append((word, segments[i1][1], segments[i2 - 1][2]))
129
+ i1 = i2 + 1
130
+ i2 = i1
131
+ else:
132
+ i2 += 1
133
+ return words
134
+
135
+
136
+
137
+
138
+ #------------------------------------------
139
+ # handle in/out/etc.
140
+ #------------------------------------------
141
+
142
+ def readwav(wav_path):
143
+ wav, sr = sf.read(wav_path, dtype=np.float32)
144
+ if len(wav.shape) == 2:
145
+ wav = wav.mean(1)
146
+ if sr != 16000:
147
+ wlen = int(wav.shape[0] / sr * 16000)
148
+ wav = signal.resample(wav, wlen)
149
+ return wav
150
+
151
+
152
+ #convert frame-numbers to timestamps in seconds
153
+ # w2v2 step size is about 20ms, or 50 frames per second
154
+ def f2s(fr):
155
+ return fr/50
156
+
157
+ def fmt(frame_aligns):
158
+ return [(label,f2s(start),f2s(end)) for label,start,end in frame_aligns]
159
+
160
+
161
+ # prepare the input transcript text string
162
+ # TODO:
163
+ # handle input strings that still have punctuation,
164
+ # or that have characters not present in labels_dict
165
+ def prep_transcript(xcp):
166
+ xcp = xcp.lower()
167
+ while ' ' in xcp:
168
+ xcp = xcp.replace(' ', ' ')
169
+ xcp = xcp.replace(' ',model_word_separator)
170
+ label_ids = [labels_dict[c] for c in xcp]
171
+ return xcp, label_ids
172
+
173
+
174
+ def ctcalign(wav_path,transcript_string):
175
+ norm_txt, rec_label_ids = prep_transcript(transcript_string)
176
+ emit = get_frame_probs(wav_path)
177
+ trellis = get_trellis(emit, rec_label_ids, blank_id)
178
+ path = backtrack(trellis, emit, rec_label_ids, blank_id)
179
+ segments = merge_repeats(path,norm_txt)
180
+ words = merge_words(segments, model_word_separator)
181
+
182
+ #segments = [s for s in segments if s[0] != model_word_separator]
183
+ return fmt(segments), fmt(words)
184
+
185
+
186
+
187
+