|
|
import os |
|
|
import time |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import ( |
|
|
Wav2Vec2Processor, HubertForCTC, |
|
|
WhisperProcessor, WhisperForConditionalGeneration, Wav2Vec2ForCTC, AutoProcessor, AutoModelForCTC |
|
|
) |
|
|
from .cmu_process import text_to_phoneme, cmu_to_ipa, clean_cmu |
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print("Using device:", device) |
|
|
|
|
|
|
|
|
def to_device(batch, device): |
|
|
if isinstance(batch, dict): |
|
|
return {k: v.to(device) for k, v in batch.items()} |
|
|
elif isinstance(batch, torch.Tensor): |
|
|
return batch.to(device) |
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") |
|
|
base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval() |
|
|
|
|
|
|
|
|
whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base") |
|
|
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device).eval() |
|
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
|
|
|
proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN) |
|
|
model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=HF_TOKEN).to(device).eval() |
|
|
|
|
|
|
|
|
|
|
|
timit_proc = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme") |
|
|
timit_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme").to(device).eval() |
|
|
|
|
|
|
|
|
|
|
|
gruut_processor = AutoProcessor.from_pretrained("bookbot/wav2vec2-ljspeech-gruut") |
|
|
gruut_model = AutoModelForCTC.from_pretrained("bookbot/wav2vec2-ljspeech-gruut").to(device).eval() |
|
|
|
|
|
|
|
|
wavlm_proc = AutoProcessor.from_pretrained("speech31/wavlm-large-english-phoneme") |
|
|
wavlm_model = AutoModelForCTC.from_pretrained("speech31/wavlm-large-english-phoneme").to(device).eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_hubert_base(wav): |
|
|
start = time.time() |
|
|
inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt", padding=True).input_values |
|
|
inputs = inputs.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = base_model(inputs).logits |
|
|
ids = torch.argmax(logits, dim=-1) |
|
|
text = base_proc.batch_decode(ids)[0] |
|
|
|
|
|
phonemes = text_to_phoneme(text) |
|
|
phonemes = cmu_to_ipa(phonemes) |
|
|
return phonemes.strip(), time.time() - start |
|
|
|
|
|
def run_whisper(wav): |
|
|
start = time.time() |
|
|
|
|
|
inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt") |
|
|
input_features = inputs.input_features.to(device) |
|
|
attention_mask = inputs.get("attention_mask", None) |
|
|
gen_kwargs = {"language": "en"} |
|
|
if attention_mask is not None: |
|
|
gen_kwargs["attention_mask"] = attention_mask.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_ids = whisper_model.generate(input_features, **gen_kwargs) |
|
|
|
|
|
text = whisper_proc.batch_decode(pred_ids, skip_special_tokens=True)[0] |
|
|
phonemes = text_to_phoneme(text) |
|
|
phonemes = cmu_to_ipa(phonemes) |
|
|
return phonemes.strip(), time.time() - start |
|
|
|
|
|
|
|
|
def run_model(wav): |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
inputs = proc(wav, sampling_rate=16000, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
ids = torch.argmax(logits, dim=-1) |
|
|
phonemes = proc.batch_decode(ids)[0] |
|
|
phonemes = cmu_to_ipa(phonemes) |
|
|
|
|
|
return phonemes.strip(), time.time() - start |
|
|
|
|
|
def run_timit(wav): |
|
|
start = time.time() |
|
|
|
|
|
inputs = timit_proc(wav, sampling_rate=16_000, return_tensors="pt", padding=True) |
|
|
inputs = inputs.to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = timit_model(inputs.input_values, attention_mask=inputs.attention_mask).logits |
|
|
|
|
|
|
|
|
predicted_ids = torch.argmax(logits, axis=-1) |
|
|
phonemes = timit_proc.batch_decode(predicted_ids) |
|
|
phonemes = "".join(phonemes) |
|
|
|
|
|
return phonemes.strip(), time.time() - start |
|
|
|
|
|
|
|
|
def run_gruut(wav): |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
inputs = gruut_processor( |
|
|
wav, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = gruut_model(**inputs).logits |
|
|
|
|
|
|
|
|
pred_ids = torch.argmax(logits, dim=-1) |
|
|
phonemes = gruut_processor.batch_decode(pred_ids)[0] |
|
|
phonemes = "".join(phonemes) |
|
|
|
|
|
return phonemes.strip(), time.time() - start |
|
|
|
|
|
def run_wavlm_large_phoneme(wav): |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
inputs = wavlm_proc( |
|
|
wav, |
|
|
sampling_rate=16000, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
).to(device) |
|
|
|
|
|
input_values = inputs.input_values |
|
|
attention_mask = inputs.get("attention_mask", None) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = wavlm_model(input_values, attention_mask=attention_mask).logits |
|
|
|
|
|
|
|
|
pred_ids = torch.argmax(logits, dim=-1) |
|
|
phonemes = wavlm_proc.batch_decode(pred_ids)[0] |
|
|
phonemes = "".join(phonemes) |
|
|
return phonemes.strip(), time.time() - start |