hongyu12321 commited on
Commit
6e327e0
·
verified ·
1 Parent(s): 2187ded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -10
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — Age-first + FAST group cartoons (SD-Turbo), single page
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
@@ -6,6 +6,8 @@ os.environ["TRANSFORMERS_NO_FLAX"] = "1"
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
8
  import math
 
 
9
  import gradio as gr
10
  from PIL import Image, ImageDraw
11
  import numpy as np
@@ -21,7 +23,7 @@ AGE_RANGE_TO_MID = {
21
  }
22
 
23
  class PretrainedAgeEstimator:
24
- def __init__(self, model_id: str = HF_MODEL_ID, device: str | None = None):
25
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
26
  self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
27
  self.model = AutoModelForImageClassification.from_pretrained(model_id)
@@ -52,7 +54,7 @@ class FaceCropper:
52
  - detect_all_wide: returns (list[crops], annotated, list[boxes])
53
  Boxes are (x1,y1,x2,y2) floats.
54
  """
55
- def __init__(self, device: str | None = None, margin_scale: float = 1.8):
56
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
57
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
58
  self.margin_scale = margin_scale
@@ -62,13 +64,13 @@ class FaceCropper:
62
  return img.convert("RGB")
63
  return Image.fromarray(img).convert("RGB")
64
 
65
- def _expand_box(self, box, W, H, aspect=0.8): # 4:5 portrait (w/h=0.8)
66
  x1, y1, x2, y2 = box
67
  cx, cy = (x1 + x2)/2, (y1 + y2)/2
68
  w, h = (x2 - x1), (y2 - y1)
69
  side = max(w, h) * self.margin_scale
70
  tw = side
71
- th = side / aspect # make it taller than wide
72
  nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
73
  ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
74
  return nx1, ny1, nx2, ny2
@@ -108,6 +110,7 @@ class FaceCropper:
108
  if boxes is None or len(boxes) == 0:
109
  return crops, annotated, []
110
 
 
111
  for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2):
112
  bx1, by1, bx2, by2 = map(float, b)
113
  draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
@@ -121,15 +124,24 @@ class FaceCropper:
121
 
122
  # ------------------ FAST Cartoonizer (SD-Turbo) ------------------
123
  from diffusers import AutoPipelineForImage2Image
 
 
 
124
  TURBO_ID = "stabilityai/sd-turbo"
125
 
126
  def load_turbo_pipe(device):
127
  dtype = torch.float16 if (device == "cuda") else torch.float32
128
  pipe = AutoPipelineForImage2Image.from_pretrained(
129
  TURBO_ID,
130
- torch_dtype=dtype,
131
- safety_checker=None,
132
  ).to(device)
 
 
 
 
 
 
 
133
  try:
134
  pipe.enable_attention_slicing()
135
  except Exception:
@@ -186,7 +198,7 @@ def predict_age(img, group_mode=False, auto_crop=True):
186
  top1, p1 = top[0]
187
  rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |")
188
  md = "\n".join(rows)
189
- # also return a simple dict from the largest (first) face just to feed Label
190
  age0, top0 = age_est.predict(crops[0], topk=5)
191
  probs0 = {lbl: float(p) for lbl, p in top0}
192
  return probs0, md, annotated
@@ -222,7 +234,69 @@ def cartoonize(img, prompt="", group_mode=False, auto_crop=True, strength=0.5, s
222
  if not crops:
223
  crops = [pil] # fallback
224
 
225
- # resize each to 384 for speed/variety
226
  proc = []
227
  for c in crops:
228
- c = _resiz_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — Age-first + FAST group cartoons (SD-Turbo), single page (HF Spaces safe)
2
 
3
  import os
4
  os.environ["TRANSFORMERS_NO_TF"] = "1"
 
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
7
 
8
  import math
9
+ from typing import Optional
10
+
11
  import gradio as gr
12
  from PIL import Image, ImageDraw
13
  import numpy as np
 
23
  }
24
 
25
  class PretrainedAgeEstimator:
26
+ def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None):
27
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
28
  self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
29
  self.model = AutoModelForImageClassification.from_pretrained(model_id)
 
54
  - detect_all_wide: returns (list[crops], annotated, list[boxes])
55
  Boxes are (x1,y1,x2,y2) floats.
56
  """
57
+ def __init__(self, device: Optional[str] = None, margin_scale: float = 1.8):
58
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
59
  self.mtcnn = MTCNN(keep_all=True, device=self.device)
60
  self.margin_scale = margin_scale
 
64
  return img.convert("RGB")
65
  return Image.fromarray(img).convert("RGB")
66
 
67
+ def _expand_box(self, box, W, H, aspect=0.8): # ~4:5 portrait (w/h=0.8)
68
  x1, y1, x2, y2 = box
69
  cx, cy = (x1 + x2)/2, (y1 + y2)/2
70
  w, h = (x2 - x1), (y2 - y1)
71
  side = max(w, h) * self.margin_scale
72
  tw = side
73
+ th = side / aspect # taller than wide
74
  nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
75
  ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
76
  return nx1, ny1, nx2, ny2
 
110
  if boxes is None or len(boxes) == 0:
111
  return crops, annotated, []
112
 
113
+ # sort roughly left->right for table order
114
  for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2):
115
  bx1, by1, bx2, by2 = map(float, b)
116
  draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
 
124
 
125
  # ------------------ FAST Cartoonizer (SD-Turbo) ------------------
126
  from diffusers import AutoPipelineForImage2Image
127
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
128
+ from transformers import AutoFeatureExtractor
129
+
130
  TURBO_ID = "stabilityai/sd-turbo"
131
 
132
  def load_turbo_pipe(device):
133
  dtype = torch.float16 if (device == "cuda") else torch.float32
134
  pipe = AutoPipelineForImage2Image.from_pretrained(
135
  TURBO_ID,
136
+ dtype=dtype, # ✅ no deprecation warning
 
137
  ).to(device)
138
+ # safety checker ON for public Spaces
139
+ pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
140
+ "CompVis/stable-diffusion-safety-checker"
141
+ )
142
+ pipe.feature_extractor = AutoFeatureExtractor.from_pretrained(
143
+ "CompVis/stable-diffusion-safety-checker"
144
+ )
145
  try:
146
  pipe.enable_attention_slicing()
147
  except Exception:
 
198
  top1, p1 = top[0]
199
  rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |")
200
  md = "\n".join(rows)
201
+ # also return a simple dict from the first face just to feed Label
202
  age0, top0 = age_est.predict(crops[0], topk=5)
203
  probs0 = {lbl: float(p) for lbl, p in top0}
204
  return probs0, md, annotated
 
234
  if not crops:
235
  crops = [pil] # fallback
236
 
 
237
  proc = []
238
  for c in crops:
239
+ c = _resize_512(c)
240
+ out = sd_pipe(
241
+ prompt=pos, negative_prompt=neg, image=c,
242
+ strength=float(strength), guidance_scale=0.0,
243
+ num_inference_steps=int(steps), generator=generator
244
+ )
245
+ proc.append(out.images[0])
246
+
247
+ # tile into a grid
248
+ n = len(proc)
249
+ cols = int(math.ceil(math.sqrt(n)))
250
+ rows = int(math.ceil(n / cols))
251
+ cell_w = max(im.width for im in proc)
252
+ cell_h = max(im.height for im in proc)
253
+ grid = Image.new("RGB", (cols * cell_w, rows * cell_h), (240, 240, 240))
254
+ for i, im in enumerate(proc):
255
+ r, c = divmod(i, cols)
256
+ grid.paste(im, (c * cell_w, r * cell_h))
257
+ return grid
258
+
259
+ # single person
260
+ face_wide = None
261
+ if auto_crop:
262
+ face_wide, _ = cropper.detect_one_wide(pil)
263
+ base = face_wide if face_wide is not None else pil
264
+ base = _resize_512(base)
265
+ out = sd_pipe(
266
+ prompt=pos, negative_prompt=neg, image=base,
267
+ strength=float(strength), guidance_scale=0.0,
268
+ num_inference_steps=int(steps), generator=generator
269
+ )
270
+ return out.images[0]
271
+
272
+ # ------------------ UI ------------------
273
+ with gr.Blocks(title="Group Age + Cartoons (Fast)") as demo:
274
+ gr.Markdown("# Predict ages and make fast cartoons — single or group photos")
275
+ with gr.Row():
276
+ with gr.Column(scale=1):
277
+ img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
278
+ group_mode = gr.Checkbox(False, label="Group photo (detect everyone)")
279
+ auto = gr.Checkbox(True, label="Auto face crop (wide)")
280
+ prompt = gr.Textbox(label="(Optional) Extra cartoon style",
281
+ placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette")
282
+ with gr.Row():
283
+ strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
284
+ steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
285
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
286
+ btn_age = gr.Button("Predict Age(s) (fast)", variant="primary")
287
+ btn_cartoon = gr.Button("Make Cartoon(s) (fast)", variant="secondary")
288
+
289
+ with gr.Column(scale=1):
290
+ probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities, first face)")
291
+ age_md = gr.Markdown(label="Age Table / Summary")
292
+ preview = gr.Image(label="Detection Preview (boxes)")
293
+ cartoon_out = gr.Image(label="Cartoon Result (grid for groups)")
294
+
295
+ btn_age.click(fn=predict_age, inputs=[img_in, group_mode, auto], outputs=[probs_out, age_md, preview])
296
+ btn_cartoon.click(fn=cartoonize, inputs=[img_in, prompt, group_mode, auto, strength, steps, seed], outputs=cartoon_out)
297
+
298
+ # Expose for Hugging Face Spaces
299
+ app = demo
300
+
301
+ if __name__ == "__main__":
302
+ app.queue().launch()