Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| import json | |
| from unidepth.models import UniDepthV2 | |
| import os | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| from PIL import Image | |
| # Load model configurations and initialize model | |
| def load_model(config_path, model_path, encoder, device): | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| model = UniDepthV2(config) | |
| model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True) | |
| model = model.to(device).eval() | |
| return model | |
| # Inference function | |
| def depth_estimation(image, model_path, encoder='vits'): | |
| try: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # device = 'cpu' | |
| config_path = 'configs/config_v2_vits14.json' | |
| # Ensure model path exists or download if needed | |
| model_path="checkpoint/latest.pth" | |
| if not os.path.exists(model_path): | |
| return "Model checkpoint not found. Please upload a valid model path." | |
| model = load_model(config_path, model_path, encoder, device) | |
| # Preprocess image | |
| rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W | |
| predictions = model.infer(rgb) | |
| depth = predictions["depth"].squeeze().to(device).numpy() | |
| min_depth = depth.min() | |
| max_depth = depth.max() | |
| depth_normalized = (depth - min_depth) / (max_depth - min_depth) | |
| # Apply colormap | |
| cmap = matplotlib.colormaps.get_cmap('Spectral') | |
| depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8) | |
| # Create a figure and axis for the colorbar | |
| fig, ax = plt.subplots(figsize=(6, 0.4)) | |
| fig.subplots_adjust(bottom=0.5) | |
| # Create a colorbar | |
| norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth) | |
| sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) | |
| sm.set_array([]) | |
| cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)') | |
| # Save the colorbar to a BytesIO object | |
| from io import BytesIO | |
| buf = BytesIO() | |
| fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1) | |
| plt.close(fig) | |
| buf.seek(0) | |
| # Open the colorbar image | |
| colorbar_img = Image.open(buf) | |
| # Create a new image with space for the colorbar | |
| new_height = depth_color.shape[0] + colorbar_img.size[1] | |
| new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255)) | |
| # Paste the depth image and colorbar | |
| new_img.paste(Image.fromarray(depth_color), (0, 0)) | |
| new_img.paste(colorbar_img, (0, depth_color.shape[0])) | |
| return new_img | |
| except Exception as e: | |
| return f"Error occurred: {str(e)}" | |
| # Gradio Interface | |
| def main(): | |
| iface = gr.Interface( | |
| fn=depth_estimation, | |
| inputs=[ | |
| gr.Image(type="numpy", label="Input Image"), | |
| gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'), | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Predicted Depth") | |
| ], | |
| title="Metric Depth Estimation", | |
| description="Upload an image to get its estimated depth map using Depth Anything V2.", | |
| ) | |
| iface.launch() | |
| if __name__ == "__main__": | |
| main() | |