WJ88 commited on
Commit
65ce328
·
verified ·
1 Parent(s): 5d860fb

init but CHAD 0.7 RELEASE commit

Browse files
Files changed (1) hide show
  1. app.py +304 -0
app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import copy
4
+ import uuid
5
+ import logging
6
+ from typing import List, Optional, Tuple, Dict
7
+
8
+ # Reduce progress/log spam before heavy imports
9
+ os.environ.setdefault("TQDM_DISABLE", "1")
10
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+ import soundfile as sf
16
+ import gradio as gr
17
+
18
+ # NeMo
19
+ from nemo.collections.asr.models import ASRModel
20
+ from omegaconf import OmegaConf
21
+ from nemo.utils import logging as nemo_logging
22
+
23
+ # ----------------------------
24
+ # Config
25
+ # ----------------------------
26
+ MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
27
+ TARGET_SR = 16_000
28
+ BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
29
+ OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
30
+ CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
31
+ FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
32
+
33
+ # ----------------------------
34
+ # Logging (unified)
35
+ # ----------------------------
36
+ LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
37
+ logger = logging.getLogger("parakeet_app")
38
+ logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
39
+ _handler = logging.StreamHandler()
40
+ _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
41
+ logger.handlers = [_handler]
42
+ logger.propagate = False
43
+
44
+ # Quiet NeMo logs
45
+ nemo_logging.setLevel(logging.ERROR)
46
+ logging.getLogger("nemo").setLevel(logging.ERROR)
47
+ logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
48
+
49
+ torch.set_grad_enabled(False)
50
+
51
+ # ----------------------------
52
+ # Audio utils
53
+ # ----------------------------
54
+ def to_mono_np(x: np.ndarray) -> np.ndarray:
55
+ if x.ndim == 2:
56
+ x = x.mean(axis=1)
57
+ return x.astype(np.float32, copy=False)
58
+
59
+ class ResamplerCache:
60
+ def __init__(self):
61
+ self._cache: Dict[int, torchaudio.transforms.Resample] = {}
62
+ def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray:
63
+ if src_sr == TARGET_SR:
64
+ return wav
65
+ if src_sr not in self._cache:
66
+ logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}")
67
+ self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR)
68
+ t = torch.from_numpy(wav)
69
+ if t.ndim == 1:
70
+ t = t.unsqueeze(0)
71
+ y = self._cache[src_sr](t)
72
+ return y.squeeze(0).numpy()
73
+
74
+ RESAMPLER = ResamplerCache()
75
+
76
+ def load_mono16k(path: str) -> np.ndarray:
77
+ """Load any audio file, convert to mono float32 at 16 kHz."""
78
+ try:
79
+ wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
80
+ wav = wav.mean(axis=1).astype(np.float32, copy=False)
81
+ return RESAMPLER.resample(wav, sr)
82
+ except Exception:
83
+ wav_t, sr = torchaudio.load(path) # (C,T)
84
+ if wav_t.dtype != torch.float32:
85
+ wav_t = wav_t.float()
86
+ wav = wav_t.mean(dim=0).numpy()
87
+ return RESAMPLER.resample(wav, int(sr))
88
+
89
+ # ----------------------------
90
+ # Model manager (MALSD batched beam everywhere, loop_labels=True)
91
+ # ----------------------------
92
+ class ParakeetManager:
93
+ def __init__(self, device: str = "cpu"):
94
+ self.device = torch.device(device)
95
+ logger.info(f"loading_model name={MODEL_NAME} device={self.device}")
96
+ self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME)
97
+ self.model.to(self.device)
98
+ self.model.eval()
99
+ for p in self.model.parameters():
100
+ p.requires_grad = False
101
+
102
+ # Base decoding cfg differs by class
103
+ if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
104
+ self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
105
+ else:
106
+ self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
107
+
108
+ self._set_malsd_beam()
109
+
110
+ # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
111
+ if hasattr(self.model.encoder, "set_default_att_context_size"):
112
+ self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
113
+ logger.info("encoder_caching_enabled left=512 right=16")
114
+
115
+ logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
116
+
117
+ def _set_malsd_beam(self):
118
+ cfg = copy.deepcopy(self._base_decoding)
119
+ cfg.strategy = "malsd_batch"
120
+ cfg.beam = OmegaConf.create({
121
+ "beam_size": BEAM_SIZE,
122
+ "return_best_hypothesis": True,
123
+ "score_norm": True,
124
+ "allow_cuda_graphs": False, # CPU-only
125
+ "max_symbols_per_step": 10,
126
+ })
127
+ OmegaConf.set_struct(cfg, False)
128
+ cfg["loop_labels"] = True
129
+ cfg["fused_batch_size"] = -1
130
+ cfg["compute_timestamps"] = False
131
+ if hasattr(cfg, "greedy"):
132
+ cfg.greedy.use_cuda_graph_decoder = False
133
+ self.model.change_decoding_strategy(cfg)
134
+ logger.info("decoding_set strategy=malsd_batch loop_labels=True")
135
+
136
+ def _transcribe(self, items: List, *, partial=None):
137
+ with torch.inference_mode():
138
+ return self.model.transcribe(
139
+ items,
140
+ batch_size=1 if len(items) == 1 else OFFLINE_BATCH,
141
+ num_workers=0,
142
+ return_hypotheses=True,
143
+ partial_hypothesis=partial,
144
+ )
145
+
146
+ # Offline batch
147
+ def transcribe_files(self, paths: List[str]):
148
+ n = 0 if not paths else len(paths)
149
+ logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}")
150
+ if not paths:
151
+ return []
152
+ arrays = [load_mono16k(p) for p in paths]
153
+ out = self._transcribe(arrays, partial=None)
154
+ results = []
155
+ for p, o in zip(paths, out):
156
+ h = o[0] if isinstance(o, list) and o else o
157
+ text = h if isinstance(h, str) else getattr(h, "text", "")
158
+ results.append({"path": p, "text": text})
159
+ logger.info("files_run ok")
160
+ return results
161
+
162
+ # Streaming step (rolling hypothesis)
163
+ def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
164
+ out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
165
+ h = out[0][0] if isinstance(out[0], list) else out[0]
166
+ return h # Hypothesis
167
+
168
+ # ----------------------------
169
+ # Streaming session (no overlap, rolling hypothesis)
170
+ # ----------------------------
171
+ class StreamingSession:
172
+ def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
173
+ self.mgr = manager
174
+ self.chunk_s = chunk_s
175
+ self.flush_pad_s = flush_pad_s
176
+ self.hyp = None
177
+ self.pending = np.zeros(0, dtype=np.float32)
178
+ self.text = ""
179
+ logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
180
+
181
+ def add_audio(self, audio: np.ndarray, src_sr: int):
182
+ mono = to_mono_np(audio)
183
+ res = RESAMPLER.resample(mono, src_sr)
184
+ self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
185
+ self._drain()
186
+
187
+ def _drain(self):
188
+ C = int(self.chunk_s * TARGET_SR)
189
+ while self.pending.size >= C:
190
+ chunk = self.pending[:C]
191
+ self.pending = self.pending[C:]
192
+ try:
193
+ self.hyp = self.mgr.stream_step(chunk, self.hyp)
194
+ new_text = getattr(self.hyp, "text", "")
195
+ if new_text:
196
+ if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
197
+ self.text = new_text
198
+ else: # Else append (handles per-chunk case)
199
+ self.text += (' ' if self.text else '') + new_text
200
+ except Exception:
201
+ logger.exception("mic_step failed")
202
+ break
203
+
204
+ def flush(self) -> str:
205
+ if self.pending.size:
206
+ pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
207
+ final = np.concatenate([self.pending, pad])
208
+ try:
209
+ self.hyp = self.mgr.stream_step(final, self.hyp)
210
+ new_text = getattr(self.hyp, "text", "")
211
+ if new_text:
212
+ if self.text and new_text.startswith(self.text):
213
+ self.text = new_text
214
+ else:
215
+ self.text += (' ' if self.text else '') + new_text
216
+ self.text += '.' # Add period for sentence closure on flush
217
+ except Exception:
218
+ logger.exception("mic_flush failed")
219
+ self.pending = np.zeros(0, dtype=np.float32)
220
+ return self.text
221
+
222
+ # ----------------------------
223
+ # Simple session registry (avoid deepcopy in gr.State)
224
+ # ----------------------------
225
+ SESS: Dict[str, StreamingSession] = {}
226
+ def _new_session_id() -> str:
227
+ return uuid.uuid4().hex
228
+
229
+ # ----------------------------
230
+ # Gradio callbacks
231
+ # ----------------------------
232
+ MANAGER = ParakeetManager(device="cpu")
233
+
234
+ def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
235
+ if x is None:
236
+ return np.zeros(0, dtype=np.float32), TARGET_SR
237
+ if isinstance(x, tuple) and len(x) == 2:
238
+ sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr
239
+ if isinstance(x, dict) and "data" in x and "sampling_rate" in x:
240
+ arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr
241
+ if isinstance(x, np.ndarray):
242
+ return x.astype(np.float32, copy=False), TARGET_SR
243
+ logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
244
+
245
+ def mic_step(audio_chunk, sess_id: Optional[str]):
246
+ if not sess_id or sess_id not in SESS:
247
+ sess_id = _new_session_id()
248
+ SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S)
249
+ sess = SESS[sess_id]
250
+ try:
251
+ wav, sr = _parse_gr_audio(audio_chunk)
252
+ except Exception:
253
+ logger.exception("mic_parse failed")
254
+ return sess_id, sess.text
255
+ if wav.size:
256
+ sess.add_audio(wav, sr)
257
+ return sess_id, sess.text
258
+
259
+ def mic_flush(sess_id: Optional[str]):
260
+ if not sess_id or sess_id not in SESS:
261
+ return None, ""
262
+ text = SESS[sess_id].flush()
263
+ logger.info("mic_flush ok")
264
+ return None, text
265
+
266
+ def files_run(files):
267
+ n = 0 if not files else len(files)
268
+ logger.info(f"files_ui start count={n}")
269
+ if not files:
270
+ return []
271
+ paths: List[str] = []
272
+ for f in files:
273
+ if isinstance(f, str):
274
+ paths.append(f)
275
+ elif hasattr(f, "name"):
276
+ paths.append(f.name)
277
+ try:
278
+ results = MANAGER.transcribe_files(paths)
279
+ except Exception:
280
+ logger.exception("files_run failed"); raise
281
+ table = [[os.path.basename(r["path"]), r["text"]] for r in results]
282
+ logger.info("files_ui ok")
283
+ return table
284
+
285
+ # ----------------------------
286
+ # UI
287
+ # ----------------------------
288
+ with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
289
+ with gr.Tab("Mic"):
290
+ mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
291
+ text_out = gr.Textbox(label="Transcript", lines=8)
292
+ flush_btn = gr.Button("Flush")
293
+ state_id = gr.State() # only a string id
294
+ mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
295
+ flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
296
+
297
+ with gr.Tab("Files"):
298
+ files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
299
+ run_btn = gr.Button("Run")
300
+ results_table = gr.Dataframe(headers=["file", "text"], label="Results",
301
+ row_count=(0, "dynamic"), col_count=(2, "fixed"))
302
+ run_btn.click(files_run, inputs=[files], outputs=[results_table])
303
+
304
+ demo.queue().launch(ssr_mode=False)