init but CHAD 0.7 RELEASE commit
Browse files
app.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import os
|
| 3 |
+
import copy
|
| 4 |
+
import uuid
|
| 5 |
+
import logging
|
| 6 |
+
from typing import List, Optional, Tuple, Dict
|
| 7 |
+
|
| 8 |
+
# Reduce progress/log spam before heavy imports
|
| 9 |
+
os.environ.setdefault("TQDM_DISABLE", "1")
|
| 10 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torchaudio
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
import gradio as gr
|
| 17 |
+
|
| 18 |
+
# NeMo
|
| 19 |
+
from nemo.collections.asr.models import ASRModel
|
| 20 |
+
from omegaconf import OmegaConf
|
| 21 |
+
from nemo.utils import logging as nemo_logging
|
| 22 |
+
|
| 23 |
+
# ----------------------------
|
| 24 |
+
# Config
|
| 25 |
+
# ----------------------------
|
| 26 |
+
MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
|
| 27 |
+
TARGET_SR = 16_000
|
| 28 |
+
BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
|
| 29 |
+
OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
|
| 30 |
+
CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
|
| 31 |
+
FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
|
| 32 |
+
|
| 33 |
+
# ----------------------------
|
| 34 |
+
# Logging (unified)
|
| 35 |
+
# ----------------------------
|
| 36 |
+
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
|
| 37 |
+
logger = logging.getLogger("parakeet_app")
|
| 38 |
+
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
| 39 |
+
_handler = logging.StreamHandler()
|
| 40 |
+
_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
|
| 41 |
+
logger.handlers = [_handler]
|
| 42 |
+
logger.propagate = False
|
| 43 |
+
|
| 44 |
+
# Quiet NeMo logs
|
| 45 |
+
nemo_logging.setLevel(logging.ERROR)
|
| 46 |
+
logging.getLogger("nemo").setLevel(logging.ERROR)
|
| 47 |
+
logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
|
| 48 |
+
|
| 49 |
+
torch.set_grad_enabled(False)
|
| 50 |
+
|
| 51 |
+
# ----------------------------
|
| 52 |
+
# Audio utils
|
| 53 |
+
# ----------------------------
|
| 54 |
+
def to_mono_np(x: np.ndarray) -> np.ndarray:
|
| 55 |
+
if x.ndim == 2:
|
| 56 |
+
x = x.mean(axis=1)
|
| 57 |
+
return x.astype(np.float32, copy=False)
|
| 58 |
+
|
| 59 |
+
class ResamplerCache:
|
| 60 |
+
def __init__(self):
|
| 61 |
+
self._cache: Dict[int, torchaudio.transforms.Resample] = {}
|
| 62 |
+
def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray:
|
| 63 |
+
if src_sr == TARGET_SR:
|
| 64 |
+
return wav
|
| 65 |
+
if src_sr not in self._cache:
|
| 66 |
+
logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}")
|
| 67 |
+
self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR)
|
| 68 |
+
t = torch.from_numpy(wav)
|
| 69 |
+
if t.ndim == 1:
|
| 70 |
+
t = t.unsqueeze(0)
|
| 71 |
+
y = self._cache[src_sr](t)
|
| 72 |
+
return y.squeeze(0).numpy()
|
| 73 |
+
|
| 74 |
+
RESAMPLER = ResamplerCache()
|
| 75 |
+
|
| 76 |
+
def load_mono16k(path: str) -> np.ndarray:
|
| 77 |
+
"""Load any audio file, convert to mono float32 at 16 kHz."""
|
| 78 |
+
try:
|
| 79 |
+
wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
|
| 80 |
+
wav = wav.mean(axis=1).astype(np.float32, copy=False)
|
| 81 |
+
return RESAMPLER.resample(wav, sr)
|
| 82 |
+
except Exception:
|
| 83 |
+
wav_t, sr = torchaudio.load(path) # (C,T)
|
| 84 |
+
if wav_t.dtype != torch.float32:
|
| 85 |
+
wav_t = wav_t.float()
|
| 86 |
+
wav = wav_t.mean(dim=0).numpy()
|
| 87 |
+
return RESAMPLER.resample(wav, int(sr))
|
| 88 |
+
|
| 89 |
+
# ----------------------------
|
| 90 |
+
# Model manager (MALSD batched beam everywhere, loop_labels=True)
|
| 91 |
+
# ----------------------------
|
| 92 |
+
class ParakeetManager:
|
| 93 |
+
def __init__(self, device: str = "cpu"):
|
| 94 |
+
self.device = torch.device(device)
|
| 95 |
+
logger.info(f"loading_model name={MODEL_NAME} device={self.device}")
|
| 96 |
+
self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME)
|
| 97 |
+
self.model.to(self.device)
|
| 98 |
+
self.model.eval()
|
| 99 |
+
for p in self.model.parameters():
|
| 100 |
+
p.requires_grad = False
|
| 101 |
+
|
| 102 |
+
# Base decoding cfg differs by class
|
| 103 |
+
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
|
| 104 |
+
self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
|
| 105 |
+
else:
|
| 106 |
+
self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
|
| 107 |
+
|
| 108 |
+
self._set_malsd_beam()
|
| 109 |
+
|
| 110 |
+
# Enable encoder caching for better streaming context (per NeMo docs/tutorials)
|
| 111 |
+
if hasattr(self.model.encoder, "set_default_att_context_size"):
|
| 112 |
+
self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
|
| 113 |
+
logger.info("encoder_caching_enabled left=512 right=16")
|
| 114 |
+
|
| 115 |
+
logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
|
| 116 |
+
|
| 117 |
+
def _set_malsd_beam(self):
|
| 118 |
+
cfg = copy.deepcopy(self._base_decoding)
|
| 119 |
+
cfg.strategy = "malsd_batch"
|
| 120 |
+
cfg.beam = OmegaConf.create({
|
| 121 |
+
"beam_size": BEAM_SIZE,
|
| 122 |
+
"return_best_hypothesis": True,
|
| 123 |
+
"score_norm": True,
|
| 124 |
+
"allow_cuda_graphs": False, # CPU-only
|
| 125 |
+
"max_symbols_per_step": 10,
|
| 126 |
+
})
|
| 127 |
+
OmegaConf.set_struct(cfg, False)
|
| 128 |
+
cfg["loop_labels"] = True
|
| 129 |
+
cfg["fused_batch_size"] = -1
|
| 130 |
+
cfg["compute_timestamps"] = False
|
| 131 |
+
if hasattr(cfg, "greedy"):
|
| 132 |
+
cfg.greedy.use_cuda_graph_decoder = False
|
| 133 |
+
self.model.change_decoding_strategy(cfg)
|
| 134 |
+
logger.info("decoding_set strategy=malsd_batch loop_labels=True")
|
| 135 |
+
|
| 136 |
+
def _transcribe(self, items: List, *, partial=None):
|
| 137 |
+
with torch.inference_mode():
|
| 138 |
+
return self.model.transcribe(
|
| 139 |
+
items,
|
| 140 |
+
batch_size=1 if len(items) == 1 else OFFLINE_BATCH,
|
| 141 |
+
num_workers=0,
|
| 142 |
+
return_hypotheses=True,
|
| 143 |
+
partial_hypothesis=partial,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Offline batch
|
| 147 |
+
def transcribe_files(self, paths: List[str]):
|
| 148 |
+
n = 0 if not paths else len(paths)
|
| 149 |
+
logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}")
|
| 150 |
+
if not paths:
|
| 151 |
+
return []
|
| 152 |
+
arrays = [load_mono16k(p) for p in paths]
|
| 153 |
+
out = self._transcribe(arrays, partial=None)
|
| 154 |
+
results = []
|
| 155 |
+
for p, o in zip(paths, out):
|
| 156 |
+
h = o[0] if isinstance(o, list) and o else o
|
| 157 |
+
text = h if isinstance(h, str) else getattr(h, "text", "")
|
| 158 |
+
results.append({"path": p, "text": text})
|
| 159 |
+
logger.info("files_run ok")
|
| 160 |
+
return results
|
| 161 |
+
|
| 162 |
+
# Streaming step (rolling hypothesis)
|
| 163 |
+
def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
|
| 164 |
+
out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
|
| 165 |
+
h = out[0][0] if isinstance(out[0], list) else out[0]
|
| 166 |
+
return h # Hypothesis
|
| 167 |
+
|
| 168 |
+
# ----------------------------
|
| 169 |
+
# Streaming session (no overlap, rolling hypothesis)
|
| 170 |
+
# ----------------------------
|
| 171 |
+
class StreamingSession:
|
| 172 |
+
def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
|
| 173 |
+
self.mgr = manager
|
| 174 |
+
self.chunk_s = chunk_s
|
| 175 |
+
self.flush_pad_s = flush_pad_s
|
| 176 |
+
self.hyp = None
|
| 177 |
+
self.pending = np.zeros(0, dtype=np.float32)
|
| 178 |
+
self.text = ""
|
| 179 |
+
logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
|
| 180 |
+
|
| 181 |
+
def add_audio(self, audio: np.ndarray, src_sr: int):
|
| 182 |
+
mono = to_mono_np(audio)
|
| 183 |
+
res = RESAMPLER.resample(mono, src_sr)
|
| 184 |
+
self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
|
| 185 |
+
self._drain()
|
| 186 |
+
|
| 187 |
+
def _drain(self):
|
| 188 |
+
C = int(self.chunk_s * TARGET_SR)
|
| 189 |
+
while self.pending.size >= C:
|
| 190 |
+
chunk = self.pending[:C]
|
| 191 |
+
self.pending = self.pending[C:]
|
| 192 |
+
try:
|
| 193 |
+
self.hyp = self.mgr.stream_step(chunk, self.hyp)
|
| 194 |
+
new_text = getattr(self.hyp, "text", "")
|
| 195 |
+
if new_text:
|
| 196 |
+
if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
|
| 197 |
+
self.text = new_text
|
| 198 |
+
else: # Else append (handles per-chunk case)
|
| 199 |
+
self.text += (' ' if self.text else '') + new_text
|
| 200 |
+
except Exception:
|
| 201 |
+
logger.exception("mic_step failed")
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
def flush(self) -> str:
|
| 205 |
+
if self.pending.size:
|
| 206 |
+
pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
|
| 207 |
+
final = np.concatenate([self.pending, pad])
|
| 208 |
+
try:
|
| 209 |
+
self.hyp = self.mgr.stream_step(final, self.hyp)
|
| 210 |
+
new_text = getattr(self.hyp, "text", "")
|
| 211 |
+
if new_text:
|
| 212 |
+
if self.text and new_text.startswith(self.text):
|
| 213 |
+
self.text = new_text
|
| 214 |
+
else:
|
| 215 |
+
self.text += (' ' if self.text else '') + new_text
|
| 216 |
+
self.text += '.' # Add period for sentence closure on flush
|
| 217 |
+
except Exception:
|
| 218 |
+
logger.exception("mic_flush failed")
|
| 219 |
+
self.pending = np.zeros(0, dtype=np.float32)
|
| 220 |
+
return self.text
|
| 221 |
+
|
| 222 |
+
# ----------------------------
|
| 223 |
+
# Simple session registry (avoid deepcopy in gr.State)
|
| 224 |
+
# ----------------------------
|
| 225 |
+
SESS: Dict[str, StreamingSession] = {}
|
| 226 |
+
def _new_session_id() -> str:
|
| 227 |
+
return uuid.uuid4().hex
|
| 228 |
+
|
| 229 |
+
# ----------------------------
|
| 230 |
+
# Gradio callbacks
|
| 231 |
+
# ----------------------------
|
| 232 |
+
MANAGER = ParakeetManager(device="cpu")
|
| 233 |
+
|
| 234 |
+
def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
|
| 235 |
+
if x is None:
|
| 236 |
+
return np.zeros(0, dtype=np.float32), TARGET_SR
|
| 237 |
+
if isinstance(x, tuple) and len(x) == 2:
|
| 238 |
+
sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr
|
| 239 |
+
if isinstance(x, dict) and "data" in x and "sampling_rate" in x:
|
| 240 |
+
arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr
|
| 241 |
+
if isinstance(x, np.ndarray):
|
| 242 |
+
return x.astype(np.float32, copy=False), TARGET_SR
|
| 243 |
+
logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
|
| 244 |
+
|
| 245 |
+
def mic_step(audio_chunk, sess_id: Optional[str]):
|
| 246 |
+
if not sess_id or sess_id not in SESS:
|
| 247 |
+
sess_id = _new_session_id()
|
| 248 |
+
SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S)
|
| 249 |
+
sess = SESS[sess_id]
|
| 250 |
+
try:
|
| 251 |
+
wav, sr = _parse_gr_audio(audio_chunk)
|
| 252 |
+
except Exception:
|
| 253 |
+
logger.exception("mic_parse failed")
|
| 254 |
+
return sess_id, sess.text
|
| 255 |
+
if wav.size:
|
| 256 |
+
sess.add_audio(wav, sr)
|
| 257 |
+
return sess_id, sess.text
|
| 258 |
+
|
| 259 |
+
def mic_flush(sess_id: Optional[str]):
|
| 260 |
+
if not sess_id or sess_id not in SESS:
|
| 261 |
+
return None, ""
|
| 262 |
+
text = SESS[sess_id].flush()
|
| 263 |
+
logger.info("mic_flush ok")
|
| 264 |
+
return None, text
|
| 265 |
+
|
| 266 |
+
def files_run(files):
|
| 267 |
+
n = 0 if not files else len(files)
|
| 268 |
+
logger.info(f"files_ui start count={n}")
|
| 269 |
+
if not files:
|
| 270 |
+
return []
|
| 271 |
+
paths: List[str] = []
|
| 272 |
+
for f in files:
|
| 273 |
+
if isinstance(f, str):
|
| 274 |
+
paths.append(f)
|
| 275 |
+
elif hasattr(f, "name"):
|
| 276 |
+
paths.append(f.name)
|
| 277 |
+
try:
|
| 278 |
+
results = MANAGER.transcribe_files(paths)
|
| 279 |
+
except Exception:
|
| 280 |
+
logger.exception("files_run failed"); raise
|
| 281 |
+
table = [[os.path.basename(r["path"]), r["text"]] for r in results]
|
| 282 |
+
logger.info("files_ui ok")
|
| 283 |
+
return table
|
| 284 |
+
|
| 285 |
+
# ----------------------------
|
| 286 |
+
# UI
|
| 287 |
+
# ----------------------------
|
| 288 |
+
with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
|
| 289 |
+
with gr.Tab("Mic"):
|
| 290 |
+
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak")
|
| 291 |
+
text_out = gr.Textbox(label="Transcript", lines=8)
|
| 292 |
+
flush_btn = gr.Button("Flush")
|
| 293 |
+
state_id = gr.State() # only a string id
|
| 294 |
+
mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out])
|
| 295 |
+
flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
|
| 296 |
+
|
| 297 |
+
with gr.Tab("Files"):
|
| 298 |
+
files = gr.File(file_count="multiple", type="filepath", label="Upload audio files")
|
| 299 |
+
run_btn = gr.Button("Run")
|
| 300 |
+
results_table = gr.Dataframe(headers=["file", "text"], label="Results",
|
| 301 |
+
row_count=(0, "dynamic"), col_count=(2, "fixed"))
|
| 302 |
+
run_btn.click(files_run, inputs=[files], outputs=[results_table])
|
| 303 |
+
|
| 304 |
+
demo.queue().launch(ssr_mode=False)
|