WJ88's picture
Update app.py
4b17dd2 verified
from __future__ import annotations
import os
import copy
import uuid
import logging
from typing import List, Optional, Tuple, Dict
# Reduce progress/log spam before heavy imports
os.environ.setdefault("TQDM_DISABLE", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import numpy as np
import torch
import torchaudio
import soundfile as sf
import gradio as gr
# NeMo
from nemo.collections.asr.models import ASRModel
from omegaconf import OmegaConf
from nemo.utils import logging as nemo_logging
# ----------------------------
# Config
# ----------------------------
MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
TARGET_SR = 16_000
BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
# ----------------------------
# Logging (unified)
# ----------------------------
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
logger = logging.getLogger("parakeet_app")
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
_handler = logging.StreamHandler()
_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s"))
logger.handlers = [_handler]
logger.propagate = False
# Quiet NeMo logs
nemo_logging.setLevel(logging.ERROR)
logging.getLogger("nemo").setLevel(logging.ERROR)
logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR)
torch.set_grad_enabled(False)
# ----------------------------
# Audio utils
# ----------------------------
def to_mono_np(x: np.ndarray) -> np.ndarray:
if x.ndim == 2:
x = x.mean(axis=1)
return x.astype(np.float32, copy=False)
class ResamplerCache:
def __init__(self):
self._cache: Dict[int, torchaudio.transforms.Resample] = {}
def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray:
if src_sr == TARGET_SR:
return wav
if src_sr not in self._cache:
logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}")
self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR)
t = torch.from_numpy(wav)
if t.ndim == 1:
t = t.unsqueeze(0)
y = self._cache[src_sr](t)
return y.squeeze(0).numpy()
RESAMPLER = ResamplerCache()
def load_mono16k(path: str) -> np.ndarray:
"""Load any audio file, convert to mono float32 at 16 kHz."""
try:
wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C)
wav = wav.mean(axis=1).astype(np.float32, copy=False)
return RESAMPLER.resample(wav, sr)
except Exception:
wav_t, sr = torchaudio.load(path) # (C,T)
if wav_t.dtype != torch.float32:
wav_t = wav_t.float()
wav = wav_t.mean(dim=0).numpy()
return RESAMPLER.resample(wav, int(sr))
# ----------------------------
# Model manager (MALSD batched beam everywhere, loop_labels=True)
# ----------------------------
class ParakeetManager:
def __init__(self, device: str = "cpu"):
self.device = torch.device(device)
logger.info(f"loading_model name={MODEL_NAME} device={self.device}")
self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME)
self.model.to(self.device)
self.model.eval()
for p in self.model.parameters():
p.requires_grad = False
# Base decoding cfg differs by class
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"):
self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg)
else:
self._base_decoding = copy.deepcopy(self.model.cfg.decoding)
self._set_malsd_beam()
# Enable encoder caching for better streaming context (per NeMo docs/tutorials)
if hasattr(self.model.encoder, "set_default_att_context_size"):
self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
logger.info("encoder_caching_enabled left=512 right=16")
logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
def _set_malsd_beam(self):
cfg = copy.deepcopy(self._base_decoding)
cfg.strategy = "malsd_batch"
cfg.beam = OmegaConf.create({
"beam_size": BEAM_SIZE,
"return_best_hypothesis": True,
"score_norm": True,
"allow_cuda_graphs": False, # CPU-only
"max_symbols_per_step": 10,
})
OmegaConf.set_struct(cfg, False)
cfg["loop_labels"] = True
cfg["fused_batch_size"] = -1
cfg["compute_timestamps"] = False
if hasattr(cfg, "greedy"):
cfg.greedy.use_cuda_graph_decoder = False
self.model.change_decoding_strategy(cfg)
logger.info("decoding_set strategy=malsd_batch loop_labels=True")
def _transcribe(self, items: List, *, partial=None):
with torch.inference_mode():
return self.model.transcribe(
items,
batch_size=1 if len(items) == 1 else OFFLINE_BATCH,
num_workers=0,
return_hypotheses=True,
partial_hypothesis=partial,
)
# Offline batch
def transcribe_files(self, paths: List[str]):
n = 0 if not paths else len(paths)
logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}")
if not paths:
return []
arrays = [load_mono16k(p) for p in paths]
out = self._transcribe(arrays, partial=None)
results = []
for p, o in zip(paths, out):
h = o[0] if isinstance(o, list) and o else o
text = h if isinstance(h, str) else getattr(h, "text", "")
results.append({"path": p, "text": text})
logger.info("files_run ok")
return results
# Streaming step (rolling hypothesis)
def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object:
out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None)
h = out[0][0] if isinstance(out[0], list) else out[0]
return h # Hypothesis
# ----------------------------
# Streaming session (no overlap, rolling hypothesis)
# ----------------------------
class StreamingSession:
def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float):
self.mgr = manager
self.chunk_s = chunk_s
self.flush_pad_s = flush_pad_s
self.hyp = None
self.pending = np.zeros(0, dtype=np.float32)
self.text = ""
logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s")
def add_audio(self, audio: np.ndarray, src_sr: int):
mono = to_mono_np(audio)
res = RESAMPLER.resample(mono, src_sr)
self.pending = np.concatenate([self.pending, res]) if self.pending.size else res
self._drain()
def _drain(self):
C = int(self.chunk_s * TARGET_SR)
while self.pending.size >= C:
chunk = self.pending[:C]
self.pending = self.pending[C:]
try:
self.hyp = self.mgr.stream_step(chunk, self.hyp)
new_text = getattr(self.hyp, "text", "")
if new_text:
if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
self.text = new_text
else: # Else append (handles per-chunk case)
self.text += (' ' if self.text else '') + new_text
except Exception:
logger.exception("mic_step failed")
break
def flush(self) -> str:
if self.pending.size:
pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32)
final = np.concatenate([self.pending, pad])
try:
self.hyp = self.mgr.stream_step(final, self.hyp)
new_text = getattr(self.hyp, "text", "")
if new_text:
if self.text and new_text.startswith(self.text):
self.text = new_text
else:
self.text += (' ' if self.text else '') + new_text
self.text += '.' # Add period for sentence closure on flush
except Exception:
logger.exception("mic_flush failed")
self.pending = np.zeros(0, dtype=np.float32)
return self.text
# ----------------------------
# Simple session registry (avoid deepcopy in gr.State)
# ----------------------------
SESS: Dict[str, StreamingSession] = {}
def _new_session_id() -> str:
return uuid.uuid4().hex
# ----------------------------
# Gradio callbacks
# ----------------------------
MANAGER = ParakeetManager(device="cpu")
def _parse_gr_audio(x) -> Tuple[np.ndarray, int]:
if x is None:
return np.zeros(0, dtype=np.float32), TARGET_SR
if isinstance(x, tuple) and len(x) == 2:
sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr
if isinstance(x, dict) and "data" in x and "sampling_rate" in x:
arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr
if isinstance(x, np.ndarray):
return x.astype(np.float32, copy=False), TARGET_SR
logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload")
def mic_step(audio_chunk, sess_id: Optional[str]):
if not sess_id or sess_id not in SESS:
sess_id = _new_session_id()
SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S)
sess = SESS[sess_id]
try:
wav, sr = _parse_gr_audio(audio_chunk)
except Exception:
logger.exception("mic_parse failed")
return sess_id, sess.text
if wav.size:
sess.add_audio(wav, sr)
return sess_id, sess.text
def mic_flush(sess_id: Optional[str]):
if not sess_id or sess_id not in SESS:
return None, ""
text = SESS[sess_id].flush()
logger.info("mic_flush ok")
return None, text
def files_run(files):
n = 0 if not files else len(files)
logger.info(f"files_ui start count={n}")
if not files:
return []
paths: List[str] = []
for f in files:
if isinstance(f, str):
paths.append(f)
elif hasattr(f, "name"):
paths.append(f.name)
try:
results = MANAGER.transcribe_files(paths)
except Exception:
logger.exception("files_run failed"); raise
table = [[os.path.basename(r["path"]), r["text"]] for r in results]
logger.info("files_ui ok")
return table
# ----------------------------
# UI Definition
# ----------------------------
with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo:
gr.Markdown("### RELEASE: GIGA-CHAD-v.0.7")
features_data = [
["Model Setup", "Loads Parakeet-TDT-0.6b-v3 (RNNT-based) with MALSD "
"decoding for beam exploration and loop labels for alignments."],
["Audio Handling", "Resamples to 16kHz mono, supports various formats."],
["Streaming (Mic)", "Partial hypotheses for seamless updates, "
"session-based for multi-chunk context."],
["UI", "Gradio tabs—Mic for live input/output (flush to finalize), "
"Files for batch results table."],
["Tech Stack", "NeMo (ASR core), Gradio (web UI), Torchaudio/Soundfile "
"(audio utils)."],
]
gr.Dataframe(
value=features_data,
headers=["Feature", "Description"],
datatype=["text", "text"],
row_count=(len(features_data), "fixed"),
col_count=(2, "fixed"),
interactive=False,
wrap=True,
)
with gr.Tab("Mic"):
mic = gr.Audio(
sources=["microphone"], type="numpy", streaming=True, label="Speak"
)
text_out = gr.Textbox(label="Transcript", lines=4)
flush_btn = gr.Button("Flush")
state_id = gr.State()
mic.stream(
mic_step, inputs=[mic, state_id], outputs=[state_id, text_out]
)
flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out])
with gr.Tab("Files"):
files = gr.File(
file_count="multiple", type="filepath", label="Upload audio files"
)
run_btn = gr.Button("Run")
results_table = gr.Dataframe(
headers=["file", "text"],
label="Results",
row_count=(5, "dynamic"),
col_count=(2, "fixed"),
wrap=True,
)
run_btn.click(files_run, inputs=[files], outputs=[results_table])
with gr.Row():
with gr.Column():
demo_description = (
"<p><strong>Parakeet-TDT v3 ASR Demo: Real-Time Mic & File "
"Transcription on CPU</strong></p>"
"<p>This Hugging Face Space demonstrates a lightweight, CPU-based "
"Automatic Speech Recognition (ASR) application using NVIDIA's "
"Parakeet-TDT-0.6b-v3 model from NeMo. Unlike NVIDIA's official demo "
"(which only supports file uploads), this app shines with "
"<strong>real-time microphone streaming</strong> transcribe live "
"speech incrementally with high quality and context retention. "
"It's perfect for interactive demos, voice notes, or testing "
"multilingual ASR without a GPU.</p>"
)
gr.HTML(demo_description)
with gr.Column():
usage_html = (
"<h3>Usage</h3>"
"<ol>"
"<li><strong>Mic Tab</strong>: Click \"RECORD\" then speak into "
"your mic - text updates live. \"Flush\" button does nothing, "
"it's a feature :)</li>"
"<li><strong>Files Tab</strong>: Upload audio files (WAV); click "
"\"Run\" for transcripts. (Tested only WAV files, TODO: handle "
"more types like mp4)</li>"
"</ol>"
)
gr.HTML(usage_html)
limitations_html = (
"<h3>Limitations</h3>"
"<ul>"
"<li>Sessions are per-browser-tab (Gradio state) - I don't know if in "
"case many users will launch this, will it work?</li>"
"<li>To be sure, Duplicate this Space or Clone it to your own PC - for "
"full privacy, no GPU needed.</li>"
"</ul>"
)
gr.HTML(limitations_html)
highlights_html = (
"<h3>Why is this Space amazing? (For people looking for low-level stuff "
"of \"AI\" - yeah, I did it! BEAM! Streaming, no greedy_batch trash)</h3>"
"<ul>"
"<li><strong>Real-Time Mic Mode</strong>: Streams audio in 2s chunks, "
"merging hypotheses for smooth, cumulative transcripts. Handles "
"conversations with retained context.</li>"
"<li><strong>Advanced Decoding</strong>: Uses modern MALSD batch beam "
"search (beam=32) for accurate, error-resistant results, outperforming "
"basic greedy methods in ambiguous audio.</li>"
"<li><strong>CPU Efficiency</strong>: Runs fast on standard hardware (no "
"GPU needed), with optimized configs like no timestamps and fused "
"batching.</li>"
"<li><strong>File Mode Bonus</strong>: Batch transcribes uploads for "
"quick comparisons.</li>"
"<li><strong>Quality Edge</strong>: Approaches ideal transcripts with "
"minimal artifacts, making it ideal for developers/testing vs. static "
"NVIDIA spaces.</li>"
"</ul>"
)
gr.HTML(highlights_html)
todo_html = (
"<h3>TODO:</h3>"
"<ul>"
"<li>Change string-level to token level (y_sequence) hypothesis alignment "
"(quality improvement, advanced technical stuff ;))</li>"
"</ul>"
"<p>Contributions welcome! Fork and PR improvements.</p>"
"<p>Built with ❤️ using Grok's guidance.</p>"
)
gr.HTML(todo_html)
gr.HTML(
"<p>If you redistribute transcripts or fine-tuned weights, "
"please retain the CC-BY-4.0 attribution notice.</p>"
)
demo.queue().launch(ssr_mode=False)