""" 🖼️→📝 Chest X-ray Report Generation + Attention Visualizer + Classification - Loads generation model (complete_model.safetensor) - Loads classification model (classification.pth) - Generates report and visualizes attention. - Lists disease probabilities. """ import os import re import random from typing import List, Tuple, Optional import logging import gradio as gr import torch import torch.nn as nn import numpy as np from PIL import Image from safetensors.torch import load_model from transformers import AutoModel, AutoImageProcessor # Optional: nicer colormap try: import matplotlib as mpl _HAS_MPL = True _COLORMAP = mpl.colormaps.get_cmap("magma") except Exception: _HAS_MPL = False _COLORMAP = None # ========= Your utilities & model ========= from utils.processing import image_transform, pil_from_path from utils.complete_model import create_complete_model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ============================================================================== # 1. CLASSIFIER LOGIC (Added) # ============================================================================== class EmbeddingClassifier(nn.Module): def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256), activation="gelu", dropout=0.05, bn=False, use_layernorm=True): super().__init__() layers = [] layers.append(nn.Linear(embedding_dim, custom_dims[0])) if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0])) elif bn: layers.append(nn.BatchNorm1d(custom_dims[0])) layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU()) if dropout > 0: layers.append(nn.Dropout(dropout)) for i in range(len(custom_dims) - 1): layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1])) if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1])) elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1])) layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU()) if dropout > 0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(custom_dims[-1], num_classes)) self.classifier = nn.Sequential(*layers) def forward(self, embeddings): return self.classifier(embeddings) class ChestXrayPredictor: def __init__(self, base_model, classifier, processor, label_cols, device): self.base_model = base_model self.classifier = classifier self.processor = processor self.label_cols = label_cols self.device = device self.base_model.eval() self.classifier.eval() def predict(self, image_source): try: if isinstance(image_source, str): image = Image.open(image_source).convert('RGB') else: image = image_source.convert('RGB') inputs = self.processor(images=image, return_tensors="pt") pixel_values = inputs['pixel_values'].to(self.device) with torch.no_grad(): outputs = self.base_model(pixel_values=pixel_values) if hasattr(outputs, 'last_hidden_state'): embeddings = outputs.last_hidden_state.mean(dim=1) else: embeddings = outputs[0].mean(dim=1) logits = self.classifier(embeddings) probs = torch.sigmoid(logits).cpu().numpy()[0].tolist() return {label: prob for label, prob in zip(self.label_cols, probs)} except Exception as e: print(f"Prediction Error: {e}") return {} def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None): device = device or ('cuda' if torch.cuda.is_available() else 'cpu') print(f"Loading Classifier from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) label_cols = checkpoint.get('label_cols', [ "Cardiomegaly", "Consolidation", "Edema", "Atelectasis", "Pleural Effusion", "No Findings" ]) base_model = AutoModel.from_pretrained(model_id).to(device) if 'base_model_state_dict' in checkpoint: base_model.load_state_dict(checkpoint['base_model_state_dict']) processor = AutoImageProcessor.from_pretrained(model_id) # Detect dims with torch.no_grad(): dummy = torch.randn(1, 3, 224, 224).to(device) out = base_model(pixel_values=dummy) embedding_dim = out.last_hidden_state.shape[-1] # Rebuild MLP model_state = checkpoint['model_state_dict'] linear_layers = [] for key, val in model_state.items(): if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2: match = re.search(r'classifier\.(\d+)\.weight', key) if match: linear_layers.append((int(match.group(1)), val.shape[1], val.shape[0])) linear_layers.sort(key=lambda x: x[0]) num_classes = linear_layers[-1][2] hidden_dims = tuple([x[2] for x in linear_layers[:-1]]) uses_bn = any('running_mean' in k for k in model_state.keys()) has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k) classifier = EmbeddingClassifier(embedding_dim, num_classes, custom_dims=hidden_dims, bn=uses_bn, use_layernorm=(has_norm and not uses_bn)) classifier.load_state_dict(model_state) classifier.to(device) return ChestXrayPredictor(base_model, classifier, processor, label_cols, device) # ============================================================================== # 2. LOAD MODELS # ============================================================================== # A. Load Generator print("Loading Generation Model...") model = create_complete_model(device=DEVICE, attention_implementation="eager") SAFETENSOR_PATH = "complete_model.safetensor" try: load_model(model, SAFETENSOR_PATH) except Exception as e: print(f"Error loading generation model: {e}") model.eval() # B. Load Classifier print("Loading Classification Model...") CLASSIFIER_PATH = "classification.pth" classifier_model = None try: if os.path.exists(CLASSIFIER_PATH): classifier_model = create_classifier(CLASSIFIER_PATH, device=DEVICE) print("✅ Classifier loaded.") else: print(f"⚠️ Classifier not found at {CLASSIFIER_PATH}") except Exception as e: print(f"⚠️ Error loading classifier: {e}") # --- Tokenizer setup --- tokenizer = getattr(model, "tokenizer", None) if tokenizer is None: raise ValueError("Expected `model.tokenizer` to exist.") pad_id = getattr(tokenizer, "pad_token_id", None) eos_id = getattr(tokenizer, "eos_token_id", None) needs_resize = False if pad_id is None or (eos_id is not None and pad_id == eos_id): tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) needs_resize = True if needs_resize: resize_fns = [ getattr(getattr(model, "decoder", None), "resize_token_embeddings", None), getattr(model, "resize_token_embeddings", None), ] for fn in resize_fns: if callable(fn): try: fn(len(tokenizer)) break except Exception: pass WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]") # ========= Logic ========= def model_heads_layers(): def _get(obj, *names, default=None): for n in names: if obj is None: return default if hasattr(obj, n): return int(getattr(obj, n)) return default cfg_candidates = [ getattr(model, "config", None), getattr(getattr(model, "decoder", None), "config", None), getattr(getattr(model, "lm_head", None), "config", None), ] L = H = None for cfg in cfg_candidates: if L is None: L = _get(cfg, "num_hidden_layers", "n_layer") if H is None: H = _get(cfg, "num_attention_heads", "n_head") return max(1, L or 12), max(1, H or 12) def get_attention_for_token_layer(attentions, token_index, layer_index, batch_index=0, head_index=0, mean_across_layers=True, mean_across_heads=True): token_attention = attentions[token_index] if mean_across_layers: layer_attention = torch.stack(token_attention).mean(dim=0) else: layer_attention = token_attention[int(layer_index)] batch_attention = layer_attention[int(batch_index)] if mean_across_heads: head_attention = batch_attention.mean(dim=0) else: head_attention = batch_attention[int(head_index)] return head_attention.squeeze(0) def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]: if not token_ids: return [], [] toks = tokenizer.convert_ids_to_tokens(token_ids) detok = tokenizer.convert_tokens_to_string(toks) matches = list(re.finditer(WORD_RE, detok)) words = [m.group(0) for m in matches] ends = [m.span()[1] for m in matches] word2tok: List[int] = [] for we in ends: prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False) if not prefix_ids: word2tok.append(0) continue last_idx = len(prefix_ids) - 1 last_idx = max(0, min(last_idx, len(token_ids) - 1)) word2tok.append(last_idx) return words, word2tok def _strip_trailing_special(ids: List[int]) -> List[int]: specials = set(getattr(tokenizer, "all_special_ids", []) or []) j = len(ids) while j > 0 and ids[j - 1] in specials: j -= 1 return ids[:j] def generate_word_visualization_gen_only(words_gen, word_ends_rel, gen_attn_values, selected_token_rel_idx): if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0: return "
No text attention values.
" starts = [] for i, end in enumerate(word_ends_rel): if i == 0: starts.append(0) else: starts.append(min(word_ends_rel[i - 1] + 1, end)) word_scores = [] T = len(gen_attn_values) for i, end in enumerate(word_ends_rel): start = starts[i] if start > end: start = end s = max(0, min(start, T - 1)) e = max(0, min(end, T - 1)) if e < s: s, e = e, s word_scores.append(float(gen_attn_values[s:e + 1].sum())) max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0) selected_word_idx = None for i, end in enumerate(word_ends_rel): if selected_token_rel_idx <= end: selected_word_idx = i break if selected_word_idx is None and word_ends_rel: selected_word_idx = len(word_ends_rel) - 1 spans = [] for i, w in enumerate(words_gen): alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0 bg = f"rgba(66,133,244,{alpha:.3f})" border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent" spans.append(f"{w}") return f"
{''.join(spans)}
" def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray: if attn_1d.shape[0] < img_token_len: img_part = np.zeros(img_token_len, dtype=float) img_part[: attn_1d.shape[0]] = attn_1d else: img_part = attn_1d[:img_token_len] mn, mx = float(img_part.min()), float(img_part.max()) denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0 norm = (img_part - mn) / denom return (norm.reshape(side, side) * 255.0).astype(np.uint8) def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image: if _HAS_MPL and _COLORMAP is not None: colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8) return Image.fromarray(colored) else: g = heatmap_u8.astype(np.float32) / 255.0 r = (g * 255.0).clip(0, 255).astype(np.uint8) g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8) b = np.zeros_like(r, dtype=np.uint8) rgb = np.stack([r, g2, b], axis=-1) return Image.fromarray(rgb) def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image: return img.resize(target_size, resample=Image.BILINEAR) def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image: if heatmap_rgb.size != orig.size: heatmap_rgb = _resize_like(heatmap_rgb, orig.size) base = orig.convert("RGBA") overlay = heatmap_rgb.convert("RGBA") r, g, b = overlay.split()[:3] a = Image.new("L", overlay.size, int(alpha * 255)) overlay = Image.merge("RGBA", (r, g, b, a)) return Image.alpha_composite(base, overlay).convert("RGB") def _prepare_image_tensor(pil_img, img_size=512): tfm = image_transform(img_size=img_size) tens = tfm(pil_img).unsqueeze(0).to(DEVICE, non_blocking=True) return tens def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads): if pil_image is None: blank = Image.new("RGB", (256, 256), "black") return (None, None, 1024, None, None, gr.update(choices=[], value=None), blank, blank, np.zeros((256, 256, 3), dtype=np.uint8), "
Upload image first.
") pixel_values = _prepare_image_tensor(pil_image, img_size=512) with torch.no_grad(): gen_ids, gen_text, attentions = model.generate(pixel_values=pixel_values, max_new_tokens=int(max_new_tokens), output_attentions=True) if isinstance(gen_ids, torch.Tensor): gen_ids = gen_ids[0].tolist() gen_ids = _strip_trailing_special(gen_ids) words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids) display_choices = [(w, i) for i, w in enumerate(words_gen)] if not display_choices: blank_hm = np.zeros((32, 32), dtype=np.uint8) hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST) overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35) return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=[], value=None), pil_image, overlay, np.array(hm_rgb), "
No tokens.
") first_idx = 0 hm_rgb_init, overlay_init, html_init = update_visualization(first_idx, attentions, gen_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image) return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=display_choices, value=first_idx), pil_image, overlay_init, hm_rgb_init, html_init) def update_visualization(selected_gen_index, attentions, gen_token_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image: Optional[Image.Image] = None): if selected_gen_index is None or attentions is None or gen_word2tok_rel is None: blank = np.zeros((256, 256, 3), dtype=np.uint8) return Image.fromarray(blank), Image.fromarray(blank), "
Generate first.
" gidx = int(selected_gen_index) if not (0 <= gidx < len(gen_word2tok_rel)): blank = np.zeros((256, 256, 3), dtype=np.uint8) return Image.fromarray(blank), Image.fromarray(blank), "
Invalid selection.
" step_index = int(gen_word2tok_rel[gidx]) if not attentions or step_index >= len(attentions): blank = np.zeros((256, 256, 3), dtype=np.uint8) return Image.fromarray(blank), Image.fromarray(blank), "
No attention.
" token_attn = get_attention_for_token_layer(attentions, token_index=step_index, layer_index=int(layer), head_index=int(head), mean_across_layers=bool(mean_layers), mean_across_heads=bool(mean_heads)) attn_vals = token_attn.detach().cpu().numpy() if attn_vals.ndim == 2: attn_vals = attn_vals[-1] heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32) hm_rgb_pil = _colorize_heatmap(heatmap_u8) if pil_image is None: pil_image = Image.new("RGB", (256, 256), "black") hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST) overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35) k_len = int(attn_vals.shape[0]) observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024))) total_gen = len(gen_token_ids) gen_vec = np.zeros(total_gen, dtype=float) if observed_gen > 0: start = 1024 end = min(1024 + observed_gen, k_len) gen_slice = attn_vals[start:end] gen_vec[: len(gen_slice)] = gen_slice html_words = generate_word_visualization_gen_only(words_gen, gen_word2tok_rel, gen_vec, step_index) return np.array(hm_rgb_pil_up), overlay_pil, html_words def toggle_slider(is_mean): return gr.update(interactive=not bool(is_mean)) # ========= Gradio UI ========= EXAMPLES_DIR = "examples" with gr.Blocks() as demo: gr.Markdown("# 🖼️→📝 Chest X-ray Report Generation & Classification") # States state_attentions = gr.State(None) state_gen_token_ids = gr.State(None) state_img_token_len = gr.State(1024) state_words_gen = gr.State(None) state_gen_word2tok_rel = gr.State(None) state_last_image = gr.State(None) L, H = model_heads_layers() with gr.Row(): # LEFT COLUMN with gr.Column(scale=1): gr.Markdown("### 1) Input") img_input = gr.Image(type="pil", label="Upload image", height=280) btn_load_sample = gr.Button("Load random sample", variant="secondary") sample_status = gr.Markdown("") gr.Markdown("### 2) Generation Settings") slider_max_tokens = gr.Slider(5, 200, value=100, step=5, label="Max New Tokens") btn_generate = gr.Button("GENERATE REPORT & CLASSIFY", variant="primary") gr.Markdown("### 3) Attention Visualization") check_mean_layers = gr.Checkbox(False, label="Mean Across Layers") check_mean_heads = gr.Checkbox(False, label="Mean Across Heads") slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=True) slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=True) # --- NEW CLASSIFICATION SECTION --- gr.Markdown("### 4) Disease Probability") classification_output = gr.Dataframe( headers=["Disease", "Probability"], datatype=["str", "str"], label="Predictions", interactive=False ) # ---------------------------------- # RIGHT COLUMN with gr.Column(scale=3): with gr.Row(): img_original_view = gr.Image(label="Original", image_mode="RGB", height=256) img_overlay_view = gr.Image(label="Attention Overlay", image_mode="RGB", height=256) heatmap_view = gr.Image(label="Heatmap", image_mode="RGB", height=256) radio_word_selector = gr.Radio([], label="Select Generated Word", info="Shows attention for this word") html_visualization = gr.HTML("
Text attention visualization will appear here.
") # Sample loader def _load_sample_from_examples(): try: files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")] if not files: return gr.update(), "No files." fp = os.path.join(EXAMPLES_DIR, random.choice(files)) pil_img = pil_from_path(fp) return gr.update(value=pil_img), f"Loaded: {os.path.basename(fp)}" except Exception as e: return gr.update(), f"Error: {e}" btn_load_sample.click(_load_sample_from_examples, inputs=[], outputs=[img_input, sample_status]) # MAIN RUN FUNCTION (GENERATION + CLASSIFICATION) def _run_all_logic(pil_image, *args): # 1. Run Generation (returns 10 items) gen_results = run_generation(pil_image, *args) # 2. Run Classification classification_data = [] if pil_image and classifier_model: try: preds = classifier_model.predict(pil_image) # Sort by probability descending sorted_preds = sorted(preds.items(), key=lambda x: x[1], reverse=True) # Format as list of lists for Gradio Dataframe: ["Name", "95.5%"] classification_data = [[k, f"{v:.1f}%"] for k, v in sorted_preds] except Exception as e: print(f"Classification runtime error: {e}") classification_data = [["Error", str(e)]] # Combine: gen_results + original_image (for state) + classification_data return (*gen_results, pil_image, classification_data) btn_generate.click( fn=_run_all_logic, inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads], outputs=[ state_attentions, state_gen_token_ids, state_img_token_len, state_words_gen, state_gen_word2tok_rel, radio_word_selector, img_original_view, img_overlay_view, heatmap_view, html_visualization, state_last_image, # Added to outputs classification_output # Added to outputs ], ) # UI updates for visualizer controls def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img): hm_rgb, overlay, html = update_visualization(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, pil_image=last_img) return overlay, hm_rgb, html for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]: control.change( fn=_update_wrapper, inputs=[radio_word_selector, state_attentions, state_gen_token_ids, slider_layer, slider_head, check_mean_layers, check_mean_heads, state_words_gen, state_gen_word2tok_rel, state_last_image], outputs=[img_overlay_view, heatmap_view, html_visualization], ) check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer) check_mean_heads.change(toggle_slider, check_mean_heads, slider_head) if __name__ == "__main__": print(f"Device: {DEVICE}") demo.launch(debug=True)