hongyu12321 commited on
Commit
aec1787
Β·
verified Β·
1 Parent(s): 7093eab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -80
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py β€” Age-first + FAST cartoon (Turbo) with prompt hint pickers (largest face only)
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
@@ -42,12 +42,12 @@ class PretrainedAgeEstimator:
42
  for i, p in enumerate(probs))
43
  return expected, top
44
 
45
- # ------------------ Face detection with WIDER crop (largest face) ------------------
46
  from facenet_pytorch import MTCNN
47
 
48
  class FaceCropper:
49
- """Detect faces; return (cropped_wide, annotated). Adds margin so face isn't full screen."""
50
- def __init__(self, device: Optional[str] = None, margin_scale: float = 1.8):
51
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
52
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
53
  self.margin_scale = margin_scale
@@ -64,9 +64,8 @@ class FaceCropper:
64
 
65
  annotated = pil.copy()
66
  draw = ImageDraw.Draw(annotated)
67
-
68
  if boxes is None or len(boxes) == 0:
69
- return None, annotated # no faces
70
 
71
  # draw all boxes
72
  for b, p in zip(boxes, probs):
@@ -74,10 +73,10 @@ class FaceCropper:
74
  draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
75
  draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
76
 
77
- # choose largest face
78
  idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
79
  x1, y1, x2, y2 = boxes[idx]
80
- # expand with margin (4:5 portrait feel)
81
  cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
82
  w, h = (x2 - x1), (y2 - y1)
83
  side = max(w, h) * self.margin_scale
@@ -92,21 +91,19 @@ class FaceCropper:
92
  crop = pil.crop((nx1, ny1, nx2, ny2))
93
  return crop, annotated
94
 
95
- # ------------------ FAST Cartoonizer (SD-Turbo) with safety ------------------
96
  from diffusers import AutoPipelineForImage2Image
97
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
98
  from transformers import AutoFeatureExtractor
99
 
100
- # Turbo is very fast (1–4 steps). Great for stylization on CPU/GPU.
101
  TURBO_ID = "stabilityai/sd-turbo"
102
 
103
  def load_turbo_pipe(device):
104
- dtype = torch.float16 if (device == "cuda") else torch.float32
105
  pipe = AutoPipelineForImage2Image.from_pretrained(
106
  TURBO_ID,
107
- dtype=dtype, # βœ… use dtype (no deprecation warning)
108
- )
109
- pipe = pipe.to(device)
110
  # Safety checker ON for public Spaces
111
  pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
112
  "CompVis/stable-diffusion-safety-checker"
@@ -122,10 +119,10 @@ def load_turbo_pipe(device):
122
 
123
  # ------------------ Init models once ------------------
124
  age_est = PretrainedAgeEstimator()
125
- cropper = FaceCropper(device=age_est.device, margin_scale=1.85) # 1.6–2.0 feels good
126
  sd_pipe = load_turbo_pipe(age_est.device)
127
 
128
- # ------------------ Prompt hint dictionaries ------------------
129
  ROLE_CHOICES = [
130
  "Queen/Princess", "King/Prince", "Fairy", "Elf", "Knight", "Sorcerer/Sorceress",
131
  "Steampunk Royalty", "Cyberpunk Royalty", "Superhero", "Anime Protagonist"
@@ -157,8 +154,7 @@ EFFECTS_CHOICES = [
157
  ]
158
 
159
  NEGATIVE_PROMPT = (
160
- "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, "
161
- "blurry, watermark, text, logo"
162
  )
163
 
164
  # ------------------ Helpers ------------------
@@ -166,7 +162,6 @@ def _ensure_pil(img):
166
  return img if isinstance(img, Image.Image) else Image.fromarray(img)
167
 
168
  def _resize_512(im: Image.Image):
169
- # keep aspect, fit longest side to 512 (faster, fewer artifacts)
170
  w, h = im.size
171
  scale = 512 / max(w, h)
172
  if scale < 1.0:
@@ -174,8 +169,16 @@ def _resize_512(im: Image.Image):
174
  return im
175
 
176
  def build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra):
177
- bits = []
178
- # role to base descriptors
 
 
 
 
 
 
 
 
179
  role_map = {
180
  "Queen/Princess": "regal queen/princess portrait",
181
  "King/Prince": "regal king/prince portrait",
@@ -186,61 +189,51 @@ def build_prompt(role, background, lighting, artstyle, colors, outfit, effects,
186
  "Steampunk Royalty": "steampunk royal portrait with brass filigree",
187
  "Cyberpunk Royalty": "cyberpunk royal portrait with neon accents",
188
  "Superhero": "heroic comic-style portrait",
189
- "Anime Protagonist": "anime protagonist portrait"
190
  }
191
- if role:
192
- bits.append(role_map.get(role, role))
193
 
194
- # the hint pickers
195
  for group in (background, lighting, artstyle, colors, outfit, effects):
196
  if group and isinstance(group, list):
197
- bits.append(", ".join(group))
 
198
 
199
- # strong general quality/style anchors
200
- bits.append("clean lineart, storybook illustration, high quality")
201
-
202
- # extra user text
203
  extra = (extra or "").strip()
204
  if extra:
205
- bits.append(extra)
206
 
207
- # join
208
- return ", ".join([b for b in bits if b])
209
 
210
- # ------------------ 1) Predict Age (fast, largest face) ------------------
211
  @torch.inference_mode()
212
  def predict_age_only(img, auto_crop=True):
213
  if img is None:
214
  return {}, "Please upload an image.", None
215
- img = _ensure_pil(img).convert("RGB")
216
 
217
- face_wide = None
218
- annotated = None
219
  if auto_crop:
220
- face_wide, annotated = cropper.detect_and_crop_wide(img)
221
- target = face_wide if face_wide is not None else img
222
 
 
223
  age, top = age_est.predict(target, topk=5)
224
  probs = {lbl: float(p) for lbl, p in top}
225
  summary = f"**Estimated age:** {age:.1f} years"
226
- return probs, summary, (annotated if annotated is not None else img)
227
 
228
- # ------------------ 2) Generate Cartoon (fast, largest face) ------------------
229
  @torch.inference_mode()
230
  def generate_cartoon(img, role, background, lighting, artstyle, colors, outfit, effects,
231
  extra_desc, auto_crop=True, strength=0.5, steps=2, seed=-1):
232
  if img is None:
233
  return None
 
234
 
235
- img = _ensure_pil(img).convert("RGB")
236
  if auto_crop:
237
- face_wide, _ = cropper.detect_and_crop_wide(img)
238
  if face_wide is not None:
239
- img = face_wide
240
-
241
- img = _resize_512(img)
242
 
243
- # prompt assembly from pickers
244
  prompt = build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra_desc)
245
 
246
  generator = None
@@ -250,54 +243,51 @@ def generate_cartoon(img, role, background, lighting, artstyle, colors, outfit,
250
  out = sd_pipe(
251
  prompt=prompt,
252
  negative_prompt=NEGATIVE_PROMPT,
253
- image=img,
254
- strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background
255
- guidance_scale=0.0, # Turbo commonly uses 0
256
- num_inference_steps=int(steps), # 1–4 steps β†’ very fast
257
  generator=generator,
258
  )
259
  return out.images[0]
260
 
261
- # ------------------ UI ------------------
262
- with gr.Blocks(title="Age First + Fast Cartoon (with Hint Pickers)") as demo:
263
- gr.Markdown("# Upload or capture once β€” get age prediction first, then a beautiful cartoon ✨")
264
- gr.Markdown("Largest face is used if multiple people are present.")
265
 
266
  with gr.Row():
267
  with gr.Column(scale=1):
268
  img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
269
- auto = gr.Checkbox(True, label="Auto face crop (wide, recommended)")
270
-
271
- # --- Age first
272
- btn_age = gr.Button("Predict Age (fast)", variant="primary")
273
-
274
- gr.Markdown("### Cartoon Description Hints")
275
- role = gr.Dropdown(choices=ROLE_CHOICES, value="Queen/Princess", label="Role")
276
- background = gr.CheckboxGroup(choices=BACKGROUND_CHOICES, label="Background")
277
- lighting = gr.CheckboxGroup(choices=LIGHTING_CHOICES, label="Lighting")
278
- artstyle = gr.CheckboxGroup(choices=ARTSTYLE_CHOICES, label="Art Style")
279
- colors = gr.CheckboxGroup(choices=COLOR_CHOICES, label="Color Mood")
280
- outfit = gr.CheckboxGroup(choices=OUTFIT_CHOICES, label="Outfit / Accessories")
281
- effects = gr.CheckboxGroup(choices=EFFECTS_CHOICES, label="Magical Effects")
282
- extra = gr.Textbox(
283
- label="Extra description (optional)",
284
- placeholder="e.g., silver tiara, flowing gown, castle balcony at sunset"
285
- )
286
 
 
287
  with gr.Row():
288
- strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
289
- steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
290
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
291
-
292
- btn_cartoon = gr.Button("Make Cartoon (fast)", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  with gr.Column(scale=1):
295
- probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
296
  age_md = gr.Markdown(label="Age Summary")
297
  preview = gr.Image(label="Detection Preview")
298
  cartoon_out = gr.Image(label="Cartoon Result")
299
 
300
- # Wire the buttons
301
  btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
302
  btn_cartoon.click(
303
  fn=generate_cartoon,
@@ -306,7 +296,7 @@ with gr.Blocks(title="Age First + Fast Cartoon (with Hint Pickers)") as demo:
306
  outputs=cartoon_out
307
  )
308
 
309
- # Expose app for HF Spaces
310
  app = demo
311
 
312
  if __name__ == "__main__":
 
1
+ # app.py β€” Compact UI: Age-first + FAST cartoon (Turbo) with collapsible advanced options
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
 
42
  for i, p in enumerate(probs))
43
  return expected, top
44
 
45
+ # ------------------ Largest-face detector with nice margin ------------------
46
  from facenet_pytorch import MTCNN
47
 
48
  class FaceCropper:
49
+ """Detect faces; return (wide_crop, annotated). Largest face only; adds margin so face isn't full screen."""
50
+ def __init__(self, device: Optional[str] = None, margin_scale: float = 1.85):
51
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
52
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
53
  self.margin_scale = margin_scale
 
64
 
65
  annotated = pil.copy()
66
  draw = ImageDraw.Draw(annotated)
 
67
  if boxes is None or len(boxes) == 0:
68
+ return None, annotated
69
 
70
  # draw all boxes
71
  for b, p in zip(boxes, probs):
 
73
  draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3)
74
  draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0))
75
 
76
+ # choose largest
77
  idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes]))
78
  x1, y1, x2, y2 = boxes[idx]
79
+ # expand with margin (approx 4:5 portrait)
80
  cx, cy = (x1 + x2) / 2.0, (y1 + y2) / 2.0
81
  w, h = (x2 - x1), (y2 - y1)
82
  side = max(w, h) * self.margin_scale
 
91
  crop = pil.crop((nx1, ny1, nx2, ny2))
92
  return crop, annotated
93
 
94
+ # ------------------ Fast Cartoonizer (SD-Turbo) with safety ------------------
95
  from diffusers import AutoPipelineForImage2Image
96
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
97
  from transformers import AutoFeatureExtractor
98
 
 
99
  TURBO_ID = "stabilityai/sd-turbo"
100
 
101
  def load_turbo_pipe(device):
102
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
103
  pipe = AutoPipelineForImage2Image.from_pretrained(
104
  TURBO_ID,
105
+ dtype=dtype, # βœ… no deprecation warning
106
+ ).to(device)
 
107
  # Safety checker ON for public Spaces
108
  pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
109
  "CompVis/stable-diffusion-safety-checker"
 
119
 
120
  # ------------------ Init models once ------------------
121
  age_est = PretrainedAgeEstimator()
122
+ cropper = FaceCropper(device=age_est.device, margin_scale=1.85)
123
  sd_pipe = load_turbo_pipe(age_est.device)
124
 
125
+ # ------------------ Hint choices (with defaults) ------------------
126
  ROLE_CHOICES = [
127
  "Queen/Princess", "King/Prince", "Fairy", "Elf", "Knight", "Sorcerer/Sorceress",
128
  "Steampunk Royalty", "Cyberpunk Royalty", "Superhero", "Anime Protagonist"
 
154
  ]
155
 
156
  NEGATIVE_PROMPT = (
157
+ "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, blurry, watermark, text, logo"
 
158
  )
159
 
160
  # ------------------ Helpers ------------------
 
162
  return img if isinstance(img, Image.Image) else Image.fromarray(img)
163
 
164
  def _resize_512(im: Image.Image):
 
165
  w, h = im.size
166
  scale = 512 / max(w, h)
167
  if scale < 1.0:
 
169
  return im
170
 
171
  def build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra):
172
+ """Defaults always exist; user selections override them."""
173
+ # Defaults (applied if user doesn't choose)
174
+ role = role or "Queen/Princess"
175
+ background = background or ["castle balcony at sunset"]
176
+ lighting = lighting or ["soft magical lighting"]
177
+ artstyle = artstyle or ["storybook illustration"]
178
+ colors = colors or ["vibrant colors"]
179
+ outfit = outfit or ["elegant gown", "jeweled tiara/crown"]
180
+ effects = effects or ["sparkles", "glowing particles"]
181
+
182
  role_map = {
183
  "Queen/Princess": "regal queen/princess portrait",
184
  "King/Prince": "regal king/prince portrait",
 
189
  "Steampunk Royalty": "steampunk royal portrait with brass filigree",
190
  "Cyberpunk Royalty": "cyberpunk royal portrait with neon accents",
191
  "Superhero": "heroic comic-style portrait",
192
+ "Anime Protagonist": "anime protagonist portrait",
193
  }
 
 
194
 
195
+ parts = [role_map.get(role, role)]
196
  for group in (background, lighting, artstyle, colors, outfit, effects):
197
  if group and isinstance(group, list):
198
+ parts.append(", ".join(group))
199
+ parts.append("clean lineart, high quality")
200
 
 
 
 
 
201
  extra = (extra or "").strip()
202
  if extra:
203
+ parts.append(extra)
204
 
205
+ return ", ".join([p for p in parts if p])
 
206
 
207
+ # ------------------ Actions ------------------
208
  @torch.inference_mode()
209
  def predict_age_only(img, auto_crop=True):
210
  if img is None:
211
  return {}, "Please upload an image.", None
212
+ pil = _ensure_pil(img).convert("RGB")
213
 
214
+ face_wide, annotated = (None, None)
 
215
  if auto_crop:
216
+ face_wide, annotated = cropper.detect_and_crop_wide(pil)
 
217
 
218
+ target = face_wide if face_wide is not None else pil
219
  age, top = age_est.predict(target, topk=5)
220
  probs = {lbl: float(p) for lbl, p in top}
221
  summary = f"**Estimated age:** {age:.1f} years"
222
+ return probs, summary, (annotated if annotated is not None else pil)
223
 
 
224
  @torch.inference_mode()
225
  def generate_cartoon(img, role, background, lighting, artstyle, colors, outfit, effects,
226
  extra_desc, auto_crop=True, strength=0.5, steps=2, seed=-1):
227
  if img is None:
228
  return None
229
+ pil = _ensure_pil(img).convert("RGB")
230
 
 
231
  if auto_crop:
232
+ face_wide, _ = cropper.detect_and_crop_wide(pil)
233
  if face_wide is not None:
234
+ pil = face_wide
 
 
235
 
236
+ pil = _resize_512(pil)
237
  prompt = build_prompt(role, background, lighting, artstyle, colors, outfit, effects, extra_desc)
238
 
239
  generator = None
 
243
  out = sd_pipe(
244
  prompt=prompt,
245
  negative_prompt=NEGATIVE_PROMPT,
246
+ image=pil,
247
+ strength=float(strength), # 0.4–0.6 keeps identity & adds dress/background
248
+ guidance_scale=0.0, # Turbo likes 0
249
+ num_inference_steps=int(steps),# 1–4 β†’ fast
250
  generator=generator,
251
  )
252
  return out.images[0]
253
 
254
+ # ------------------ Compact UI ------------------
255
+ with gr.Blocks(title="Age + Cartoon (Compact)") as demo:
256
+ gr.Markdown("## Upload β†’ Predict Age β†’ Make Cartoon ✨")
257
+ gr.Markdown("Largest face is used if multiple people are present. Defaults are applied automatically.")
258
 
259
  with gr.Row():
260
  with gr.Column(scale=1):
261
  img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
262
+ auto = gr.Checkbox(True, label="Auto face crop (recommended)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ # Buttons visible immediately (no scrolling)
265
  with gr.Row():
266
+ btn_age = gr.Button("Predict Age", variant="primary")
267
+ btn_cartoon = gr.Button("Make Cartoon", variant="secondary")
268
+
269
+ # Collapsible advanced options
270
+ with gr.Accordion("🎨 Advanced Cartoon Options", open=False):
271
+ role = gr.Dropdown(choices=ROLE_CHOICES, value="Queen/Princess", label="Role")
272
+ background = gr.CheckboxGroup(choices=BACKGROUND_CHOICES, value=["castle balcony at sunset"], label="Background")
273
+ lighting = gr.CheckboxGroup(choices=LIGHTING_CHOICES, value=["soft magical lighting"], label="Lighting")
274
+ artstyle = gr.CheckboxGroup(choices=ARTSTYLE_CHOICES, value=["storybook illustration"], label="Art Style")
275
+ colors = gr.CheckboxGroup(choices=COLOR_CHOICES, value=["vibrant colors"], label="Color Mood")
276
+ outfit = gr.CheckboxGroup(choices=OUTFIT_CHOICES, value=["elegant gown", "jeweled tiara/crown"], label="Outfit / Accessories")
277
+ effects = gr.CheckboxGroup(choices=EFFECTS_CHOICES, value=["sparkles", "glowing particles"], label="Magical Effects")
278
+ extra = gr.Textbox(label="Extra description (optional)", placeholder="e.g., silver tiara, flowing gown, balcony at sunset")
279
+ with gr.Row():
280
+ strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
281
+ steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
282
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
283
 
284
  with gr.Column(scale=1):
285
+ probs_out = gr.Label(num_top_classes=5, label="Age Prediction")
286
  age_md = gr.Markdown(label="Age Summary")
287
  preview = gr.Image(label="Detection Preview")
288
  cartoon_out = gr.Image(label="Cartoon Result")
289
 
290
+ # Wire events
291
  btn_age.click(fn=predict_age_only, inputs=[img_in, auto], outputs=[probs_out, age_md, preview])
292
  btn_cartoon.click(
293
  fn=generate_cartoon,
 
296
  outputs=cartoon_out
297
  )
298
 
299
+ # Expose for HF Spaces
300
  app = demo
301
 
302
  if __name__ == "__main__":