| | """ |
| | Sequence Prediction Evaluation with QwenImageEditPlusPipeline / Flux2KleinPipeline. |
| | |
| | Evaluates the model's ability to predict the next number in a sequence |
| | by generating images and extracting answers via OCR. |
| | """ |
| |
|
| | import json |
| | import re |
| | from pathlib import Path |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| |
|
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from tqdm import tqdm |
| |
|
| |
|
| | class ModelType(str, Enum): |
| | QWEN_IMAGE_EDIT = "qwen" |
| | FLUX2_KLEIN = "flux2-klein" |
| |
|
| |
|
| | @dataclass |
| | class EvalConfig: |
| | """Evaluation configuration.""" |
| | dataset_dir: str = "sequence_dataset" |
| | output_dir: str = "eval_results" |
| | |
| | |
| | model_type: ModelType = ModelType.QWEN_IMAGE_EDIT |
| | model_id: str = "" |
| | |
| | |
| | prompt: str = ( |
| | "Based on the number patterns shown in the previous images, " |
| | "fill in the missing number in the empty cell of the last image." |
| | ) |
| | negative_prompt: str = "" |
| | |
| | |
| | num_inference_steps: int = 5 |
| | guidance_scale: float = 1.0 |
| | true_cfg_scale: float = 4.0 |
| | height: int = 210 |
| | width: int = 750 |
| | |
| | seed: int = 42 |
| | device: str = "cuda" |
| | dtype: torch.dtype = field(default_factory=lambda: torch.bfloat16) |
| | |
| | def __post_init__(self): |
| | """Set default model_id based on model_type.""" |
| | if not self.model_id: |
| | if self.model_type == ModelType.QWEN_IMAGE_EDIT: |
| | self.model_id = "Qwen/Qwen-Image-Edit-2509" |
| | elif self.model_type == ModelType.FLUX2_KLEIN: |
| | self.model_id = "black-forest-labs/FLUX.2-klein-9B" |
| |
|
| |
|
| | class OCRExtractor: |
| | """Extract numbers from grid images using OCR.""" |
| | |
| | def __init__(self, backend: str = "easyocr"): |
| | """ |
| | Args: |
| | backend: OCR backend ("easyocr" or "pytesseract"). |
| | """ |
| | self.backend = backend |
| | if backend == "easyocr": |
| | import easyocr |
| | self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) |
| | elif backend == "pytesseract": |
| | import pytesseract |
| | self.pytesseract = pytesseract |
| | else: |
| | raise ValueError(f"Unknown backend: {backend}") |
| | |
| | def extract_last_number(self, image: Image.Image) -> int | None: |
| | """ |
| | Extract the last (rightmost) number from a grid image. |
| | |
| | Args: |
| | image: PIL Image of the number grid. |
| | |
| | Returns: |
| | Extracted number or None if extraction fails. |
| | """ |
| | w, h = image.size |
| | cell_crop = image.crop((w * 3 // 4, 0, w, h)) |
| | cell_array = np.array(cell_crop) |
| | |
| | if self.backend == "easyocr": |
| | results = self.reader.readtext(cell_array) |
| | for _, text, conf in results: |
| | digits = re.findall(r'-?\d+', text) |
| | if digits: |
| | return int(digits[0]) |
| | |
| | elif self.backend == "pytesseract": |
| | text = self.pytesseract.image_to_string( |
| | cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-' |
| | ) |
| | digits = re.findall(r'-?\d+', text) |
| | if digits: |
| | return int(digits[0]) |
| | |
| | return None |
| | |
| | def extract_all_numbers(self, image: Image.Image, num_cells: int = 4) -> list[int | None]: |
| | """Extract all numbers from a grid image.""" |
| | w, h = image.size |
| | cell_width = w // num_cells |
| | numbers = [] |
| | |
| | for i in range(num_cells): |
| | cell_crop = image.crop((i * cell_width, 0, (i + 1) * cell_width, h)) |
| | cell_array = np.array(cell_crop) |
| | |
| | if self.backend == "easyocr": |
| | results = self.reader.readtext(cell_array) |
| | num = None |
| | for _, text, conf in results: |
| | digits = re.findall(r'-?\d+', text) |
| | if digits: |
| | num = int(digits[0]) |
| | break |
| | numbers.append(num) |
| | |
| | elif self.backend == "pytesseract": |
| | text = self.pytesseract.image_to_string( |
| | cell_crop, config='--psm 7 -c tessedit_char_whitelist=0123456789-' |
| | ) |
| | digits = re.findall(r'-?\d+', text) |
| | numbers.append(int(digits[0]) if digits else None) |
| | |
| | return numbers |
| |
|
| |
|
| | class SequenceEvaluator: |
| | """Evaluator for sequence prediction task.""" |
| | |
| | def __init__(self, config: EvalConfig): |
| | self.config = config |
| | self.output_dir = Path(config.output_dir) |
| | self.output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | self.pipeline = self._load_pipeline() |
| | |
| | |
| | self.ocr = OCRExtractor(backend="easyocr") |
| | |
| | def _load_pipeline(self): |
| | """Load pipeline based on model type.""" |
| | if self.config.model_type == ModelType.QWEN_IMAGE_EDIT: |
| | return self._load_qwen_pipeline() |
| | elif self.config.model_type == ModelType.FLUX2_KLEIN: |
| | return self._load_flux2_klein_pipeline() |
| | else: |
| | raise ValueError(f"Unknown model type: {self.config.model_type}") |
| | |
| | def _load_qwen_pipeline(self): |
| | """Load QwenImageEditPlusPipeline.""" |
| | from diffusers import QwenImageEditPlusPipeline |
| | |
| | pipeline = QwenImageEditPlusPipeline.from_pretrained( |
| | self.config.model_id, |
| | torch_dtype=self.config.dtype, |
| | ) |
| | pipeline.to(self.config.device) |
| | pipeline.set_progress_bar_config(disable=True) |
| | return pipeline |
| | |
| | def _load_flux2_klein_pipeline(self): |
| | """Load Flux2KleinPipeline.""" |
| | from diffusers import Flux2KleinPipeline |
| | |
| | pipeline = Flux2KleinPipeline.from_pretrained( |
| | self.config.model_id, |
| | torch_dtype=self.config.dtype, |
| | ) |
| | pipeline.enable_model_cpu_offload() |
| | pipeline.set_progress_bar_config(disable=True) |
| | return pipeline |
| | |
| | def _load_images(self, image_paths: list[str], image_dir: Path) -> list[Image.Image]: |
| | """Load images from paths.""" |
| | return [Image.open(image_dir / p).convert("RGB") for p in image_paths] |
| | |
| | def predict(self, images: list[Image.Image]) -> Image.Image: |
| | """ |
| | Generate prediction image given input images. |
| | |
| | Args: |
| | images: List of input images (context + query). |
| | |
| | Returns: |
| | Generated image with predicted number. |
| | """ |
| | generator = torch.Generator(device=self.config.device).manual_seed(self.config.seed) |
| | |
| | if self.config.model_type == ModelType.QWEN_IMAGE_EDIT: |
| | inputs = { |
| | "image": images, |
| | "prompt": self.config.prompt, |
| | "generator": generator, |
| | "true_cfg_scale": self.config.true_cfg_scale, |
| | "negative_prompt": self.config.negative_prompt, |
| | "num_inference_steps": self.config.num_inference_steps, |
| | } |
| | |
| | elif self.config.model_type == ModelType.FLUX2_KLEIN: |
| | |
| | inputs = { |
| | "image": images, |
| | "prompt": self.config.prompt, |
| | "generator": generator, |
| | "guidance_scale": self.config.guidance_scale, |
| | "num_inference_steps": self.config.num_inference_steps, |
| | "height": self.config.height, |
| | "width": self.config.width, |
| | } |
| | |
| | with torch.inference_mode(): |
| | output = self.pipeline(**inputs) |
| | |
| | return output.images[0] |
| | |
| | def evaluate_sample(self, sample: dict, image_dir: Path) -> dict: |
| | """ |
| | Evaluate a single sample. |
| | |
| | Args: |
| | sample: Sample metadata dict. |
| | image_dir: Directory containing images. |
| | |
| | Returns: |
| | Evaluation result dict. |
| | """ |
| | |
| | images = self._load_images(sample["images"], image_dir) |
| | |
| | |
| | pred_image = self.predict(images) |
| | |
| | |
| | pred_path = self.output_dir / f"{sample['id']:05d}_pred.png" |
| | pred_image.save(pred_path) |
| | |
| | |
| | pred_number = self.ocr.extract_last_number(pred_image) |
| | |
| | |
| | gt_number = sample["answer"] |
| | |
| | |
| | correct = pred_number == gt_number |
| | |
| | return { |
| | "id": sample["id"], |
| | "seq_type": sample["seq_type"], |
| | "gt_answer": gt_number, |
| | "pred_answer": pred_number, |
| | "correct": correct, |
| | "pred_image": str(pred_path), |
| | } |
| | |
| | def evaluate(self, split: str = "test") -> dict: |
| | """ |
| | Evaluate on entire dataset split. |
| | |
| | Args: |
| | split: Dataset split ("train" or "test"). |
| | |
| | Returns: |
| | Evaluation results summary. |
| | """ |
| | dataset_dir = Path(self.config.dataset_dir) |
| | |
| | |
| | with open(dataset_dir / f"{split}.json") as f: |
| | samples = json.load(f) |
| | |
| | image_dir = dataset_dir / split / "images" |
| | |
| | results = [] |
| | for sample in tqdm(samples, desc=f"Evaluating {split}"): |
| | result = self.evaluate_sample(sample, image_dir) |
| | results.append(result) |
| | |
| | |
| | total = len(results) |
| | correct = sum(r["correct"] for r in results) |
| | accuracy = correct / total if total > 0 else 0.0 |
| | |
| | |
| | type_stats = {} |
| | for r in results: |
| | seq_type = r["seq_type"] |
| | if seq_type not in type_stats: |
| | type_stats[seq_type] = {"correct": 0, "total": 0} |
| | type_stats[seq_type]["total"] += 1 |
| | if r["correct"]: |
| | type_stats[seq_type]["correct"] += 1 |
| | |
| | type_accuracy = { |
| | k: v["correct"] / v["total"] for k, v in type_stats.items() |
| | } |
| | |
| | summary = { |
| | "split": split, |
| | "model_type": self.config.model_type.value, |
| | "model_id": self.config.model_id, |
| | "total": total, |
| | "correct": correct, |
| | "accuracy": accuracy, |
| | "type_accuracy": type_accuracy, |
| | "results": results, |
| | } |
| | |
| | |
| | with open(self.output_dir / f"{split}_results.json", "w") as f: |
| | json.dump(summary, f, indent=2) |
| | |
| | return summary |
| |
|
| |
|
| | def main(): |
| | """Run evaluation.""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Sequence Prediction Evaluation") |
| | parser.add_argument("--model", type=str, default="flux2-klein", |
| | choices=["qwen", "flux2-klein"], |
| | help="Model type to use") |
| | parser.add_argument("--model-id", type=str, default="", |
| | help="Custom model ID (optional)") |
| | parser.add_argument("--dataset-dir", type=str, default="sequence_dataset", |
| | help="Dataset directory") |
| | parser.add_argument("--output-dir", type=str, default="eval_results", |
| | help="Output directory") |
| | parser.add_argument("--steps", type=int, default=50, |
| | help="Number of inference steps") |
| | parser.add_argument("--seed", type=int, default=42, |
| | help="Random seed") |
| | args = parser.parse_args() |
| | |
| | config = EvalConfig( |
| | dataset_dir=args.dataset_dir, |
| | output_dir=args.output_dir, |
| | model_type=ModelType(args.model), |
| | model_id=args.model_id, |
| | num_inference_steps=args.steps, |
| | seed=args.seed, |
| | ) |
| | |
| | print(f"Model: {config.model_type.value} ({config.model_id})") |
| | |
| | evaluator = SequenceEvaluator(config) |
| | results = evaluator.evaluate("test") |
| | |
| | print(f"\n{'='*50}") |
| | print(f"Evaluation Results ({config.model_type.value})") |
| | print(f"{'='*50}") |
| | print(f"Total samples: {results['total']}") |
| | print(f"Correct: {results['correct']}") |
| | print(f"Accuracy: {results['accuracy']:.2%}") |
| | print(f"\nPer-type accuracy:") |
| | for seq_type, acc in sorted(results["type_accuracy"].items()): |
| | print(f" {seq_type}: {acc:.2%}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |