# hf_model.py # ---- keep TF/JAX silent and unused ---- import os os.environ["TRANSFORMERS_NO_TF"] = "1" # don't load TensorFlow os.environ["TRANSFORMERS_NO_FLAX"] = "1" # don't load Flax/JAX os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # silence TF C++ logs (INFO/WARN) # optional: extra hush try: from transformers.utils import logging as hf_logging hf_logging.set_verbosity_error() import tensorflow as tf tf.get_logger().setLevel("ERROR") import absl.logging absl.logging.set_verbosity(absl.logging.ERROR) except Exception: pass from typing import Dict, List, Tuple from PIL import Image import torch from transformers import AutoImageProcessor, AutoModelForImageClassification # A solid, ready-made age classifier (ViT backbone). # It predicts age ranges as classes; we map ranges -> a point estimate. HF_MODEL_ID = "nateraw/vit-age-classifier" # Map model class labels to numeric midpoints for an age estimate # (these labels come with the above model). AGE_RANGE_TO_MID = { "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35, "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75 } class PretrainedAgeEstimator: def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None, use_fast: bool = True): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True) self.model = AutoModelForImageClassification.from_pretrained(model_id) self.model.to(self.device).eval() # keep label mapping self.id2label: Dict[int, str] = self.model.config.id2label self.label2id: Dict[str, int] = self.model.config.label2id @torch.inference_mode() def predict(self, img: Image.Image, topk: int = 3) -> Tuple[float, List[Tuple[str, float]]]: """Return (point_estimate_age, [(label, prob), ...]).""" if img.mode != "RGB": img = img.convert("RGB") inputs = self.processor(images=img, return_tensors="pt").to(self.device) logits = self.model(**inputs).logits probs = logits.softmax(dim=-1).squeeze(0) # [num_classes] # Top-k labels topk = min(topk, probs.numel()) values, indices = torch.topk(probs, k=topk) top = [(self.id2label[i.item()], v.item()) for i, v in zip(indices, values)] # Point estimate from expectation over range midpoints expected = 0.0 for i, p in enumerate(probs): label = self.id2label[i] mid = AGE_RANGE_TO_MID.get(label, 35) # default mid if unseen label expected += mid * p.item() return expected, top