File size: 6,624 Bytes
442b121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import os
import tempfile
import numpy as np
import axengine as axe
import cv2
from utils.restoration_helper import RestoreHelper

restore_helper = RestoreHelper(
        upscale_factor=1,
        face_size=512,
        crop_ratio=(1, 1),
        det_model="../model/yolov5l-face.axmodel",
        res_model="../model/codeformer.axmodel",
        bg_model="../model/realesrgan-x2.axmodel",
        save_ext='png',
        use_parse=True
        )

def face(img_path, session):

    output_names = [x.name for x in session.get_outputs()]
    input_name = session.get_inputs()[0].name

    ori_image = cv2.imread(img_path)
    h, w = ori_image.shape[:2]
    image = cv2.resize(ori_image, (512, 512))
    image = (image[..., ::-1] /255.0).astype(np.float32)
    
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    image = ((image - mean) / std).astype(np.float32)

    #image = (image /1.0).astype(np.float32)
    img = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
    
    # Use the model to generate super-resolved images
    sr = session.run(output_names, {input_name: img})

    #sr_y_image = imgproc.array_to_image(sr)
    sr = np.transpose(sr[0].squeeze(0), (1,2,0))
    sr = (sr*std + mean).astype(np.float32)
    
    # Save image
    ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
    out_image = cv2.resize(ndarr[..., ::-1], (w, h))

    return out_image

def full_image(img_path, restore_helper=restore_helper):

    restore_helper.clean_all()
    img = cv2.imread(img_path, cv2.IMREAD_COLOR)

    restore_helper.read_image(img)
    # get face landmarks for each face
    num_det_faces = restore_helper.get_face_landmarks_5(
        only_center_face=False, resize=640, eye_dist_threshold=5)
    # align and warp each face
    restore_helper.align_warp_face()
    # face restoration for each cropped face
    for idx, cropped_face in enumerate(restore_helper.cropped_faces):
        # prepare data
        cropped_face_t = (cropped_face.astype(np.float32) / 255.0) * 2.0 - 1.0
        cropped_face_t = np.transpose(
            np.expand_dims(np.ascontiguousarray(cropped_face_t[...,::-1]), axis=0), 
            (0,3,1,2)
        )
        #print('cropped_face_t', cropped_face_t.shape)

        try:
            ort_outs = restore_helper.rs_sessison.run(
                restore_helper.rs_output, 
                {restore_helper.rs_input: cropped_face_t}
                )
            restored_face = ort_outs[0]
            restored_face = (restored_face.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
            restored_face = np.clip(restored_face[...,::-1], 0, 255).astype(np.uint8)
        except Exception as error:
            print(f'\tFailed inference for CodeFormer: {error}')
            restored_face = (cropped_face_t.squeeze().transpose(1, 2, 0) * 0.5 + 0.5) * 255
            restored_face = np.clip(restored_face, 0, 255).astype(np.uint8)

        restored_face = restored_face.astype('uint8')
        restore_helper.add_restored_face(restored_face, cropped_face)

    # upsample the background
    # Now only support RealESRGAN for upsampling background
    bg_img = restore_helper.background_upsampling(img)
    restore_helper.get_inverse_affine(None)
    # paste each restored face to the input image
    restored_img = restore_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False)

    return restored_img
    

def colorize_image(input_img_path: str, model_name: str, progress=gr.Progress()):
    if not input_img_path:
        raise gr.Error("未上传图片")

    # 加载图像
    progress(0.3, desc="加载图像...")

    # 根据模型选择调用不同函数
    if model_name == "Face":
        out = face(input_img_path, session=restore_helper.rs_sessison)
    else:
        out = full_image(input_img_path, restore_helper=restore_helper)

    progress(0.9, desc="保存结果...")

    # 保存到临时文件
    output_path = os.path.join(tempfile.gettempdir(), "restore_output.jpg")
    cv2.imwrite(output_path, out)

    progress(1.0, desc="完成!")
    return output_path


# ==============================
# Gradio 界面
# ==============================
custom_css = """
body, .gradio-container {
    font-family: 'Microsoft YaHei', 'PingFang SC', 'Helvetica Neue', Arial, sans-serif;
}
.model-buttons .wrap {
    display: flex;
    gap: 10px;
}
.model-buttons .wrap label {
    background-color: #f0f0f0;
    padding: 10px 20px;
    border-radius: 8px;
    cursor: pointer;
    text-align: center;
    font-weight: 600;
    border: 2px solid transparent;
    flex: 1;
}
.model-buttons .wrap label:hover {
    background-color: #e0e0e0;
}
.model-buttons .wrap input[type="radio"]:checked + label {
    background-color: #4CAF50;
    color: white;
    border-color: #45a049;
}
"""

with gr.Blocks(title="人脸修复工具") as demo:
    gr.Markdown("## 🎨 人脸修复演示DEMO")

    with gr.Row(equal_height=True):
        # 左侧:输入区
        with gr.Column(scale=1, min_width=300):
            gr.Markdown("### 📤 输入")
            input_image = gr.Image(
                type="filepath",
                label="上传图片",
                sources=["upload"],
                height=300
            )
            
            gr.Markdown("### 🔧 选择修复模式")
            model_choice = gr.Radio(
                choices=["Face", "Full image"],
                value="Face",
                label=None,
                elem_classes="model-buttons"
            )
            
            run_btn = gr.Button("🚀 开始修复", variant="primary")

        # 右侧:输出区
        with gr.Column(scale=1, min_width=600):
            gr.Markdown("### 🖼️ 修复结果")
            output_image = gr.Image(
                label="修复后图片",
                interactive=False,
                height=600
            )
            download_btn = gr.File(label="📥 下载修复图片")

    # 绑定事件
    def on_colorize(img_path, model, progress=gr.Progress()):
        if img_path is None:
            raise gr.Error("请先上传图片!")
        try:
            result_path = colorize_image(img_path, model, progress=progress)
            return result_path, result_path
        except Exception as e:
            raise gr.Error(f"处理失败: {str(e)}")

    run_btn.click(
        fn=on_colorize,
        inputs=[input_image, model_choice],
        outputs=[output_image, download_btn]
    )

# 启动
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), css=custom_css)