hongyu12321 commited on
Commit
df3ab33
·
verified ·
1 Parent(s): c5f0322

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +40 -70
  2. hf_model.py +73 -0
  3. infer.py +18 -0
app.py CHANGED
@@ -1,70 +1,40 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
68
-
69
- if __name__ == "__main__":
70
- demo.launch()
 
1
+ # app.py
2
+ import os
3
+ os.environ["TRANSFORMERS_NO_TF"] = "1"
4
+ os.environ["TRANSFORMERS_NO_FLAX"] = "1"
5
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
+
7
+ import gradio as gr
8
+ from PIL import Image
9
+ from hf_model import PretrainedAgeEstimator
10
+
11
+ est = PretrainedAgeEstimator()
12
+
13
+ def predict(img):
14
+ # Gradio may pass PIL or numpy; handle both
15
+ if not isinstance(img, Image.Image):
16
+ img = Image.fromarray(img)
17
+
18
+ age, top = est.predict(img, topk=5)
19
+
20
+ # 1) dict[str, float] for Label
21
+ probs = {lbl: float(prob) for lbl, prob in top}
22
+
23
+ # 2) plain string for the estimate
24
+ summary = f"Estimated age: **{age:.1f}** years"
25
+
26
+ return probs, summary
27
+
28
+ demo = gr.Interface(
29
+ fn=predict,
30
+ inputs=gr.Image(type="pil", label="Upload a face image"),
31
+ outputs=[
32
+ gr.Label(num_top_classes=5, label="Age Prediction (probabilities)"),
33
+ gr.Markdown(label="Summary"),
34
+ ],
35
+ title="Pretrained Age Estimator",
36
+ description="Runs a pretrained ViT-based age classifier and reports a point estimate from class probabilities."
37
+ )
38
+
39
+ if __name__ == "__main__":
40
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hf_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hf_model.py
2
+ # ---- keep TF/JAX silent and unused ----
3
+ import os
4
+ os.environ["TRANSFORMERS_NO_TF"] = "1" # don't load TensorFlow
5
+ os.environ["TRANSFORMERS_NO_FLAX"] = "1" # don't load Flax/JAX
6
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # silence TF C++ logs (INFO/WARN)
7
+
8
+ # optional: extra hush
9
+ try:
10
+ from transformers.utils import logging as hf_logging
11
+ hf_logging.set_verbosity_error()
12
+ import tensorflow as tf
13
+ tf.get_logger().setLevel("ERROR")
14
+ import absl.logging
15
+ absl.logging.set_verbosity(absl.logging.ERROR)
16
+ except Exception:
17
+ pass
18
+
19
+
20
+
21
+ from typing import Dict, List, Tuple
22
+ from PIL import Image
23
+ import torch
24
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
25
+
26
+
27
+
28
+ # A solid, ready-made age classifier (ViT backbone).
29
+ # It predicts age ranges as classes; we map ranges -> a point estimate.
30
+ HF_MODEL_ID = "nateraw/vit-age-classifier"
31
+
32
+ # Map model class labels to numeric midpoints for an age estimate
33
+ # (these labels come with the above model).
34
+ AGE_RANGE_TO_MID = {
35
+ "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
36
+ "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
37
+ }
38
+
39
+ class PretrainedAgeEstimator:
40
+ def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None, use_fast: bool = True):
41
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
42
+ self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
43
+
44
+ self.model = AutoModelForImageClassification.from_pretrained(model_id)
45
+ self.model.to(self.device).eval()
46
+
47
+ # keep label mapping
48
+ self.id2label: Dict[int, str] = self.model.config.id2label
49
+ self.label2id: Dict[str, int] = self.model.config.label2id
50
+
51
+ @torch.inference_mode()
52
+ def predict(self, img: Image.Image, topk: int = 3) -> Tuple[float, List[Tuple[str, float]]]:
53
+ """Return (point_estimate_age, [(label, prob), ...])."""
54
+ if img.mode != "RGB":
55
+ img = img.convert("RGB")
56
+
57
+ inputs = self.processor(images=img, return_tensors="pt").to(self.device)
58
+ logits = self.model(**inputs).logits
59
+ probs = logits.softmax(dim=-1).squeeze(0) # [num_classes]
60
+
61
+ # Top-k labels
62
+ topk = min(topk, probs.numel())
63
+ values, indices = torch.topk(probs, k=topk)
64
+ top = [(self.id2label[i.item()], v.item()) for i, v in zip(indices, values)]
65
+
66
+ # Point estimate from expectation over range midpoints
67
+ expected = 0.0
68
+ for i, p in enumerate(probs):
69
+ label = self.id2label[i]
70
+ mid = AGE_RANGE_TO_MID.get(label, 35) # default mid if unseen label
71
+ expected += mid * p.item()
72
+
73
+ return expected, top
infer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # infer.py
2
+ from PIL import Image
3
+ from hf_model import PretrainedAgeEstimator
4
+
5
+ def predict_age(image_path: str):
6
+ est = PretrainedAgeEstimator()
7
+ img = Image.open(image_path)
8
+ age, top = est.predict(img, topk=3)
9
+ return age, top
10
+
11
+ if __name__ == "__main__":
12
+ import sys, json
13
+ path = sys.argv[1]
14
+ age, top = predict_age(path)
15
+ print(json.dumps({
16
+ "estimated_age": round(age, 1),
17
+ "top_classes": [(lbl, round(prob, 4)) for lbl, prob in top]
18
+ }, indent=2))