RADAR-demo / app.py
arcanoXIII's picture
Upload 2 files
c077072 verified
import gradio as gr
import torch
import torchvision
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import io
import os
# Import your models
from models.feature_extractor import FeatureExtractor, FeatureExtractorDepth
from models.projector import SiameseProjector
from models.fuser import DoubleCrossAttentionFusion
from loaders.loader_utils import SquarePad
# Configuration
CHECKPOINT_PATH = './checkpoints'
MODEL_LABEL = 'multimodal_15k_10inp'
EPOCHS = 120
BATCH_SIZE = 4
IMAGE_SIZE = 896
# Load models
print("Loading models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model_name = f'{MODEL_LABEL}_{EPOCHS}ep_{BATCH_SIZE}bs'
rgb_transform = torchvision.transforms.Compose([
SquarePad(),
torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE),
interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
torchvision.transforms.Lambda(lambda img: img.unsqueeze(0)),
])
fe_rgb = FeatureExtractor().to(device).eval()
fe_depth = FeatureExtractorDepth().to(device).eval()
fusion_block = DoubleCrossAttentionFusion(hidden_dim=fe_rgb.embed_dim).to(device)
fusion_block.load_state_dict(torch.load(
os.path.join(CHECKPOINT_PATH, f'fusion_block_{model_name}.pth'),
weights_only=False,
map_location=device
))
fusion_block.eval()
projector = SiameseProjector(inner_features=fe_rgb.embed_dim).to(device)
projector.load_state_dict(torch.load(
os.path.join(CHECKPOINT_PATH, f'projector_{model_name}.pth'),
weights_only=False,
map_location=device
))
projector.eval()
print("Models loaded successfully!")
def detect_manipulation(image):
"""Process image and return heatmap"""
if image is None:
return None
# Convert to PIL
if isinstance(image, np.ndarray):
rgb_input = Image.fromarray(image.astype('uint8')).convert('RGB')
else:
rgb_input = image.convert('RGB')
original_size = rgb_input.size
# Transform and process
rgb = rgb_transform(rgb_input)
rgb = rgb.to(device)
with torch.no_grad():
rgb_feat = fe_rgb(rgb)
depth_feat = fe_depth(rgb)
fused_feat = fusion_block(rgb_feat, depth_feat)
_, segmentation_map = projector(fused_feat)
segmentation_map = torch.sigmoid(segmentation_map)
# Resize back to original
segmentation_map = torch.nn.functional.interpolate(
segmentation_map,
size=[max(original_size), max(original_size)],
mode='bilinear'
).squeeze()
segmentation_map = torchvision.transforms.functional.center_crop(
segmentation_map,
original_size[::-1]
)
heatmap = segmentation_map.cpu().detach().numpy()
# Create visualization with exact size
# Calculate figure size to match image dimensions
dpi = 100
fig_height = original_size[1] / dpi
fig_width = original_size[0] / dpi
fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)
ax = fig.add_axes([0, 0, 1, 1]) # No margins
ax.imshow(heatmap, cmap='jet')
ax.axis('off')
# Convert to numpy array
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi)
buf.seek(0)
result_image = Image.open(buf)
# Ensure exact size match by resizing if needed
if result_image.size != original_size:
result_image = result_image.resize(original_size, Image.LANCZOS)
result_array = np.array(result_image)
plt.close(fig)
return result_array
# Custom CSS for styling
custom_css = """
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
}
#title {
text-align: center;
font-size: 2.5em;
font-weight: bold;
margin-bottom: 0.5em;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
#subtitle {
text-align: center;
font-size: 1.2em;
color: #666;
margin-bottom: 1em;
}
#info {
background: #e8f4fd;
border-left: 4px solid #2196F3;
padding: 15px;
border-radius: 5px;
margin-bottom: 20px;
color: #1976D2;
}
"""
# Create interface using Gradio 4.x Blocks
with gr.Blocks(css=custom_css, title="RADAR - Image Manipulation Detection") as demo:
gr.HTML('<h1 id="title">🎯 RADAR</h1>')
gr.HTML('<p id="subtitle">ReliAble iDentification of inpainted AReas</p>')
gr.HTML('''
<div id="info">
<strong>ℹ️ About RADAR:</strong> Upload an image to detect and localize regions
that have been manipulated using diffusion-based inpainting models.
The output shows a heatmap where red areas indicate detected manipulations.
</div>
''')
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Image", type="numpy")
with gr.Column():
output_image = gr.Image(label="Manipulation Heatmap", type="numpy")
submit_btn = gr.Button("🔍 Detect Manipulations", variant="primary")
# Connect the button
submit_btn.click(
fn=detect_manipulation,
inputs=input_image,
outputs=output_image
)
# Also trigger on image upload
input_image.change(
fn=detect_manipulation,
inputs=input_image,
outputs=output_image
)
# Launch
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)