File size: 2,764 Bytes
df3ab33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# hf_model.py
# ---- keep TF/JAX silent and unused ----
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"      # don't load TensorFlow
os.environ["TRANSFORMERS_NO_FLAX"] = "1"   # don't load Flax/JAX
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"   # silence TF C++ logs (INFO/WARN)

# optional: extra hush
try:
    from transformers.utils import logging as hf_logging
    hf_logging.set_verbosity_error()
    import tensorflow as tf
    tf.get_logger().setLevel("ERROR")
    import absl.logging
    absl.logging.set_verbosity(absl.logging.ERROR)
except Exception:
    pass



from typing import Dict, List, Tuple
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification



# A solid, ready-made age classifier (ViT backbone).
# It predicts age ranges as classes; we map ranges -> a point estimate.
HF_MODEL_ID = "nateraw/vit-age-classifier"

# Map model class labels to numeric midpoints for an age estimate
# (these labels come with the above model).
AGE_RANGE_TO_MID = {
    "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
    "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
}

class PretrainedAgeEstimator:
    def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None, use_fast: bool = True):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)

        self.model = AutoModelForImageClassification.from_pretrained(model_id)
        self.model.to(self.device).eval()

        # keep label mapping
        self.id2label: Dict[int, str] = self.model.config.id2label
        self.label2id: Dict[str, int] = self.model.config.label2id

    @torch.inference_mode()
    def predict(self, img: Image.Image, topk: int = 3) -> Tuple[float, List[Tuple[str, float]]]:
        """Return (point_estimate_age, [(label, prob), ...])."""
        if img.mode != "RGB":
            img = img.convert("RGB")

        inputs = self.processor(images=img, return_tensors="pt").to(self.device)
        logits = self.model(**inputs).logits
        probs = logits.softmax(dim=-1).squeeze(0)  # [num_classes]

        # Top-k labels
        topk = min(topk, probs.numel())
        values, indices = torch.topk(probs, k=topk)
        top = [(self.id2label[i.item()], v.item()) for i, v in zip(indices, values)]

        # Point estimate from expectation over range midpoints
        expected = 0.0
        for i, p in enumerate(probs):
            label = self.id2label[i]
            mid = AGE_RANGE_TO_MID.get(label, 35)  # default mid if unseen label
            expected += mid * p.item()

        return expected, top