CodeFormer / python /gradio_demo.py
wli1995's picture
Upload gradio demo
442b121 verified
import gradio as gr
import os
import tempfile
import numpy as np
import axengine as axe
import cv2
from utils.restoration_helper import RestoreHelper
restore_helper = RestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model="../model/yolov5l-face.axmodel",
res_model="../model/codeformer.axmodel",
bg_model="../model/realesrgan-x2.axmodel",
save_ext='png',
use_parse=True
)
def face(img_path, session):
output_names = [x.name for x in session.get_outputs()]
input_name = session.get_inputs()[0].name
ori_image = cv2.imread(img_path)
h, w = ori_image.shape[:2]
image = cv2.resize(ori_image, (512, 512))
image = (image[..., ::-1] /255.0).astype(np.float32)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
image = ((image - mean) / std).astype(np.float32)
#image = (image /1.0).astype(np.float32)
img = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
# Use the model to generate super-resolved images
sr = session.run(output_names, {input_name: img})
#sr_y_image = imgproc.array_to_image(sr)
sr = np.transpose(sr[0].squeeze(0), (1,2,0))
sr = (sr*std + mean).astype(np.float32)
# Save image
ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
out_image = cv2.resize(ndarr[..., ::-1], (w, h))
return out_image
def full_image(img_path, restore_helper=restore_helper):
restore_helper.clean_all()
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
restore_helper.read_image(img)
# get face landmarks for each face
num_det_faces = restore_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5)
# align and warp each face
restore_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(restore_helper.cropped_faces):
# prepare data
cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
cropped_face_t = np.transpose(
np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0),
(0,3,1,2)
)
#print('cropped_face_t', cropped_face_t.shape)
try:
ort_outs = restore_helper.rs_sessison.run(
restore_helper.rs_output,
{restore_helper.rs_input: cropped_face_t}
)
restored_face = ort_outs[0]
restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}')
restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)
restored_face = restored_face.astype('uint8')
restore_helper.add_restored_face(restored_face, cropped_face)
# upsample the background
# Now only support RealESRGAN for upsampling background
bg_img = restore_helper.background_upsampling(img)
restore_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
return restored_img
def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
if not input_img_path:
raise gr.Error("未上传图片")
# 加载图像
progress(0.3, desc="加载图像...")
# 根据模型选择调用不同函数
if model_name == "Face":
out = face(input_img_path, session=restore_helper.rs_sessison)
else:
out = full_image(input_img_path, restore_helper=restore_helper)
progress(0.9, desc="保存结果...")
# 保存到临时文件
output_path = os.path.join(tempfile.gettempdir(), "restore_output.jpg")
cv2.imwrite(output_path, out)
progress(1.0, desc="完成!")
return output_path
# ==============================
# Gradio 界面
# ==============================
custom_css = """
body, .gradio-container {
font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
}
.model-buttons .wrap {
display: flex;
gap: 10px;
}
.model-buttons .wrap label {
background-color: #f0f0f0;
padding: 10px 20px;
border-radius: 8px;
cursor: pointer;
text-align: center;
font-weight: 600;
border: 2px solid transparent;
flex: 1;
}
.model-buttons .wrap label:hover {
background-color: #e0e0e0;
}
.model-buttons .wrap input[type="radio"]:checked + label {
background-color: #4CAF50;
color: white;
border-color: #45a049;
}
"""
with gr.Blocks(title="人脸修复工具") as demo:
gr.Markdown("## 🎨 人脸修复演示DEMO")
with gr.Row(equal_height=True):
# 左侧:输入区
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 📤 输入")
input_image = gr.Image(
type="filepath",
label="上传图片",
sources=["upload"],
height=300
)
gr.Markdown("### 🔧 选择修复模式")
model_choice = gr.Radio(
choices=["Face", "Full image"],
value="Face",
label=None,
elem_classes="model-buttons"
)
run_btn = gr.Button("🚀 开始修复", variant="primary")
# 右侧:输出区
with gr.Column(scale=1, min_width=600):
gr.Markdown("### 🖼️ 修复结果")
output_image = gr.Image(
label="修复后图片",
interactive=False,
height=600
)
download_btn = gr.File(label="📥 下载修复图片")
# 绑定事件
def on_colorize(img_path, model, progress=gr.Progress()):
if img_path is None:
raise gr.Error("请先上传图片!")
try:
result_path = colorize_image(img_path, model, progress=progress)
return result_path, result_path
except Exception as e:
raise gr.Error(f"处理失败: {str(e)}")
run_btn.click(
fn=on_colorize,
inputs=[input_image, model_choice],
outputs=[output_image, download_btn]
)
# 启动
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css)