Detail-plus-plus / utils /ptp_utils.py
Westlake-AGI-Lab's picture
Update utils/ptp_utils.py
2993b8e verified
import cv2
import numpy as np
import torch
from PIL import Image
from einops import rearrange
import math
from typing import Union, Tuple, List
def aggregate_attention(attention_store, res: List[int], from_where: List[str], is_cross: bool, select: int, batch_size: int = 1):
out = []
attention_maps = attention_store.get_average_attention()
res_W, res_H = res
num_pixels = res_H*res_W
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(batch_size, -1, res_W, res_H, item.shape[-1])[select]
out.append(cross_maps)
out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0]
return out.cpu()
def aggregate_attention_intermediate(
attention_store,
res: int,
from_where: List[str],
from_res: List[int],
is_cross: bool,
select: int) -> torch.Tensor:
""" Aggregates the attention across the different layers and heads at the specified resolution. """
out = []
attention_maps = attention_store.get_average_attention()
num_pixels = [r ** 2 for r in from_res]
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] in num_pixels:
cur_res = int(math.sqrt(item.shape[1]))
cross_maps = item.reshape(1, -1, cur_res, cur_res, item.shape[-1])[select]
cross_maps = rearrange(cross_maps, 'b h w c -> b c h w')
cross_maps = torch.nn.functional.interpolate(cross_maps, size=(res,res),mode='nearest', )
cross_maps = rearrange(cross_maps, 'b c h w -> b h w c')
out.append(cross_maps)
out = torch.cat(out, dim=0) #[40,16,16,77]
out = out.sum(0) / out.shape[0]
return out
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
font = cv2.FONT_HERSHEY_SIMPLEX
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
img[:h] = image
textsize = cv2.getTextSize(text, font, 1, 2)[0]
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
return img
def view_images(images, num_rows=1, offset_ratio=0.02):
if type(images) is list:
if isinstance(images[0], Image.Image):
h, w = images[0].size
images = [np.array(img) for img in images]
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)
h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]
pil_img = Image.fromarray(image_)