File size: 10,412 Bytes
359afe5 45089ef dba24db 45089ef aa67214 45089ef aa67214 45089ef 359afe5 5adf0a6 359afe5 aa67214 359afe5 aa67214 359afe5 aa67214 359afe5 aa67214 359afe5 5adf0a6 45089ef 5adf0a6 aa67214 45089ef aa67214 5adf0a6 aa67214 5adf0a6 aa67214 5adf0a6 aa67214 45089ef aa67214 45089ef aa67214 45089ef aa67214 45089ef aa67214 45089ef aa67214 5adf0a6 aa67214 cbe4946 aa67214 cbe4946 aa67214 cbe4946 aa67214 5adf0a6 aa67214 359afe5 cbe4946 dba24db 359afe5 cbe4946 359afe5 5adf0a6 aa67214 5adf0a6 359afe5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
import pandas as pd
from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit, run_wavlm_large_phoneme, run_gruut
from utils.audio_process import calculate_error_rate, load_audio
from utils.cmu_process import clean_cmu, cmu_to_ipa, text_to_phoneme
from constants import DATASETS, FINAL_SIZE
from datasets import load_dataset, Audio
import argparse
# Map model names to their runner functions
MODEL_RUNNERS = {
"HuBERT-Base": run_hubert_base,
"Whisper": run_whisper,
"HuBERT fine-tuned": run_model,
"Timit": run_timit,
"WavLM": run_wavlm_large_phoneme,
"LJSpeech Gruut": run_gruut,
}
def set_output(model, pre_pho, ref_pho, duration, per, score):
return {
"model": model,
"phonemes": pre_pho,
"ref_phonemes": ref_pho,
"duration": duration,
"PER": per,
"score": score
}
def get_output(model, wav, reference_phoneme):
"""
Run the given model, compute error rate, and return formatted output.
"""
if model not in MODEL_RUNNERS:
raise ValueError(f"Unknown model: {model}")
run_func = MODEL_RUNNERS[model]
phonemes, dur = run_func(wav)
per, score = calculate_error_rate(reference_phoneme, phonemes)
return set_output(model, phonemes, reference_phoneme, dur, per, score)
def benchmark_all(example):
"""
Run all models on a single dataset example in parallel.
"""
# Load waveform manually to avoid datasets' torchcodec dependency
wav = load_audio(example["audio"])
reference_phoneme = example["phonetic"]
reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme))
# Run all models in parallel using ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
models = [
"HuBERT-Base",
"Whisper",
"HuBERT fine-tuned",
"Timit",
"WavLM",
"LJSpeech Gruut"
]
with ThreadPoolExecutor(max_workers=len(models)) as executor:
futures = [
executor.submit(get_output, model, wav, reference_phoneme)
for model in models
]
results = [future.result() for future in futures]
return pd.DataFrame(results)
def benchmark_dataset(dataset):
"""
Run benchmark_all on each sample and compute average PER and duration per model.
"""
all_results = []
for example in dataset:
df = benchmark_all(example)
all_results.append(df)
full_df = pd.concat(all_results, ignore_index=True)
# Compute average PER and duration per model
avg_stats = (
full_df.groupby("model")[["PER", "duration"]]
.mean()
.reset_index()
.rename(columns={"PER": "Average PER", "duration": "Average Duration (s)"})
)
return full_df, avg_stats
def load_dataset_with_limits(dataset_config, max_samples=None, use_streaming=False):
"""
Load a dataset with optional size limits and streaming.
Args:
dataset_config: Dictionary containing dataset configuration
max_samples: Maximum number of samples to load (None for no limit)
use_streaming: Whether to use streaming for large datasets
Returns:
Dataset object
"""
try:
# Prepare load_dataset arguments
load_args = {
"path": dataset_config["name"],
"split": dataset_config["split"]
}
# Add config if specified
if "config" in dataset_config:
load_args["name"] = dataset_config["config"]
# Add streaming if requested
if use_streaming:
load_args["streaming"] = True
print(f"Loading {dataset_config['name']} with streaming...")
else:
print(f"Loading {dataset_config['name']}...")
dataset = load_dataset(**load_args)
# Apply size limits
if max_samples is not None:
print(f"Limiting dataset to {max_samples} samples...")
if use_streaming:
dataset = dataset.take(max_samples)
else:
dataset = dataset.select(range(min(max_samples, len(dataset))))
return dataset
except Exception as e:
print(f"[warn] skip dataset {dataset_config['name']}: {e}")
return None
def parse_cli_args():
"""
Parse and return CLI arguments for the evaluation script.
"""
parser = argparse.ArgumentParser(description='Phoneme Detection Evaluation')
parser.add_argument('--max-samples', type=int, default=None,
help='Override max_samples for all datasets')
parser.add_argument('--dataset', type=str, default=None,
help='Process only specific dataset (by name)')
return parser.parse_args()
def cast_audio_column_safely(dataset):
"""
Ensure the dataset's 'audio' column is set to non-decoding Audio.
"""
try:
dataset = dataset.cast_column("audio", Audio(decode=False))
except Exception:
pass
return dataset
def prepare_dataset_for_evaluation(dataset, dataset_config, max_samples):
"""
Normalize, deduplicate, and filter dataset examples for evaluation.
Handles both streaming and non-streaming datasets.
Returns a finalized small dataset suitable for benchmarking.
"""
field = dataset_config["field"]
use_streaming = dataset_config.get("use_streaming", False)
if use_streaming:
print("Processing streaming dataset...")
valid_samples = []
streaming_limit = min(max_samples, FINAL_SIZE)
for example in dataset:
if field == "text":
phonetic_text = text_to_phoneme(example[field])
example = {**example, "phonetic": phonetic_text}
current_field = "phonetic"
else:
current_field = field
if current_field in example:
phoneme_tokens = example[current_field].split()
if len(phoneme_tokens) >= 10:
valid_samples.append(example)
if len(valid_samples) >= streaming_limit:
break
print(f"Found {len(valid_samples)} valid samples")
if len(valid_samples) == 0:
print("No valid samples found, skipping dataset")
return None
from datasets import Dataset
dataset_final = Dataset.from_list(valid_samples)
return dataset_final
else:
if field == "text":
dataset = dataset.map(lambda x: {"phonetic": text_to_phoneme(x[field])})
field = "phonetic"
unique_texts = dataset.unique(field)
print("Unique phonetic strings (", dataset_config["name"], "):", len(unique_texts))
dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
def is_valid(example):
phoneme_tokens = example[field].split()
return len(phoneme_tokens) >= 10
dataset_filtered = dataset_unique.filter(is_valid)
final_size = min(FINAL_SIZE, len(dataset_filtered))
dataset_final = dataset_filtered.shuffle(seed=42).select(range(final_size))
return dataset_final
def evaluate_dataset(dataset_final):
"""
Run benchmarking on a capped subset of the dataset and return both
the full per-example results and the aggregated stats per model.
"""
benchmark_size = min(FINAL_SIZE, len(dataset_final))
return benchmark_dataset(dataset_final.select(range(benchmark_size)))
def update_aggregates(per_model_results, avg_stats, dataset_name):
"""
Update the aggregate dictionary per model with results from one dataset.
"""
dataset_key = dataset_name.split("/")[-1]
for _, row in avg_stats.iterrows():
model_name = str(row["model"]).replace(" ", "-")
per = float(row["Average PER"]) if row["Average PER"] is not None else None
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
if model_name not in per_model_results:
per_model_results[model_name] = {}
per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur}
def save_leaderboard_results(per_model_results, results_dir="eval-results"):
"""
Persist one JSON file per model for the leaderboard app to consume.
"""
import json, os, time
os.makedirs(results_dir, exist_ok=True)
timestamp = int(time.time())
for model_name, task_results in per_model_results.items():
org_model = f"{model_name}"
payload = {
"config": {
"model_name": org_model,
"model_dtype": "float32",
"model_sha": ""
},
"results": task_results
}
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
print(f"Saved leaderboard result: {out_path}")
def process_single_dataset(dataset_config, args, per_model_results):
"""
Load, normalize, evaluate a single dataset and update aggregates.
"""
if args.dataset and args.dataset not in dataset_config["name"]:
return
max_samples = args.max_samples if args.max_samples is not None else dataset_config.get("max_samples")
use_streaming = dataset_config.get("use_streaming", False)
dataset = load_dataset_with_limits(
dataset_config,
max_samples=max_samples,
use_streaming=use_streaming
)
if dataset is None:
return
dataset = cast_audio_column_safely(dataset)
dataset_final = prepare_dataset_for_evaluation(dataset, dataset_config, max_samples)
if dataset_final is None:
return
print(dataset_final)
print("Final size:", len(dataset_final))
full_results, avg_stats = evaluate_dataset(dataset_final)
print("Average Statistic per model (", dataset_config["name"], "):")
print(avg_stats)
update_aggregates(per_model_results, avg_stats, dataset_config["name"])
def main():
args = parse_cli_args()
per_model_results = {}
for dataset_config in DATASETS:
process_single_dataset(dataset_config, args, per_model_results)
save_leaderboard_results(per_model_results)
if __name__ == "__main__":
main()
|