WJ88 commited on
Commit
4b17dd2
·
verified ·
1 Parent(s): 85b6293

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -27,7 +27,7 @@ MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3")
27
  TARGET_SR = 16_000
28
  BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
29
  OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
30
- CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "4.0"))
31
  FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
32
 
33
  # ----------------------------
@@ -108,9 +108,9 @@ class ParakeetManager:
108
  self._set_malsd_beam()
109
 
110
  # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
111
- # if hasattr(self.model.encoder, "set_default_att_context_size"):
112
- # self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
113
- # logger.info("encoder_caching_enabled left=512 right=16")
114
 
115
  logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
116
 
@@ -191,7 +191,12 @@ class StreamingSession:
191
  self.pending = self.pending[C:]
192
  try:
193
  self.hyp = self.mgr.stream_step(chunk, self.hyp)
194
- self.text = getattr(self.hyp, "text", self.text) # Simple overwrite: trusts cumulative hyp.text
 
 
 
 
 
195
  except Exception:
196
  logger.exception("mic_step failed")
197
  break
@@ -202,11 +207,17 @@ class StreamingSession:
202
  final = np.concatenate([self.pending, pad])
203
  try:
204
  self.hyp = self.mgr.stream_step(final, self.hyp)
205
- self.text = getattr(self.hyp, "text", self.text) # Simple overwrite: trusts cumulative hyp.text
 
 
 
 
 
 
206
  except Exception:
207
  logger.exception("mic_flush failed")
208
  self.pending = np.zeros(0, dtype=np.float32)
209
- return self.text # No forced punctuation—let model handle it
210
 
211
  # ----------------------------
212
  # Simple session registry (avoid deepcopy in gr.State)
 
27
  TARGET_SR = 16_000
28
  BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "32")) # Increased for subtle quality gains
29
  OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8"))
30
+ CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0"))
31
  FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0"))
32
 
33
  # ----------------------------
 
108
  self._set_malsd_beam()
109
 
110
  # Enable encoder caching for better streaming context (per NeMo docs/tutorials)
111
+ if hasattr(self.model.encoder, "set_default_att_context_size"):
112
+ self.model.encoder.set_default_att_context_size([512, 16]) # Large left for cumulative context, small right for buffering
113
+ logger.info("encoder_caching_enabled left=512 right=16")
114
 
115
  logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}")
116
 
 
191
  self.pending = self.pending[C:]
192
  try:
193
  self.hyp = self.mgr.stream_step(chunk, self.hyp)
194
+ new_text = getattr(self.hyp, "text", "")
195
+ if new_text:
196
+ if self.text and new_text.startswith(self.text): # If cumulative (partial extends), replace with extended
197
+ self.text = new_text
198
+ else: # Else append (handles per-chunk case)
199
+ self.text += (' ' if self.text else '') + new_text
200
  except Exception:
201
  logger.exception("mic_step failed")
202
  break
 
207
  final = np.concatenate([self.pending, pad])
208
  try:
209
  self.hyp = self.mgr.stream_step(final, self.hyp)
210
+ new_text = getattr(self.hyp, "text", "")
211
+ if new_text:
212
+ if self.text and new_text.startswith(self.text):
213
+ self.text = new_text
214
+ else:
215
+ self.text += (' ' if self.text else '') + new_text
216
+ self.text += '.' # Add period for sentence closure on flush
217
  except Exception:
218
  logger.exception("mic_flush failed")
219
  self.pending = np.zeros(0, dtype=np.float32)
220
+ return self.text
221
 
222
  # ----------------------------
223
  # Simple session registry (avoid deepcopy in gr.State)