csukuangfj commited on
Commit
c58376d
1 Parent(s): 9be5254

first working version

Browse files
Files changed (3) hide show
  1. app.py +296 -0
  2. model.py +152 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
4
+ #
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # References:
20
+ # https://gradio.app/docs/#dropdown
21
+
22
+ import logging
23
+ import os
24
+ import time
25
+ import uuid
26
+ from datetime import datetime
27
+
28
+ import gradio as gr
29
+ import torch
30
+
31
+ from model import (
32
+ embedding2models,
33
+ get_speaker_diarization,
34
+ read_wave,
35
+ speaker_segmentation_models,
36
+ )
37
+
38
+ embedding_frameworks = list(embedding2models.keys())
39
+
40
+
41
+ def MyPrint(s):
42
+ now = datetime.now()
43
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
44
+ print(f"{date_time}: {s}")
45
+
46
+
47
+ def convert_to_wav(in_filename: str) -> str:
48
+ """Convert the input audio file to a wave file"""
49
+ out_filename = str(uuid.uuid4())
50
+ out_filename = f"{in_filename}.wav"
51
+
52
+ MyPrint(f"Converting '{in_filename}' to '{out_filename}'")
53
+ _ = os.system(
54
+ f"ffmpeg -hide_banner -loglevel error -i '{in_filename}' -ar 16000 -ac 1 '{out_filename}' -y"
55
+ )
56
+
57
+ return out_filename
58
+
59
+
60
+ def build_html_output(s: str, style: str = "result_item_success"):
61
+ return f"""
62
+ <div class='result'>
63
+ <div class='result_item {style}'>
64
+ {s}
65
+ </div>
66
+ </div>
67
+ """
68
+
69
+
70
+ def process_uploaded_file(
71
+ embedding_framework: str,
72
+ embedding_model: str,
73
+ speaker_segmentation_model: str,
74
+ input_num_speakers: str,
75
+ input_threshold: str,
76
+ in_filename: str,
77
+ ):
78
+ if in_filename is None or in_filename == "":
79
+ return "", build_html_output(
80
+ "Please first upload a file and then click "
81
+ 'the button "submit for recognition"',
82
+ "result_item_error",
83
+ )
84
+
85
+ try:
86
+ input_num_speakers = int(input_num_speakers)
87
+ except ValueError:
88
+ return "", build_html_output(
89
+ "Please set a valid number of speakers",
90
+ "result_item_error",
91
+ )
92
+
93
+ try:
94
+ input_threshold = float(input_threshold)
95
+ if input_threshold < 0 or input_threshold < 10:
96
+ raise ValueError("")
97
+ except ValueError:
98
+ return "", build_html_output(
99
+ "Please set a valid threshold between (0, 10)",
100
+ "result_item_error",
101
+ )
102
+
103
+ MyPrint(f"Processing uploaded file: {in_filename}")
104
+ try:
105
+ return process(
106
+ in_filename=in_filename,
107
+ embedding_framework=embedding_framework,
108
+ embedding_model=embedding_model,
109
+ speaker_segmentation_model=speaker_segmentation_model,
110
+ input_num_speakers=input_num_speakers,
111
+ input_threshold=input_threshold,
112
+ )
113
+ except Exception as e:
114
+ MyPrint(str(e))
115
+ return "", build_html_output(str(e), "result_item_error")
116
+
117
+
118
+ @torch.no_grad()
119
+ def process(
120
+ embedding_framework: str,
121
+ embedding_model: str,
122
+ speaker_segmentation_model: str,
123
+ input_num_speakers: str,
124
+ input_threshold: str,
125
+ in_filename: str,
126
+ ):
127
+ MyPrint(f"embedding_framework: {embedding_framework}")
128
+ MyPrint(f"embedding_model: {embedding_model}")
129
+ MyPrint(f"speaker_segmentation_model: {speaker_segmentation_model}")
130
+ MyPrint(f"input_num_speakers: {input_num_speakers}")
131
+ MyPrint(f"input_threshold: {input_threshold}")
132
+ MyPrint(f"in_filename: {in_filename}")
133
+
134
+ filename = convert_to_wav(in_filename)
135
+
136
+ now = datetime.now()
137
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
138
+ MyPrint(f"Started at {date_time}")
139
+
140
+ start = time.time()
141
+
142
+ sd = get_speaker_diarization(
143
+ segmentation=speaker_segmentation_model,
144
+ embedding_model=embedding_model,
145
+ num_clusters=input_num_speakers,
146
+ threshold=input_threshold,
147
+ )
148
+
149
+ audio = read_wave(filename)[0]
150
+ segments = sd.process(audio).sort_by_start_time()
151
+ s = ""
152
+ for seg in segments:
153
+ s += f"{seg.start:.3f} -- {seg.end:.3f} speaker_{seg.speaker:02d}\n"
154
+
155
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
156
+ end = time.time()
157
+
158
+ duration = audio.shape[0] / sd.sample_rate
159
+ rtf = (end - start) / duration
160
+
161
+ MyPrint(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
162
+
163
+ info = f"""
164
+ Wave duration : {duration: .3f} s <br/>
165
+ Processing time: {end - start: .3f} s <br/>
166
+ RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
167
+ """
168
+ if rtf > 1:
169
+ info += (
170
+ "<br/>We are loading the model for the first run. "
171
+ "Please run again to measure the real RTF.<br/>"
172
+ )
173
+
174
+ MyPrint(info)
175
+ MyPrint(f"\nembedding_model: {embedding_model}\nSegments: {s}")
176
+
177
+ return s, build_html_output(info)
178
+
179
+
180
+ title = "# Speaker diarization with Next-gen Kaldi"
181
+ description = """
182
+ This space shows how to do speaker diarization with Next-gen Kaldi.
183
+
184
+ It is running on CPU within a docker container provided by Hugging Face.
185
+
186
+ See more information by visiting
187
+ <https://k2-fsa.github.io/sherpa/onnx/speaker-diarization/index.html>
188
+
189
+ If you want to try it on Android, please download pre-built Android
190
+ APKs for speaker diarzation by visiting
191
+ <https://k2-fsa.github.io/sherpa/onnx/speaker-diarization/android.html>
192
+ """
193
+
194
+ # css style is copied from
195
+ # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
196
+ css = """
197
+ .result {display:flex;flex-direction:column}
198
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
199
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
200
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
201
+ """
202
+
203
+
204
+ def update_embedding_model_dropdown(framework: str):
205
+ if framework in embedding2models:
206
+ choices = embedding2models[framework]
207
+ return gr.Dropdown(
208
+ choices=choices,
209
+ value=choices[0],
210
+ interactive=True,
211
+ )
212
+
213
+ raise ValueError(f"Unsupported framework: {framework}")
214
+
215
+
216
+ demo = gr.Blocks(css=css)
217
+
218
+
219
+ with demo:
220
+ gr.Markdown(title)
221
+
222
+ embedding_framework_choices = list(embedding2models.keys())
223
+ embedding_framework_radio = gr.Radio(
224
+ label="Speaker embedding frameworks",
225
+ choices=embedding_framework_choices,
226
+ value=embedding_framework_choices[0],
227
+ )
228
+
229
+ embedding_model_dropdown = gr.Dropdown(
230
+ choices=embedding2models[embedding_framework_choices[0]],
231
+ label="Select a speaker embedding model",
232
+ value=embedding2models[embedding_framework_choices[0]][0],
233
+ )
234
+
235
+ embedding_framework_choices.change(
236
+ update_embedding_model_dropdown,
237
+ inputs=embedding_framework_radio,
238
+ outputs=embedding_model_dropdown,
239
+ )
240
+
241
+ speaker_segmentation_model_dropdown = gr.Dropdown(
242
+ choices=speaker_segmentation_models,
243
+ label="Select a speaker segmentation model",
244
+ value=speaker_segmentation_models[0],
245
+ )
246
+
247
+ input_num_speakers = gr.Textbox(
248
+ label="Number of speakers",
249
+ info="Number of speakers",
250
+ lines=1,
251
+ max_lines=1,
252
+ value="0",
253
+ placeholder="Specify number of speakers in the test file",
254
+ )
255
+
256
+ input_threshold = gr.Textbox(
257
+ label="Clustering threshold",
258
+ info="Threshold for clustering",
259
+ lines=1,
260
+ max_lines=1,
261
+ value="0.5",
262
+ placeholder="Clustering for threshold",
263
+ )
264
+
265
+ with gr.Tabs():
266
+ with gr.TabItem("Upload from disk"):
267
+ uploaded_file = gr.Audio(
268
+ sources=["upload"], # Choose between "microphone", "upload"
269
+ type="filepath",
270
+ label="Upload from disk",
271
+ )
272
+ upload_button = gr.Button("Submit for speaker diarization")
273
+ uploaded_output = gr.Textbox(label="Result from uploaded file")
274
+ uploaded_html_info = gr.HTML(label="Info")
275
+
276
+ upload_button.click(
277
+ process_uploaded_file,
278
+ inputs=[
279
+ embedding_framework_radio,
280
+ embedding_model_dropdown,
281
+ speaker_segmentation_model_dropdown,
282
+ input_num_speakers,
283
+ input_threshold,
284
+ uploaded_file,
285
+ ],
286
+ outputs=[uploaded_output, uploaded_html_info],
287
+ )
288
+
289
+ gr.Markdown(description)
290
+
291
+ if __name__ == "__main__":
292
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
293
+
294
+ logging.basicConfig(format=formatter, level=logging.INFO)
295
+
296
+ demo.launch()
model.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
2
+ #
3
+ # See LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import wave
18
+ from typing import List, Tuple
19
+
20
+ import numpy as np
21
+ import sherpa_onnx
22
+ from huggingface_hub import hf_hub_download
23
+
24
+
25
+ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
26
+ """
27
+ Args:
28
+ wave_filename:
29
+ Path to a wave file. It should be single channel and each sample should
30
+ be 16-bit. Its sample rate does not need to be 16kHz.
31
+ Returns:
32
+ Return a tuple containing:
33
+ - A 1-D array of dtype np.float32 containing the samples, which are
34
+ normalized to the range [-1, 1].
35
+ - sample rate of the wave file
36
+ """
37
+
38
+ with wave.open(wave_filename) as f:
39
+ assert f.getnchannels() == 1, f.getnchannels()
40
+ assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
41
+ num_samples = f.getnframes()
42
+ samples = f.readframes(num_samples)
43
+ samples_int16 = np.frombuffer(samples, dtype=np.int16)
44
+ samples_float32 = samples_int16.astype(np.float32)
45
+
46
+ samples_float32 = samples_float32 / 32768
47
+ return samples_float32, f.getframerate()
48
+
49
+
50
+ def _get_nn_model_filename(
51
+ repo_id: str,
52
+ filename: str,
53
+ subfolder: str = ".",
54
+ ) -> str:
55
+ nn_model_filename = hf_hub_download(
56
+ repo_id=repo_id,
57
+ filename=filename,
58
+ subfolder=subfolder,
59
+ )
60
+ return nn_model_filename
61
+
62
+
63
+ def get_speaker_segmentation_model(repo_id) -> List[str]:
64
+ assert repo_id in ("pyannote/segmentation-3.0",)
65
+
66
+ if repo_id == "pyannote/segmentation-3.0":
67
+ return _get_nn_model_filename(
68
+ repo_id="csukuangfj/sherpa-onnx-pyannote-segmentation-3-0",
69
+ filename="model.onnx",
70
+ )
71
+
72
+
73
+ def get_speaker_embedding_model(model_name) -> List[str]:
74
+ assert (
75
+ model_name
76
+ in three_d_speaker_embedding_models
77
+ + nemo_speaker_embedding_models
78
+ + wespeaker_embedding_models
79
+ )
80
+
81
+ return _get_nn_model_filename(
82
+ repo_id="csukuangfj/speaker-embedding-models",
83
+ filename=model_name,
84
+ )
85
+
86
+
87
+ def get_speaker_diarization(
88
+ segmentation_model: str, embedding_model: str, num_clusters: int, threshold: float
89
+ ):
90
+ segmentation = get_speaker_segmentation_model(segmentation_model)
91
+ embedding = get_speaker_embedding_model(embedding_model)
92
+
93
+ config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
94
+ segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
95
+ pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
96
+ model=segmentation
97
+ ),
98
+ ),
99
+ embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(model=embedding),
100
+ clustering=sherpa_onnx.FastClusteringConfig(
101
+ num_clusters=num_clusters,
102
+ threshold=threshold,
103
+ ),
104
+ min_duration_on=0.3,
105
+ min_duration_off=0.5,
106
+ )
107
+ if not config.validate():
108
+ raise RuntimeError(
109
+ "Please check your config and make sure all required files exist"
110
+ )
111
+
112
+ return sherpa_onnx.OfflineSpeakerDiarization(config)
113
+ pass
114
+
115
+
116
+ speaker_segmentation_models = ["pyannote/segmentation-3.0"]
117
+
118
+
119
+ nemo_speaker_embedding_models = [
120
+ "nemo_en_speakerverification_speakernet.onnx",
121
+ "nemo_en_titanet_large.onnx",
122
+ "nemo_en_titanet_small.onnx",
123
+ ]
124
+
125
+ three_d_speaker_embedding_models = [
126
+ "3dspeaker_speech_campplus_sv_en_voxceleb_16k.onnx",
127
+ "3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx",
128
+ "3dspeaker_speech_campplus_sv_zh_en_16k-common_advanced.onnx",
129
+ "3dspeaker_speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx",
130
+ "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx",
131
+ "3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx",
132
+ "3dspeaker_speech_eres2net_sv_en_voxceleb_16k.onnx",
133
+ "3dspeaker_speech_eres2net_sv_zh-cn_16k-common.onnx",
134
+ "3dspeaker_speech_eres2netv2_sv_zh-cn_16k-common.onnx",
135
+ ]
136
+ wespeaker_embedding_models = [
137
+ "wespeaker_en_voxceleb_CAM++.onnx",
138
+ "wespeaker_en_voxceleb_CAM++_LM.onnx",
139
+ "wespeaker_en_voxceleb_resnet152_LM.onnx",
140
+ "wespeaker_en_voxceleb_resnet221_LM.onnx",
141
+ "wespeaker_en_voxceleb_resnet293_LM.onnx",
142
+ "wespeaker_en_voxceleb_resnet34.onnx",
143
+ "wespeaker_en_voxceleb_resnet34_LM.onnx",
144
+ "wespeaker_zh_cnceleb_resnet34.onnx",
145
+ "wespeaker_zh_cnceleb_resnet34_LM.onnx",
146
+ ]
147
+
148
+ embedding2models = {
149
+ "3D-Speaker": three_d_speaker_embedding_models,
150
+ "NeMo": nemo_speaker_embedding_models,
151
+ "WeSpeaker": wespeaker_embedding_models,
152
+ }
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface_hub
2
+
3
+ #https://huggingface.co/csukuangfj/sherpa-onnx-wheels/resolve/main/sherpa_onnx-1.9.26-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
4
+
5
+ sherpa-onnx>=1.10.28