Spaces:
Sleeping
Sleeping
| import spaces | |
| import jiwer | |
| import numpy as np | |
| import gradio as gr | |
| def calculate_wer(reference, hypothesis): | |
| reference_str = " ".join(reference) | |
| hypothesis_str = " ".join(hypothesis) | |
| return jiwer.wer(reference_str, hypothesis_str) | |
| def calculate_cer(reference, hypothesis): | |
| reference_str = " ".join(reference) | |
| hypothesis_str = " ".join(hypothesis) | |
| return jiwer.cer(reference_str, hypothesis_str) | |
| def calculate_sentence_metrics(reference, hypothesis): | |
| reference_sentences = [line.strip() for line in reference] | |
| hypothesis_sentences = [line.strip() for line in hypothesis] | |
| sentence_wers = [] | |
| sentence_cers = [] | |
| min_length = min(len(reference_sentences), len(hypothesis_sentences)) | |
| for i in range(min_length): | |
| ref = reference_sentences[i] | |
| hyp = hypothesis_sentences[i] | |
| wer = jiwer.wer(ref, hyp) | |
| cer = jiwer.cer(ref, hyp) | |
| sentence_wers.append(wer) | |
| sentence_cers.append(cer) | |
| average_wer = np.mean(sentence_wers) if sentence_wers else 0.0 | |
| std_dev_wer = np.std(sentence_wers) if sentence_wers else 0.0 | |
| average_cer = np.mean(sentence_cers) if sentence_cers else 0.0 | |
| std_dev_cer = np.std(sentence_cers) if sentence_cers else 0.0 | |
| return { | |
| "sentence_wers": sentence_wers, | |
| "sentence_cers": sentence_cers, | |
| "average_wer": average_wer, | |
| "average_cer": average_cer, | |
| "std_dev_wer": std_dev_wer, | |
| "std_dev_cer": std_dev_cer | |
| } | |
| def identify_misaligned_sentences(reference, hypothesis): | |
| reference_sentences = [line.strip() for line in reference] | |
| hypothesis_sentences = [line.strip() for line in hypothesis] | |
| misaligned = [] | |
| for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)): | |
| if ref != hyp: | |
| ref_words = ref.split() | |
| hyp_words = hyp.split() | |
| min_length = min(len(ref_words), len(hyp_words)) | |
| misalignment_start = 0 | |
| for j in range(min_length): | |
| if ref_words[j] != hyp_words[j]: | |
| misalignment_start = j | |
| break | |
| context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**']) if ref_words else "" | |
| context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**']) if hyp_words else "" | |
| misaligned.append({ | |
| "index": i + 1, | |
| "reference": ref, | |
| "hypothesis": hyp, | |
| "misalignment_start": misalignment_start, | |
| "context_ref": context_ref, | |
| "context_hyp": context_hyp | |
| }) | |
| # Handle extra sentences | |
| if len(reference_sentences) > len(hypothesis_sentences): | |
| for i in range(len(hypothesis_sentences), len(reference_sentences)): | |
| misaligned.append({ | |
| "index": i + 1, | |
| "reference": reference_sentences[i], | |
| "hypothesis": "No corresponding sentence", | |
| "misalignment_start": 0, | |
| "context_ref": reference_sentences[i], | |
| "context_hyp": "No corresponding sentence" | |
| }) | |
| elif len(hypothesis_sentences) > len(reference_sentences): | |
| for i in range(len(reference_sentences), len(hypothesis_sentences)): | |
| misaligned.append({ | |
| "index": i + 1, | |
| "reference": "No corresponding sentence", | |
| "hypothesis": hypothesis_sentences[i], | |
| "misalignment_start": 0, | |
| "context_ref": "No corresponding sentence", | |
| "context_hyp": hypothesis_sentences[i] | |
| }) | |
| return misaligned | |
| def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer): | |
| md = "### Sentence-level Metrics\n\n" | |
| md += f"**Average WER**: {average_wer:.2f}\n\n" | |
| md += f"**Standard Deviation WER**: {std_dev_wer:.2f}\n\n" | |
| md += f"**Average CER**: {average_cer:.2f}\n\n" | |
| md += f"**Standard Deviation CER**: {std_dev_cer:.2f}\n\n" | |
| md += "---\n**WER by Sentence**\n" | |
| for i, wer in enumerate(sentence_wers): | |
| md += f"- Sentence {i+1}: {wer:.2f}\n" | |
| md += "\n**CER by Sentence**\n" | |
| for i, cer in enumerate(sentence_cers): | |
| md += f"- Sentence {i+1}: {cer:.2f}\n" | |
| return md | |
| def process_files(reference_file, hypothesis_file): | |
| try: | |
| with open(reference_file.name, 'r', encoding='utf-8') as f: | |
| reference_text = f.read().splitlines() | |
| with open(hypothesis_file.name, 'r', encoding='utf-8') as f: | |
| hypothesis_text = f.read().splitlines() | |
| overall_wer = calculate_wer(reference_text, hypothesis_text) | |
| overall_cer = calculate_cer(reference_text, hypothesis_text) | |
| sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text) | |
| misaligned = identify_misaligned_sentences(reference_text, hypothesis_text) | |
| return { | |
| "Overall WER": overall_wer, | |
| "Overall CER": overall_cer, | |
| **sentence_metrics, | |
| "Misaligned Sentences": misaligned | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def process_and_display(ref_file, hyp_file): | |
| result = process_files(ref_file, hyp_file) | |
| if "error" in result: | |
| return {"error": result["error"]}, "", "" | |
| metrics = { | |
| "Overall WER": result["Overall WER"], | |
| "Overall CER": result["Overall CER"] | |
| } | |
| metrics_md = format_sentence_metrics( | |
| result["sentence_wers"], | |
| result["sentence_cers"], | |
| result["average_wer"], | |
| result["average_cer"], | |
| result["std_dev_wer"], | |
| result["std_dev_cer"] | |
| ) | |
| misaligned_md = "### Misaligned Sentences\n\n" | |
| if result["Misaligned Sentences"]: | |
| for mis in result["Misaligned Sentences"]: | |
| misaligned_md += f"**Sentence {mis['index']}**\n" | |
| misaligned_md += f"- Reference: {mis['context_ref']}\n" | |
| misaligned_md += f"- Hypothesis: {mis['context_hyp']}\n" | |
| misaligned_md += f"- Misalignment starts at position: {mis['misalignment_start']}\n\n" | |
| else: | |
| misaligned_md += "* No misaligned sentences found." | |
| return metrics, metrics_md, misaligned_md | |
| def main(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π ASR Metrics Analysis Tool") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Upload your reference and hypothesis files") | |
| reference_file = gr.File(label="Reference File (.txt)") | |
| hypothesis_file = gr.File(label="Hypothesis File (.txt)") | |
| compute_button = gr.Button("Compute Metrics", variant="primary") | |
| with gr.Column(): | |
| results_output = gr.JSON(label="Results Summary") | |
| metrics_output = gr.Markdown(label="Sentence Metrics") | |
| misaligned_output = gr.Markdown(label="Misaligned Sentences") | |
| compute_button.click( | |
| fn=process_and_display, | |
| inputs=[reference_file, hypothesis_file], | |
| outputs=[results_output, metrics_output, misaligned_output] | |
| ) | |
| demo.launch(ssr_mode=False) | |
| if __name__ == "__main__": | |
| main() | |