Spaces:
Sleeping
Sleeping
| """ | |
| 🖼️→📝 Chest X-ray Report Generation + Attention Visualizer + Classification | |
| - Loads generation model (complete_model.safetensor) | |
| - Loads classification model (classification.pth) | |
| - Generates report and visualizes attention. | |
| - Lists disease probabilities. | |
| """ | |
| import os | |
| import re | |
| import random | |
| from typing import List, Tuple, Optional | |
| import logging | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| from safetensors.torch import load_model | |
| from transformers import AutoModel, AutoImageProcessor | |
| # Optional: nicer colormap | |
| try: | |
| import matplotlib as mpl | |
| _HAS_MPL = True | |
| _COLORMAP = mpl.colormaps.get_cmap("magma") | |
| except Exception: | |
| _HAS_MPL = False | |
| _COLORMAP = None | |
| # ========= Your utilities & model ========= | |
| from utils.processing import image_transform, pil_from_path | |
| from utils.complete_model import create_complete_model | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ============================================================================== | |
| # 1. CLASSIFIER LOGIC (Added) | |
| # ============================================================================== | |
| class EmbeddingClassifier(nn.Module): | |
| def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256), | |
| activation="gelu", dropout=0.05, bn=False, use_layernorm=True): | |
| super().__init__() | |
| layers = [] | |
| layers.append(nn.Linear(embedding_dim, custom_dims[0])) | |
| if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0])) | |
| elif bn: layers.append(nn.BatchNorm1d(custom_dims[0])) | |
| layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU()) | |
| if dropout > 0: layers.append(nn.Dropout(dropout)) | |
| for i in range(len(custom_dims) - 1): | |
| layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1])) | |
| if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1])) | |
| elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1])) | |
| layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU()) | |
| if dropout > 0: layers.append(nn.Dropout(dropout)) | |
| layers.append(nn.Linear(custom_dims[-1], num_classes)) | |
| self.classifier = nn.Sequential(*layers) | |
| def forward(self, embeddings): | |
| return self.classifier(embeddings) | |
| class ChestXrayPredictor: | |
| def __init__(self, base_model, classifier, processor, label_cols, device): | |
| self.base_model = base_model | |
| self.classifier = classifier | |
| self.processor = processor | |
| self.label_cols = label_cols | |
| self.device = device | |
| self.base_model.eval() | |
| self.classifier.eval() | |
| def predict(self, image_source): | |
| try: | |
| if isinstance(image_source, str): | |
| image = Image.open(image_source).convert('RGB') | |
| else: | |
| image = image_source.convert('RGB') | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| pixel_values = inputs['pixel_values'].to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.base_model(pixel_values=pixel_values) | |
| if hasattr(outputs, 'last_hidden_state'): | |
| embeddings = outputs.last_hidden_state.mean(dim=1) | |
| else: | |
| embeddings = outputs[0].mean(dim=1) | |
| logits = self.classifier(embeddings) | |
| probs = torch.sigmoid(logits).cpu().numpy()[0].tolist() | |
| return {label: prob for label, prob in zip(self.label_cols, probs)} | |
| except Exception as e: | |
| print(f"Prediction Error: {e}") | |
| return {} | |
| def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None): | |
| device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Loading Classifier from {checkpoint_path}...") | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| label_cols = checkpoint.get('label_cols', [ | |
| "Cardiomegaly", "Consolidation", "Edema", | |
| "Atelectasis", "Pleural Effusion", "No Findings" | |
| ]) | |
| base_model = AutoModel.from_pretrained(model_id).to(device) | |
| if 'base_model_state_dict' in checkpoint: | |
| base_model.load_state_dict(checkpoint['base_model_state_dict']) | |
| processor = AutoImageProcessor.from_pretrained(model_id) | |
| # Detect dims | |
| with torch.no_grad(): | |
| dummy = torch.randn(1, 3, 224, 224).to(device) | |
| out = base_model(pixel_values=dummy) | |
| embedding_dim = out.last_hidden_state.shape[-1] | |
| # Rebuild MLP | |
| model_state = checkpoint['model_state_dict'] | |
| linear_layers = [] | |
| for key, val in model_state.items(): | |
| if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2: | |
| match = re.search(r'classifier\.(\d+)\.weight', key) | |
| if match: | |
| linear_layers.append((int(match.group(1)), val.shape[1], val.shape[0])) | |
| linear_layers.sort(key=lambda x: x[0]) | |
| num_classes = linear_layers[-1][2] | |
| hidden_dims = tuple([x[2] for x in linear_layers[:-1]]) | |
| uses_bn = any('running_mean' in k for k in model_state.keys()) | |
| has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k) | |
| classifier = EmbeddingClassifier(embedding_dim, num_classes, custom_dims=hidden_dims, bn=uses_bn, use_layernorm=(has_norm and not uses_bn)) | |
| classifier.load_state_dict(model_state) | |
| classifier.to(device) | |
| return ChestXrayPredictor(base_model, classifier, processor, label_cols, device) | |
| # ============================================================================== | |
| # 2. LOAD MODELS | |
| # ============================================================================== | |
| # A. Load Generator | |
| print("Loading Generation Model...") | |
| model = create_complete_model(device=DEVICE, attention_implementation="eager") | |
| SAFETENSOR_PATH = "complete_model.safetensor" | |
| try: | |
| load_model(model, SAFETENSOR_PATH) | |
| except Exception as e: | |
| print(f"Error loading generation model: {e}") | |
| model.eval() | |
| # B. Load Classifier | |
| print("Loading Classification Model...") | |
| CLASSIFIER_PATH = "classification.pth" | |
| classifier_model = None | |
| try: | |
| if os.path.exists(CLASSIFIER_PATH): | |
| classifier_model = create_classifier(CLASSIFIER_PATH, device=DEVICE) | |
| print("✅ Classifier loaded.") | |
| else: | |
| print(f"⚠️ Classifier not found at {CLASSIFIER_PATH}") | |
| except Exception as e: | |
| print(f"⚠️ Error loading classifier: {e}") | |
| # --- Tokenizer setup --- | |
| tokenizer = getattr(model, "tokenizer", None) | |
| if tokenizer is None: | |
| raise ValueError("Expected `model.tokenizer` to exist.") | |
| pad_id = getattr(tokenizer, "pad_token_id", None) | |
| eos_id = getattr(tokenizer, "eos_token_id", None) | |
| needs_resize = False | |
| if pad_id is None or (eos_id is not None and pad_id == eos_id): | |
| tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) | |
| needs_resize = True | |
| if needs_resize: | |
| resize_fns = [ | |
| getattr(getattr(model, "decoder", None), "resize_token_embeddings", None), | |
| getattr(model, "resize_token_embeddings", None), | |
| ] | |
| for fn in resize_fns: | |
| if callable(fn): | |
| try: | |
| fn(len(tokenizer)) | |
| break | |
| except Exception: | |
| pass | |
| WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]") | |
| # ========= Logic ========= | |
| def model_heads_layers(): | |
| def _get(obj, *names, default=None): | |
| for n in names: | |
| if obj is None: return default | |
| if hasattr(obj, n): return int(getattr(obj, n)) | |
| return default | |
| cfg_candidates = [ | |
| getattr(model, "config", None), | |
| getattr(getattr(model, "decoder", None), "config", None), | |
| getattr(getattr(model, "lm_head", None), "config", None), | |
| ] | |
| L = H = None | |
| for cfg in cfg_candidates: | |
| if L is None: L = _get(cfg, "num_hidden_layers", "n_layer") | |
| if H is None: H = _get(cfg, "num_attention_heads", "n_head") | |
| return max(1, L or 12), max(1, H or 12) | |
| def get_attention_for_token_layer(attentions, token_index, layer_index, batch_index=0, head_index=0, mean_across_layers=True, mean_across_heads=True): | |
| token_attention = attentions[token_index] | |
| if mean_across_layers: | |
| layer_attention = torch.stack(token_attention).mean(dim=0) | |
| else: | |
| layer_attention = token_attention[int(layer_index)] | |
| batch_attention = layer_attention[int(batch_index)] | |
| if mean_across_heads: | |
| head_attention = batch_attention.mean(dim=0) | |
| else: | |
| head_attention = batch_attention[int(head_index)] | |
| return head_attention.squeeze(0) | |
| def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]: | |
| if not token_ids: return [], [] | |
| toks = tokenizer.convert_ids_to_tokens(token_ids) | |
| detok = tokenizer.convert_tokens_to_string(toks) | |
| matches = list(re.finditer(WORD_RE, detok)) | |
| words = [m.group(0) for m in matches] | |
| ends = [m.span()[1] for m in matches] | |
| word2tok: List[int] = [] | |
| for we in ends: | |
| prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False) | |
| if not prefix_ids: | |
| word2tok.append(0) | |
| continue | |
| last_idx = len(prefix_ids) - 1 | |
| last_idx = max(0, min(last_idx, len(token_ids) - 1)) | |
| word2tok.append(last_idx) | |
| return words, word2tok | |
| def _strip_trailing_special(ids: List[int]) -> List[int]: | |
| specials = set(getattr(tokenizer, "all_special_ids", []) or []) | |
| j = len(ids) | |
| while j > 0 and ids[j - 1] in specials: | |
| j -= 1 | |
| return ids[:j] | |
| def generate_word_visualization_gen_only(words_gen, word_ends_rel, gen_attn_values, selected_token_rel_idx): | |
| if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0: | |
| return "<div style='width:100%;'>No text attention values.</div>" | |
| starts = [] | |
| for i, end in enumerate(word_ends_rel): | |
| if i == 0: starts.append(0) | |
| else: starts.append(min(word_ends_rel[i - 1] + 1, end)) | |
| word_scores = [] | |
| T = len(gen_attn_values) | |
| for i, end in enumerate(word_ends_rel): | |
| start = starts[i] | |
| if start > end: start = end | |
| s = max(0, min(start, T - 1)) | |
| e = max(0, min(end, T - 1)) | |
| if e < s: s, e = e, s | |
| word_scores.append(float(gen_attn_values[s:e + 1].sum())) | |
| max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0) | |
| selected_word_idx = None | |
| for i, end in enumerate(word_ends_rel): | |
| if selected_token_rel_idx <= end: | |
| selected_word_idx = i | |
| break | |
| if selected_word_idx is None and word_ends_rel: selected_word_idx = len(word_ends_rel) - 1 | |
| spans = [] | |
| for i, w in enumerate(words_gen): | |
| alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0 | |
| bg = f"rgba(66,133,244,{alpha:.3f})" | |
| border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent" | |
| spans.append(f"<span style='display:inline-block;background:{bg};border:{border};border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>{w}</span>") | |
| return f"<div style='width:100%;'><div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'><div style='white-space:normal;line-height:1.8;'>{''.join(spans)}</div></div></div>" | |
| def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray: | |
| if attn_1d.shape[0] < img_token_len: | |
| img_part = np.zeros(img_token_len, dtype=float) | |
| img_part[: attn_1d.shape[0]] = attn_1d | |
| else: | |
| img_part = attn_1d[:img_token_len] | |
| mn, mx = float(img_part.min()), float(img_part.max()) | |
| denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0 | |
| norm = (img_part - mn) / denom | |
| return (norm.reshape(side, side) * 255.0).astype(np.uint8) | |
| def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image: | |
| if _HAS_MPL and _COLORMAP is not None: | |
| colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8) | |
| return Image.fromarray(colored) | |
| else: | |
| g = heatmap_u8.astype(np.float32) / 255.0 | |
| r = (g * 255.0).clip(0, 255).astype(np.uint8) | |
| g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8) | |
| b = np.zeros_like(r, dtype=np.uint8) | |
| rgb = np.stack([r, g2, b], axis=-1) | |
| return Image.fromarray(rgb) | |
| def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: | |
| return img.resize(target_size, resample=Image.BILINEAR) | |
| def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image: | |
| if heatmap_rgb.size != orig.size: | |
| heatmap_rgb = _resize_like(heatmap_rgb, orig.size) | |
| base = orig.convert("RGBA") | |
| overlay = heatmap_rgb.convert("RGBA") | |
| r, g, b = overlay.split()[:3] | |
| a = Image.new("L", overlay.size, int(alpha * 255)) | |
| overlay = Image.merge("RGBA", (r, g, b, a)) | |
| return Image.alpha_composite(base, overlay).convert("RGB") | |
| def _prepare_image_tensor(pil_img, img_size=512): | |
| tfm = image_transform(img_size=img_size) | |
| tens = tfm(pil_img).unsqueeze(0).to(DEVICE, non_blocking=True) | |
| return tens | |
| def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads): | |
| if pil_image is None: | |
| blank = Image.new("RGB", (256, 256), "black") | |
| return (None, None, 1024, None, None, gr.update(choices=[], value=None), blank, blank, np.zeros((256, 256, 3), dtype=np.uint8), "<div style='text-align:center;'>Upload image first.</div>") | |
| pixel_values = _prepare_image_tensor(pil_image, img_size=512) | |
| with torch.no_grad(): | |
| gen_ids, gen_text, attentions = model.generate(pixel_values=pixel_values, max_new_tokens=int(max_new_tokens), output_attentions=True) | |
| if isinstance(gen_ids, torch.Tensor): gen_ids = gen_ids[0].tolist() | |
| gen_ids = _strip_trailing_special(gen_ids) | |
| words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids) | |
| display_choices = [(w, i) for i, w in enumerate(words_gen)] | |
| if not display_choices: | |
| blank_hm = np.zeros((32, 32), dtype=np.uint8) | |
| hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST) | |
| overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35) | |
| return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=[], value=None), pil_image, overlay, np.array(hm_rgb), "<div>No tokens.</div>") | |
| first_idx = 0 | |
| hm_rgb_init, overlay_init, html_init = update_visualization(first_idx, attentions, gen_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image) | |
| return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=display_choices, value=first_idx), pil_image, overlay_init, hm_rgb_init, html_init) | |
| def update_visualization(selected_gen_index, attentions, gen_token_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image: Optional[Image.Image] = None): | |
| if selected_gen_index is None or attentions is None or gen_word2tok_rel is None: | |
| blank = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return Image.fromarray(blank), Image.fromarray(blank), "<div>Generate first.</div>" | |
| gidx = int(selected_gen_index) | |
| if not (0 <= gidx < len(gen_word2tok_rel)): | |
| blank = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return Image.fromarray(blank), Image.fromarray(blank), "<div>Invalid selection.</div>" | |
| step_index = int(gen_word2tok_rel[gidx]) | |
| if not attentions or step_index >= len(attentions): | |
| blank = np.zeros((256, 256, 3), dtype=np.uint8) | |
| return Image.fromarray(blank), Image.fromarray(blank), "<div>No attention.</div>" | |
| token_attn = get_attention_for_token_layer(attentions, token_index=step_index, layer_index=int(layer), head_index=int(head), mean_across_layers=bool(mean_layers), mean_across_heads=bool(mean_heads)) | |
| attn_vals = token_attn.detach().cpu().numpy() | |
| if attn_vals.ndim == 2: attn_vals = attn_vals[-1] | |
| heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32) | |
| hm_rgb_pil = _colorize_heatmap(heatmap_u8) | |
| if pil_image is None: pil_image = Image.new("RGB", (256, 256), "black") | |
| hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST) | |
| overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35) | |
| k_len = int(attn_vals.shape[0]) | |
| observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024))) | |
| total_gen = len(gen_token_ids) | |
| gen_vec = np.zeros(total_gen, dtype=float) | |
| if observed_gen > 0: | |
| start = 1024 | |
| end = min(1024 + observed_gen, k_len) | |
| gen_slice = attn_vals[start:end] | |
| gen_vec[: len(gen_slice)] = gen_slice | |
| html_words = generate_word_visualization_gen_only(words_gen, gen_word2tok_rel, gen_vec, step_index) | |
| return np.array(hm_rgb_pil_up), overlay_pil, html_words | |
| def toggle_slider(is_mean): | |
| return gr.update(interactive=not bool(is_mean)) | |
| # ========= Gradio UI ========= | |
| EXAMPLES_DIR = "examples" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🖼️→📝 Chest X-ray Report Generation & Classification") | |
| # States | |
| state_attentions = gr.State(None) | |
| state_gen_token_ids = gr.State(None) | |
| state_img_token_len = gr.State(1024) | |
| state_words_gen = gr.State(None) | |
| state_gen_word2tok_rel = gr.State(None) | |
| state_last_image = gr.State(None) | |
| L, H = model_heads_layers() | |
| with gr.Row(): | |
| # LEFT COLUMN | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1) Input") | |
| img_input = gr.Image(type="pil", label="Upload image", height=280) | |
| btn_load_sample = gr.Button("Load random sample", variant="secondary") | |
| sample_status = gr.Markdown("") | |
| gr.Markdown("### 2) Generation Settings") | |
| slider_max_tokens = gr.Slider(5, 200, value=100, step=5, label="Max New Tokens") | |
| btn_generate = gr.Button("GENERATE REPORT & CLASSIFY", variant="primary") | |
| gr.Markdown("### 3) Attention Visualization") | |
| check_mean_layers = gr.Checkbox(False, label="Mean Across Layers") | |
| check_mean_heads = gr.Checkbox(False, label="Mean Across Heads") | |
| slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=True) | |
| slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=True) | |
| # --- NEW CLASSIFICATION SECTION --- | |
| gr.Markdown("### 4) Disease Probability") | |
| classification_output = gr.Dataframe( | |
| headers=["Disease", "Probability"], | |
| datatype=["str", "str"], | |
| label="Predictions", | |
| interactive=False | |
| ) | |
| # ---------------------------------- | |
| # RIGHT COLUMN | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| img_original_view = gr.Image(label="Original", image_mode="RGB", height=256) | |
| img_overlay_view = gr.Image(label="Attention Overlay", image_mode="RGB", height=256) | |
| heatmap_view = gr.Image(label="Heatmap", image_mode="RGB", height=256) | |
| radio_word_selector = gr.Radio([], label="Select Generated Word", info="Shows attention for this word") | |
| html_visualization = gr.HTML("<div style='text-align:center;padding:20px;color:#888;'>Text attention visualization will appear here.</div>") | |
| # Sample loader | |
| def _load_sample_from_examples(): | |
| try: | |
| files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")] | |
| if not files: return gr.update(), "No files." | |
| fp = os.path.join(EXAMPLES_DIR, random.choice(files)) | |
| pil_img = pil_from_path(fp) | |
| return gr.update(value=pil_img), f"Loaded: {os.path.basename(fp)}" | |
| except Exception as e: | |
| return gr.update(), f"Error: {e}" | |
| btn_load_sample.click(_load_sample_from_examples, inputs=[], outputs=[img_input, sample_status]) | |
| # MAIN RUN FUNCTION (GENERATION + CLASSIFICATION) | |
| def _run_all_logic(pil_image, *args): | |
| # 1. Run Generation (returns 10 items) | |
| gen_results = run_generation(pil_image, *args) | |
| # 2. Run Classification | |
| classification_data = [] | |
| if pil_image and classifier_model: | |
| try: | |
| preds = classifier_model.predict(pil_image) | |
| # Sort by probability descending | |
| sorted_preds = sorted(preds.items(), key=lambda x: x[1], reverse=True) | |
| # Format as list of lists for Gradio Dataframe: ["Name", "95.5%"] | |
| classification_data = [[k, f"{v:.1f}%"] for k, v in sorted_preds] | |
| except Exception as e: | |
| print(f"Classification runtime error: {e}") | |
| classification_data = [["Error", str(e)]] | |
| # Combine: gen_results + original_image (for state) + classification_data | |
| return (*gen_results, pil_image, classification_data) | |
| btn_generate.click( | |
| fn=_run_all_logic, | |
| inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads], | |
| outputs=[ | |
| state_attentions, | |
| state_gen_token_ids, | |
| state_img_token_len, | |
| state_words_gen, | |
| state_gen_word2tok_rel, | |
| radio_word_selector, | |
| img_original_view, | |
| img_overlay_view, | |
| heatmap_view, | |
| html_visualization, | |
| state_last_image, # Added to outputs | |
| classification_output # Added to outputs | |
| ], | |
| ) | |
| # UI updates for visualizer controls | |
| def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img): | |
| hm_rgb, overlay, html = update_visualization(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, pil_image=last_img) | |
| return overlay, hm_rgb, html | |
| for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]: | |
| control.change( | |
| fn=_update_wrapper, | |
| inputs=[radio_word_selector, state_attentions, state_gen_token_ids, slider_layer, slider_head, check_mean_layers, check_mean_heads, state_words_gen, state_gen_word2tok_rel, state_last_image], | |
| outputs=[img_overlay_view, heatmap_view, html_visualization], | |
| ) | |
| check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer) | |
| check_mean_heads.change(toggle_slider, check_mean_heads, slider_head) | |
| if __name__ == "__main__": | |
| print(f"Device: {DEVICE}") | |
| demo.launch(debug=True) |