Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| /** | |
| * Decision Path 3D Visualization - Fixed Version | |
| * | |
| * Shows the exact path through the neural network when generating tokens, | |
| * highlighting critical layers, attention patterns, and decision factors. | |
| * This is the core "Glass Box" visualization for the PhD thesis. | |
| * | |
| * @component | |
| */ | |
| "use client"; | |
| import { useRef, useState, useEffect, useMemo } from "react"; | |
| import { Canvas, useFrame } from "@react-three/fiber"; | |
| import { OrbitControls } from "@react-three/drei"; | |
| import * as THREE from "three"; | |
| import { | |
| GitBranch, | |
| Activity, | |
| Sparkles, | |
| Zap, | |
| Brain | |
| } from "lucide-react"; | |
| interface LayerActivation { | |
| layer_index: number; | |
| attention_weights: number[][]; | |
| hidden_state_norm: number; | |
| ffn_activation: number; | |
| top_attention_heads: number[]; | |
| confidence: number; | |
| } | |
| interface DecisionPath { | |
| token: string; | |
| token_id: number; | |
| probability: number; | |
| layer_activations: LayerActivation[]; | |
| attention_flow: Array<{ | |
| from_layer: number; | |
| to_layer: number | string; | |
| strength: number; | |
| top_heads: number[]; | |
| }>; | |
| alternatives: Array<{ | |
| token: string; | |
| token_id: number; | |
| probability: number; | |
| }>; | |
| decision_factors: { | |
| attention_focus: number; | |
| semantic_alignment: number; | |
| syntactic_correctness: number; | |
| context_relevance: number; | |
| confidence: number; | |
| }; | |
| critical_layers: number[]; | |
| confidence_score: number; | |
| timestamp: number; | |
| } | |
| // Enhanced Layer Component with proper FFN visualization | |
| interface LayerProps { | |
| position: [number, number, number]; | |
| layerIndex: number; | |
| isCritical: boolean; | |
| isActive: boolean; | |
| activation?: LayerActivation; | |
| } | |
| function Layer({ position, layerIndex, isCritical, isActive, activation }: LayerProps) { | |
| const meshRef = useRef<THREE.Mesh>(null); | |
| const ffnRef = useRef<THREE.Mesh>(null); | |
| useFrame((state) => { | |
| if (meshRef.current && isCritical) { | |
| const scale = 1 + Math.sin(state.clock.elapsedTime * 3) * 0.1; | |
| meshRef.current.scale.set(scale, scale, scale); | |
| } | |
| if (ffnRef.current && activation) { | |
| // Pulse FFN based on activation strength | |
| const ffnScale = 1 + (activation.ffn_activation * 0.2); | |
| ffnRef.current.scale.set(1, ffnScale, 1); | |
| } | |
| }); | |
| const baseColor = isCritical ? "#ff6b6b" : isActive ? "#4ecdc4" : "#2d3748"; | |
| const ffnColor = isCritical ? "#e91e63" : isActive ? "#9c27b0" : "#6b46c1"; | |
| return ( | |
| <group position={position}> | |
| {/* Main attention layer */} | |
| <mesh ref={meshRef}> | |
| <boxGeometry args={[4, 0.3, 2]} /> | |
| <meshStandardMaterial | |
| color={baseColor} | |
| emissive={isCritical ? baseColor : "#000000"} | |
| emissiveIntensity={isCritical ? 0.3 : 0} | |
| metalness={0.6} | |
| roughness={0.3} | |
| /> | |
| </mesh> | |
| {/* FFN Component - positioned behind */} | |
| <group position={[0, 0, -1.5]}> | |
| <mesh ref={ffnRef}> | |
| <boxGeometry args={[3, 0.2, 0.8]} /> | |
| <meshStandardMaterial | |
| color={ffnColor} | |
| emissive={ffnColor} | |
| emissiveIntensity={isActive ? 0.2 : 0.1} | |
| metalness={0.7} | |
| roughness={0.3} | |
| /> | |
| </mesh> | |
| </group> | |
| {/* Attention heads visualization - small cubes */} | |
| {isActive && ( | |
| <group position={[0, 0, 1.2]}> | |
| {Array.from({ length: 16 }).map((_, i) => ( | |
| <mesh key={i} position={[(i % 4 - 1.5) * 0.3, 0, Math.floor(i / 4) * 0.2 - 0.3]}> | |
| <boxGeometry args={[0.15, 0.1, 0.15]} /> | |
| <meshStandardMaterial | |
| color={activation?.top_attention_heads?.includes(i) ? "#ffd93d" : "#4a5568"} | |
| emissive={activation?.top_attention_heads?.includes(i) ? "#ffd93d" : "#000000"} | |
| emissiveIntensity={activation?.top_attention_heads?.includes(i) ? 0.5 : 0} | |
| /> | |
| </mesh> | |
| ))} | |
| </group> | |
| )} | |
| </group> | |
| ); | |
| } | |
| // Simple scene | |
| function DecisionPathScene({ decisionPath }: { decisionPath: DecisionPath | null }) { | |
| const numLayers = 20; | |
| const layerSpacing = 3.5; | |
| return ( | |
| <> | |
| <ambientLight intensity={0.5} /> | |
| <pointLight position={[10, 10, 10]} intensity={1} /> | |
| <directionalLight position={[0, 10, 5]} intensity={0.5} /> | |
| {/* Input Layer */} | |
| <mesh position={[0, -5, 0]}> | |
| <boxGeometry args={[5, 0.2, 2]} /> | |
| <meshStandardMaterial color="#10b981" /> | |
| </mesh> | |
| {/* Transformer Layers */} | |
| {Array.from({ length: numLayers }).map((_, i) => { | |
| const isCritical = decisionPath?.critical_layers?.includes(i) || false; | |
| const activation = decisionPath?.layer_activations?.[i]; | |
| return ( | |
| <Layer | |
| key={i} | |
| position={[0, i * layerSpacing, 0]} | |
| layerIndex={i} | |
| isCritical={isCritical} | |
| isActive={!!activation} | |
| activation={activation} | |
| /> | |
| ); | |
| })} | |
| {/* Output Layer */} | |
| <mesh position={[0, numLayers * layerSpacing + 5, 0]}> | |
| <boxGeometry args={[5, 0.2, 2]} /> | |
| <meshStandardMaterial color="#f59e0b" /> | |
| </mesh> | |
| {/* Connection lines - simplified for now */} | |
| {decisionPath && decisionPath.critical_layers && decisionPath.critical_layers.map((layerIdx, idx) => { | |
| const startY = layerIdx * layerSpacing; | |
| const endY = idx < decisionPath.critical_layers.length - 1 | |
| ? decisionPath.critical_layers[idx + 1] * layerSpacing | |
| : numLayers * layerSpacing + 5; | |
| const points = []; | |
| points.push(new THREE.Vector3(0, startY, 0)); | |
| points.push(new THREE.Vector3(0, endY, 0)); | |
| const geometry = new THREE.BufferGeometry().setFromPoints(points); | |
| return ( | |
| <primitive | |
| key={`line-${idx}`} | |
| object={new THREE.Line( | |
| geometry, | |
| new THREE.LineBasicMaterial({ color: 0xff6b6b, linewidth: 3 }) | |
| )} | |
| /> | |
| ); | |
| })} | |
| <gridHelper args={[100, 100, 0x444444, 0x222222]} /> | |
| </> | |
| ); | |
| } | |
| export default function DecisionPath3DFixed() { | |
| const [decisionPath, setDecisionPath] = useState<DecisionPath | null>(null); | |
| const [isConnected, setIsConnected] = useState(false); | |
| const [isAnalyzing, setIsAnalyzing] = useState(false); | |
| const [mounted, setMounted] = useState(false); | |
| const [modelLoading, setModelLoading] = useState(true); | |
| const [loadingProgress, setLoadingProgress] = useState(0); | |
| const [loadingMessage, setLoadingMessage] = useState("Initializing..."); | |
| const wsRef = useRef<WebSocket | null>(null); | |
| const [prompt, setPrompt] = useState("def quicksort(arr):"); | |
| useEffect(() => { | |
| setMounted(true); | |
| }, []); | |
| useEffect(() => { | |
| if (!mounted) return; | |
| const connectToService = () => { | |
| try { | |
| const ws = new WebSocket('ws://localhost:8769'); | |
| ws.onopen = () => { | |
| console.log('[DecisionPath3D] Connected to service'); | |
| setIsConnected(true); | |
| wsRef.current = ws; | |
| // Don't immediately set as ready - wait for model_ready or loading_progress messages | |
| }; | |
| ws.onmessage = (event) => { | |
| console.log('[DecisionPath3D] Raw message received:', event.data); | |
| const data = JSON.parse(event.data); | |
| console.log('[DecisionPath3D] Parsed message type:', data.type); | |
| console.log('[DecisionPath3D] Message data:', data); | |
| if (data.type === 'decision_path') { | |
| console.log('[DecisionPath3D] Setting decision path with critical layers:', data.data?.critical_layers); | |
| setDecisionPath(data.data); | |
| } else if (data.type === 'analysis_complete') { | |
| console.log('[DecisionPath3D] Analysis complete'); | |
| setIsAnalyzing(false); | |
| } else if (data.type === 'loading_progress') { | |
| setLoadingProgress(data.progress); | |
| setLoadingMessage(data.message); | |
| if (data.progress === 100) { | |
| setModelLoading(false); | |
| } | |
| } else if (data.type === 'model_ready') { | |
| setModelLoading(false); | |
| setLoadingProgress(100); | |
| setLoadingMessage("Model ready!"); | |
| } else if (data.type === 'loading_error') { | |
| setModelLoading(false); | |
| setLoadingMessage(`Error: ${data.message}`); | |
| } | |
| }; | |
| ws.onerror = (error) => { | |
| console.log('[DecisionPath3D] Service not available'); | |
| setIsConnected(false); | |
| }; | |
| ws.onclose = () => { | |
| console.log('[DecisionPath3D] Disconnected'); | |
| setIsConnected(false); | |
| wsRef.current = null; | |
| }; | |
| } catch (error) { | |
| console.log('[DecisionPath3D] Connection failed'); | |
| setIsConnected(false); | |
| } | |
| }; | |
| connectToService(); | |
| return () => { | |
| if (wsRef.current) { | |
| wsRef.current.close(); | |
| } | |
| }; | |
| }, [mounted]); | |
| const startAnalysis = () => { | |
| console.log('[DecisionPath3D] Start analysis clicked'); | |
| console.log('[DecisionPath3D] WebSocket state:', wsRef.current?.readyState); | |
| console.log('[DecisionPath3D] Is connected:', isConnected); | |
| if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { | |
| console.log('[DecisionPath3D] Sending analyze request with prompt:', prompt); | |
| setIsAnalyzing(true); | |
| wsRef.current.send(JSON.stringify({ | |
| type: 'analyze', | |
| prompt: prompt | |
| })); | |
| } else { | |
| console.log('[DecisionPath3D] WebSocket not ready, state:', wsRef.current?.readyState); | |
| } | |
| }; | |
| if (!mounted) { | |
| return ( | |
| <div className="bg-gray-900 rounded-xl p-6 h-[900px]"> | |
| <div className="flex items-center justify-center h-full"> | |
| <div className="text-gray-400">Loading 3D visualization...</div> | |
| </div> | |
| </div> | |
| ); | |
| } | |
| return ( | |
| <div className="bg-gray-900 rounded-xl p-6 h-[900px]"> | |
| {/* Header */} | |
| <div className="flex items-center justify-between mb-4"> | |
| <div> | |
| <h2 className="text-2xl font-bold flex items-center gap-2"> | |
| <GitBranch className="w-6 h-6 text-yellow-400" /> | |
| Decision Path Visualization | |
| </h2> | |
| <p className="text-gray-400 mt-1"> | |
| See exactly how the model makes its decisions - the Glass Box view | |
| </p> | |
| </div> | |
| <div className="flex items-center gap-4"> | |
| <div className={`flex items-center gap-2 px-3 py-1 rounded-full ${ | |
| isConnected ? 'bg-green-900/30 text-green-400' : 'bg-yellow-900/30 text-yellow-400' | |
| }`}> | |
| <Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} /> | |
| {isConnected ? 'Connected' : 'Disconnected'} | |
| </div> | |
| </div> | |
| </div> | |
| {/* Controls */} | |
| <div className="bg-gray-800 rounded-lg p-4 mb-4"> | |
| <div className="flex items-center gap-4"> | |
| <input | |
| type="text" | |
| value={prompt} | |
| onChange={(e) => setPrompt(e.target.value)} | |
| className="flex-1 px-3 py-2 bg-gray-900 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none font-mono text-sm" | |
| placeholder="Enter code to analyze..." | |
| /> | |
| <button | |
| onClick={startAnalysis} | |
| disabled={!isConnected || isAnalyzing} | |
| className="px-6 py-2 bg-yellow-600 text-white rounded-lg hover:bg-yellow-700 transition-colors disabled:opacity-50 flex items-center gap-2" | |
| > | |
| {isAnalyzing ? ( | |
| <> | |
| <Activity className="w-4 h-4 animate-spin" /> | |
| Analyzing... | |
| </> | |
| ) : ( | |
| <> | |
| <Sparkles className="w-4 h-4" /> | |
| Analyze Decision Path | |
| </> | |
| )} | |
| </button> | |
| </div> | |
| </div> | |
| {/* 3D Canvas */} | |
| <div className="h-[700px] bg-black rounded-lg relative"> | |
| {modelLoading ? ( | |
| <div className="flex flex-col items-center justify-center h-full"> | |
| <div className="text-white mb-4"> | |
| <Brain className="w-16 h-16 animate-pulse" /> | |
| </div> | |
| <div className="text-xl text-white mb-2">Loading Model</div> | |
| <div className="text-sm text-gray-400 mb-4">{loadingMessage}</div> | |
| <div className="w-64 h-2 bg-gray-700 rounded-full overflow-hidden"> | |
| <div | |
| className="h-full bg-gradient-to-r from-blue-500 to-purple-500 transition-all duration-500" | |
| style={{ width: `${loadingProgress}%` }} | |
| /> | |
| </div> | |
| <div className="text-xs text-gray-500 mt-2">{loadingProgress}%</div> | |
| <div className="text-xs text-gray-500 mt-4">356M parameters • 20 layers • 16 attention heads</div> | |
| </div> | |
| ) : ( | |
| <Canvas camera={{ position: [-40, 50, 40], fov: 50 }}> | |
| <DecisionPathScene decisionPath={decisionPath} /> | |
| <OrbitControls | |
| enablePan={true} | |
| enableZoom={true} | |
| enableRotate={true} | |
| target={[0, 35, 0]} | |
| /> | |
| </Canvas> | |
| )} | |
| {/* Legend */} | |
| <div className="absolute top-4 right-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs"> | |
| <div className="font-semibold text-white mb-2">Decision Path</div> | |
| <div className="space-y-1"> | |
| <div className="flex items-center gap-2"> | |
| <div className="w-3 h-3 bg-red-500 rounded"></div> | |
| <span className="text-gray-300">Critical Layers</span> | |
| </div> | |
| <div className="flex items-center gap-2"> | |
| <div className="w-3 h-3 bg-teal-500 rounded"></div> | |
| <span className="text-gray-300">Active Layers</span> | |
| </div> | |
| <div className="flex items-center gap-2"> | |
| <div className="w-3 h-3 bg-yellow-500 rounded"></div> | |
| <span className="text-gray-300">Top Attention Heads</span> | |
| </div> | |
| <div className="flex items-center gap-2"> | |
| <Zap className="w-3 h-3 text-yellow-400" /> | |
| <span className="text-gray-300">Decision Flow</span> | |
| </div> | |
| </div> | |
| </div> | |
| {/* Info Panel */} | |
| {decisionPath && ( | |
| <div className="absolute bottom-4 left-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs max-w-xs"> | |
| <div className="font-semibold text-white mb-2">Current Decision</div> | |
| <div className="space-y-1 text-gray-300"> | |
| <div>Token: <span className="text-yellow-400">{decisionPath.token}</span></div> | |
| <div>Confidence: <span className="text-green-400">{(decisionPath.confidence_score * 100).toFixed(0)}%</span></div> | |
| <div>Critical Layers: <span className="text-red-400">{decisionPath.critical_layers.join(", ")}</span></div> | |
| </div> | |
| </div> | |
| )} | |
| </div> | |
| </div> | |
| ); | |
| } |