Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +40 -70
- hf_model.py +73 -0
- infer.py +18 -0
app.py
CHANGED
|
@@ -1,70 +1,40 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
):
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
"""
|
| 44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 45 |
-
"""
|
| 46 |
-
chatbot = gr.ChatInterface(
|
| 47 |
-
respond,
|
| 48 |
-
type="messages",
|
| 49 |
-
additional_inputs=[
|
| 50 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
| 51 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 52 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 53 |
-
gr.Slider(
|
| 54 |
-
minimum=0.1,
|
| 55 |
-
maximum=1.0,
|
| 56 |
-
value=0.95,
|
| 57 |
-
step=0.05,
|
| 58 |
-
label="Top-p (nucleus sampling)",
|
| 59 |
-
),
|
| 60 |
-
],
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
with gr.Blocks() as demo:
|
| 64 |
-
with gr.Sidebar():
|
| 65 |
-
gr.LoginButton()
|
| 66 |
-
chatbot.render()
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
if __name__ == "__main__":
|
| 70 |
-
demo.launch()
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
import os
|
| 3 |
+
os.environ["TRANSFORMERS_NO_TF"] = "1"
|
| 4 |
+
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
|
| 5 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from hf_model import PretrainedAgeEstimator
|
| 10 |
+
|
| 11 |
+
est = PretrainedAgeEstimator()
|
| 12 |
+
|
| 13 |
+
def predict(img):
|
| 14 |
+
# Gradio may pass PIL or numpy; handle both
|
| 15 |
+
if not isinstance(img, Image.Image):
|
| 16 |
+
img = Image.fromarray(img)
|
| 17 |
+
|
| 18 |
+
age, top = est.predict(img, topk=5)
|
| 19 |
+
|
| 20 |
+
# 1) dict[str, float] for Label
|
| 21 |
+
probs = {lbl: float(prob) for lbl, prob in top}
|
| 22 |
+
|
| 23 |
+
# 2) plain string for the estimate
|
| 24 |
+
summary = f"Estimated age: **{age:.1f}** years"
|
| 25 |
+
|
| 26 |
+
return probs, summary
|
| 27 |
+
|
| 28 |
+
demo = gr.Interface(
|
| 29 |
+
fn=predict,
|
| 30 |
+
inputs=gr.Image(type="pil", label="Upload a face image"),
|
| 31 |
+
outputs=[
|
| 32 |
+
gr.Label(num_top_classes=5, label="Age Prediction (probabilities)"),
|
| 33 |
+
gr.Markdown(label="Summary"),
|
| 34 |
+
],
|
| 35 |
+
title="Pretrained Age Estimator",
|
| 36 |
+
description="Runs a pretrained ViT-based age classifier and reports a point estimate from class probabilities."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# hf_model.py
|
| 2 |
+
# ---- keep TF/JAX silent and unused ----
|
| 3 |
+
import os
|
| 4 |
+
os.environ["TRANSFORMERS_NO_TF"] = "1" # don't load TensorFlow
|
| 5 |
+
os.environ["TRANSFORMERS_NO_FLAX"] = "1" # don't load Flax/JAX
|
| 6 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # silence TF C++ logs (INFO/WARN)
|
| 7 |
+
|
| 8 |
+
# optional: extra hush
|
| 9 |
+
try:
|
| 10 |
+
from transformers.utils import logging as hf_logging
|
| 11 |
+
hf_logging.set_verbosity_error()
|
| 12 |
+
import tensorflow as tf
|
| 13 |
+
tf.get_logger().setLevel("ERROR")
|
| 14 |
+
import absl.logging
|
| 15 |
+
absl.logging.set_verbosity(absl.logging.ERROR)
|
| 16 |
+
except Exception:
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from typing import Dict, List, Tuple
|
| 22 |
+
from PIL import Image
|
| 23 |
+
import torch
|
| 24 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# A solid, ready-made age classifier (ViT backbone).
|
| 29 |
+
# It predicts age ranges as classes; we map ranges -> a point estimate.
|
| 30 |
+
HF_MODEL_ID = "nateraw/vit-age-classifier"
|
| 31 |
+
|
| 32 |
+
# Map model class labels to numeric midpoints for an age estimate
|
| 33 |
+
# (these labels come with the above model).
|
| 34 |
+
AGE_RANGE_TO_MID = {
|
| 35 |
+
"0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
|
| 36 |
+
"40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
class PretrainedAgeEstimator:
|
| 40 |
+
def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None, use_fast: bool = True):
|
| 41 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 43 |
+
|
| 44 |
+
self.model = AutoModelForImageClassification.from_pretrained(model_id)
|
| 45 |
+
self.model.to(self.device).eval()
|
| 46 |
+
|
| 47 |
+
# keep label mapping
|
| 48 |
+
self.id2label: Dict[int, str] = self.model.config.id2label
|
| 49 |
+
self.label2id: Dict[str, int] = self.model.config.label2id
|
| 50 |
+
|
| 51 |
+
@torch.inference_mode()
|
| 52 |
+
def predict(self, img: Image.Image, topk: int = 3) -> Tuple[float, List[Tuple[str, float]]]:
|
| 53 |
+
"""Return (point_estimate_age, [(label, prob), ...])."""
|
| 54 |
+
if img.mode != "RGB":
|
| 55 |
+
img = img.convert("RGB")
|
| 56 |
+
|
| 57 |
+
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
|
| 58 |
+
logits = self.model(**inputs).logits
|
| 59 |
+
probs = logits.softmax(dim=-1).squeeze(0) # [num_classes]
|
| 60 |
+
|
| 61 |
+
# Top-k labels
|
| 62 |
+
topk = min(topk, probs.numel())
|
| 63 |
+
values, indices = torch.topk(probs, k=topk)
|
| 64 |
+
top = [(self.id2label[i.item()], v.item()) for i, v in zip(indices, values)]
|
| 65 |
+
|
| 66 |
+
# Point estimate from expectation over range midpoints
|
| 67 |
+
expected = 0.0
|
| 68 |
+
for i, p in enumerate(probs):
|
| 69 |
+
label = self.id2label[i]
|
| 70 |
+
mid = AGE_RANGE_TO_MID.get(label, 35) # default mid if unseen label
|
| 71 |
+
expected += mid * p.item()
|
| 72 |
+
|
| 73 |
+
return expected, top
|
infer.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# infer.py
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from hf_model import PretrainedAgeEstimator
|
| 4 |
+
|
| 5 |
+
def predict_age(image_path: str):
|
| 6 |
+
est = PretrainedAgeEstimator()
|
| 7 |
+
img = Image.open(image_path)
|
| 8 |
+
age, top = est.predict(img, topk=3)
|
| 9 |
+
return age, top
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
import sys, json
|
| 13 |
+
path = sys.argv[1]
|
| 14 |
+
age, top = predict_age(path)
|
| 15 |
+
print(json.dumps({
|
| 16 |
+
"estimated_age": round(age, 1),
|
| 17 |
+
"top_classes": [(lbl, round(prob, 4)) for lbl, prob in top]
|
| 18 |
+
}, indent=2))
|