api / frontend /DecisionPath3DFixed.tsx
gary-boon
Deploy Visualisable.ai backend with API protection
c6c8587
raw
history blame
15.2 kB
/**
* Decision Path 3D Visualization - Fixed Version
*
* Shows the exact path through the neural network when generating tokens,
* highlighting critical layers, attention patterns, and decision factors.
* This is the core "Glass Box" visualization for the PhD thesis.
*
* @component
*/
"use client";
import { useRef, useState, useEffect, useMemo } from "react";
import { Canvas, useFrame } from "@react-three/fiber";
import { OrbitControls } from "@react-three/drei";
import * as THREE from "three";
import {
GitBranch,
Activity,
Sparkles,
Zap,
Brain
} from "lucide-react";
interface LayerActivation {
layer_index: number;
attention_weights: number[][];
hidden_state_norm: number;
ffn_activation: number;
top_attention_heads: number[];
confidence: number;
}
interface DecisionPath {
token: string;
token_id: number;
probability: number;
layer_activations: LayerActivation[];
attention_flow: Array<{
from_layer: number;
to_layer: number | string;
strength: number;
top_heads: number[];
}>;
alternatives: Array<{
token: string;
token_id: number;
probability: number;
}>;
decision_factors: {
attention_focus: number;
semantic_alignment: number;
syntactic_correctness: number;
context_relevance: number;
confidence: number;
};
critical_layers: number[];
confidence_score: number;
timestamp: number;
}
// Enhanced Layer Component with proper FFN visualization
interface LayerProps {
position: [number, number, number];
layerIndex: number;
isCritical: boolean;
isActive: boolean;
activation?: LayerActivation;
}
function Layer({ position, layerIndex, isCritical, isActive, activation }: LayerProps) {
const meshRef = useRef<THREE.Mesh>(null);
const ffnRef = useRef<THREE.Mesh>(null);
useFrame((state) => {
if (meshRef.current && isCritical) {
const scale = 1 + Math.sin(state.clock.elapsedTime * 3) * 0.1;
meshRef.current.scale.set(scale, scale, scale);
}
if (ffnRef.current && activation) {
// Pulse FFN based on activation strength
const ffnScale = 1 + (activation.ffn_activation * 0.2);
ffnRef.current.scale.set(1, ffnScale, 1);
}
});
const baseColor = isCritical ? "#ff6b6b" : isActive ? "#4ecdc4" : "#2d3748";
const ffnColor = isCritical ? "#e91e63" : isActive ? "#9c27b0" : "#6b46c1";
return (
<group position={position}>
{/* Main attention layer */}
<mesh ref={meshRef}>
<boxGeometry args={[4, 0.3, 2]} />
<meshStandardMaterial
color={baseColor}
emissive={isCritical ? baseColor : "#000000"}
emissiveIntensity={isCritical ? 0.3 : 0}
metalness={0.6}
roughness={0.3}
/>
</mesh>
{/* FFN Component - positioned behind */}
<group position={[0, 0, -1.5]}>
<mesh ref={ffnRef}>
<boxGeometry args={[3, 0.2, 0.8]} />
<meshStandardMaterial
color={ffnColor}
emissive={ffnColor}
emissiveIntensity={isActive ? 0.2 : 0.1}
metalness={0.7}
roughness={0.3}
/>
</mesh>
</group>
{/* Attention heads visualization - small cubes */}
{isActive && (
<group position={[0, 0, 1.2]}>
{Array.from({ length: 16 }).map((_, i) => (
<mesh key={i} position={[(i % 4 - 1.5) * 0.3, 0, Math.floor(i / 4) * 0.2 - 0.3]}>
<boxGeometry args={[0.15, 0.1, 0.15]} />
<meshStandardMaterial
color={activation?.top_attention_heads?.includes(i) ? "#ffd93d" : "#4a5568"}
emissive={activation?.top_attention_heads?.includes(i) ? "#ffd93d" : "#000000"}
emissiveIntensity={activation?.top_attention_heads?.includes(i) ? 0.5 : 0}
/>
</mesh>
))}
</group>
)}
</group>
);
}
// Simple scene
function DecisionPathScene({ decisionPath }: { decisionPath: DecisionPath | null }) {
const numLayers = 20;
const layerSpacing = 3.5;
return (
<>
<ambientLight intensity={0.5} />
<pointLight position={[10, 10, 10]} intensity={1} />
<directionalLight position={[0, 10, 5]} intensity={0.5} />
{/* Input Layer */}
<mesh position={[0, -5, 0]}>
<boxGeometry args={[5, 0.2, 2]} />
<meshStandardMaterial color="#10b981" />
</mesh>
{/* Transformer Layers */}
{Array.from({ length: numLayers }).map((_, i) => {
const isCritical = decisionPath?.critical_layers?.includes(i) || false;
const activation = decisionPath?.layer_activations?.[i];
return (
<Layer
key={i}
position={[0, i * layerSpacing, 0]}
layerIndex={i}
isCritical={isCritical}
isActive={!!activation}
activation={activation}
/>
);
})}
{/* Output Layer */}
<mesh position={[0, numLayers * layerSpacing + 5, 0]}>
<boxGeometry args={[5, 0.2, 2]} />
<meshStandardMaterial color="#f59e0b" />
</mesh>
{/* Connection lines - simplified for now */}
{decisionPath && decisionPath.critical_layers && decisionPath.critical_layers.map((layerIdx, idx) => {
const startY = layerIdx * layerSpacing;
const endY = idx < decisionPath.critical_layers.length - 1
? decisionPath.critical_layers[idx + 1] * layerSpacing
: numLayers * layerSpacing + 5;
const points = [];
points.push(new THREE.Vector3(0, startY, 0));
points.push(new THREE.Vector3(0, endY, 0));
const geometry = new THREE.BufferGeometry().setFromPoints(points);
return (
<primitive
key={`line-${idx}`}
object={new THREE.Line(
geometry,
new THREE.LineBasicMaterial({ color: 0xff6b6b, linewidth: 3 })
)}
/>
);
})}
<gridHelper args={[100, 100, 0x444444, 0x222222]} />
</>
);
}
export default function DecisionPath3DFixed() {
const [decisionPath, setDecisionPath] = useState<DecisionPath | null>(null);
const [isConnected, setIsConnected] = useState(false);
const [isAnalyzing, setIsAnalyzing] = useState(false);
const [mounted, setMounted] = useState(false);
const [modelLoading, setModelLoading] = useState(true);
const [loadingProgress, setLoadingProgress] = useState(0);
const [loadingMessage, setLoadingMessage] = useState("Initializing...");
const wsRef = useRef<WebSocket | null>(null);
const [prompt, setPrompt] = useState("def quicksort(arr):");
useEffect(() => {
setMounted(true);
}, []);
useEffect(() => {
if (!mounted) return;
const connectToService = () => {
try {
const ws = new WebSocket('ws://localhost:8769');
ws.onopen = () => {
console.log('[DecisionPath3D] Connected to service');
setIsConnected(true);
wsRef.current = ws;
// Don't immediately set as ready - wait for model_ready or loading_progress messages
};
ws.onmessage = (event) => {
console.log('[DecisionPath3D] Raw message received:', event.data);
const data = JSON.parse(event.data);
console.log('[DecisionPath3D] Parsed message type:', data.type);
console.log('[DecisionPath3D] Message data:', data);
if (data.type === 'decision_path') {
console.log('[DecisionPath3D] Setting decision path with critical layers:', data.data?.critical_layers);
setDecisionPath(data.data);
} else if (data.type === 'analysis_complete') {
console.log('[DecisionPath3D] Analysis complete');
setIsAnalyzing(false);
} else if (data.type === 'loading_progress') {
setLoadingProgress(data.progress);
setLoadingMessage(data.message);
if (data.progress === 100) {
setModelLoading(false);
}
} else if (data.type === 'model_ready') {
setModelLoading(false);
setLoadingProgress(100);
setLoadingMessage("Model ready!");
} else if (data.type === 'loading_error') {
setModelLoading(false);
setLoadingMessage(`Error: ${data.message}`);
}
};
ws.onerror = (error) => {
console.log('[DecisionPath3D] Service not available');
setIsConnected(false);
};
ws.onclose = () => {
console.log('[DecisionPath3D] Disconnected');
setIsConnected(false);
wsRef.current = null;
};
} catch (error) {
console.log('[DecisionPath3D] Connection failed');
setIsConnected(false);
}
};
connectToService();
return () => {
if (wsRef.current) {
wsRef.current.close();
}
};
}, [mounted]);
const startAnalysis = () => {
console.log('[DecisionPath3D] Start analysis clicked');
console.log('[DecisionPath3D] WebSocket state:', wsRef.current?.readyState);
console.log('[DecisionPath3D] Is connected:', isConnected);
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
console.log('[DecisionPath3D] Sending analyze request with prompt:', prompt);
setIsAnalyzing(true);
wsRef.current.send(JSON.stringify({
type: 'analyze',
prompt: prompt
}));
} else {
console.log('[DecisionPath3D] WebSocket not ready, state:', wsRef.current?.readyState);
}
};
if (!mounted) {
return (
<div className="bg-gray-900 rounded-xl p-6 h-[900px]">
<div className="flex items-center justify-center h-full">
<div className="text-gray-400">Loading 3D visualization...</div>
</div>
</div>
);
}
return (
<div className="bg-gray-900 rounded-xl p-6 h-[900px]">
{/* Header */}
<div className="flex items-center justify-between mb-4">
<div>
<h2 className="text-2xl font-bold flex items-center gap-2">
<GitBranch className="w-6 h-6 text-yellow-400" />
Decision Path Visualization
</h2>
<p className="text-gray-400 mt-1">
See exactly how the model makes its decisions - the Glass Box view
</p>
</div>
<div className="flex items-center gap-4">
<div className={`flex items-center gap-2 px-3 py-1 rounded-full ${
isConnected ? 'bg-green-900/30 text-green-400' : 'bg-yellow-900/30 text-yellow-400'
}`}>
<Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} />
{isConnected ? 'Connected' : 'Disconnected'}
</div>
</div>
</div>
{/* Controls */}
<div className="bg-gray-800 rounded-lg p-4 mb-4">
<div className="flex items-center gap-4">
<input
type="text"
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="flex-1 px-3 py-2 bg-gray-900 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none font-mono text-sm"
placeholder="Enter code to analyze..."
/>
<button
onClick={startAnalysis}
disabled={!isConnected || isAnalyzing}
className="px-6 py-2 bg-yellow-600 text-white rounded-lg hover:bg-yellow-700 transition-colors disabled:opacity-50 flex items-center gap-2"
>
{isAnalyzing ? (
<>
<Activity className="w-4 h-4 animate-spin" />
Analyzing...
</>
) : (
<>
<Sparkles className="w-4 h-4" />
Analyze Decision Path
</>
)}
</button>
</div>
</div>
{/* 3D Canvas */}
<div className="h-[700px] bg-black rounded-lg relative">
{modelLoading ? (
<div className="flex flex-col items-center justify-center h-full">
<div className="text-white mb-4">
<Brain className="w-16 h-16 animate-pulse" />
</div>
<div className="text-xl text-white mb-2">Loading Model</div>
<div className="text-sm text-gray-400 mb-4">{loadingMessage}</div>
<div className="w-64 h-2 bg-gray-700 rounded-full overflow-hidden">
<div
className="h-full bg-gradient-to-r from-blue-500 to-purple-500 transition-all duration-500"
style={{ width: `${loadingProgress}%` }}
/>
</div>
<div className="text-xs text-gray-500 mt-2">{loadingProgress}%</div>
<div className="text-xs text-gray-500 mt-4">356M parameters • 20 layers • 16 attention heads</div>
</div>
) : (
<Canvas camera={{ position: [-40, 50, 40], fov: 50 }}>
<DecisionPathScene decisionPath={decisionPath} />
<OrbitControls
enablePan={true}
enableZoom={true}
enableRotate={true}
target={[0, 35, 0]}
/>
</Canvas>
)}
{/* Legend */}
<div className="absolute top-4 right-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs">
<div className="font-semibold text-white mb-2">Decision Path</div>
<div className="space-y-1">
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-red-500 rounded"></div>
<span className="text-gray-300">Critical Layers</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-teal-500 rounded"></div>
<span className="text-gray-300">Active Layers</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-yellow-500 rounded"></div>
<span className="text-gray-300">Top Attention Heads</span>
</div>
<div className="flex items-center gap-2">
<Zap className="w-3 h-3 text-yellow-400" />
<span className="text-gray-300">Decision Flow</span>
</div>
</div>
</div>
{/* Info Panel */}
{decisionPath && (
<div className="absolute bottom-4 left-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs max-w-xs">
<div className="font-semibold text-white mb-2">Current Decision</div>
<div className="space-y-1 text-gray-300">
<div>Token: <span className="text-yellow-400">{decisionPath.token}</span></div>
<div>Confidence: <span className="text-green-400">{(decisionPath.confidence_score * 100).toFixed(0)}%</span></div>
<div>Critical Layers: <span className="text-red-400">{decisionPath.critical_layers.join(", ")}</span></div>
</div>
</div>
)}
</div>
</div>
);
}