hongyu12321 commited on
Commit
97ae321
·
verified ·
1 Parent(s): 26f7527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -135
app.py CHANGED
@@ -1,170 +1,192 @@
1
- # app.py
 
 
2
  import os
3
  os.environ["TRANSFORMERS_NO_TF"] = "1"
4
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
6
 
7
  import gradio as gr
8
- from PIL import Image
9
  import numpy as np
10
  import torch
11
 
12
- from hf_model import PretrainedAgeEstimator
13
- from face_utils import FaceCropper
14
-
15
- # NEW: diffusers for cartoonizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from diffusers import StableDiffusionImg2ImgPipeline
17
 
18
- # ---------- Load models once ----------
19
- est = PretrainedAgeEstimator()
20
- cropper = FaceCropper(device=est.device)
21
-
22
- # A solid, public SD 1.5 img2img pipeline; fast and reliable
23
  SD15_ID = "runwayml/stable-diffusion-v1-5"
24
- sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
25
- SD15_ID,
26
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
- safety_checker=None, # rely on prompts; HF Spaces also has a global filter
28
- ).to(est.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # ---------- Helpers ----------
31
  def _ensure_pil(img):
32
- if isinstance(img, Image.Image):
33
- return img
34
- return Image.fromarray(img)
35
 
36
- # ----- Age: single image -----
37
- def predict_single(img, auto_crop=True, topk=5, show_annot=True):
38
  if img is None:
39
- return {}, "No image provided.", None
 
40
  img = _ensure_pil(img).convert("RGB")
41
 
42
- preview = img
43
  face = None
 
44
  if auto_crop:
45
- face, annotated, _ = cropper.detect_and_crop(img, select="largest")
46
- preview = annotated if show_annot else img
47
-
48
- target = face if face is not None else img
49
- age, top = est.predict(target, topk=topk)
50
 
51
- probs = {lbl: float(prob) for lbl, prob in top}
 
 
 
52
  summary = f"**Estimated age:** {age:.1f} years"
53
- return probs, summary, preview
54
-
55
- # ----- Age: batch -----
56
- def predict_batch(files, auto_crop=True, topk=5):
57
- if not files:
58
- return "No files uploaded."
59
- rows = ["| File | Estimated Age | Top-1 | p |", "|---|---:|---|---:|"]
60
- for f in files:
61
- try:
62
- img = Image.open(f.name).convert("RGB")
63
- face = None
64
- if auto_crop:
65
- face, _, _ = cropper.detect_and_crop(img, select="largest")
66
- target = face if face is not None else img
67
- age, top = est.predict(target, topk=topk)
68
- top1_lbl, top1_p = top[0]
69
- rows.append(f"| {os.path.basename(f.name)} | {age:.1f} | {top1_lbl} | {top1_p:.3f} |")
70
- except Exception:
71
- rows.append(f"| {os.path.basename(f.name)} | (error) | - | - |")
72
- return "\n".join(rows)
73
-
74
- # ----- NEW: Cartoonizer (img2img) -----
75
- def cartoonize(img, prompt, strength=0.6, guidance=7.5, steps=25, seed=0, use_face_crop=True):
76
- """
77
- img: PIL or numpy
78
- prompt: text description, e.g. "cute cel-shaded cartoon, soft outlines, vibrant colors"
79
- strength: how much to deviate from the input (0.3 subtle → 0.8 strong)
80
- guidance: prompt strength (5–12 typical)
81
- steps: diffusion steps (20–40 typical)
82
- seed: reproducibility (-1 for random)
83
- """
84
- if img is None:
85
- return None
86
-
87
- img = _ensure_pil(img).convert("RGB")
88
-
89
- # optional crop to the largest face for better identity preservation
90
- if use_face_crop:
91
- face, _, _ = cropper.detect_and_crop(img, select="largest")
92
- if face is not None:
93
- img = face
94
 
95
- # cartoon-y defaults (you can tweak in UI)
96
- base_prompt = (
97
- "cartoon, cel-shaded, clean lineart, smooth shading, high contrast, vibrant, studio ghibli style, "
98
- "pixar style, highly detailed, 2D illustration"
99
- )
100
- full_prompt = f"{base_prompt}, {prompt}".strip().strip(",")
101
 
102
  generator = None
103
- if seed and seed >= 0:
104
- generator = torch.Generator(device=est.device).manual_seed(int(seed))
105
 
 
106
  out = sd_pipe(
107
- prompt=full_prompt,
108
- image=img,
109
- strength=float(strength),
110
- guidance_scale=float(guidance),
111
  num_inference_steps=int(steps),
112
  generator=generator,
113
  )
114
- result = out.images[0]
115
- return result
116
-
117
- # ---------- UI ----------
118
- with gr.Blocks(title="Pretrained Age Estimator + Cartoonizer") as demo:
119
- gr.Markdown("# Pretrained Age Estimator + Cartoonizer")
120
- gr.Markdown("Detects age from a face and can also generate a cartoonized image guided by your text description.")
121
-
122
- with gr.Tabs():
123
- with gr.Tab("Age (Single)"):
124
- with gr.Row():
125
- with gr.Column():
126
- inp = gr.Image(type="pil", label="Upload a face image")
127
- cam = gr.Image(sources=["webcam"], type="pil", label="Webcam (optional)")
128
- auto = gr.Checkbox(True, label="Auto face crop (MTCNN)")
129
- topk = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
130
- annot = gr.Checkbox(True, label="Show detection preview")
131
- btn = gr.Button("Predict Age", variant="primary")
132
- with gr.Column():
133
- out_label = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
134
- out_md = gr.Markdown(label="Summary")
135
- out_prev = gr.Image(label="Preview", visible=True)
136
-
137
- def run_single(img, cam_img, auto_crop, topk_val, show_annot):
138
- chosen = cam_img if cam_img is not None else img
139
- return predict_single(chosen, auto_crop, int(topk_val), show_annot)
140
-
141
- btn.click(fn=run_single, inputs=[inp, cam, auto, topk, annot],
142
- outputs=[out_label, out_md, out_prev])
143
-
144
- with gr.Tab("Age (Batch)"):
145
- files = gr.Files(label="Upload multiple images")
146
- auto_b = gr.Checkbox(True, label="Auto face crop (MTCNN)")
147
- topk_b = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
148
- btn_b = gr.Button("Run batch")
149
- out_table = gr.Markdown()
150
- btn_b.click(fn=predict_batch, inputs=[files, auto_b, topk_b], outputs=out_table)
151
-
152
- with gr.Tab("Cartoonizer"):
153
- src = gr.Image(type="pil", label="Source image (face or any photo)")
154
- prompt = gr.Textbox(label="Your style prompt",
155
- value="cute cel-shaded cartoon, clean lines, soft colors")
156
  with gr.Row():
157
- strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Transformation strength")
158
- guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance scale")
159
  steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
160
- seed = gr.Number(value=0, precision=0, label="Seed (0 or -1 = random)")
161
- use_crop = gr.Checkbox(True, label="Crop to largest face before stylizing")
162
- btn_c = gr.Button("Generate Cartoon", variant="primary")
163
- out_img = gr.Image(label="Cartoon result")
164
-
165
- btn_c.click(fn=cartoonize,
166
- inputs=[src, prompt, strength, guidance, steps, seed, use_crop],
167
- outputs=out_img)
 
 
 
168
 
169
  if __name__ == "__main__":
170
  demo.launch()
 
1
+ # app.py — One-page Age + Cartoon app (no extra modules needed)
2
+
3
+ # Quiet TF/Flax logs (PyTorch-only)
4
  import os
5
  os.environ["TRANSFORMERS_NO_TF"] = "1"
6
  os.environ["TRANSFORMERS_NO_FLAX"] = "1"
7
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
8
 
9
  import gradio as gr
10
+ from PIL import Image, ImageDraw
11
  import numpy as np
12
  import torch
13
 
14
+ # ---------------------------
15
+ # 1) Pretrained Age Estimator
16
+ # ---------------------------
17
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
18
+
19
+ HF_MODEL_ID = "nateraw/vit-age-classifier"
20
+ AGE_RANGE_TO_MID = {
21
+ "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35,
22
+ "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75
23
+ }
24
+
25
+ class PretrainedAgeEstimator:
26
+ def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None):
27
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
28
+ self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
29
+ self.model = AutoModelForImageClassification.from_pretrained(model_id)
30
+ self.model.to(self.device).eval()
31
+ self.id2label = self.model.config.id2label
32
+
33
+ @torch.inference_mode()
34
+ def predict(self, img: Image.Image, topk: int = 5):
35
+ if img.mode != "RGB":
36
+ img = img.convert("RGB")
37
+ inputs = self.processor(images=img, return_tensors="pt").to(self.device)
38
+ logits = self.model(**inputs).logits
39
+ probs = logits.softmax(dim=-1).squeeze(0)
40
+ k = min(topk, probs.numel())
41
+ values, indices = torch.topk(probs, k=k)
42
+ top = [(self.id2label[i.item()], float(v.item())) for i, v in zip(indices, values)]
43
+ expected = sum(AGE_RANGE_TO_MID.get(self.id2label[i], 35) * float(p)
44
+ for i, p in enumerate(probs))
45
+ return expected, top
46
+
47
+ # ---------------------------
48
+ # 2) Face detector / cropper (MTCNN)
49
+ # ---------------------------
50
+ from facenet_pytorch import MTCNN
51
+ class FaceCropper:
52
+ """Detect faces and return (cropped_face, annotated_image)."""
53
+ def __init__(self, device: str | None = None):
54
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
55
+ self.mtcnn = MTCNN(keep_all=True, device=self.device)
56
+
57
+ def _ensure_pil(self, img):
58
+ if isinstance(img, Image.Image):
59
+ return img.convert("RGB")
60
+ return Image.fromarray(img).convert("RGB")
61
+
62
+ def detect_and_crop(self, img, select="largest"):
63
+ pil = self._ensure_pil(img)
64
+ boxes, probs = self.mtcnn.detect(pil)
65
+
66
+ annotated = pil.copy()
67
+ draw = ImageDraw.Draw(annotated)
68
+
69
+ if boxes is None or len(boxes) == 0:
70
+ return None, annotated
71
+
72
+ # draw boxes
73
+ for b, p in zip(boxes, probs):
74
+ x1, y1, x2, y2 = map(float, b)
75
+ draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0), width=3)
76
+ draw.text((x1, max(0, y1-12)), f"{p:.2f}", fill=(255, 0, 0))
77
+
78
+ # choose largest by area
79
+ idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
80
+ if isinstance(select, int) and 0 <= select < len(boxes):
81
+ idx = select
82
+ x1, y1, x2, y2 = boxes[idx].astype(int)
83
+ face = pil.crop((x1, y1, x2, y2))
84
+ return face, annotated
85
+
86
+ # ---------------------------
87
+ # 3) Cartoonizer (Stable Diffusion img2img)
88
+ # ---------------------------
89
  from diffusers import StableDiffusionImg2ImgPipeline
90
 
 
 
 
 
 
91
  SD15_ID = "runwayml/stable-diffusion-v1-5"
92
+ def load_sd_pipe(device):
93
+ dtype = torch.float16 if (device == "cuda") else torch.float32
94
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
95
+ SD15_ID,
96
+ torch_dtype=dtype,
97
+ safety_checker=None, # rely on prompts; HF has global content filters
98
+ )
99
+ return pipe.to(device)
100
+
101
+ # ---------------------------
102
+ # 4) Initialize models once
103
+ # ---------------------------
104
+ age_est = PretrainedAgeEstimator()
105
+ cropper = FaceCropper(device=age_est.device)
106
+ sd_pipe = load_sd_pipe(age_est.device)
107
+
108
+ # ---------------------------
109
+ # 5) App logic (one click does both)
110
+ # ---------------------------
111
+ DEFAULT_PROMPT = (
112
+ "cartoon, cel-shaded, clean lineart, smooth shading, vibrant colors, "
113
+ "studio ghibli style, pixar style, 2D illustration, high quality"
114
+ )
115
 
 
116
  def _ensure_pil(img):
117
+ return img if isinstance(img, Image.Image) else Image.fromarray(img)
 
 
118
 
119
+ @torch.inference_mode()
120
+ def run_all(img, prompt, auto_crop=True, strength=0.6, guidance=7.5, steps=25, seed=-1):
121
  if img is None:
122
+ return {}, "Please upload an image.", None
123
+
124
  img = _ensure_pil(img).convert("RGB")
125
 
126
+ # ---- choose region for both age + cartoon ----
127
  face = None
128
+ annotated = None
129
  if auto_crop:
130
+ face, annotated = cropper.detect_and_crop(img, select="largest")
 
 
 
 
131
 
132
+ target_for_age = face if face is not None else img
133
+ # Age prediction
134
+ age, top = age_est.predict(target_for_age, topk=5)
135
+ probs = {lbl: float(p) for lbl, p in top}
136
  summary = f"**Estimated age:** {age:.1f} years"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Cartoon generation
139
+ txt = (prompt or "").strip()
140
+ if not txt:
141
+ txt = DEFAULT_PROMPT
142
+ else:
143
+ txt = f"{DEFAULT_PROMPT}, {txt}"
144
 
145
  generator = None
146
+ if isinstance(seed, (int, float)) and int(seed) >= 0:
147
+ generator = torch.Generator(device=age_est.device).manual_seed(int(seed))
148
 
149
+ base_img = face if face is not None else img
150
  out = sd_pipe(
151
+ prompt=txt,
152
+ image=base_img,
153
+ strength=float(strength), # 0.3 subtle → 0.8 strong
154
+ guidance_scale=float(guidance), # 5–12 typical
155
  num_inference_steps=int(steps),
156
  generator=generator,
157
  )
158
+ cartoon = out.images[0]
159
+ return probs, summary, cartoon
160
+
161
+ # ---------------------------
162
+ # 6) Gradio UI (single page)
163
+ # ---------------------------
164
+ with gr.Blocks(title="Age + Cartoon (One Page)") as demo:
165
+ gr.Markdown("# Age Estimator + Cartoonizer")
166
+ gr.Markdown("Upload or capture once — get **age prediction** and a **cartoon** of the same image.")
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ img_in = gr.Image(sources=["upload", "webcam"], type="pil",
171
+ label="Upload / Webcam")
172
+ prompt = gr.Textbox(label="(Optional) Extra cartoon style",
173
+ placeholder="e.g., comic-book halftone, bold lines, neon palette")
174
+ auto = gr.Checkbox(True, label="Auto face crop (recommended)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with gr.Row():
176
+ strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Cartoon strength")
177
+ guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance")
178
  steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
179
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
180
+ go = gr.Button("Predict Age + Generate Cartoon", variant="primary", size="lg")
181
+
182
+ with gr.Column(scale=1):
183
+ probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
184
+ age_md = gr.Markdown(label="Age Summary")
185
+ cartoon_out = gr.Image(label="Cartoon Result")
186
+
187
+ go.click(fn=run_all,
188
+ inputs=[img_in, prompt, auto, strength, guidance, steps, seed],
189
+ outputs=[probs_out, age_md, cartoon_out])
190
 
191
  if __name__ == "__main__":
192
  demo.launch()