Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| "use client"; | |
| import { useEffect, useRef, useState } from "react"; | |
| import * as d3 from "d3"; | |
| import { useWebSocket } from "@/lib/websocket-client"; | |
| import { getApiUrl } from "@/lib/config"; | |
| import { Download, Layers, RefreshCw, ZoomIn, ZoomOut, Eye, Activity, HelpCircle, X, Info, Zap } from "lucide-react"; | |
| import { TraceData } from "@/lib/types"; | |
| interface AttentionData { | |
| layer: string; | |
| weights: number[][]; | |
| tokens?: string[]; | |
| max_weight: number; | |
| entropy?: number; | |
| timestamp: number; | |
| } | |
| export default function AttentionExplorer() { | |
| const { lastMessage, traces, isConnected } = useWebSocket(); | |
| const svgRef = useRef<SVGSVGElement>(null); | |
| const containerRef = useRef<HTMLDivElement>(null); | |
| const [selectedLayer, setSelectedLayer] = useState<string>(""); | |
| const [attentionData, setAttentionData] = useState<AttentionData[]>([]); | |
| const [currentAttention, setCurrentAttention] = useState<AttentionData | null>(null); | |
| const [availableLayers, setAvailableLayers] = useState<string[]>([]); | |
| const [zoom, setZoom] = useState(1); | |
| const [hoveredCell, setHoveredCell] = useState<{row: number, col: number, value: number} | null>(null); | |
| const [colorScheme, setColorScheme] = useState<'blues' | 'viridis' | 'plasma'>('blues'); | |
| const [showExplanation, setShowExplanation] = useState(false); | |
| const [prompt, setPrompt] = useState("def fibonacci(n):"); | |
| const [isGenerating, setIsGenerating] = useState(false); | |
| // Listen for demo completions from LocalControlPanel | |
| useEffect(() => { | |
| const handleDemoCompleted = (event: CustomEvent) => { | |
| const data = event.detail; | |
| console.log('Demo completed event received:', data); | |
| // Process demo traces if they exist | |
| if (data && data.traces) { | |
| const attentionTraces = data.traces | |
| .filter((t: TraceData) => t.type === 'attention' && t.weights) | |
| .map((t: TraceData) => ({ | |
| layer: t.layer || 'unknown', | |
| weights: t.weights || [], | |
| max_weight: t.max_weight || 1, | |
| entropy: t.entropy, | |
| timestamp: t.timestamp || Date.now(), | |
| tokens: t.tokens | |
| })); | |
| if (attentionTraces.length > 0) { | |
| console.log('Processing demo attention traces:', attentionTraces.length); | |
| const layerSet = new Set(attentionTraces.map((a: AttentionData) => a.layer)); | |
| const uniqueLayers = (Array.from(layerSet) as string[]) | |
| .sort((a: string, b: string) => { | |
| const numA = parseInt(a.replace('layer.', '')); | |
| const numB = parseInt(b.replace('layer.', '')); | |
| return numA - numB; | |
| }); | |
| // Update state with demo data | |
| setAttentionData(attentionTraces); | |
| setAvailableLayers(uniqueLayers); | |
| setSelectedLayer(uniqueLayers[0]); | |
| } | |
| } | |
| }; | |
| window.addEventListener('demo-completed', handleDemoCompleted as EventListener); | |
| return () => window.removeEventListener('demo-completed', handleDemoCompleted as EventListener); | |
| }, []); | |
| // Collect attention traces from WebSocket (only if we have traces) | |
| useEffect(() => { | |
| // Only process WebSocket traces if we actually have some | |
| if (traces && traces.length > 0) { | |
| const attentionTraces = traces | |
| .filter(t => t.type === 'attention' && t.weights) | |
| .map(t => ({ | |
| layer: t.layer || 'unknown', | |
| weights: t.weights || [], | |
| max_weight: t.max_weight || 1, | |
| entropy: t.entropy, | |
| timestamp: t.timestamp || Date.now(), | |
| tokens: t.tokens | |
| })); | |
| if (attentionTraces.length > 0) { | |
| setAttentionData(attentionTraces); | |
| // Auto-select first layer if none selected | |
| if (!selectedLayer) { | |
| const uniqueLayers = Array.from(new Set(attentionTraces.map(a => a.layer))) | |
| .sort((a: string, b: string) => { | |
| const numA = parseInt(a.replace('layer.', '')); | |
| const numB = parseInt(b.replace('layer.', '')); | |
| return numA - numB; | |
| }); | |
| setAvailableLayers(uniqueLayers); | |
| setSelectedLayer(uniqueLayers[0]); | |
| } | |
| } | |
| } | |
| }, [traces]); // Remove selectedLayer from dependencies to avoid clearing data | |
| // Update current attention when layer selection changes | |
| useEffect(() => { | |
| console.log('Layer selection changed to:', selectedLayer); | |
| console.log('Available attention data:', attentionData.length, 'items'); | |
| if (selectedLayer && attentionData.length > 0) { | |
| // Get the most recent attention data for the selected layer | |
| const matchingData = attentionData.filter(a => a.layer === selectedLayer); | |
| console.log('Found', matchingData.length, 'matching items for layer', selectedLayer); | |
| const layerData = matchingData | |
| .sort((a, b) => b.timestamp - a.timestamp)[0]; | |
| if (layerData) { | |
| console.log('Setting current attention for layer:', selectedLayer); | |
| console.log('Weights shape:', layerData.weights.length, 'x', layerData.weights[0]?.length); | |
| console.log('Max weight:', layerData.max_weight); | |
| console.log('Entropy:', layerData.entropy); | |
| setCurrentAttention(layerData); | |
| } else { | |
| console.log('No data found for layer:', selectedLayer); | |
| setCurrentAttention(null); | |
| } | |
| } | |
| }, [selectedLayer, attentionData]); | |
| // D3 Heatmap Visualization | |
| useEffect(() => { | |
| if (!currentAttention || !svgRef.current || !containerRef.current) return; | |
| const container = containerRef.current; | |
| const margin = { top: 120, right: 120, bottom: 100, left: 100 }; | |
| const cellSize = 20; // Size of each cell in the heatmap | |
| const weights = currentAttention.weights; | |
| const numRows = weights.length; | |
| const numCols = weights[0]?.length || 0; | |
| if (numRows === 0 || numCols === 0) return; | |
| const width = numCols * cellSize; | |
| const height = numRows * cellSize; | |
| // Clear previous visualization | |
| d3.select(svgRef.current).selectAll("*").remove(); | |
| const svg = d3.select(svgRef.current) | |
| .attr("width", (width + margin.left + margin.right) * zoom) | |
| .attr("height", (height + margin.top + margin.bottom) * zoom) | |
| .attr("viewBox", `0 0 ${width + margin.left + margin.right} ${height + margin.top + margin.bottom}`); | |
| const g = svg.append("g") | |
| .attr("transform", `translate(${margin.left},${margin.top})`); | |
| // Create scales | |
| const xScale = d3.scaleBand() | |
| .domain(d3.range(numCols).map(String)) | |
| .range([0, width]) | |
| .padding(0.01); | |
| const yScale = d3.scaleBand() | |
| .domain(d3.range(numRows).map(String)) | |
| .range([0, height]) | |
| .padding(0.01); | |
| // Color scale based on selected scheme | |
| // Ensure max_weight is valid, default to 1 if not | |
| const maxWeight = currentAttention.max_weight > 0 ? currentAttention.max_weight : 1.0; | |
| let colorScale: d3.ScaleSequential<string>; | |
| if (colorScheme === 'viridis') { | |
| colorScale = d3.scaleSequential(d3.interpolateViridis) | |
| .domain([0, maxWeight]); | |
| } else if (colorScheme === 'plasma') { | |
| colorScale = d3.scaleSequential(d3.interpolatePlasma) | |
| .domain([0, maxWeight]); | |
| } else { | |
| colorScale = d3.scaleSequential(d3.interpolateBlues) | |
| .domain([0, maxWeight]); | |
| } | |
| // Create tooltip | |
| const tooltip = d3.select("body").append("div") | |
| .attr("class", "attention-tooltip") | |
| .style("opacity", 0) | |
| .style("position", "absolute") | |
| .style("background", "rgba(0, 0, 0, 0.9)") | |
| .style("color", "white") | |
| .style("padding", "10px") | |
| .style("border-radius", "6px") | |
| .style("font-size", "12px") | |
| .style("pointer-events", "none") | |
| .style("z-index", "1000"); | |
| // Draw cells | |
| const cells = g.selectAll(".cell") | |
| .data(weights.flatMap((row, i) => | |
| row.map((value, j) => ({ row: i, col: j, value })) | |
| )) | |
| .enter().append("rect") | |
| .attr("class", "cell") | |
| .attr("x", d => xScale(String(d.col))!) | |
| .attr("y", d => yScale(String(d.row))!) | |
| .attr("width", xScale.bandwidth()) | |
| .attr("height", yScale.bandwidth()) | |
| .attr("fill", d => colorScale(d.value)) | |
| .attr("stroke", "#1f2937") | |
| .attr("stroke-width", 0.5) | |
| .style("cursor", "pointer") | |
| .on("mouseover", function(event, d) { | |
| // Show tooltip | |
| tooltip.transition().duration(200).style("opacity", .95); | |
| const tokenFrom = currentAttention.tokens?.[d.row] || `Token ${d.row}`; | |
| const tokenTo = currentAttention.tokens?.[d.col] || `Token ${d.col}`; | |
| tooltip.html(` | |
| <div style="font-weight: bold; margin-bottom: 5px;">Attention Weight</div> | |
| <div>From: ${tokenFrom}</div> | |
| <div>To: ${tokenTo}</div> | |
| <div style="margin-top: 5px; color: #60a5fa;">Weight: ${d.value.toFixed(4)}</div> | |
| `) | |
| .style("left", (event.pageX + 10) + "px") | |
| .style("top", (event.pageY - 28) + "px"); | |
| setHoveredCell({ row: d.row, col: d.col, value: d.value }); | |
| // Highlight cell | |
| d3.select(this) | |
| .attr("stroke", "#3b82f6") | |
| .attr("stroke-width", 2); | |
| // Highlight row and column headers | |
| d3.selectAll(`.row-label-${d.row}`).style("fill", "#3b82f6").style("font-weight", "bold"); | |
| d3.selectAll(`.col-label-${d.col}`).style("fill", "#3b82f6").style("font-weight", "bold"); | |
| }) | |
| .on("mouseout", function(event, d) { | |
| tooltip.transition().duration(500).style("opacity", 0); | |
| setHoveredCell(null); | |
| d3.select(this) | |
| .attr("stroke", "#1f2937") | |
| .attr("stroke-width", 0.5); | |
| // Reset headers | |
| d3.selectAll(`.row-label-${d.row}`).style("fill", "#9ca3af").style("font-weight", "normal"); | |
| d3.selectAll(`.col-label-${d.col}`).style("fill", "#9ca3af").style("font-weight", "normal"); | |
| }); | |
| // Add row labels | |
| g.selectAll(".row-label") | |
| .data(d3.range(numRows)) | |
| .enter().append("text") | |
| .attr("class", d => `row-label row-label-${d}`) | |
| .attr("x", -10) | |
| .attr("y", d => yScale(String(d))! + yScale.bandwidth() / 2) | |
| .attr("text-anchor", "end") | |
| .attr("dominant-baseline", "middle") | |
| .style("font-size", "11px") | |
| .style("fill", "#9ca3af") | |
| .style("font-family", "monospace") | |
| .text(d => currentAttention.tokens?.[d] || `T${d}`); | |
| // Add column labels | |
| g.selectAll(".col-label") | |
| .data(d3.range(numCols)) | |
| .enter().append("text") | |
| .attr("class", d => `col-label col-label-${d}`) | |
| .attr("x", d => xScale(String(d))! + xScale.bandwidth() / 2) | |
| .attr("y", -10) | |
| .attr("text-anchor", "middle") | |
| .style("font-size", "11px") | |
| .style("fill", "#9ca3af") | |
| .style("font-family", "monospace") | |
| .text(d => currentAttention.tokens?.[d] || `T${d}`); | |
| // Add title - positioned higher to avoid overlap | |
| svg.append("text") | |
| .attr("x", (width + margin.left + margin.right) / 2) | |
| .attr("y", 30) | |
| .attr("text-anchor", "middle") | |
| .style("font-size", "18px") | |
| .style("font-weight", "bold") | |
| .style("fill", "#fff") | |
| .text(`Attention Weights - ${currentAttention.layer}`); | |
| // Add axis labels | |
| svg.append("text") | |
| .attr("x", (width + margin.left + margin.right) / 2) | |
| .attr("y", height + margin.top + margin.bottom - 20) | |
| .attr("text-anchor", "middle") | |
| .style("font-size", "14px") | |
| .style("fill", "#9ca3af") | |
| .text("Target Tokens →"); | |
| svg.append("text") | |
| .attr("transform", "rotate(-90)") | |
| .attr("x", -(height + margin.top + margin.bottom) / 2) | |
| .attr("y", 20) | |
| .attr("text-anchor", "middle") | |
| .style("font-size", "14px") | |
| .style("fill", "#9ca3af") | |
| .text("Source Tokens →"); | |
| // Add color legend - positioned in top right corner | |
| const legendWidth = 150; | |
| const legendHeight = 15; | |
| const legendScale = d3.scaleLinear() | |
| .domain([0, currentAttention.max_weight]) | |
| .range([0, legendWidth]); | |
| const legendAxis = d3.axisBottom(legendScale) | |
| .ticks(4) | |
| .tickFormat(d => (d as number).toFixed(2)); | |
| const legend = svg.append("g") | |
| .attr("transform", `translate(${width + margin.left - legendWidth}, ${60})`); | |
| // Create gradient for legend | |
| const gradientId = `attention-gradient-${Date.now()}`; | |
| const gradient = svg.append("defs") | |
| .append("linearGradient") | |
| .attr("id", gradientId) | |
| .attr("x1", "0%") | |
| .attr("x2", "100%"); | |
| for (let i = 0; i <= 20; i++) { | |
| const t = i / 20; | |
| gradient.append("stop") | |
| .attr("offset", `${t * 100}%`) | |
| .attr("stop-color", colorScale(t * currentAttention.max_weight)); | |
| } | |
| legend.append("rect") | |
| .attr("width", legendWidth) | |
| .attr("height", legendHeight) | |
| .style("fill", `url(#${gradientId})`) | |
| .style("stroke", "#4b5563") | |
| .style("stroke-width", 0.5); | |
| legend.append("g") | |
| .attr("transform", `translate(0, ${legendHeight})`) | |
| .call(legendAxis) | |
| .selectAll("text") | |
| .style("fill", "#9ca3af") | |
| .style("font-size", "10px"); | |
| legend.append("text") | |
| .attr("x", legendWidth / 2) | |
| .attr("y", -5) | |
| .attr("text-anchor", "middle") | |
| .style("font-size", "11px") | |
| .style("fill", "#9ca3af") | |
| .text("Attention Weight"); | |
| // Cleanup | |
| return () => { | |
| tooltip.remove(); | |
| }; | |
| }, [currentAttention, zoom, colorScheme]); | |
| // Update available layers whenever attentionData changes | |
| useEffect(() => { | |
| if (attentionData.length > 0) { | |
| const layers = Array.from(new Set(attentionData.map(a => a.layer))) | |
| .sort((a, b) => { | |
| // Sort layers numerically (layer.0, layer.1, etc.) | |
| const numA = parseInt(a.replace('layer.', '')); | |
| const numB = parseInt(b.replace('layer.', '')); | |
| return numA - numB; | |
| }); | |
| console.log('Updating available layers from attentionData:', layers); | |
| setAvailableLayers(layers); | |
| } | |
| }, [attentionData]); | |
| // Export functionality | |
| const exportAttentionMap = () => { | |
| if (!currentAttention) return; | |
| const dataStr = JSON.stringify(currentAttention, null, 2); | |
| const dataUri = 'data:application/json;charset=utf-8,'+ encodeURIComponent(dataStr); | |
| const exportFileDefaultName = `attention_${currentAttention.layer.replace(/\./g, '_')}_${Date.now()}.json`; | |
| const linkElement = document.createElement('a'); | |
| linkElement.setAttribute('href', dataUri); | |
| linkElement.setAttribute('download', exportFileDefaultName); | |
| linkElement.click(); | |
| }; | |
| // Generate code with attention traces | |
| const handleGenerate = async () => { | |
| if (!prompt || isGenerating) return; | |
| setIsGenerating(true); | |
| try { | |
| const response = await fetch(`${getApiUrl()}/generate`, { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify({ | |
| prompt, | |
| max_tokens: 100, | |
| temperature: 0.7, | |
| sampling_rate: 0.3 // Sample 30% of tokens for attention | |
| }) | |
| }); | |
| if (!response.ok) { | |
| throw new Error(`HTTP error! status: ${response.status}`); | |
| } | |
| const data = await response.json(); | |
| console.log('Generation completed:', data); | |
| // Process any attention traces from the response | |
| if (data.traces) { | |
| const attentionTraces = data.traces | |
| .filter((t: TraceData) => t.type === 'attention' && t.weights) | |
| .map((t: TraceData) => ({ | |
| layer: t.layer || 'unknown', | |
| weights: t.weights || [], | |
| max_weight: t.max_weight || 1, | |
| entropy: t.entropy, | |
| timestamp: t.timestamp || Date.now(), | |
| tokens: t.tokens | |
| })); | |
| if (attentionTraces.length > 0) { | |
| console.log('Setting attention data:', attentionTraces.length, 'traces'); | |
| const layerSet = new Set(attentionTraces.map((a: AttentionData) => a.layer)); | |
| const uniqueLayers = (Array.from(layerSet) as string[]) | |
| .sort((a: string, b: string) => { | |
| const numA = parseInt(a.replace('layer.', '')); | |
| const numB = parseInt(b.replace('layer.', '')); | |
| return numA - numB; | |
| }); | |
| console.log('Unique layers found from HTTP response:', uniqueLayers); | |
| // Set all state together | |
| setAttentionData(attentionTraces); | |
| setAvailableLayers(uniqueLayers); | |
| setSelectedLayer(uniqueLayers[0]); | |
| console.log('State updated with', attentionTraces.length, 'traces and', uniqueLayers.length, 'layers'); | |
| } | |
| } | |
| } catch (error) { | |
| console.error('Generation failed:', error); | |
| alert(`Generation failed: ${error}`); | |
| } finally { | |
| setIsGenerating(false); | |
| } | |
| }; | |
| // Export as image | |
| const exportAsImage = () => { | |
| if (!svgRef.current) return; | |
| const svgData = new XMLSerializer().serializeToString(svgRef.current); | |
| const canvas = document.createElement("canvas"); | |
| const ctx = canvas.getContext("2d"); | |
| const img = new Image(); | |
| img.onload = () => { | |
| canvas.width = img.width; | |
| canvas.height = img.height; | |
| ctx?.drawImage(img, 0, 0); | |
| canvas.toBlob((blob) => { | |
| if (blob) { | |
| const url = URL.createObjectURL(blob); | |
| const link = document.createElement('a'); | |
| link.download = `attention_${currentAttention?.layer.replace(/\./g, '_')}_${Date.now()}.png`; | |
| link.href = url; | |
| link.click(); | |
| } | |
| }); | |
| }; | |
| img.src = 'data:image/svg+xml;base64,' + btoa(svgData); | |
| }; | |
| // Generate contextual explanation for current visualization | |
| const generateExplanation = () => { | |
| if (!currentAttention) { | |
| return { | |
| title: "No Attention Data", | |
| description: "Run a model to see attention patterns between tokens.", | |
| details: [] | |
| }; | |
| } | |
| const maxWeight = currentAttention.max_weight; | |
| const entropy = currentAttention.entropy; | |
| const numTokens = currentAttention.weights.length; | |
| const strongAttentions = currentAttention.weights.flat().filter(w => w > maxWeight * 0.5).length; | |
| const totalCells = currentAttention.weights.flat().length; | |
| const focusPercentage = ((strongAttentions / totalCells) * 100).toFixed(1); | |
| return { | |
| title: `Attention Pattern: ${currentAttention.layer}`, | |
| description: `Viewing ${numTokens}×${numTokens} attention matrix showing how tokens attend to each other.`, | |
| details: [ | |
| { | |
| heading: "What is an Attention Heatmap?", | |
| content: `This heatmap shows attention weights between tokens. Each cell represents how much one token "pays attention" to another. Brighter cells (higher values) mean stronger attention relationships.` | |
| }, | |
| { | |
| heading: "Reading the Matrix", | |
| content: `Rows represent source tokens (what's attending), columns represent target tokens (what's being attended to). The diagonal often shows self-attention. Patterns reveal how the model processes relationships between words.` | |
| }, | |
| { | |
| heading: "Current Pattern Analysis", | |
| content: `${focusPercentage}% of attention weights are strong (>50% of max). ${entropy ? `Entropy: ${entropy.toFixed(2)} - ${entropy < 2 ? 'Focused attention' : entropy < 4 ? 'Distributed attention' : 'Scattered attention'}` : 'Pattern shows attention distribution across tokens.'}` | |
| }, | |
| { | |
| heading: "Color Intensity Meaning", | |
| content: `Colors range from dark (0.00) to bright (${maxWeight.toFixed(2)}). Bright spots indicate strong dependencies. The model uses these weights to combine information from different positions.` | |
| }, | |
| { | |
| heading: "Common Patterns", | |
| content: `Look for: Vertical lines (tokens everyone attends to), horizontal lines (tokens attending broadly), diagonal patterns (local/sequential attention), and block patterns (grouped attention).` | |
| }, | |
| { | |
| heading: "Layer Significance", | |
| content: `${currentAttention.layer.includes('0') ? 'Early layers capture local patterns and syntax.' : | |
| currentAttention.layer.includes('1') ? 'Middle layers build semantic relationships.' : | |
| 'Later layers form high-level representations.'} Each layer builds on previous ones.` | |
| } | |
| ] | |
| }; | |
| }; | |
| const explanation = generateExplanation(); | |
| return ( | |
| <div className="bg-gray-900 rounded-xl p-6"> | |
| <div className="flex items-center justify-between mb-6"> | |
| <div> | |
| <h2 className="text-2xl font-bold flex items-center gap-2"> | |
| <Eye className="w-6 h-6 text-blue-400" /> | |
| Attention Explorer | |
| </h2> | |
| <p className="text-gray-400 mt-1"> | |
| Visualize attention patterns across transformer layers | |
| </p> | |
| </div> | |
| <div className="flex items-center gap-4"> | |
| {/* Connection Status */} | |
| <div className={`flex items-center gap-2 px-3 py-1 rounded-full ${ | |
| isConnected ? 'bg-green-900/30 text-green-400' : 'bg-red-900/30 text-red-400' | |
| }`}> | |
| <Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} /> | |
| {isConnected ? 'Connected' : 'Disconnected'} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Prompt Input and Generate Button */} | |
| <div className="flex gap-4 mb-4"> | |
| <div className="flex-1"> | |
| <textarea | |
| value={prompt} | |
| onChange={(e) => setPrompt(e.target.value)} | |
| placeholder="Enter a prompt to generate code (e.g., 'def fibonacci(n):')" | |
| className="w-full px-4 py-2 bg-gray-800 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none resize-none" | |
| rows={2} | |
| disabled={isGenerating} | |
| /> | |
| </div> | |
| <button | |
| onClick={handleGenerate} | |
| disabled={!prompt || isGenerating || !isConnected} | |
| className="px-6 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors disabled:opacity-50 disabled:cursor-not-allowed flex items-center gap-2 h-fit" | |
| > | |
| {isGenerating ? ( | |
| <> | |
| <RefreshCw className="w-5 h-5 animate-spin" /> | |
| Generating... | |
| </> | |
| ) : ( | |
| <> | |
| <Zap className="w-5 h-5" /> | |
| Generate & Trace | |
| </> | |
| )} | |
| </button> | |
| </div> | |
| {/* Controls */} | |
| <div className="flex flex-wrap items-center gap-4 mb-4"> | |
| {/* Layer Selector */} | |
| <div className="flex items-center gap-2"> | |
| <Layers className="w-5 h-5 text-gray-400" /> | |
| <select | |
| value={selectedLayer} | |
| onChange={(e) => setSelectedLayer(e.target.value)} | |
| className="bg-gray-800 text-white px-3 py-1.5 rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none min-w-[200px]" | |
| disabled={availableLayers.length === 0} | |
| > | |
| {availableLayers.length === 0 ? ( | |
| <option value="">No layers available</option> | |
| ) : ( | |
| availableLayers.map(layer => ( | |
| <option key={layer} value={layer}>{layer}</option> | |
| )) | |
| )} | |
| </select> | |
| </div> | |
| {/* Color Scheme Selector */} | |
| <div className="flex items-center gap-2"> | |
| <span className="text-gray-400 text-sm">Color:</span> | |
| <div className="flex gap-1"> | |
| <button | |
| onClick={() => setColorScheme('blues')} | |
| className={`px-2 py-1 text-xs rounded ${colorScheme === 'blues' ? 'bg-blue-600' : 'bg-gray-700'}`} | |
| > | |
| Blues | |
| </button> | |
| <button | |
| onClick={() => setColorScheme('viridis')} | |
| className={`px-2 py-1 text-xs rounded ${colorScheme === 'viridis' ? 'bg-blue-600' : 'bg-gray-700'}`} | |
| > | |
| Viridis | |
| </button> | |
| <button | |
| onClick={() => setColorScheme('plasma')} | |
| className={`px-2 py-1 text-xs rounded ${colorScheme === 'plasma' ? 'bg-blue-600' : 'bg-gray-700'}`} | |
| > | |
| Plasma | |
| </button> | |
| </div> | |
| </div> | |
| {/* Zoom Controls */} | |
| <div className="flex items-center gap-2"> | |
| <button | |
| onClick={() => setZoom(Math.max(0.5, zoom - 0.1))} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title="Zoom Out" | |
| > | |
| <ZoomOut className="w-4 h-4" /> | |
| </button> | |
| <span className="text-sm text-gray-400 min-w-[50px] text-center">{(zoom * 100).toFixed(0)}%</span> | |
| <button | |
| onClick={() => setZoom(Math.min(2, zoom + 0.1))} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title="Zoom In" | |
| > | |
| <ZoomIn className="w-4 h-4" /> | |
| </button> | |
| <button | |
| onClick={() => setZoom(1)} | |
| className="p-2 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors" | |
| title="Reset Zoom" | |
| > | |
| <RefreshCw className="w-4 h-4" /> | |
| </button> | |
| </div> | |
| {/* Export Buttons */} | |
| <div className="flex gap-2 ml-auto"> | |
| <button | |
| onClick={exportAttentionMap} | |
| disabled={!currentAttention} | |
| className="px-3 py-1.5 bg-gray-800 text-white rounded-lg hover:bg-gray-700 transition-colors disabled:opacity-50 flex items-center gap-2" | |
| title="Export as JSON" | |
| > | |
| <Download className="w-4 h-4" /> | |
| JSON | |
| </button> | |
| </div> | |
| </div> | |
| {/* Status Bar */} | |
| <div className="flex items-center gap-4 mb-4 text-sm"> | |
| <div className="text-gray-400"> | |
| <span className="font-semibold">{attentionData.length}</span> attention maps captured | |
| </div> | |
| {hoveredCell && ( | |
| <div className="text-gray-400"> | |
| Cell [{hoveredCell.row}, {hoveredCell.col}]: <span className="text-blue-400 font-mono">{hoveredCell.value.toFixed(4)}</span> | |
| </div> | |
| )} | |
| {currentAttention?.entropy !== undefined && ( | |
| <div className="text-gray-400"> | |
| Entropy: <span className="text-green-400 font-mono">{currentAttention.entropy.toFixed(3)}</span> | |
| </div> | |
| )} | |
| </div> | |
| {/* Main Content Area with Side Panel */} | |
| <div className="flex gap-4"> | |
| {/* Visualization Container */} | |
| <div className="flex-1 transition-all duration-500 ease-in-out"> | |
| <div ref={containerRef} className="bg-gray-800 rounded-lg p-4 overflow-auto max-h-[600px] relative"> | |
| {/* Help Toggle Button */} | |
| <button | |
| onClick={() => setShowExplanation(!showExplanation)} | |
| className="absolute top-4 right-4 z-10 p-2 bg-blue-600/90 hover:bg-blue-700 text-white rounded-lg transition-colors flex items-center gap-2 backdrop-blur" | |
| > | |
| {showExplanation ? <X className="w-5 h-5" /> : <HelpCircle className="w-5 h-5" />} | |
| <span className="text-sm font-medium"> | |
| {showExplanation ? 'Hide Info' : 'What am I seeing?'} | |
| </span> | |
| </button> | |
| {currentAttention ? ( | |
| <svg ref={svgRef}></svg> | |
| ) : ( | |
| <div className="flex items-center justify-center h-96 text-gray-500"> | |
| <div className="text-center"> | |
| <Eye className="w-12 h-12 mx-auto mb-4 opacity-50" /> | |
| <p className="text-lg mb-2">No Attention Data Available</p> | |
| <p className="text-sm mb-4">Run a model to capture attention patterns</p> | |
| <p className="text-xs text-gray-600"> | |
| Attention maps will appear here when traces are received | |
| </p> | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| </div> | |
| {/* Explanation Side Panel */} | |
| <div className={`${showExplanation ? 'w-96' : 'w-0'} transition-all duration-500 ease-in-out overflow-hidden`}> | |
| <div className="w-96 h-[600px] bg-gray-900 rounded-lg border border-gray-700"> | |
| {/* Panel Header */} | |
| <div className="bg-gray-800 px-4 py-3 border-b border-gray-700"> | |
| <div className="flex items-center gap-2"> | |
| <Info className="w-5 h-5 text-blue-400" /> | |
| <h3 className="text-lg font-semibold text-white">Understanding Attention</h3> | |
| </div> | |
| </div> | |
| {/* Panel Content */} | |
| <div className="px-4 py-4 overflow-y-auto h-[calc(600px-60px)]"> | |
| {/* Main Description */} | |
| <div className="mb-4 p-3 bg-blue-900/20 border border-blue-800 rounded-lg"> | |
| <h4 className="text-sm font-semibold text-blue-400 mb-1">{explanation.title}</h4> | |
| <p className="text-xs text-gray-300">{explanation.description}</p> | |
| </div> | |
| {/* Explanation Sections */} | |
| <div className="space-y-3"> | |
| {explanation.details.map((section, idx) => ( | |
| <div key={idx} className="bg-gray-800 rounded-lg p-3"> | |
| <h5 className="font-medium text-sm text-white mb-1 flex items-center gap-1"> | |
| <Zap className="w-3 h-3 text-yellow-400" /> | |
| {section.heading} | |
| </h5> | |
| <p className="text-xs text-gray-300 leading-relaxed">{section.content}</p> | |
| </div> | |
| ))} | |
| </div> | |
| {/* Visual Guide */} | |
| <div className="mt-4 p-3 bg-purple-900/20 border border-purple-800 rounded-lg"> | |
| <h4 className="font-medium text-sm text-purple-400 mb-2">Visual Guide</h4> | |
| <div className="space-y-2 text-xs"> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-purple-300">•</span> | |
| <span className="text-gray-300">Rows = Source tokens (what's attending)</span> | |
| </div> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-purple-300">•</span> | |
| <span className="text-gray-300">Columns = Target tokens (what's being attended to)</span> | |
| </div> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-purple-300">•</span> | |
| <span className="text-gray-300">Brightness = Attention strength (0 to {currentAttention?.max_weight.toFixed(2) || '1.00'})</span> | |
| </div> | |
| <div className="flex items-start gap-2"> | |
| <span className="text-purple-300">•</span> | |
| <span className="text-gray-300">Hover over cells for detailed weights</span> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Current Metrics */} | |
| {currentAttention && ( | |
| <div className="mt-4 p-3 bg-gray-800 rounded-lg"> | |
| <h4 className="font-medium text-sm text-gray-300 mb-2">Current Metrics</h4> | |
| <div className="space-y-1 text-xs"> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Layer:</span> | |
| <span className="text-white font-mono">{currentAttention.layer}</span> | |
| </div> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Matrix Size:</span> | |
| <span className="text-white">{currentAttention.weights.length} × {currentAttention.weights[0]?.length || 0}</span> | |
| </div> | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Max Weight:</span> | |
| <span className="text-blue-400">{currentAttention.max_weight.toFixed(4)}</span> | |
| </div> | |
| {currentAttention.entropy !== undefined && ( | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Entropy:</span> | |
| <span className="text-green-400">{currentAttention.entropy.toFixed(3)}</span> | |
| </div> | |
| )} | |
| <div className="flex justify-between"> | |
| <span className="text-gray-400">Timestamp:</span> | |
| <span className="text-white">{new Date(currentAttention.timestamp).toLocaleTimeString()}</span> | |
| </div> | |
| </div> | |
| </div> | |
| )} | |
| {/* Tips */} | |
| <div className="mt-4 p-3 bg-gray-800 rounded-lg"> | |
| <h4 className="font-medium text-sm text-gray-300 mb-2">💡 Tips</h4> | |
| <ul className="text-xs text-gray-400 space-y-1"> | |
| <li>• Use zoom controls to explore large matrices</li> | |
| <li>• Switch color schemes for better contrast</li> | |
| <li>• Compare patterns across different layers</li> | |
| <li>• Look for diagonal patterns (self-attention)</li> | |
| </ul> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Info Panel */} | |
| {currentAttention && ( | |
| <div className="mt-4 p-4 bg-gray-800 rounded-lg"> | |
| <h3 className="text-lg font-semibold mb-3">Attention Map Details</h3> | |
| <div className="grid grid-cols-2 md:grid-cols-4 gap-4 text-sm"> | |
| <div> | |
| <span className="text-gray-400">Layer:</span> | |
| <div className="font-mono text-white mt-1">{currentAttention.layer}</div> | |
| </div> | |
| <div> | |
| <span className="text-gray-400">Shape:</span> | |
| <div className="font-mono text-white mt-1"> | |
| {currentAttention.weights.length} × {currentAttention.weights[0]?.length || 0} | |
| </div> | |
| </div> | |
| <div> | |
| <span className="text-gray-400">Max Weight:</span> | |
| <div className="font-mono text-blue-400 mt-1">{currentAttention.max_weight.toFixed(4)}</div> | |
| </div> | |
| <div> | |
| <span className="text-gray-400">Timestamp:</span> | |
| <div className="font-mono text-white mt-1"> | |
| {new Date(currentAttention.timestamp).toLocaleTimeString()} | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| ); | |
| } |