File size: 6,624 Bytes
442b121 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import gradio as gr
import os
import tempfile
import numpy as np
import axengine as axe
import cv2
from utils.restoration_helper import RestoreHelper
restore_helper = RestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model="../model/yolov5l-face.axmodel",
res_model="../model/codeformer.axmodel",
bg_model="../model/realesrgan-x2.axmodel",
save_ext='png',
use_parse=True
)
def face(img_path, session):
output_names = [x.name for x in session.get_outputs()]
input_name = session.get_inputs()[0].name
ori_image = cv2.imread(img_path)
h, w = ori_image.shape[:2]
image = cv2.resize(ori_image, (512, 512))
image = (image[..., ::-1] /255.0).astype(np.float32)
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
image = ((image - mean) / std).astype(np.float32)
#image = (image /1.0).astype(np.float32)
img = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
# Use the model to generate super-resolved images
sr = session.run(output_names, {input_name: img})
#sr_y_image = imgproc.array_to_image(sr)
sr = np.transpose(sr[0].squeeze(0), (1,2,0))
sr = (sr*std + mean).astype(np.float32)
# Save image
ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
out_image = cv2.resize(ndarr[..., ::-1], (w, h))
return out_image
def full_image(img_path, restore_helper=restore_helper):
restore_helper.clean_all()
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
restore_helper.read_image(img)
# get face landmarks for each face
num_det_faces = restore_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5)
# align and warp each face
restore_helper.align_warp_face()
# face restoration for each cropped face
for idx, cropped_face in enumerate(restore_helper.cropped_faces):
# prepare data
cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
cropped_face_t = np.transpose(
np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0),
(0,3,1,2)
)
#print('cropped_face_t', cropped_face_t.shape)
try:
ort_outs = restore_helper.rs_sessison.run(
restore_helper.rs_output,
{restore_helper.rs_input: cropped_face_t}
)
restored_face = ort_outs[0]
restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}')
restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)
restored_face = restored_face.astype('uint8')
restore_helper.add_restored_face(restored_face, cropped_face)
# upsample the background
# Now only support RealESRGAN for upsampling background
bg_img = restore_helper.background_upsampling(img)
restore_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)
return restored_img
def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
if not input_img_path:
raise gr.Error("未上传图片")
# 加载图像
progress(0.3, desc="加载图像...")
# 根据模型选择调用不同函数
if model_name == "Face":
out = face(input_img_path, session=restore_helper.rs_sessison)
else:
out = full_image(input_img_path, restore_helper=restore_helper)
progress(0.9, desc="保存结果...")
# 保存到临时文件
output_path = os.path.join(tempfile.gettempdir(), "restore_output.jpg")
cv2.imwrite(output_path, out)
progress(1.0, desc="完成!")
return output_path
# ==============================
# Gradio 界面
# ==============================
custom_css = """
body, .gradio-container {
font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
}
.model-buttons .wrap {
display: flex;
gap: 10px;
}
.model-buttons .wrap label {
background-color: #f0f0f0;
padding: 10px 20px;
border-radius: 8px;
cursor: pointer;
text-align: center;
font-weight: 600;
border: 2px solid transparent;
flex: 1;
}
.model-buttons .wrap label:hover {
background-color: #e0e0e0;
}
.model-buttons .wrap input[type="radio"]:checked + label {
background-color: #4CAF50;
color: white;
border-color: #45a049;
}
"""
with gr.Blocks(title="人脸修复工具") as demo:
gr.Markdown("## 🎨 人脸修复演示DEMO")
with gr.Row(equal_height=True):
# 左侧:输入区
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 📤 输入")
input_image = gr.Image(
type="filepath",
label="上传图片",
sources=["upload"],
height=300
)
gr.Markdown("### 🔧 选择修复模式")
model_choice = gr.Radio(
choices=["Face", "Full image"],
value="Face",
label=None,
elem_classes="model-buttons"
)
run_btn = gr.Button("🚀 开始修复", variant="primary")
# 右侧:输出区
with gr.Column(scale=1, min_width=600):
gr.Markdown("### 🖼️ 修复结果")
output_image = gr.Image(
label="修复后图片",
interactive=False,
height=600
)
download_btn = gr.File(label="📥 下载修复图片")
# 绑定事件
def on_colorize(img_path, model, progress=gr.Progress()):
if img_path is None:
raise gr.Error("请先上传图片!")
try:
result_path = colorize_image(img_path, model, progress=progress)
return result_path, result_path
except Exception as e:
raise gr.Error(f"处理失败: {str(e)}")
run_btn.click(
fn=on_colorize,
inputs=[input_image, model_choice],
outputs=[output_image, download_btn]
)
# 启动
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css) |