RedFish / app.py
hongyu12321's picture
Update app.py
8fc4073 verified
raw
history blame
6.88 kB
# app.py
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import gradio as gr
from PIL import Image
import numpy as np
import torch
from hf_model import PretrainedAgeEstimator
from face_utils import FaceCropper
# NEW: diffusers for cartoonizer
from diffusers import StableDiffusionImg2ImgPipeline
# ---------- Load models once ----------
est = PretrainedAgeEstimator()
cropper = FaceCropper(device=est.device)
# A solid, public SD 1.5 img2img pipeline; fast and reliable
SD15_ID = "runwayml/stable-diffusion-v1-5"
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
SD15_ID,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None, # rely on prompts; HF Spaces also has a global filter
).to(est.device)
# ---------- Helpers ----------
def _ensure_pil(img):
if isinstance(img, Image.Image):
return img
return Image.fromarray(img)
# ----- Age: single image -----
def predict_single(img, auto_crop=True, topk=5, show_annot=True):
if img is None:
return {}, "No image provided.", None
img = _ensure_pil(img).convert("RGB")
preview = img
face = None
if auto_crop:
face, annotated, _ = cropper.detect_and_crop(img, select="largest")
preview = annotated if show_annot else img
target = face if face is not None else img
age, top = est.predict(target, topk=topk)
probs = {lbl: float(prob) for lbl, prob in top}
summary = f"**Estimated age:** {age:.1f} years"
return probs, summary, preview
# ----- Age: batch -----
def predict_batch(files, auto_crop=True, topk=5):
if not files:
return "No files uploaded."
rows = ["| File | Estimated Age | Top-1 | p |", "|---|---:|---|---:|"]
for f in files:
try:
img = Image.open(f.name).convert("RGB")
face = None
if auto_crop:
face, _, _ = cropper.detect_and_crop(img, select="largest")
target = face if face is not None else img
age, top = est.predict(target, topk=topk)
top1_lbl, top1_p = top[0]
rows.append(f"| {os.path.basename(f.name)} | {age:.1f} | {top1_lbl} | {top1_p:.3f} |")
except Exception:
rows.append(f"| {os.path.basename(f.name)} | (error) | - | - |")
return "\n".join(rows)
# ----- NEW: Cartoonizer (img2img) -----
def cartoonize(img, prompt, strength=0.6, guidance=7.5, steps=25, seed=0, use_face_crop=True):
"""
img: PIL or numpy
prompt: text description, e.g. "cute cel-shaded cartoon, soft outlines, vibrant colors"
strength: how much to deviate from the input (0.3 subtle β†’ 0.8 strong)
guidance: prompt strength (5–12 typical)
steps: diffusion steps (20–40 typical)
seed: reproducibility (-1 for random)
"""
if img is None:
return None
img = _ensure_pil(img).convert("RGB")
# optional crop to the largest face for better identity preservation
if use_face_crop:
face, _, _ = cropper.detect_and_crop(img, select="largest")
if face is not None:
img = face
# cartoon-y defaults (you can tweak in UI)
base_prompt = (
"cartoon, cel-shaded, clean lineart, smooth shading, high contrast, vibrant, studio ghibli style, "
"pixar style, highly detailed, 2D illustration"
)
full_prompt = f"{base_prompt}, {prompt}".strip().strip(",")
generator = None
if seed and seed >= 0:
generator = torch.Generator(device=est.device).manual_seed(int(seed))
out = sd_pipe(
prompt=full_prompt,
image=img,
strength=float(strength),
guidance_scale=float(guidance),
num_inference_steps=int(steps),
generator=generator,
)
result = out.images[0]
return result
# ---------- UI ----------
with gr.Blocks(title="Pretrained Age Estimator + Cartoonizer") as demo:
gr.Markdown("# Pretrained Age Estimator + Cartoonizer")
gr.Markdown("Detects age from a face and can also generate a cartoonized image guided by your text description.")
with gr.Tabs():
with gr.Tab("Age (Single)"):
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Upload a face image")
cam = gr.Image(sources=["webcam"], type="pil", label="Webcam (optional)")
auto = gr.Checkbox(True, label="Auto face crop (MTCNN)")
topk = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
annot = gr.Checkbox(True, label="Show detection preview")
btn = gr.Button("Predict Age", variant="primary")
with gr.Column():
out_label = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
out_md = gr.Markdown(label="Summary")
out_prev = gr.Image(label="Preview", visible=True)
def run_single(img, cam_img, auto_crop, topk_val, show_annot):
chosen = cam_img if cam_img is not None else img
return predict_single(chosen, auto_crop, int(topk_val), show_annot)
btn.click(fn=run_single, inputs=[inp, cam, auto, topk, annot],
outputs=[out_label, out_md, out_prev])
with gr.Tab("Age (Batch)"):
files = gr.Files(label="Upload multiple images")
auto_b = gr.Checkbox(True, label="Auto face crop (MTCNN)")
topk_b = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
btn_b = gr.Button("Run batch")
out_table = gr.Markdown()
btn_b.click(fn=predict_batch, inputs=[files, auto_b, topk_b], outputs=out_table)
with gr.Tab("Cartoonizer"):
src = gr.Image(type="pil", label="Source image (face or any photo)")
prompt = gr.Textbox(label="Your style prompt",
value="cute cel-shaded cartoon, clean lines, soft colors")
with gr.Row():
strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Transformation strength")
guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance scale")
steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
seed = gr.Number(value=0, precision=0, label="Seed (0 or -1 = random)")
use_crop = gr.Checkbox(True, label="Crop to largest face before stylizing")
btn_c = gr.Button("Generate Cartoon", variant="primary")
out_img = gr.Image(label="Cartoon result")
btn_c.click(fn=cartoonize,
inputs=[src, prompt, strength, guidance, steps, seed, use_crop],
outputs=out_img)
if __name__ == "__main__":
demo.launch()