clr commited on
Commit
60edeed
1 Parent(s): dea339f

Update ctcalign.py

Browse files
Files changed (1) hide show
  1. ctcalign.py +38 -14
ctcalign.py CHANGED
@@ -3,6 +3,7 @@ import torch, torchaudio
3
  import soundfile as sf
4
  import numpy as np
5
  from scipy import signal
 
6
 
7
  #------------------------------------------
8
  # setup wav2vec2
@@ -96,6 +97,29 @@ def get_trellis(emission, tokens, blank_id):
96
  return trellis
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def backtrack(trellis, emission, tokens, blank_id):
100
  # Note:
101
  # j and t are indices for trellis, which has extra dimensions
@@ -106,11 +130,10 @@ def backtrack(trellis, emission, tokens, blank_id):
106
  # the corresponding index in transcript is `J-1`.
107
  j = trellis.size(1) - 1
108
  t_start = torch.argmax(trellis[:, j]).item()
109
-
110
  path = []
111
  for t in range(t_start, 0, -1):
112
  # 1. Figure out if the current position was stay or change
113
- # Note (again):
114
  # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
115
  # Score for token staying the same from time frame J-1 to T.
116
  stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
@@ -120,7 +143,7 @@ def backtrack(trellis, emission, tokens, blank_id):
120
  # 2. Store the path with frame-wise probability.
121
  prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
122
  # Return token index and time index in non-trellis coordinate.
123
- path.append((j - 1, t - 1, prob))
124
 
125
  # 3. Update the token
126
  if changed > stayed:
@@ -132,32 +155,35 @@ def backtrack(trellis, emission, tokens, blank_id):
132
  return path[::-1]
133
 
134
 
135
-
136
  def merge_repeats(path,transcript):
137
  i1, i2 = 0, 0
138
  segments = []
139
  while i1 < len(path):
140
- while i2 < len(path) and path[i1][0] == path[i2][0]: # while both path steps point to the same token index
141
  i2 += 1
 
142
  segments.append( # when i2 finally switches to a different token,
143
- #Segment(
144
- (transcript[path[i1][0]], # to the list of segments, append the token from i1
145
- path[i1][1], # time of the first path-point of that token
146
- path[i2 - 1][1] + 1, # time of the final path-point for that token.
 
147
  )
148
  )
149
  i1 = i2
150
  return segments
151
 
 
 
152
  def merge_words(segments, separator):
153
  words = []
154
  i1, i2 = 0, 0
155
  while i1 < len(segments):
156
- if i2 >= len(segments) or segments[i2][0] == separator:
157
  if i1 != i2:
158
  segs = segments[i1:i2]
159
- word = "".join([seg[0] for seg in segs])
160
- words.append((word, segments[i1][1], segments[i2 - 1][2]))
161
  i1 = i2 + 1
162
  i2 = i1
163
  else:
@@ -165,8 +191,6 @@ def merge_words(segments, separator):
165
  return words
166
 
167
 
168
-
169
-
170
  #------------------------------------------
171
  # handle in/out/etc.
172
  #------------------------------------------
 
3
  import soundfile as sf
4
  import numpy as np
5
  from scipy import signal
6
+ from dataclasses import dataclass
7
 
8
  #------------------------------------------
9
  # setup wav2vec2
 
97
  return trellis
98
 
99
 
100
+
101
+ @dataclass
102
+ class Point:
103
+ token_index: int
104
+ time_index: int
105
+ score: float
106
+
107
+ @dataclass
108
+ class Segment:
109
+ label: str
110
+ start: int
111
+ end: int
112
+ score: float
113
+
114
+ def __repr__(self):
115
+ return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
116
+
117
+ @property
118
+ def length(self):
119
+ return self.end - self.start
120
+
121
+
122
+
123
  def backtrack(trellis, emission, tokens, blank_id):
124
  # Note:
125
  # j and t are indices for trellis, which has extra dimensions
 
130
  # the corresponding index in transcript is `J-1`.
131
  j = trellis.size(1) - 1
132
  t_start = torch.argmax(trellis[:, j]).item()
133
+
134
  path = []
135
  for t in range(t_start, 0, -1):
136
  # 1. Figure out if the current position was stay or change
 
137
  # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
138
  # Score for token staying the same from time frame J-1 to T.
139
  stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
 
143
  # 2. Store the path with frame-wise probability.
144
  prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
145
  # Return token index and time index in non-trellis coordinate.
146
+ path.append(Point(j - 1, t - 1, prob))
147
 
148
  # 3. Update the token
149
  if changed > stayed:
 
155
  return path[::-1]
156
 
157
 
 
158
  def merge_repeats(path,transcript):
159
  i1, i2 = 0, 0
160
  segments = []
161
  while i1 < len(path):
162
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index: # while both path steps point to the same token index
163
  i2 += 1
164
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
165
  segments.append( # when i2 finally switches to a different token,
166
+ Segment(
167
+ transcript[path[i1].token_index],# to the list of segments, append the token from i1
168
+ path[i1].time_index, # time of the first path-point of that token
169
+ path[i2 - 1].time_index + 1, # time of the final path-point for that token.
170
+ score,
171
  )
172
  )
173
  i1 = i2
174
  return segments
175
 
176
+
177
+
178
  def merge_words(segments, separator):
179
  words = []
180
  i1, i2 = 0, 0
181
  while i1 < len(segments):
182
+ if i2 >= len(segments) or segments[i2].label == separator:
183
  if i1 != i2:
184
  segs = segments[i1:i2]
185
+ word = "".join([seg.label for seg in segs])
186
+ words.append((word, segments[i1].start, segments[i2 - 1].end))
187
  i1 = i2 + 1
188
  i2 = i1
189
  else:
 
191
  return words
192
 
193
 
 
 
194
  #------------------------------------------
195
  # handle in/out/etc.
196
  #------------------------------------------