manu02 commited on
Commit
b492e55
·
1 Parent(s): 689ded9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +631 -631
app.py CHANGED
@@ -1,631 +1,631 @@
1
- # app.py
2
- """
3
- 🖼️→📝 Image-to-Text Attention Visualizer (Custom Model)
4
- - Loads your custom model via create_complete_model()
5
- - Accepts an image, applies your transform, then calls:
6
- model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
7
- - Selector lists ONLY generated words (no prompt tokens).
8
- - Viewer (single row) shows:
9
- (1) original image,
10
- (2) original + colored attention heatmap overlay,
11
- (3) heatmap alone (colored).
12
- - Heatmap is built from the first 1024 image tokens (32×32), then upscaled to the image size.
13
- - Text block below shows word-level attention over generated tokens (no return_offsets_mapping used).
14
- - Fixes deprecations: Matplotlib colormap API & Pillow mode inference.
15
- """
16
-
17
- import os
18
- import re
19
- import random
20
- from typing import List, Tuple, Optional
21
-
22
- import gradio as gr
23
- import torch
24
- import numpy as np
25
- from PIL import Image
26
- from safetensors.torch import load_model
27
-
28
- # Optional: nicer colormap (Matplotlib >=3.7 API; no deprecation warnings)
29
- try:
30
- import matplotlib as mpl
31
- _HAS_MPL = True
32
- _COLORMAP = mpl.colormaps.get_cmap("magma")
33
- except Exception:
34
- _HAS_MPL = False
35
- _COLORMAP = None
36
-
37
- # ========= Your utilities & model =========
38
- from utils.processing import image_transform, pil_from_path
39
- from utils.complete_model import create_complete_model
40
-
41
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- model = create_complete_model(device=DEVICE, attention_implementation="eager")
43
- SAFETENSOR_PATH = "complete_model.safetensor"
44
- try:
45
- load_model(model, SAFETENSOR_PATH)
46
- except Exception as e:
47
- print(f"Error loading model: {e}, continuing with uninitialized weights.")
48
- model.eval()
49
- device = DEVICE
50
-
51
- # --- Grab tokenizer from your model ---
52
- tokenizer = getattr(model, "tokenizer", None)
53
- if tokenizer is None:
54
- raise ValueError("Expected `model.tokenizer` to exist and be a HF-like tokenizer.")
55
-
56
- # --- Fix PAD/EOS ambiguity (and resize embeddings if applicable) ---
57
- needs_resize = False
58
- pad_id = getattr(tokenizer, "pad_token_id", None)
59
- eos_id = getattr(tokenizer, "eos_token_id", None)
60
- if pad_id is None or (eos_id is not None and pad_id == eos_id):
61
- tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
62
- needs_resize = True
63
-
64
- # Try common resize hooks safely (only if your decoder actually uses tokenizer vocab)
65
- if needs_resize:
66
- resize_fns = [
67
- getattr(getattr(model, "decoder", None), "resize_token_embeddings", None),
68
- getattr(model, "resize_token_embeddings", None),
69
- ]
70
- for fn in resize_fns:
71
- if callable(fn):
72
- try:
73
- fn(len(tokenizer))
74
- break
75
- except Exception:
76
- # If your model doesn't need resizing (separate vocab), it's fine.
77
- pass
78
-
79
- # ========= Regex for words (words + punctuation) =========
80
- WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
81
-
82
- # ========= Model metadata (for slider ranges) =========
83
- def model_heads_layers():
84
- def _get(obj, *names, default=None):
85
- for n in names:
86
- if obj is None:
87
- return default
88
- if hasattr(obj, n):
89
- try:
90
- return int(getattr(obj, n))
91
- except Exception:
92
- return default
93
- return default
94
-
95
- cfg_candidates = [
96
- getattr(model, "config", None),
97
- getattr(getattr(model, "decoder", None), "config", None),
98
- getattr(getattr(model, "lm_head", None), "config", None),
99
- ]
100
- L = H = None
101
- for cfg in cfg_candidates:
102
- if L is None:
103
- L = _get(cfg, "num_hidden_layers", "n_layer", default=None)
104
- if H is None:
105
- H = _get(cfg, "num_attention_heads", "n_head", default=None)
106
- if L is None: L = 12
107
- if H is None: H = 12
108
- return max(1, L), max(1, H)
109
-
110
- # ========= Attention utils =========
111
- def get_attention_for_token_layer(
112
- attentions,
113
- token_index,
114
- layer_index,
115
- batch_index=0,
116
- head_index=0,
117
- mean_across_layers=True,
118
- mean_across_heads=True,
119
- ):
120
- """
121
- `attentions`:
122
- tuple length = #generated tokens
123
- attentions[t] -> tuple over layers; each layer tensor is (batch, heads, q, k)
124
- """
125
- token_attention = attentions[token_index]
126
-
127
- if mean_across_layers:
128
- layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
129
- else:
130
- layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
131
-
132
- batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
133
-
134
- if mean_across_heads:
135
- head_attention = batch_attention.mean(dim=0) # (q, k)
136
- else:
137
- head_attention = batch_attention[int(head_index)] # (q, k)
138
-
139
- return head_attention.squeeze(0) # q==1 -> (k,)
140
-
141
- # ========= Tokens → words mapping (no offset_mapping needed) =========
142
- def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
143
- """
144
- Works with slow/fast tokenizers. No return_offsets_mapping.
145
- Steps:
146
- 1) detok token_ids
147
- 2) regex-split words and get their char-end positions
148
- 3) for each word-end (we), encode detok[:we] w/ add_special_tokens=False
149
- last token index = len(prefix_ids) - 1
150
- """
151
- if not token_ids:
152
- return [], []
153
-
154
- toks = tokenizer.convert_ids_to_tokens(token_ids)
155
- detok = tokenizer.convert_tokens_to_string(toks)
156
-
157
- matches = list(re.finditer(WORD_RE, detok))
158
- words = [m.group(0) for m in matches]
159
- ends = [m.span()[1] for m in matches] # char end (exclusive)
160
-
161
- word2tok: List[int] = []
162
- for we in ends:
163
- prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False)
164
- if not prefix_ids:
165
- word2tok.append(0)
166
- continue
167
- last_idx = len(prefix_ids) - 1
168
- last_idx = max(0, min(last_idx, len(token_ids) - 1))
169
- word2tok.append(last_idx)
170
-
171
- return words, word2tok
172
-
173
- def _strip_trailing_special(ids: List[int]) -> List[int]:
174
- specials = set(getattr(tokenizer, "all_special_ids", []) or [])
175
- j = len(ids)
176
- while j > 0 and ids[j - 1] in specials:
177
- j -= 1
178
- return ids[:j]
179
-
180
- # ========= Visualization (word-level for generated text) =========
181
- def generate_word_visualization_gen_only(
182
- words_gen: List[str],
183
- word_ends_rel: List[int],
184
- gen_attn_values: np.ndarray,
185
- selected_token_rel_idx: int,
186
- ) -> str:
187
- """
188
- words_gen: generated words only
189
- word_ends_rel: last-token indices of each generated word (relative to generation)
190
- gen_attn_values: length == len(gen_token_ids), attention over generated tokens only
191
- (zeros for future tokens padded at the end)
192
- """
193
- if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0:
194
- return (
195
- "<div style='width:100%;'>"
196
- " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
197
- " <div style='color:#ddd;'>No text attention values.</div>"
198
- " </div>"
199
- "</div>"
200
- )
201
-
202
- # compute word starts from ends (inclusive indexing)
203
- starts = []
204
- for i, end in enumerate(word_ends_rel):
205
- if i == 0:
206
- starts.append(0)
207
- else:
208
- starts.append(min(word_ends_rel[i - 1] + 1, end))
209
-
210
- # sum attention per word
211
- word_scores = []
212
- T = len(gen_attn_values)
213
- for i, end in enumerate(word_ends_rel):
214
- start = starts[i]
215
- if start > end:
216
- start = end
217
- s = max(0, min(start, T - 1))
218
- e = max(0, min(end, T - 1))
219
- if e < s:
220
- s, e = e, s
221
- word_scores.append(float(gen_attn_values[s:e + 1].sum()))
222
-
223
- max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
224
-
225
- # find selected word (contains selected token idx)
226
- selected_word_idx = None
227
- for i, end in enumerate(word_ends_rel):
228
- if selected_token_rel_idx <= end:
229
- selected_word_idx = i
230
- break
231
- if selected_word_idx is None and word_ends_rel:
232
- selected_word_idx = len(word_ends_rel) - 1
233
-
234
- spans = []
235
- for i, w in enumerate(words_gen):
236
- alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
237
- bg = f"rgba(66,133,244,{alpha:.3f})"
238
- border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
239
- spans.append(
240
- f"<span style='display:inline-block;background:{bg};border:{border};"
241
- f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
242
- f"{w}</span>"
243
- )
244
-
245
- return (
246
- "<div style='width:100%;'>"
247
- " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
248
- " <div style='white-space:normal;line-height:1.8;'>"
249
- f" {''.join(spans)}"
250
- " </div>"
251
- " </div>"
252
- "</div>"
253
- )
254
-
255
- # ========= Heatmap helpers for 1024 image tokens =========
256
- def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray:
257
- """
258
- attn_1d: (k,) attention over keys for a given generation step; first 1024 are image tokens.
259
- Returns a (32, 32) uint8 grayscale array.
260
- """
261
- # take first 1024 (image tokens); pad/truncate as needed
262
- if attn_1d.shape[0] < img_token_len:
263
- img_part = np.zeros(img_token_len, dtype=float)
264
- img_part[: attn_1d.shape[0]] = attn_1d
265
- else:
266
- img_part = attn_1d[:img_token_len]
267
-
268
- # normalize to [0,1]
269
- mn, mx = float(img_part.min()), float(img_part.max())
270
- denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0
271
- norm = (img_part - mn) / denom
272
-
273
- # return uint8 (0–255)
274
- return (norm.reshape(side, side) * 255.0).astype(np.uint8)
275
-
276
- def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image:
277
- """
278
- Convert (H,W) uint8 grayscale to RGB heatmap using matplotlib (if available) or a simple fallback.
279
- """
280
- if _HAS_MPL and _COLORMAP is not None:
281
- colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8)
282
- return Image.fromarray(colored) # Pillow infers RGB
283
- else:
284
- # Fallback: map grayscale to red-yellow (simple linear)
285
- g = heatmap_u8.astype(np.float32) / 255.0
286
- r = (g * 255.0).clip(0, 255).astype(np.uint8)
287
- g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8)
288
- b = np.zeros_like(r, dtype=np.uint8)
289
- rgb = np.stack([r, g2, b], axis=-1)
290
- return Image.fromarray(rgb) # Pillow infers RGB
291
-
292
- def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
293
- return img.resize(target_size, resample=Image.BILINEAR)
294
-
295
- def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image:
296
- """
297
- Blend heatmap over original. alpha in [0,1].
298
- """
299
- if heatmap_rgb.size != orig.size:
300
- heatmap_rgb = _resize_like(heatmap_rgb, orig.size)
301
- base = orig.convert("RGBA")
302
- overlay = heatmap_rgb.convert("RGBA")
303
- # set global alpha
304
- r, g, b = overlay.split()[:3]
305
- a = Image.new("L", overlay.size, int(alpha * 255))
306
- overlay = Image.merge("RGBA", (r, g, b, a))
307
- return Image.alpha_composite(base, overlay).convert("RGB")
308
-
309
- # ========= Core (image → generate) =========
310
- def _prepare_image_tensor(pil_img, img_size=512):
311
- tfm = image_transform(img_size=img_size)
312
- tens = tfm(pil_img).unsqueeze(0).to(device, non_blocking=True) # [1,3,H,W]
313
- return tens
314
-
315
- def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads):
316
- """
317
- 1) Transform image
318
- 2) model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
319
- expected to return (gen_ids, gen_text, attentions)
320
- 3) Build selector over generated words only
321
- 4) Initial visualization -> (orig, overlay, heatmap, word HTML)
322
- """
323
- if pil_image is None:
324
- # Return placeholders
325
- blank = Image.new("RGB", (256, 256), "black")
326
- return (
327
- None, None, 1024, None, None,
328
- gr.update(choices=[], value=None),
329
- blank, # original
330
- blank, # overlay
331
- np.zeros((256, 256, 3), dtype=np.uint8), # heatmap RGB upscaled (placeholder)
332
- "<div style='text-align:center;padding:20px;'>Upload or load an image first.</div>",
333
- )
334
-
335
- pixel_values = _prepare_image_tensor(pil_image, img_size=512)
336
-
337
- with torch.no_grad():
338
- gen_ids, gen_text, attentions = model.generate(
339
- pixel_values=pixel_values,
340
- max_new_tokens=int(max_new_tokens),
341
- output_attentions=True
342
- )
343
-
344
- # Expect batch size 1
345
- if isinstance(gen_ids, torch.Tensor):
346
- gen_ids = gen_ids[0].tolist()
347
- gen_ids = _strip_trailing_special(gen_ids)
348
-
349
- words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids)
350
-
351
- display_choices = [(w, i) for i, w in enumerate(words_gen)]
352
- if not display_choices:
353
- # No generated tokens; still show original and blank heatmap/overlay
354
- blank_hm = np.zeros((32, 32), dtype=np.uint8)
355
- hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST)
356
- overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35)
357
- return (
358
- attentions, gen_ids, 1024, words_gen, gen_word2tok_rel,
359
- gr.update(choices=[], value=None),
360
- pil_image, # original
361
- overlay, # overlay
362
- np.array(hm_rgb), # heatmap RGB
363
- "<div style='text-align:center;padding:20px;'>No generated tokens to visualize.</div>",
364
- )
365
-
366
- first_idx = 0
367
- hm_rgb_init, overlay_init, html_init = update_visualization(
368
- selected_gen_index=first_idx,
369
- attentions=attentions,
370
- gen_token_ids=gen_ids,
371
- layer=layer,
372
- head=head,
373
- mean_layers=mean_layers,
374
- mean_heads=mean_heads,
375
- words_gen=words_gen,
376
- gen_word2tok_rel=gen_word2tok_rel,
377
- pil_image=pil_image,
378
- )
379
-
380
- return (
381
- attentions, # state_attentions
382
- gen_ids, # state_gen_token_ids
383
- 1024, # state_img_token_len (fixed)
384
- words_gen, # state_words_gen
385
- gen_word2tok_rel, # state_gen_word2tok_rel
386
- gr.update(choices=display_choices, value=first_idx),
387
- pil_image, # original image view
388
- overlay_init, # overlay (PIL)
389
- hm_rgb_init, # heatmap RGB (np array or PIL)
390
- html_init, # HTML words viz
391
- )
392
-
393
- def update_visualization(
394
- selected_gen_index,
395
- attentions,
396
- gen_token_ids,
397
- layer,
398
- head,
399
- mean_layers,
400
- mean_heads,
401
- words_gen,
402
- gen_word2tok_rel,
403
- pil_image: Optional[Image.Image] = None,
404
- ):
405
- """
406
- Recompute visualization for the chosen GENERATED word:
407
- - Extract attention vector for that generation step.
408
- - Build 32×32 heatmap from first 1024 values (image tokens), colorize and upscale to original image size.
409
- - Create overlay (original + heatmap with alpha).
410
- - Build word HTML from the portion corresponding to generated tokens.
411
- For step t, keys cover: 1024 image tokens + (t+1) generated tokens so far.
412
- """
413
- if selected_gen_index is None or attentions is None or gen_word2tok_rel is None:
414
- blank = np.zeros((256, 256, 3), dtype=np.uint8)
415
- return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Generate first.</div>"
416
-
417
- gidx = int(selected_gen_index)
418
- if not (0 <= gidx < len(gen_word2tok_rel)):
419
- blank = np.zeros((256, 256, 3), dtype=np.uint8)
420
- return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
421
-
422
- step_index = int(gen_word2tok_rel[gidx]) # last token of that word (relative to generation)
423
- if not attentions or step_index >= len(attentions):
424
- blank = np.zeros((256, 256, 3), dtype=np.uint8)
425
- return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>No attention for this step.</div>"
426
-
427
- token_attn = get_attention_for_token_layer(
428
- attentions,
429
- token_index=step_index,
430
- layer_index=int(layer),
431
- head_index=int(head),
432
- mean_across_layers=bool(mean_layers),
433
- mean_across_heads=bool(mean_heads),
434
- )
435
-
436
- attn_vals = token_attn.detach().cpu().numpy()
437
- if attn_vals.ndim == 2:
438
- attn_vals = attn_vals[-1] # (k,) from (q,k)
439
-
440
- # ---- Heatmap over 1024 image tokens (colorized and upscaled to original size) ----
441
- heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32)
442
- hm_rgb_pil = _colorize_heatmap(heatmap_u8)
443
-
444
- # If original image not provided (should be), create a placeholder size
445
- if pil_image is None:
446
- pil_image = Image.new("RGB", (256, 256), "black")
447
-
448
- hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST)
449
- overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35)
450
-
451
- # ---- Word-level viz over generated tokens only ----
452
- k_len = int(attn_vals.shape[0])
453
- observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024)))
454
- total_gen = len(gen_token_ids)
455
-
456
- gen_vec = np.zeros(total_gen, dtype=float)
457
- if observed_gen > 0:
458
- # slice generated part of attention vector
459
- start = 1024
460
- end = min(1024 + observed_gen, k_len)
461
- gen_slice = attn_vals[start:end]
462
- gen_vec[: len(gen_slice)] = gen_slice
463
-
464
- selected_token_rel_idx = step_index
465
-
466
- html_words = generate_word_visualization_gen_only(
467
- words_gen=words_gen,
468
- word_ends_rel=gen_word2tok_rel,
469
- gen_attn_values=gen_vec,
470
- selected_token_rel_idx=selected_token_rel_idx,
471
- )
472
-
473
- # Return (heatmap RGB, overlay, html)
474
- return np.array(hm_rgb_pil_up), overlay_pil, html_words
475
-
476
- def toggle_slider(is_mean):
477
- return gr.update(interactive=not bool(is_mean))
478
-
479
- # ========= Gradio UI =========
480
- EXAMPLES_DIR = "examples"
481
-
482
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
483
- gr.Markdown("# 🖼️→📝 Image-to-Text Attention Visualizer (three views + text)")
484
- gr.Markdown(
485
- "Upload an image or click **Load random sample**, generate text, then select a **generated word**. "
486
- "Above: original image, overlay (original + attention), and heatmap (colored). "
487
- "Below: word-level attention over generated text."
488
- )
489
-
490
- # States
491
- state_attentions = gr.State(None) # tuple over generation steps
492
- state_gen_token_ids = gr.State(None) # list[int]
493
- state_img_token_len = gr.State(1024) # fixed
494
- state_words_gen = gr.State(None) # list[str]
495
- state_gen_word2tok_rel = gr.State(None) # list[int]
496
- state_last_image = gr.State(None) # PIL image of last input
497
-
498
- L, H = model_heads_layers()
499
-
500
- with gr.Row():
501
- with gr.Column(scale=1):
502
- gr.Markdown("### 1) Image")
503
- img_input = gr.Image(type="pil", label="Upload image", height=280)
504
- btn_load_sample = gr.Button("Load random sample from /examples", variant="secondary")
505
- sample_status = gr.Markdown("")
506
-
507
- gr.Markdown("### 2) Generation")
508
- slider_max_tokens = gr.Slider(5, 200, value=40, step=5, label="Max New Tokens")
509
- btn_generate = gr.Button("Generate", variant="primary")
510
-
511
- gr.Markdown("### 3) Attention")
512
- check_mean_layers = gr.Checkbox(True, label="Mean Across Layers")
513
- check_mean_heads = gr.Checkbox(True, label="Mean Across Heads")
514
- slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=False)
515
- slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=False)
516
-
517
- with gr.Column(scale=3):
518
- # Three views row
519
- with gr.Row():
520
- img_original_view = gr.Image(
521
- value=None,
522
- label="Original image",
523
- image_mode="RGB",
524
- height=256
525
- )
526
- img_overlay_view = gr.Image(
527
- value=None,
528
- label="Overlay (image + attention)",
529
- image_mode="RGB",
530
- height=256
531
- )
532
- heatmap_view = gr.Image(
533
- value=None,
534
- label="Heatmap (colored)",
535
- image_mode="RGB",
536
- height=256
537
- )
538
-
539
- # Word selector & HTML viz below
540
- radio_word_selector = gr.Radio(
541
- [], label="Select Generated Word",
542
- info="Selector lists only generated words"
543
- )
544
- html_visualization = gr.HTML(
545
- "<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
546
- "Text attention visualization will appear here.</div>"
547
- )
548
-
549
- # Sample loader: always use `examples/`
550
- def _load_sample_from_examples():
551
- try:
552
- files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")]
553
- if not files:
554
- return gr.update(), "No files in /examples."
555
- fp = os.path.join(EXAMPLES_DIR, random.choice(files))
556
- pil_img = pil_from_path(fp)
557
- return gr.update(value=pil_img), f"Loaded sample: {os.path.basename(fp)}"
558
- except Exception as e:
559
- return gr.update(), f"Error loading sample: {e}"
560
-
561
- btn_load_sample.click(
562
- fn=_load_sample_from_examples,
563
- inputs=[],
564
- outputs=[img_input, sample_status]
565
- )
566
-
567
- # Generate
568
- def _run_and_store(pil_image, *args):
569
- out = run_generation(pil_image, *args)
570
- # store the original image for later updates
571
- return (*out, pil_image)
572
-
573
- btn_generate.click(
574
- fn=_run_and_store,
575
- inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads],
576
- outputs=[
577
- state_attentions,
578
- state_gen_token_ids,
579
- state_img_token_len,
580
- state_words_gen,
581
- state_gen_word2tok_rel,
582
- radio_word_selector,
583
- img_original_view, # original
584
- img_overlay_view, # overlay
585
- heatmap_view, # heatmap
586
- html_visualization, # words HTML
587
- state_last_image, # store original PIL
588
- ],
589
- )
590
-
591
- # Update viz on any control change
592
- def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img):
593
- hm_rgb, overlay, html = update_visualization(
594
- selected_gen_index,
595
- attn,
596
- gen_ids,
597
- lyr,
598
- hed,
599
- meanL,
600
- meanH,
601
- words,
602
- word2tok,
603
- pil_image=last_img
604
- )
605
- return overlay, hm_rgb, html
606
-
607
- for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
608
- control.change(
609
- fn=_update_wrapper,
610
- inputs=[
611
- radio_word_selector,
612
- state_attentions,
613
- state_gen_token_ids,
614
- slider_layer,
615
- slider_head,
616
- check_mean_layers,
617
- check_mean_heads,
618
- state_words_gen,
619
- state_gen_word2tok_rel,
620
- state_last_image,
621
- ],
622
- outputs=[img_overlay_view, heatmap_view, html_visualization],
623
- )
624
-
625
- # Toggle slider interactivity
626
- check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
627
- check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
628
-
629
- if __name__ == "__main__":
630
- print(f"Device: {device}")
631
- demo.launch(debug=True)
 
1
+ # app.py
2
+ """
3
+ 🖼️→📝 Image-to-Text Attention Visualizer (Custom Model)
4
+ - Loads your custom model via create_complete_model()
5
+ - Accepts an image, applies your transform, then calls:
6
+ model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
7
+ - Selector lists ONLY generated words (no prompt tokens).
8
+ - Viewer (single row) shows:
9
+ (1) original image,
10
+ (2) original + colored attention heatmap overlay,
11
+ (3) heatmap alone (colored).
12
+ - Heatmap is built from the first 1024 image tokens (32×32), then upscaled to the image size.
13
+ - Text block below shows word-level attention over generated tokens (no return_offsets_mapping used).
14
+ - Fixes deprecations: Matplotlib colormap API & Pillow mode inference.
15
+ """
16
+
17
+ import os
18
+ import re
19
+ import random
20
+ from typing import List, Tuple, Optional
21
+
22
+ import gradio as gr
23
+ import torch
24
+ import numpy as np
25
+ from PIL import Image
26
+ from safetensors.torch import load_model
27
+
28
+ # Optional: nicer colormap (Matplotlib >=3.7 API; no deprecation warnings)
29
+ try:
30
+ import matplotlib as mpl
31
+ _HAS_MPL = True
32
+ _COLORMAP = mpl.colormaps.get_cmap("magma")
33
+ except Exception:
34
+ _HAS_MPL = False
35
+ _COLORMAP = None
36
+
37
+ # ========= Your utilities & model =========
38
+ from utils.processing import image_transform, pil_from_path
39
+ from utils.complete_model import create_complete_model
40
+
41
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model = create_complete_model(device=DEVICE, attention_implementation="eager")
43
+ SAFETENSOR_PATH = "complete_model.safetensor"
44
+ try:
45
+ load_model(model, SAFETENSOR_PATH)
46
+ except Exception as e:
47
+ print(f"Error loading model: {e}, continuing with uninitialized weights.")
48
+ model.eval()
49
+ device = DEVICE
50
+
51
+ # --- Grab tokenizer from your model ---
52
+ tokenizer = getattr(model, "tokenizer", None)
53
+ if tokenizer is None:
54
+ raise ValueError("Expected `model.tokenizer` to exist and be a HF-like tokenizer.")
55
+
56
+ # --- Fix PAD/EOS ambiguity (and resize embeddings if applicable) ---
57
+ needs_resize = False
58
+ pad_id = getattr(tokenizer, "pad_token_id", None)
59
+ eos_id = getattr(tokenizer, "eos_token_id", None)
60
+ if pad_id is None or (eos_id is not None and pad_id == eos_id):
61
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
62
+ needs_resize = True
63
+
64
+ # Try common resize hooks safely (only if your decoder actually uses tokenizer vocab)
65
+ if needs_resize:
66
+ resize_fns = [
67
+ getattr(getattr(model, "decoder", None), "resize_token_embeddings", None),
68
+ getattr(model, "resize_token_embeddings", None),
69
+ ]
70
+ for fn in resize_fns:
71
+ if callable(fn):
72
+ try:
73
+ fn(len(tokenizer))
74
+ break
75
+ except Exception:
76
+ # If your model doesn't need resizing (separate vocab), it's fine.
77
+ pass
78
+
79
+ # ========= Regex for words (words + punctuation) =========
80
+ WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
81
+
82
+ # ========= Model metadata (for slider ranges) =========
83
+ def model_heads_layers():
84
+ def _get(obj, *names, default=None):
85
+ for n in names:
86
+ if obj is None:
87
+ return default
88
+ if hasattr(obj, n):
89
+ try:
90
+ return int(getattr(obj, n))
91
+ except Exception:
92
+ return default
93
+ return default
94
+
95
+ cfg_candidates = [
96
+ getattr(model, "config", None),
97
+ getattr(getattr(model, "decoder", None), "config", None),
98
+ getattr(getattr(model, "lm_head", None), "config", None),
99
+ ]
100
+ L = H = None
101
+ for cfg in cfg_candidates:
102
+ if L is None:
103
+ L = _get(cfg, "num_hidden_layers", "n_layer", default=None)
104
+ if H is None:
105
+ H = _get(cfg, "num_attention_heads", "n_head", default=None)
106
+ if L is None: L = 12
107
+ if H is None: H = 12
108
+ return max(1, L), max(1, H)
109
+
110
+ # ========= Attention utils =========
111
+ def get_attention_for_token_layer(
112
+ attentions,
113
+ token_index,
114
+ layer_index,
115
+ batch_index=0,
116
+ head_index=0,
117
+ mean_across_layers=True,
118
+ mean_across_heads=True,
119
+ ):
120
+ """
121
+ `attentions`:
122
+ tuple length = #generated tokens
123
+ attentions[t] -> tuple over layers; each layer tensor is (batch, heads, q, k)
124
+ """
125
+ token_attention = attentions[token_index]
126
+
127
+ if mean_across_layers:
128
+ layer_attention = torch.stack(token_attention).mean(dim=0) # (batch, heads, q, k)
129
+ else:
130
+ layer_attention = token_attention[int(layer_index)] # (batch, heads, q, k)
131
+
132
+ batch_attention = layer_attention[int(batch_index)] # (heads, q, k)
133
+
134
+ if mean_across_heads:
135
+ head_attention = batch_attention.mean(dim=0) # (q, k)
136
+ else:
137
+ head_attention = batch_attention[int(head_index)] # (q, k)
138
+
139
+ return head_attention.squeeze(0) # q==1 -> (k,)
140
+
141
+ # ========= Tokens → words mapping (no offset_mapping needed) =========
142
+ def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
143
+ """
144
+ Works with slow/fast tokenizers. No return_offsets_mapping.
145
+ Steps:
146
+ 1) detok token_ids
147
+ 2) regex-split words and get their char-end positions
148
+ 3) for each word-end (we), encode detok[:we] w/ add_special_tokens=False
149
+ last token index = len(prefix_ids) - 1
150
+ """
151
+ if not token_ids:
152
+ return [], []
153
+
154
+ toks = tokenizer.convert_ids_to_tokens(token_ids)
155
+ detok = tokenizer.convert_tokens_to_string(toks)
156
+
157
+ matches = list(re.finditer(WORD_RE, detok))
158
+ words = [m.group(0) for m in matches]
159
+ ends = [m.span()[1] for m in matches] # char end (exclusive)
160
+
161
+ word2tok: List[int] = []
162
+ for we in ends:
163
+ prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False)
164
+ if not prefix_ids:
165
+ word2tok.append(0)
166
+ continue
167
+ last_idx = len(prefix_ids) - 1
168
+ last_idx = max(0, min(last_idx, len(token_ids) - 1))
169
+ word2tok.append(last_idx)
170
+
171
+ return words, word2tok
172
+
173
+ def _strip_trailing_special(ids: List[int]) -> List[int]:
174
+ specials = set(getattr(tokenizer, "all_special_ids", []) or [])
175
+ j = len(ids)
176
+ while j > 0 and ids[j - 1] in specials:
177
+ j -= 1
178
+ return ids[:j]
179
+
180
+ # ========= Visualization (word-level for generated text) =========
181
+ def generate_word_visualization_gen_only(
182
+ words_gen: List[str],
183
+ word_ends_rel: List[int],
184
+ gen_attn_values: np.ndarray,
185
+ selected_token_rel_idx: int,
186
+ ) -> str:
187
+ """
188
+ words_gen: generated words only
189
+ word_ends_rel: last-token indices of each generated word (relative to generation)
190
+ gen_attn_values: length == len(gen_token_ids), attention over generated tokens only
191
+ (zeros for future tokens padded at the end)
192
+ """
193
+ if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0:
194
+ return (
195
+ "<div style='width:100%;'>"
196
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
197
+ " <div style='color:#ddd;'>No text attention values.</div>"
198
+ " </div>"
199
+ "</div>"
200
+ )
201
+
202
+ # compute word starts from ends (inclusive indexing)
203
+ starts = []
204
+ for i, end in enumerate(word_ends_rel):
205
+ if i == 0:
206
+ starts.append(0)
207
+ else:
208
+ starts.append(min(word_ends_rel[i - 1] + 1, end))
209
+
210
+ # sum attention per word
211
+ word_scores = []
212
+ T = len(gen_attn_values)
213
+ for i, end in enumerate(word_ends_rel):
214
+ start = starts[i]
215
+ if start > end:
216
+ start = end
217
+ s = max(0, min(start, T - 1))
218
+ e = max(0, min(end, T - 1))
219
+ if e < s:
220
+ s, e = e, s
221
+ word_scores.append(float(gen_attn_values[s:e + 1].sum()))
222
+
223
+ max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
224
+
225
+ # find selected word (contains selected token idx)
226
+ selected_word_idx = None
227
+ for i, end in enumerate(word_ends_rel):
228
+ if selected_token_rel_idx <= end:
229
+ selected_word_idx = i
230
+ break
231
+ if selected_word_idx is None and word_ends_rel:
232
+ selected_word_idx = len(word_ends_rel) - 1
233
+
234
+ spans = []
235
+ for i, w in enumerate(words_gen):
236
+ alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
237
+ bg = f"rgba(66,133,244,{alpha:.3f})"
238
+ border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
239
+ spans.append(
240
+ f"<span style='display:inline-block;background:{bg};border:{border};"
241
+ f"border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>"
242
+ f"{w}</span>"
243
+ )
244
+
245
+ return (
246
+ "<div style='width:100%;'>"
247
+ " <div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'>"
248
+ " <div style='white-space:normal;line-height:1.8;'>"
249
+ f" {''.join(spans)}"
250
+ " </div>"
251
+ " </div>"
252
+ "</div>"
253
+ )
254
+
255
+ # ========= Heatmap helpers for 1024 image tokens =========
256
+ def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray:
257
+ """
258
+ attn_1d: (k,) attention over keys for a given generation step; first 1024 are image tokens.
259
+ Returns a (32, 32) uint8 grayscale array.
260
+ """
261
+ # take first 1024 (image tokens); pad/truncate as needed
262
+ if attn_1d.shape[0] < img_token_len:
263
+ img_part = np.zeros(img_token_len, dtype=float)
264
+ img_part[: attn_1d.shape[0]] = attn_1d
265
+ else:
266
+ img_part = attn_1d[:img_token_len]
267
+
268
+ # normalize to [0,1]
269
+ mn, mx = float(img_part.min()), float(img_part.max())
270
+ denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0
271
+ norm = (img_part - mn) / denom
272
+
273
+ # return uint8 (0–255)
274
+ return (norm.reshape(side, side) * 255.0).astype(np.uint8)
275
+
276
+ def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image:
277
+ """
278
+ Convert (H,W) uint8 grayscale to RGB heatmap using matplotlib (if available) or a simple fallback.
279
+ """
280
+ if _HAS_MPL and _COLORMAP is not None:
281
+ colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8)
282
+ return Image.fromarray(colored) # Pillow infers RGB
283
+ else:
284
+ # Fallback: map grayscale to red-yellow (simple linear)
285
+ g = heatmap_u8.astype(np.float32) / 255.0
286
+ r = (g * 255.0).clip(0, 255).astype(np.uint8)
287
+ g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8)
288
+ b = np.zeros_like(r, dtype=np.uint8)
289
+ rgb = np.stack([r, g2, b], axis=-1)
290
+ return Image.fromarray(rgb) # Pillow infers RGB
291
+
292
+ def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
293
+ return img.resize(target_size, resample=Image.BILINEAR)
294
+
295
+ def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image:
296
+ """
297
+ Blend heatmap over original. alpha in [0,1].
298
+ """
299
+ if heatmap_rgb.size != orig.size:
300
+ heatmap_rgb = _resize_like(heatmap_rgb, orig.size)
301
+ base = orig.convert("RGBA")
302
+ overlay = heatmap_rgb.convert("RGBA")
303
+ # set global alpha
304
+ r, g, b = overlay.split()[:3]
305
+ a = Image.new("L", overlay.size, int(alpha * 255))
306
+ overlay = Image.merge("RGBA", (r, g, b, a))
307
+ return Image.alpha_composite(base, overlay).convert("RGB")
308
+
309
+ # ========= Core (image → generate) =========
310
+ def _prepare_image_tensor(pil_img, img_size=512):
311
+ tfm = image_transform(img_size=img_size)
312
+ tens = tfm(pil_img).unsqueeze(0).to(device, non_blocking=True) # [1,3,H,W]
313
+ return tens
314
+
315
+ def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads):
316
+ """
317
+ 1) Transform image
318
+ 2) model.generate(pixel_values=..., max_new_tokens=..., output_attentions=True)
319
+ expected to return (gen_ids, gen_text, attentions)
320
+ 3) Build selector over generated words only
321
+ 4) Initial visualization -> (orig, overlay, heatmap, word HTML)
322
+ """
323
+ if pil_image is None:
324
+ # Return placeholders
325
+ blank = Image.new("RGB", (256, 256), "black")
326
+ return (
327
+ None, None, 1024, None, None,
328
+ gr.update(choices=[], value=None),
329
+ blank, # original
330
+ blank, # overlay
331
+ np.zeros((256, 256, 3), dtype=np.uint8), # heatmap RGB upscaled (placeholder)
332
+ "<div style='text-align:center;padding:20px;'>Upload or load an image first.</div>",
333
+ )
334
+
335
+ pixel_values = _prepare_image_tensor(pil_image, img_size=512)
336
+
337
+ with torch.no_grad():
338
+ gen_ids, gen_text, attentions = model.generate(
339
+ pixel_values=pixel_values,
340
+ max_new_tokens=int(max_new_tokens),
341
+ output_attentions=True
342
+ )
343
+
344
+ # Expect batch size 1
345
+ if isinstance(gen_ids, torch.Tensor):
346
+ gen_ids = gen_ids[0].tolist()
347
+ gen_ids = _strip_trailing_special(gen_ids)
348
+
349
+ words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids)
350
+
351
+ display_choices = [(w, i) for i, w in enumerate(words_gen)]
352
+ if not display_choices:
353
+ # No generated tokens; still show original and blank heatmap/overlay
354
+ blank_hm = np.zeros((32, 32), dtype=np.uint8)
355
+ hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST)
356
+ overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35)
357
+ return (
358
+ attentions, gen_ids, 1024, words_gen, gen_word2tok_rel,
359
+ gr.update(choices=[], value=None),
360
+ pil_image, # original
361
+ overlay, # overlay
362
+ np.array(hm_rgb), # heatmap RGB
363
+ "<div style='text-align:center;padding:20px;'>No generated tokens to visualize.</div>",
364
+ )
365
+
366
+ first_idx = 0
367
+ hm_rgb_init, overlay_init, html_init = update_visualization(
368
+ selected_gen_index=first_idx,
369
+ attentions=attentions,
370
+ gen_token_ids=gen_ids,
371
+ layer=layer,
372
+ head=head,
373
+ mean_layers=mean_layers,
374
+ mean_heads=mean_heads,
375
+ words_gen=words_gen,
376
+ gen_word2tok_rel=gen_word2tok_rel,
377
+ pil_image=pil_image,
378
+ )
379
+
380
+ return (
381
+ attentions, # state_attentions
382
+ gen_ids, # state_gen_token_ids
383
+ 1024, # state_img_token_len (fixed)
384
+ words_gen, # state_words_gen
385
+ gen_word2tok_rel, # state_gen_word2tok_rel
386
+ gr.update(choices=display_choices, value=first_idx),
387
+ pil_image, # original image view
388
+ overlay_init, # overlay (PIL)
389
+ hm_rgb_init, # heatmap RGB (np array or PIL)
390
+ html_init, # HTML words viz
391
+ )
392
+
393
+ def update_visualization(
394
+ selected_gen_index,
395
+ attentions,
396
+ gen_token_ids,
397
+ layer,
398
+ head,
399
+ mean_layers,
400
+ mean_heads,
401
+ words_gen,
402
+ gen_word2tok_rel,
403
+ pil_image: Optional[Image.Image] = None,
404
+ ):
405
+ """
406
+ Recompute visualization for the chosen GENERATED word:
407
+ - Extract attention vector for that generation step.
408
+ - Build 32×32 heatmap from first 1024 values (image tokens), colorize and upscale to original image size.
409
+ - Create overlay (original + heatmap with alpha).
410
+ - Build word HTML from the portion corresponding to generated tokens.
411
+ For step t, keys cover: 1024 image tokens + (t+1) generated tokens so far.
412
+ """
413
+ if selected_gen_index is None or attentions is None or gen_word2tok_rel is None:
414
+ blank = np.zeros((256, 256, 3), dtype=np.uint8)
415
+ return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Generate first.</div>"
416
+
417
+ gidx = int(selected_gen_index)
418
+ if not (0 <= gidx < len(gen_word2tok_rel)):
419
+ blank = np.zeros((256, 256, 3), dtype=np.uint8)
420
+ return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>Invalid selection.</div>"
421
+
422
+ step_index = int(gen_word2tok_rel[gidx]) # last token of that word (relative to generation)
423
+ if not attentions or step_index >= len(attentions):
424
+ blank = np.zeros((256, 256, 3), dtype=np.uint8)
425
+ return Image.fromarray(blank), Image.fromarray(blank), "<div style='text-align:center;padding:20px;'>No attention for this step.</div>"
426
+
427
+ token_attn = get_attention_for_token_layer(
428
+ attentions,
429
+ token_index=step_index,
430
+ layer_index=int(layer),
431
+ head_index=int(head),
432
+ mean_across_layers=bool(mean_layers),
433
+ mean_across_heads=bool(mean_heads),
434
+ )
435
+
436
+ attn_vals = token_attn.detach().cpu().numpy()
437
+ if attn_vals.ndim == 2:
438
+ attn_vals = attn_vals[-1] # (k,) from (q,k)
439
+
440
+ # ---- Heatmap over 1024 image tokens (colorized and upscaled to original size) ----
441
+ heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32)
442
+ hm_rgb_pil = _colorize_heatmap(heatmap_u8)
443
+
444
+ # If original image not provided (should be), create a placeholder size
445
+ if pil_image is None:
446
+ pil_image = Image.new("RGB", (256, 256), "black")
447
+
448
+ hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST)
449
+ overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35)
450
+
451
+ # ---- Word-level viz over generated tokens only ----
452
+ k_len = int(attn_vals.shape[0])
453
+ observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024)))
454
+ total_gen = len(gen_token_ids)
455
+
456
+ gen_vec = np.zeros(total_gen, dtype=float)
457
+ if observed_gen > 0:
458
+ # slice generated part of attention vector
459
+ start = 1024
460
+ end = min(1024 + observed_gen, k_len)
461
+ gen_slice = attn_vals[start:end]
462
+ gen_vec[: len(gen_slice)] = gen_slice
463
+
464
+ selected_token_rel_idx = step_index
465
+
466
+ html_words = generate_word_visualization_gen_only(
467
+ words_gen=words_gen,
468
+ word_ends_rel=gen_word2tok_rel,
469
+ gen_attn_values=gen_vec,
470
+ selected_token_rel_idx=selected_token_rel_idx,
471
+ )
472
+
473
+ # Return (heatmap RGB, overlay, html)
474
+ return np.array(hm_rgb_pil_up), overlay_pil, html_words
475
+
476
+ def toggle_slider(is_mean):
477
+ return gr.update(interactive=not bool(is_mean))
478
+
479
+ # ========= Gradio UI =========
480
+ EXAMPLES_DIR = "examples"
481
+
482
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
483
+ gr.Markdown("# 🖼️→📝 Image-to-Text Attention Visualizer (three views + text)")
484
+ gr.Markdown(
485
+ "Upload an image or click **Load random sample**, generate text, then select a **generated word**. "
486
+ "Above: original image, overlay (original + attention), and heatmap (colored). "
487
+ "Below: word-level attention over generated text."
488
+ )
489
+
490
+ # States
491
+ state_attentions = gr.State(None) # tuple over generation steps
492
+ state_gen_token_ids = gr.State(None) # list[int]
493
+ state_img_token_len = gr.State(1024) # fixed
494
+ state_words_gen = gr.State(None) # list[str]
495
+ state_gen_word2tok_rel = gr.State(None) # list[int]
496
+ state_last_image = gr.State(None) # PIL image of last input
497
+
498
+ L, H = model_heads_layers()
499
+
500
+ with gr.Row():
501
+ with gr.Column(scale=1):
502
+ gr.Markdown("### 1) Image")
503
+ img_input = gr.Image(type="pil", label="Upload image", height=280)
504
+ btn_load_sample = gr.Button("Load random sample from /examples", variant="secondary")
505
+ sample_status = gr.Markdown("")
506
+
507
+ gr.Markdown("### 2) Generation")
508
+ slider_max_tokens = gr.Slider(5, 200, value=100, step=5, label="Max New Tokens")
509
+ btn_generate = gr.Button("Generate", variant="primary")
510
+
511
+ gr.Markdown("### 3) Attention")
512
+ check_mean_layers = gr.Checkbox(False, label="Mean Across Layers")
513
+ check_mean_heads = gr.Checkbox(False, label="Mean Across Heads")
514
+ slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=True)
515
+ slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=True)
516
+
517
+ with gr.Column(scale=3):
518
+ # Three views row
519
+ with gr.Row():
520
+ img_original_view = gr.Image(
521
+ value=None,
522
+ label="Original image",
523
+ image_mode="RGB",
524
+ height=256
525
+ )
526
+ img_overlay_view = gr.Image(
527
+ value=None,
528
+ label="Overlay (image + attention)",
529
+ image_mode="RGB",
530
+ height=256
531
+ )
532
+ heatmap_view = gr.Image(
533
+ value=None,
534
+ label="Heatmap (colored)",
535
+ image_mode="RGB",
536
+ height=256
537
+ )
538
+
539
+ # Word selector & HTML viz below
540
+ radio_word_selector = gr.Radio(
541
+ [], label="Select Generated Word",
542
+ info="Selector lists only generated words"
543
+ )
544
+ html_visualization = gr.HTML(
545
+ "<div style='text-align:center;padding:20px;color:#888;border:1px dashed #888;border-radius:8px;'>"
546
+ "Text attention visualization will appear here.</div>"
547
+ )
548
+
549
+ # Sample loader: always use `examples/`
550
+ def _load_sample_from_examples():
551
+ try:
552
+ files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")]
553
+ if not files:
554
+ return gr.update(), "No files in /examples."
555
+ fp = os.path.join(EXAMPLES_DIR, random.choice(files))
556
+ pil_img = pil_from_path(fp)
557
+ return gr.update(value=pil_img), f"Loaded sample: {os.path.basename(fp)}"
558
+ except Exception as e:
559
+ return gr.update(), f"Error loading sample: {e}"
560
+
561
+ btn_load_sample.click(
562
+ fn=_load_sample_from_examples,
563
+ inputs=[],
564
+ outputs=[img_input, sample_status]
565
+ )
566
+
567
+ # Generate
568
+ def _run_and_store(pil_image, *args):
569
+ out = run_generation(pil_image, *args)
570
+ # store the original image for later updates
571
+ return (*out, pil_image)
572
+
573
+ btn_generate.click(
574
+ fn=_run_and_store,
575
+ inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads],
576
+ outputs=[
577
+ state_attentions,
578
+ state_gen_token_ids,
579
+ state_img_token_len,
580
+ state_words_gen,
581
+ state_gen_word2tok_rel,
582
+ radio_word_selector,
583
+ img_original_view, # original
584
+ img_overlay_view, # overlay
585
+ heatmap_view, # heatmap
586
+ html_visualization, # words HTML
587
+ state_last_image, # store original PIL
588
+ ],
589
+ )
590
+
591
+ # Update viz on any control change
592
+ def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img):
593
+ hm_rgb, overlay, html = update_visualization(
594
+ selected_gen_index,
595
+ attn,
596
+ gen_ids,
597
+ lyr,
598
+ hed,
599
+ meanL,
600
+ meanH,
601
+ words,
602
+ word2tok,
603
+ pil_image=last_img
604
+ )
605
+ return overlay, hm_rgb, html
606
+
607
+ for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
608
+ control.change(
609
+ fn=_update_wrapper,
610
+ inputs=[
611
+ radio_word_selector,
612
+ state_attentions,
613
+ state_gen_token_ids,
614
+ slider_layer,
615
+ slider_head,
616
+ check_mean_layers,
617
+ check_mean_heads,
618
+ state_words_gen,
619
+ state_gen_word2tok_rel,
620
+ state_last_image,
621
+ ],
622
+ outputs=[img_overlay_view, heatmap_view, html_visualization],
623
+ )
624
+
625
+ # Toggle slider interactivity
626
+ check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
627
+ check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
628
+
629
+ if __name__ == "__main__":
630
+ print(f"Device: {device}")
631
+ demo.launch(debug=True)