Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| os.environ["TRANSFORMERS_NO_TF"] = "1" | |
| os.environ["TRANSFORMERS_NO_FLAX"] = "1" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from hf_model import PretrainedAgeEstimator | |
| from face_utils import FaceCropper | |
| # NEW: diffusers for cartoonizer | |
| from diffusers import StableDiffusionImg2ImgPipeline | |
| # ---------- Load models once ---------- | |
| est = PretrainedAgeEstimator() | |
| cropper = FaceCropper(device=est.device) | |
| # A solid, public SD 1.5 img2img pipeline; fast and reliable | |
| SD15_ID = "runwayml/stable-diffusion-v1-5" | |
| sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| SD15_ID, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| safety_checker=None, # rely on prompts; HF Spaces also has a global filter | |
| ).to(est.device) | |
| # ---------- Helpers ---------- | |
| def _ensure_pil(img): | |
| if isinstance(img, Image.Image): | |
| return img | |
| return Image.fromarray(img) | |
| # ----- Age: single image ----- | |
| def predict_single(img, auto_crop=True, topk=5, show_annot=True): | |
| if img is None: | |
| return {}, "No image provided.", None | |
| img = _ensure_pil(img).convert("RGB") | |
| preview = img | |
| face = None | |
| if auto_crop: | |
| face, annotated, _ = cropper.detect_and_crop(img, select="largest") | |
| preview = annotated if show_annot else img | |
| target = face if face is not None else img | |
| age, top = est.predict(target, topk=topk) | |
| probs = {lbl: float(prob) for lbl, prob in top} | |
| summary = f"**Estimated age:** {age:.1f} years" | |
| return probs, summary, preview | |
| # ----- Age: batch ----- | |
| def predict_batch(files, auto_crop=True, topk=5): | |
| if not files: | |
| return "No files uploaded." | |
| rows = ["| File | Estimated Age | Top-1 | p |", "|---|---:|---|---:|"] | |
| for f in files: | |
| try: | |
| img = Image.open(f.name).convert("RGB") | |
| face = None | |
| if auto_crop: | |
| face, _, _ = cropper.detect_and_crop(img, select="largest") | |
| target = face if face is not None else img | |
| age, top = est.predict(target, topk=topk) | |
| top1_lbl, top1_p = top[0] | |
| rows.append(f"| {os.path.basename(f.name)} | {age:.1f} | {top1_lbl} | {top1_p:.3f} |") | |
| except Exception: | |
| rows.append(f"| {os.path.basename(f.name)} | (error) | - | - |") | |
| return "\n".join(rows) | |
| # ----- NEW: Cartoonizer (img2img) ----- | |
| def cartoonize(img, prompt, strength=0.6, guidance=7.5, steps=25, seed=0, use_face_crop=True): | |
| """ | |
| img: PIL or numpy | |
| prompt: text description, e.g. "cute cel-shaded cartoon, soft outlines, vibrant colors" | |
| strength: how much to deviate from the input (0.3 subtle β 0.8 strong) | |
| guidance: prompt strength (5β12 typical) | |
| steps: diffusion steps (20β40 typical) | |
| seed: reproducibility (-1 for random) | |
| """ | |
| if img is None: | |
| return None | |
| img = _ensure_pil(img).convert("RGB") | |
| # optional crop to the largest face for better identity preservation | |
| if use_face_crop: | |
| face, _, _ = cropper.detect_and_crop(img, select="largest") | |
| if face is not None: | |
| img = face | |
| # cartoon-y defaults (you can tweak in UI) | |
| base_prompt = ( | |
| "cartoon, cel-shaded, clean lineart, smooth shading, high contrast, vibrant, studio ghibli style, " | |
| "pixar style, highly detailed, 2D illustration" | |
| ) | |
| full_prompt = f"{base_prompt}, {prompt}".strip().strip(",") | |
| generator = None | |
| if seed and seed >= 0: | |
| generator = torch.Generator(device=est.device).manual_seed(int(seed)) | |
| out = sd_pipe( | |
| prompt=full_prompt, | |
| image=img, | |
| strength=float(strength), | |
| guidance_scale=float(guidance), | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| ) | |
| result = out.images[0] | |
| return result | |
| # ---------- UI ---------- | |
| with gr.Blocks(title="Pretrained Age Estimator + Cartoonizer") as demo: | |
| gr.Markdown("# Pretrained Age Estimator + Cartoonizer") | |
| gr.Markdown("Detects age from a face and can also generate a cartoonized image guided by your text description.") | |
| with gr.Tabs(): | |
| with gr.Tab("Age (Single)"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="Upload a face image") | |
| cam = gr.Image(sources=["webcam"], type="pil", label="Webcam (optional)") | |
| auto = gr.Checkbox(True, label="Auto face crop (MTCNN)") | |
| topk = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges") | |
| annot = gr.Checkbox(True, label="Show detection preview") | |
| btn = gr.Button("Predict Age", variant="primary") | |
| with gr.Column(): | |
| out_label = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)") | |
| out_md = gr.Markdown(label="Summary") | |
| out_prev = gr.Image(label="Preview", visible=True) | |
| def run_single(img, cam_img, auto_crop, topk_val, show_annot): | |
| chosen = cam_img if cam_img is not None else img | |
| return predict_single(chosen, auto_crop, int(topk_val), show_annot) | |
| btn.click(fn=run_single, inputs=[inp, cam, auto, topk, annot], | |
| outputs=[out_label, out_md, out_prev]) | |
| with gr.Tab("Age (Batch)"): | |
| files = gr.Files(label="Upload multiple images") | |
| auto_b = gr.Checkbox(True, label="Auto face crop (MTCNN)") | |
| topk_b = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges") | |
| btn_b = gr.Button("Run batch") | |
| out_table = gr.Markdown() | |
| btn_b.click(fn=predict_batch, inputs=[files, auto_b, topk_b], outputs=out_table) | |
| with gr.Tab("Cartoonizer"): | |
| src = gr.Image(type="pil", label="Source image (face or any photo)") | |
| prompt = gr.Textbox(label="Your style prompt", | |
| value="cute cel-shaded cartoon, clean lines, soft colors") | |
| with gr.Row(): | |
| strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Transformation strength") | |
| guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance scale") | |
| steps = gr.Slider(10, 50, value=25, step=1, label="Steps") | |
| seed = gr.Number(value=0, precision=0, label="Seed (0 or -1 = random)") | |
| use_crop = gr.Checkbox(True, label="Crop to largest face before stylizing") | |
| btn_c = gr.Button("Generate Cartoon", variant="primary") | |
| out_img = gr.Image(label="Cartoon result") | |
| btn_c.click(fn=cartoonize, | |
| inputs=[src, prompt, strength, guidance, steps, seed, use_crop], | |
| outputs=out_img) | |
| if __name__ == "__main__": | |
| demo.launch() | |