Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from einops import rearrange | |
| import math | |
| from typing import Union, Tuple, List | |
| def aggregate_attention(attention_store, res: List[int], from_where: List[str], is_cross: bool, select: int, batch_size: int = 1): | |
| out = [] | |
| attention_maps = attention_store.get_average_attention() | |
| res_W, res_H = res | |
| num_pixels = res_H*res_W | |
| for location in from_where: | |
| for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
| if item.shape[1] == num_pixels: | |
| cross_maps = item.reshape(batch_size, -1, res_W, res_H, item.shape[-1])[select] | |
| out.append(cross_maps) | |
| out = torch.cat(out, dim=0) | |
| out = out.sum(0) / out.shape[0] | |
| return out.cpu() | |
| def aggregate_attention_intermediate( | |
| attention_store, | |
| res: int, | |
| from_where: List[str], | |
| from_res: List[int], | |
| is_cross: bool, | |
| select: int) -> torch.Tensor: | |
| """ Aggregates the attention across the different layers and heads at the specified resolution. """ | |
| out = [] | |
| attention_maps = attention_store.get_average_attention() | |
| num_pixels = [r ** 2 for r in from_res] | |
| for location in from_where: | |
| for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: | |
| if item.shape[1] in num_pixels: | |
| cur_res = int(math.sqrt(item.shape[1])) | |
| cross_maps = item.reshape(1, -1, cur_res, cur_res, item.shape[-1])[select] | |
| cross_maps = rearrange(cross_maps, 'b h w c -> b c h w') | |
| cross_maps = torch.nn.functional.interpolate(cross_maps, size=(res,res),mode='nearest', ) | |
| cross_maps = rearrange(cross_maps, 'b c h w -> b h w c') | |
| out.append(cross_maps) | |
| out = torch.cat(out, dim=0) #[40,16,16,77] | |
| out = out.sum(0) / out.shape[0] | |
| return out | |
| def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): | |
| h, w, c = image.shape | |
| offset = int(h * .2) | |
| img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) | |
| img[:h] = image | |
| textsize = cv2.getTextSize(text, font, 1, 2)[0] | |
| text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 | |
| cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) | |
| return img | |
| def view_images(images, num_rows=1, offset_ratio=0.02): | |
| if type(images) is list: | |
| if isinstance(images[0], Image.Image): | |
| h, w = images[0].size | |
| images = [np.array(img) for img in images] | |
| num_empty = len(images) % num_rows | |
| elif images.ndim == 4: | |
| num_empty = images.shape[0] % num_rows | |
| else: | |
| images = [images] | |
| num_empty = 0 | |
| empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 | |
| images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty | |
| num_items = len(images) | |
| h, w, c = images[0].shape | |
| offset = int(h * offset_ratio) | |
| num_cols = num_items // num_rows | |
| image_ = np.ones((h * num_rows + offset * (num_rows - 1), | |
| w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 | |
| for i in range(num_rows): | |
| for j in range(num_cols): | |
| image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ | |
| i * num_cols + j] | |
| pil_img = Image.fromarray(image_) | |