from __future__ import annotations import os import copy import uuid import logging from typing import List, Optional, Tuple, Dict # Reduce progress/log spam before heavy imports os.environ.setdefault("TQDM_DISABLE", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import numpy as np import torch import torchaudio import soundfile as sf import gradio as gr # NeMo from nemo.collections.asr.models import ASRModel from omegaconf import OmegaConf from nemo.utils import logging as nemo_logging # ---------------------------- # Config # ---------------------------- MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3") TARGET_SR = 16_000 BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8")) CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0")) FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0")) # ---------------------------- # Logging (unified) # ---------------------------- LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() logger = logging.getLogger("parakeet_app") logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) _handler = logging.StreamHandler() _handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) logger.handlers = [_handler] logger.propagate = False # Quiet NeMo logs nemo_logging.setLevel(logging.ERROR) logging.getLogger("nemo").setLevel(logging.ERROR) logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR) torch.set_grad_enabled(False) # ---------------------------- # Audio utils # ---------------------------- def to_mono_np(x: np.ndarray) -> np.ndarray: if x.ndim == 2: x = x.mean(axis=1) return x.astype(np.float32, copy=False) class ResamplerCache: def __init__(self): self._cache: Dict[int, torchaudio.transforms.Resample] = {} def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray: if src_sr == TARGET_SR: return wav if src_sr not in self._cache: logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}") self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR) t = torch.from_numpy(wav) if t.ndim == 1: t = t.unsqueeze(0) y = self._cache[src_sr](t) return y.squeeze(0).numpy() RESAMPLER = ResamplerCache() def load_mono16k(path: str) -> np.ndarray: """Load any audio file, convert to mono float32 at 16 kHz.""" try: wav, sr = sf.read(path, dtype="float32", always_2d=True) # (T,C) wav = wav.mean(axis=1).astype(np.float32, copy=False) return RESAMPLER.resample(wav, sr) except Exception: wav_t, sr = torchaudio.load(path) # (C,T) if wav_t.dtype != torch.float32: wav_t = wav_t.float() wav = wav_t.mean(dim=0).numpy() return RESAMPLER.resample(wav, int(sr)) # ---------------------------- # Model manager (MALSD batched beam everywhere, loop_labels=True) # ---------------------------- class ParakeetManager: def __init__(self, device: str = "cpu"): self.device = torch.device(device) logger.info(f"loading_model name={MODEL_NAME} device={self.device}") self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME) self.model.to(self.device) self.model.eval() for p in self.model.parameters(): p.requires_grad = False # Base decoding cfg differs by class if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"): self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg) else: self._base_decoding = copy.deepcopy(self.model.cfg.decoding) self._set_malsd_beam() # Enable encoder caching for better streaming context (per NeMo docs/tutorials) if hasattr(self.model.encoder, "set_default_att_context_size"): self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering logger.info("encoder_caching_enabled left=512 right=16") logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}") def _set_malsd_beam(self): cfg = copy.deepcopy(self._base_decoding) cfg.strategy = "malsd_batch" cfg.beam = OmegaConf.create({ "beam_size": BEAM_SIZE, "return_best_hypothesis": True, "score_norm": True, "allow_cuda_graphs": False, # CPU-only "max_symbols_per_step": 10, }) OmegaConf.set_struct(cfg, False) cfg["loop_labels"] = True cfg["fused_batch_size"] = -1 cfg["compute_timestamps"] = False if hasattr(cfg, "greedy"): cfg.greedy.use_cuda_graph_decoder = False self.model.change_decoding_strategy(cfg) logger.info("decoding_set strategy=malsd_batch loop_labels=True") def _transcribe(self, items: List, *, partial=None): with torch.inference_mode(): return self.model.transcribe( items, batch_size=1 if len(items) == 1 else OFFLINE_BATCH, num_workers=0, return_hypotheses=True, partial_hypothesis=partial, ) # Offline batch def transcribe_files(self, paths: List[str]): n = 0 if not paths else len(paths) logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}") if not paths: return [] arrays = [load_mono16k(p) for p in paths] out = self._transcribe(arrays, partial=None) results = [] for p, o in zip(paths, out): h = o[0] if isinstance(o, list) and o else o text = h if isinstance(h, str) else getattr(h, "text", "") results.append({"path": p, "text": text}) logger.info("files_run ok") return results # Streaming step (rolling hypothesis) def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object: out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None) h = out[0][0] if isinstance(out[0], list) else out[0] return h # Hypothesis # ---------------------------- # Streaming session (no overlap, rolling hypothesis) # ---------------------------- class StreamingSession: def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float): self.mgr = manager self.chunk_s = chunk_s self.flush_pad_s = flush_pad_s self.hyp = None self.pending = np.zeros(0, dtype=np.float32) self.text = "" logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s") def add_audio(self, audio: np.ndarray, src_sr: int): mono = to_mono_np(audio) res = RESAMPLER.resample(mono, src_sr) self.pending = np.concatenate([self.pending, res]) if self.pending.size else res self._drain() def _drain(self): C = int(self.chunk_s * TARGET_SR) while self.pending.size >= C: chunk = self.pending[:C] self.pending = self.pending[C:] try: self.hyp = self.mgr.stream_step(chunk, self.hyp) new_text = getattr(self.hyp, "text", "") if new_text: if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended self.text = new_text else: # Else append (handles per-chunk case) self.text += (' ' if self.text else '') + new_text except Exception: logger.exception("mic_step failed") break def flush(self) -> str: if self.pending.size: pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32) final = np.concatenate([self.pending, pad]) try: self.hyp = self.mgr.stream_step(final, self.hyp) new_text = getattr(self.hyp, "text", "") if new_text: if self.text and new_text.startswith(self.text): self.text = new_text else: self.text += (' ' if self.text else '') + new_text self.text += '.' # Add period for sentence closure on flush except Exception: logger.exception("mic_flush failed") self.pending = np.zeros(0, dtype=np.float32) return self.text # ---------------------------- # Simple session registry (avoid deepcopy in gr.State) # ---------------------------- SESS: Dict[str, StreamingSession] = {} def _new_session_id() -> str: return uuid.uuid4().hex # ---------------------------- # Gradio callbacks # ---------------------------- MANAGER = ParakeetManager(device="cpu") def _parse_gr_audio(x) -> Tuple[np.ndarray, int]: if x is None: return np.zeros(0, dtype=np.float32), TARGET_SR if isinstance(x, tuple) and len(x) == 2: sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr if isinstance(x, dict) and "data" in x and "sampling_rate" in x: arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr if isinstance(x, np.ndarray): return x.astype(np.float32, copy=False), TARGET_SR logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload") def mic_step(audio_chunk, sess_id: Optional[str]): if not sess_id or sess_id not in SESS: sess_id = _new_session_id() SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S) sess = SESS[sess_id] try: wav, sr = _parse_gr_audio(audio_chunk) except Exception: logger.exception("mic_parse failed") return sess_id, sess.text if wav.size: sess.add_audio(wav, sr) return sess_id, sess.text def mic_flush(sess_id: Optional[str]): if not sess_id or sess_id not in SESS: return None, "" text = SESS[sess_id].flush() logger.info("mic_flush ok") return None, text def files_run(files): n = 0 if not files else len(files) logger.info(f"files_ui start count={n}") if not files: return [] paths: List[str] = [] for f in files: if isinstance(f, str): paths.append(f) elif hasattr(f, "name"): paths.append(f.name) try: results = MANAGER.transcribe_files(paths) except Exception: logger.exception("files_run failed"); raise table = [[os.path.basename(r["path"]), r["text"]] for r in results] logger.info("files_ui ok") return table # ---------------------------- # UI Definition # ---------------------------- with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo: gr.Markdown("### RELEASE: GIGA-CHAD-v.0.7") features_data = [ ["Model Setup", "Loads Parakeet-TDT-0.6b-v3 (RNNT-based) with MALSD " "decoding for beam exploration and loop labels for alignments."], ["Audio Handling", "Resamples to 16kHz mono, supports various formats."], ["Streaming (Mic)", "Partial hypotheses for seamless updates, " "session-based for multi-chunk context."], ["UI", "Gradio tabs—Mic for live input/output (flush to finalize), " "Files for batch results table."], ["Tech Stack", "NeMo (ASR core), Gradio (web UI), Torchaudio/Soundfile " "(audio utils)."], ] gr.Dataframe( value=features_data, headers=["Feature", "Description"], datatype=["text", "text"], row_count=(len(features_data), "fixed"), col_count=(2, "fixed"), interactive=False, wrap=True, ) with gr.Tab("Mic"): mic = gr.Audio( sources=["microphone"], type="numpy", streaming=True, label="Speak" ) text_out = gr.Textbox(label="Transcript", lines=4) flush_btn = gr.Button("Flush") state_id = gr.State() mic.stream( mic_step, inputs=[mic, state_id], outputs=[state_id, text_out] ) flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out]) with gr.Tab("Files"): files = gr.File( file_count="multiple", type="filepath", label="Upload audio files" ) run_btn = gr.Button("Run") results_table = gr.Dataframe( headers=["file", "text"], label="Results", row_count=(5, "dynamic"), col_count=(2, "fixed"), wrap=True, ) run_btn.click(files_run, inputs=[files], outputs=[results_table]) with gr.Row(): with gr.Column(): demo_description = ( "

Parakeet-TDT v3 ASR Demo: Real-Time Mic & File " "Transcription on CPU

" "

This Hugging Face Space demonstrates a lightweight, CPU-based " "Automatic Speech Recognition (ASR) application using NVIDIA's " "Parakeet-TDT-0.6b-v3 model from NeMo. Unlike NVIDIA's official demo " "(which only supports file uploads), this app shines with " "real-time microphone streaming transcribe live " "speech incrementally with high quality and context retention. " "It's perfect for interactive demos, voice notes, or testing " "multilingual ASR without a GPU.

" ) gr.HTML(demo_description) with gr.Column(): usage_html = ( "

Usage

" "
    " "
  1. Mic Tab: Click \"RECORD\" then speak into " "your mic - text updates live. \"Flush\" button does nothing, " "it's a feature :)
  2. " "
  3. Files Tab: Upload audio files (WAV); click " "\"Run\" for transcripts. (Tested only WAV files, TODO: handle " "more types like mp4)
  4. " "
" ) gr.HTML(usage_html) limitations_html = ( "

Limitations

" "" ) gr.HTML(limitations_html) highlights_html = ( "

Why is this Space amazing? (For people looking for low-level stuff " "of \"AI\" - yeah, I did it! BEAM! Streaming, no greedy_batch trash)

" "" ) gr.HTML(highlights_html) todo_html = ( "

TODO:

" "" "

Contributions welcome! Fork and PR improvements.

" "

Built with ❤️ using Grok's guidance.

" ) gr.HTML(todo_html) gr.HTML( "

If you redistribute transcripts or fine-tuned weights, " "please retain the CC-BY-4.0 attribution notice.

" ) demo.queue().launch(ssr_mode=False)