Spaces:
Sleeping
Sleeping
| # app.py — Age-first + FAST group cartoons (SD-Turbo), single page (HF Spaces safe) | |
| import os | |
| os.environ["TRANSFORMERS_NO_TF"] = "1" | |
| os.environ["TRANSFORMERS_NO_FLAX"] = "1" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import math | |
| from typing import Optional | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import torch | |
| # ------------------ Age estimator ------------------ | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| HF_MODEL_ID = "nateraw/vit-age-classifier" | |
| AGE_RANGE_TO_MID = { | |
| "0-2": 1, "3-9": 6, "10-19": 15, "20-29": 25, "30-39": 35, | |
| "40-49": 45, "50-59": 55, "60-69": 65, "70+": 75 | |
| } | |
| class PretrainedAgeEstimator: | |
| def __init__(self, model_id: str = HF_MODEL_ID, device: Optional[str] = None): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True) | |
| self.model = AutoModelForImageClassification.from_pretrained(model_id) | |
| self.model.to(self.device).eval() | |
| self.id2label = self.model.config.id2label | |
| def predict(self, img: Image.Image, topk: int = 5): | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| inputs = self.processor(images=img, return_tensors="pt").to(self.device) | |
| logits = self.model(**inputs).logits | |
| probs = logits.softmax(dim=-1).squeeze(0) | |
| k = min(topk, probs.numel()) | |
| values, indices = torch.topk(probs, k=k) | |
| top = [(self.id2label[i.item()], float(v.item())) for i, v in zip(indices, values)] | |
| expected = sum(AGE_RANGE_TO_MID.get(self.id2label[i], 35) * float(p) | |
| for i, p in enumerate(probs)) | |
| return expected, top | |
| # ------------------ Face detection (single & group) ------------------ | |
| from facenet_pytorch import MTCNN | |
| class FaceCropper: | |
| """ | |
| Detect faces. | |
| - detect_one_wide: returns (crop_with_margin, annotated) | |
| - detect_all_wide: returns (list[crops], annotated, list[boxes]) | |
| Boxes are (x1,y1,x2,y2) floats. | |
| """ | |
| def __init__(self, device: Optional[str] = None, margin_scale: float = 1.8): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.mtcnn = MTCNN(keep_all=True, device=self.device) | |
| self.margin_scale = margin_scale | |
| def _ensure_pil(self, img): | |
| if isinstance(img, Image.Image): | |
| return img.convert("RGB") | |
| return Image.fromarray(img).convert("RGB") | |
| def _expand_box(self, box, W, H, aspect=0.8): # ~4:5 portrait (w/h=0.8) | |
| x1, y1, x2, y2 = box | |
| cx, cy = (x1 + x2)/2, (y1 + y2)/2 | |
| w, h = (x2 - x1), (y2 - y1) | |
| side = max(w, h) * self.margin_scale | |
| tw = side | |
| th = side / aspect # taller than wide | |
| nx1 = int(max(0, cx - tw/2)); nx2 = int(min(W, cx + tw/2)) | |
| ny1 = int(max(0, cy - th/2)); ny2 = int(min(H, cy + th/2)) | |
| return nx1, ny1, nx2, ny2 | |
| def detect_one_wide(self, img): | |
| pil = self._ensure_pil(img) | |
| W, H = pil.size | |
| boxes, probs = self.mtcnn.detect(pil) | |
| annotated = pil.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| if boxes is None or len(boxes) == 0: | |
| return None, annotated | |
| # draw all boxes | |
| for b, p in zip(boxes, probs): | |
| bx1, by1, bx2, by2 = map(float, b) | |
| draw.rectangle([bx1, by1, bx2, by2], outline=(255, 0, 0), width=3) | |
| draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(255, 0, 0)) | |
| # choose largest | |
| idx = int(np.argmax([(b[2]-b[0])*(b[3]-b[1]) for b in boxes])) | |
| nx1, ny1, nx2, ny2 = self._expand_box(boxes[idx], W, H) | |
| crop = pil.crop((nx1, ny1, nx2, ny2)) | |
| return crop, annotated | |
| def detect_all_wide(self, img): | |
| pil = self._ensure_pil(img) | |
| W, H = pil.size | |
| boxes, probs = self.mtcnn.detect(pil) | |
| annotated = pil.copy() | |
| draw = ImageDraw.Draw(annotated) | |
| crops = [] | |
| ordered = [] | |
| if boxes is None or len(boxes) == 0: | |
| return crops, annotated, [] | |
| # sort roughly left->right for table order | |
| for b, p in sorted(zip(boxes, probs), key=lambda x: (x[0][0]+x[0][2])/2): | |
| bx1, by1, bx2, by2 = map(float, b) | |
| draw.rectangle([bx1, by1, bx2, by2], outline=(0, 200, 255), width=3) | |
| draw.text((bx1, max(0, by1-12)), f"{p:.2f}", fill=(0, 200, 255)) | |
| nx1, ny1, nx2, ny2 = self._expand_box(b, W, H) | |
| crops.append(pil.crop((nx1, ny1, nx2, ny2))) | |
| ordered.append((bx1, by1, bx2, by2)) | |
| return crops, annotated, ordered | |
| # ------------------ FAST Cartoonizer (SD-Turbo) ------------------ | |
| from diffusers import AutoPipelineForImage2Image | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | |
| from transformers import AutoFeatureExtractor | |
| TURBO_ID = "stabilityai/sd-turbo" | |
| def load_turbo_pipe(device): | |
| dtype = torch.float16 if (device == "cuda") else torch.float32 | |
| pipe = AutoPipelineForImage2Image.from_pretrained( | |
| TURBO_ID, | |
| dtype=dtype, # ✅ no deprecation warning | |
| ).to(device) | |
| # safety checker ON for public Spaces | |
| pipe.safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
| "CompVis/stable-diffusion-safety-checker" | |
| ) | |
| pipe.feature_extractor = AutoFeatureExtractor.from_pretrained( | |
| "CompVis/stable-diffusion-safety-checker" | |
| ) | |
| try: | |
| pipe.enable_attention_slicing() | |
| except Exception: | |
| pass | |
| return pipe | |
| # init models once | |
| age_est = PretrainedAgeEstimator() | |
| cropper = FaceCropper(device=age_est.device, margin_scale=1.9) | |
| sd_pipe = load_turbo_pipe(age_est.device) | |
| # prompts | |
| DEFAULT_POSITIVE = ( | |
| "beautiful princess portrait, elegant gown, tiara, soft magical lighting, " | |
| "sparkles, dreamy castle background, painterly, clean lineart, vibrant but natural colors, " | |
| "storybook illustration, high quality" | |
| ) | |
| DEFAULT_NEGATIVE = ( | |
| "deformed, disfigured, ugly, extra limbs, extra fingers, bad anatomy, low quality, " | |
| "blurry, watermark, text, logo" | |
| ) | |
| def _ensure_pil(img): | |
| return img if isinstance(img, Image.Image) else Image.fromarray(img) | |
| def _resize_512(im: Image.Image): | |
| w, h = im.size | |
| scale = 512 / max(w, h) | |
| if scale < 1.0: | |
| im = im.resize((int(w*scale), int(h*scale)), Image.LANCZOS) | |
| return im | |
| # ------------- AGE (single/group) ------------- | |
| def predict_age(img, group_mode=False, auto_crop=True): | |
| if img is None: | |
| return {}, "Please upload an image.", None | |
| pil = _ensure_pil(img).convert("RGB") | |
| if group_mode: | |
| crops, annotated, boxes = cropper.detect_all_wide(pil) | |
| if not crops: | |
| # fallback to full image | |
| age, top = age_est.predict(pil, topk=5) | |
| probs = {lbl: float(p) for lbl, p in top} | |
| md = f"**Estimated age (whole image):** {age:.1f} years" | |
| return probs, md, pil | |
| # per-face ages | |
| rows = ["| # | Age (yrs) | Top-1 | p |", "|---:|---:|---|---:|"] | |
| for i, face in enumerate(crops, 1): | |
| age, top = age_est.predict(face, topk=3) | |
| top1, p1 = top[0] | |
| rows.append(f"| {i} | {age:.1f} | {top1} | {p1:.2f} |") | |
| md = "\n".join(rows) | |
| # also return a simple dict from the first face just to feed Label | |
| age0, top0 = age_est.predict(crops[0], topk=5) | |
| probs0 = {lbl: float(p) for lbl, p in top0} | |
| return probs0, md, annotated | |
| # single | |
| face_wide = None; annotated = None | |
| if auto_crop: | |
| face_wide, annotated = cropper.detect_one_wide(pil) | |
| target = face_wide if face_wide is not None else pil | |
| age, top = age_est.predict(target, topk=5) | |
| probs = {lbl: float(p) for lbl, p in top} | |
| md = f"**Estimated age:** {age:.1f} years" | |
| return probs, md, (annotated if annotated is not None else pil) | |
| # ------------- CARTOON (single/group) ------------- | |
| def cartoonize(img, prompt="", group_mode=False, auto_crop=True, strength=0.5, steps=2, seed=-1): | |
| if img is None: | |
| return None | |
| pil = _ensure_pil(img).convert("RGB") | |
| user = (prompt or "").strip() | |
| pos = DEFAULT_POSITIVE if not user else f"{DEFAULT_POSITIVE}, {user}" | |
| neg = DEFAULT_NEGATIVE | |
| generator = None | |
| if isinstance(seed, (int, float)) and int(seed) >= 0: | |
| generator = torch.Generator(device=age_est.device).manual_seed(int(seed)) | |
| if group_mode: | |
| # detect all faces, stylize each, assemble grid | |
| crops, _, _ = cropper.detect_all_wide(pil) | |
| if not crops: | |
| crops = [pil] # fallback | |
| proc = [] | |
| for c in crops: | |
| c = _resize_512(c) | |
| out = sd_pipe( | |
| prompt=pos, negative_prompt=neg, image=c, | |
| strength=float(strength), guidance_scale=0.0, | |
| num_inference_steps=int(steps), generator=generator | |
| ) | |
| proc.append(out.images[0]) | |
| # tile into a grid | |
| n = len(proc) | |
| cols = int(math.ceil(math.sqrt(n))) | |
| rows = int(math.ceil(n / cols)) | |
| cell_w = max(im.width for im in proc) | |
| cell_h = max(im.height for im in proc) | |
| grid = Image.new("RGB", (cols * cell_w, rows * cell_h), (240, 240, 240)) | |
| for i, im in enumerate(proc): | |
| r, c = divmod(i, cols) | |
| grid.paste(im, (c * cell_w, r * cell_h)) | |
| return grid | |
| # single person | |
| face_wide = None | |
| if auto_crop: | |
| face_wide, _ = cropper.detect_one_wide(pil) | |
| base = face_wide if face_wide is not None else pil | |
| base = _resize_512(base) | |
| out = sd_pipe( | |
| prompt=pos, negative_prompt=neg, image=base, | |
| strength=float(strength), guidance_scale=0.0, | |
| num_inference_steps=int(steps), generator=generator | |
| ) | |
| return out.images[0] | |
| # ------------------ UI ------------------ | |
| with gr.Blocks(title="Group Age + Cartoons (Fast)") as demo: | |
| gr.Markdown("# Predict ages and make fast cartoons — single or group photos") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(sources=["upload", "webcam"], type="pil", label="Upload / Webcam") | |
| group_mode = gr.Checkbox(False, label="Group photo (detect everyone)") | |
| auto = gr.Checkbox(True, label="Auto face crop (wide)") | |
| prompt = gr.Textbox(label="(Optional) Extra cartoon style", | |
| placeholder="e.g., studio ghibli watercolor, soft bokeh, pastel palette") | |
| with gr.Row(): | |
| strength = gr.Slider(0.3, 0.8, value=0.5, step=0.05, label="Cartoon strength") | |
| steps = gr.Slider(1, 4, value=2, step=1, label="Turbo steps (1–4)") | |
| seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)") | |
| btn_age = gr.Button("Predict Age(s) (fast)", variant="primary") | |
| btn_cartoon = gr.Button("Make Cartoon(s) (fast)", variant="secondary") | |
| with gr.Column(scale=1): | |
| probs_out = gr.Label(num_top_classes=5, label="Age Prediction (probabilities, first face)") | |
| age_md = gr.Markdown(label="Age Table / Summary") | |
| preview = gr.Image(label="Detection Preview (boxes)") | |
| cartoon_out = gr.Image(label="Cartoon Result (grid for groups)") | |
| btn_age.click(fn=predict_age, inputs=[img_in, group_mode, auto], outputs=[probs_out, age_md, preview]) | |
| btn_cartoon.click(fn=cartoonize, inputs=[img_in, prompt, group_mode, auto, strength, steps, seed], outputs=cartoon_out) | |
| # Expose for Hugging Face Spaces | |
| app = demo | |
| if __name__ == "__main__": | |
| app.queue().launch() | |