api / frontend /AttentionExplorer.tsx
gary-boon
Deploy Visualisable.ai backend with API protection
c6c8587
raw
history blame
35.3 kB
"use client";
import { useEffect, useRef, useState } from "react";
import * as d3 from "d3";
import { useWebSocket } from "@/lib/websocket-client";
import { getApiUrl } from "@/lib/config";
import { Download, Layers, RefreshCw, ZoomIn, ZoomOut, Eye, Activity, HelpCircle, X, Info, Zap } from "lucide-react";
import { TraceData } from "@/lib/types";
interface AttentionData {
layer: string;
weights: number[][];
tokens?: string[];
max_weight: number;
entropy?: number;
timestamp: number;
}
export default function AttentionExplorer() {
const { lastMessage, traces, isConnected } = useWebSocket();
const svgRef = useRef<SVGSVGElement>(null);
const containerRef = useRef<HTMLDivElement>(null);
const [selectedLayer, setSelectedLayer] = useState<string>("");
const [attentionData, setAttentionData] = useState<AttentionData[]>([]);
const [currentAttention, setCurrentAttention] = useState<AttentionData | null>(null);
const [availableLayers, setAvailableLayers] = useState<string[]>([]);
const [zoom, setZoom] = useState(1);
const [hoveredCell, setHoveredCell] = useState<{row: number, col: number, value: number} | null>(null);
const [colorScheme, setColorScheme] = useState<'blues' | 'viridis' | 'plasma'>('blues');
const [showExplanation, setShowExplanation] = useState(false);
const [prompt, setPrompt] = useState("def fibonacci(n):");
const [isGenerating, setIsGenerating] = useState(false);
// Listen for demo completions from LocalControlPanel
useEffect(() => {
const handleDemoCompleted = (event: CustomEvent) => {
const data = event.detail;
console.log('Demo completed event received:', data);
// Process demo traces if they exist
if (data && data.traces) {
const attentionTraces = data.traces
.filter((t: TraceData) => t.type === 'attention' && t.weights)
.map((t: TraceData) => ({
layer: t.layer || 'unknown',
weights: t.weights || [],
max_weight: t.max_weight || 1,
entropy: t.entropy,
timestamp: t.timestamp || Date.now(),
tokens: t.tokens
}));
if (attentionTraces.length > 0) {
console.log('Processing demo attention traces:', attentionTraces.length);
const layerSet = new Set(attentionTraces.map((a: AttentionData) => a.layer));
const uniqueLayers = (Array.from(layerSet) as string[])
.sort((a: string, b: string) => {
const numA = parseInt(a.replace('layer.', ''));
const numB = parseInt(b.replace('layer.', ''));
return numA - numB;
});
// Update state with demo data
setAttentionData(attentionTraces);
setAvailableLayers(uniqueLayers);
setSelectedLayer(uniqueLayers[0]);
}
}
};
window.addEventListener('demo-completed', handleDemoCompleted as EventListener);
return () => window.removeEventListener('demo-completed', handleDemoCompleted as EventListener);
}, []);
// Collect attention traces from WebSocket (only if we have traces)
useEffect(() => {
// Only process WebSocket traces if we actually have some
if (traces && traces.length > 0) {
const attentionTraces = traces
.filter(t => t.type === 'attention' && t.weights)
.map(t => ({
layer: t.layer || 'unknown',
weights: t.weights || [],
max_weight: t.max_weight || 1,
entropy: t.entropy,
timestamp: t.timestamp || Date.now(),
tokens: t.tokens
}));
if (attentionTraces.length > 0) {
setAttentionData(attentionTraces);
// Auto-select first layer if none selected
if (!selectedLayer) {
const uniqueLayers = Array.from(new Set(attentionTraces.map(a => a.layer)))
.sort((a: string, b: string) => {
const numA = parseInt(a.replace('layer.', ''));
const numB = parseInt(b.replace('layer.', ''));
return numA - numB;
});
setAvailableLayers(uniqueLayers);
setSelectedLayer(uniqueLayers[0]);
}
}
}
}, [traces]); // Remove selectedLayer from dependencies to avoid clearing data
// Update current attention when layer selection changes
useEffect(() => {
console.log('Layer selection changed to:', selectedLayer);
console.log('Available attention data:', attentionData.length, 'items');
if (selectedLayer && attentionData.length > 0) {
// Get the most recent attention data for the selected layer
const matchingData = attentionData.filter(a => a.layer === selectedLayer);
console.log('Found', matchingData.length, 'matching items for layer', selectedLayer);
const layerData = matchingData
.sort((a, b) => b.timestamp - a.timestamp)[0];
if (layerData) {
console.log('Setting current attention for layer:', selectedLayer);
console.log('Weights shape:', layerData.weights.length, 'x', layerData.weights[0]?.length);
console.log('Max weight:', layerData.max_weight);
console.log('Entropy:', layerData.entropy);
setCurrentAttention(layerData);
} else {
console.log('No data found for layer:', selectedLayer);
setCurrentAttention(null);
}
}
}, [selectedLayer, attentionData]);
// D3 Heatmap Visualization
useEffect(() => {
if (!currentAttention || !svgRef.current || !containerRef.current) return;
const container = containerRef.current;
const margin = { top: 120, right: 120, bottom: 100, left: 100 };
const cellSize = 20; // Size of each cell in the heatmap
const weights = currentAttention.weights;
const numRows = weights.length;
const numCols = weights[0]?.length || 0;
if (numRows === 0 || numCols === 0) return;
const width = numCols * cellSize;
const height = numRows * cellSize;
// Clear previous visualization
d3.select(svgRef.current).selectAll("*").remove();
const svg = d3.select(svgRef.current)
.attr("width", (width + margin.left + margin.right) * zoom)
.attr("height", (height + margin.top + margin.bottom) * zoom)
.attr("viewBox", `0 0 ${width + margin.left + margin.right} ${height + margin.top + margin.bottom}`);
const g = svg.append("g")
.attr("transform", `translate(${margin.left},${margin.top})`);
// Create scales
const xScale = d3.scaleBand()
.domain(d3.range(numCols).map(String))
.range([0, width])
.padding(0.01);
const yScale = d3.scaleBand()
.domain(d3.range(numRows).map(String))
.range([0, height])
.padding(0.01);
// Color scale based on selected scheme
// Ensure max_weight is valid, default to 1 if not
const maxWeight = currentAttention.max_weight > 0 ? currentAttention.max_weight : 1.0;
let colorScale: d3.ScaleSequential<string>;
if (colorScheme === 'viridis') {
colorScale = d3.scaleSequential(d3.interpolateViridis)
.domain([0, maxWeight]);
} else if (colorScheme === 'plasma') {
colorScale = d3.scaleSequential(d3.interpolatePlasma)
.domain([0, maxWeight]);
} else {
colorScale = d3.scaleSequential(d3.interpolateBlues)
.domain([0, maxWeight]);
}
// Create tooltip
const tooltip = d3.select("body").append("div")
.attr("class", "attention-tooltip")
.style("opacity", 0)
.style("position", "absolute")
.style("background", "rgba(0, 0, 0, 0.9)")
.style("color", "white")
.style("padding", "10px")
.style("border-radius", "6px")
.style("font-size", "12px")
.style("pointer-events", "none")
.style("z-index", "1000");
// Draw cells
const cells = g.selectAll(".cell")
.data(weights.flatMap((row, i) =>
row.map((value, j) => ({ row: i, col: j, value }))
))
.enter().append("rect")
.attr("class", "cell")
.attr("x", d => xScale(String(d.col))!)
.attr("y", d => yScale(String(d.row))!)
.attr("width", xScale.bandwidth())
.attr("height", yScale.bandwidth())
.attr("fill", d => colorScale(d.value))
.attr("stroke", "#1f2937")
.attr("stroke-width", 0.5)
.style("cursor", "pointer")
.on("mouseover", function(event, d) {
// Show tooltip
tooltip.transition().duration(200).style("opacity", .95);
const tokenFrom = currentAttention.tokens?.[d.row] || `Token ${d.row}`;
const tokenTo = currentAttention.tokens?.[d.col] || `Token ${d.col}`;
tooltip.html(`
<div style="font-weight: bold; margin-bottom: 5px;">Attention Weight</div>
<div>From: ${tokenFrom}</div>
<div>To: ${tokenTo}</div>
<div style="margin-top: 5px; color: #60a5fa;">Weight: ${d.value.toFixed(4)}</div>
`)
.style("left", (event.pageX + 10) + "px")
.style("top", (event.pageY - 28) + "px");
setHoveredCell({ row: d.row, col: d.col, value: d.value });
// Highlight cell
d3.select(this)
.attr("stroke", "#3b82f6")
.attr("stroke-width", 2);
// Highlight row and column headers
d3.selectAll(`.row-label-${d.row}`).style("fill", "#3b82f6").style("font-weight", "bold");
d3.selectAll(`.col-label-${d.col}`).style("fill", "#3b82f6").style("font-weight", "bold");
})
.on("mouseout", function(event, d) {
tooltip.transition().duration(500).style("opacity", 0);
setHoveredCell(null);
d3.select(this)
.attr("stroke", "#1f2937")
.attr("stroke-width", 0.5);
// Reset headers
d3.selectAll(`.row-label-${d.row}`).style("fill", "#9ca3af").style("font-weight", "normal");
d3.selectAll(`.col-label-${d.col}`).style("fill", "#9ca3af").style("font-weight", "normal");
});
// Add row labels
g.selectAll(".row-label")
.data(d3.range(numRows))
.enter().append("text")
.attr("class", d => `row-label row-label-${d}`)
.attr("x", -10)
.attr("y", d => yScale(String(d))! + yScale.bandwidth() / 2)
.attr("text-anchor", "end")
.attr("dominant-baseline", "middle")
.style("font-size", "11px")
.style("fill", "#9ca3af")
.style("font-family", "monospace")
.text(d => currentAttention.tokens?.[d] || `T${d}`);
// Add column labels
g.selectAll(".col-label")
.data(d3.range(numCols))
.enter().append("text")
.attr("class", d => `col-label col-label-${d}`)
.attr("x", d => xScale(String(d))! + xScale.bandwidth() / 2)
.attr("y", -10)
.attr("text-anchor", "middle")
.style("font-size", "11px")
.style("fill", "#9ca3af")
.style("font-family", "monospace")
.text(d => currentAttention.tokens?.[d] || `T${d}`);
// Add title - positioned higher to avoid overlap
svg.append("text")
.attr("x", (width + margin.left + margin.right) / 2)
.attr("y", 30)
.attr("text-anchor", "middle")
.style("font-size", "18px")
.style("font-weight", "bold")
.style("fill", "#fff")
.text(`Attention Weights - ${currentAttention.layer}`);
// Add axis labels
svg.append("text")
.attr("x", (width + margin.left + margin.right) / 2)
.attr("y", height + margin.top + margin.bottom - 20)
.attr("text-anchor", "middle")
.style("font-size", "14px")
.style("fill", "#9ca3af")
.text("Target Tokens →");
svg.append("text")
.attr("transform", "rotate(-90)")
.attr("x", -(height + margin.top + margin.bottom) / 2)
.attr("y", 20)
.attr("text-anchor", "middle")
.style("font-size", "14px")
.style("fill", "#9ca3af")
.text("Source Tokens →");
// Add color legend - positioned in top right corner
const legendWidth = 150;
const legendHeight = 15;
const legendScale = d3.scaleLinear()
.domain([0, currentAttention.max_weight])
.range([0, legendWidth]);
const legendAxis = d3.axisBottom(legendScale)
.ticks(4)
.tickFormat(d => (d as number).toFixed(2));
const legend = svg.append("g")
.attr("transform", `translate(${width + margin.left - legendWidth}, ${60})`);
// Create gradient for legend
const gradientId = `attention-gradient-${Date.now()}`;
const gradient = svg.append("defs")
.append("linearGradient")
.attr("id", gradientId)
.attr("x1", "0%")
.attr("x2", "100%");
for (let i = 0; i <= 20; i++) {
const t = i / 20;
gradient.append("stop")
.attr("offset", `${t * 100}%`)
.attr("stop-color", colorScale(t * currentAttention.max_weight));
}
legend.append("rect")
.attr("width", legendWidth)
.attr("height", legendHeight)
.style("fill", `url(#${gradientId})`)
.style("stroke", "#4b5563")
.style("stroke-width", 0.5);
legend.append("g")
.attr("transform", `translate(0, ${legendHeight})`)
.call(legendAxis)
.selectAll("text")
.style("fill", "#9ca3af")
.style("font-size", "10px");
legend.append("text")
.attr("x", legendWidth / 2)
.attr("y", -5)
.attr("text-anchor", "middle")
.style("font-size", "11px")
.style("fill", "#9ca3af")
.text("Attention Weight");
// Cleanup
return () => {
tooltip.remove();
};
}, [currentAttention, zoom, colorScheme]);
// Update available layers whenever attentionData changes
useEffect(() => {
if (attentionData.length > 0) {
const layers = Array.from(new Set(attentionData.map(a => a.layer)))
.sort((a, b) => {
// Sort layers numerically (layer.0, layer.1, etc.)
const numA = parseInt(a.replace('layer.', ''));
const numB = parseInt(b.replace('layer.', ''));
return numA - numB;
});
console.log('Updating available layers from attentionData:', layers);
setAvailableLayers(layers);
}
}, [attentionData]);
// Export functionality
const exportAttentionMap = () => {
if (!currentAttention) return;
const dataStr = JSON.stringify(currentAttention, null, 2);
const dataUri = 'data:application/json;charset=utf-8,'+ encodeURIComponent(dataStr);
const exportFileDefaultName = `attention_${currentAttention.layer.replace(/\./g, '_')}_${Date.now()}.json`;
const linkElement = document.createElement('a');
linkElement.setAttribute('href', dataUri);
linkElement.setAttribute('download', exportFileDefaultName);
linkElement.click();
};
// Generate code with attention traces
const handleGenerate = async () => {
if (!prompt || isGenerating) return;
setIsGenerating(true);
try {
const response = await fetch(`${getApiUrl()}/generate`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
prompt,
max_tokens: 100,
temperature: 0.7,
sampling_rate: 0.3 // Sample 30% of tokens for attention
})
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const data = await response.json();
console.log('Generation completed:', data);
// Process any attention traces from the response
if (data.traces) {
const attentionTraces = data.traces
.filter((t: TraceData) => t.type === 'attention' && t.weights)
.map((t: TraceData) => ({
layer: t.layer || 'unknown',
weights: t.weights || [],
max_weight: t.max_weight || 1,
entropy: t.entropy,
timestamp: t.timestamp || Date.now(),
tokens: t.tokens
}));
if (attentionTraces.length > 0) {
console.log('Setting attention data:', attentionTraces.length, 'traces');
const layerSet = new Set(attentionTraces.map((a: AttentionData) => a.layer));
const uniqueLayers = (Array.from(layerSet) as string[])
.sort((a: string, b: string) => {
const numA = parseInt(a.replace('layer.', ''));
const numB = parseInt(b.replace('layer.', ''));
return numA - numB;
});
console.log('Unique layers found from HTTP response:', uniqueLayers);
// Set all state together
setAttentionData(attentionTraces);
setAvailableLayers(uniqueLayers);
setSelectedLayer(uniqueLayers[0]);
console.log('State updated with', attentionTraces.length, 'traces and', uniqueLayers.length, 'layers');
}
}
} catch (error) {
console.error('Generation failed:', error);
alert(`Generation failed: ${error}`);
} finally {
setIsGenerating(false);
}
};
// Export as image
const exportAsImage = () => {
if (!svgRef.current) return;
const svgData = new XMLSerializer().serializeToString(svgRef.current);
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
const img = new Image();
img.onload = () => {
canvas.width = img.width;
canvas.height = img.height;
ctx?.drawImage(img, 0, 0);
canvas.toBlob((blob) => {
if (blob) {
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.download = `attention_${currentAttention?.layer.replace(/\./g, '_')}_${Date.now()}.png`;
link.href = url;
link.click();
}
});
};
img.src = 'data:image/svg+xml;base64,' + btoa(svgData);
};
// Generate contextual explanation for current visualization
const generateExplanation = () => {
if (!currentAttention) {
return {
title: "No Attention Data",
description: "Run a model to see attention patterns between tokens.",
details: []
};
}
const maxWeight = currentAttention.max_weight;
const entropy = currentAttention.entropy;
const numTokens = currentAttention.weights.length;
const strongAttentions = currentAttention.weights.flat().filter(w => w > maxWeight * 0.5).length;
const totalCells = currentAttention.weights.flat().length;
const focusPercentage = ((strongAttentions / totalCells) * 100).toFixed(1);
return {
title: `Attention Pattern: ${currentAttention.layer}`,
description: `Viewing ${numTokens}×${numTokens} attention matrix showing how tokens attend to each other.`,
details: [
{
heading: "What is an Attention Heatmap?",
content: `This heatmap shows attention weights between tokens. Each cell represents how much one token "pays attention" to another. Brighter cells (higher values) mean stronger attention relationships.`
},
{
heading: "Reading the Matrix",
content: `Rows represent source tokens (what's attending), columns represent target tokens (what's being attended to). The diagonal often shows self-attention. Patterns reveal how the model processes relationships between words.`
},
{
heading: "Current Pattern Analysis",
content: `${focusPercentage}% of attention weights are strong (>50% of max). ${entropy ? `Entropy: ${entropy.toFixed(2)} - ${entropy < 2 ? 'Focused attention' : entropy < 4 ? 'Distributed attention' : 'Scattered attention'}` : 'Pattern shows attention distribution across tokens.'}`
},
{
heading: "Color Intensity Meaning",
content: `Colors range from dark (0.00) to bright (${maxWeight.toFixed(2)}). Bright spots indicate strong dependencies. The model uses these weights to combine information from different positions.`
},
{
heading: "Common Patterns",
content: `Look for: Vertical lines (tokens everyone attends to), horizontal lines (tokens attending broadly), diagonal patterns (local/sequential attention), and block patterns (grouped attention).`
},
{
heading: "Layer Significance",
content: `${currentAttention.layer.includes('0') ? 'Early layers capture local patterns and syntax.' :
currentAttention.layer.includes('1') ? 'Middle layers build semantic relationships.' :
'Later layers form high-level representations.'} Each layer builds on previous ones.`
}
]
};
};
const explanation = generateExplanation();
return (
<div className="bg-gray-900 rounded-xl p-6">
<div className="flex items-center justify-between mb-6">
<div>
<h2 className="text-2xl font-bold flex items-center gap-2">
<Eye className="w-6 h-6 text-blue-400" />
Attention Explorer
</h2>
<p className="text-gray-400 mt-1">
Visualize attention patterns across transformer layers
</p>
</div>
<div className="flex items-center gap-4">
{/* Connection Status */}
<div className={`flex items-center gap-2 px-3 py-1 rounded-full ${
isConnected ? 'bg-green-900/30 text-green-400' : 'bg-red-900/30 text-red-400'
}`}>
<Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} />
{isConnected ? 'Connected' : 'Disconnected'}
</div>
</div>
</div>
{/* Prompt Input and Generate Button */}
<div className="flex gap-4 mb-4">
<div className="flex-1">
<textarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
placeholder="Enter a prompt to generate code (e.g., 'def fibonacci(n):')"
className="w-full px-4 py-2 bg-gray-800 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none resize-none"
rows={2}
disabled={isGenerating}
/>
</div>
<button
onClick={handleGenerate}
disabled={!prompt || isGenerating || !isConnected}
className="px-6 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2 h-fit"
>
{isGenerating ? (
<>
<RefreshCw className="w-5 h-5 animate-spin" />
Generating...
</>
) : (
<>
<Zap className="w-5 h-5" />
Generate & Trace
</>
)}
</button>
</div>
{/* Controls */}
<div className="flex flex-wrap items-center gap-4 mb-4">
{/* Layer Selector */}
<div className="flex items-center gap-2">
<Layers className="w-5 h-5 text-gray-400" />
<select
value={selectedLayer}
onChange={(e) => setSelectedLayer(e.target.value)}
className="bg-gray-800 text-white px-3 py-1.5 rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none min-w-[200px]"
disabled={availableLayers.length === 0}
>
{availableLayers.length === 0 ? (
<option value="">No layers available</option>
) : (
availableLayers.map(layer => (
<option key={layer} value={layer}>{layer}</option>
))
)}
</select>
</div>
{/* Color Scheme Selector */}
<div className="flex items-center gap-2">
<span className="text-gray-400 text-sm">Color:</span>
<div className="flex gap-1">
<button
onClick={() => setColorScheme('blues')}
className={`px-2 py-1 text-xs rounded ${colorScheme === 'blues' ? 'bg-blue-600' : 'bg-gray-700'}`}
>
Blues
</button>
<button
onClick={() => setColorScheme('viridis')}
className={`px-2 py-1 text-xs rounded ${colorScheme === 'viridis' ? 'bg-blue-600' : 'bg-gray-700'}`}
>
Viridis
</button>
<button
onClick={() => setColorScheme('plasma')}
className={`px-2 py-1 text-xs rounded ${colorScheme === 'plasma' ? 'bg-blue-600' : 'bg-gray-700'}`}
>
Plasma
</button>
</div>
</div>
{/* Zoom Controls */}
<div className="flex items-center gap-2">
<button
onClick={() => setZoom(Math.max(0.5, zoom - 0.1))}
className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors"
title="Zoom Out"
>
<ZoomOut className="w-4 h-4" />
</button>
<span className="text-sm text-gray-400 min-w-[50px] text-center">{(zoom * 100).toFixed(0)}%</span>
<button
onClick={() => setZoom(Math.min(2, zoom + 0.1))}
className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors"
title="Zoom In"
>
<ZoomIn className="w-4 h-4" />
</button>
<button
onClick={() => setZoom(1)}
className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors"
title="Reset Zoom"
>
<RefreshCw className="w-4 h-4" />
</button>
</div>
{/* Export Buttons */}
<div className="flex gap-2 ml-auto">
<button
onClick={exportAttentionMap}
disabled={!currentAttention}
className="px-3 py-1.5 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors disabled:opacity-50 flex items-center gap-2"
title="Export as JSON"
>
<Download className="w-4 h-4" />
JSON
</button>
</div>
</div>
{/* Status Bar */}
<div className="flex items-center gap-4 mb-4 text-sm">
<div className="text-gray-400">
<span className="font-semibold">{attentionData.length}</span> attention maps captured
</div>
{hoveredCell && (
<div className="text-gray-400">
Cell [{hoveredCell.row}, {hoveredCell.col}]: <span className="text-blue-400 font-mono">{hoveredCell.value.toFixed(4)}</span>
</div>
)}
{currentAttention?.entropy !== undefined && (
<div className="text-gray-400">
Entropy: <span className="text-green-400 font-mono">{currentAttention.entropy.toFixed(3)}</span>
</div>
)}
</div>
{/* Main Content Area with Side Panel */}
<div className="flex gap-4">
{/* Visualization Container */}
<div className="flex-1 transition-all duration-500 ease-in-out">
<div ref={containerRef} className="bg-gray-800 rounded-lg p-4 overflow-auto max-h-[600px] relative">
{/* Help Toggle Button */}
<button
onClick={() => setShowExplanation(!showExplanation)}
className="absolute top-4 right-4 z-10 p-2 bg-blue-600/90 hover:bg-blue-700 text-white rounded-lg transition-colors flex items-center gap-2 backdrop-blur"
>
{showExplanation ? <X className="w-5 h-5" /> : <HelpCircle className="w-5 h-5" />}
<span className="text-sm font-medium">
{showExplanation ? 'Hide Info' : 'What am I seeing?'}
</span>
</button>
{currentAttention ? (
<svg ref={svgRef}></svg>
) : (
<div className="flex items-center justify-center h-96 text-gray-500">
<div className="text-center">
<Eye className="w-12 h-12 mx-auto mb-4 opacity-50" />
<p className="text-lg mb-2">No Attention Data Available</p>
<p className="text-sm mb-4">Run a model to capture attention patterns</p>
<p className="text-xs text-gray-600">
Attention maps will appear here when traces are received
</p>
</div>
</div>
)}
</div>
</div>
{/* Explanation Side Panel */}
<div className={`${showExplanation ? 'w-96' : 'w-0'} transition-all duration-500 ease-in-out overflow-hidden`}>
<div className="w-96 h-[600px] bg-gray-900 rounded-lg border border-gray-700">
{/* Panel Header */}
<div className="bg-gray-800 px-4 py-3 border-b border-gray-700">
<div className="flex items-center gap-2">
<Info className="w-5 h-5 text-blue-400" />
<h3 className="text-lg font-semibold text-white">Understanding Attention</h3>
</div>
</div>
{/* Panel Content */}
<div className="px-4 py-4 overflow-y-auto h-[calc(600px-60px)]">
{/* Main Description */}
<div className="mb-4 p-3 bg-blue-900/20 border border-blue-800 rounded-lg">
<h4 className="text-sm font-semibold text-blue-400 mb-1">{explanation.title}</h4>
<p className="text-xs text-gray-300">{explanation.description}</p>
</div>
{/* Explanation Sections */}
<div className="space-y-3">
{explanation.details.map((section, idx) => (
<div key={idx} className="bg-gray-800 rounded-lg p-3">
<h5 className="font-medium text-sm text-white mb-1 flex items-center gap-1">
<Zap className="w-3 h-3 text-yellow-400" />
{section.heading}
</h5>
<p className="text-xs text-gray-300 leading-relaxed">{section.content}</p>
</div>
))}
</div>
{/* Visual Guide */}
<div className="mt-4 p-3 bg-purple-900/20 border border-purple-800 rounded-lg">
<h4 className="font-medium text-sm text-purple-400 mb-2">Visual Guide</h4>
<div className="space-y-2 text-xs">
<div className="flex items-start gap-2">
<span className="text-purple-300"></span>
<span className="text-gray-300">Rows = Source tokens (what&apos;s attending)</span>
</div>
<div className="flex items-start gap-2">
<span className="text-purple-300"></span>
<span className="text-gray-300">Columns = Target tokens (what&apos;s being attended to)</span>
</div>
<div className="flex items-start gap-2">
<span className="text-purple-300"></span>
<span className="text-gray-300">Brightness = Attention strength (0 to {currentAttention?.max_weight.toFixed(2) || '1.00'})</span>
</div>
<div className="flex items-start gap-2">
<span className="text-purple-300"></span>
<span className="text-gray-300">Hover over cells for detailed weights</span>
</div>
</div>
</div>
{/* Current Metrics */}
{currentAttention && (
<div className="mt-4 p-3 bg-gray-800 rounded-lg">
<h4 className="font-medium text-sm text-gray-300 mb-2">Current Metrics</h4>
<div className="space-y-1 text-xs">
<div className="flex justify-between">
<span className="text-gray-400">Layer:</span>
<span className="text-white font-mono">{currentAttention.layer}</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">Matrix Size:</span>
<span className="text-white">{currentAttention.weights.length} × {currentAttention.weights[0]?.length || 0}</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">Max Weight:</span>
<span className="text-blue-400">{currentAttention.max_weight.toFixed(4)}</span>
</div>
{currentAttention.entropy !== undefined && (
<div className="flex justify-between">
<span className="text-gray-400">Entropy:</span>
<span className="text-green-400">{currentAttention.entropy.toFixed(3)}</span>
</div>
)}
<div className="flex justify-between">
<span className="text-gray-400">Timestamp:</span>
<span className="text-white">{new Date(currentAttention.timestamp).toLocaleTimeString()}</span>
</div>
</div>
</div>
)}
{/* Tips */}
<div className="mt-4 p-3 bg-gray-800 rounded-lg">
<h4 className="font-medium text-sm text-gray-300 mb-2">💡 Tips</h4>
<ul className="text-xs text-gray-400 space-y-1">
<li>• Use zoom controls to explore large matrices</li>
<li>• Switch color schemes for better contrast</li>
<li>• Compare patterns across different layers</li>
<li>• Look for diagonal patterns (self-attention)</li>
</ul>
</div>
</div>
</div>
</div>
</div>
{/* Info Panel */}
{currentAttention && (
<div className="mt-4 p-4 bg-gray-800 rounded-lg">
<h3 className="text-lg font-semibold mb-3">Attention Map Details</h3>
<div className="grid grid-cols-2 md:grid-cols-4 gap-4 text-sm">
<div>
<span className="text-gray-400">Layer:</span>
<div className="font-mono text-white mt-1">{currentAttention.layer}</div>
</div>
<div>
<span className="text-gray-400">Shape:</span>
<div className="font-mono text-white mt-1">
{currentAttention.weights.length} × {currentAttention.weights[0]?.length || 0}
</div>
</div>
<div>
<span className="text-gray-400">Max Weight:</span>
<div className="font-mono text-blue-400 mt-1">{currentAttention.max_weight.toFixed(4)}</div>
</div>
<div>
<span className="text-gray-400">Timestamp:</span>
<div className="font-mono text-white mt-1">
{new Date(currentAttention.timestamp).toLocaleTimeString()}
</div>
</div>
</div>
</div>
)}
</div>
);
}