Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| from draggan.viz import renderer | |
| import torch | |
| from torch import optim | |
| from torch.nn import functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import dataclasses | |
| import draggan.dnnlib as dnnlib | |
| from .lpips import util | |
| def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | |
| lr_ramp = min(1, (1 - t) / rampdown) | |
| lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | |
| lr_ramp = lr_ramp * min(1, t / rampup) | |
| return initial_lr * lr_ramp | |
| def make_image(tensor): | |
| return ( | |
| tensor.detach() | |
| .clamp_(min=-1, max=1) | |
| .add(1) | |
| .div_(2) | |
| .mul(255) | |
| .type(torch.uint8) | |
| .permute(0, 2, 3, 1) | |
| .to("cpu") | |
| .numpy() | |
| ) | |
| class InverseConfig: | |
| lr_warmup = 0.05 | |
| lr_decay = 0.25 | |
| lr = 0.1 | |
| noise = 0.05 | |
| noise_decay = 0.75 | |
| # step = 1000 | |
| step = 1000 | |
| noise_regularize = 1e5 | |
| mse = 0.1 | |
| def inverse_image( | |
| g_ema, | |
| image, | |
| percept, | |
| image_size=256, | |
| w_plus = False, | |
| config=InverseConfig(), | |
| device='cuda:0' | |
| ): | |
| args = config | |
| n_mean_latent = 10000 | |
| resize = min(image_size, 256) | |
| if torch.is_tensor(image)==False: | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(resize,), | |
| transforms.CenterCrop(resize), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| img = transform(image) | |
| else: | |
| img = transforms.functional.resize(image,resize) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.CenterCrop(resize), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| img = transform(img) | |
| imgs = [] | |
| imgs.append(img) | |
| imgs = torch.stack(imgs, 0).to(device) | |
| with torch.no_grad(): | |
| #noise_sample = torch.randn(n_mean_latent, 512, device=device) | |
| noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device) | |
| #label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device) | |
| w_samples = g_ema.mapping(noise_sample,None) | |
| w_samples = w_samples[:, :1, :] | |
| w_avg = w_samples.mean(0) | |
| w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5 | |
| noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name} | |
| for noise in noises.values(): | |
| noise = torch.randn_like(noise) | |
| noise.requires_grad = True | |
| w_opt = w_avg.detach().clone() | |
| if w_plus: | |
| w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1) | |
| w_opt.requires_grad = True | |
| #if args.w_plus: | |
| #latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) | |
| optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr) | |
| pbar = tqdm(range(args.step)) | |
| latent_path = [] | |
| for i in pbar: | |
| t = i / args.step | |
| lr = get_lr(t, args.lr) | |
| optimizer.param_groups[0]["lr"] = lr | |
| noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 | |
| w_noise = torch.randn_like(w_opt) * noise_strength | |
| if w_plus: | |
| ws = w_opt + w_noise | |
| else: | |
| ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1]) | |
| img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True) | |
| #latent_n = latent_noise(latent_in, noise_strength.item()) | |
| #latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) | |
| #img_gen, F = g_ema.generate(latent, noise) | |
| # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. | |
| if img_gen.shape[2] > 256: | |
| img_gen = F.interpolate(img_gen, size=(256, 256), mode='area') | |
| p_loss = percept(img_gen,imgs) | |
| # Noise regularization. | |
| reg_loss = 0.0 | |
| for v in noises.values(): | |
| noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() | |
| while True: | |
| reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 | |
| reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 | |
| if noise.shape[2] <= 8: | |
| break | |
| noise = F.avg_pool2d(noise, kernel_size=2) | |
| mse_loss = F.mse_loss(img_gen, imgs) | |
| loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| # Normalize noise. | |
| with torch.no_grad(): | |
| for buf in noises.values(): | |
| buf -= buf.mean() | |
| buf *= buf.square().mean().rsqrt() | |
| if (i + 1) % 100 == 0: | |
| latent_path.append(w_opt.detach().clone()) | |
| pbar.set_description( | |
| ( | |
| f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};" | |
| f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" | |
| ) | |
| ) | |
| #latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) | |
| #img_gen, F = g_ema.generate(latent, noise) | |
| if w_plus: | |
| ws = latent_path[-1] | |
| else: | |
| ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1]) | |
| img_gen = g_ema.synthesis(ws, noise_mode='const') | |
| result = { | |
| "latent": latent_path[-1], | |
| "sample": img_gen, | |
| "real": imgs, | |
| } | |
| return result | |
| def toogle_grad(model, flag=True): | |
| for p in model.parameters(): | |
| p.requires_grad = flag | |
| class PTI: | |
| def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ): | |
| self.g_ema = G | |
| self.l2_lambda = l2_lambda | |
| self.max_pti_step = max_pti_step | |
| self.pti_lr = pti_lr | |
| self.percept = percept | |
| def cacl_loss(self,percept, generated_image,real_image): | |
| mse_loss = F.mse_loss(generated_image, real_image) | |
| p_loss = percept(generated_image, real_image).sum() | |
| loss = p_loss +self.l2_lambda * mse_loss | |
| return loss | |
| def train(self,img,w_plus=False): | |
| if not torch.cuda.is_available(): | |
| device = 'cpu' | |
| else: | |
| device = 'cuda' | |
| if torch.is_tensor(img) == False: | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(self.g_ema.img_resolution, ), | |
| transforms.CenterCrop(self.g_ema.img_resolution), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| real_img = transform(img).to(device).unsqueeze(0) | |
| else: | |
| img = transforms.functional.resize(img, self.g_ema.img_resolution) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.CenterCrop(self.g_ema.img_resolution), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ] | |
| ) | |
| real_img = transform(img).to(device).unsqueeze(0) | |
| inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus,device=device) | |
| w_pivot = inversed_result['latent'] | |
| if w_plus: | |
| ws = w_pivot | |
| else: | |
| ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1]) | |
| toogle_grad(self.g_ema,True) | |
| optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr) | |
| print('start PTI') | |
| pbar = tqdm(range(self.max_pti_step)) | |
| for i in pbar: | |
| t = i / self.max_pti_step | |
| lr = get_lr(t, self.pti_lr) | |
| optimizer.param_groups[0]["lr"] = lr | |
| generated_image = self.g_ema.synthesis(ws,noise_mode='const') | |
| loss = self.cacl_loss(self.percept,generated_image,real_img) | |
| pbar.set_description( | |
| ( | |
| f"loss: {loss.item():.4f}" | |
| ) | |
| ) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| with torch.no_grad(): | |
| generated_image = self.g_ema.synthesis(ws, noise_mode='const') | |
| return generated_image,ws |