Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| /** | |
| * Decision Path 3D Visualization | |
| * | |
| * Shows the exact path through the neural network when generating tokens, | |
| * highlighting critical layers, attention patterns, and decision factors. | |
| * This is the core "Glass Box" visualization for the PhD thesis. | |
| * | |
| * @component | |
| */ | |
| "use client"; | |
| import { useRef, useState, useEffect, useMemo } from "react"; | |
| import { Canvas, useFrame, useThree } from "@react-three/fiber"; | |
| import { | |
| OrbitControls, | |
| Text, | |
| Box, | |
| Sphere, | |
| Line, | |
| Billboard | |
| } from "@react-three/drei"; | |
| import * as THREE from "three"; | |
| import { | |
| Brain, | |
| Zap, | |
| Eye, | |
| GitBranch, | |
| Activity, | |
| Sparkles, | |
| AlertCircle, | |
| TrendingUp, | |
| Layers | |
| } from "lucide-react"; | |
| interface LayerActivation { | |
| layer_index: number; | |
| attention_weights: number[][]; | |
| hidden_state_norm: number; | |
| ffn_activation: number; | |
| top_attention_heads: number[]; | |
| confidence: number; | |
| } | |
| interface DecisionPath { | |
| token: string; | |
| token_id: number; | |
| probability: number; | |
| layer_activations: LayerActivation[]; | |
| attention_flow: Array<{ | |
| from_layer: number; | |
| to_layer: number | string; | |
| strength: number; | |
| top_heads: number[]; | |
| }>; | |
| alternatives: Array<{ | |
| token: string; | |
| token_id: number; | |
| probability: number; | |
| }>; | |
| decision_factors: { | |
| attention_focus: number; | |
| semantic_alignment: number; | |
| syntactic_correctness: number; | |
| context_relevance: number; | |
| confidence: number; | |
| }; | |
| critical_layers: number[]; | |
| confidence_score: number; | |
| timestamp: number; | |
| } | |
| // Enhanced Transformer Layer with activation visualization | |
| function EnhancedTransformerLayer({ | |
| position, | |
| layerIndex, | |
| activation, | |
| isCritical, | |
| isActive | |
| }: { | |
| position: [number, number, number]; | |
| layerIndex: number; | |
| activation?: LayerActivation; | |
| isCritical?: boolean; | |
| isActive?: boolean; | |
| }) { | |
| const meshRef = useRef<THREE.Mesh>(null); | |
| const glowRef = useRef<THREE.Mesh>(null); | |
| useFrame((state) => { | |
| if (meshRef.current && isCritical) { | |
| // Pulse critical layers | |
| const scale = 1 + Math.sin(state.clock.elapsedTime * 3) * 0.1; | |
| meshRef.current.scale.set(scale, scale, scale); | |
| } | |
| if (glowRef.current && activation) { | |
| // Glow based on activation strength | |
| glowRef.current.scale.set( | |
| 1 + activation.confidence * 0.3, | |
| 1 + activation.confidence * 0.3, | |
| 1 + activation.confidence * 0.3 | |
| ); | |
| } | |
| }); | |
| const baseColor = isCritical ? "#ff6b6b" : isActive ? "#4ecdc4" : "#2d3748"; | |
| const emissiveIntensity = activation ? activation.confidence : 0; | |
| return ( | |
| <group position={position}> | |
| {/* Glow effect for active layers */} | |
| {activation && ( | |
| <Sphere ref={glowRef} args={[2.5, 16, 16]}> | |
| <meshBasicMaterial | |
| color={isCritical ? "#ff6b6b" : "#4ecdc4"} | |
| transparent | |
| opacity={activation.confidence * 0.3} | |
| /> | |
| </Sphere> | |
| )} | |
| {/* Main layer box */} | |
| <Box ref={meshRef} args={[4, 0.3, 3]}> | |
| <meshStandardMaterial | |
| color={baseColor} | |
| emissive={baseColor} | |
| emissiveIntensity={emissiveIntensity} | |
| metalness={0.8} | |
| roughness={0.2} | |
| /> | |
| </Box> | |
| {/* Layer label */} | |
| <Text | |
| position={[0, 0.3, 0]} | |
| fontSize={0.15} | |
| color="white" | |
| anchorX="center" | |
| > | |
| Layer {layerIndex} | |
| </Text> | |
| {/* Confidence indicator */} | |
| {activation && ( | |
| <Text | |
| position={[0, -0.3, 0]} | |
| fontSize={0.08} | |
| color={isCritical ? "#ff6b6b" : "#4ecdc4"} | |
| anchorX="center" | |
| > | |
| Confidence: {(activation.confidence * 100).toFixed(0)}% | |
| </Text> | |
| )} | |
| {/* Attention head activation visualization */} | |
| {activation && activation.top_attention_heads.map((headIdx, i) => ( | |
| <Box | |
| key={headIdx} | |
| position={[(i - 1) * 0.4, 0, 1.8]} | |
| args={[0.3, 0.15, 0.3]} | |
| > | |
| <meshStandardMaterial | |
| color="#ffd93d" | |
| emissive="#ffd93d" | |
| emissiveIntensity={0.5} | |
| /> | |
| </Box> | |
| ))} | |
| </group> | |
| ); | |
| } | |
| // Animated decision flow particle | |
| function DecisionParticle({ | |
| path, | |
| onComplete | |
| }: { | |
| path: THREE.Vector3[]; | |
| onComplete?: () => void; | |
| }) { | |
| const particleRef = useRef<THREE.Mesh>(null); | |
| const [progress, setProgress] = useState(0); | |
| useFrame((state, delta) => { | |
| if (particleRef.current && progress < 1) { | |
| const newProgress = Math.min(progress + delta * 0.3, 1); | |
| setProgress(newProgress); | |
| // Interpolate position along path | |
| const segmentCount = path.length - 1; | |
| const currentSegment = Math.floor(newProgress * segmentCount); | |
| const segmentProgress = (newProgress * segmentCount) % 1; | |
| if (currentSegment < segmentCount) { | |
| const start = path[currentSegment]; | |
| const end = path[currentSegment + 1]; | |
| particleRef.current.position.lerpVectors(start, end, segmentProgress); | |
| } | |
| if (newProgress >= 1 && onComplete) { | |
| onComplete(); | |
| } | |
| } | |
| }); | |
| return ( | |
| <Sphere ref={particleRef} args={[0.2, 16, 16]}> | |
| <meshStandardMaterial | |
| color="#ffd93d" | |
| emissive="#ffd93d" | |
| emissiveIntensity={1} | |
| /> | |
| </Sphere> | |
| ); | |
| } | |
| // Attention flow visualization | |
| function AttentionFlowVisualization({ | |
| flow, | |
| layerSpacing | |
| }: { | |
| flow: DecisionPath['attention_flow']; | |
| layerSpacing: number; | |
| }) { | |
| return ( | |
| <> | |
| {flow.map((connection, idx) => { | |
| const startY = connection.from_layer * layerSpacing; | |
| const endY = connection.to_layer === "output" | |
| ? 20 * layerSpacing + 5 | |
| : (connection.to_layer as number) * layerSpacing; | |
| const points = [ | |
| new THREE.Vector3(0, startY, 0), | |
| new THREE.Vector3(0, endY, 0) | |
| ]; | |
| return ( | |
| <Line | |
| key={idx} | |
| points={points} | |
| color={connection.strength > 0.7 ? "#ff6b6b" : "#4ecdc4"} | |
| lineWidth={connection.strength * 5} | |
| transparent | |
| opacity={0.6} | |
| /> | |
| ); | |
| })} | |
| </> | |
| ); | |
| } | |
| // Alternative tokens display | |
| function AlternativesDisplay({ | |
| alternatives, | |
| position | |
| }: { | |
| alternatives: DecisionPath['alternatives']; | |
| position: [number, number, number]; | |
| }) { | |
| return ( | |
| <group position={position}> | |
| <Text fontSize={0.12} color="#9ca3af" position={[0, 0.5, 0]}> | |
| Alternatives Considered: | |
| </Text> | |
| {alternatives.slice(0, 3).map((alt, idx) => ( | |
| <group key={idx} position={[0, -idx * 0.3, 0]}> | |
| <Text | |
| fontSize={0.1} | |
| color={idx === 0 ? "#4ecdc4" : "#6b7280"} | |
| anchorX="center" | |
| > | |
| {alt.token}: {(alt.probability * 100).toFixed(1)}% | |
| </Text> | |
| </group> | |
| ))} | |
| </group> | |
| ); | |
| } | |
| // Decision factors visualization | |
| function DecisionFactorsDisplay({ | |
| factors, | |
| position | |
| }: { | |
| factors: DecisionPath['decision_factors']; | |
| position: [number, number, number]; | |
| }) { | |
| const factorList = Object.entries(factors); | |
| return ( | |
| <group position={position}> | |
| <Text fontSize={0.12} color="#ffd93d" position={[0, 0.8, 0]}> | |
| Decision Factors: | |
| </Text> | |
| {factorList.map(([key, value], idx) => { | |
| const barWidth = value * 2; | |
| return ( | |
| <group key={key} position={[0, -idx * 0.25, 0]}> | |
| <Box position={[-1, 0, 0]} args={[barWidth, 0.15, 0.1]}> | |
| <meshStandardMaterial | |
| color={value > 0.7 ? "#4ecdc4" : value > 0.4 ? "#ffd93d" : "#ff6b6b"} | |
| emissive={value > 0.7 ? "#4ecdc4" : "#ffd93d"} | |
| emissiveIntensity={0.3} | |
| /> | |
| </Box> | |
| <Text | |
| position={[1, 0, 0]} | |
| fontSize={0.08} | |
| color="#9ca3af" | |
| anchorX="left" | |
| > | |
| {key.replace(/_/g, ' ')}: {(value * 100).toFixed(0)}% | |
| </Text> | |
| </group> | |
| ); | |
| })} | |
| </group> | |
| ); | |
| } | |
| // Main 3D scene with decision path | |
| function DecisionPathScene({ decisionPath }: { decisionPath: DecisionPath | null }) { | |
| const numLayers = 20; | |
| const layerSpacing = 3.5; | |
| const [showParticle, setShowParticle] = useState(false); | |
| // Create path for particle animation | |
| const particlePath = useMemo(() => { | |
| if (!decisionPath) return []; | |
| const path: THREE.Vector3[] = [ | |
| new THREE.Vector3(0, -5, 0), // Start at input | |
| ]; | |
| // Add critical layers | |
| decisionPath.critical_layers.forEach(layerIdx => { | |
| path.push(new THREE.Vector3(0, layerIdx * layerSpacing, 0)); | |
| }); | |
| // End at output | |
| path.push(new THREE.Vector3(0, numLayers * layerSpacing + 5, 0)); | |
| return path; | |
| }, [decisionPath]); | |
| useEffect(() => { | |
| if (decisionPath) { | |
| setShowParticle(true); | |
| } | |
| }, [decisionPath]); | |
| return ( | |
| <> | |
| {/* Lighting */} | |
| <ambientLight intensity={0.3} /> | |
| <pointLight position={[10, 10, 10]} intensity={1} /> | |
| <spotLight position={[0, 50, 0]} angle={0.3} penumbra={1} intensity={1} /> | |
| {/* Input Layer */} | |
| <group position={[0, -5, 0]}> | |
| <Box args={[5, 0.2, 2]}> | |
| <meshStandardMaterial color="#10b981" emissive="#10b981" emissiveIntensity={0.3} /> | |
| </Box> | |
| <Text position={[0, 0.3, 0]} fontSize={0.15} color="white"> | |
| Input Embeddings | |
| </Text> | |
| </group> | |
| {/* Transformer Layers */} | |
| {Array.from({ length: numLayers }).map((_, i) => { | |
| const activation = decisionPath?.layer_activations[i]; | |
| const isCritical = decisionPath?.critical_layers.includes(i); | |
| return ( | |
| <EnhancedTransformerLayer | |
| key={i} | |
| position={[0, i * layerSpacing, 0]} | |
| layerIndex={i} | |
| activation={activation} | |
| isCritical={isCritical} | |
| isActive={!!activation} | |
| /> | |
| ); | |
| })} | |
| {/* Output Layer */} | |
| <group position={[0, numLayers * layerSpacing + 5, 0]}> | |
| <Box args={[5, 0.2, 2]}> | |
| <meshStandardMaterial color="#f59e0b" emissive="#f59e0b" emissiveIntensity={0.3} /> | |
| </Box> | |
| <Text position={[0, 0.3, 0]} fontSize={0.15} color="white"> | |
| Output: {decisionPath?.token || "..."} | |
| </Text> | |
| <Text position={[0, -0.3, 0]} fontSize={0.1} color="#9ca3af"> | |
| Probability: {decisionPath ? (decisionPath.probability * 100).toFixed(1) : "0"}% | |
| </Text> | |
| </group> | |
| {/* Attention Flow Visualization */} | |
| {decisionPath && ( | |
| <AttentionFlowVisualization | |
| flow={decisionPath.attention_flow} | |
| layerSpacing={layerSpacing} | |
| /> | |
| )} | |
| {/* Decision Particle Animation */} | |
| {showParticle && particlePath.length > 0 && ( | |
| <DecisionParticle | |
| path={particlePath} | |
| onComplete={() => setShowParticle(false)} | |
| /> | |
| )} | |
| {/* Alternatives Display */} | |
| {decisionPath && ( | |
| <AlternativesDisplay | |
| alternatives={decisionPath.alternatives} | |
| position={[8, numLayers * layerSpacing / 2, 0]} | |
| /> | |
| )} | |
| {/* Decision Factors Display */} | |
| {decisionPath && ( | |
| <DecisionFactorsDisplay | |
| factors={decisionPath.decision_factors} | |
| position={[-8, numLayers * layerSpacing / 2, 0]} | |
| /> | |
| )} | |
| {/* Grid */} | |
| <gridHelper args={[150, 150, 0x444444, 0x222222]} /> | |
| </> | |
| ); | |
| } | |
| export default function DecisionPath3D() { | |
| const [decisionPath, setDecisionPath] = useState<DecisionPath | null>(null); | |
| const [isConnected, setIsConnected] = useState(false); | |
| const [isAnalyzing, setIsAnalyzing] = useState(false); | |
| const [isClient, setIsClient] = useState(false); | |
| const wsRef = useRef<WebSocket | null>(null); | |
| const [prompt, setPrompt] = useState("def quicksort(arr):"); | |
| // Ensure client-side only rendering | |
| useEffect(() => { | |
| setIsClient(true); | |
| }, []); | |
| // Connect to decision path service | |
| useEffect(() => { | |
| if (!isClient) return; | |
| const connectToService = () => { | |
| try { | |
| const ws = new WebSocket('ws://localhost:8769'); | |
| ws.onopen = () => { | |
| console.log('[DecisionPath3D] Connected to service'); | |
| setIsConnected(true); | |
| wsRef.current = ws; | |
| }; | |
| ws.onmessage = (event) => { | |
| const data = JSON.parse(event.data); | |
| console.log('[DecisionPath3D] Received:', data.type); | |
| if (data.type === 'decision_path') { | |
| setDecisionPath(data.data); | |
| } else if (data.type === 'analysis_complete') { | |
| setIsAnalyzing(false); | |
| } | |
| }; | |
| ws.onerror = (error) => { | |
| console.log('[DecisionPath3D] Service not available'); | |
| setIsConnected(false); | |
| }; | |
| ws.onclose = () => { | |
| console.log('[DecisionPath3D] Disconnected'); | |
| setIsConnected(false); | |
| wsRef.current = null; | |
| }; | |
| } catch (error) { | |
| console.log('[DecisionPath3D] Connection failed'); | |
| setIsConnected(false); | |
| } | |
| }; | |
| connectToService(); | |
| return () => { | |
| if (wsRef.current) { | |
| wsRef.current.close(); | |
| } | |
| }; | |
| }, [isClient]); | |
| const startAnalysis = () => { | |
| if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { | |
| setIsAnalyzing(true); | |
| wsRef.current.send(JSON.stringify({ | |
| type: 'analyze', | |
| prompt: prompt | |
| })); | |
| } | |
| }; | |
| return ( | |
| <div className="bg-gray-900 rounded-xl p-6 h-[900px]"> | |
| {/* Header */} | |
| <div className="flex items-center justify-between mb-4"> | |
| <div> | |
| <h2 className="text-2xl font-bold flex items-center gap-2"> | |
| <GitBranch className="w-6 h-6 text-yellow-400" /> | |
| Decision Path Visualization | |
| </h2> | |
| <p className="text-gray-400 mt-1"> | |
| See exactly how the model makes its decisions - the Glass Box view | |
| </p> | |
| </div> | |
| <div className="flex items-center gap-4"> | |
| <div className={`flex items-center gap-2 px-3 py-1 rounded-full ${ | |
| isConnected ? 'bg-green-900/30 text-green-400' : 'bg-yellow-900/30 text-yellow-400' | |
| }`}> | |
| <Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} /> | |
| {isConnected ? 'Connected' : 'Disconnected'} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Controls */} | |
| <div className="bg-gray-800 rounded-lg p-4 mb-4"> | |
| <div className="flex items-center gap-4"> | |
| <input | |
| type="text" | |
| value={prompt} | |
| onChange={(e) => setPrompt(e.target.value)} | |
| className="flex-1 px-3 py-2 bg-gray-900 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none font-mono text-sm" | |
| placeholder="Enter code to analyze..." | |
| /> | |
| <button | |
| onClick={startAnalysis} | |
| disabled={!isConnected || isAnalyzing} | |
| className="px-6 py-2 bg-yellow-600 text-white rounded-lg hover:bg-yellow-700 transition-colors disabled:opacity-50 flex items-center gap-2" | |
| > | |
| {isAnalyzing ? ( | |
| <> | |
| <Activity className="w-4 h-4 animate-spin" /> | |
| Analyzing... | |
| </> | |
| ) : ( | |
| <> | |
| <Sparkles className="w-4 h-4" /> | |
| Analyze Decision Path | |
| </> | |
| )} | |
| </button> | |
| </div> | |
| </div> | |
| {/* 3D Canvas */} | |
| <div className="h-[700px] bg-black rounded-lg relative"> | |
| {isClient ? ( | |
| <Canvas camera={{ position: [30, 40, 50], fov: 50 }}> | |
| <DecisionPathScene decisionPath={decisionPath} /> | |
| <OrbitControls | |
| enablePan={true} | |
| enableZoom={true} | |
| enableRotate={true} | |
| target={[0, 35, 0]} | |
| /> | |
| </Canvas> | |
| ) : ( | |
| <div className="flex items-center justify-center h-full"> | |
| <div className="text-gray-400">Loading 3D visualization...</div> | |
| </div> | |
| )} | |
| {/* Legend */} | |
| <div className="absolute top-4 right-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs"> | |
| <div className="font-semibold text-white mb-2">Decision Path</div> | |
| <div className="space-y-1"> | |
| <div className="flex items-center gap-2"> | |
| <div className="w-3 h-3 bg-red-500 rounded"></div> | |
| <span className="text-gray-300">Critical Layers</span> | |
| </div> | |
| <div className="flex items-center gap-2"> | |
| <div className="w-3 h-3 bg-teal-500 rounded"></div> | |
| <span className="text-gray-300">Active Layers</span> | |
| </div> | |
| <div className="flex items-center gap-2"> | |
| <div className="w-3 h-3 bg-yellow-500 rounded"></div> | |
| <span className="text-gray-300">Top Attention Heads</span> | |
| </div> | |
| <div className="flex items-center gap-2"> | |
| <Zap className="w-3 h-3 text-yellow-400" /> | |
| <span className="text-gray-300">Decision Flow</span> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Info Panel */} | |
| {decisionPath && ( | |
| <div className="absolute bottom-4 left-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs max-w-xs"> | |
| <div className="font-semibold text-white mb-2">Current Decision</div> | |
| <div className="space-y-1 text-gray-300"> | |
| <div>Token: <span className="text-yellow-400">{decisionPath.token}</span></div> | |
| <div>Confidence: <span className="text-green-400">{(decisionPath.confidence_score * 100).toFixed(0)}%</span></div> | |
| <div>Critical Layers: <span className="text-red-400">{decisionPath.critical_layers.join(", ")}</span></div> | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| </div> | |
| ); | |
| } |