Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import os | |
| import subprocess | |
| def install(package): | |
| subprocess.check_call([os.sys.executable, "-m", "pip", "install", package]) | |
| install("torchvision") | |
| install("loguru") | |
| install("imageio") | |
| install("modelscope") | |
| install("einops") | |
| install("safetensors") | |
| install("transformers") | |
| install("ftfy") | |
| install("accelerate") | |
| install("sentencepiece") | |
| install("spaces") | |
| install("opencv-python") | |
| install("trimesh") | |
| install("gradio_litmodel3d") | |
| install("open3d") | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from loguru import logger | |
| from tqdm import tqdm | |
| from tools.common_utils import save_video | |
| from dkt.pipelines.wan_video_new import WanVideoPipeline, ModelConfig | |
| try: | |
| import gradio_client.utils as _gc_utils | |
| if hasattr(_gc_utils, "get_type"): | |
| _orig_get_type = _gc_utils.get_type | |
| def _get_type_safe(schema): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return _orig_get_type(schema) | |
| _gc_utils.get_type = _get_type_safe | |
| except Exception: | |
| pass | |
| # Additional guard: handle boolean JSON Schemas and parsing errors | |
| try: | |
| import gradio_client.utils as _gc_utils | |
| # Wrap the internal _json_schema_to_python_type if present | |
| if hasattr(_gc_utils, "_json_schema_to_python_type"): | |
| _orig_internal = _gc_utils._json_schema_to_python_type | |
| def _json_schema_to_python_type_safe(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| try: | |
| return _orig_internal(schema, defs) | |
| except Exception: | |
| return "Any" | |
| _gc_utils._json_schema_to_python_type = _json_schema_to_python_type_safe | |
| # Also wrap the public json_schema_to_python_type to be extra defensive | |
| if hasattr(_gc_utils, "json_schema_to_python_type"): | |
| _orig_public = _gc_utils.json_schema_to_python_type | |
| def json_schema_to_python_type_safe(schema): | |
| try: | |
| return _orig_public(schema) | |
| except Exception: | |
| return "Any" | |
| _gc_utils.json_schema_to_python_type = json_schema_to_python_type_safe | |
| except Exception: | |
| pass | |
| import cv2 | |
| import copy | |
| import trimesh | |
| from gradio_litmodel3d import LitModel3D | |
| from os.path import join | |
| from tools.depth2pcd import depth2pcd | |
| try: | |
| from moge.model.v2 import MoGeModel | |
| except: | |
| os.system('pip install git+https://github.com/microsoft/MoGe.git -i https://pypi.org/simple/ --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host files.pythonhosted.org') | |
| from moge.model.v2 import MoGeModel | |
| from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map | |
| import glob | |
| import datetime | |
| import shutil | |
| import tempfile | |
| import spaces | |
| PIPE_1_3B = None | |
| MOGE_MODULE = None | |
| #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors | |
| PROMPT = 'depth' | |
| NEGATIVE_PROMPT = '' | |
| def resize_frame(frame, height, width): | |
| frame = np.array(frame) | |
| frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
| frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True) | |
| frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy() | |
| frame = Image.fromarray(frame) | |
| return frame | |
| def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: | |
| pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) | |
| pts_rgb = frame[valid_mask] | |
| # Initialize a 3D scene | |
| scene_3d = trimesh.Scene() | |
| # Add point cloud data to the scene | |
| point_cloud_data = trimesh.PointCloud( | |
| vertices=pts_3d, colors=pts_rgb | |
| ) | |
| scene_3d.add_geometry(point_cloud_data) | |
| return scene_3d | |
| def create_simple_glb_from_pointcloud(points, colors, glb_filename): | |
| try: | |
| if len(points) == 0: | |
| logger.warning(f"No valid points to create GLB for {glb_filename}") | |
| return False | |
| if colors is not None: | |
| # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]") | |
| pts_rgb = colors | |
| else: | |
| logger.info("No colors provided, adding default white colors") | |
| pts_rgb = np.ones((len(points), 3)) | |
| valid_mask = np.ones(len(points), dtype=bool) | |
| scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) | |
| scene_3d.export(glb_filename) | |
| # logger.info(f"Saved GLB file using trimesh: {glb_filename}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") | |
| return False | |
| def extract_frames_from_video_file(video_path): | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps <= 0: | |
| fps = 15.0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_rgb = Image.fromarray(frame_rgb) | |
| frames.append(frame_rgb) | |
| cap.release() | |
| return frames, fps | |
| except Exception as e: | |
| logger.error(f"Error extracting frames from {video_path}: {str(e)}") | |
| return [], 15.0 | |
| def load_moge_model(device="cuda:0"): | |
| global MOGE_MODULE | |
| if MOGE_MODULE is not None: | |
| return MOGE_MODULE | |
| logger.info(f"Loading MoGe model on {device}...") | |
| MOGE_MODULE = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal').to(device) | |
| return MOGE_MODULE | |
| def load_model_1_3b(device="cuda:0"): | |
| global PIPE_1_3B | |
| if PIPE_1_3B is not None: | |
| return PIPE_1_3B | |
| logger.info(f"Loading 1.3B model on {device}...") | |
| pipe = WanVideoPipeline.from_pretrained( | |
| torch_dtype=torch.bfloat16, | |
| device=device, | |
| model_configs=[ | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="diffusion_pytorch_model*.safetensors", | |
| offload_device="cpu", | |
| ), | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", | |
| offload_device="cpu", | |
| ), | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="Wan2.1_VAE.pth", | |
| offload_device="cpu", | |
| ), | |
| ModelConfig( | |
| model_id="PAI/Wan2.1-Fun-1.3B-Control", | |
| origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", | |
| offload_device="cpu", | |
| ), | |
| ], | |
| training_strategy="origin", | |
| ) | |
| lora_config = ModelConfig( | |
| model_id="Daniellesry/DKT-Depth-1-3B", | |
| origin_file_pattern="dkt-1-3B.safetensors", | |
| offload_device="cpu", | |
| ) | |
| lora_config.download_if_necessary(use_usp=False) | |
| pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work? | |
| pipe.enable_vram_management() | |
| PIPE_1_3B = pipe | |
| return pipe | |
| def get_model(model_size): | |
| if model_size == "1.3B": | |
| assert PIPE_1_3B is not None, "1.3B model not initialized" | |
| return PIPE_1_3B | |
| else: | |
| raise ValueError(f"Unsupported model size: {model_size}") | |
| def process_video( | |
| video_file, | |
| model_size, | |
| height, | |
| width, | |
| num_inference_steps, | |
| window_size, | |
| overlap | |
| ): | |
| try: | |
| pipe = get_model(model_size) | |
| if pipe is None: | |
| return None, f"Model {model_size} not initialized. Please restart the application." | |
| tmp_video_path = video_file | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # 使用临时目录存储所有文件 | |
| cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') | |
| original_filename = f"input_{timestamp}.mp4" | |
| dst_path = os.path.join(cur_save_dir, original_filename) | |
| shutil.copy2(tmp_video_path, dst_path) | |
| origin_frames, input_fps = extract_frames_from_video_file(tmp_video_path) | |
| if not origin_frames: | |
| return None, "Failed to extract frames from video" | |
| logger.info(f"Extracted {len(origin_frames)} frames from video") | |
| original_width, original_height = origin_frames[0].size | |
| ROTATE = False | |
| if original_width < original_height: | |
| ROTATE = True | |
| origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames] | |
| tmp = original_width | |
| original_width = original_height | |
| original_height = tmp | |
| frames = [resize_frame(frame, height, width) for frame in origin_frames] | |
| frame_length = len(frames) | |
| if (frame_length - 1) % 4 != 0: | |
| new_len = ((frame_length - 1) // 4 + 1) * 4 + 1 | |
| frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)] | |
| control_video = frames | |
| video, vae_outs = pipe( | |
| prompt=PROMPT, | |
| negative_prompt=NEGATIVE_PROMPT, | |
| control_video=control_video, | |
| height=height, | |
| width=width, | |
| num_frames=len(control_video), | |
| seed=1, | |
| tiled=False, | |
| num_inference_steps=num_inference_steps, | |
| sliding_window_size=window_size, | |
| sliding_window_stride=window_size - overlap, | |
| cfg_scale=1.0, | |
| ) | |
| #* moge process | |
| torch.cuda.empty_cache() | |
| processed_video = video[:frame_length] | |
| processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video] | |
| if ROTATE: | |
| processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video] | |
| origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames] | |
| output_filename = f"output_{timestamp}.mp4" | |
| output_path = os.path.join(cur_save_dir, output_filename) | |
| color_predictions = [] | |
| if PROMPT == 'depth': | |
| prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video] | |
| prediced_depth_map_np = np.stack(prediced_depth_map_np) | |
| prediced_depth_map_np = prediced_depth_map_np/ 255.0 | |
| __min = prediced_depth_map_np.min() | |
| __max = prediced_depth_map_np.max() | |
| prediced_depth_map_np = (prediced_depth_map_np - __min) / (__max - __min) | |
| color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np] | |
| else: | |
| color_predictions = processed_video | |
| save_video(color_predictions, output_path, fps=input_fps, quality=5) | |
| frame_num = len(origin_frames) | |
| resize_W,resize_H = origin_frames[0].size | |
| vis_pc_num = 4 | |
| indices = np.linspace(0, frame_num-1, vis_pc_num) | |
| indices = np.round(indices).astype(np.int32) | |
| pc_save_dir = os.path.join(cur_save_dir, 'pointclouds') | |
| os.makedirs(pc_save_dir, exist_ok=True) | |
| glb_files = [] | |
| moge_device = MOGE_MODULE.device if MOGE_MODULE is not None else torch.device("cuda:0") | |
| for idx in tqdm(indices): | |
| orgin_rgb_frame = origin_frames[idx] | |
| predicted_depth = processed_video[idx] | |
| # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1] | |
| input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array | |
| input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1) | |
| output = MOGE_MODULE.infer(input_image) | |
| #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])" | |
| moge_intrinsics = output['intrinsics'].cpu().numpy() | |
| moge_mask = output['mask'].cpu().numpy() | |
| moge_depth = output['depth'].cpu().numpy() | |
| predicted_depth = np.array(predicted_depth) | |
| predicted_depth = predicted_depth.mean(-1) / 255.0 | |
| metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask) | |
| moge_intrinsics[0, 0] *= resize_W | |
| moge_intrinsics[1, 1] *= resize_H | |
| moge_intrinsics[0, 2] *= resize_W | |
| moge_intrinsics[1, 2] *= resize_H | |
| # pcd = depth2pcd(metric_depth, moge_intrinsics, color=cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB), input_mask=moge_mask, ret_pcd=True) | |
| pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=True) | |
| # pcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) * np.array([1, -1, -1], dtype=np.float32)) | |
| apply_filter = True | |
| if apply_filter: | |
| cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=3.0) | |
| pcd = pcd.select_by_index(ind) | |
| #* save pcd: o3d.io.write_point_cloud(f'{pc_save_dir}/{timestamp}_{idx:02d}.ply', pcd) | |
| points = np.asarray(pcd.points) | |
| colors = np.asarray(pcd.colors) if pcd.has_colors() else None | |
| glb_filename = os.path.join(pc_save_dir, f'{timestamp}_{idx:02d}.glb') | |
| success = create_simple_glb_from_pointcloud(points, colors, glb_filename) | |
| if not success: | |
| logger.warning(f"Failed to save GLB file: {glb_filename}") | |
| glb_files.append(glb_filename) | |
| return output_path, glb_files | |
| except Exception as e: | |
| logger.error(f"Error processing video: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| def main(): | |
| #* gradio creation and initialization | |
| css = """ | |
| #video-display-container { | |
| max-height: 100vh; | |
| } | |
| #video-display-input { | |
| max-height: 80vh; | |
| } | |
| #video-display-output { | |
| max-height: 80vh; | |
| } | |
| #download { | |
| height: 62px; | |
| } | |
| .title { | |
| text-align: center; | |
| } | |
| .description { | |
| text-align: center; | |
| } | |
| .gradio-examples { | |
| max-height: 400px; | |
| overflow-y: auto; | |
| } | |
| .gradio-examples .examples-container { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
| gap: 10px; | |
| padding: 10px; | |
| } | |
| .gradio-container .gradio-examples .pagination, | |
| .gradio-container .gradio-examples .pagination button, | |
| div[data-testid="examples"] .pagination, | |
| div[data-testid="examples"] .pagination button { | |
| font-size: 28px !important; | |
| font-weight: bold !important; | |
| padding: 15px 20px !important; | |
| min-width: 60px !important; | |
| height: 60px !important; | |
| border-radius: 10px !important; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| cursor: pointer !important; | |
| margin: 8px !important; | |
| display: inline-block !important; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| div[data-testid="examples"] .pagination button:not(.active), | |
| .gradio-container .gradio-examples .pagination button:not(.active) { | |
| font-size: 32px !important; | |
| font-weight: bold !important; | |
| padding: 15px 20px !important; | |
| min-width: 60px !important; | |
| height: 60px !important; | |
| background: linear-gradient(135deg, #8a9cf0 0%, #9a6bb2 100%) !important; | |
| opacity: 0.8 !important; | |
| } | |
| div[data-testid="examples"] .pagination button:hover, | |
| .gradio-container .gradio-examples .pagination button:hover { | |
| background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%) !important; | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 6px 12px rgba(0,0,0,0.3) !important; | |
| opacity: 1 !important; | |
| } | |
| div[data-testid="examples"] .pagination button.active, | |
| .gradio-container .gradio-examples .pagination button.active { | |
| background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%) !important; | |
| box-shadow: 0 4px 8px rgba(17,153,142,0.4) !important; | |
| opacity: 1 !important; | |
| } | |
| button[class*="pagination"], | |
| button[class*="page"] { | |
| font-size: 28px !important; | |
| font-weight: bold !important; | |
| padding: 15px 20px !important; | |
| min-width: 60px !important; | |
| height: 60px !important; | |
| border-radius: 10px !important; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| border: none !important; | |
| cursor: pointer !important; | |
| margin: 8px !important; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| """ | |
| head_html = """ | |
| <link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E"> | |
| <link rel="shortcut icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E"> | |
| <link rel="icon" type="image/png" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'%3E%3Ctext y='.9em' font-size='90'%3E🦾%3C/text%3E%3C/svg%3E"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| """ | |
| # title = "# Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation " | |
| # description = """Official demo for **DKT **.""" | |
| # with gr.Blocks(css=css, title="DKT - Diffusion Knows Transparency", favicon_path="favicon.ico") as demo: | |
| height = 480 | |
| width = 832 | |
| window_size = 21 | |
| with gr.Blocks(css=css, title="DKT", head=head_html) as demo: | |
| # gr.Markdown(title, elem_classes=["title"]) | |
| # gr.Markdown(description, elem_classes=["description"]) | |
| # gr.Markdown("### Video Processing Demo", elem_classes=["description"]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Input Video", elem_id='video-display-input') | |
| model_size = gr.Radio( | |
| choices=["1.3B", "14B"], | |
| value="1.3B", | |
| label="Model Size" | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| num_inference_steps = gr.Slider( | |
| minimum=1, maximum=50, value=5, step=1, | |
| label="Number of Inference Steps" | |
| ) | |
| overlap = gr.Slider( | |
| minimum=1, maximum=20, value=3, step=1, | |
| label="Overlap" | |
| ) | |
| submit = gr.Button(value="Compute Depth", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video( | |
| label="Depth Outputs", | |
| elem_id='video-display-output', | |
| autoplay=True | |
| ) | |
| vis_video = gr.Video( | |
| label="Visualization Video", | |
| visible=False, | |
| autoplay=True | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"]) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| output_point_map0 = LitModel3D( | |
| label="Point Cloud Key Frame 1", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False, | |
| # height=400, | |
| ) | |
| with gr.Column(scale=1): | |
| output_point_map1 = LitModel3D( | |
| label="Point Cloud Key Frame 2", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| output_point_map2 = LitModel3D( | |
| label="Point Cloud Key Frame 3", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| output_point_map3 = LitModel3D( | |
| label="Point Cloud Key Frame 4", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| interactive=False | |
| ) | |
| def on_submit(video_file, model_size, num_inference_steps, overlap): | |
| if video_file is None: | |
| return None, None, None, None, None, None, "Please upload a video file" | |
| try: | |
| output_path, glb_files = process_video( | |
| video_file, model_size, height, width, num_inference_steps, window_size, overlap | |
| ) | |
| if output_path is None: | |
| return None, None, None, None, None, None, glb_files | |
| model3d_outputs = [None] * 4 | |
| if glb_files: | |
| for i, glb_file in enumerate(glb_files[:4]): | |
| if os.path.exists(glb_file): | |
| model3d_outputs[i] = glb_file | |
| return output_path, None, *model3d_outputs | |
| except Exception as e: | |
| return None, None, None, None, None, None, f"Error: {str(e)}" | |
| submit.click( | |
| on_submit, | |
| inputs=[ | |
| input_video, model_size, num_inference_steps, overlap | |
| ], | |
| outputs=[ | |
| output_video, vis_video, | |
| output_point_map0, output_point_map1, output_point_map2, output_point_map3 | |
| ] | |
| ) | |
| example_files = glob.glob('examples/*') | |
| if example_files: | |
| example_inputs = [] | |
| for file_path in example_files: | |
| example_inputs.append([file_path, "1.3B"]) | |
| examples = gr.Examples( | |
| examples=example_inputs, | |
| inputs=[input_video, model_size], | |
| outputs=[ | |
| output_video, vis_video, | |
| output_point_map0, output_point_map1, output_point_map2, output_point_map3 | |
| ], | |
| fn=on_submit, | |
| examples_per_page=6 | |
| ) | |
| #* main code, model and moge model initialization | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| load_model_1_3b(device=device) | |
| load_moge_model(device=device) | |
| torch.cuda.empty_cache() | |
| demo.queue().launch(share = True,server_name="0.0.0.0", server_port=7860) | |
| if __name__ == '__main__': | |
| process_video = spaces.GPU(process_video) | |
| main() | |