api / frontend /ModelArchitecture3D.tsx
gary-boon
Deploy Visualisable.ai backend with API protection
c6c8587
raw
history blame
25.3 kB
/**
* 3D Model Architecture Visualization
*
* Interactive 3D visualization of the transformer model architecture,
* showing layers, attention heads, and data flow with real values.
* Inspired by neural network architecture diagrams.
*
* @component
*/
"use client";
import { useRef, useState, useEffect, Suspense } from "react";
import { Canvas, useFrame, useThree } from "@react-three/fiber";
import { getApiUrl } from "@/lib/config";
import {
OrbitControls,
Text,
Box,
Plane,
Line,
Billboard,
PerspectiveCamera,
Environment,
Float
} from "@react-three/drei";
import * as THREE from "three";
import {
Brain,
Layers,
Activity,
Zap,
Eye,
GitBranch,
Maximize2,
Move3D,
HelpCircle,
X,
Info
} from "lucide-react";
// Layer component representing a transformer layer
function TransformerLayer({
position,
layerIndex,
attentionValues,
onClick,
isActive
}: {
position: [number, number, number];
layerIndex: number;
attentionValues?: number[][];
onClick?: () => void;
isActive?: boolean;
}) {
const meshRef = useRef<THREE.Mesh>(null);
const [hovered, setHovered] = useState(false);
useFrame((state) => {
if (meshRef.current) {
// NO floating animation - layers stay at their exact positions
// meshRef.current.position.y = position[1]; // Keep at exact position
// Only pulse when active
if (isActive) {
const scale = 1 + Math.sin(state.clock.elapsedTime * 3) * 0.05;
meshRef.current.scale.set(scale, scale, scale);
}
}
});
return (
<group position={position}>
{/* Main layer box */}
<Box
ref={meshRef}
args={[4, 0.3, 3]}
onClick={onClick}
onPointerOver={() => setHovered(true)}
onPointerOut={() => setHovered(false)}
>
<meshStandardMaterial
color={isActive ? "#3b82f6" : hovered ? "#4b5563" : "#1f2937"}
emissive={isActive ? "#3b82f6" : "#000000"}
emissiveIntensity={isActive ? 0.2 : 0}
metalness={0.8}
roughness={0.2}
transparent
opacity={0.9}
/>
</Box>
{/* Layer label */}
<Text
position={[0, 0.3, 0]}
fontSize={0.15}
color="white"
anchorX="center"
anchorY="middle"
>
Layer {layerIndex}
</Text>
{/* Debug: Show actual position */}
<Text
position={[2.5, 0, 0]}
fontSize={0.08}
color="#666"
anchorX="center"
>
Y: {position[1].toFixed(1)}
</Text>
{/* Attention heads visualization */}
<group position={[0, 0, 1.8]}>
{Array.from({ length: 16 }).map((_, i) => (
<Box
key={i}
position={[(i % 4 - 1.5) * 0.25, 0, Math.floor(i / 4) * 0.25 - 0.375]}
args={[0.2, 0.1, 0.2]}
>
<meshStandardMaterial
color={`hsl(${120 + i * 10}, 70%, 50%)`}
emissive={`hsl(${120 + i * 10}, 70%, 50%)`}
emissiveIntensity={0.3}
/>
</Box>
))}
<Text
position={[0, 0.2, 0]}
fontSize={0.1}
color="#9ca3af"
anchorX="center"
>
16 Attention Heads
</Text>
</group>
{/* FFN visualization */}
<group position={[0, 0, -1.8]}>
<Box args={[3, 0.1, 0.5]}>
<meshStandardMaterial
color="#8b5cf6"
emissive="#8b5cf6"
emissiveIntensity={0.2}
metalness={0.6}
roughness={0.3}
/>
</Box>
<Text
position={[0, 0.15, 0]}
fontSize={0.1}
color="#9ca3af"
anchorX="center"
>
FFN (4096d)
</Text>
</group>
</group>
);
}
// Attention flow visualization
function AttentionFlow({
startPos,
endPos,
intensity = 1,
color = "#3b82f6"
}: {
startPos: [number, number, number];
endPos: [number, number, number];
intensity?: number;
color?: string;
}) {
const lineRef = useRef<THREE.BufferGeometry>(null);
useFrame((state) => {
// Animate the flow
const time = state.clock.elapsedTime;
// You could add particle effects here
});
const points = [
new THREE.Vector3(...startPos),
new THREE.Vector3(...endPos)
];
return (
<Line
points={points}
color={color}
lineWidth={intensity * 2}
transparent
opacity={0.6}
/>
);
}
// Token embedding visualization
function TokenEmbedding({ position }: { position: [number, number, number] }) {
const meshRef = useRef<THREE.Mesh>(null);
useFrame((state) => {
if (meshRef.current) {
meshRef.current.rotation.y = state.clock.elapsedTime * 0.5;
}
});
return (
<group position={position}>
<Box ref={meshRef} args={[5, 0.2, 2]}>
<meshStandardMaterial
color="#10b981"
emissive="#10b981"
emissiveIntensity={0.3}
metalness={0.7}
roughness={0.2}
/>
</Box>
<Text
position={[0, 0.3, 0]}
fontSize={0.15}
color="white"
anchorX="center"
>
Token Embeddings
</Text>
<Text
position={[0, -0.3, 0]}
fontSize={0.1}
color="#9ca3af"
anchorX="center"
>
51,200 × 1,024
</Text>
</group>
);
}
// Output layer visualization
function OutputLayer({ position, modelInfo }: { position: [number, number, number]; modelInfo: { layers: number; heads: number; vocabSize: number; hiddenSize: number; totalParams: number } }) {
const meshRef = useRef<THREE.Mesh>(null);
const [probabilities, setProbabilities] = useState<number[]>([]);
// Debug log position
useEffect(() => {
console.log(`OutputLayer rendered at position: [${position[0]}, ${position[1]}, ${position[2]}]`);
}, [position]);
useEffect(() => {
// Simulate probability distribution
const probs = Array.from({ length: modelInfo.heads }, () => Math.random());
setProbabilities(probs);
}, []);
useFrame((state) => {
if (meshRef.current) {
meshRef.current.rotation.y = state.clock.elapsedTime * 0.3;
}
});
return (
<group position={position}>
<Box ref={meshRef} args={[5, 0.2, 2]}>
<meshStandardMaterial
color="#f59e0b"
emissive="#f59e0b"
emissiveIntensity={0.3}
metalness={0.7}
roughness={0.2}
/>
</Box>
<Text
position={[0, 0.3, 0]}
fontSize={0.15}
color="white"
anchorX="center"
>
Output Probabilities
</Text>
<Text
position={[0, -0.3, 0]}
fontSize={0.1}
color="#9ca3af"
anchorX="center"
>
51,200 tokens
</Text>
{/* Probability bars */}
<group position={[0, 0.6, 0]}>
{probabilities.slice(0, 10).map((prob, i) => (
<Box
key={i}
position={[(i - 4.5) * 0.5, prob * 0.5, 0]}
args={[0.3, prob, 0.1]}
>
<meshStandardMaterial
color={`hsl(${prob * 120}, 70%, 50%)`}
emissive={`hsl(${prob * 120}, 70%, 50%)`}
emissiveIntensity={0.3}
/>
</Box>
))}
</group>
</group>
);
}
// Main 3D scene
function Scene({ modelInfo }: { modelInfo: { layers: number; heads: number; vocabSize: number; hiddenSize: number; totalParams: number } }) {
const [selectedLayer, setSelectedLayer] = useState<number | null>(null);
const { camera } = useThree();
// Model configuration from fetched data
const numLayers = modelInfo.layers;
const layerSpacing = 3.5; // Much larger spacing for clear separation
// Calculate positions
const outputYPosition = numLayers * layerSpacing + 5;
const inputYPosition = -5;
// Log to verify we're creating 20 layers
useEffect(() => {
console.log(`Creating ${numLayers} transformer layers`);
console.log(`Layer spacing: ${layerSpacing}`);
console.log(`Layer positions: `, Array.from({ length: numLayers }, (_, i) => i * layerSpacing));
console.log(`Last layer (19) position: ${(numLayers - 1) * layerSpacing}`);
console.log(`Output position calculated: ${outputYPosition}`);
console.log(`Input position: ${inputYPosition}`);
}, []);
return (
<>
{/* Lighting */}
<ambientLight intensity={0.3} />
<pointLight position={[10, 10, 10]} intensity={1} />
<pointLight position={[-10, -10, -10]} intensity={0.5} />
<spotLight position={[0, 20, 0]} angle={0.3} penumbra={1} intensity={1} />
{/* Token Embeddings (Input Layer) */}
<TokenEmbedding position={[0, inputYPosition, 0]} />
{/* Transformer Layers (0-19) */}
{Array.from({ length: numLayers }).map((_, i) => {
const yPosition = i * layerSpacing;
return (
<TransformerLayer
key={`layer-${i}`}
position={[0, yPosition, 0]}
layerIndex={i}
onClick={() => setSelectedLayer(i)}
isActive={selectedLayer === i}
/>
);
})}
{/* Output Layer (After Layer 19) */}
<OutputLayer position={[0, outputYPosition, 0]} modelInfo={modelInfo} />
{/* Attention flows between layers */}
{Array.from({ length: numLayers - 1 }).map((_, i) => (
<AttentionFlow
key={`flow-${i}`}
startPos={[0, i * layerSpacing + 0.3, 0]}
endPos={[0, (i + 1) * layerSpacing - 0.3, 0]}
intensity={0.5}
/>
))}
{/* Flow from input to first layer */}
<AttentionFlow
startPos={[0, -4.5, 0]}
endPos={[0, -0.3, 0]}
intensity={0.5}
color="#10b981"
/>
{/* Flow from last layer to output */}
<AttentionFlow
startPos={[0, (numLayers - 1) * layerSpacing + 0.3, 0]}
endPos={[0, outputYPosition - 0.5, 0]}
intensity={0.5}
color="#f59e0b"
/>
{/* Grid for reference */}
<gridHelper args={[150, 150, 0x444444, 0x222222]} />
{/* Layer info display */}
{selectedLayer !== null && (
<Billboard position={[5, selectedLayer * layerSpacing, 0]}>
<Text fontSize={0.2} color="white">
Layer {selectedLayer} Details
</Text>
<Text position={[0, -0.3, 0]} fontSize={0.1} color="#9ca3af">
• 16 attention heads
</Text>
<Text position={[0, -0.5, 0]} fontSize={0.1} color="#9ca3af">
• 1024 hidden dimensions
</Text>
<Text position={[0, -0.7, 0]} fontSize={0.1} color="#9ca3af">
• 4096 FFN dimensions
</Text>
<Text position={[0, -0.9, 0]} fontSize={0.1} color="#9ca3af">
• Y Position: {(selectedLayer * layerSpacing).toFixed(1)}
</Text>
</Billboard>
)}
{/* Debug: Show output actual position */}
<Billboard position={[8, outputYPosition, 0]}>
<Text fontSize={0.15} color="#f59e0b">
Output Y: {outputYPosition.toFixed(1)}
</Text>
</Billboard>
{/* Debug: Show highest layer position */}
<Billboard position={[-8, (numLayers - 1) * layerSpacing, 0]}>
<Text fontSize={0.15} color="#3b82f6">
Layer 19 Y: {((numLayers - 1) * layerSpacing).toFixed(1)}
</Text>
</Billboard>
</>
);
}
export default function ModelArchitecture3D() {
const [viewMode, setViewMode] = useState<"perspective" | "top" | "side">("perspective");
const [showLabels, setShowLabels] = useState(true);
const [autoRotate, setAutoRotate] = useState(false); // Start without auto-rotate for better control
const [showExplanation, setShowExplanation] = useState(false);
// Fetch real model data
const [modelInfo, setModelInfo] = useState({
layers: 20,
heads: 16,
vocabSize: 51200,
hiddenSize: 1024,
totalParams: 356712448
});
useEffect(() => {
fetch(`${getApiUrl()}/model/info`)
.then(res => res.json())
.then(data => {
setModelInfo({
layers: data.layers,
heads: data.heads,
vocabSize: data.vocabSize,
hiddenSize: data.hiddenSize,
totalParams: data.totalParams
});
})
.catch(err => console.log('Using default model info'));
}, []);
// Generate contextual explanation for current visualization
const generateExplanation = () => {
return {
title: "Transformer Architecture Visualization",
description: `Interactive 3D view of a ${modelInfo.layers}-layer transformer model with ${modelInfo.heads} attention heads per layer.`,
details: [
{
heading: "What is a Transformer?",
content: `Transformers are the foundation of modern LLMs. They process text by passing it through multiple layers, each applying attention mechanisms to understand relationships between words. This model has ${modelInfo.layers} layers stacked vertically.`
},
{
heading: "Reading the 3D Structure",
content: `Bottom (green): Input embeddings convert text to vectors. Middle (blue): ${modelInfo.layers} transformer layers process information. Top (orange): Output layer predicts next tokens. The vertical flow shows how data moves through the network.`
},
{
heading: "Attention Heads (Small Blocks)",
content: `Each layer contains ${modelInfo.heads} attention heads (colored blocks). These heads learn different aspects of language: grammar, semantics, context, etc. They work in parallel, each focusing on different patterns.`
},
{
heading: "Feed-Forward Networks (Purple)",
content: `The purple blocks are FFN components with 4096 dimensions. After attention, these networks transform the data further, adding non-linearity and learning complex patterns.`
},
{
heading: "Information Flow",
content: `Data flows upward from input to output. Each layer refines the understanding, building from simple patterns (early layers) to complex concepts (later layers). The lines show this sequential processing.`
},
{
heading: "Model Scale",
content: `This architecture represents ~${Math.round(modelInfo.totalParams / 1e6)}M parameters. With ${modelInfo.vocabSize.toLocaleString()} token vocabulary, ${modelInfo.hiddenSize} hidden dimensions, and ${modelInfo.layers} layers, it can generate coherent code by learning patterns from training data.`
}
]
};
};
const explanation = generateExplanation();
return (
<div className="bg-gray-900 rounded-xl p-6 h-[800px]">
{/* Header */}
<div className="flex items-center justify-between mb-4">
<div>
<h2 className="text-2xl font-bold flex items-center gap-2">
<Move3D className="w-6 h-6 text-blue-400" />
3D Model Architecture
</h2>
<p className="text-gray-400 mt-1">
Interactive 3D visualization of the transformer architecture
</p>
</div>
{/* Controls */}
<div className="flex items-center gap-2">
<button
onClick={() => setAutoRotate(!autoRotate)}
className={`px-3 py-1.5 rounded-lg text-sm ${
autoRotate ? 'bg-blue-600 text-white' : 'bg-gray-800 text-gray-300'
}`}
>
Auto-rotate
</button>
<button
onClick={() => setShowLabels(!showLabels)}
className={`px-3 py-1.5 rounded-lg text-sm ${
showLabels ? 'bg-blue-600 text-white' : 'bg-gray-800 text-gray-300'
}`}
>
Labels
</button>
</div>
</div>
{/* Main Content Area with Side Panel */}
<div className="flex gap-4">
{/* 3D Canvas */}
<div className="flex-1 min-w-0 transition-all duration-500 ease-in-out">
<div className="h-[700px] bg-black rounded-lg relative">
<Canvas camera={{ position: [50, 40, 70], fov: 45 }}> {/* Camera much further back for full view */}
<Suspense fallback={null}>
<Scene modelInfo={modelInfo} />
<OrbitControls
enablePan={true}
enableZoom={true}
enableRotate={true}
autoRotate={autoRotate}
autoRotateSpeed={0.5}
minDistance={20}
maxDistance={200}
target={[0, 35, 0]}
/>
<Environment preset="city" />
</Suspense>
</Canvas>
{/* Help Toggle Button */}
<button
onClick={() => setShowExplanation(!showExplanation)}
className="absolute top-4 left-4 z-10 p-2 bg-blue-600/90 hover:bg-blue-700 text-white rounded-lg transition-colors flex items-center gap-2 backdrop-blur"
>
{showExplanation ? <X className="w-5 h-5" /> : <HelpCircle className="w-5 h-5" />}
<span className="text-sm font-medium">
{showExplanation ? 'Hide Info' : 'What am I seeing?'}
</span>
</button>
{/* Instructions */}
<div className="absolute bottom-4 left-4 bg-gray-800/80 backdrop-blur rounded-lg p-3 text-xs text-gray-400">
<div className="flex items-center gap-2 mb-1">
<Eye className="w-3 h-3" />
<span>Click layers to inspect</span>
</div>
<div className="flex items-center gap-2 mb-1">
<Move3D className="w-3 h-3" />
<span>Drag to rotate • Scroll to zoom</span>
</div>
<div className="flex items-center gap-2">
<Layers className="w-3 h-3" />
<span>{modelInfo.layers} layers × {modelInfo.heads} heads = {modelInfo.layers * modelInfo.heads} attention patterns</span>
</div>
<div className="mt-2 pt-2 border-t border-gray-600">
<div className="font-semibold text-white">Architecture Stack:</div>
<div className="text-green-400">↑ Output Probabilities (51,200 tokens)</div>
<div className="text-blue-400">↑ Layers 0-19 (20 transformer blocks)</div>
<div className="text-green-400">↑ Input Embeddings (51,200 × 1,024)</div>
</div>
</div>
{/* Legend */}
<div className="absolute top-4 right-4 bg-gray-800/80 backdrop-blur rounded-lg p-3 text-xs">
<div className="font-semibold text-white mb-2">Components</div>
<div className="space-y-1">
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-green-500 rounded"></div>
<span className="text-gray-300">Token Embeddings</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-blue-500 rounded"></div>
<span className="text-gray-300">Attention Layers</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-purple-500 rounded"></div>
<span className="text-gray-300">Feed-Forward</span>
</div>
<div className="flex items-center gap-2">
<div className="w-3 h-3 bg-amber-500 rounded"></div>
<span className="text-gray-300">Output Layer</span>
</div>
</div>
</div>
</div>
</div>
{/* Explanation Side Panel */}
<div className={`${showExplanation ? 'w-96' : 'w-0'} transition-all duration-500 ease-in-out overflow-hidden`}>
<div className="w-96 h-[700px] bg-gray-900 rounded-lg border border-gray-700">
{/* Panel Header */}
<div className="bg-gray-800 px-4 py-3 border-b border-gray-700">
<div className="flex items-center gap-2">
<Info className="w-5 h-5 text-blue-400" />
<h3 className="text-lg font-semibold text-white">Understanding the Architecture</h3>
</div>
</div>
{/* Panel Content */}
<div className="px-4 py-4 overflow-y-auto h-[calc(700px-60px)]">
{/* Main Description */}
<div className="mb-4 p-3 bg-blue-900/20 border border-blue-800 rounded-lg">
<h4 className="text-sm font-semibold text-blue-400 mb-1">{explanation.title}</h4>
<p className="text-xs text-gray-300">{explanation.description}</p>
</div>
{/* Explanation Sections */}
<div className="space-y-3">
{explanation.details.map((section, idx) => (
<div key={idx} className="bg-gray-800 rounded-lg p-3">
<h5 className="font-medium text-sm text-white mb-1 flex items-center gap-1">
<Zap className="w-3 h-3 text-yellow-400" />
{section.heading}
</h5>
<p className="text-xs text-gray-300 leading-relaxed">{section.content}</p>
</div>
))}
</div>
{/* Visual Guide */}
<div className="mt-4 p-3 bg-purple-900/20 border border-purple-800 rounded-lg">
<h4 className="font-medium text-sm text-purple-400 mb-2">Layer Components</h4>
<div className="space-y-2 text-xs">
<div className="flex items-start gap-2">
<div className="w-3 h-3 bg-green-500 rounded mt-0.5"></div>
<span className="text-gray-300">Input: Token embeddings (51,200 × 1,024)</span>
</div>
<div className="flex items-start gap-2">
<div className="w-3 h-3 bg-blue-500 rounded mt-0.5"></div>
<span className="text-gray-300">Attention: {modelInfo.layers} layers × {modelInfo.heads} heads</span>
</div>
<div className="flex items-start gap-2">
<div className="w-3 h-3 bg-purple-500 rounded mt-0.5"></div>
<span className="text-gray-300">FFN: 4096 dimensional processing</span>
</div>
<div className="flex items-start gap-2">
<div className="w-3 h-3 bg-amber-500 rounded mt-0.5"></div>
<span className="text-gray-300">Output: Probability over 51,200 tokens</span>
</div>
</div>
</div>
{/* Model Statistics */}
<div className="mt-4 p-3 bg-gray-800 rounded-lg">
<h4 className="font-medium text-sm text-gray-300 mb-2">Model Statistics</h4>
<div className="space-y-1 text-xs">
<div className="flex justify-between">
<span className="text-gray-400">Total Layers:</span>
<span className="text-white">20</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">Attention Heads:</span>
<span className="text-white">16 per layer</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">Hidden Dimensions:</span>
<span className="text-white">1,024</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">FFN Dimensions:</span>
<span className="text-white">4,096</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">Vocabulary Size:</span>
<span className="text-white">51,200 tokens</span>
</div>
<div className="flex justify-between">
<span className="text-gray-400">Parameters:</span>
<span className="text-blue-400">~350M</span>
</div>
</div>
</div>
{/* Interaction Guide */}
<div className="mt-4 p-3 bg-gray-800 rounded-lg">
<h4 className="font-medium text-sm text-gray-300 mb-2">💡 Interactive Features</h4>
<ul className="text-xs text-gray-400 space-y-1">
<li>• Click layers to see details</li>
<li>• Drag to rotate the model</li>
<li>• Scroll to zoom in/out</li>
<li>• Enable auto-rotate for 360° view</li>
<li>• Each layer processes all tokens in parallel</li>
</ul>
</div>
</div>
</div>
</div>
</div>
</div>
);
}