File size: 6,143 Bytes
6600352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import math
import os
from typing import List
from PIL import Image
import cv2
import numpy as np
import torch

from utils import ptp_utils
import matplotlib.pyplot as plt


def save_binary_masks(
    attention_masks,
    word: str,
    res: int = 16,
    orig_image=None,
    save_path=None,
    txt_under_img:bool=False,
):
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)

    if isinstance(attention_masks, torch.Tensor):
        attention_masks = attention_masks.squeeze().cpu().numpy()
    elif isinstance(attention_masks, np.ndarray):
        attention_masks = attention_masks.squeeze()
    else:
        raise TypeError("attention_masks must be torch.Tensor or np.ndarray")

    mask = (attention_masks > 0).astype(np.uint8) * 255
    mask_image = Image.fromarray(mask, mode='L')
    mask_image = mask_image.resize((256, 256), resample=Image.NEAREST)  #
    mask_image = mask_image.convert('RGB')
    mask_np = np.array(mask_image)
    if txt_under_img:
        mask_with_text = ptp_utils.text_under_image(mask_np, word)
        final_image = Image.fromarray(mask_with_text)
    else:
        final_image = Image.fromarray(mask_np)
    final_image = final_image.resize((256, 256), resample=Image.BILINEAR)
    if save_path:
        final_image.save(save_path)




def show_cross_attention(prompt: str,
                         attention_store,
                         tokenizer,
                         res: int,
                         from_where: List[str],
                         subject_words: List[str],
                         bs:int=2,
                         select: int = 0,
                         orig_image=None,
                         text_under_img:bool=True):


    tokens = tokenizer.encode(prompt)
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(attention_store, res, from_where, True, select, bs).detach().cpu()
    images = []
    token_texts = [decoder(int(token)) for token in tokens]
    token_indices = [i for i, text in enumerate(token_texts) if text in subject_words]
    last_idx = len(token_texts) - 1

    # show spatial attention for indices of tokens to strengthen
    for i in token_indices:
        image = attention_maps[:, :, i]         # (32,32)
        image = show_image_relevance(image, orig_image)
        image = image.astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2)))
        if text_under_img:
            image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
        images.append(image)

    ptp_utils.view_images(np.stack(images, axis=0))
        



def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=32):
    # create heatmap from mask on image
    def show_cam_on_image(img, mask):
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        cam = heatmap + np.float32(img)
        cam = cam / np.max(cam)
        return cam

    image = image.resize((relevnace_res ** 2, relevnace_res ** 2))
    image = np.array(image)

    image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])       # (1,1,16,16)
    image_relevance = image_relevance.cuda() # because float16 precision interpolation is not supported on cpu
    image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')    # (1,1,256,256)
    image_relevance = image_relevance.cpu() # send it back to cpu
    image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
    image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)                         # (256,256)   
    image = (image - image.min()) / (image.max() - image.min())
    vis = show_cam_on_image(image, image_relevance)
    vis = np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis


def get_image_grid(images: List[Image.Image]) -> Image:
    num_images = len(images)
    cols = int(math.ceil(math.sqrt(num_images)))
    rows = int(math.ceil(num_images / cols))
    width, height = images[0].size
    grid_image = Image.new('RGB', (cols * width, rows * height))
    for i, img in enumerate(images):
        x = i % cols
        y = i // cols
        grid_image.paste(img, (x * width, y * height))
    return grid_image

def aggregate_attention(attention_store, res: int, from_where: List[str], is_cross: bool, select: int, bs:int = 2):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(bs, -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()

def show_self_attention(attention_stores, from_where: str, layers:int):
    top_components = []
    # the first self attention map
    first_attention_map = attention_stores[0].get_average_attention()[from_where][layers][:8].mean(dim=0)
    U, S, V = torch.svd(first_attention_map.to(torch.float32))
    top_U = U[:, :6]
    top_components.append(top_U)
    for i, attention_store in enumerate(attention_stores, start=0):
        attention_map = (attention_store.get_average_attention()[from_where][layers][8:]).mean(dim=0).to(torch.float32)
        U, S, V = torch.svd(attention_map)
        top_U = U[:,:6]
        top_components.append(top_U)

    for batch_idx, components in enumerate(top_components):
        plt.figure(figsize=(24, 4))
        for comp_idx in range(6):
            plt.subplot(1, 6, comp_idx + 1)
            component = components[:,comp_idx].reshape(16,16).to('cpu')
            plt.imshow(component, cmap='viridis')
            # plt.colorbar()
            plt.axis('off')
            plt.title(f'prompt {batch_idx + 1} Top {comp_idx + 1}')
        plt.tight_layout()
        plt.show()