import os import os import subprocess def install(package): subprocess.check_call([os.sys.executable, "-m", "pip", "install", package]) install("torchvision") install("loguru") install("imageio") install("modelscope") install("einops") install("safetensors") install("transformers") install("ftfy") install("accelerate") install("sentencepiece") install("spaces") install("opencv-python") install("trimesh") install("gradio_litmodel3d") install("open3d") import gradio as gr import numpy as np import torch from PIL import Image from loguru import logger from tqdm import tqdm from tools.common_utils import save_video from dkt.pipelines.wan_video_new import WanVideoPipeline, ModelConfig try: import gradio_client.utils as _gc_utils if hasattr(_gc_utils, "get_type"): _orig_get_type = _gc_utils.get_type def _get_type_safe(schema): if not isinstance(schema, dict): return "Any" return _orig_get_type(schema) _gc_utils.get_type = _get_type_safe except Exception: pass # Additional guard: handle boolean JSON Schemas and parsing errors try: import gradio_client.utils as _gc_utils # Wrap the internal _json_schema_to_python_type if present if hasattr(_gc_utils, "_json_schema_to_python_type"): _orig_internal = _gc_utils._json_schema_to_python_type def _json_schema_to_python_type_safe(schema, defs=None): if isinstance(schema, bool): return "Any" try: return _orig_internal(schema, defs) except Exception: return "Any" _gc_utils._json_schema_to_python_type = _json_schema_to_python_type_safe # Also wrap the public json_schema_to_python_type to be extra defensive if hasattr(_gc_utils, "json_schema_to_python_type"): _orig_public = _gc_utils.json_schema_to_python_type def json_schema_to_python_type_safe(schema): try: return _orig_public(schema) except Exception: return "Any" _gc_utils.json_schema_to_python_type = json_schema_to_python_type_safe except Exception: pass import cv2 import copy import trimesh from gradio_litmodel3d import LitModel3D from os.path import join from tools.depth2pcd import depth2pcd try: from moge.model.v2 import MoGeModel except: os.system('pip install git+https://github.com/microsoft/MoGe.git -i https://pypi.org/simple/ --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host files.pythonhosted.org') from moge.model.v2 import MoGeModel from tools.eval_utils import transfer_pred_disp2depth, colorize_depth_map import glob import datetime import shutil import tempfile import spaces PIPE_1_3B = None MOGE_MODULE = None #* better for bg: logs/outs/train/remote/sft-T2SQNet_glassverse_cleargrasp_HISS_DREDS_DREDS_glassverse_interiorverse-4gpus-origin-lora128-1.3B-rgb_depth-w832-h480-Wan2.1-Fun-Control-2025-10-28-23:26:41/epoch-0-20000.safetensors PROMPT = 'depth' NEGATIVE_PROMPT = '' def resize_frame(frame, height, width): frame = np.array(frame) frame = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).float() / 255.0 frame = torch.nn.functional.interpolate(frame, (height, width), mode="bicubic", align_corners=False, antialias=True) frame = (frame.squeeze(0).permute(1, 2, 0).clamp(0, 1) * 255).byte().numpy() frame = Image.fromarray(frame) return frame def pmap_to_glb(point_map, valid_mask, frame) -> trimesh.Scene: pts_3d = point_map[valid_mask] * np.array([-1, -1, 1]) pts_rgb = frame[valid_mask] # Initialize a 3D scene scene_3d = trimesh.Scene() # Add point cloud data to the scene point_cloud_data = trimesh.PointCloud( vertices=pts_3d, colors=pts_rgb ) scene_3d.add_geometry(point_cloud_data) return scene_3d def create_simple_glb_from_pointcloud(points, colors, glb_filename): try: if len(points) == 0: logger.warning(f"No valid points to create GLB for {glb_filename}") return False if colors is not None: # logger.info(f"Adding colors to GLB: shape={colors.shape}, range=[{colors.min():.3f}, {colors.max():.3f}]") pts_rgb = colors else: logger.info("No colors provided, adding default white colors") pts_rgb = np.ones((len(points), 3)) valid_mask = np.ones(len(points), dtype=bool) scene_3d = pmap_to_glb(points, valid_mask, pts_rgb) scene_3d.export(glb_filename) # logger.info(f"Saved GLB file using trimesh: {glb_filename}") return True except Exception as e: logger.error(f"Error creating GLB from pointcloud using trimesh: {str(e)}") return False def extract_frames_from_video_file(video_path): try: cap = cv2.VideoCapture(video_path) frames = [] fps = cap.get(cv2.CAP_PROP_FPS) if fps <= 0: fps = 15.0 while True: ret, frame = cap.read() if not ret: break frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_rgb = Image.fromarray(frame_rgb) frames.append(frame_rgb) cap.release() return frames, fps except Exception as e: logger.error(f"Error extracting frames from {video_path}: {str(e)}") return [], 15.0 def load_moge_model(device="cuda:0"): global MOGE_MODULE if MOGE_MODULE is not None: return MOGE_MODULE logger.info(f"Loading MoGe model on {device}...") MOGE_MODULE = MoGeModel.from_pretrained('Ruicheng/moge-2-vitl-normal').to(device) return MOGE_MODULE def load_model_1_3b(device="cuda:0"): global PIPE_1_3B if PIPE_1_3B is not None: return PIPE_1_3B logger.info(f"Loading 1.3B model on {device}...") pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=[ ModelConfig( model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", ), ModelConfig( model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", ), ModelConfig( model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", ), ModelConfig( model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu", ), ], training_strategy="origin", ) lora_config = ModelConfig( model_id="Daniellesry/DKT-Depth-1-3B", origin_file_pattern="dkt-1-3B.safetensors", offload_device="cpu", ) lora_config.download_if_necessary(use_usp=False) pipe.load_lora(pipe.dit, lora_config.path, alpha=1.0)#todo is it work? pipe.enable_vram_management() PIPE_1_3B = pipe return pipe def get_model(model_size): if model_size == "1.3B": assert PIPE_1_3B is not None, "1.3B model not initialized" return PIPE_1_3B else: raise ValueError(f"Unsupported model size: {model_size}") def process_video( video_file, model_size, height, width, num_inference_steps, window_size, overlap ): try: pipe = get_model(model_size) if pipe is None: return None, f"Model {model_size} not initialized. Please restart the application." tmp_video_path = video_file timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") # 使用临时目录存储所有文件 cur_save_dir = tempfile.mkdtemp(prefix=f'dkt_{timestamp}_{model_size}_') original_filename = f"input_{timestamp}.mp4" dst_path = os.path.join(cur_save_dir, original_filename) shutil.copy2(tmp_video_path, dst_path) origin_frames, input_fps = extract_frames_from_video_file(tmp_video_path) if not origin_frames: return None, "Failed to extract frames from video" logger.info(f"Extracted {len(origin_frames)} frames from video") original_width, original_height = origin_frames[0].size ROTATE = False if original_width < original_height: ROTATE = True origin_frames = [x.transpose(Image.ROTATE_90) for x in origin_frames] tmp = original_width original_width = original_height original_height = tmp frames = [resize_frame(frame, height, width) for frame in origin_frames] frame_length = len(frames) if (frame_length - 1) % 4 != 0: new_len = ((frame_length - 1) // 4 + 1) * 4 + 1 frames = frames + [copy.deepcopy(frames[-1]) for _ in range(new_len - frame_length)] control_video = frames video, vae_outs = pipe( prompt=PROMPT, negative_prompt=NEGATIVE_PROMPT, control_video=control_video, height=height, width=width, num_frames=len(control_video), seed=1, tiled=False, num_inference_steps=num_inference_steps, sliding_window_size=window_size, sliding_window_stride=window_size - overlap, cfg_scale=1.0, ) #* moge process torch.cuda.empty_cache() processed_video = video[:frame_length] processed_video = [resize_frame(frame, original_height, original_width) for frame in processed_video] if ROTATE: processed_video = [x.transpose(Image.ROTATE_270) for x in processed_video] origin_frames = [x.transpose(Image.ROTATE_270) for x in origin_frames] output_filename = f"output_{timestamp}.mp4" output_path = os.path.join(cur_save_dir, output_filename) color_predictions = [] if PROMPT == 'depth': prediced_depth_map_np = [np.array(item).astype(np.float32).mean(-1) for item in processed_video] prediced_depth_map_np = np.stack(prediced_depth_map_np) prediced_depth_map_np = prediced_depth_map_np/ 255.0 __min = prediced_depth_map_np.min() __max = prediced_depth_map_np.max() prediced_depth_map_np = (prediced_depth_map_np - __min) / (__max - __min) color_predictions = [colorize_depth_map(item) for item in prediced_depth_map_np] else: color_predictions = processed_video save_video(color_predictions, output_path, fps=input_fps, quality=5) frame_num = len(origin_frames) resize_W,resize_H = origin_frames[0].size vis_pc_num = 4 indices = np.linspace(0, frame_num-1, vis_pc_num) indices = np.round(indices).astype(np.int32) pc_save_dir = os.path.join(cur_save_dir, 'pointclouds') os.makedirs(pc_save_dir, exist_ok=True) glb_files = [] moge_device = MOGE_MODULE.device if MOGE_MODULE is not None else torch.device("cuda:0") for idx in tqdm(indices): orgin_rgb_frame = origin_frames[idx] predicted_depth = processed_video[idx] # Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1] input_image_np = np.array(orgin_rgb_frame) # Convert PIL Image to numpy array input_image = torch.tensor(input_image_np / 255, dtype=torch.float32, device=moge_device).permute(2, 0, 1) output = MOGE_MODULE.infer(input_image) #* "dict_keys(['points', 'intrinsics', 'depth', 'mask', 'normal'])" moge_intrinsics = output['intrinsics'].cpu().numpy() moge_mask = output['mask'].cpu().numpy() moge_depth = output['depth'].cpu().numpy() predicted_depth = np.array(predicted_depth) predicted_depth = predicted_depth.mean(-1) / 255.0 metric_depth = transfer_pred_disp2depth(predicted_depth, moge_depth, moge_mask) moge_intrinsics[0, 0] *= resize_W moge_intrinsics[1, 1] *= resize_H moge_intrinsics[0, 2] *= resize_W moge_intrinsics[1, 2] *= resize_H # pcd = depth2pcd(metric_depth, moge_intrinsics, color=cv2.cvtColor(input_image_np, cv2.COLOR_BGR2RGB), input_mask=moge_mask, ret_pcd=True) pcd = depth2pcd(metric_depth, moge_intrinsics, color=input_image_np, input_mask=moge_mask, ret_pcd=True) # pcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) * np.array([1, -1, -1], dtype=np.float32)) apply_filter = True if apply_filter: cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=3.0) pcd = pcd.select_by_index(ind) #* save pcd: o3d.io.write_point_cloud(f'{pc_save_dir}/{timestamp}_{idx:02d}.ply', pcd) points = np.asarray(pcd.points) colors = np.asarray(pcd.colors) if pcd.has_colors() else None glb_filename = os.path.join(pc_save_dir, f'{timestamp}_{idx:02d}.glb') success = create_simple_glb_from_pointcloud(points, colors, glb_filename) if not success: logger.warning(f"Failed to save GLB file: {glb_filename}") glb_files.append(glb_filename) return output_path, glb_files except Exception as e: logger.error(f"Error processing video: {str(e)}") return None, f"Error: {str(e)}" def main(): #* gradio creation and initialization css = """ #video-display-container { max-height: 100vh; } #video-display-input { max-height: 80vh; } #video-display-output { max-height: 80vh; } #download { height: 62px; } .title { text-align: center; } .description { text-align: center; } .gradio-examples { max-height: 400px; overflow-y: auto; } .gradio-examples .examples-container { display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 10px; padding: 10px; } .gradio-container .gradio-examples .pagination, .gradio-container .gradio-examples .pagination button, div[data-testid="examples"] .pagination, div[data-testid="examples"] .pagination button { font-size: 28px !important; font-weight: bold !important; padding: 15px 20px !important; min-width: 60px !important; height: 60px !important; border-radius: 10px !important; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; border: none !important; cursor: pointer !important; margin: 8px !important; display: inline-block !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; transition: all 0.3s ease !important; } div[data-testid="examples"] .pagination button:not(.active), .gradio-container .gradio-examples .pagination button:not(.active) { font-size: 32px !important; font-weight: bold !important; padding: 15px 20px !important; min-width: 60px !important; height: 60px !important; background: linear-gradient(135deg, #8a9cf0 0%, #9a6bb2 100%) !important; opacity: 0.8 !important; } div[data-testid="examples"] .pagination button:hover, .gradio-container .gradio-examples .pagination button:hover { background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%) !important; transform: translateY(-2px) !important; box-shadow: 0 6px 12px rgba(0,0,0,0.3) !important; opacity: 1 !important; } div[data-testid="examples"] .pagination button.active, .gradio-container .gradio-examples .pagination button.active { background: linear-gradient(135deg, #11998e 0%, #38ef7d 100%) !important; box-shadow: 0 4px 8px rgba(17,153,142,0.4) !important; opacity: 1 !important; } button[class*="pagination"], button[class*="page"] { font-size: 28px !important; font-weight: bold !important; padding: 15px 20px !important; min-width: 60px !important; height: 60px !important; border-radius: 10px !important; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; border: none !important; cursor: pointer !important; margin: 8px !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; transition: all 0.3s ease !important; } """ head_html = """ """ # title = "# Diffusion Knows Transparency: Repurposing Video Diffusion for Transparent Object Depth and Normal Estimation " # description = """Official demo for **DKT **.""" # with gr.Blocks(css=css, title="DKT - Diffusion Knows Transparency", favicon_path="favicon.ico") as demo: height = 480 width = 832 window_size = 21 with gr.Blocks(css=css, title="DKT", head=head_html) as demo: # gr.Markdown(title, elem_classes=["title"]) # gr.Markdown(description, elem_classes=["description"]) # gr.Markdown("### Video Processing Demo", elem_classes=["description"]) with gr.Row(): with gr.Column(): input_video = gr.Video(label="Input Video", elem_id='video-display-input') model_size = gr.Radio( choices=["1.3B", "14B"], value="1.3B", label="Model Size" ) with gr.Accordion("Advanced Parameters", open=False): num_inference_steps = gr.Slider( minimum=1, maximum=50, value=5, step=1, label="Number of Inference Steps" ) overlap = gr.Slider( minimum=1, maximum=20, value=3, step=1, label="Overlap" ) submit = gr.Button(value="Compute Depth", variant="primary") with gr.Column(): output_video = gr.Video( label="Depth Outputs", elem_id='video-display-output', autoplay=True ) vis_video = gr.Video( label="Visualization Video", visible=False, autoplay=True ) with gr.Row(): gr.Markdown("### 3D Point Cloud Visualization", elem_classes=["title"]) with gr.Row(equal_height=True): with gr.Column(scale=1): output_point_map0 = LitModel3D( label="Point Cloud Key Frame 1", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False, # height=400, ) with gr.Column(scale=1): output_point_map1 = LitModel3D( label="Point Cloud Key Frame 2", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False ) with gr.Row(equal_height=True): with gr.Column(scale=1): output_point_map2 = LitModel3D( label="Point Cloud Key Frame 3", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False ) with gr.Column(scale=1): output_point_map3 = LitModel3D( label="Point Cloud Key Frame 4", clear_color=[1.0, 1.0, 1.0, 1.0], interactive=False ) def on_submit(video_file, model_size, num_inference_steps, overlap): if video_file is None: return None, None, None, None, None, None, "Please upload a video file" try: output_path, glb_files = process_video( video_file, model_size, height, width, num_inference_steps, window_size, overlap ) if output_path is None: return None, None, None, None, None, None, glb_files model3d_outputs = [None] * 4 if glb_files: for i, glb_file in enumerate(glb_files[:4]): if os.path.exists(glb_file): model3d_outputs[i] = glb_file return output_path, None, *model3d_outputs except Exception as e: return None, None, None, None, None, None, f"Error: {str(e)}" submit.click( on_submit, inputs=[ input_video, model_size, num_inference_steps, overlap ], outputs=[ output_video, vis_video, output_point_map0, output_point_map1, output_point_map2, output_point_map3 ] ) example_files = glob.glob('examples/*') if example_files: example_inputs = [] for file_path in example_files: example_inputs.append([file_path, "1.3B"]) examples = gr.Examples( examples=example_inputs, inputs=[input_video, model_size], outputs=[ output_video, vis_video, output_point_map0, output_point_map1, output_point_map2, output_point_map3 ], fn=on_submit, examples_per_page=6 ) #* main code, model and moge model initialization device = torch.device("cuda" if torch.cuda.is_available() else "cpu") load_model_1_3b(device=device) load_moge_model(device=device) torch.cuda.empty_cache() demo.queue().launch(share = True,server_name="0.0.0.0", server_port=7860) if __name__ == '__main__': process_video = spaces.GPU(process_video) main()