HoneyTian commited on
Commit
ee326f0
·
1 Parent(s): d9015be

add examples

Browse files
Files changed (4) hide show
  1. main.py +9 -285
  2. tabs/__init__.py +6 -0
  3. tabs/shell_tab.py +30 -0
  4. tabs/vad_tab.py +297 -0
main.py CHANGED
@@ -1,34 +1,15 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
- from functools import lru_cache
5
- import json
6
  import logging
7
- from pathlib import Path
8
  import platform
9
- import shutil
10
- import tempfile
11
- import time
12
- from typing import Dict, Tuple
13
- import uuid
14
- import zipfile
15
 
16
  import gradio as gr
17
- import librosa
18
- from huggingface_hub import snapshot_download
19
- import matplotlib.pyplot as plt
20
- import numpy as np
21
- from scipy.io import wavfile
22
 
23
  import log
24
  from project_settings import environment, project_path, log_directory, time_zone_info
25
- from toolbox.os.command import Command
26
- from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
27
- from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
28
- from toolbox.torchaudio.models.vad.native_silero_vad.inference_native_silero_vad_onnx import InferenceNativeSileroVadOnnx
29
- from toolbox.torchaudio.utils.visualization import process_speech_probs
30
- from toolbox.vad.utils import PostProcess
31
- from toolbox.pydub.volume import get_volume
32
 
33
  log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
34
 
@@ -68,277 +49,20 @@ def get_args():
68
  return args
69
 
70
 
71
- def save_input_audio(sample_rate: int, signal: np.ndarray) -> str:
72
- if signal.dtype != np.int16:
73
- raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}")
74
- temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio"
75
- temp_audio_dir.mkdir(parents=True, exist_ok=True)
76
- filename = temp_audio_dir / f"{uuid.uuid4()}.wav"
77
- filename = filename.as_posix()
78
- wavfile.write(
79
- filename,
80
- sample_rate, signal
81
- )
82
- return filename
83
-
84
-
85
- def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int):
86
- filename = save_input_audio(sample_rate, signal)
87
-
88
- signal, _ = librosa.load(filename, sr=target_sample_rate)
89
- signal = np.array(signal * (1 << 15), dtype=np.int16)
90
- return signal
91
-
92
-
93
- def shell(cmd: str):
94
- return Command.popen(cmd)
95
-
96
-
97
- def get_infer_cls_by_model_name(model_name: str):
98
- if model_name.__contains__("native_silero_vad"):
99
- infer_cls = InferenceNativeSileroVadOnnx
100
- elif model_name.__contains__("fsmn-vad"):
101
- infer_cls = InferenceFSMNVadOnnx
102
- elif model_name.__contains__("silero-vad"):
103
- infer_cls = InferenceSileroVad
104
- else:
105
- raise AssertionError
106
- return infer_cls
107
-
108
-
109
- vad_engines: Dict[str, dict] = None
110
-
111
-
112
- @lru_cache(maxsize=1)
113
- def load_vad_model(infer_cls, **kwargs):
114
- infer_engine = infer_cls(**kwargs)
115
- return infer_engine
116
-
117
-
118
- def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""):
119
- duration = np.arange(0, len(signal)) / sample_rate
120
- plt.figure(figsize=(12, 5))
121
- plt.plot(duration, signal, color='b')
122
- plt.plot(duration, speech_probs, color='gray')
123
- plt.title(title)
124
-
125
- temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
126
- plt.savefig(temp_file.name, bbox_inches="tight")
127
- plt.close()
128
- return temp_file.name
129
-
130
-
131
- def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
132
- start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
133
- ring_max_length: int = 10,
134
- min_silence_length: int = 2,
135
- max_speech_length: int = 10000, min_speech_length: int = 10,
136
- engine: str = None,
137
- ):
138
- if audio_file_t is None and audio_microphone_t is None:
139
- raise gr.Error(f"audio file and microphone is null.")
140
- if audio_file_t is not None and audio_microphone_t is not None:
141
- gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
142
- audio_t: Tuple = audio_file_t or audio_microphone_t
143
-
144
- sample_rate, signal = audio_t
145
- if sample_rate != 8000:
146
- signal = convert_sample_rate(signal, sample_rate, 8000)
147
- sample_rate = 8000
148
-
149
- audio_duration = signal.shape[-1] // sample_rate
150
- audio = np.array(signal / (1 << 15), dtype=np.float32)
151
-
152
- infer_engine_param = vad_engines.get(engine)
153
- if infer_engine_param is None:
154
- raise gr.Error(f"invalid denoise engine: {engine}.")
155
-
156
- try:
157
- infer_cls = infer_engine_param["infer_cls"]
158
- kwargs = infer_engine_param["kwargs"]
159
- infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs)
160
-
161
- begin = time.time()
162
- vad_info = infer_engine.infer(audio)
163
- time_cost = time.time() - begin
164
-
165
- probs: np.ndarray = vad_info["probs"]
166
- lsnr: np.ndarray = vad_info["lsnr"]
167
- # lsnr = lsnr / np.max(np.abs(lsnr))
168
- lsnr = lsnr / 30
169
-
170
- frame_step = infer_engine.config.hop_size
171
-
172
- # post process
173
- vad_post_process = PostProcess(
174
- start_ring_rate=start_ring_rate,
175
- end_ring_rate=end_ring_rate,
176
- ring_max_length=ring_max_length,
177
- min_silence_length=min_silence_length,
178
- max_speech_length=max_speech_length,
179
- min_speech_length=min_speech_length
180
- )
181
- vad_segments = vad_post_process.get_vad_segments(probs)
182
- vad_flags = vad_post_process.get_vad_flags(probs, vad_segments)
183
-
184
- # vad_image
185
- vad_ = process_speech_probs(audio, vad_flags, frame_step)
186
- vad_image = generate_image(audio, vad_)
187
-
188
- # probs_image
189
- probs_ = process_speech_probs(audio, probs, frame_step)
190
- probs_image = generate_image(audio, probs_)
191
-
192
- # lsnr_image
193
- lsnr_ = process_speech_probs(audio, lsnr, frame_step)
194
- lsnr_image = generate_image(audio, lsnr_)
195
-
196
- # vad segment
197
- vad_segments = [
198
- [
199
- v[0] * frame_step / sample_rate,
200
- v[1] * frame_step / sample_rate
201
- ] for v in vad_segments
202
- ]
203
-
204
- # volume
205
- volume_map: dict = get_volume(audio, sample_rate)
206
-
207
- # message
208
- rtf = time_cost / audio_duration
209
- info = {
210
- "vad_segments": vad_segments,
211
- "time_cost": round(time_cost, 4),
212
- "duration": round(audio_duration, 4),
213
- "rtf": round(rtf, 4),
214
- **volume_map
215
- }
216
- message = json.dumps(info, ensure_ascii=False, indent=4)
217
-
218
- except Exception as e:
219
- raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
220
-
221
- return vad_image, probs_image, lsnr_image, message
222
-
223
-
224
  def main():
225
  args = get_args()
226
 
227
- examples_dir = Path(args.examples_dir)
228
- trained_model_dir = Path(args.trained_model_dir)
229
-
230
- # download models
231
- if not trained_model_dir.exists():
232
- trained_model_dir.mkdir(parents=True, exist_ok=True)
233
- _ = snapshot_download(
234
- repo_id=args.models_repo_id,
235
- local_dir=trained_model_dir.as_posix(),
236
- token=args.hf_token,
237
- )
238
-
239
- # engines
240
- global vad_engines
241
- vad_engines = {
242
- filename.stem: {
243
- "infer_cls": get_infer_cls_by_model_name(filename.stem),
244
- "kwargs": {
245
- "pretrained_model_path_or_zip_file": filename.as_posix()
246
- }
247
- }
248
- for filename in (project_path / "trained_models").glob("*.zip")
249
- if filename.name not in (
250
- # "cnn-vad-by-webrtcvad-nx-dns3.zip",
251
- # "fsmn-vad-by-webrtcvad-nx-dns3.zip",
252
- "examples.zip",
253
- "sound-2-ch32.zip",
254
- "sound-3-ch32.zip",
255
- "sound-4-ch32.zip",
256
- "sound-8-ch32.zip",
257
- )
258
- }
259
-
260
- # choices
261
- vad_engine_choices = list(vad_engines.keys())
262
-
263
- # examples
264
- if not examples_dir.exists():
265
- example_zip_file = trained_model_dir / "examples.zip"
266
- with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
267
- out_root = examples_dir
268
- if out_root.exists():
269
- shutil.rmtree(out_root.as_posix())
270
- out_root.mkdir(parents=True, exist_ok=True)
271
- f_zip.extractall(path=out_root)
272
-
273
- # examples
274
- examples = list()
275
- for filename in examples_dir.glob("**/*.wav"):
276
- examples.append([
277
- filename.as_posix(),
278
- None,
279
- vad_engine_choices[0],
280
- ])
281
-
282
  # ui
283
  with gr.Blocks() as blocks:
284
  gr.Markdown(value="vad.")
285
  with gr.Tabs():
286
- with gr.TabItem("vad"):
287
- with gr.Row():
288
- with gr.Column(variant="panel", scale=5):
289
- with gr.Tabs():
290
- with gr.TabItem("file"):
291
- vad_audio_file = gr.Audio(label="audio")
292
- with gr.TabItem("microphone"):
293
- vad_audio_microphone = gr.Audio(sources="microphone", label="audio")
294
-
295
- with gr.Row():
296
- vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
297
- vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
298
- with gr.Row():
299
- vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)")
300
- vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)")
301
- with gr.Row():
302
- vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)")
303
- vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)")
304
- vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
305
- vad_button = gr.Button(variant="primary")
306
- with gr.Column(variant="panel", scale=5):
307
- vad_vad_image = gr.Image(label="vad")
308
- vad_prob_image = gr.Image(label="prob")
309
- vad_lsnr_image = gr.Image(label="lsnr")
310
- vad_message = gr.Textbox(lines=1, max_lines=20, label="message")
311
-
312
- vad_button.click(
313
- when_click_vad_button,
314
- inputs=[
315
- vad_audio_file, vad_audio_microphone,
316
- vad_start_ring_rate, vad_end_ring_rate,
317
- vad_ring_max_length,
318
- vad_min_silence_length,
319
- vad_max_speech_length, vad_min_speech_length,
320
- vad_engine,
321
- ],
322
- outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
323
- )
324
- gr.Examples(
325
- examples=examples,
326
- inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
327
- outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
328
- fn=when_click_vad_button,
329
- # cache_examples=True,
330
- # cache_mode="lazy",
331
- )
332
- with gr.TabItem("shell"):
333
- shell_text = gr.Textbox(label="cmd")
334
- shell_button = gr.Button("run")
335
- shell_output = gr.Textbox(label="output")
336
-
337
- shell_button.click(
338
- shell,
339
- inputs=[shell_text,],
340
- outputs=[shell_output],
341
- )
342
 
343
  # http://127.0.0.1:7866/
344
  # http://10.75.27.247:7866/
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
 
4
  import logging
 
5
  import platform
 
 
 
 
 
 
6
 
7
  import gradio as gr
 
 
 
 
 
8
 
9
  import log
10
  from project_settings import environment, project_path, log_directory, time_zone_info
11
+ from tabs.vad_tab import get_vad_tab
12
+ from tabs.shell_tab import get_shell_tab
 
 
 
 
 
13
 
14
  log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
15
 
 
49
  return args
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def main():
53
  args = get_args()
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # ui
56
  with gr.Blocks() as blocks:
57
  gr.Markdown(value="vad.")
58
  with gr.Tabs():
59
+ _ = get_vad_tab(
60
+ trained_model_dir=args.trained_model_dir,
61
+ examples_dir=args.examples_dir,
62
+ models_repo_id=args.models_repo_id,
63
+ hf_token=args.hf_token,
64
+ )
65
+ _ = get_shell_tab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # http://127.0.0.1:7866/
68
  # http://10.75.27.247:7866/
tabs/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
tabs/shell_tab.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import gradio as gr
4
+
5
+ from toolbox.os.command import Command
6
+
7
+
8
+ def shell(cmd: str):
9
+ return Command.popen(cmd)
10
+
11
+
12
+ def get_shell_tab():
13
+ with gr.TabItem("shell"):
14
+ shell_text = gr.Textbox(label="cmd")
15
+ shell_button = gr.Button("run")
16
+ shell_output = gr.Textbox(label="output", max_lines=100)
17
+
18
+ shell_button.click(
19
+ shell,
20
+ inputs=[shell_text, ],
21
+ outputs=[shell_output],
22
+ )
23
+
24
+ return locals()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ with gr.Blocks() as block:
29
+ fs_components = get_shell_tab()
30
+ block.launch()
tabs/vad_tab.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ from functools import lru_cache
6
+ import json
7
+ import logging
8
+ from pathlib import Path
9
+ import shutil
10
+ import tempfile
11
+ import time
12
+ from typing import Dict, Tuple
13
+ import uuid
14
+ import zipfile
15
+
16
+ import gradio as gr
17
+ import librosa
18
+ from huggingface_hub import snapshot_download
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ from scipy.io import wavfile
22
+
23
+ from project_settings import project_path
24
+ from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad_onnx import InferenceFSMNVadOnnx
25
+ from toolbox.torchaudio.models.vad.silero_vad.inference_silero_vad import InferenceSileroVad
26
+ from toolbox.torchaudio.models.vad.native_silero_vad.inference_native_silero_vad_onnx import InferenceNativeSileroVadOnnx
27
+ from toolbox.torchaudio.utils.visualization import process_speech_probs
28
+ from toolbox.vad.utils import PostProcess
29
+ from toolbox.pydub.volume import get_volume
30
+
31
+ logger = logging.getLogger("main")
32
+
33
+
34
+ def save_input_audio(sample_rate: int, signal: np.ndarray) -> str:
35
+ if signal.dtype != np.int16:
36
+ raise AssertionError(f"only support dtype np.int16, however: {signal.dtype}")
37
+ temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio"
38
+ temp_audio_dir.mkdir(parents=True, exist_ok=True)
39
+ filename = temp_audio_dir / f"{uuid.uuid4()}.wav"
40
+ filename = filename.as_posix()
41
+ wavfile.write(
42
+ filename,
43
+ sample_rate, signal
44
+ )
45
+ return filename
46
+
47
+
48
+ def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int):
49
+ filename = save_input_audio(sample_rate, signal)
50
+
51
+ signal, _ = librosa.load(filename, sr=target_sample_rate)
52
+ signal = np.array(signal * (1 << 15), dtype=np.int16)
53
+ return signal
54
+
55
+
56
+ def get_infer_cls_by_model_name(model_name: str):
57
+ if model_name.__contains__("native_silero_vad"):
58
+ infer_cls = InferenceNativeSileroVadOnnx
59
+ elif model_name.__contains__("fsmn-vad"):
60
+ infer_cls = InferenceFSMNVadOnnx
61
+ elif model_name.__contains__("silero-vad"):
62
+ infer_cls = InferenceSileroVad
63
+ else:
64
+ raise AssertionError
65
+ return infer_cls
66
+
67
+
68
+ vad_engines: Dict[str, dict] = None
69
+
70
+
71
+ @lru_cache(maxsize=1)
72
+ def load_vad_model(infer_cls, **kwargs):
73
+ infer_engine = infer_cls(**kwargs)
74
+ return infer_engine
75
+
76
+
77
+ def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""):
78
+ duration = np.arange(0, len(signal)) / sample_rate
79
+ plt.figure(figsize=(12, 5))
80
+ plt.plot(duration, signal, color='b')
81
+ plt.plot(duration, speech_probs, color='gray')
82
+ plt.title(title)
83
+
84
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
85
+ plt.savefig(temp_file.name, bbox_inches="tight")
86
+ plt.close()
87
+ return temp_file.name
88
+
89
+
90
+ def when_click_vad_button(audio_file_t = None, audio_microphone_t = None,
91
+ start_ring_rate: float = 0.5, end_ring_rate: float = 0.3,
92
+ ring_max_length: int = 10,
93
+ min_silence_length: int = 2,
94
+ max_speech_length: int = 10000, min_speech_length: int = 10,
95
+ engine: str = None,
96
+ ):
97
+ if audio_file_t is None and audio_microphone_t is None:
98
+ raise gr.Error(f"audio file and microphone is null.")
99
+ if audio_file_t is not None and audio_microphone_t is not None:
100
+ gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
101
+ audio_t: Tuple = audio_file_t or audio_microphone_t
102
+
103
+ sample_rate, signal = audio_t
104
+ if sample_rate != 8000:
105
+ signal = convert_sample_rate(signal, sample_rate, 8000)
106
+ sample_rate = 8000
107
+
108
+ audio_duration = signal.shape[-1] // sample_rate
109
+ audio = np.array(signal / (1 << 15), dtype=np.float32)
110
+
111
+ infer_engine_param = vad_engines.get(engine)
112
+ if infer_engine_param is None:
113
+ raise gr.Error(f"invalid denoise engine: {engine}.")
114
+
115
+ try:
116
+ infer_cls = infer_engine_param["infer_cls"]
117
+ kwargs = infer_engine_param["kwargs"]
118
+ infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs)
119
+
120
+ begin = time.time()
121
+ vad_info = infer_engine.infer(audio)
122
+ time_cost = time.time() - begin
123
+
124
+ probs: np.ndarray = vad_info["probs"]
125
+ lsnr: np.ndarray = vad_info["lsnr"]
126
+ # lsnr = lsnr / np.max(np.abs(lsnr))
127
+ lsnr = lsnr / 30
128
+
129
+ frame_step = infer_engine.config.hop_size
130
+
131
+ # post process
132
+ vad_post_process = PostProcess(
133
+ start_ring_rate=start_ring_rate,
134
+ end_ring_rate=end_ring_rate,
135
+ ring_max_length=ring_max_length,
136
+ min_silence_length=min_silence_length,
137
+ max_speech_length=max_speech_length,
138
+ min_speech_length=min_speech_length
139
+ )
140
+ vad_segments = vad_post_process.get_vad_segments(probs)
141
+ vad_flags = vad_post_process.get_vad_flags(probs, vad_segments)
142
+
143
+ # vad_image
144
+ vad_ = process_speech_probs(audio, vad_flags, frame_step)
145
+ vad_image = generate_image(audio, vad_)
146
+
147
+ # probs_image
148
+ probs_ = process_speech_probs(audio, probs, frame_step)
149
+ probs_image = generate_image(audio, probs_)
150
+
151
+ # lsnr_image
152
+ lsnr_ = process_speech_probs(audio, lsnr, frame_step)
153
+ lsnr_image = generate_image(audio, lsnr_)
154
+
155
+ # vad segment
156
+ vad_segments = [
157
+ [
158
+ v[0] * frame_step / sample_rate,
159
+ v[1] * frame_step / sample_rate
160
+ ] for v in vad_segments
161
+ ]
162
+
163
+ # volume
164
+ volume_map: dict = get_volume(audio, sample_rate)
165
+
166
+ # message
167
+ rtf = time_cost / audio_duration
168
+ info = {
169
+ "vad_segments": vad_segments,
170
+ "time_cost": round(time_cost, 4),
171
+ "duration": round(audio_duration, 4),
172
+ "rtf": round(rtf, 4),
173
+ **volume_map
174
+ }
175
+ message = json.dumps(info, ensure_ascii=False, indent=4)
176
+
177
+ except Exception as e:
178
+ raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
179
+
180
+ return vad_image, probs_image, lsnr_image, message
181
+
182
+
183
+ def get_vad_tab(trained_model_dir: str, examples_dir: str, models_repo_id: str, hf_token: str):
184
+ examples_dir = Path(examples_dir)
185
+ trained_model_dir = Path(trained_model_dir)
186
+
187
+ # download models
188
+ if not trained_model_dir.exists():
189
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
190
+ _ = snapshot_download(
191
+ repo_id=models_repo_id,
192
+ local_dir=trained_model_dir.as_posix(),
193
+ token=hf_token,
194
+ )
195
+
196
+ # engines
197
+ global vad_engines
198
+ vad_engines = {
199
+ filename.stem: {
200
+ "infer_cls": get_infer_cls_by_model_name(filename.stem),
201
+ "kwargs": {
202
+ "pretrained_model_path_or_zip_file": filename.as_posix()
203
+ }
204
+ }
205
+ for filename in (project_path / "trained_models").glob("*.zip")
206
+ if filename.name not in (
207
+ # "cnn-vad-by-webrtcvad-nx-dns3.zip",
208
+ # "fsmn-vad-by-webrtcvad-nx-dns3.zip",
209
+ "examples.zip",
210
+ "sound-2-ch32.zip",
211
+ "sound-3-ch32.zip",
212
+ "sound-4-ch32.zip",
213
+ "sound-8-ch32.zip",
214
+ )
215
+ }
216
+
217
+ # choices
218
+ vad_engine_choices = list(vad_engines.keys())
219
+
220
+ # examples
221
+ if not examples_dir.exists():
222
+ example_zip_file = trained_model_dir / "examples.zip"
223
+ with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
224
+ out_root = examples_dir
225
+ if out_root.exists():
226
+ shutil.rmtree(out_root.as_posix())
227
+ out_root.mkdir(parents=True, exist_ok=True)
228
+ f_zip.extractall(path=out_root)
229
+
230
+ # examples
231
+ examples = defaultdict(list)
232
+ for filename in examples_dir.glob("**/*.wav"):
233
+ category = filename.parts[-2]
234
+ examples[category].append([
235
+ filename.as_posix(),
236
+ None,
237
+ vad_engine_choices[0],
238
+ ])
239
+
240
+ # ui
241
+ with gr.TabItem("vad"):
242
+ with gr.Row():
243
+ with gr.Column(variant="panel", scale=5):
244
+ with gr.Tabs():
245
+ with gr.TabItem("file"):
246
+ vad_audio_file = gr.Audio(label="audio")
247
+ with gr.TabItem("microphone"):
248
+ vad_audio_microphone = gr.Audio(sources="microphone", label="audio")
249
+
250
+ with gr.Row():
251
+ vad_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="start_ring_rate")
252
+ vad_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="end_ring_rate")
253
+ with gr.Row():
254
+ vad_ring_max_length = gr.Number(value=10, label="ring_max_length (*10ms)")
255
+ vad_min_silence_length = gr.Number(value=6, label="min_silence_length (*10ms)")
256
+ with gr.Row():
257
+ vad_max_speech_length = gr.Number(value=100000, label="max_speech_length (*10ms)")
258
+ vad_min_speech_length = gr.Number(value=15, label="min_speech_length (*10ms)")
259
+ vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
260
+ vad_button = gr.Button(variant="primary")
261
+ with gr.Column(variant="panel", scale=5):
262
+ vad_vad_image = gr.Image(label="vad")
263
+ vad_prob_image = gr.Image(label="prob")
264
+ vad_lsnr_image = gr.Image(label="lsnr")
265
+ vad_message = gr.Textbox(lines=1, max_lines=20, label="message")
266
+
267
+ # examples ui
268
+ with gr.Tabs():
269
+ for label, sub_examples in examples.items():
270
+ with gr.TabItem(label):
271
+ gr.Examples(
272
+ examples=sub_examples,
273
+ inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
274
+ outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
275
+ fn=when_click_vad_button,
276
+ # cache_examples=True,
277
+ # cache_mode="lazy",
278
+ )
279
+
280
+ vad_button.click(
281
+ when_click_vad_button,
282
+ inputs=[
283
+ vad_audio_file, vad_audio_microphone,
284
+ vad_start_ring_rate, vad_end_ring_rate,
285
+ vad_ring_max_length,
286
+ vad_min_silence_length,
287
+ vad_max_speech_length, vad_min_speech_length,
288
+ vad_engine,
289
+ ],
290
+ outputs=[vad_vad_image, vad_prob_image, vad_lsnr_image, vad_message],
291
+ )
292
+
293
+ return locals()
294
+
295
+
296
+ if __name__ == "__main__":
297
+ pass