Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torchvision | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import io | |
| import os | |
| # Import your models | |
| from models.feature_extractor import FeatureExtractor, FeatureExtractorDepth | |
| from models.projector import SiameseProjector | |
| from models.fuser import DoubleCrossAttentionFusion | |
| from loaders.loader_utils import SquarePad | |
| # Configuration | |
| CHECKPOINT_PATH = './checkpoints' | |
| MODEL_LABEL = 'multimodal_15k_10inp' | |
| EPOCHS = 120 | |
| BATCH_SIZE = 4 | |
| IMAGE_SIZE = 896 | |
| # Load models | |
| print("Loading models...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| model_name = f'{MODEL_LABEL}_{EPOCHS}ep_{BATCH_SIZE}bs' | |
| rgb_transform = torchvision.transforms.Compose([ | |
| SquarePad(), | |
| torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), | |
| interpolation=torchvision.transforms.InterpolationMode.BICUBIC), | |
| torchvision.transforms.ToTensor(), | |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| torchvision.transforms.Lambda(lambda img: img.unsqueeze(0)), | |
| ]) | |
| fe_rgb = FeatureExtractor().to(device).eval() | |
| fe_depth = FeatureExtractorDepth().to(device).eval() | |
| fusion_block = DoubleCrossAttentionFusion(hidden_dim=fe_rgb.embed_dim).to(device) | |
| fusion_block.load_state_dict(torch.load( | |
| os.path.join(CHECKPOINT_PATH, f'fusion_block_{model_name}.pth'), | |
| weights_only=False, | |
| map_location=device | |
| )) | |
| fusion_block.eval() | |
| projector = SiameseProjector(inner_features=fe_rgb.embed_dim).to(device) | |
| projector.load_state_dict(torch.load( | |
| os.path.join(CHECKPOINT_PATH, f'projector_{model_name}.pth'), | |
| weights_only=False, | |
| map_location=device | |
| )) | |
| projector.eval() | |
| print("Models loaded successfully!") | |
| def detect_manipulation(image): | |
| """Process image and return heatmap""" | |
| if image is None: | |
| return None | |
| # Convert to PIL | |
| if isinstance(image, np.ndarray): | |
| rgb_input = Image.fromarray(image.astype('uint8')).convert('RGB') | |
| else: | |
| rgb_input = image.convert('RGB') | |
| original_size = rgb_input.size | |
| # Transform and process | |
| rgb = rgb_transform(rgb_input) | |
| rgb = rgb.to(device) | |
| with torch.no_grad(): | |
| rgb_feat = fe_rgb(rgb) | |
| depth_feat = fe_depth(rgb) | |
| fused_feat = fusion_block(rgb_feat, depth_feat) | |
| _, segmentation_map = projector(fused_feat) | |
| segmentation_map = torch.sigmoid(segmentation_map) | |
| # Resize back to original | |
| segmentation_map = torch.nn.functional.interpolate( | |
| segmentation_map, | |
| size=[max(original_size), max(original_size)], | |
| mode='bilinear' | |
| ).squeeze() | |
| segmentation_map = torchvision.transforms.functional.center_crop( | |
| segmentation_map, | |
| original_size[::-1] | |
| ) | |
| heatmap = segmentation_map.cpu().detach().numpy() | |
| # Create visualization with exact size | |
| # Calculate figure size to match image dimensions | |
| dpi = 100 | |
| fig_height = original_size[1] / dpi | |
| fig_width = original_size[0] / dpi | |
| fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi) | |
| ax = fig.add_axes([0, 0, 1, 1]) # No margins | |
| ax.imshow(heatmap, cmap='jet') | |
| ax.axis('off') | |
| # Convert to numpy array | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi) | |
| buf.seek(0) | |
| result_image = Image.open(buf) | |
| # Ensure exact size match by resizing if needed | |
| if result_image.size != original_size: | |
| result_image = result_image.resize(original_size, Image.LANCZOS) | |
| result_array = np.array(result_image) | |
| plt.close(fig) | |
| return result_array | |
| # Custom CSS for styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| } | |
| #title { | |
| text-align: center; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 0.5em; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| #subtitle { | |
| text-align: center; | |
| font-size: 1.2em; | |
| color: #666; | |
| margin-bottom: 1em; | |
| } | |
| #info { | |
| background: #e8f4fd; | |
| border-left: 4px solid #2196F3; | |
| padding: 15px; | |
| border-radius: 5px; | |
| margin-bottom: 20px; | |
| color: #1976D2; | |
| } | |
| """ | |
| # Create interface using Gradio 4.x Blocks | |
| with gr.Blocks(css=custom_css, title="RADAR - Image Manipulation Detection") as demo: | |
| gr.HTML('<h1 id="title">🎯 RADAR</h1>') | |
| gr.HTML('<p id="subtitle">ReliAble iDentification of inpainted AReas</p>') | |
| gr.HTML(''' | |
| <div id="info"> | |
| <strong>ℹ️ About RADAR:</strong> Upload an image to detect and localize regions | |
| that have been manipulated using diffusion-based inpainting models. | |
| The output shows a heatmap where red areas indicate detected manipulations. | |
| </div> | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload Image", type="numpy") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Manipulation Heatmap", type="numpy") | |
| submit_btn = gr.Button("🔍 Detect Manipulations", variant="primary") | |
| # Connect the button | |
| submit_btn.click( | |
| fn=detect_manipulation, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| # Also trigger on image upload | |
| input_image.change( | |
| fn=detect_manipulation, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |