|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import time |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import pandas as pd |
|
|
from datasets import load_dataset |
|
|
from loguru import logger |
|
|
from sklearn.metrics import balanced_accuracy_score, f1_score |
|
|
from tqdm import tqdm |
|
|
|
|
|
from src.judge import JudgeManager |
|
|
|
|
|
|
|
|
def load_models( |
|
|
models_path: str = "models.jsonl", |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Load models from a JSONL file""" |
|
|
models = [] |
|
|
with open(models_path, "r") as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
try: |
|
|
models.append(json.loads(line)) |
|
|
except json.JSONDecodeError: |
|
|
logger.warning(f"Failed to parse line: {line}") |
|
|
return models |
|
|
|
|
|
|
|
|
def get_test_type_from_dataset(dataset_name: str) -> str: |
|
|
"""Determine the appropriate test type based on the dataset name""" |
|
|
if any(kw in dataset_name.lower() for kw in ["injection", "jailbreak"]): |
|
|
return "prompt_injections" |
|
|
elif any(kw in dataset_name.lower() for kw in ["safety", "harm", "toxic"]): |
|
|
return "safety" |
|
|
elif any(kw in dataset_name.lower() for kw in ["fact", "ground", "truthful"]): |
|
|
return "grounding" |
|
|
elif any(kw in dataset_name.lower() for kw in ["policy", "guideline", "ethics"]): |
|
|
return "policy" |
|
|
else: |
|
|
return "generic" |
|
|
|
|
|
|
|
|
def parse_ground_truth( |
|
|
label_column: str, |
|
|
label: Any, |
|
|
safe_label: str = "SAFE", |
|
|
unsafe_label: str = "UNSAFE", |
|
|
) -> str: |
|
|
"""Convert various label formats to a standardized format for comparison""" |
|
|
if isinstance(label, bool): |
|
|
if label_column.lower() == "safe": |
|
|
return "SAFE" if not label else "UNSAFE" |
|
|
return "SAFE" if label else "UNSAFE" |
|
|
elif isinstance(label, (int, float)): |
|
|
if label_column.lower() == "safe": |
|
|
return safe_label if label > 0.5 else unsafe_label |
|
|
return "UNSAFE" if label > 0.5 else "SAFE" |
|
|
elif isinstance(label, str): |
|
|
label = label.upper() |
|
|
|
|
|
mappings = { |
|
|
"TRUE": "GROUNDED", |
|
|
"FALSE": "UNGROUNDED", |
|
|
"YES": "GROUNDED", |
|
|
"NO": "UNGROUNDED", |
|
|
"PASS": "SAFE", |
|
|
"FAIL": "UNSAFE", |
|
|
"UNSAFE": "UNSAFE", |
|
|
"SAFE": "SAFE", |
|
|
"HARMFUL": "UNSAFE", |
|
|
"NOT_HARMFUL": "SAFE", |
|
|
"COMPLIANT": "COMPLIES", |
|
|
"NONCOMPLIANT": "VIOLATION", |
|
|
"GOOD": "GOOD_RESPONSE", |
|
|
"BAD": "POOR_RESPONSE", |
|
|
"BENIGN": "SAFE", |
|
|
"JAILBREAK": "PROMPT_INJECTION", |
|
|
} |
|
|
return mappings.get(label, label) |
|
|
else: |
|
|
logger.warning(f"Unrecognized label format: {label}, type: {type(label)}") |
|
|
return str(label).upper() |
|
|
|
|
|
|
|
|
def calculate_metrics( |
|
|
predictions: List[str], |
|
|
ground_truth: List[str], |
|
|
) -> Dict[str, float]: |
|
|
"""Calculate performance metrics""" |
|
|
metrics = {} |
|
|
|
|
|
|
|
|
valid_pairs = [(p, gt) for p, gt in zip(predictions, ground_truth) if p and gt] |
|
|
|
|
|
if not valid_pairs: |
|
|
logger.warning("No valid prediction-ground truth pairs for metric calculation") |
|
|
return {"f1": 0, "balanced_accuracy": 0} |
|
|
|
|
|
preds, gts = zip(*valid_pairs) |
|
|
|
|
|
|
|
|
unique_labels = list(set(preds) | set(gts)) |
|
|
|
|
|
if len(unique_labels) == 1: |
|
|
|
|
|
metrics["balanced_accuracy"] = 1.0 if preds == gts else 0.0 |
|
|
else: |
|
|
try: |
|
|
metrics["balanced_accuracy"] = balanced_accuracy_score(gts, preds) |
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating balanced accuracy: {e}") |
|
|
metrics["balanced_accuracy"] = 0 |
|
|
|
|
|
try: |
|
|
|
|
|
metrics["f1"] = f1_score(gts, preds, average="weighted", zero_division=0) |
|
|
except Exception as e: |
|
|
logger.error(f"Error calculating F1 score: {e}") |
|
|
metrics["f1"] = 0 |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def extract_label_from_evaluation(evaluation: Dict[str, Any]) -> str: |
|
|
"""Extract the label from the judge evaluation result""" |
|
|
|
|
|
if "evaluation" in evaluation: |
|
|
eval_text = evaluation["evaluation"] |
|
|
|
|
|
import re |
|
|
|
|
|
label_match = re.search(r"LABEL:\s*(\w+(?:_\w+)*)", eval_text, re.IGNORECASE) |
|
|
if label_match: |
|
|
return label_match.group(1).upper() |
|
|
|
|
|
|
|
|
if "label" in evaluation: |
|
|
return evaluation["label"].upper() |
|
|
|
|
|
logger.warning(f"Could not extract label from evaluation: {evaluation}") |
|
|
return "" |
|
|
|
|
|
|
|
|
def evaluate_dataset( |
|
|
dataset_name: str, |
|
|
models_path: str = "models.jsonl", |
|
|
max_samples: int = None, |
|
|
test_type: str = None, |
|
|
dataset_config: str = None, |
|
|
) -> None: |
|
|
"""Main function to evaluate a dataset against AI judges""" |
|
|
logger.info(f"Evaluating dataset: {dataset_name}") |
|
|
|
|
|
|
|
|
models = load_models(models_path) |
|
|
if not models: |
|
|
logger.error("No models found in models.jsonl") |
|
|
return |
|
|
|
|
|
logger.info(f"Loaded {len(models)} models") |
|
|
|
|
|
|
|
|
judge_manager = JudgeManager(models) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if dataset_config: |
|
|
logger.info(f"Using dataset config: {dataset_config}") |
|
|
dataset = load_dataset(dataset_name, dataset_config) |
|
|
else: |
|
|
try: |
|
|
dataset = load_dataset(dataset_name) |
|
|
except ValueError as e: |
|
|
|
|
|
if "Config name is missing" in str(e): |
|
|
logger.error(f"This dataset requires a config name. {str(e)}") |
|
|
logger.error("Please use --dataset-config to specify the config.") |
|
|
return |
|
|
raise e |
|
|
|
|
|
logger.info(f"Available splits: {list(dataset.keys())}") |
|
|
|
|
|
|
|
|
if "test" in dataset: |
|
|
split = "test" |
|
|
elif "validation" in dataset: |
|
|
split = "validation" |
|
|
elif "train" in dataset: |
|
|
split = "train" |
|
|
else: |
|
|
|
|
|
split = list(dataset.keys())[0] |
|
|
|
|
|
logger.info(f"Using split: {split}") |
|
|
data = dataset[split] |
|
|
|
|
|
|
|
|
if max_samples and max_samples > 0: |
|
|
data = data.select(range(min(max_samples, len(data)))) |
|
|
|
|
|
logger.info(f"Dataset contains {len(data)} samples") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading dataset {dataset_name}: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
column_names = data.column_names |
|
|
logger.info(f"Dataset columns: {column_names}") |
|
|
|
|
|
|
|
|
input_column = None |
|
|
possible_input_names = [ |
|
|
"input", |
|
|
"question", |
|
|
"prompt", |
|
|
"instruction", |
|
|
"context", |
|
|
"text", |
|
|
"adversarial", |
|
|
"doc", |
|
|
] |
|
|
for possible_name in possible_input_names: |
|
|
matches = [col for col in column_names if possible_name in col.lower()] |
|
|
if matches: |
|
|
input_column = matches[0] |
|
|
break |
|
|
|
|
|
|
|
|
if not input_column: |
|
|
for col in column_names: |
|
|
if isinstance(data[0][col], str): |
|
|
input_column = col |
|
|
break |
|
|
|
|
|
|
|
|
output_column = None |
|
|
possible_output_names = [ |
|
|
"output", |
|
|
"answer", |
|
|
"response", |
|
|
"completion", |
|
|
"generation", |
|
|
"claim", |
|
|
] |
|
|
for possible_name in possible_output_names: |
|
|
matches = [col for col in column_names if possible_name in col.lower()] |
|
|
if matches: |
|
|
output_column = matches[0] |
|
|
break |
|
|
|
|
|
|
|
|
label_column = None |
|
|
possible_label_names = [ |
|
|
"label", |
|
|
"ground_truth", |
|
|
"class", |
|
|
"target", |
|
|
"gold", |
|
|
"correct", |
|
|
"type", |
|
|
"safe", |
|
|
] |
|
|
for possible_name in possible_label_names: |
|
|
matches = [col for col in column_names if possible_name in col.lower()] |
|
|
if matches: |
|
|
label_column = matches[0] |
|
|
break |
|
|
|
|
|
|
|
|
if test_type: |
|
|
logger.info(f"Using provided test type: {test_type}") |
|
|
else: |
|
|
test_type = get_test_type_from_dataset(dataset_name) |
|
|
logger.info(f"Auto-detected test type: {test_type}") |
|
|
|
|
|
|
|
|
input_only_test_types = ["safety", "prompt_injections"] |
|
|
requires_output = test_type not in input_only_test_types |
|
|
|
|
|
if not input_column: |
|
|
logger.error("Could not determine input column, required for all test types.") |
|
|
return |
|
|
|
|
|
if requires_output and not output_column: |
|
|
logger.error(f"Test type '{test_type}' requires output column, none found.") |
|
|
return |
|
|
|
|
|
|
|
|
column_info = f"Using columns: input={input_column}" |
|
|
if output_column: |
|
|
column_info += f", output={output_column}" |
|
|
if label_column: |
|
|
column_info += f", label={label_column}" |
|
|
else: |
|
|
logger.warning("No label column found. Cannot compare against ground truth.") |
|
|
|
|
|
logger.info(column_info) |
|
|
|
|
|
|
|
|
raw_results = [] |
|
|
judge_metrics = { |
|
|
judge["id"]: { |
|
|
"judge_id": judge["id"], |
|
|
"judge_name": judge["name"], |
|
|
"predictions": [], |
|
|
"ground_truths": [], |
|
|
"total_time": 0, |
|
|
"count": 0, |
|
|
"correct": 0, |
|
|
} |
|
|
for judge in models |
|
|
} |
|
|
|
|
|
|
|
|
for i, sample in enumerate(tqdm(data)): |
|
|
input_text = sample[input_column] |
|
|
|
|
|
|
|
|
|
|
|
output_text = "" |
|
|
if output_column and output_column in sample: |
|
|
output_text = sample[output_column] |
|
|
elif requires_output: |
|
|
logger.warning(f"Sample {i} missing output field required for '{test_type}'") |
|
|
continue |
|
|
|
|
|
|
|
|
ground_truth = None |
|
|
if label_column and label_column in sample: |
|
|
ground_truth = parse_ground_truth(label_column, sample[label_column]) |
|
|
|
|
|
|
|
|
for judge in models: |
|
|
judge_id = judge["id"] |
|
|
|
|
|
try: |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"Evaluating sample {i} with judge {judge_id}") |
|
|
|
|
|
evaluation = judge_manager.get_evaluation( |
|
|
judge=judge, |
|
|
input_text=input_text, |
|
|
output_text=output_text, |
|
|
test_type=test_type, |
|
|
) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
|
|
|
prediction = extract_label_from_evaluation(evaluation) |
|
|
|
|
|
|
|
|
raw_result = { |
|
|
"dataset": dataset_name, |
|
|
"sample_id": i, |
|
|
"judge_id": judge_id, |
|
|
"judge_name": judge["name"], |
|
|
"input": input_text, |
|
|
"output": output_text, |
|
|
"prediction": prediction, |
|
|
"ground_truth": ground_truth, |
|
|
"latency": elapsed_time, |
|
|
"evaluation": evaluation.get("evaluation", ""), |
|
|
} |
|
|
raw_results.append(raw_result) |
|
|
|
|
|
|
|
|
judge_metrics[judge_id]["predictions"].append(prediction) |
|
|
judge_metrics[judge_id]["total_time"] += elapsed_time |
|
|
judge_metrics[judge_id]["count"] += 1 |
|
|
|
|
|
if ground_truth: |
|
|
judge_metrics[judge_id]["ground_truths"].append(ground_truth) |
|
|
if prediction == ground_truth: |
|
|
judge_metrics[judge_id]["correct"] += 1 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error evaluating sample {i} with judge {judge_id}: {e}") |
|
|
|
|
|
|
|
|
raw_df = pd.DataFrame(raw_results) |
|
|
raw_results_filename = f"{dataset_name.replace('/', '-')}-raw-results.csv" |
|
|
raw_results_path = f"benchmarks/{raw_results_filename}" |
|
|
raw_df.to_csv(raw_results_path, index=False) |
|
|
logger.info(f"Raw results saved to {raw_results_path}") |
|
|
|
|
|
|
|
|
judges_metrics = [] |
|
|
|
|
|
for judge_id in raw_df["judge_id"].unique(): |
|
|
judge_results = raw_df[raw_df["judge_id"] == judge_id].copy() |
|
|
if test_type == "prompt_injections": |
|
|
gt_label = "PROMPT_INJECTION" |
|
|
elif test_type == "grounding": |
|
|
gt_label = "GROUNDED" |
|
|
else: |
|
|
gt_label = "UNSAFE" |
|
|
|
|
|
judge_results["ground_truth_binary"] = judge_results["ground_truth"].isin( |
|
|
[gt_label], |
|
|
) |
|
|
judge_results["prediction_binary"] = judge_results["prediction"].isin( |
|
|
[gt_label], |
|
|
) |
|
|
|
|
|
if judge_results["ground_truth_binary"].any(): |
|
|
f1 = f1_score( |
|
|
judge_results["ground_truth_binary"], |
|
|
judge_results["prediction_binary"], |
|
|
average="binary", |
|
|
pos_label=True, |
|
|
) |
|
|
else: |
|
|
logger.warning(f"No positive ground truth samples for judge {judge_id}, setting F1 to 0.") |
|
|
f1 = 0.0 |
|
|
|
|
|
bacc = balanced_accuracy_score( |
|
|
judge_results["ground_truth"].astype(str), |
|
|
judge_results["prediction"].astype(str), |
|
|
) |
|
|
|
|
|
judge_results["correct"] = judge_results["prediction"] == judge_results["ground_truth"] |
|
|
|
|
|
avg_latency = judge_results["latency"].mean() |
|
|
total_time = judge_results["latency"].sum() |
|
|
|
|
|
print( |
|
|
f"Judge {judge_id} F1: {f1:.4f}, BAcc: {bacc:.4f}, " |
|
|
f"Avg Latency: {avg_latency:.2f}s, Total Time: {total_time:.2f}s" |
|
|
) |
|
|
|
|
|
|
|
|
judges_metrics.append( |
|
|
{ |
|
|
"judge_id": judge_id, |
|
|
"judge_name": judge_results["judge_name"].iloc[0], |
|
|
"dataset": dataset_name, |
|
|
"f1": f1, |
|
|
"bacc": bacc, |
|
|
"avg_latency": avg_latency, |
|
|
"total_time": total_time, |
|
|
"count": len(judge_results), |
|
|
"correct": judge_results["correct"].sum(), |
|
|
}, |
|
|
) |
|
|
|
|
|
judges_metrics_df = pd.DataFrame(judges_metrics) |
|
|
judges_metrics_filename = f"{dataset_name.replace('/', '-')}-judges-metrics.csv" |
|
|
judges_metrics_path = f"benchmarks/{judges_metrics_filename}" |
|
|
judges_metrics_df.to_csv(judges_metrics_path, index=False) |
|
|
logger.info(f"Judge metrics saved to {judges_metrics_path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Evaluate HuggingFace datasets against AI judges") |
|
|
parser.add_argument("dataset", help="HuggingFace dataset name (e.g., 'truthful_qa')") |
|
|
parser.add_argument("--models", default="models.jsonl", help="Path to models JSONL file") |
|
|
parser.add_argument( |
|
|
"--max-samples", |
|
|
type=int, |
|
|
help="Maximum number of samples to evaluate", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test-type", |
|
|
choices=[ |
|
|
"prompt_injections", |
|
|
"safety", |
|
|
"grounding", |
|
|
"policy", |
|
|
"generic", |
|
|
], |
|
|
help="Override test type (default: auto-detect from dataset name)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--dataset-config", |
|
|
help="Dataset config name (e.g., 'train' for allenai/wildjailbreak)", |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
evaluate_dataset( |
|
|
args.dataset, |
|
|
args.models, |
|
|
args.max_samples, |
|
|
args.test_type, |
|
|
args.dataset_config, |
|
|
) |
|
|
|