Spaces:
Sleeping
Sleeping
File size: 6,876 Bytes
8fc4073 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# app.py
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import gradio as gr
from PIL import Image
import numpy as np
import torch
from hf_model import PretrainedAgeEstimator
from face_utils import FaceCropper
# NEW: diffusers for cartoonizer
from diffusers import StableDiffusionImg2ImgPipeline
# ---------- Load models once ----------
est = PretrainedAgeEstimator()
cropper = FaceCropper(device=est.device)
# A solid, public SD 1.5 img2img pipeline; fast and reliable
SD15_ID = "runwayml/stable-diffusion-v1-5"
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
SD15_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None, # rely on prompts; HF Spaces also has a global filter
).to(est.device)
# ---------- Helpers ----------
def _ensure_pil(img):
if isinstance(img, Image.Image):
return img
return Image.fromarray(img)
# ----- Age: single image -----
def predict_single(img, auto_crop=True, topk=5, show_annot=True):
if img is None:
return {}, "No image provided.", None
img = _ensure_pil(img).convert("RGB")
preview = img
face = None
if auto_crop:
face, annotated, _ = cropper.detect_and_crop(img, select="largest")
preview = annotated if show_annot else img
target = face if face is not None else img
age, top = est.predict(target, topk=topk)
probs = {lbl: float(prob) for lbl, prob in top}
summary = f"**Estimated age:** {age:.1f} years"
return probs, summary, preview
# ----- Age: batch -----
def predict_batch(files, auto_crop=True, topk=5):
if not files:
return "No files uploaded."
rows = ["| File | Estimated Age | Top-1 | p |", "|---|---:|---|---:|"]
for f in files:
try:
img = Image.open(f.name).convert("RGB")
face = None
if auto_crop:
face, _, _ = cropper.detect_and_crop(img, select="largest")
target = face if face is not None else img
age, top = est.predict(target, topk=topk)
top1_lbl, top1_p = top[0]
rows.append(f"| {os.path.basename(f.name)} | {age:.1f} | {top1_lbl} | {top1_p:.3f} |")
except Exception:
rows.append(f"| {os.path.basename(f.name)} | (error) | - | - |")
return "\n".join(rows)
# ----- NEW: Cartoonizer (img2img) -----
def cartoonize(img, prompt, strength=0.6, guidance=7.5, steps=25, seed=0, use_face_crop=True):
"""
img: PIL or numpy
prompt: text description, e.g. "cute cel-shaded cartoon, soft outlines, vibrant colors"
strength: how much to deviate from the input (0.3 subtle → 0.8 strong)
guidance: prompt strength (5–12 typical)
steps: diffusion steps (20–40 typical)
seed: reproducibility (-1 for random)
"""
if img is None:
return None
img = _ensure_pil(img).convert("RGB")
# optional crop to the largest face for better identity preservation
if use_face_crop:
face, _, _ = cropper.detect_and_crop(img, select="largest")
if face is not None:
img = face
# cartoon-y defaults (you can tweak in UI)
base_prompt = (
"cartoon, cel-shaded, clean lineart, smooth shading, high contrast, vibrant, studio ghibli style, "
"pixar style, highly detailed, 2D illustration"
)
full_prompt = f"{base_prompt}, {prompt}".strip().strip(",")
generator = None
if seed and seed >= 0:
generator = torch.Generator(device=est.device).manual_seed(int(seed))
out = sd_pipe(
prompt=full_prompt,
image=img,
strength=float(strength),
guidance_scale=float(guidance),
num_inference_steps=int(steps),
generator=generator,
)
result = out.images[0]
return result
# ---------- UI ----------
with gr.Blocks(title="Pretrained Age Estimator + Cartoonizer") as demo:
gr.Markdown("# Pretrained Age Estimator + Cartoonizer")
gr.Markdown("Detects age from a face and can also generate a cartoonized image guided by your text description.")
with gr.Tabs():
with gr.Tab("Age (Single)"):
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Upload a face image")
cam = gr.Image(sources=["webcam"], type="pil", label="Webcam (optional)")
auto = gr.Checkbox(True, label="Auto face crop (MTCNN)")
topk = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
annot = gr.Checkbox(True, label="Show detection preview")
btn = gr.Button("Predict Age", variant="primary")
with gr.Column():
out_label = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
out_md = gr.Markdown(label="Summary")
out_prev = gr.Image(label="Preview", visible=True)
def run_single(img, cam_img, auto_crop, topk_val, show_annot):
chosen = cam_img if cam_img is not None else img
return predict_single(chosen, auto_crop, int(topk_val), show_annot)
btn.click(fn=run_single, inputs=[inp, cam, auto, topk, annot],
outputs=[out_label, out_md, out_prev])
with gr.Tab("Age (Batch)"):
files = gr.Files(label="Upload multiple images")
auto_b = gr.Checkbox(True, label="Auto face crop (MTCNN)")
topk_b = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
btn_b = gr.Button("Run batch")
out_table = gr.Markdown()
btn_b.click(fn=predict_batch, inputs=[files, auto_b, topk_b], outputs=out_table)
with gr.Tab("Cartoonizer"):
src = gr.Image(type="pil", label="Source image (face or any photo)")
prompt = gr.Textbox(label="Your style prompt",
value="cute cel-shaded cartoon, clean lines, soft colors")
with gr.Row():
strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Transformation strength")
guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance scale")
steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
seed = gr.Number(value=0, precision=0, label="Seed (0 or -1 = random)")
use_crop = gr.Checkbox(True, label="Crop to largest face before stylizing")
btn_c = gr.Button("Generate Cartoon", variant="primary")
out_img = gr.Image(label="Cartoon result")
btn_c.click(fn=cartoonize,
inputs=[src, prompt, strength, guidance, steps, seed, use_crop],
outputs=out_img)
if __name__ == "__main__":
demo.launch()
|