Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import scipy.ndimage | |
| import torch | |
| from PIL import ImageFilter | |
| from scipy.ndimage import binary_closing, binary_fill_holes | |
| from .image_convert import pil2tensor, tensor2pil | |
| def combine_mask(destination, source, x, y): | |
| output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() | |
| source = source.reshape((-1, source.shape[-2], source.shape[-1])) | |
| left, top = ( | |
| x, | |
| y, | |
| ) | |
| right, bottom = ( | |
| min(left + source.shape[-1], destination.shape[-1]), | |
| min(top + source.shape[-2], destination.shape[-2]), | |
| ) | |
| visible_width, visible_height = ( | |
| right - left, | |
| bottom - top, | |
| ) | |
| source_portion = source[:, :visible_height, :visible_width] | |
| destination_portion = destination[:, top:bottom, left:right] | |
| # operation == "subtract": | |
| output[:, top:bottom, left:right] = destination_portion - source_portion | |
| output = torch.clamp(output, 0.0, 1.0) | |
| return output | |
| def grow_mask(mask, expand, tapered_corners): | |
| if expand == 0: | |
| return mask | |
| device = mask.device | |
| mask = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) # 添加通道维度 | |
| # 创建卷积核 | |
| c = 0.0 if tapered_corners else 1.0 | |
| kernel = torch.tensor([[c, 1.0, c], [1.0, 1.0, 1.0], [c, 1.0, c]], device=device).unsqueeze(0).unsqueeze(0) | |
| # 计算填充 | |
| padding = abs(expand) | |
| if expand > 0: | |
| # 膨胀操作 | |
| mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=0) | |
| mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=expand) | |
| else: | |
| # 腐蚀操作 | |
| mask = 1 - mask | |
| mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=1) | |
| mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=-expand) | |
| mask = 1 - mask | |
| # 移除额外的padding | |
| if padding > 0: | |
| mask = mask[:, :, padding:-padding, padding:-padding] | |
| # 将结果转回原始形状 | |
| output = mask.squeeze(1) | |
| return torch.clamp(output, 0.0, 1.0) | |
| def fill_holes(mask): | |
| holemask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu() | |
| out = [] | |
| for m in holemask: | |
| mask_np = m.numpy() | |
| binary_mask = mask_np > 0 | |
| struct = np.ones((5, 5)) | |
| closed_mask = binary_closing(binary_mask, structure=struct, border_value=1) | |
| filled_mask = binary_fill_holes(closed_mask) | |
| output = filled_mask.astype(np.float32) * 255 # type: ignore | |
| output = torch.from_numpy(output) | |
| out.append(output) | |
| mask = torch.stack(out, dim=0) | |
| mask = torch.clamp(mask, 0.0, 1.0) | |
| return mask | |
| def invert_mask(mask): | |
| return 1.0 - mask | |
| def expand_mask(mask, expand, tapered_corners): | |
| c = 0 if tapered_corners else 1 | |
| kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]]) | |
| device = mask.device | |
| mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu() | |
| out = [] | |
| for m in mask: | |
| output = m.numpy() | |
| for _ in range(abs(expand)): | |
| if expand < 0: | |
| output = scipy.ndimage.grey_erosion(output, footprint=kernel) | |
| else: | |
| output = scipy.ndimage.grey_dilation(output, footprint=kernel) | |
| output = torch.from_numpy(output) | |
| out.append(output) | |
| return torch.stack(out, dim=0).to(device) | |
| def blur_mask(mask, radius): | |
| pil_image = tensor2pil(mask) | |
| return pil2tensor(pil_image.filter(ImageFilter.GaussianBlur(radius))) | |
| def solid_mask(width, height, value=1): | |
| return torch.full((1, height, width), value, dtype=torch.float32, device='cpu') | |
| def mask_floor(mask, threshold: float = 0.99): | |
| # 将遮罩二值化,大于等于阈值的设为1,小于阈值的设为0 | |
| return (mask >= threshold).to(mask.dtype) | |
| def mask_unsqueeze(mask): | |
| # 调整遮罩的维度,确保输出的遮罩形状为 B1HW | |
| if len(mask.shape) == 3: # BHW -> B1HW | |
| mask = mask.unsqueeze(1) | |
| elif len(mask.shape) == 2: # HW -> B1HW | |
| mask = mask.unsqueeze(0).unsqueeze(0) | |
| return mask | |