File size: 16,794 Bytes
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b070cd
5a05fa9
 
 
 
 
 
 
 
 
 
 
fba21bb
5a05fa9
 
fba21bb
5a05fa9
 
 
fba21bb
 
5a05fa9
 
fba21bb
 
 
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
424c620
 
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df184ed
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df184ed
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
a0c1734
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0c1734
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4403e4e
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b070cd
5a05fa9
 
 
df184ed
5a05fa9
 
 
df184ed
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df184ed
5a05fa9
 
 
 
 
fba21bb
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df184ed
 
5a05fa9
 
 
 
 
 
 
df184ed
6b070cd
df184ed
a43203f
 
df184ed
 
5a05fa9
df184ed
 
5a05fa9
df184ed
 
 
 
 
 
 
 
 
 
 
 
 
 
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
df184ed
 
5a05fa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df184ed
 
 
 
5a05fa9
 
 
 
 
 
df184ed
 
 
 
 
5a05fa9
 
df184ed
6b070cd
df184ed
 
 
 
 
 
5a05fa9
 
df184ed
 
5a05fa9
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
#!/usr/bin/env python3
# eval_arena.py - Evaluate HuggingFace datasets against AI judges

import argparse
import json
import time
from typing import Any, Dict, List

import pandas as pd
from datasets import load_dataset
from loguru import logger
from sklearn.metrics import balanced_accuracy_score, f1_score
from tqdm import tqdm

from src.judge import JudgeManager


def load_models(
    models_path: str = "models.jsonl",
) -> List[Dict[str, Any]]:
    """Load models from a JSONL file"""
    models = []
    with open(models_path, "r") as f:
        for line in f:
            if line.strip():  # Skip empty lines
                try:
                    models.append(json.loads(line))
                except json.JSONDecodeError:
                    logger.warning(f"Failed to parse line: {line}")
    return models


def get_test_type_from_dataset(dataset_name: str) -> str:
    """Determine the appropriate test type based on the dataset name"""
    if any(kw in dataset_name.lower() for kw in ["injection", "jailbreak"]):
        return "prompt_injections"
    elif any(kw in dataset_name.lower() for kw in ["safety", "harm", "toxic"]):
        return "safety"
    elif any(kw in dataset_name.lower() for kw in ["fact", "ground", "truthful"]):
        return "grounding"
    elif any(kw in dataset_name.lower() for kw in ["policy", "guideline", "ethics"]):
        return "policy"
    else:
        return "generic"  # Generic evaluation as fallback


def parse_ground_truth(
    label_column: str,
    label: Any,
    safe_label: str = "SAFE",
    unsafe_label: str = "UNSAFE",
) -> str:
    """Convert various label formats to a standardized format for comparison"""
    if isinstance(label, bool):
        if label_column.lower() == "safe":
            return "SAFE" if not label else "UNSAFE"
        return "SAFE" if label else "UNSAFE"
    elif isinstance(label, (int, float)):
        if label_column.lower() == "safe":
            return safe_label if label > 0.5 else unsafe_label
        return "UNSAFE" if label > 0.5 else "SAFE"
    elif isinstance(label, str):
        label = label.upper()
        # Map common label formats to the ones used by judges
        mappings = {
            "TRUE": "GROUNDED",
            "FALSE": "UNGROUNDED",
            "YES": "GROUNDED",
            "NO": "UNGROUNDED",
            "PASS": "SAFE",
            "FAIL": "UNSAFE",
            "UNSAFE": "UNSAFE",
            "SAFE": "SAFE",
            "HARMFUL": "UNSAFE",
            "NOT_HARMFUL": "SAFE",
            "COMPLIANT": "COMPLIES",
            "NONCOMPLIANT": "VIOLATION",
            "GOOD": "GOOD_RESPONSE",
            "BAD": "POOR_RESPONSE",
            "BENIGN": "SAFE",
            "JAILBREAK": "PROMPT_INJECTION",
        }
        return mappings.get(label, label)
    else:
        logger.warning(f"Unrecognized label format: {label}, type: {type(label)}")
        return str(label).upper()


def calculate_metrics(
    predictions: List[str],
    ground_truth: List[str],
) -> Dict[str, float]:
    """Calculate performance metrics"""
    metrics = {}

    # Filter out any pairs where we don't have both prediction and ground truth
    valid_pairs = [(p, gt) for p, gt in zip(predictions, ground_truth) if p and gt]

    if not valid_pairs:
        logger.warning("No valid prediction-ground truth pairs for metric calculation")
        return {"f1": 0, "balanced_accuracy": 0}

    preds, gts = zip(*valid_pairs)

    # Get unique labels
    unique_labels = list(set(preds) | set(gts))

    if len(unique_labels) == 1:
        # Only one class present, can't calculate balanced accuracy
        metrics["balanced_accuracy"] = 1.0 if preds == gts else 0.0
    else:
        try:
            metrics["balanced_accuracy"] = balanced_accuracy_score(gts, preds)
        except Exception as e:
            logger.error(f"Error calculating balanced accuracy: {e}")
            metrics["balanced_accuracy"] = 0

    try:
        # Try multi-class F1
        metrics["f1"] = f1_score(gts, preds, average="weighted", zero_division=0)
    except Exception as e:
        logger.error(f"Error calculating F1 score: {e}")
        metrics["f1"] = 0

    return metrics


def extract_label_from_evaluation(evaluation: Dict[str, Any]) -> str:
    """Extract the label from the judge evaluation result"""
    # Check if we have a raw evaluation string
    if "evaluation" in evaluation:
        eval_text = evaluation["evaluation"]
        # Look for "LABEL:" in the evaluation text
        import re

        label_match = re.search(r"LABEL:\s*(\w+(?:_\w+)*)", eval_text, re.IGNORECASE)
        if label_match:
            return label_match.group(1).upper()

    # If no label found in evaluation, try other fields
    if "label" in evaluation:
        return evaluation["label"].upper()

    logger.warning(f"Could not extract label from evaluation: {evaluation}")
    return ""


def evaluate_dataset(
    dataset_name: str,
    models_path: str = "models.jsonl",
    max_samples: int = None,
    test_type: str = None,
    dataset_config: str = None,
) -> None:
    """Main function to evaluate a dataset against AI judges"""
    logger.info(f"Evaluating dataset: {dataset_name}")

    # Load models from models.jsonl
    models = load_models(models_path)
    if not models:
        logger.error("No models found in models.jsonl")
        return

    logger.info(f"Loaded {len(models)} models")

    # Initialize JudgeManager with models
    judge_manager = JudgeManager(models)

    # Determine which split to use
    try:
        # Load the dataset with config if provided
        if dataset_config:
            logger.info(f"Using dataset config: {dataset_config}")
            dataset = load_dataset(dataset_name, dataset_config)
        else:
            try:
                dataset = load_dataset(dataset_name)
            except ValueError as e:
                # If error mentions config name is missing, provide help
                if "Config name is missing" in str(e):
                    logger.error(f"This dataset requires a config name. {str(e)}")
                    logger.error("Please use --dataset-config to specify the config.")
                    return
                raise e

        logger.info(f"Available splits: {list(dataset.keys())}")

        # Prefer test split if available, otherwise use validation or train
        if "test" in dataset:
            split = "test"
        elif "validation" in dataset:
            split = "validation"
        elif "train" in dataset:
            split = "train"
        else:
            # Use the first available split
            split = list(dataset.keys())[0]

        logger.info(f"Using split: {split}")
        data = dataset[split]

        # Limit the number of samples if specified
        if max_samples and max_samples > 0:
            data = data.select(range(min(max_samples, len(data))))

        logger.info(f"Dataset contains {len(data)} samples")
    except Exception as e:
        logger.error(f"Error loading dataset {dataset_name}: {e}")
        return

    # Try to determine the columns for input and output
    # This is a heuristic as different datasets have different structures
    column_names = data.column_names
    logger.info(f"Dataset columns: {column_names}")

    # Look for common column names that might contain input text
    input_column = None
    possible_input_names = [
        "input",
        "question",
        "prompt",
        "instruction",
        "context",
        "text",
        "adversarial",
        "doc",
    ]
    for possible_name in possible_input_names:
        matches = [col for col in column_names if possible_name in col.lower()]
        if matches:
            input_column = matches[0]
            break

    # If still not found, try to use the first string column
    if not input_column:
        for col in column_names:
            if isinstance(data[0][col], str):
                input_column = col
                break

    # Similar approach for output column
    output_column = None
    possible_output_names = [
        "output",
        "answer",
        "response",
        "completion",
        "generation",
        "claim",
    ]
    for possible_name in possible_output_names:
        matches = [col for col in column_names if possible_name in col.lower()]
        if matches:
            output_column = matches[0]
            break

    # Look for label/ground truth column
    label_column = None
    possible_label_names = [
        "label",
        "ground_truth",
        "class",
        "target",
        "gold",
        "correct",
        "type",
        "safe",
    ]
    for possible_name in possible_label_names:
        matches = [col for col in column_names if possible_name in col.lower()]
        if matches:
            label_column = matches[0]
            break

    # Determine test type based on dataset name or use provided test_type
    if test_type:
        logger.info(f"Using provided test type: {test_type}")
    else:
        test_type = get_test_type_from_dataset(dataset_name)
        logger.info(f"Auto-detected test type: {test_type}")

    # Check if we have the minimum required columns based on test type
    input_only_test_types = ["safety", "prompt_injections"]
    requires_output = test_type not in input_only_test_types

    if not input_column:
        logger.error("Could not determine input column, required for all test types.")
        return

    if requires_output and not output_column:
        logger.error(f"Test type '{test_type}' requires output column, none found.")
        return

    # Log what columns we're using
    column_info = f"Using columns: input={input_column}"
    if output_column:
        column_info += f", output={output_column}"
    if label_column:
        column_info += f", label={label_column}"
    else:
        logger.warning("No label column found. Cannot compare against ground truth.")

    logger.info(column_info)

    # Initialize results storage
    raw_results = []
    judge_metrics = {
        judge["id"]: {
            "judge_id": judge["id"],
            "judge_name": judge["name"],
            "predictions": [],
            "ground_truths": [],
            "total_time": 0,
            "count": 0,
            "correct": 0,
        }
        for judge in models
    }

    # Process each sample in the dataset
    for i, sample in enumerate(tqdm(data)):
        input_text = sample[input_column]

        # Use empty string as output if output column is not available
        # but only for test types that can work with just input
        output_text = ""
        if output_column and output_column in sample:
            output_text = sample[output_column]
        elif requires_output:
            logger.warning(f"Sample {i} missing output field required for '{test_type}'")
            continue

        # Get ground truth if available
        ground_truth = None
        if label_column and label_column in sample:
            ground_truth = parse_ground_truth(label_column, sample[label_column])

        # Evaluate with each judge
        for judge in models:
            judge_id = judge["id"]

            try:
                # Time the evaluation
                start_time = time.time()
                logger.info(f"Evaluating sample {i} with judge {judge_id}")
                # Get evaluation from judge
                evaluation = judge_manager.get_evaluation(
                    judge=judge,
                    input_text=input_text,
                    output_text=output_text,
                    test_type=test_type,
                )

                elapsed_time = time.time() - start_time

                # Extract label from evaluation
                prediction = extract_label_from_evaluation(evaluation)

                # Store raw result
                raw_result = {
                    "dataset": dataset_name,
                    "sample_id": i,
                    "judge_id": judge_id,
                    "judge_name": judge["name"],
                    "input": input_text,
                    "output": output_text,
                    "prediction": prediction,
                    "ground_truth": ground_truth,
                    "latency": elapsed_time,
                    "evaluation": evaluation.get("evaluation", ""),
                }
                raw_results.append(raw_result)

                # Update metrics
                judge_metrics[judge_id]["predictions"].append(prediction)
                judge_metrics[judge_id]["total_time"] += elapsed_time
                judge_metrics[judge_id]["count"] += 1

                if ground_truth:
                    judge_metrics[judge_id]["ground_truths"].append(ground_truth)
                    if prediction == ground_truth:
                        judge_metrics[judge_id]["correct"] += 1

            except Exception as e:
                logger.error(f"Error evaluating sample {i} with judge {judge_id}: {e}")

    # Save raw results
    raw_df = pd.DataFrame(raw_results)
    raw_results_filename = f"{dataset_name.replace('/', '-')}-raw-results.csv"
    raw_results_path = f"benchmarks/{raw_results_filename}"
    raw_df.to_csv(raw_results_path, index=False)
    logger.info(f"Raw results saved to {raw_results_path}")

    # Calculate final metrics for each judge
    judges_metrics = []

    for judge_id in raw_df["judge_id"].unique():
        judge_results = raw_df[raw_df["judge_id"] == judge_id].copy()
        if test_type == "prompt_injections":
            gt_label = "PROMPT_INJECTION"
        elif test_type == "grounding":
            gt_label = "GROUNDED"
        else:
            gt_label = "UNSAFE"

        judge_results["ground_truth_binary"] = judge_results["ground_truth"].isin(
            [gt_label],
        )
        judge_results["prediction_binary"] = judge_results["prediction"].isin(
            [gt_label],
        )

        if judge_results["ground_truth_binary"].any():
            f1 = f1_score(
                judge_results["ground_truth_binary"],
                judge_results["prediction_binary"],
                average="binary",
                pos_label=True,  # We are comparing boolean Series now
            )
        else:
            logger.warning(f"No positive ground truth samples for judge {judge_id}, setting F1 to 0.")
            f1 = 0.0

        bacc = balanced_accuracy_score(
            judge_results["ground_truth"].astype(str),
            judge_results["prediction"].astype(str),
        )

        judge_results["correct"] = judge_results["prediction"] == judge_results["ground_truth"]

        avg_latency = judge_results["latency"].mean()
        total_time = judge_results["latency"].sum()

        print(
            f"Judge {judge_id} F1: {f1:.4f}, BAcc: {bacc:.4f}, "
            f"Avg Latency: {avg_latency:.2f}s, Total Time: {total_time:.2f}s"
        )

        # aggregate the metrics to a dataframe
        judges_metrics.append(
            {
                "judge_id": judge_id,
                "judge_name": judge_results["judge_name"].iloc[0],
                "dataset": dataset_name,
                "f1": f1,
                "bacc": bacc,
                "avg_latency": avg_latency,
                "total_time": total_time,
                "count": len(judge_results),
                "correct": judge_results["correct"].sum(),
            },
        )

    judges_metrics_df = pd.DataFrame(judges_metrics)
    judges_metrics_filename = f"{dataset_name.replace('/', '-')}-judges-metrics.csv"
    judges_metrics_path = f"benchmarks/{judges_metrics_filename}"
    judges_metrics_df.to_csv(judges_metrics_path, index=False)
    logger.info(f"Judge metrics saved to {judges_metrics_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate HuggingFace datasets against AI judges")
    parser.add_argument("dataset", help="HuggingFace dataset name (e.g., 'truthful_qa')")
    parser.add_argument("--models", default="models.jsonl", help="Path to models JSONL file")
    parser.add_argument(
        "--max-samples",
        type=int,
        help="Maximum number of samples to evaluate",
    )
    parser.add_argument(
        "--test-type",
        choices=[
            "prompt_injections",
            "safety",
            "grounding",
            "policy",
            "generic",
        ],
        help="Override test type (default: auto-detect from dataset name)",
    )
    parser.add_argument(
        "--dataset-config",
        help="Dataset config name (e.g., 'train' for allenai/wildjailbreak)",
    )

    args = parser.parse_args()

    evaluate_dataset(
        args.dataset,
        args.models,
        args.max_samples,
        args.test_type,
        args.dataset_config,
    )