from __future__ import annotations import logging from dataclasses import dataclass from pathlib import Path from typing import Optional import numpy as np import torch from PIL import Image from .models.cmx.builder_np_conf import myEncoderDecoder as TruForNetwork LOGGER = logging.getLogger(__name__) @dataclass(frozen=True) class TruForOutputs: """Lightweight container for TruFor inference outputs.""" tamper_map: np.ndarray confidence_map: Optional[np.ndarray] detection_score: Optional[float] class TruForBundledModel: """Loads the TruFor network from the vendored sources and runs inference.""" def __init__(self, weights_path: Path | str, device: str = "cpu") -> None: self.weights_path = Path(weights_path) if not self.weights_path.exists(): raise FileNotFoundError(f"TruFor weights missing at {self.weights_path}") try: self.device = torch.device(device) except RuntimeError as exc: # pragma: no cover - defensive path for invalid strings raise ValueError(f"Unsupported torch device '{device}'") from exc self.model = self._build_model().to(self.device) self.model.eval() # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def predict(self, image: Image.Image) -> TruForOutputs: if image is None: raise ValueError("An input image is required for TruFor inference.") tensor = self._prepare_tensor(image).to(self.device) with torch.inference_mode(): pred, conf, det, _ = self.model(tensor) tamper_map = torch.softmax(pred[0], dim=0)[1].cpu().numpy() confidence_map: Optional[np.ndarray] = None if conf is not None: confidence_map = torch.sigmoid(conf[0][0]).cpu().numpy() detection_score: Optional[float] = None if det is not None: detection_score = torch.sigmoid(det).item() return TruForOutputs( tamper_map=tamper_map, confidence_map=confidence_map, detection_score=detection_score, ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _build_model(self) -> torch.nn.Module: cfg = self._default_config() model = TruForNetwork(cfg=cfg) checkpoint = torch.load(self.weights_path, map_location="cpu", weights_only=False) state_dict = checkpoint.get("state_dict", checkpoint) missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: LOGGER.warning("TruFor missing keys: %s", sorted(missing)) if unexpected: LOGGER.warning("TruFor unexpected keys: %s", sorted(unexpected)) return model @staticmethod def _prepare_tensor(image: Image.Image) -> torch.Tensor: rgb = np.asarray(image.convert("RGB"), dtype=np.float32) tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0) tensor = tensor / 256.0 # matches the reference implementation return tensor class AttrNamespace(dict): def __getattr__(self, item): try: return self[item] except KeyError as exc: raise AttributeError(item) from exc def __setattr__(self, key, value): self[key] = value def __contains__(self, item): return item in self.keys() @classmethod def _default_config(cls) -> AttrNamespace: extra = cls.AttrNamespace( BACKBONE="mit_b2", DECODER="MLPDecoder", DECODER_EMBED_DIM=512, PREPRC="imagenet", BN_EPS=0.001, BN_MOMENTUM=0.1, DETECTION="confpool", CONF=True, NP_WEIGHTS="", ) model = cls.AttrNamespace( NAME="detconfcmx", MODS=("RGB", "NP++"), PRETRAINED="", EXTRA=extra, ) dataset = cls.AttrNamespace(NUM_CLASSES=2) return cls.AttrNamespace(MODEL=model, DATASET=dataset)