Update ctcalign.py
Browse files- ctcalign.py +38 -14
ctcalign.py
CHANGED
@@ -3,6 +3,7 @@ import torch, torchaudio
|
|
3 |
import soundfile as sf
|
4 |
import numpy as np
|
5 |
from scipy import signal
|
|
|
6 |
|
7 |
#------------------------------------------
|
8 |
# setup wav2vec2
|
@@ -96,6 +97,29 @@ def get_trellis(emission, tokens, blank_id):
|
|
96 |
return trellis
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def backtrack(trellis, emission, tokens, blank_id):
|
100 |
# Note:
|
101 |
# j and t are indices for trellis, which has extra dimensions
|
@@ -106,11 +130,10 @@ def backtrack(trellis, emission, tokens, blank_id):
|
|
106 |
# the corresponding index in transcript is `J-1`.
|
107 |
j = trellis.size(1) - 1
|
108 |
t_start = torch.argmax(trellis[:, j]).item()
|
109 |
-
|
110 |
path = []
|
111 |
for t in range(t_start, 0, -1):
|
112 |
# 1. Figure out if the current position was stay or change
|
113 |
-
# Note (again):
|
114 |
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
115 |
# Score for token staying the same from time frame J-1 to T.
|
116 |
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
@@ -120,7 +143,7 @@ def backtrack(trellis, emission, tokens, blank_id):
|
|
120 |
# 2. Store the path with frame-wise probability.
|
121 |
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
122 |
# Return token index and time index in non-trellis coordinate.
|
123 |
-
path.append((j - 1, t - 1, prob))
|
124 |
|
125 |
# 3. Update the token
|
126 |
if changed > stayed:
|
@@ -132,32 +155,35 @@ def backtrack(trellis, emission, tokens, blank_id):
|
|
132 |
return path[::-1]
|
133 |
|
134 |
|
135 |
-
|
136 |
def merge_repeats(path,transcript):
|
137 |
i1, i2 = 0, 0
|
138 |
segments = []
|
139 |
while i1 < len(path):
|
140 |
-
while i2 < len(path) and path[i1]
|
141 |
i2 += 1
|
|
|
142 |
segments.append( # when i2 finally switches to a different token,
|
143 |
-
|
144 |
-
|
145 |
-
path[i1]
|
146 |
-
path[i2 - 1]
|
|
|
147 |
)
|
148 |
)
|
149 |
i1 = i2
|
150 |
return segments
|
151 |
|
|
|
|
|
152 |
def merge_words(segments, separator):
|
153 |
words = []
|
154 |
i1, i2 = 0, 0
|
155 |
while i1 < len(segments):
|
156 |
-
if i2 >= len(segments) or segments[i2]
|
157 |
if i1 != i2:
|
158 |
segs = segments[i1:i2]
|
159 |
-
word = "".join([seg
|
160 |
-
words.append((word, segments[i1]
|
161 |
i1 = i2 + 1
|
162 |
i2 = i1
|
163 |
else:
|
@@ -165,8 +191,6 @@ def merge_words(segments, separator):
|
|
165 |
return words
|
166 |
|
167 |
|
168 |
-
|
169 |
-
|
170 |
#------------------------------------------
|
171 |
# handle in/out/etc.
|
172 |
#------------------------------------------
|
|
|
3 |
import soundfile as sf
|
4 |
import numpy as np
|
5 |
from scipy import signal
|
6 |
+
from dataclasses import dataclass
|
7 |
|
8 |
#------------------------------------------
|
9 |
# setup wav2vec2
|
|
|
97 |
return trellis
|
98 |
|
99 |
|
100 |
+
|
101 |
+
@dataclass
|
102 |
+
class Point:
|
103 |
+
token_index: int
|
104 |
+
time_index: int
|
105 |
+
score: float
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class Segment:
|
109 |
+
label: str
|
110 |
+
start: int
|
111 |
+
end: int
|
112 |
+
score: float
|
113 |
+
|
114 |
+
def __repr__(self):
|
115 |
+
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
|
116 |
+
|
117 |
+
@property
|
118 |
+
def length(self):
|
119 |
+
return self.end - self.start
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
def backtrack(trellis, emission, tokens, blank_id):
|
124 |
# Note:
|
125 |
# j and t are indices for trellis, which has extra dimensions
|
|
|
130 |
# the corresponding index in transcript is `J-1`.
|
131 |
j = trellis.size(1) - 1
|
132 |
t_start = torch.argmax(trellis[:, j]).item()
|
133 |
+
|
134 |
path = []
|
135 |
for t in range(t_start, 0, -1):
|
136 |
# 1. Figure out if the current position was stay or change
|
|
|
137 |
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
138 |
# Score for token staying the same from time frame J-1 to T.
|
139 |
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
|
|
143 |
# 2. Store the path with frame-wise probability.
|
144 |
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
145 |
# Return token index and time index in non-trellis coordinate.
|
146 |
+
path.append(Point(j - 1, t - 1, prob))
|
147 |
|
148 |
# 3. Update the token
|
149 |
if changed > stayed:
|
|
|
155 |
return path[::-1]
|
156 |
|
157 |
|
|
|
158 |
def merge_repeats(path,transcript):
|
159 |
i1, i2 = 0, 0
|
160 |
segments = []
|
161 |
while i1 < len(path):
|
162 |
+
while i2 < len(path) and path[i1].token_index == path[i2].token_index: # while both path steps point to the same token index
|
163 |
i2 += 1
|
164 |
+
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
|
165 |
segments.append( # when i2 finally switches to a different token,
|
166 |
+
Segment(
|
167 |
+
transcript[path[i1].token_index],# to the list of segments, append the token from i1
|
168 |
+
path[i1].time_index, # time of the first path-point of that token
|
169 |
+
path[i2 - 1].time_index + 1, # time of the final path-point for that token.
|
170 |
+
score,
|
171 |
)
|
172 |
)
|
173 |
i1 = i2
|
174 |
return segments
|
175 |
|
176 |
+
|
177 |
+
|
178 |
def merge_words(segments, separator):
|
179 |
words = []
|
180 |
i1, i2 = 0, 0
|
181 |
while i1 < len(segments):
|
182 |
+
if i2 >= len(segments) or segments[i2].label == separator:
|
183 |
if i1 != i2:
|
184 |
segs = segments[i1:i2]
|
185 |
+
word = "".join([seg.label for seg in segs])
|
186 |
+
words.append((word, segments[i1].start, segments[i2 - 1].end))
|
187 |
i1 = i2 + 1
|
188 |
i2 = i1
|
189 |
else:
|
|
|
191 |
return words
|
192 |
|
193 |
|
|
|
|
|
194 |
#------------------------------------------
|
195 |
# handle in/out/etc.
|
196 |
#------------------------------------------
|