File size: 6,876 Bytes
8fc4073
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# app.py
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import gradio as gr
from PIL import Image
import numpy as np
import torch

from hf_model import PretrainedAgeEstimator
from face_utils import FaceCropper

# NEW: diffusers for cartoonizer
from diffusers import StableDiffusionImg2ImgPipeline

# ---------- Load models once ----------
est = PretrainedAgeEstimator()
cropper = FaceCropper(device=est.device)

# A solid, public SD 1.5 img2img pipeline; fast and reliable
SD15_ID = "runwayml/stable-diffusion-v1-5"
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    SD15_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    safety_checker=None,   # rely on prompts; HF Spaces also has a global filter
).to(est.device)

# ---------- Helpers ----------
def _ensure_pil(img):
    if isinstance(img, Image.Image):
        return img
    return Image.fromarray(img)

# ----- Age: single image -----
def predict_single(img, auto_crop=True, topk=5, show_annot=True):
    if img is None:
        return {}, "No image provided.", None
    img = _ensure_pil(img).convert("RGB")

    preview = img
    face = None
    if auto_crop:
        face, annotated, _ = cropper.detect_and_crop(img, select="largest")
        preview = annotated if show_annot else img

    target = face if face is not None else img
    age, top = est.predict(target, topk=topk)

    probs = {lbl: float(prob) for lbl, prob in top}
    summary = f"**Estimated age:** {age:.1f} years"
    return probs, summary, preview

# ----- Age: batch -----
def predict_batch(files, auto_crop=True, topk=5):
    if not files:
        return "No files uploaded."
    rows = ["| File | Estimated Age | Top-1 | p |", "|---|---:|---|---:|"]
    for f in files:
        try:
            img = Image.open(f.name).convert("RGB")
            face = None
            if auto_crop:
                face, _, _ = cropper.detect_and_crop(img, select="largest")
            target = face if face is not None else img
            age, top = est.predict(target, topk=topk)
            top1_lbl, top1_p = top[0]
            rows.append(f"| {os.path.basename(f.name)} | {age:.1f} | {top1_lbl} | {top1_p:.3f} |")
        except Exception:
            rows.append(f"| {os.path.basename(f.name)} | (error) | - | - |")
    return "\n".join(rows)

# ----- NEW: Cartoonizer (img2img) -----
def cartoonize(img, prompt, strength=0.6, guidance=7.5, steps=25, seed=0, use_face_crop=True):
    """
    img: PIL or numpy
    prompt: text description, e.g. "cute cel-shaded cartoon, soft outlines, vibrant colors"
    strength: how much to deviate from the input (0.3 subtle → 0.8 strong)
    guidance: prompt strength (5–12 typical)
    steps: diffusion steps (20–40 typical)
    seed: reproducibility (-1 for random)
    """
    if img is None:
        return None

    img = _ensure_pil(img).convert("RGB")

    # optional crop to the largest face for better identity preservation
    if use_face_crop:
        face, _, _ = cropper.detect_and_crop(img, select="largest")
        if face is not None:
            img = face

    # cartoon-y defaults (you can tweak in UI)
    base_prompt = (
        "cartoon, cel-shaded, clean lineart, smooth shading, high contrast, vibrant, studio ghibli style, "
        "pixar style, highly detailed, 2D illustration"
    )
    full_prompt = f"{base_prompt}, {prompt}".strip().strip(",")

    generator = None
    if seed and seed >= 0:
        generator = torch.Generator(device=est.device).manual_seed(int(seed))

    out = sd_pipe(
        prompt=full_prompt,
        image=img,
        strength=float(strength),
        guidance_scale=float(guidance),
        num_inference_steps=int(steps),
        generator=generator,
    )
    result = out.images[0]
    return result

# ---------- UI ----------
with gr.Blocks(title="Pretrained Age Estimator + Cartoonizer") as demo:
    gr.Markdown("# Pretrained Age Estimator + Cartoonizer")
    gr.Markdown("Detects age from a face and can also generate a cartoonized image guided by your text description.")

    with gr.Tabs():
        with gr.Tab("Age (Single)"):
            with gr.Row():
                with gr.Column():
                    inp = gr.Image(type="pil", label="Upload a face image")
                    cam = gr.Image(sources=["webcam"], type="pil", label="Webcam (optional)")
                    auto = gr.Checkbox(True, label="Auto face crop (MTCNN)")
                    topk = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
                    annot = gr.Checkbox(True, label="Show detection preview")
                    btn = gr.Button("Predict Age", variant="primary")
                with gr.Column():
                    out_label = gr.Label(num_top_classes=5, label="Age Prediction (probabilities)")
                    out_md = gr.Markdown(label="Summary")
                    out_prev = gr.Image(label="Preview", visible=True)

            def run_single(img, cam_img, auto_crop, topk_val, show_annot):
                chosen = cam_img if cam_img is not None else img
                return predict_single(chosen, auto_crop, int(topk_val), show_annot)

            btn.click(fn=run_single, inputs=[inp, cam, auto, topk, annot],
                      outputs=[out_label, out_md, out_prev])

        with gr.Tab("Age (Batch)"):
            files = gr.Files(label="Upload multiple images")
            auto_b = gr.Checkbox(True, label="Auto face crop (MTCNN)")
            topk_b = gr.Slider(3, 9, value=5, step=1, label="Top-K age ranges")
            btn_b = gr.Button("Run batch")
            out_table = gr.Markdown()
            btn_b.click(fn=predict_batch, inputs=[files, auto_b, topk_b], outputs=out_table)

        with gr.Tab("Cartoonizer"):
            src = gr.Image(type="pil", label="Source image (face or any photo)")
            prompt = gr.Textbox(label="Your style prompt",
                                value="cute cel-shaded cartoon, clean lines, soft colors")
            with gr.Row():
                strength = gr.Slider(0.2, 0.95, value=0.6, step=0.05, label="Transformation strength")
                guidance = gr.Slider(3, 15, value=7.5, step=0.5, label="Guidance scale")
                steps = gr.Slider(10, 50, value=25, step=1, label="Steps")
                seed = gr.Number(value=0, precision=0, label="Seed (0 or -1 = random)")
            use_crop = gr.Checkbox(True, label="Crop to largest face before stylizing")
            btn_c = gr.Button("Generate Cartoon", variant="primary")
            out_img = gr.Image(label="Cartoon result")

            btn_c.click(fn=cartoonize,
                        inputs=[src, prompt, strength, guidance, steps, seed, use_crop],
                        outputs=out_img)

if __name__ == "__main__":
    demo.launch()