api / frontend /DecisionPath3D.tsx
gary-boon
Deploy Visualisable.ai backend with API protection
c6c8587
raw
history blame
18.2 kB
/**
* Decision Path 3D Visualization
*
* Shows the exact path through the neural network when generating tokens,
* highlighting critical layers, attention patterns, and decision factors.
* This is the core "Glass Box" visualization for the PhD thesis.
*
* @component
*/
"use client";
import { useRef, useState, useEffect, useMemo } from "react";
import { Canvas, useFrame, useThree } from "@react-three/fiber";
import {
OrbitControls,
Text,
Box,
Sphere,
Line,
Billboard
} from "@react-three/drei";
import * as THREE from "three";
import {
Brain,
Zap,
Eye,
GitBranch,
Activity,
Sparkles,
AlertCircle,
TrendingUp,
Layers
} from "lucide-react";
interface LayerActivation {
layer_index: number;
attention_weights: number[][];
hidden_state_norm: number;
ffn_activation: number;
top_attention_heads: number[];
confidence: number;
}
interface DecisionPath {
token: string;
token_id: number;
probability: number;
layer_activations: LayerActivation[];
attention_flow: Array<{
from_layer: number;
to_layer: number | string;
strength: number;
top_heads: number[];
}>;
alternatives: Array<{
token: string;
token_id: number;
probability: number;
}>;
decision_factors: {
attention_focus: number;
semantic_alignment: number;
syntactic_correctness: number;
context_relevance: number;
confidence: number;
};
critical_layers: number[];
confidence_score: number;
timestamp: number;
}
// Enhanced Transformer Layer with activation visualization
function EnhancedTransformerLayer({
position,
layerIndex,
activation,
isCritical,
isActive
}: {
position: [number, number, number];
layerIndex: number;
activation?: LayerActivation;
isCritical?: boolean;
isActive?: boolean;
}) {
const meshRef = useRef<THREE.Mesh>(null);
const glowRef = useRef<THREE.Mesh>(null);
useFrame((state) => {
if (meshRef.current && isCritical) {
// Pulse critical layers
const scale = 1 + Math.sin(state.clock.elapsedTime * 3) * 0.1;
meshRef.current.scale.set(scale, scale, scale);
}
if (glowRef.current && activation) {
// Glow based on activation strength
glowRef.current.scale.set(
1 + activation.confidence * 0.3,
1 + activation.confidence * 0.3,
1 + activation.confidence * 0.3
);
}
});
const baseColor = isCritical ? "#ff6b6b" : isActive ? "#4ecdc4" : "#2d3748";
const emissiveIntensity = activation ? activation.confidence : 0;
return (
<group position={position}>
{/* Glow effect for active layers */}
{activation && (
<Sphere ref={glowRef} args={[2.5, 16, 16]}>
<meshBasicMaterial
color={isCritical ? "#ff6b6b" : "#4ecdc4"}
transparent
opacity={activation.confidence * 0.3}
/>
</Sphere>
)}
{/* Main layer box */}
<Box ref={meshRef} args={[4, 0.3, 3]}>
<meshStandardMaterial
color={baseColor}
emissive={baseColor}
emissiveIntensity={emissiveIntensity}
metalness={0.8}
roughness={0.2}
/>
</Box>
{/* Layer label */}
<Text
position={[0, 0.3, 0]}
fontSize={0.15}
color="white"
anchorX="center"
>
Layer {layerIndex}
</Text>
{/* Confidence indicator */}
{activation && (
<Text
position={[0, -0.3, 0]}
fontSize={0.08}
color={isCritical ? "#ff6b6b" : "#4ecdc4"}
anchorX="center"
>
Confidence: {(activation.confidence * 100).toFixed(0)}%
</Text>
)}
{/* Attention head activation visualization */}
{activation && activation.top_attention_heads.map((headIdx, i) => (
<Box
key={headIdx}
position={[(i - 1) * 0.4, 0, 1.8]}
args={[0.3, 0.15, 0.3]}
>
<meshStandardMaterial
color="#ffd93d"
emissive="#ffd93d"
emissiveIntensity={0.5}
/>
</Box>
))}
</group>
);
}
// Animated decision flow particle
function DecisionParticle({
path,
onComplete
}: {
path: THREE.Vector3[];
onComplete?: () => void;
}) {
const particleRef = useRef<THREE.Mesh>(null);
const [progress, setProgress] = useState(0);
useFrame((state, delta) => {
if (particleRef.current && progress < 1) {
const newProgress = Math.min(progress + delta * 0.3, 1);
setProgress(newProgress);
// Interpolate position along path
const segmentCount = path.length - 1;
const currentSegment = Math.floor(newProgress * segmentCount);
const segmentProgress = (newProgress * segmentCount) % 1;
if (currentSegment < segmentCount) {
const start = path[currentSegment];
const end = path[currentSegment + 1];
particleRef.current.position.lerpVectors(start, end, segmentProgress);
}
if (newProgress >= 1 && onComplete) {
onComplete();
}
}
});
return (
<Sphere ref={particleRef} args={[0.2, 16, 16]}>
<meshStandardMaterial
color="#ffd93d"
emissive="#ffd93d"
emissiveIntensity={1}
/>
</Sphere>
);
}
// Attention flow visualization
function AttentionFlowVisualization({
flow,
layerSpacing
}: {
flow: DecisionPath['attention_flow'];
layerSpacing: number;
}) {
return (
<>
{flow.map((connection, idx) => {
const startY = connection.from_layer * layerSpacing;
const endY = connection.to_layer === "output"
? 20 * layerSpacing + 5
: (connection.to_layer as number) * layerSpacing;
const points = [
new THREE.Vector3(0, startY, 0),
new THREE.Vector3(0, endY, 0)
];
return (
<Line
key={idx}
points={points}
color={connection.strength > 0.7 ? "#ff6b6b" : "#4ecdc4"}
lineWidth={connection.strength * 5}
transparent
opacity={0.6}
/>
);
})}
</>
);
}
// Alternative tokens display
function AlternativesDisplay({
alternatives,
position
}: {
alternatives: DecisionPath['alternatives'];
position: [number, number, number];
}) {
return (
<group position={position}>
<Text fontSize={0.12} color="#9ca3af" position={[0, 0.5, 0]}>
Alternatives Considered:
</Text>
{alternatives.slice(0, 3).map((alt, idx) => (
<group key={idx} position={[0, -idx * 0.3, 0]}>
<Text
fontSize={0.1}
color={idx === 0 ? "#4ecdc4" : "#6b7280"}
anchorX="center"
>
{alt.token}: {(alt.probability * 100).toFixed(1)}%
</Text>
</group>
))}
</group>
);
}
// Decision factors visualization
function DecisionFactorsDisplay({
factors,
position
}: {
factors: DecisionPath['decision_factors'];
position: [number, number, number];
}) {
const factorList = Object.entries(factors);
return (
<group position={position}>
<Text fontSize={0.12} color="#ffd93d" position={[0, 0.8, 0]}>
Decision Factors:
</Text>
{factorList.map(([key, value], idx) => {
const barWidth = value * 2;
return (
<group key={key} position={[0, -idx * 0.25, 0]}>
<Box position={[-1, 0, 0]} args={[barWidth, 0.15, 0.1]}>
<meshStandardMaterial
color={value > 0.7 ? "#4ecdc4" : value > 0.4 ? "#ffd93d" : "#ff6b6b"}
emissive={value > 0.7 ? "#4ecdc4" : "#ffd93d"}
emissiveIntensity={0.3}
/>
</Box>
<Text
position={[1, 0, 0]}
fontSize={0.08}
color="#9ca3af"
anchorX="left"
>
{key.replace(/_/g, ' ')}: {(value * 100).toFixed(0)}%
</Text>
</group>
);
})}
</group>
);
}
// Main 3D scene with decision path
function DecisionPathScene({ decisionPath }: { decisionPath: DecisionPath | null }) {
const numLayers = 20;
const layerSpacing = 3.5;
const [showParticle, setShowParticle] = useState(false);
// Create path for particle animation
const particlePath = useMemo(() => {
if (!decisionPath) return [];
const path: THREE.Vector3[] = [
new THREE.Vector3(0, -5, 0), // Start at input
];
// Add critical layers
decisionPath.critical_layers.forEach(layerIdx => {
path.push(new THREE.Vector3(0, layerIdx * layerSpacing, 0));
});
// End at output
path.push(new THREE.Vector3(0, numLayers * layerSpacing + 5, 0));
return path;
}, [decisionPath]);
useEffect(() => {
if (decisionPath) {
setShowParticle(true);
}
}, [decisionPath]);
return (
<>
{/* Lighting */}
<ambientLight intensity={0.3} />
<pointLight position={[10, 10, 10]} intensity={1} />
<spotLight position={[0, 50, 0]} angle={0.3} penumbra={1} intensity={1} />
{/* Input Layer */}
<group position={[0, -5, 0]}>
<Box args={[5, 0.2, 2]}>
<meshStandardMaterial color="#10b981" emissive="#10b981" emissiveIntensity={0.3} />
</Box>
<Text position={[0, 0.3, 0]} fontSize={0.15} color="white">
Input Embeddings
</Text>
</group>
{/* Transformer Layers */}
{Array.from({ length: numLayers }).map((_, i) => {
const activation = decisionPath?.layer_activations[i];
const isCritical = decisionPath?.critical_layers.includes(i);
return (
<EnhancedTransformerLayer
key={i}
position={[0, i * layerSpacing, 0]}
layerIndex={i}
activation={activation}
isCritical={isCritical}
isActive={!!activation}
/>
);
})}
{/* Output Layer */}
<group position={[0, numLayers * layerSpacing + 5, 0]}>
<Box args={[5, 0.2, 2]}>
<meshStandardMaterial color="#f59e0b" emissive="#f59e0b" emissiveIntensity={0.3} />
</Box>
<Text position={[0, 0.3, 0]} fontSize={0.15} color="white">
Output: {decisionPath?.token || "..."}
</Text>
<Text position={[0, -0.3, 0]} fontSize={0.1} color="#9ca3af">
Probability: {decisionPath ? (decisionPath.probability * 100).toFixed(1) : "0"}%
</Text>
</group>
{/* Attention Flow Visualization */}
{decisionPath && (
<AttentionFlowVisualization
flow={decisionPath.attention_flow}
layerSpacing={layerSpacing}
/>
)}
{/* Decision Particle Animation */}
{showParticle && particlePath.length > 0 && (
<DecisionParticle
path={particlePath}
onComplete={() => setShowParticle(false)}
/>
)}
{/* Alternatives Display */}
{decisionPath && (
<AlternativesDisplay
alternatives={decisionPath.alternatives}
position={[8, numLayers * layerSpacing / 2, 0]}
/>
)}
{/* Decision Factors Display */}
{decisionPath && (
<DecisionFactorsDisplay
factors={decisionPath.decision_factors}
position={[-8, numLayers * layerSpacing / 2, 0]}
/>
)}
{/* Grid */}
<gridHelper args={[150, 150, 0x444444, 0x222222]} />
</>
);
}
export default function DecisionPath3D() {
const [decisionPath, setDecisionPath] = useState<DecisionPath | null>(null);
const [isConnected, setIsConnected] = useState(false);
const [isAnalyzing, setIsAnalyzing] = useState(false);
const [isClient, setIsClient] = useState(false);
const wsRef = useRef<WebSocket | null>(null);
const [prompt, setPrompt] = useState("def quicksort(arr):");
// Ensure client-side only rendering
useEffect(() => {
setIsClient(true);
}, []);
// Connect to decision path service
useEffect(() => {
if (!isClient) return;
const connectToService = () => {
try {
const ws = new WebSocket('ws://localhost:8769');
ws.onopen = () => {
console.log('[DecisionPath3D] Connected to service');
setIsConnected(true);
wsRef.current = ws;
};
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
console.log('[DecisionPath3D] Received:', data.type);
if (data.type === 'decision_path') {
setDecisionPath(data.data);
} else if (data.type === 'analysis_complete') {
setIsAnalyzing(false);
}
};
ws.onerror = (error) => {
console.log('[DecisionPath3D] Service not available');
setIsConnected(false);
};
ws.onclose = () => {
console.log('[DecisionPath3D] Disconnected');
setIsConnected(false);
wsRef.current = null;
};
} catch (error) {
console.log('[DecisionPath3D] Connection failed');
setIsConnected(false);
}
};
connectToService();
return () => {
if (wsRef.current) {
wsRef.current.close();
}
};
}, [isClient]);
const startAnalysis = () => {
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
setIsAnalyzing(true);
wsRef.current.send(JSON.stringify({
type: 'analyze',
prompt: prompt
}));
}
};
return (
<div className="bg-gray-900 rounded-xl p-6 h-[900px]">
{/* Header */}
<div className="flex items-center justify-between mb-4">
<div>
<h2 className="text-2xl font-bold flex items-center gap-2">
<GitBranch className="w-6 h-6 text-yellow-400" />
Decision Path Visualization
</h2>
<p className="text-gray-400 mt-1">
See exactly how the model makes its decisions - the Glass Box view
</p>
</div>
<div className="flex items-center gap-4">
<div className={`flex items-center gap-2 px-3 py-1 rounded-full ${
isConnected ? 'bg-green-900/30 text-green-400' : 'bg-yellow-900/30 text-yellow-400'
}`}>
<Activity className={`w-4 h-4 ${isConnected ? 'animate-pulse' : ''}`} />
{isConnected ? 'Connected' : 'Disconnected'}
</div>
</div>
</div>
{/* Controls */}
<div className="bg-gray-800 rounded-lg p-4 mb-4">
<div className="flex items-center gap-4">
<input
type="text"
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="flex-1 px-3 py-2 bg-gray-900 text-white rounded-lg border border-gray-700 focus:border-blue-500 focus:outline-none font-mono text-sm"
placeholder="Enter code to analyze..."
/>
<button
onClick={startAnalysis}
disabled={!isConnected || isAnalyzing}
className="px-6 py-2 bg-yellow-600 text-white rounded-lg hover:bg-yellow-700 transition-colors disabled:opacity-50 flex items-center gap-2"
>
{isAnalyzing ? (
<>
<Activity className="w-4 h-4 animate-spin" />
Analyzing...
</>
) : (
<>
<Sparkles className="w-4 h-4" />
Analyze Decision Path
</>
)}
</button>
</div>
</div>
{/* 3D Canvas */}
<div className="h-[700px] bg-black rounded-lg relative">
{isClient ? (
<Canvas camera={{ position: [30, 40, 50], fov: 50 }}>
<DecisionPathScene decisionPath={decisionPath} />
<OrbitControls
enablePan={true}
enableZoom={true}
enableRotate={true}
target={[0, 35, 0]}
/>
</Canvas>
) : (
<div className="flex items-center justify-center h-full">
<div className="text-gray-400">Loading 3D visualization...</div>
</div>
)}
{/* Legend */}
<div className="absolute top-4 right-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs">
<div className="font-semibold text-white mb-2">Decision Path</div>
<div className="space-y-1">
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-red-500 rounded"></div>
<span className="text-gray-300">Critical Layers</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-teal-500 rounded"></div>
<span className="text-gray-300">Active Layers</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-yellow-500 rounded"></div>
<span className="text-gray-300">Top Attention Heads</span>
</div>
<div className="flex items-center gap-2">
<Zap className="w-3 h-3 text-yellow-400" />
<span className="text-gray-300">Decision Flow</span>
</div>
</div>
</div>
{/* Info Panel */}
{decisionPath && (
<div className="absolute bottom-4 left-4 bg-gray-800/90 backdrop-blur rounded-lg p-3 text-xs max-w-xs">
<div className="font-semibold text-white mb-2">Current Decision</div>
<div className="space-y-1 text-gray-300">
<div>Token: <span className="text-yellow-400">{decisionPath.token}</span></div>
<div>Confidence: <span className="text-green-400">{(decisionPath.confidence_score * 100).toFixed(0)}%</span></div>
<div>Critical Layers: <span className="text-red-400">{decisionPath.critical_layers.join(", ")}</span></div>
</div>
</div>
)}
</div>
</div>
);
}