Spaces:
Runtime error
Runtime error
File size: 6,143 Bytes
6600352 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import math
import os
from typing import List
from PIL import Image
import cv2
import numpy as np
import torch
from utils import ptp_utils
import matplotlib.pyplot as plt
def save_binary_masks(
attention_masks,
word: str,
res: int = 16,
orig_image=None,
save_path=None,
txt_under_img:bool=False,
):
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if isinstance(attention_masks, torch.Tensor):
attention_masks = attention_masks.squeeze().cpu().numpy()
elif isinstance(attention_masks, np.ndarray):
attention_masks = attention_masks.squeeze()
else:
raise TypeError("attention_masks must be torch.Tensor or np.ndarray")
mask = (attention_masks > 0).astype(np.uint8) * 255
mask_image = Image.fromarray(mask, mode='L')
mask_image = mask_image.resize((256, 256), resample=Image.NEAREST) #
mask_image = mask_image.convert('RGB')
mask_np = np.array(mask_image)
if txt_under_img:
mask_with_text = ptp_utils.text_under_image(mask_np, word)
final_image = Image.fromarray(mask_with_text)
else:
final_image = Image.fromarray(mask_np)
final_image = final_image.resize((256, 256), resample=Image.BILINEAR)
if save_path:
final_image.save(save_path)
def show_cross_attention(prompt: str,
attention_store,
tokenizer,
res: int,
from_where: List[str],
subject_words: List[str],
bs:int=2,
select: int = 0,
orig_image=None,
text_under_img:bool=True):
tokens = tokenizer.encode(prompt)
decoder = tokenizer.decode
attention_maps = aggregate_attention(attention_store, res, from_where, True, select, bs).detach().cpu()
images = []
token_texts = [decoder(int(token)) for token in tokens]
token_indices = [i for i, text in enumerate(token_texts) if text in subject_words]
last_idx = len(token_texts) - 1
# show spatial attention for indices of tokens to strengthen
for i in token_indices:
image = attention_maps[:, :, i] # (32,32)
image = show_image_relevance(image, orig_image)
image = image.astype(np.uint8)
image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2)))
if text_under_img:
image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
images.append(image)
ptp_utils.view_images(np.stack(images, axis=0))
def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=32):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
image = image.resize((relevnace_res ** 2, relevnace_res ** 2))
image = np.array(image)
image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1]) # (1,1,16,16)
image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear') # (1,1,256,256)
image_relevance = image_relevance.cpu() # send it back to cpu
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2) # (256,256)
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
def get_image_grid(images: List[Image.Image]) -> Image:
num_images = len(images)
cols = int(math.ceil(math.sqrt(num_images)))
rows = int(math.ceil(num_images / cols))
width, height = images[0].size
grid_image = Image.new('RGB', (cols * width, rows * height))
for i, img in enumerate(images):
x = i % cols
y = i // cols
grid_image.paste(img, (x * width, y * height))
return grid_image
def aggregate_attention(attention_store, res: int, from_where: List[str], is_cross: bool, select: int, bs:int = 2):
out = []
attention_maps = attention_store.get_average_attention()
num_pixels = res ** 2
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if item.shape[1] == num_pixels:
cross_maps = item.reshape(bs, -1, res, res, item.shape[-1])[select]
out.append(cross_maps)
out = torch.cat(out, dim=0)
out = out.sum(0) / out.shape[0]
return out.cpu()
def show_self_attention(attention_stores, from_where: str, layers:int):
top_components = []
# the first self attention map
first_attention_map = attention_stores[0].get_average_attention()[from_where][layers][:8].mean(dim=0)
U, S, V = torch.svd(first_attention_map.to(torch.float32))
top_U = U[:, :6]
top_components.append(top_U)
for i, attention_store in enumerate(attention_stores, start=0):
attention_map = (attention_store.get_average_attention()[from_where][layers][8:]).mean(dim=0).to(torch.float32)
U, S, V = torch.svd(attention_map)
top_U = U[:,:6]
top_components.append(top_U)
for batch_idx, components in enumerate(top_components):
plt.figure(figsize=(24, 4))
for comp_idx in range(6):
plt.subplot(1, 6, comp_idx + 1)
component = components[:,comp_idx].reshape(16,16).to('cpu')
plt.imshow(component, cmap='viridis')
# plt.colorbar()
plt.axis('off')
plt.title(f'prompt {batch_idx + 1} Top {comp_idx + 1}')
plt.tight_layout()
plt.show() |