Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import einops | |
| import numpy as np | |
| import torch | |
| from hydra.utils import instantiate | |
| from lightly.models import utils | |
| # https://docs.lightly.ai/self-supervised-learning/examples/mae.html | |
| from lightly.models.modules import MAEDecoderTIMM, MaskedVisionTransformerTIMM | |
| from timm.models.vision_transformer import VisionTransformer | |
| from huggingface_hub import PyTorchModelHubMixin | |
| class MAE(torch.nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| vit: VisionTransformer = instantiate(cfg.ssl_model.vit, img_size=cfg.ssl_aug.standard_view.output_size) | |
| self.patch_size = vit.patch_embed.patch_size[0] | |
| # Get MAE backbone | |
| self.backbone = MaskedVisionTransformerTIMM(vit=vit) | |
| self.sequence_length = self.backbone.sequence_length | |
| self.encoder_dim = vit.embed_dim # for convenience later | |
| # Get decoder | |
| self.decoder = MAEDecoderTIMM( | |
| num_patches=vit.patch_embed.num_patches, | |
| patch_size=self.patch_size, | |
| embed_dim=vit.embed_dim, | |
| decoder_embed_dim=cfg.ssl_model.decoder.embed_dim, | |
| decoder_depth=cfg.ssl_model.decoder.depth, | |
| decoder_num_heads=cfg.ssl_model.decoder.num_heads, | |
| mlp_ratio=cfg.ssl_model.decoder.mlp_ratio, | |
| proj_drop_rate=cfg.ssl_model.decoder.dropout, | |
| attn_drop_rate=cfg.ssl_model.decoder.attention_dropout, | |
| ) | |
| self.mask_ratio = cfg.ssl_model.mask_ratio # saved as model parameter, not aug, since it is applied within model | |
| self.criterion = torch.nn.MSELoss() | |
| def forward_encoder(self, images, idx_keep=None): | |
| return self.backbone.encode(images=images, idx_keep=idx_keep) | |
| def forward_decoder(self, x_encoded, idx_keep, idx_mask): | |
| # build decoder input | |
| batch_size = x_encoded.shape[0] | |
| x_decode = self.decoder.embed(x_encoded) | |
| x_masked = utils.repeat_token(self.decoder.mask_token, (batch_size, self.sequence_length)) | |
| x_masked = utils.set_at_index(x_masked, idx_keep, x_decode.type_as(x_masked)) | |
| # decoder forward pass | |
| x_decoded = self.decoder.decode(x_masked) | |
| # predict pixel values for masked tokens | |
| x_pred = utils.get_at_index(x_decoded, idx_mask) | |
| x_pred = self.decoder.predict(x_pred) | |
| return x_pred | |
| def training_step(self, batch, batch_idx): | |
| images = batch["image"] # views contains only a single view | |
| batch_size = images.shape[0] | |
| idx_keep, idx_mask = utils.random_token_mask( | |
| size=(batch_size, self.sequence_length), | |
| mask_ratio=self.mask_ratio, | |
| device=images.device, | |
| ) | |
| x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) | |
| # decode and calculate loss (encoder no longer directly used) | |
| x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) | |
| # get image patches for masked tokens | |
| patches = utils.patchify(images, self.patch_size) | |
| # must adjust idx_mask for missing class token | |
| # (class token was added after calculating which indices to mask, | |
| # so we need to subtract 1 from idx_mask to get the new indices that are masked) | |
| target = utils.get_at_index(patches, idx_mask - 1) | |
| loss = self.criterion(x_pred, target) | |
| return loss, x_encoded | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| images = batch["image"] # views contains only a single view | |
| batch_size = images.shape[0] | |
| idx_keep, idx_mask = utils.random_token_mask( | |
| size=(batch_size, self.sequence_length), | |
| mask_ratio=self.mask_ratio, | |
| device=images.device, | |
| ) | |
| x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) | |
| x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) | |
| # get image patches for masked tokens | |
| patches = utils.patchify(images, self.patch_size) | |
| # must adjust idx_mask for missing class token | |
| target = utils.get_at_index(patches, idx_mask - 1) | |
| loss = self.criterion(x_pred, target) | |
| return loss, None | |
| def predict_step(self, batch, batch_idx): | |
| idx_keep, idx_mask = self.mask_random_indices(batch) | |
| return self.predict(batch, idx_mask=idx_mask, idx_keep=idx_keep) | |
| def mask_random_indices(self, batch): | |
| idx_keep, idx_mask = utils.random_token_mask( | |
| size=(batch["image"].shape[0], self.sequence_length), # (batch_size, seq_len) | |
| mask_ratio=self.mask_ratio, | |
| device=batch["image"].device, | |
| ) | |
| return idx_keep, idx_mask | |
| def predict(self, batch, idx_mask, idx_keep=None): | |
| # not used during training etc, only as a handy API | |
| # note the order of arguments is idx_mask first, as this is what most people change! | |
| # idx 0 is the class token and is never masked | |
| # user must add 1 to all indices before passing to predict! assumes this is already done | |
| assert idx_mask is not None | |
| if idx_keep is None: # probably a user only providing idx_mask, not using predict_step above | |
| all_indices = set(range(0, self.sequence_length)) | |
| idx_keep = [] | |
| for row in idx_mask: | |
| keep_row = list(all_indices - set(row.tolist())) | |
| idx_keep.append(keep_row) | |
| idx_keep = torch.tensor(idx_keep).to(idx_mask.device) | |
| images = batch["image"] | |
| batch_size = images.shape[0] | |
| x_encoded = self.forward_encoder(images=images, idx_keep=idx_keep) | |
| x_pred = self.forward_decoder(x_encoded=x_encoded, idx_keep=idx_keep, idx_mask=idx_mask) | |
| # get masked and reconstructed images | |
| im_masked, im_reconstructed = self.mask_and_reconstruct_images(mask=idx_mask, num_images=batch_size, y=x_pred, x=images) | |
| # calculate MSE (copied from above, but with per-image reduction not per-batch reduction) | |
| patches = utils.patchify(images, self.patch_size) # does not change batch dim | |
| target = utils.get_at_index(patches, idx_mask - 1) | |
| mse_per_patch = torch.nn.MSELoss(reduction="none")(x_pred, target) | |
| mse_per_image = mse_per_patch.view(batch_size, -1).mean(dim=1) # reduce all dimensions but batch | |
| return { | |
| 'id_str': batch['id_str'], | |
| 'images': image_batch_to_pil_list(images), | |
| 'encoded': x_encoded, | |
| 'masked': image_batch_to_pil_list(im_masked), | |
| 'reconstructed': image_batch_to_pil_list(im_reconstructed), | |
| 'reconstruction_error': mse_per_image | |
| } | |
| def mask_and_reconstruct_images(self, mask, num_images, y, x): | |
| im_masked = self.patchify(x) # still the original image, just reshaped | |
| im_reconstructed = im_masked.clone() # same for now, but will become the reconstructed images | |
| # is mask is None, both masked and reconstructed are just the original image, do nothing | |
| # otherwise | |
| if mask is not None: | |
| for batch_index in range(num_images): | |
| # we ran out of images in the batch | |
| if batch_index >= x.shape[0] or batch_index > num_images: | |
| break | |
| # replace values with either 0 or the predicted fill values | |
| for mask_idx, token_idx in enumerate(mask[batch_index]): | |
| im_masked[batch_index, token_idx - 1] = 0 # set masked pixels to 0 | |
| im_reconstructed[batch_index, token_idx - 1, :] = y[batch_index, mask_idx, :] # set masked pixels to predicted pixels | |
| # depatchify i.e. reshape back like original image | |
| im_masked = self.unpatchify(im_masked) | |
| im_reconstructed = self.unpatchify(im_reconstructed) | |
| return im_masked, im_reconstructed | |
| def unpatchify(self, x): | |
| # i.e. [b, h*w, p*p*c] -> [b, c, h*p, w*p], where p is patch size | |
| return einops.rearrange( | |
| x, | |
| "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", | |
| p1=self.patch_size, | |
| p2=self.patch_size, | |
| b=x.shape[0], | |
| c=3, | |
| h=int(np.sqrt(x.shape[1])), | |
| w=int(np.sqrt(x.shape[1])), | |
| ) | |
| def patchify(self, x): | |
| # confusingly, "h" here is height // patch size i.e. number of patches and p is patch size | |
| # in more normal terms | |
| # x is an image shape [b, c, h, w] | |
| # reshape to [b, n_patches^2/patch_size^2, patch_size^2*c] | |
| return einops.rearrange( | |
| x, | |
| "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", | |
| p1=self.patch_size, | |
| p2=self.patch_size, | |
| b=x.shape[0], | |
| c=3, | |
| h=x.shape[-2] // self.patch_size, | |
| w=x.shape[-1] // self.patch_size, | |
| ) | |
| def encoder(self): | |
| return self.backbone.vit # hopefully equivalent to self.backbone.encode(x, idx_keep=all) | |
| def image_batch_to_pil_list(images): | |
| images = einops.rearrange(images, 'b c h w -> b h w c') | |
| images = torch.clamp(images, 0, 1)*255 | |
| images = images.cpu().numpy() | |
| images = images.astype(np.uint8) | |
| # print(images.shape) | |
| return [Image.fromarray(im) for im in images] | |