Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# app.py — Age-first + FAST group cartoons (SD-Turbo), single page
|
| 2 |
|
| 3 |
import os
|
| 4 |
os.environ["TRANSFORMERS_NO_TF"] = "1"
|
|
@@ -6,6 +6,8 @@ os.environ["TRANSFORMERS_NO_FLAX"] = "1"
|
|
| 6 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 7 |
|
| 8 |
import math
|
|
|
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
from PIL import Image, ImageDraw
|
| 11 |
import numpy as np
|
|
@@ -21,7 +23,7 @@ AGE_RANGE_TO_MID = {
|
|
| 21 |
}
|
| 22 |
|
| 23 |
class PretrainedAgeEstimator:
|
| 24 |
-
def __init__(self, model_id: str = HF_MODEL_ID, device: str
|
| 25 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 27 |
self.model = AutoModelForImageClassification.from_pretrained(model_id)
|
|
@@ -52,7 +54,7 @@ class FaceCropper:
|
|
| 52 |
- detect_all_wide: returns (list[crops], annotated, list[boxes])
|
| 53 |
Boxes are (x1,y1,x2,y2) floats.
|
| 54 |
"""
|
| 55 |
-
def __init__(self, device: str
|
| 56 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
self.mtcnn = MTCNN(keep_all=True, device=self.device)
|
| 58 |
self.margin_scale = margin_scale
|
|
@@ -62,13 +64,13 @@ class FaceCropper:
|
|
| 62 |
return img.convert("RGB")
|
| 63 |
return Image.fromarray(img).convert("RGB")
|
| 64 |
|
| 65 |
-
def _expand_box(self, box, W, H, aspect=0.8): # 4:5 portrait (w/h=0.8)
|
| 66 |
x1, y1, x2, y2 = box
|
| 67 |
cx, cy = (x1 + x2)/2, (y1 + y2)/2
|
| 68 |
w, h = (x2 - x1), (y2 - y1)
|
| 69 |
side = max(w, h) * self.margin_scale
|
| 70 |
tw = side
|
| 71 |
-
th = side / aspect #
|
| 72 |
nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
|
| 73 |
ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
|
| 74 |
return nx1, ny1, nx2, ny2
|
|
@@ -108,6 +110,7 @@ class FaceCropper:
|
|
| 108 |
if boxes is None or len(boxes) == 0:
|
| 109 |
return crops, annotated, []
|
| 110 |
|
|
|
|
| 111 |
for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2):
|
| 112 |
bx1, by1, bx2, by2 = map(float, b)
|
| 113 |
draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
|
|
@@ -121,15 +124,24 @@ class FaceCropper:
|
|
| 121 |
|
| 122 |
# ------------------ FAST Cartoonizer (SD-Turbo) ------------------
|
| 123 |
from diffusers import AutoPipelineForImage2Image
|
|
|
|
|
|
|
|
|
|
| 124 |
TURBO_ID = "stabilityai/sd-turbo"
|
| 125 |
|
| 126 |
def load_turbo_pipe(device):
|
| 127 |
dtype = torch.float16 if (device == "cuda") else torch.float32
|
| 128 |
pipe = AutoPipelineForImage2Image.from_pretrained(
|
| 129 |
TURBO_ID,
|
| 130 |
-
|
| 131 |
-
safety_checker=None,
|
| 132 |
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
try:
|
| 134 |
pipe.enable_attention_slicing()
|
| 135 |
except Exception:
|
|
@@ -186,7 +198,7 @@ def predict_age(img, group_mode=False, auto_crop=True):
|
|
| 186 |
top1, p1 = top[0]
|
| 187 |
rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |")
|
| 188 |
md = "\n".join(rows)
|
| 189 |
-
# also return a simple dict from the
|
| 190 |
age0, top0 = age_est.predict(crops[0], topk=5)
|
| 191 |
probs0 = {lbl: float(p) for lbl, p in top0}
|
| 192 |
return probs0, md, annotated
|
|
@@ -222,7 +234,69 @@ def cartoonize(img, prompt="", group_mode=False, auto_crop=True, strength=0.5, s
|
|
| 222 |
if not crops:
|
| 223 |
crops = [pil] # fallback
|
| 224 |
|
| 225 |
-
# resize each to 384 for speed/variety
|
| 226 |
proc = []
|
| 227 |
for c in crops:
|
| 228 |
-
c =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py — Age-first + FAST group cartoons (SD-Turbo), single page (HF Spaces safe)
|
| 2 |
|
| 3 |
import os
|
| 4 |
os.environ["TRANSFORMERS_NO_TF"] = "1"
|
|
|
|
| 6 |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 7 |
|
| 8 |
import math
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
import gradio as gr
|
| 12 |
from PIL import Image, ImageDraw
|
| 13 |
import numpy as np
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
class PretrainedAgeEstimator:
|
| 26 |
+
def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None):
|
| 27 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 29 |
self.model = AutoModelForImageClassification.from_pretrained(model_id)
|
|
|
|
| 54 |
- detect_all_wide: returns (list[crops], annotated, list[boxes])
|
| 55 |
Boxes are (x1,y1,x2,y2) floats.
|
| 56 |
"""
|
| 57 |
+
def __init__(self, device: Optional[str] = None, margin_scale: float = 1.8):
|
| 58 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 59 |
self.mtcnn = MTCNN(keep_all=True, device=self.device)
|
| 60 |
self.margin_scale = margin_scale
|
|
|
|
| 64 |
return img.convert("RGB")
|
| 65 |
return Image.fromarray(img).convert("RGB")
|
| 66 |
|
| 67 |
+
def _expand_box(self, box, W, H, aspect=0.8): # ~4:5 portrait (w/h=0.8)
|
| 68 |
x1, y1, x2, y2 = box
|
| 69 |
cx, cy = (x1 + x2)/2, (y1 + y2)/2
|
| 70 |
w, h = (x2 - x1), (y2 - y1)
|
| 71 |
side = max(w, h) * self.margin_scale
|
| 72 |
tw = side
|
| 73 |
+
th = side / aspect # taller than wide
|
| 74 |
nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2))
|
| 75 |
ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2))
|
| 76 |
return nx1, ny1, nx2, ny2
|
|
|
|
| 110 |
if boxes is None or len(boxes) == 0:
|
| 111 |
return crops, annotated, []
|
| 112 |
|
| 113 |
+
# sort roughly left->right for table order
|
| 114 |
for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2):
|
| 115 |
bx1, by1, bx2, by2 = map(float, b)
|
| 116 |
draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3)
|
|
|
|
| 124 |
|
| 125 |
# ------------------ FAST Cartoonizer (SD-Turbo) ------------------
|
| 126 |
from diffusers import AutoPipelineForImage2Image
|
| 127 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 128 |
+
from transformers import AutoFeatureExtractor
|
| 129 |
+
|
| 130 |
TURBO_ID = "stabilityai/sd-turbo"
|
| 131 |
|
| 132 |
def load_turbo_pipe(device):
|
| 133 |
dtype = torch.float16 if (device == "cuda") else torch.float32
|
| 134 |
pipe = AutoPipelineForImage2Image.from_pretrained(
|
| 135 |
TURBO_ID,
|
| 136 |
+
dtype=dtype, # ✅ no deprecation warning
|
|
|
|
| 137 |
).to(device)
|
| 138 |
+
# safety checker ON for public Spaces
|
| 139 |
+
pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
| 140 |
+
"CompVis/stable-diffusion-safety-checker"
|
| 141 |
+
)
|
| 142 |
+
pipe.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
| 143 |
+
"CompVis/stable-diffusion-safety-checker"
|
| 144 |
+
)
|
| 145 |
try:
|
| 146 |
pipe.enable_attention_slicing()
|
| 147 |
except Exception:
|
|
|
|
| 198 |
top1, p1 = top[0]
|
| 199 |
rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |")
|
| 200 |
md = "\n".join(rows)
|
| 201 |
+
# also return a simple dict from the first face just to feed Label
|
| 202 |
age0, top0 = age_est.predict(crops[0], topk=5)
|
| 203 |
probs0 = {lbl: float(p) for lbl, p in top0}
|
| 204 |
return probs0, md, annotated
|
|
|
|
| 234 |
if not crops:
|
| 235 |
crops = [pil] # fallback
|
| 236 |
|
|
|
|
| 237 |
proc = []
|
| 238 |
for c in crops:
|
| 239 |
+
c = _resize_512(c)
|
| 240 |
+
out = sd_pipe(
|
| 241 |
+
prompt=pos, negative_prompt=neg, image=c,
|
| 242 |
+
strength=float(strength), guidance_scale=0.0,
|
| 243 |
+
num_inference_steps=int(steps), generator=generator
|
| 244 |
+
)
|
| 245 |
+
proc.append(out.images[0])
|
| 246 |
+
|
| 247 |
+
# tile into a grid
|
| 248 |
+
n = len(proc)
|
| 249 |
+
cols = int(math.ceil(math.sqrt(n)))
|
| 250 |
+
rows = int(math.ceil(n / cols))
|
| 251 |
+
cell_w = max(im.width for im in proc)
|
| 252 |
+
cell_h = max(im.height for im in proc)
|
| 253 |
+
grid = Image.new("RGB", (cols * cell_w, rows * cell_h), (240, 240, 240))
|
| 254 |
+
for i, im in enumerate(proc):
|
| 255 |
+
r, c = divmod(i, cols)
|
| 256 |
+
grid.paste(im, (c * cell_w, r * cell_h))
|
| 257 |
+
return grid
|
| 258 |
+
|
| 259 |
+
# single person
|
| 260 |
+
face_wide = None
|
| 261 |
+
if auto_crop:
|
| 262 |
+
face_wide, _ = cropper.detect_one_wide(pil)
|
| 263 |
+
base = face_wide if face_wide is not None else pil
|
| 264 |
+
base = _resize_512(base)
|
| 265 |
+
out = sd_pipe(
|
| 266 |
+
prompt=pos, negative_prompt=neg, image=base,
|
| 267 |
+
strength=float(strength), guidance_scale=0.0,
|
| 268 |
+
num_inference_steps=int(steps), generator=generator
|
| 269 |
+
)
|
| 270 |
+
return out.images[0]
|
| 271 |
+
|
| 272 |
+
# ------------------ UI ------------------
|
| 273 |
+
with gr.Blocks(title="Group Age + Cartoons (Fast)") as demo:
|
| 274 |
+
gr.Markdown("# Predict ages and make fast cartoons — single or group photos")
|
| 275 |
+
with gr.Row():
|
| 276 |
+
with gr.Column(scale=1):
|
| 277 |
+
img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam")
|
| 278 |
+
group_mode = gr.Checkbox(False, label="Group photo (detect everyone)")
|
| 279 |
+
auto = gr.Checkbox(True, label="Auto face crop (wide)")
|
| 280 |
+
prompt = gr.Textbox(label="(Optional) Extra cartoon style",
|
| 281 |
+
placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette")
|
| 282 |
+
with gr.Row():
|
| 283 |
+
strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength")
|
| 284 |
+
steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)")
|
| 285 |
+
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
|
| 286 |
+
btn_age = gr.Button("Predict Age(s) (fast)", variant="primary")
|
| 287 |
+
btn_cartoon = gr.Button("Make Cartoon(s) (fast)", variant="secondary")
|
| 288 |
+
|
| 289 |
+
with gr.Column(scale=1):
|
| 290 |
+
probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities, first face)")
|
| 291 |
+
age_md = gr.Markdown(label="Age Table / Summary")
|
| 292 |
+
preview = gr.Image(label="Detection Preview (boxes)")
|
| 293 |
+
cartoon_out = gr.Image(label="Cartoon Result (grid for groups)")
|
| 294 |
+
|
| 295 |
+
btn_age.click(fn=predict_age, inputs=[img_in, group_mode, auto], outputs=[probs_out, age_md, preview])
|
| 296 |
+
btn_cartoon.click(fn=cartoonize, inputs=[img_in, prompt, group_mode, auto, strength, steps, seed], outputs=cartoon_out)
|
| 297 |
+
|
| 298 |
+
# Expose for Hugging Face Spaces
|
| 299 |
+
app = demo
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
app.queue().launch()
|