#!/usr/bin/env python3 # eval_arena.py - Evaluate HuggingFace datasets against AI judges import argparse import json import time from typing import Any, Dict, List import pandas as pd from datasets import load_dataset from loguru import logger from sklearn.metrics import balanced_accuracy_score, f1_score from tqdm import tqdm from src.judge import JudgeManager def load_models( models_path: str = "models.jsonl", ) -> List[Dict[str, Any]]: """Load models from a JSONL file""" models = [] with open(models_path, "r") as f: for line in f: if line.strip(): # Skip empty lines try: models.append(json.loads(line)) except json.JSONDecodeError: logger.warning(f"Failed to parse line: {line}") return models def get_test_type_from_dataset(dataset_name: str) -> str: """Determine the appropriate test type based on the dataset name""" if any(kw in dataset_name.lower() for kw in ["injection", "jailbreak"]): return "prompt_injections" elif any(kw in dataset_name.lower() for kw in ["safety", "harm", "toxic"]): return "safety" elif any(kw in dataset_name.lower() for kw in ["fact", "ground", "truthful"]): return "grounding" elif any(kw in dataset_name.lower() for kw in ["policy", "guideline", "ethics"]): return "policy" else: return "generic" # Generic evaluation as fallback def parse_ground_truth( label_column: str, label: Any, safe_label: str = "SAFE", unsafe_label: str = "UNSAFE", ) -> str: """Convert various label formats to a standardized format for comparison""" if isinstance(label, bool): if label_column.lower() == "safe": return "SAFE" if not label else "UNSAFE" return "SAFE" if label else "UNSAFE" elif isinstance(label, (int, float)): if label_column.lower() == "safe": return safe_label if label > 0.5 else unsafe_label return "UNSAFE" if label > 0.5 else "SAFE" elif isinstance(label, str): label = label.upper() # Map common label formats to the ones used by judges mappings = { "TRUE": "GROUNDED", "FALSE": "UNGROUNDED", "YES": "GROUNDED", "NO": "UNGROUNDED", "PASS": "SAFE", "FAIL": "UNSAFE", "UNSAFE": "UNSAFE", "SAFE": "SAFE", "HARMFUL": "UNSAFE", "NOT_HARMFUL": "SAFE", "COMPLIANT": "COMPLIES", "NONCOMPLIANT": "VIOLATION", "GOOD": "GOOD_RESPONSE", "BAD": "POOR_RESPONSE", "BENIGN": "SAFE", "JAILBREAK": "PROMPT_INJECTION", } return mappings.get(label, label) else: logger.warning(f"Unrecognized label format: {label}, type: {type(label)}") return str(label).upper() def calculate_metrics( predictions: List[str], ground_truth: List[str], ) -> Dict[str, float]: """Calculate performance metrics""" metrics = {} # Filter out any pairs where we don't have both prediction and ground truth valid_pairs = [(p, gt) for p, gt in zip(predictions, ground_truth) if p and gt] if not valid_pairs: logger.warning("No valid prediction-ground truth pairs for metric calculation") return {"f1": 0, "balanced_accuracy": 0} preds, gts = zip(*valid_pairs) # Get unique labels unique_labels = list(set(preds) | set(gts)) if len(unique_labels) == 1: # Only one class present, can't calculate balanced accuracy metrics["balanced_accuracy"] = 1.0 if preds == gts else 0.0 else: try: metrics["balanced_accuracy"] = balanced_accuracy_score(gts, preds) except Exception as e: logger.error(f"Error calculating balanced accuracy: {e}") metrics["balanced_accuracy"] = 0 try: # Try multi-class F1 metrics["f1"] = f1_score(gts, preds, average="weighted", zero_division=0) except Exception as e: logger.error(f"Error calculating F1 score: {e}") metrics["f1"] = 0 return metrics def extract_label_from_evaluation(evaluation: Dict[str, Any]) -> str: """Extract the label from the judge evaluation result""" # Check if we have a raw evaluation string if "evaluation" in evaluation: eval_text = evaluation["evaluation"] # Look for "LABEL:" in the evaluation text import re label_match = re.search(r"LABEL:\s*(\w+(?:_\w+)*)", eval_text, re.IGNORECASE) if label_match: return label_match.group(1).upper() # If no label found in evaluation, try other fields if "label" in evaluation: return evaluation["label"].upper() logger.warning(f"Could not extract label from evaluation: {evaluation}") return "" def evaluate_dataset( dataset_name: str, models_path: str = "models.jsonl", max_samples: int = None, test_type: str = None, dataset_config: str = None, ) -> None: """Main function to evaluate a dataset against AI judges""" logger.info(f"Evaluating dataset: {dataset_name}") # Load models from models.jsonl models = load_models(models_path) if not models: logger.error("No models found in models.jsonl") return logger.info(f"Loaded {len(models)} models") # Initialize JudgeManager with models judge_manager = JudgeManager(models) # Determine which split to use try: # Load the dataset with config if provided if dataset_config: logger.info(f"Using dataset config: {dataset_config}") dataset = load_dataset(dataset_name, dataset_config) else: try: dataset = load_dataset(dataset_name) except ValueError as e: # If error mentions config name is missing, provide help if "Config name is missing" in str(e): logger.error(f"This dataset requires a config name. {str(e)}") logger.error("Please use --dataset-config to specify the config.") return raise e logger.info(f"Available splits: {list(dataset.keys())}") # Prefer test split if available, otherwise use validation or train if "test" in dataset: split = "test" elif "validation" in dataset: split = "validation" elif "train" in dataset: split = "train" else: # Use the first available split split = list(dataset.keys())[0] logger.info(f"Using split: {split}") data = dataset[split] # Limit the number of samples if specified if max_samples and max_samples > 0: data = data.select(range(min(max_samples, len(data)))) logger.info(f"Dataset contains {len(data)} samples") except Exception as e: logger.error(f"Error loading dataset {dataset_name}: {e}") return # Try to determine the columns for input and output # This is a heuristic as different datasets have different structures column_names = data.column_names logger.info(f"Dataset columns: {column_names}") # Look for common column names that might contain input text input_column = None possible_input_names = [ "input", "question", "prompt", "instruction", "context", "text", "adversarial", "doc", ] for possible_name in possible_input_names: matches = [col for col in column_names if possible_name in col.lower()] if matches: input_column = matches[0] break # If still not found, try to use the first string column if not input_column: for col in column_names: if isinstance(data[0][col], str): input_column = col break # Similar approach for output column output_column = None possible_output_names = [ "output", "answer", "response", "completion", "generation", "claim", ] for possible_name in possible_output_names: matches = [col for col in column_names if possible_name in col.lower()] if matches: output_column = matches[0] break # Look for label/ground truth column label_column = None possible_label_names = [ "label", "ground_truth", "class", "target", "gold", "correct", "type", "safe", ] for possible_name in possible_label_names: matches = [col for col in column_names if possible_name in col.lower()] if matches: label_column = matches[0] break # Determine test type based on dataset name or use provided test_type if test_type: logger.info(f"Using provided test type: {test_type}") else: test_type = get_test_type_from_dataset(dataset_name) logger.info(f"Auto-detected test type: {test_type}") # Check if we have the minimum required columns based on test type input_only_test_types = ["safety", "prompt_injections"] requires_output = test_type not in input_only_test_types if not input_column: logger.error("Could not determine input column, required for all test types.") return if requires_output and not output_column: logger.error(f"Test type '{test_type}' requires output column, none found.") return # Log what columns we're using column_info = f"Using columns: input={input_column}" if output_column: column_info += f", output={output_column}" if label_column: column_info += f", label={label_column}" else: logger.warning("No label column found. Cannot compare against ground truth.") logger.info(column_info) # Initialize results storage raw_results = [] judge_metrics = { judge["id"]: { "judge_id": judge["id"], "judge_name": judge["name"], "predictions": [], "ground_truths": [], "total_time": 0, "count": 0, "correct": 0, } for judge in models } # Process each sample in the dataset for i, sample in enumerate(tqdm(data)): input_text = sample[input_column] # Use empty string as output if output column is not available # but only for test types that can work with just input output_text = "" if output_column and output_column in sample: output_text = sample[output_column] elif requires_output: logger.warning(f"Sample {i} missing output field required for '{test_type}'") continue # Get ground truth if available ground_truth = None if label_column and label_column in sample: ground_truth = parse_ground_truth(label_column, sample[label_column]) # Evaluate with each judge for judge in models: judge_id = judge["id"] try: # Time the evaluation start_time = time.time() logger.info(f"Evaluating sample {i} with judge {judge_id}") # Get evaluation from judge evaluation = judge_manager.get_evaluation( judge=judge, input_text=input_text, output_text=output_text, test_type=test_type, ) elapsed_time = time.time() - start_time # Extract label from evaluation prediction = extract_label_from_evaluation(evaluation) # Store raw result raw_result = { "dataset": dataset_name, "sample_id": i, "judge_id": judge_id, "judge_name": judge["name"], "input": input_text, "output": output_text, "prediction": prediction, "ground_truth": ground_truth, "latency": elapsed_time, "evaluation": evaluation.get("evaluation", ""), } raw_results.append(raw_result) # Update metrics judge_metrics[judge_id]["predictions"].append(prediction) judge_metrics[judge_id]["total_time"] += elapsed_time judge_metrics[judge_id]["count"] += 1 if ground_truth: judge_metrics[judge_id]["ground_truths"].append(ground_truth) if prediction == ground_truth: judge_metrics[judge_id]["correct"] += 1 except Exception as e: logger.error(f"Error evaluating sample {i} with judge {judge_id}: {e}") # Save raw results raw_df = pd.DataFrame(raw_results) raw_results_filename = f"{dataset_name.replace('/', '-')}-raw-results.csv" raw_results_path = f"benchmarks/{raw_results_filename}" raw_df.to_csv(raw_results_path, index=False) logger.info(f"Raw results saved to {raw_results_path}") # Calculate final metrics for each judge judges_metrics = [] for judge_id in raw_df["judge_id"].unique(): judge_results = raw_df[raw_df["judge_id"] == judge_id].copy() if test_type == "prompt_injections": gt_label = "PROMPT_INJECTION" else: gt_label = "UNSAFE" judge_results["ground_truth_binary"] = judge_results["ground_truth"].isin( [gt_label], ) judge_results["prediction_binary"] = judge_results["prediction"].isin( [gt_label], ) if judge_results["ground_truth_binary"].any(): f1 = f1_score( judge_results["ground_truth_binary"], judge_results["prediction_binary"], average="binary", pos_label=True, # We are comparing boolean Series now ) else: logger.warning(f"No positive ground truth samples for judge {judge_id}, setting F1 to 0.") f1 = 0.0 bacc = balanced_accuracy_score( judge_results["ground_truth"].astype(str), judge_results["prediction"].astype(str), ) judge_results["correct"] = judge_results["prediction"] == judge_results["ground_truth"] avg_latency = judge_results["latency"].mean() total_time = judge_results["latency"].sum() print( f"Judge {judge_id} F1: {f1:.4f}, BAcc: {bacc:.4f}, " f"Avg Latency: {avg_latency:.2f}s, Total Time: {total_time:.2f}s" ) # aggregate the metrics to a dataframe judges_metrics.append( { "judge_id": judge_id, "judge_name": judge_results["judge_name"].iloc[0], "dataset": dataset_name, "f1": f1, "bacc": bacc, "avg_latency": avg_latency, "total_time": total_time, "count": len(judge_results), "correct": judge_results["correct"].sum(), }, ) judges_metrics_df = pd.DataFrame(judges_metrics) judges_metrics_filename = f"{dataset_name.replace('/', '-')}-judges-metrics.csv" judges_metrics_path = f"benchmarks/{judges_metrics_filename}" judges_metrics_df.to_csv(judges_metrics_path, index=False) logger.info(f"Judge metrics saved to {judges_metrics_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate HuggingFace datasets against AI judges") parser.add_argument("dataset", help="HuggingFace dataset name (e.g., 'truthful_qa')") parser.add_argument("--models", default="models.jsonl", help="Path to models JSONL file") parser.add_argument( "--max-samples", type=int, help="Maximum number of samples to evaluate", ) parser.add_argument( "--test-type", choices=[ "prompt_injections", "safety", "grounding", "policy", "generic", ], help="Override test type (default: auto-detect from dataset name)", ) parser.add_argument( "--dataset-config", help="Dataset config name (e.g., 'train' for allenai/wildjailbreak)", ) args = parser.parse_args() evaluate_dataset( args.dataset, args.models, args.max_samples, args.test_type, args.dataset_config, )