Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding:utf-8 -*- | |
| # Power by Zongsheng Yue 2022-06-09 14:59:55 | |
| import torch | |
| import random | |
| import numpy as np | |
| from einops import rearrange | |
| def batch_inpainging_from_grad(im_in, mask, gradx, grady): | |
| ''' | |
| Recovering from gradient for batch data (torch tensro). | |
| Input: | |
| im_in: N x c x h x w, torch tensor, masked image | |
| mask: N x 1 x h x w, torch tensor | |
| gradx, grady: N x c x h x w, torch tensor, image gradient | |
| ''' | |
| im_out = torch.zeros_like(im_in.data) | |
| for ii in range(im_in.shape[0]): | |
| im_current, gradx_current, grady_current = [rearrange(x[ii,].cpu().numpy(), 'c h w -> h w c') | |
| for x in [im_in, gradx, grady]] | |
| mask_current = mask[ii, 0,].cpu().numpy() | |
| out_current = inpainting_from_grad(im_current, mask_current, gradx_current, grady_current) | |
| im_out[ii,] = torch.from_numpy(rearrange(out_current, 'h w c -> c h w')).to( | |
| device=im_in.device, | |
| dtype=im_in.dtype | |
| ) | |
| return im_out | |
| def inpainting_from_grad(im_in, mask, gradx, grady): | |
| ''' | |
| Input: | |
| im_in: h x w x c, masked image, numpy array | |
| mask: h x w, image mask, 1 represents missing value | |
| gradx: h x w x c, gradient along x-axis, numpy array | |
| grady: h x w x c, gradient along y-axis, numpy array | |
| Output: | |
| im_out: recoverd image | |
| ''' | |
| h, w = im_in.shape[:2] | |
| counts_h = np.sum(1-mask, axis=0, keepdims=False) | |
| counts_w = np.sum(1-mask, axis=1, keepdims=False) | |
| if np.any(counts_h[1:-1,] == h): | |
| idx = find_first_index(counts_h[1:-1,], h) + 1 | |
| im_out = fill_image_from_gradx(im_in, mask, gradx, idx) | |
| elif np.any(counts_w[1:-1,] == w): | |
| idx = find_first_index(counts_w[1:-1,], w) + 1 | |
| im_out = inpainting_from_grad(im_in.T, mask.T, gradx.T, idx) | |
| else: | |
| idx = random.choices(list(range(1,w-1)), k=1, weights=counts_h[1:-1])[0] | |
| line = fill_line(im_in[:, idx, ], mask[:, idx,], grady[:, idx,]) | |
| im_in[:, idx,] = line | |
| im_out = fill_image_from_gradx(im_in, mask, gradx, idx) | |
| if im_in.ndim > mask.ndim: | |
| mask = mask[:, :, None] | |
| im_out = im_in + im_out * mask | |
| return im_out | |
| def fill_image_from_gradx(im_in, mask, gradx, idx): | |
| init = np.zeros_like(im_in) | |
| init[:, idx,] = im_in[:, idx,] | |
| right = np.cumsum(init[:, idx:-1, ] + gradx[:, idx+1:, ], axis=1) | |
| left = np.cumsum( | |
| init[:, idx:0:-1, ] - gradx[:, idx:0:-1, ], | |
| axis=1 | |
| )[:, ::-1] | |
| center = im_in[:, idx, ][:, None] # h x 1 x 3 | |
| im_out = np.concatenate((left, center, right), axis=1) | |
| return im_out | |
| def fill_line(xx, mm, grad): | |
| ''' | |
| Fill one line from grad. | |
| Input: | |
| xx: n x c array, masked vector | |
| mm: (n,) array, mask, 1 represent missing value | |
| grad: (n,) array | |
| ''' | |
| n = xx.shape[0] | |
| assert mm.sum() < n | |
| if mm.sum() == 0: | |
| return xx | |
| else: | |
| idx1 = find_first_index(mm, 1) | |
| if idx1 == 0: | |
| idx2 = find_first_index(mm, 0) | |
| subx = xx[idx2::-1,].copy() | |
| subgrad = grad[idx2::-1, ].copy() | |
| subx -= subgrad | |
| xx[:idx2,] = np.cumsum(subx, axis=0)[idx2-1::-1,] | |
| mm[idx1:idx2,] = 0 | |
| else: | |
| idx2 = find_first_index(mm[idx1:,], 0) + idx1 | |
| subx = xx[idx1-1:idx2-1,].copy() | |
| subgrad = grad[idx1:idx2,].copy() | |
| subx += subgrad | |
| xx[idx1:idx2,] = np.cumsum(subx, axis=0) | |
| mm[idx1:idx2,] = 0 | |
| return fill_line(xx, mm, grad) | |
| def find_first_index(mm, value): | |
| ''' | |
| Input: | |
| mm: (n, ) array | |
| value: scalar | |
| ''' | |
| try: | |
| out = next((idx for idx, val in np.ndenumerate(mm) if val == value))[0] | |
| except StopIteration: | |
| out = mm.shape[0] | |
| return out | |
| if __name__ == '__main__': | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(str(Path(__file__).resolve().parents[1])) | |
| from utils import util_image | |
| from datapipe.masks.train import process_mask | |
| # mask_file_names = [x for x in Path('../lama/LaMa_test_images').glob('*mask*.png')] | |
| mask_file_names = [x for x in Path('./testdata/inpainting/val/places/').glob('*mask*.png')] | |
| file_names = [x.parents[0]/(x.stem.rsplit('_mask',1)[0]+'.png') for x in mask_file_names] | |
| for im_path, mask_path in zip(file_names, mask_file_names): | |
| im = util_image.imread(im_path, chn='rgb', dtype='float32') | |
| mask = process_mask(util_image.imread(mask_path, chn='rgb', dtype='float32')[:, :, 0]) | |
| grad_dict = util_image.imgrad(im) | |
| im_masked = im * (1 - mask[:, :, None]) | |
| im_recover = inpainting_from_grad(im_masked, mask, grad_dict['gradx'], grad_dict['grady']) | |
| error_max = np.abs(im_recover -im).max() | |
| print('Error Max: {:.2e}'.format(error_max)) | |