/** * Token Flow Visualizer Component * * Visualizes how tokens flow through transformer layers, * showing attention paths and information propagation * * @component */ "use client"; import { useState, useEffect, useRef } from "react"; import * as d3 from "d3"; import { getApiUrl, getWsUrl } from "@/lib/config"; import { GitBranch, Activity, Layers, Play, Pause, RotateCcw, ZoomIn, ZoomOut, Download, Info, HelpCircle, X, Zap, RefreshCw } from "lucide-react"; // Token data structure interface Token { id: string; text: string; position: number; embedding?: number[]; } // Layer data structure interface LayerData { layerIndex: number; layerName: string; tokens: TokenState[]; attention: number[][]; timestamp: number; } // Token state at a specific layer interface TokenState { tokenId: string; text: string; position: number; activation: number; attention_received: number; attention_given: number; importance: number; } // Flow connection between tokens across layers interface FlowConnection { source: { layer: number; token: number }; target: { layer: number; token: number }; strength: number; type: 'attention' | 'residual' | 'feedforward'; } export default function TokenFlowVisualizer() { const [tokens, setTokens] = useState([]); const [layers, setLayers] = useState([]); const [flowConnections, setFlowConnections] = useState([]); const [selectedToken, setSelectedToken] = useState(null); const [selectedLayer, setSelectedLayer] = useState(null); const [isPlaying, setIsPlaying] = useState(false); const [currentStep, setCurrentStep] = useState(0); const [zoom, setZoom] = useState(1); const [showResidual, setShowResidual] = useState(true); const [showAttention, setShowAttention] = useState(true); const [showExplanation, setShowExplanation] = useState(false); const [isConnected, setIsConnected] = useState(false); const [prompt, setPrompt] = useState("def fibonacci(n):\n '''Calculate fibonacci number'''"); const [isGenerating, setIsGenerating] = useState(false); const [traces, setTraces] = useState[]>([]); const svgRef = useRef(null); const animationRef = useRef(null); const wsRef = useRef(null); // Connect to WebSocket for real-time updates useEffect(() => { let mounted = true; let reconnectTimeout: NodeJS.Timeout; const connectWS = () => { if (!mounted) return; try { const ws = new WebSocket(getWsUrl()); ws.onopen = () => { if (!mounted) return; console.log('TokenFlow: WebSocket connected'); setIsConnected(true); }; ws.onmessage = (event) => { if (!mounted) return; let data; try { data = JSON.parse(event.data); } catch (e) { // Skip non-JSON messages return; } // Collect traces for visualization if (data.type === 'attention' || data.type === 'activation') { setTraces(prev => [...prev, data]); } else if (data.type === 'generated_token') { // Handle token generation setTokens(prev => { const newToken: Token = { id: `token_${prev.length}`, text: data.token, position: prev.length }; return [...prev, newToken]; }); } }; ws.onerror = () => { if (mounted) { setIsConnected(false); } }; ws.onclose = () => { if (!mounted) return; console.log('TokenFlow: WebSocket disconnected, will reconnect...'); setIsConnected(false); reconnectTimeout = setTimeout(() => { if (mounted) connectWS(); }, 3000); }; wsRef.current = ws; } catch (error) { console.log('WebSocket connection attempt failed, will retry...'); if (mounted) { setIsConnected(false); reconnectTimeout = setTimeout(() => { if (mounted) connectWS(); }, 3000); } } }; connectWS(); return () => { mounted = false; if (reconnectTimeout) { clearTimeout(reconnectTimeout); } if (wsRef.current) { wsRef.current.close(); } }; }, []); // Listen for demo events from LocalControlPanel useEffect(() => { const handleDemoPromptSelected = (event: CustomEvent) => { const { prompt, demoId } = event.detail; console.log('TokenFlow: Demo prompt selected -', demoId); if (prompt) { setPrompt(prompt); } }; const handleDemoStarting = (event: CustomEvent) => { const { demoId } = event.detail; console.log('TokenFlow: Demo starting, clearing data -', demoId); // Clear all data when demo starts setTokens([]); setLayers([]); setFlowConnections([]); setTraces([]); setSelectedToken(null); setSelectedLayer(null); }; const handleDemoCompleted = (event: CustomEvent) => { const data = event.detail; console.log('TokenFlow: Demo completed', data); // Process the completed demo data if (data && data.traces) { setTraces(data.traces); } }; window.addEventListener('demo-prompt-selected', handleDemoPromptSelected as EventListener); window.addEventListener('demo-starting', handleDemoStarting as EventListener); window.addEventListener('demo-completed', handleDemoCompleted as EventListener); return () => { window.removeEventListener('demo-prompt-selected', handleDemoPromptSelected as EventListener); window.removeEventListener('demo-starting', handleDemoStarting as EventListener); window.removeEventListener('demo-completed', handleDemoCompleted as EventListener); }; }, []); // Process traces to extract token flow data useEffect(() => { const attentionTraces = traces.filter(t => t.type === 'attention' && t.weights); const activationTraces = traces.filter(t => t.type === 'activation'); console.log('[TokenFlow] Total traces:', traces.length); console.log('[TokenFlow] Attention traces:', attentionTraces.length); console.log('[TokenFlow] First attention trace:', attentionTraces[0]); if (attentionTraces.length > 0 || tokens.length > 0) { // Use existing tokens if available (from streaming), otherwise extract from traces if (tokens.length === 0) { const traceWithTokens = attentionTraces.find(t => t.tokens); if (traceWithTokens?.tokens && Array.isArray(traceWithTokens.tokens)) { const extractedTokens: Token[] = (traceWithTokens.tokens as string[]).map((text: string, idx: number) => ({ id: `token_${idx}`, text, position: idx })); setTokens(extractedTokens); } } // Build layer data const layerMap = new Map[]>(); attentionTraces.forEach(trace => { const layer = String(trace.layer || 'unknown'); if (!layerMap.has(layer)) { layerMap.set(layer, []); } layerMap.get(layer)?.push(trace); }); const layerDataArray: LayerData[] = Array.from(layerMap.entries()) .map(([layerName, traces], idx) => { const latestTrace = traces[traces.length - 1]; const weights = (latestTrace.weights || []) as number[][]; // Calculate token states for this layer // Use the tokens we have collected (either from streaming or from traces) const tokenTexts = tokens.length > 0 ? tokens.map(t => t.text) : ((latestTrace.tokens || []) as string[]); const tokenStates: TokenState[] = tokenTexts.map((text: string, tokenIdx: number) => { // Calculate attention received (sum of column) const attention_received = weights.reduce((sum: number, row: number[]) => sum + (row[tokenIdx] || 0), 0 ); // Calculate attention given (sum of row) const attention_given = weights[tokenIdx]?.reduce((sum: number, val: number) => sum + val, 0 ) || 0; // Calculate importance as combination of received and given attention const importance = (attention_received + attention_given) / 2; return { tokenId: `token_${tokenIdx}`, text, position: tokenIdx, activation: Math.random(), // Would come from activation traces attention_received, attention_given, importance }; }); return { layerIndex: idx, layerName, tokens: tokenStates, attention: weights, timestamp: (latestTrace.timestamp || Date.now()) as number }; }) .sort((a, b) => { // Extract layer numbers for proper numerical sorting const aNum = parseInt(a.layerName.replace(/[^0-9]/g, '')) || 0; const bNum = parseInt(b.layerName.replace(/[^0-9]/g, '')) || 0; return aNum - bNum; }); setLayers(layerDataArray); // Generate flow connections generateFlowConnections(layerDataArray); } }, [traces, tokens]); // Generate flow connections between layers const generateFlowConnections = (layerData: LayerData[]) => { const connections: FlowConnection[] = []; for (let i = 0; i < layerData.length - 1; i++) { const currentLayer = layerData[i]; const nextLayer = layerData[i + 1]; // Add attention connections if (currentLayer.attention && showAttention) { currentLayer.attention.forEach((row, srcToken) => { row.forEach((weight, tgtToken) => { if (weight > 0.1) { // Threshold for visibility connections.push({ source: { layer: i, token: srcToken }, target: { layer: i + 1, token: tgtToken }, strength: weight, type: 'attention' }); } }); }); } // Add residual connections if (showResidual) { currentLayer.tokens.forEach((token, idx) => { if (idx < nextLayer.tokens.length) { connections.push({ source: { layer: i, token: idx }, target: { layer: i + 1, token: idx }, strength: 0.5, type: 'residual' }); } }); } } setFlowConnections(connections); }; // D3 Visualization useEffect(() => { if (!svgRef.current || layers.length === 0) return; const margin = { top: 60, right: 200, bottom: 60, left: 100 }; const width = 1300; const height = 600; // Clear previous visualization d3.select(svgRef.current).selectAll("*").remove(); const svg = d3.select(svgRef.current) .attr("width", width) .attr("height", height) .attr("viewBox", `0 0 ${width} ${height}`); const g = svg.append("g") .attr("transform", `translate(${margin.left},${margin.top}) scale(${zoom})`); // Calculate positions const layerWidth = (width - margin.left - margin.right) / (layers.length || 1); const tokenHeight = 40; const tokenWidth = 80; // Create layer groups const layerGroups = g.selectAll(".layer-group") .data(layers) .enter() .append("g") .attr("class", "layer-group") .attr("transform", (d, i) => `translate(${i * layerWidth}, 0)`); // Add layer labels layerGroups.append("text") .attr("x", layerWidth / 2) .attr("y", -20) .attr("text-anchor", "middle") .attr("fill", "#9ca3af") .attr("font-size", "12px") .attr("font-weight", "bold") .text(d => d.layerName); // Function to get token position const getTokenPosition = (layerIdx: number, tokenIdx: number) => { const x = layerIdx * layerWidth + layerWidth / 2; const y = tokenIdx * (tokenHeight + 10) + tokenHeight / 2; return { x, y }; }; // Draw flow connections const connectionPaths = g.selectAll(".flow-connection") .data(flowConnections) .enter() .append("path") .attr("class", d => `flow-connection flow-${d.type}`) .attr("d", d => { const source = getTokenPosition(d.source.layer, d.source.token); const target = getTokenPosition(d.target.layer, d.target.token); // Create curved path const midX = (source.x + target.x) / 2; return `M ${source.x} ${source.y} Q ${midX} ${source.y} ${midX} ${(source.y + target.y) / 2} T ${target.x} ${target.y}`; }) .attr("stroke", d => { if (d.type === 'attention') return "#3b82f6"; if (d.type === 'residual') return "#10b981"; return "#8b5cf6"; }) .attr("stroke-width", d => Math.max(0.5, d.strength * 3)) .attr("stroke-opacity", d => d.strength * 0.6) .attr("fill", "none"); // Add animation to connections if playing if (isPlaying) { connectionPaths .attr("stroke-dasharray", "5,5") .append("animate") .attr("attributeName", "stroke-dashoffset") .attr("from", "10") .attr("to", "0") .attr("dur", "1s") .attr("repeatCount", "indefinite"); } // Draw tokens const tokenGroups = layerGroups.selectAll(".token") .data(d => d.tokens) .enter() .append("g") .attr("class", "token-group") .attr("transform", (d, i) => { const pos = getTokenPosition(0, i); return `translate(${layerWidth / 2 - tokenWidth / 2}, ${i * (tokenHeight + 10)})`; }); // Token rectangles tokenGroups.append("rect") .attr("width", tokenWidth) .attr("height", tokenHeight) .attr("rx", 6) .attr("fill", d => { const importance = d.importance || 0; return d3.interpolateYlOrRd(importance); }) .attr("stroke", d => selectedToken === d.position ? "#3b82f6" : "#4b5563") .attr("stroke-width", d => selectedToken === d.position ? 2 : 1) .style("cursor", "pointer") .on("click", (event, d) => { setSelectedToken(d.position === selectedToken ? null : d.position); }); // Token text tokenGroups.append("text") .attr("x", tokenWidth / 2) .attr("y", tokenHeight / 2) .attr("text-anchor", "middle") .attr("dominant-baseline", "middle") .attr("fill", d => d.importance > 0.5 ? "#fff" : "#1f2937") .attr("font-size", "11px") .attr("font-family", "monospace") .attr("pointer-events", "none") .text(d => d.text.substring(0, 8)); // Add importance indicator tokenGroups.append("circle") .attr("cx", tokenWidth - 10) .attr("cy", 10) .attr("r", d => Math.max(2, d.importance * 6)) .attr("fill", "#fbbf24") .attr("opacity", 0.8); // Add title svg.append("text") .attr("x", width / 2) .attr("y", 30) .attr("text-anchor", "middle") .attr("font-size", "16px") .attr("font-weight", "bold") .attr("fill", "#fff") .text("Token Flow Through Transformer Layers"); // Add legend - positioned in the right margin area, clear of the visualization // The visualization ends at width - margin.right (1100), legend goes in the margin const legend = svg.append("g") .attr("transform", `translate(${width - 180}, 100)`); const legendItems = [ { color: "#3b82f6", label: "Attention Flow", type: "attention" }, { color: "#10b981", label: "Residual Connection", type: "residual" }, { color: "#fbbf24", label: "Token Importance", type: "importance" } ]; legendItems.forEach((item, i) => { const legendItem = legend.append("g") .attr("transform", `translate(0, ${i * 25})`); if (item.type === "importance") { legendItem.append("circle") .attr("cx", 10) .attr("cy", 10) .attr("r", 6) .attr("fill", item.color); } else { legendItem.append("line") .attr("x1", 0) .attr("y1", 10) .attr("x2", 20) .attr("y2", 10) .attr("stroke", item.color) .attr("stroke-width", 2); } legendItem.append("text") .attr("x", 30) .attr("y", 10) .attr("dominant-baseline", "middle") .attr("fill", "#9ca3af") .attr("font-size", "11px") .text(item.label); }); }, [layers, flowConnections, selectedToken, zoom, isPlaying, showAttention, showResidual]); // Animation control const toggleAnimation = () => { setIsPlaying(!isPlaying); if (!isPlaying) { animateFlow(); } else { if (animationRef.current) { cancelAnimationFrame(animationRef.current); } } }; const animateFlow = () => { setCurrentStep(prev => (prev + 1) % (layers.length || 1)); animationRef.current = requestAnimationFrame(() => { if (isPlaying) { setTimeout(animateFlow, 1000); } }); }; const reset = () => { setCurrentStep(0); setSelectedToken(null); setSelectedLayer(null); setIsPlaying(false); }; // Generate contextual explanation for current visualization const generateExplanation = () => { if (layers.length === 0) { return { title: "No Token Flow Data", description: "Run a model to see how tokens flow through transformer layers.", details: [] }; } const numLayers = layers.length; const numTokens = tokens.length; const activeConnections = flowConnections.filter(c => c.strength > 0.1).length; const totalConnections = flowConnections.length; const connectionDensity = totalConnections > 0 ? ((activeConnections / totalConnections) * 100).toFixed(1) : "0"; return { title: `Token Flow Analysis: ${numTokens} tokens, ${numLayers} layers`, description: `Visualizing information flow through the transformer's attention mechanism.`, details: [ { heading: "What is Token Flow?", content: `This visualization shows how tokens are processed through transformer layers in real-time. Each column represents a layer, each box is a token at that layer. The visualization builds progressively as tokens are generated.` }, { heading: "Reading the Flow", content: `Tokens flow from left (layer 0, input) to right (final layer, output). Each column is a transformer layer processing all tokens. Color intensity (yellow→orange→red) shows token importance/activation strength.` }, { heading: "Real-time Generation", content: `The visualization starts with one column and expands horizontally as new tokens are generated. You're watching the model build its understanding token by token, layer by layer.` }, { heading: "Current Network Stats", content: `${numTokens} tokens × ${numLayers} layers = ${numTokens * numLayers} nodes. ${connectionDensity}% of possible connections are active (strength > 0.1). Blue lines show attention flow between tokens.` }, { heading: "Connection Types", content: `Blue lines: Attention connections showing which tokens attend to which. Green lines: Residual connections (when enabled). Line thickness indicates connection strength.` }, { heading: "Color Meaning", content: `Token color represents activation level: Light yellow (low activation), Orange (medium), Red (high activation). This shows which tokens are most important at each processing stage.` } ] }; }; const explanation = generateExplanation(); // Export functionality const exportVisualization = () => { if (!svgRef.current) return; const svgData = new XMLSerializer().serializeToString(svgRef.current); const svgBlob = new Blob([svgData], { type: "image/svg+xml;charset=utf-8" }); const svgUrl = URL.createObjectURL(svgBlob); const link = document.createElement("a"); link.href = svgUrl; link.download = `token_flow_${Date.now()}.svg`; link.click(); }; return (

Token Flow Visualizer

Track how information flows through transformer layers

{isConnected ? 'Connected' : 'Disconnected'}
{/* Generation Controls */}
setPrompt(e.target.value)} className="flex-1 px-4 py-2 bg-gray-800 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none font-mono text-sm" placeholder="Enter prompt to analyze token flow..." />
{/* Controls */}
{/* Playback Controls */}
Layer {currentStep + 1} / {layers.length || 1}
{/* View Options */}
{/* Zoom Controls */}
{(zoom * 100).toFixed(0)}%
{/* Export */}
{/* Main Content Area with Side Panel */}
{/* Visualization Container */}
{/* Help Toggle Button */} {layers.length > 0 ? ( ) : (

No Token Flow Data

Run a model to visualize token flow through layers

)}
{/* Explanation Side Panel */}
{/* Panel Header */}

Understanding Token Flow

{/* Panel Content */}
{/* Main Description */}

{explanation.title}

{explanation.description}

{/* Explanation Sections */}
{explanation.details.map((section, idx) => (
{section.heading}

{section.content}

))}
{/* Visual Guide */}

Visual Elements

Nodes = Tokens at each layer
Lines = Attention connections
Thickness = Connection strength
Color intensity = Token importance
{/* Current Metrics */} {layers.length > 0 && (

Current Metrics

Tokens: {tokens.length}
Layers: {layers.length}
Total Nodes: {tokens.length * layers.length}
Active Connections: {flowConnections.filter(c => c.strength > 0.1).length}
)} {/* Tips */}

💡 Tips

  • • Click tokens to trace their path
  • • Use animation to see flow evolution
  • • Zoom for different perspectives
  • • Toggle connection types with controls
{/* Info Panel */} {selectedToken !== null && tokens[selectedToken] && (

Selected Token: "{tokens[selectedToken].text}"

Position:
{selectedToken}
Layers Processed:
{layers.length}
Connections:
{flowConnections.filter(c => c.source.token === selectedToken || c.target.token === selectedToken ).length}
Max Importance:
{Math.max(...layers.map(l => l.tokens[selectedToken]?.importance || 0 )).toFixed(3)}
)} {/* Instructions */} {layers.length === 0 && (

How to Use

  1. Run a model to generate attention traces
  2. Token flow will automatically visualize
  3. Click tokens to see their flow details
  4. Use controls to animate the flow
)}
); }