Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import Any, Dict | |
| import numpy as np | |
| import torch | |
| from diffusers.image_processor import VaeImageProcessor | |
| from PIL import Image | |
| from torch import nn | |
| logger: logging.Logger = logging.getLogger(__name__) | |
| class LeffaTransform(nn.Module): | |
| def __init__( | |
| self, | |
| height: int = 1024, | |
| width: int = 768, | |
| dataset: str = "virtual_tryon", # virtual_tryon or pose_transfer | |
| ): | |
| super().__init__() | |
| self.height = height | |
| self.width = width | |
| self.dataset = dataset | |
| self.vae_processor = VaeImageProcessor(vae_scale_factor=8) | |
| self.mask_processor = VaeImageProcessor( | |
| vae_scale_factor=8, | |
| do_normalize=False, | |
| do_binarize=True, | |
| do_convert_grayscale=True, | |
| ) | |
| def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: | |
| batch_size = len(batch["src_image"]) | |
| src_image_list = [] | |
| ref_image_list = [] | |
| mask_list = [] | |
| densepose_list = [] | |
| for i in range(batch_size): | |
| # 1. get original data | |
| src_image = batch["src_image"][i] | |
| ref_image = batch["ref_image"][i] | |
| mask = batch["mask"][i] | |
| densepose = batch["densepose"][i] | |
| # 3. process data | |
| src_image = self.vae_processor.preprocess( | |
| src_image, self.height, self.width)[0] | |
| ref_image = self.vae_processor.preprocess( | |
| ref_image, self.height, self.width)[0] | |
| mask = self.mask_processor.preprocess( | |
| mask, self.height, self.width)[0] | |
| if self.dataset in ["pose_transfer"]: | |
| densepose = densepose.resize( | |
| (self.width, self.height), Image.NEAREST) | |
| else: | |
| densepose = self.vae_processor.preprocess( | |
| densepose, self.height, self.width | |
| )[0] | |
| src_image = self.prepare_image(src_image) | |
| ref_image = self.prepare_image(ref_image) | |
| mask = self.prepare_mask(mask) | |
| if self.dataset in ["pose_transfer"]: | |
| densepose = self.prepare_densepose(densepose) | |
| else: | |
| densepose = self.prepare_image(densepose) | |
| src_image_list.append(src_image) | |
| ref_image_list.append(ref_image) | |
| mask_list.append(mask) | |
| densepose_list.append(densepose) | |
| src_image = torch.cat(src_image_list, dim=0) | |
| ref_image = torch.cat(ref_image_list, dim=0) | |
| mask = torch.cat(mask_list, dim=0) | |
| densepose = torch.cat(densepose_list, dim=0) | |
| batch["src_image"] = src_image | |
| batch["ref_image"] = ref_image | |
| batch["mask"] = mask | |
| batch["densepose"] = densepose | |
| return batch | |
| def prepare_image(image): | |
| if isinstance(image, torch.Tensor): | |
| # Batch single image | |
| if image.ndim == 3: | |
| image = image.unsqueeze(0) | |
| image = image.to(dtype=torch.float32) | |
| else: | |
| # preprocess image | |
| if isinstance(image, (Image.Image, np.ndarray)): | |
| image = [image] | |
| if isinstance(image, list) and isinstance(image[0], Image.Image): | |
| image = [np.array(i.convert("RGB"))[None, :] for i in image] | |
| image = np.concatenate(image, axis=0) | |
| elif isinstance(image, list) and isinstance(image[0], np.ndarray): | |
| image = np.concatenate([i[None, :] for i in image], axis=0) | |
| image = image.transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image).to( | |
| dtype=torch.float32) / 127.5 - 1.0 | |
| return image | |
| def prepare_mask(mask): | |
| if isinstance(mask, torch.Tensor): | |
| if mask.ndim == 2: | |
| # Batch and add channel dim for single mask | |
| mask = mask.unsqueeze(0).unsqueeze(0) | |
| elif mask.ndim == 3 and mask.shape[0] == 1: | |
| # Single mask, the 0'th dimension is considered to be | |
| # the existing batch size of 1 | |
| mask = mask.unsqueeze(0) | |
| elif mask.ndim == 3 and mask.shape[0] != 1: | |
| # Batch of mask, the 0'th dimension is considered to be | |
| # the batching dimension | |
| mask = mask.unsqueeze(1) | |
| # Binarize mask | |
| mask[mask < 0.5] = 0 | |
| mask[mask >= 0.5] = 1 | |
| else: | |
| # preprocess mask | |
| if isinstance(mask, (Image.Image, np.ndarray)): | |
| mask = [mask] | |
| if isinstance(mask, list) and isinstance(mask[0], Image.Image): | |
| mask = np.concatenate( | |
| [np.array(m.convert("L"))[None, None, :] for m in mask], | |
| axis=0, | |
| ) | |
| mask = mask.astype(np.float32) / 255.0 | |
| elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): | |
| mask = np.concatenate([m[None, None, :] for m in mask], axis=0) | |
| mask[mask < 0.5] = 0 | |
| mask[mask >= 0.5] = 1 | |
| mask = torch.from_numpy(mask) | |
| return mask | |
| def prepare_densepose(densepose): | |
| """ | |
| For internal (meta) densepose, the first and second channel should be normalized to 0~1 by 255.0, | |
| and the third channel should be normalized to 0~1 by 24.0 | |
| """ | |
| if isinstance(densepose, torch.Tensor): | |
| # Batch single densepose | |
| if densepose.ndim == 3: | |
| densepose = densepose.unsqueeze(0) | |
| densepose = densepose.to(dtype=torch.float32) | |
| else: | |
| # preprocess densepose | |
| if isinstance(densepose, (Image.Image, np.ndarray)): | |
| densepose = [densepose] | |
| if isinstance(densepose, list) and isinstance( | |
| densepose[0], Image.Image | |
| ): | |
| densepose = [np.array(i.convert("RGB"))[None, :] | |
| for i in densepose] | |
| densepose = np.concatenate(densepose, axis=0) | |
| elif isinstance(densepose, list) and isinstance(densepose[0], np.ndarray): | |
| densepose = np.concatenate( | |
| [i[None, :] for i in densepose], axis=0) | |
| densepose = densepose.transpose(0, 3, 1, 2) | |
| densepose = densepose.astype(np.float32) | |
| densepose[:, 0:2, :, :] /= 255.0 | |
| densepose[:, 2:3, :, :] /= 24.0 | |
| densepose = torch.from_numpy(densepose).to( | |
| dtype=torch.float32) * 2.0 - 1.0 | |
| return densepose | |