File size: 4,297 Bytes
65d7391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from PIL import Image

from .models.cmx.builder_np_conf import myEncoderDecoder as TruForNetwork

LOGGER = logging.getLogger(__name__)


@dataclass(frozen=True)
class TruForOutputs:
    """Lightweight container for TruFor inference outputs."""

    tamper_map: np.ndarray
    confidence_map: Optional[np.ndarray]
    detection_score: Optional[float]


class TruForBundledModel:
    """Loads the TruFor network from the vendored sources and runs inference."""

    def __init__(self, weights_path: Path | str, device: str = "cpu") -> None:
        self.weights_path = Path(weights_path)
        if not self.weights_path.exists():
            raise FileNotFoundError(f"TruFor weights missing at {self.weights_path}")

        try:
            self.device = torch.device(device)
        except RuntimeError as exc:  # pragma: no cover - defensive path for invalid strings
            raise ValueError(f"Unsupported torch device '{device}'") from exc

        self.model = self._build_model().to(self.device)
        self.model.eval()

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------
    def predict(self, image: Image.Image) -> TruForOutputs:
        if image is None:
            raise ValueError("An input image is required for TruFor inference.")

        tensor = self._prepare_tensor(image).to(self.device)

        with torch.inference_mode():
            pred, conf, det, _ = self.model(tensor)

            tamper_map = torch.softmax(pred[0], dim=0)[1].cpu().numpy()

            confidence_map: Optional[np.ndarray] = None
            if conf is not None:
                confidence_map = torch.sigmoid(conf[0][0]).cpu().numpy()

            detection_score: Optional[float] = None
            if det is not None:
                detection_score = torch.sigmoid(det).item()

        return TruForOutputs(
            tamper_map=tamper_map,
            confidence_map=confidence_map,
            detection_score=detection_score,
        )

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------
    def _build_model(self) -> torch.nn.Module:
        cfg = self._default_config()
        model = TruForNetwork(cfg=cfg)
        checkpoint = torch.load(self.weights_path, map_location="cpu", weights_only=False)
        state_dict = checkpoint.get("state_dict", checkpoint)

        missing, unexpected = model.load_state_dict(state_dict, strict=False)
        if missing:
            LOGGER.warning("TruFor missing keys: %s", sorted(missing))
        if unexpected:
            LOGGER.warning("TruFor unexpected keys: %s", sorted(unexpected))

        return model

    @staticmethod
    def _prepare_tensor(image: Image.Image) -> torch.Tensor:
        rgb = np.asarray(image.convert("RGB"), dtype=np.float32)
        tensor = torch.from_numpy(rgb.transpose(2, 0, 1)).unsqueeze(0)
        tensor = tensor / 256.0  # matches the reference implementation
        return tensor

    class AttrNamespace(dict):
        def __getattr__(self, item):
            try:
                return self[item]
            except KeyError as exc:
                raise AttributeError(item) from exc

        def __setattr__(self, key, value):
            self[key] = value

        def __contains__(self, item):
            return item in self.keys()

    @classmethod
    def _default_config(cls) -> AttrNamespace:
        extra = cls.AttrNamespace(
            BACKBONE="mit_b2",
            DECODER="MLPDecoder",
            DECODER_EMBED_DIM=512,
            PREPRC="imagenet",
            BN_EPS=0.001,
            BN_MOMENTUM=0.1,
            DETECTION="confpool",
            CONF=True,
            NP_WEIGHTS="",
        )

        model = cls.AttrNamespace(
            NAME="detconfcmx",
            MODS=("RGB", "NP++"),
            PRETRAINED="",
            EXTRA=extra,
        )

        dataset = cls.AttrNamespace(NUM_CLASSES=2)

        return cls.AttrNamespace(MODEL=model, DATASET=dataset)