Spaces:
Running
Running
File size: 4,297 Bytes
65d7391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import numpy as np
import torch
from PIL import Image
from .models.cmx.builder_np_conf import myEncoderDecoder as TruForNetwork
LOGGER = logging.getLogger(__name__)
@dataclass(frozen=True)
class TruForOutputs:
"""Lightweight container for TruFor inference outputs."""
tamper_map: np.ndarray
confidence_map: Optional[np.ndarray]
detection_score: Optional[float]
class TruForBundledModel:
"""Loads the TruFor network from the vendored sources and runs inference."""
def __init__(self, weights_path: Path | str, device: str = "cpu") -> None:
self.weights_path = Path(weights_path)
if not self.weights_path.exists():
raise FileNotFoundError(f"TruFor weights missing at {self.weights_path}")
try:
self.device = torch.device(device)
except RuntimeError as exc: # pragma: no cover - defensive path for invalid strings
raise ValueError(f"Unsupported torch device '{device}'") from exc
self.model = self._build_model().to(self.device)
self.model.eval()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def predict(self, image: Image.Image) -> TruForOutputs:
if image is None:
raise ValueError("An input image is required for TruFor inference.")
tensor = self._prepare_tensor(image).to(self.device)
with torch.inference_mode():
pred, conf, det, _ = self.model(tensor)
tamper_map = torch.softmax(pred[0], dim=0)[1].cpu().numpy()
confidence_map: Optional[np.ndarray] = None
if conf is not None:
confidence_map = torch.sigmoid(conf[0][0]).cpu().numpy()
detection_score: Optional[float] = None
if det is not None:
detection_score = torch.sigmoid(det).item()
return TruForOutputs(
tamper_map=tamper_map,
confidence_map=confidence_map,
detection_score=detection_score,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_model(self) -> torch.nn.Module:
cfg = self._default_config()
model = TruForNetwork(cfg=cfg)
checkpoint = torch.load(self.weights_path, map_location="cpu", weights_only=False)
state_dict = checkpoint.get("state_dict", checkpoint)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
LOGGER.warning("TruFor missing keys: %s", sorted(missing))
if unexpected:
LOGGER.warning("TruFor unexpected keys: %s", sorted(unexpected))
return model
@staticmethod
def _prepare_tensor(image: Image.Image) -> torch.Tensor:
rgb = np.asarray(image.convert("RGB"), dtype=np.float32)
tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0)
tensor = tensor / 256.0 # matches the reference implementation
return tensor
class AttrNamespace(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError as exc:
raise AttributeError(item) from exc
def __setattr__(self, key, value):
self[key] = value
def __contains__(self, item):
return item in self.keys()
@classmethod
def _default_config(cls) -> AttrNamespace:
extra = cls.AttrNamespace(
BACKBONE="mit_b2",
DECODER="MLPDecoder",
DECODER_EMBED_DIM=512,
PREPRC="imagenet",
BN_EPS=0.001,
BN_MOMENTUM=0.1,
DETECTION="confpool",
CONF=True,
NP_WEIGHTS="",
)
model = cls.AttrNamespace(
NAME="detconfcmx",
MODS=("RGB", "NP++"),
PRETRAINED="",
EXTRA=extra,
)
dataset = cls.AttrNamespace(NUM_CLASSES=2)
return cls.AttrNamespace(MODEL=model, DATASET=dataset)
|