lataon's picture
update hubertbase and whisper to also convert into ipa for ranking (#1)
8df6804 verified
import os
import time
import torch
import torchaudio
from transformers import (
Wav2Vec2Processor, HubertForCTC,
WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2ForCTC, AutoProcessor, AutoModelForCTC
)
from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# === Helper: move all tensors to model device ===
def to_device(batch, device):
if isinstance(batch, dict):
return {k: v.to(device) for k, v in batch.items()}
elif isinstance(batch, torch.Tensor):
return batch.to(device)
return batch
# === Setup: Load all 3 models ===
# 1. Base HuBERT
base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval()
# 2. Whisper + phonemizer
whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base")
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device).eval()
# 3. My Hubert Model (optional HF token via env)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# print(HF_TOKEN)
proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN)
model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN).to(device).eval()
# 4. wav2vec2-xls-r-300m-timit-phoneme
# load model and processor
timit_proc = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
timit_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme").to(device).eval()
# 5 bookbot/wav2vec2-ljspeech-gruut
gruut_processor = AutoProcessor.from_pretrained("bookbot/wav2vec2-ljspeech-gruut")
gruut_model = AutoModelForCTC.from_pretrained("bookbot/wav2vec2-ljspeech-gruut").to(device).eval()
# 6 microsoft/wavlm-large-english-phoneme
wavlm_proc = AutoProcessor.from_pretrained("speech31/wavlm-large-english-phoneme")
wavlm_model = AutoModelForCTC.from_pretrained("speech31/wavlm-large-english-phoneme").to(device).eval()
# === Inference functions ===
def run_hubert_base(wav):
start = time.time()
inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values
inputs = inputs.to(device)
with torch.no_grad():
logits = base_model(inputs).logits
ids = torch.argmax(logits, dim=-1)
text = base_proc.batch_decode(ids)[0]
# Convert to phonemes (CMU-like string without stresses)
phonemes = text_to_phoneme(text)
phonemes = cmu_to_ipa(phonemes)
return phonemes.strip(), time.time() - start
def run_whisper(wav):
start = time.time()
inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features.to(device)
attention_mask = inputs.get("attention_mask", None)
gen_kwargs = {"language": "en"}
if attention_mask is not None:
gen_kwargs["attention_mask"] = attention_mask.to(device)
with torch.no_grad():
pred_ids = whisper_model.generate(input_features, **gen_kwargs)
text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0]
phonemes = text_to_phoneme(text)
phonemes = cmu_to_ipa(phonemes)
return phonemes.strip(), time.time() - start
def run_model(wav):
start = time.time()
# Prepare input (BatchEncoding supports .to(device))
inputs = proc(wav, sampling_rate=16000, return_tensors="pt").to(device)
# Forward pass
with torch.no_grad():
logits = model(**inputs).logits
# Greedy decode
ids = torch.argmax(logits, dim=-1)
phonemes = proc.batch_decode(ids)[0]
phonemes = cmu_to_ipa(phonemes)
return phonemes.strip(), time.time() - start
def run_timit(wav):
start = time.time()
# Read and process the input
inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True)
inputs = inputs.to(device)
# Forward pass
with torch.no_grad():
logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits
# Decode id into string
predicted_ids = torch.argmax(logits, axis=-1)
phonemes = timit_proc.batch_decode(predicted_ids)
phonemes = "".join(phonemes)
return phonemes.strip(), time.time() - start
def run_gruut(wav):
start = time.time()
# Preprocess waveform → model input
inputs = gruut_processor(
wav,
sampling_rate=16000,
return_tensors="pt",
padding=True
).to(device)
# Forward pass
with torch.no_grad():
logits = gruut_model(**inputs).logits
# Greedy decode → IPA phonemes
pred_ids = torch.argmax(logits, dim=-1)
phonemes = gruut_processor.batch_decode(pred_ids)[0]
phonemes = "".join(phonemes)
return phonemes.strip(), time.time() - start
def run_wavlm_large_phoneme(wav):
start = time.time()
# Preprocess waveform → model input
inputs = wavlm_proc(
wav,
sampling_rate=16000,
return_tensors="pt",
padding=True
).to(device)
input_values = inputs.input_values
attention_mask = inputs.get("attention_mask", None)
# Forward pass
with torch.no_grad():
logits = wavlm_model(input_values, attention_mask=attention_mask).logits
# Greedy decode → phoneme tokens
pred_ids = torch.argmax(logits, dim=-1)
phonemes = wavlm_proc.batch_decode(pred_ids)[0]
phonemes = "".join(phonemes)
return phonemes.strip(), time.time() - start