"""
🖼️→📝 Chest X-ray Report Generation + Attention Visualizer + Classification
- Loads generation model (complete_model.safetensor)
- Loads classification model (classification.pth)
- Generates report and visualizes attention.
- Lists disease probabilities.
"""
import os
import re
import random
from typing import List, Tuple, Optional
import logging
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from safetensors.torch import load_model
from transformers import AutoModel, AutoImageProcessor
# Optional: nicer colormap
try:
import matplotlib as mpl
_HAS_MPL = True
_COLORMAP = mpl.colormaps.get_cmap("magma")
except Exception:
_HAS_MPL = False
_COLORMAP = None
# ========= Your utilities & model =========
from utils.processing import image_transform, pil_from_path
from utils.complete_model import create_complete_model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ==============================================================================
# 1. CLASSIFIER LOGIC (Added)
# ==============================================================================
class EmbeddingClassifier(nn.Module):
def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256),
activation="gelu", dropout=0.05, bn=False, use_layernorm=True):
super().__init__()
layers = []
layers.append(nn.Linear(embedding_dim, custom_dims[0]))
if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0]))
elif bn: layers.append(nn.BatchNorm1d(custom_dims[0]))
layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
if dropout > 0: layers.append(nn.Dropout(dropout))
for i in range(len(custom_dims) - 1):
layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1]))
if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1]))
elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1]))
layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
if dropout > 0: layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(custom_dims[-1], num_classes))
self.classifier = nn.Sequential(*layers)
def forward(self, embeddings):
return self.classifier(embeddings)
class ChestXrayPredictor:
def __init__(self, base_model, classifier, processor, label_cols, device):
self.base_model = base_model
self.classifier = classifier
self.processor = processor
self.label_cols = label_cols
self.device = device
self.base_model.eval()
self.classifier.eval()
def predict(self, image_source):
try:
if isinstance(image_source, str):
image = Image.open(image_source).convert('RGB')
else:
image = image_source.convert('RGB')
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(self.device)
with torch.no_grad():
outputs = self.base_model(pixel_values=pixel_values)
if hasattr(outputs, 'last_hidden_state'):
embeddings = outputs.last_hidden_state.mean(dim=1)
else:
embeddings = outputs[0].mean(dim=1)
logits = self.classifier(embeddings)
probs = torch.sigmoid(logits).cpu().numpy()[0].tolist()
return {label: prob for label, prob in zip(self.label_cols, probs)}
except Exception as e:
print(f"Prediction Error: {e}")
return {}
def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None):
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Loading Classifier from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
label_cols = checkpoint.get('label_cols', [
"Cardiomegaly", "Consolidation", "Edema",
"Atelectasis", "Pleural Effusion", "No Findings"
])
base_model = AutoModel.from_pretrained(model_id).to(device)
if 'base_model_state_dict' in checkpoint:
base_model.load_state_dict(checkpoint['base_model_state_dict'])
processor = AutoImageProcessor.from_pretrained(model_id)
# Detect dims
with torch.no_grad():
dummy = torch.randn(1, 3, 224, 224).to(device)
out = base_model(pixel_values=dummy)
embedding_dim = out.last_hidden_state.shape[-1]
# Rebuild MLP
model_state = checkpoint['model_state_dict']
linear_layers = []
for key, val in model_state.items():
if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2:
match = re.search(r'classifier\.(\d+)\.weight', key)
if match:
linear_layers.append((int(match.group(1)), val.shape[1], val.shape[0]))
linear_layers.sort(key=lambda x: x[0])
num_classes = linear_layers[-1][2]
hidden_dims = tuple([x[2] for x in linear_layers[:-1]])
uses_bn = any('running_mean' in k for k in model_state.keys())
has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k)
classifier = EmbeddingClassifier(embedding_dim, num_classes, custom_dims=hidden_dims, bn=uses_bn, use_layernorm=(has_norm and not uses_bn))
classifier.load_state_dict(model_state)
classifier.to(device)
return ChestXrayPredictor(base_model, classifier, processor, label_cols, device)
# ==============================================================================
# 2. LOAD MODELS
# ==============================================================================
# A. Load Generator
print("Loading Generation Model...")
model = create_complete_model(device=DEVICE, attention_implementation="eager")
SAFETENSOR_PATH = "complete_model.safetensor"
try:
load_model(model, SAFETENSOR_PATH)
except Exception as e:
print(f"Error loading generation model: {e}")
model.eval()
# B. Load Classifier
print("Loading Classification Model...")
CLASSIFIER_PATH = "classification.pth"
classifier_model = None
try:
if os.path.exists(CLASSIFIER_PATH):
classifier_model = create_classifier(CLASSIFIER_PATH, device=DEVICE)
print("✅ Classifier loaded.")
else:
print(f"⚠️ Classifier not found at {CLASSIFIER_PATH}")
except Exception as e:
print(f"⚠️ Error loading classifier: {e}")
# --- Tokenizer setup ---
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is None:
raise ValueError("Expected `model.tokenizer` to exist.")
pad_id = getattr(tokenizer, "pad_token_id", None)
eos_id = getattr(tokenizer, "eos_token_id", None)
needs_resize = False
if pad_id is None or (eos_id is not None and pad_id == eos_id):
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
needs_resize = True
if needs_resize:
resize_fns = [
getattr(getattr(model, "decoder", None), "resize_token_embeddings", None),
getattr(model, "resize_token_embeddings", None),
]
for fn in resize_fns:
if callable(fn):
try:
fn(len(tokenizer))
break
except Exception:
pass
WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")
# ========= Logic =========
def model_heads_layers():
def _get(obj, *names, default=None):
for n in names:
if obj is None: return default
if hasattr(obj, n): return int(getattr(obj, n))
return default
cfg_candidates = [
getattr(model, "config", None),
getattr(getattr(model, "decoder", None), "config", None),
getattr(getattr(model, "lm_head", None), "config", None),
]
L = H = None
for cfg in cfg_candidates:
if L is None: L = _get(cfg, "num_hidden_layers", "n_layer")
if H is None: H = _get(cfg, "num_attention_heads", "n_head")
return max(1, L or 12), max(1, H or 12)
def get_attention_for_token_layer(attentions, token_index, layer_index, batch_index=0, head_index=0, mean_across_layers=True, mean_across_heads=True):
token_attention = attentions[token_index]
if mean_across_layers:
layer_attention = torch.stack(token_attention).mean(dim=0)
else:
layer_attention = token_attention[int(layer_index)]
batch_attention = layer_attention[int(batch_index)]
if mean_across_heads:
head_attention = batch_attention.mean(dim=0)
else:
head_attention = batch_attention[int(head_index)]
return head_attention.squeeze(0)
def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
if not token_ids: return [], []
toks = tokenizer.convert_ids_to_tokens(token_ids)
detok = tokenizer.convert_tokens_to_string(toks)
matches = list(re.finditer(WORD_RE, detok))
words = [m.group(0) for m in matches]
ends = [m.span()[1] for m in matches]
word2tok: List[int] = []
for we in ends:
prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False)
if not prefix_ids:
word2tok.append(0)
continue
last_idx = len(prefix_ids) - 1
last_idx = max(0, min(last_idx, len(token_ids) - 1))
word2tok.append(last_idx)
return words, word2tok
def _strip_trailing_special(ids: List[int]) -> List[int]:
specials = set(getattr(tokenizer, "all_special_ids", []) or [])
j = len(ids)
while j > 0 and ids[j - 1] in specials:
j -= 1
return ids[:j]
def generate_word_visualization_gen_only(words_gen, word_ends_rel, gen_attn_values, selected_token_rel_idx):
if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0:
return "
No text attention values.
"
starts = []
for i, end in enumerate(word_ends_rel):
if i == 0: starts.append(0)
else: starts.append(min(word_ends_rel[i - 1] + 1, end))
word_scores = []
T = len(gen_attn_values)
for i, end in enumerate(word_ends_rel):
start = starts[i]
if start > end: start = end
s = max(0, min(start, T - 1))
e = max(0, min(end, T - 1))
if e < s: s, e = e, s
word_scores.append(float(gen_attn_values[s:e + 1].sum()))
max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
selected_word_idx = None
for i, end in enumerate(word_ends_rel):
if selected_token_rel_idx <= end:
selected_word_idx = i
break
if selected_word_idx is None and word_ends_rel: selected_word_idx = len(word_ends_rel) - 1
spans = []
for i, w in enumerate(words_gen):
alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
bg = f"rgba(66,133,244,{alpha:.3f})"
border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
spans.append(f"{w}")
return f""
def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray:
if attn_1d.shape[0] < img_token_len:
img_part = np.zeros(img_token_len, dtype=float)
img_part[: attn_1d.shape[0]] = attn_1d
else:
img_part = attn_1d[:img_token_len]
mn, mx = float(img_part.min()), float(img_part.max())
denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0
norm = (img_part - mn) / denom
return (norm.reshape(side, side) * 255.0).astype(np.uint8)
def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image:
if _HAS_MPL and _COLORMAP is not None:
colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8)
return Image.fromarray(colored)
else:
g = heatmap_u8.astype(np.float32) / 255.0
r = (g * 255.0).clip(0, 255).astype(np.uint8)
g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8)
b = np.zeros_like(r, dtype=np.uint8)
rgb = np.stack([r, g2, b], axis=-1)
return Image.fromarray(rgb)
def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
return img.resize(target_size, resample=Image.BILINEAR)
def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image:
if heatmap_rgb.size != orig.size:
heatmap_rgb = _resize_like(heatmap_rgb, orig.size)
base = orig.convert("RGBA")
overlay = heatmap_rgb.convert("RGBA")
r, g, b = overlay.split()[:3]
a = Image.new("L", overlay.size, int(alpha * 255))
overlay = Image.merge("RGBA", (r, g, b, a))
return Image.alpha_composite(base, overlay).convert("RGB")
def _prepare_image_tensor(pil_img, img_size=512):
tfm = image_transform(img_size=img_size)
tens = tfm(pil_img).unsqueeze(0).to(DEVICE, non_blocking=True)
return tens
def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads):
if pil_image is None:
blank = Image.new("RGB", (256, 256), "black")
return (None, None, 1024, None, None, gr.update(choices=[], value=None), blank, blank, np.zeros((256, 256, 3), dtype=np.uint8), "Upload image first.
")
pixel_values = _prepare_image_tensor(pil_image, img_size=512)
with torch.no_grad():
gen_ids, gen_text, attentions = model.generate(pixel_values=pixel_values, max_new_tokens=int(max_new_tokens), output_attentions=True)
if isinstance(gen_ids, torch.Tensor): gen_ids = gen_ids[0].tolist()
gen_ids = _strip_trailing_special(gen_ids)
words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids)
display_choices = [(w, i) for i, w in enumerate(words_gen)]
if not display_choices:
blank_hm = np.zeros((32, 32), dtype=np.uint8)
hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST)
overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35)
return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=[], value=None), pil_image, overlay, np.array(hm_rgb), "No tokens.
")
first_idx = 0
hm_rgb_init, overlay_init, html_init = update_visualization(first_idx, attentions, gen_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image)
return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=display_choices, value=first_idx), pil_image, overlay_init, hm_rgb_init, html_init)
def update_visualization(selected_gen_index, attentions, gen_token_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image: Optional[Image.Image] = None):
if selected_gen_index is None or attentions is None or gen_word2tok_rel is None:
blank = np.zeros((256, 256, 3), dtype=np.uint8)
return Image.fromarray(blank), Image.fromarray(blank), "Generate first.
"
gidx = int(selected_gen_index)
if not (0 <= gidx < len(gen_word2tok_rel)):
blank = np.zeros((256, 256, 3), dtype=np.uint8)
return Image.fromarray(blank), Image.fromarray(blank), "Invalid selection.
"
step_index = int(gen_word2tok_rel[gidx])
if not attentions or step_index >= len(attentions):
blank = np.zeros((256, 256, 3), dtype=np.uint8)
return Image.fromarray(blank), Image.fromarray(blank), "No attention.
"
token_attn = get_attention_for_token_layer(attentions, token_index=step_index, layer_index=int(layer), head_index=int(head), mean_across_layers=bool(mean_layers), mean_across_heads=bool(mean_heads))
attn_vals = token_attn.detach().cpu().numpy()
if attn_vals.ndim == 2: attn_vals = attn_vals[-1]
heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32)
hm_rgb_pil = _colorize_heatmap(heatmap_u8)
if pil_image is None: pil_image = Image.new("RGB", (256, 256), "black")
hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST)
overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35)
k_len = int(attn_vals.shape[0])
observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024)))
total_gen = len(gen_token_ids)
gen_vec = np.zeros(total_gen, dtype=float)
if observed_gen > 0:
start = 1024
end = min(1024 + observed_gen, k_len)
gen_slice = attn_vals[start:end]
gen_vec[: len(gen_slice)] = gen_slice
html_words = generate_word_visualization_gen_only(words_gen, gen_word2tok_rel, gen_vec, step_index)
return np.array(hm_rgb_pil_up), overlay_pil, html_words
def toggle_slider(is_mean):
return gr.update(interactive=not bool(is_mean))
# ========= Gradio UI =========
EXAMPLES_DIR = "examples"
with gr.Blocks() as demo:
gr.Markdown("# 🖼️→📝 Chest X-ray Report Generation & Classification")
# States
state_attentions = gr.State(None)
state_gen_token_ids = gr.State(None)
state_img_token_len = gr.State(1024)
state_words_gen = gr.State(None)
state_gen_word2tok_rel = gr.State(None)
state_last_image = gr.State(None)
L, H = model_heads_layers()
with gr.Row():
# LEFT COLUMN
with gr.Column(scale=1):
gr.Markdown("### 1) Input")
img_input = gr.Image(type="pil", label="Upload image", height=280)
btn_load_sample = gr.Button("Load random sample", variant="secondary")
sample_status = gr.Markdown("")
gr.Markdown("### 2) Generation Settings")
slider_max_tokens = gr.Slider(5, 200, value=100, step=5, label="Max New Tokens")
btn_generate = gr.Button("GENERATE REPORT & CLASSIFY", variant="primary")
gr.Markdown("### 3) Attention Visualization")
check_mean_layers = gr.Checkbox(False, label="Mean Across Layers")
check_mean_heads = gr.Checkbox(False, label="Mean Across Heads")
slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=True)
slider_head = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head", interactive=True)
# --- NEW CLASSIFICATION SECTION ---
gr.Markdown("### 4) Disease Probability")
classification_output = gr.Dataframe(
headers=["Disease", "Probability"],
datatype=["str", "str"],
label="Predictions",
interactive=False
)
# ----------------------------------
# RIGHT COLUMN
with gr.Column(scale=3):
with gr.Row():
img_original_view = gr.Image(label="Original", image_mode="RGB", height=256)
img_overlay_view = gr.Image(label="Attention Overlay", image_mode="RGB", height=256)
heatmap_view = gr.Image(label="Heatmap", image_mode="RGB", height=256)
radio_word_selector = gr.Radio([], label="Select Generated Word", info="Shows attention for this word")
html_visualization = gr.HTML("Text attention visualization will appear here.
")
# Sample loader
def _load_sample_from_examples():
try:
files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")]
if not files: return gr.update(), "No files."
fp = os.path.join(EXAMPLES_DIR, random.choice(files))
pil_img = pil_from_path(fp)
return gr.update(value=pil_img), f"Loaded: {os.path.basename(fp)}"
except Exception as e:
return gr.update(), f"Error: {e}"
btn_load_sample.click(_load_sample_from_examples, inputs=[], outputs=[img_input, sample_status])
# MAIN RUN FUNCTION (GENERATION + CLASSIFICATION)
def _run_all_logic(pil_image, *args):
# 1. Run Generation (returns 10 items)
gen_results = run_generation(pil_image, *args)
# 2. Run Classification
classification_data = []
if pil_image and classifier_model:
try:
preds = classifier_model.predict(pil_image)
# Sort by probability descending
sorted_preds = sorted(preds.items(), key=lambda x: x[1], reverse=True)
# Format as list of lists for Gradio Dataframe: ["Name", "95.5%"]
classification_data = [[k, f"{v:.1f}%"] for k, v in sorted_preds]
except Exception as e:
print(f"Classification runtime error: {e}")
classification_data = [["Error", str(e)]]
# Combine: gen_results + original_image (for state) + classification_data
return (*gen_results, pil_image, classification_data)
btn_generate.click(
fn=_run_all_logic,
inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads],
outputs=[
state_attentions,
state_gen_token_ids,
state_img_token_len,
state_words_gen,
state_gen_word2tok_rel,
radio_word_selector,
img_original_view,
img_overlay_view,
heatmap_view,
html_visualization,
state_last_image, # Added to outputs
classification_output # Added to outputs
],
)
# UI updates for visualizer controls
def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img):
hm_rgb, overlay, html = update_visualization(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, pil_image=last_img)
return overlay, hm_rgb, html
for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
control.change(
fn=_update_wrapper,
inputs=[radio_word_selector, state_attentions, state_gen_token_ids, slider_layer, slider_head, check_mean_layers, check_mean_heads, state_words_gen, state_gen_word2tok_rel, state_last_image],
outputs=[img_overlay_view, heatmap_view, html_visualization],
)
check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)
if __name__ == "__main__":
print(f"Device: {DEVICE}")
demo.launch(debug=True)