hongyu12321 commited on
Commit
8fc4073
·
verified ·
1 Parent(s): aca8858

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -40
app.py CHANGED
@@ -1,40 +1,170 @@
1
- # app.py
2
- import os
3
- os.environ["TRANSFORMERS_NO_TF"] = "1"
4
- os.environ["TRANSFORMERS_NO_FLAX"] = "1"
5
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
-
7
- import gradio as gr
8
- from PIL import Image
9
- from hf_model import PretrainedAgeEstimator
10
-
11
- est = PretrainedAgeEstimator()
12
-
13
- def predict(img):
14
- # Gradio may pass PIL or numpy; handle both
15
- if not isinstance(img, Image.Image):
16
- img = Image.fromarray(img)
17
-
18
- age, top = est.predict(img, topk=5)
19
-
20
- # 1) dict[str, float] for Label
21
- probs = {lbl: float(prob) for lbl, prob in top}
22
-
23
- # 2) plain string for the estimate
24
- summary = f"Estimated age: **{age:.1f}** years"
25
-
26
- return probs, summary
27
-
28
- demo = gr.Interface(
29
- fn=predict,
30
- inputs=gr.Image(type="pil", label="Upload a face image"),
31
- outputs=[
32
- gr.Label(num_top_classes=5, label="Age Prediction (probabilities)"),
33
- gr.Markdown(label="Summary"),
34
- ],
35
- title="Pretrained Age Estimator",
36
- description="Runs a pretrained ViT-based age classifier and reports a point estimate from class probabilities."
37
- )
38
-
39
- if __name__ == "__main__":
40
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ os.environ["TRANSFORMERS_NO_TF"] = "1"
4
+ os.environ["TRANSFORMERS_NO_FLAX"] = "1"
5
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
+
7
+ import gradio as gr
8
+ from PIL import Image
9
+ import numpy as np
10
+ import torch
11
+
12
+ from hf_model import PretrainedAgeEstimator
13
+ from face_utils import FaceCropper
14
+
15
+ # NEW: diffusers for cartoonizer
16
+ from diffusers import StableDiffusionImg2ImgPipeline
17
+
18
+ # ---------- Load models once ----------
19
+ est = PretrainedAgeEstimator()
20
+ cropper = FaceCropper(device=est.device)
21
+
22
+ # A solid, public SD 1.5 img2img pipeline; fast and reliable
23
+ SD15_ID = "runwayml/stable-diffusion-v1-5"
24
+ sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
25
+ SD15_ID,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
+ safety_checker=None, # rely on prompts; HF Spaces also has a global filter
28
+ ).to(est.device)
29
+
30
+ # ---------- Helpers ----------
31
+ def _ensure_pil(img):
32
+ if isinstance(img, Image.Image):
33
+ return img
34
+ return Image.fromarray(img)
35
+
36
+ # ----- Age: single image -----
37
+ def predict_single(img, auto_crop=True, topk=5, show_annot=True):
38
+ if img is None:
39
+ return {}, "No image provided.", None
40
+ img = _ensure_pil(img).convert("RGB")
41
+
42
+ preview = img
43
+ face = None
44
+ if auto_crop:
45
+ face, annotated, _ = cropper.detect_and_crop(img, select="largest")
46
+ preview = annotated if show_annot else img
47
+
48
+ target = face if face is not None else img
49
+ age, top = est.predict(target, topk=topk)
50
+
51
+ probs = {lbl: float(prob) for lbl, prob in top}
52
+ summary = f"**Estimated age:** {age:.1f} years"
53
+ return probs, summary, preview
54
+
55
+ # ----- Age: batch -----
56
+ def predict_batch(files, auto_crop=True, topk=5):
57
+ if not files:
58
+ return "No files uploaded."
59
+ rows = ["| File | Estimated Age | Top-1 | p |", "|---|---:|---|---:|"]
60
+ for f in files:
61
+ try:
62
+ img = Image.open(f.name).convert("RGB")
63
+ face = None
64
+ if auto_crop:
65
+ face, _, _ = cropper.detect_and_crop(img, select="largest")
66
+ target = face if face is not None else img
67
+ age, top = est.predict(target, topk=topk)
68
+ top1_lbl, top1_p = top[0]
69
+ rows.append(f"| {os.path.basename(f.name)} | {age:.1f} | {top1_lbl} | {top1_p:.3f} |")
70
+ except Exception:
71
+ rows.append(f"| {os.path.basename(f.name)} | (error) | - | - |")
72
+ return "\n".join(rows)
73
+
74
+ # ----- NEW: Cartoonizer (img2img) -----
75
+ def cartoonize(img, prompt, strength=0.6, guidance=7.5, steps=25, seed=0, use_face_crop=True):
76
+ """
77
+ img: PIL or numpy
78
+ prompt: text description, e.g. "cute cel-shaded cartoon, soft outlines, vibrant colors"
79
+ strength: how much to deviate from the input (0.3 subtle → 0.8 strong)
80
+ guidance: prompt strength (5–12 typical)
81
+ steps: diffusion steps (20–40 typical)
82
+ seed: reproducibility (-1 for random)
83
+ """
84
+ if img is None:
85
+ return None
86
+
87
+ img = _ensure_pil(img).convert("RGB")
88
+
89
+ # optional crop to the largest face for better identity preservation
90
+ if use_face_crop:
91
+ face, _, _ = cropper.detect_and_crop(img, select="largest")
92
+ if face is not None:
93
+ img = face
94
+
95
+ # cartoon-y defaults (you can tweak in UI)
96
+ base_prompt = (
97
+ "cartoon, cel-shaded, clean lineart, smooth shading, high contrast, vibrant, studio ghibli style, "
98
+ "pixar style, highly detailed, 2D illustration"
99
+ )
100
+ full_prompt = f"{base_prompt}, {prompt}".strip().strip(",")
101
+
102
+ generator = None
103
+ if seed and seed >= 0:
104
+ generator = torch.Generator(device=est.device).manual_seed(int(seed))
105
+
106
+ out = sd_pipe(
107
+ prompt=full_prompt,
108
+ image=img,
109
+ strength=float(strength),
110
+ guidance_scale=float(guidance),
111
+ num_inference_steps=int(steps),
112
+ generator=generator,
113
+ )
114
+ result = out.images[0]
115
+ return result
116
+
117
+ # ---------- UI ----------
118
+ with gr.Blocks(title="Pretrained Age Estimator + Cartoonizer") as demo:
119
+ gr.Markdown("# Pretrained Age Estimator + Cartoonizer")
120
+ gr.Markdown("Detects age from a face and can also generate a cartoonized image guided by your text description.")
121
+
122
+ with gr.Tabs():
123
+ with gr.Tab("Age (Single)"):
124
+ with gr.Row():
125
+ with gr.Column():
126
+ inp = gr.Image(type="pil", label="Upload a face image")
127
+ cam = gr.Image(sources=["webcam"], type="pil", label="Webcam (optional)")
128
+ auto = gr.Checkbox(True, label="Auto face crop (MTCNN)")
129
+ topk = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
130
+ annot = gr.Checkbox(True, label="Show detection preview")
131
+ btn = gr.Button("Predict Age", variant="primary")
132
+ with gr.Column():
133
+ out_label = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
134
+ out_md = gr.Markdown(label="Summary")
135
+ out_prev = gr.Image(label="Preview", visible=True)
136
+
137
+ def run_single(img, cam_img, auto_crop, topk_val, show_annot):
138
+ chosen = cam_img if cam_img is not None else img
139
+ return predict_single(chosen, auto_crop, int(topk_val), show_annot)
140
+
141
+ btn.click(fn=run_single, inputs=[inp, cam, auto, topk, annot],
142
+ outputs=[out_label, out_md, out_prev])
143
+
144
+ with gr.Tab("Age (Batch)"):
145
+ files = gr.Files(label="Upload multiple images")
146
+ auto_b = gr.Checkbox(True, label="Auto face crop (MTCNN)")
147
+ topk_b = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
148
+ btn_b = gr.Button("Run batch")
149
+ out_table = gr.Markdown()
150
+ btn_b.click(fn=predict_batch, inputs=[files, auto_b, topk_b], outputs=out_table)
151
+
152
+ with gr.Tab("Cartoonizer"):
153
+ src = gr.Image(type="pil", label="Source image (face or any photo)")
154
+ prompt = gr.Textbox(label="Your style prompt",
155
+ value="cute cel-shaded cartoon, clean lines, soft colors")
156
+ with gr.Row():
157
+ strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Transformation strength")
158
+ guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance scale")
159
+ steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
160
+ seed = gr.Number(value=0, precision=0, label="Seed (0 or -1 = random)")
161
+ use_crop = gr.Checkbox(True, label="Crop to largest face before stylizing")
162
+ btn_c = gr.Button("Generate Cartoon", variant="primary")
163
+ out_img = gr.Image(label="Cartoon result")
164
+
165
+ btn_c.click(fn=cartoonize,
166
+ inputs=[src, prompt, strength, guidance, steps, seed, use_crop],
167
+ outputs=out_img)
168
+
169
+ if __name__ == "__main__":
170
+ demo.launch()