Spaces:
Runtime error
Runtime error
| import glob | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from zoedepth.utils.misc import colorize | |
| from zoedepth.utils.config import get_config | |
| from zoedepth.models.builder import build_model | |
| from zoedepth.models.model_io import load_wts | |
| from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline | |
| def load_ckpt(config, model, checkpoint_dir: str = "./checkpoints", ckpt_type: str = "best"): | |
| if hasattr(config, "checkpoint"): | |
| checkpoint = config.checkpoint | |
| elif hasattr(config, "ckpt_pattern"): | |
| pattern = config.ckpt_pattern | |
| matches = glob.glob(os.path.join( | |
| checkpoint_dir, f"*{pattern}*{ckpt_type}*")) | |
| if not (len(matches) > 0): | |
| raise ValueError(f"No matches found for the pattern {pattern}") | |
| checkpoint = matches[0] | |
| else: | |
| return model | |
| model = load_wts(model, checkpoint) | |
| print("Loaded weights from {0}".format(checkpoint)) | |
| return model | |
| def get_zoe_dc_model(vanilla: bool = False, ckpt_path: str = None, **kwargs): | |
| def ZoeD_N(midas_model_type="DPT_BEiT_L_384", vanilla=False, **kwargs): | |
| if midas_model_type != "DPT_BEiT_L_384": | |
| raise ValueError(f"Only DPT_BEiT_L_384 MiDaS model is supported for pretrained Zoe_N model, got: {midas_model_type}") | |
| zoedepth_config = get_config("zoedepth", "train", **kwargs) | |
| model = build_model(zoedepth_config) | |
| if vanilla: | |
| model.__setattr__("vanilla", True) | |
| return model | |
| else: | |
| model.__setattr__("vanilla", False) | |
| if zoedepth_config.add_depth_channel and not vanilla: | |
| model.core.core.pretrained.model.patch_embed.proj = torch.nn.Conv2d( | |
| model.core.core.pretrained.model.patch_embed.proj.in_channels+2, | |
| model.core.core.pretrained.model.patch_embed.proj.out_channels, | |
| kernel_size=model.core.core.pretrained.model.patch_embed.proj.kernel_size, | |
| stride=model.core.core.pretrained.model.patch_embed.proj.stride, | |
| padding=model.core.core.pretrained.model.patch_embed.proj.padding, | |
| bias=True) | |
| if ckpt_path is not None: | |
| assert os.path.exists(ckpt_path) | |
| zoedepth_config.__setattr__("checkpoint", ckpt_path) | |
| else: | |
| assert vanilla, "ckpt_path must be provided for non-vanilla model" | |
| model = load_ckpt(zoedepth_config, model) | |
| return model | |
| return ZoeD_N(vanilla=vanilla, ckpt_path=ckpt_path, **kwargs) | |
| def infer_with_pad(zoe, x, pad_input: bool = True, fh: float = 3, fw: float = 3, upsampling_mode: str = "bicubic", padding_mode: str = "reflect", **kwargs): | |
| assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) | |
| if pad_input: | |
| assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" | |
| pad_h = int(np.sqrt(x.shape[2]/2) * fh) | |
| pad_w = int(np.sqrt(x.shape[3]/2) * fw) | |
| padding = [pad_w, pad_w] | |
| if pad_h > 0: | |
| padding += [pad_h, pad_h] | |
| x_rgb = x[:, :3] | |
| x_remaining = x[:, 3:] | |
| x_rgb = F.pad(x_rgb, padding, mode=padding_mode, **kwargs) | |
| x_remaining = F.pad(x_remaining, padding, mode="constant", value=0, **kwargs) | |
| x = torch.cat([x_rgb, x_remaining], dim=1) | |
| out = zoe(x)["metric_depth"] | |
| if out.shape[-2:] != x.shape[-2:]: | |
| out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) | |
| if pad_input: | |
| # crop to the original size, handling the case where pad_h and pad_w is 0 | |
| if pad_h > 0: | |
| out = out[:, :, pad_h:-pad_h,:] | |
| if pad_w > 0: | |
| out = out[:, :, :, pad_w:-pad_w] | |
| return out | |
| def infer_with_zoe_dc(zoe_dc, image, sparse_depth, scaling: float = 1): | |
| sparse_depth_mask = (sparse_depth[None, None, ...] > 0).float() | |
| # the metric depth range defined during training is [1e-3, 10] | |
| x = torch.cat([image[None, ...], sparse_depth[None, None, ...] / (float(scaling) * 10.0), sparse_depth_mask], dim=1).to(zoe_dc.device) | |
| out = infer_with_pad(zoe_dc, x) | |
| out_flip = infer_with_pad(zoe_dc, torch.flip(x, dims=[3])) | |
| out = (out + torch.flip(out_flip, dims=[3])) / 2 | |
| pred_depth = float(scaling) * out | |
| return torch.nn.functional.interpolate(pred_depth, image.shape[-2:], mode='bilinear', align_corners=True)[0, 0] | |
| def get_sd_pipeline(): | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float16, | |
| ) | |
| pipe.vae = AsymmetricAutoencoderKL.from_pretrained( | |
| "cross-attention/asymmetric-autoencoder-kl-x-2", | |
| torch_dtype=torch.float16 | |
| ) | |
| return pipe | |