/** * Prompt Diff Analyzer Component * * Visualizes and compares attention patterns between different prompts * to show how prompt changes affect model behavior. * * CURRENT STATUS: Demo Mode * - Uses existing traces from working_demo.py * - Simulates differences between prompts * - TODO: Integrate with real LLM models for actual prompt comparison * * @component */ "use client"; import { useState, useEffect, useRef } from "react"; import * as d3 from "d3"; import { useWebSocket } from "@/lib/websocket-client"; import { PromptServiceClient } from "@/lib/prompt-service-client"; import { getApiUrl } from "@/lib/config"; import { ArrowRight, AlertTriangle, CheckCircle, Minus, Plus, GitCompare, Activity, Download, HelpCircle, X, Info, Zap } from "lucide-react"; // Attention data structure for a single layer interface AttentionData { layer: string; weights: number[][]; tokens?: string[]; max_weight: number; entropy?: number; timestamp: number; } // Comparison data structure for two prompts interface PromptComparison { promptA: string; promptB: string; attentionA: AttentionData[]; attentionB: AttentionData[]; timestamp: number; } export default function PromptDiff() { const { traces, isConnected } = useWebSocket(); const [prompt1, setPrompt1] = useState("def fibonacci(n):\n '''Calculate fibonacci number'''"); const [prompt2, setPrompt2] = useState("def fibonacci(n):\n '''Calculate fibonacci number with memoization'''"); const [shouldAutoAnalyze, setShouldAutoAnalyze] = useState(false); // Listen for demo selections from LocalControlPanel useEffect(() => { const handleDemoPromptsSelected = (event: Event) => { const customEvent = event as CustomEvent; const { promptA, promptB } = customEvent.detail; console.log('PromptDiff: Demo prompts received -', { promptA, promptB }); if (promptA && promptB) { setPrompt1(promptA); setPrompt2(promptB); setShouldAutoAnalyze(true); } }; window.addEventListener('demo-prompts-selected', handleDemoPromptsSelected); return () => window.removeEventListener('demo-prompts-selected', handleDemoPromptsSelected); }, []); // We'll add the auto-analyze effect after analyzePrompts is defined const [comparison, setComparison] = useState(null); const [selectedLayer, setSelectedLayer] = useState(""); const [isAnalyzing, setIsAnalyzing] = useState(false); const [savedComparisons, setSavedComparisons] = useState([]); const [isPromptServiceConnected, setIsPromptServiceConnected] = useState(false); const [useRealModel, setUseRealModel] = useState(true); // Default to using real model const [generatedTexts, setGeneratedTexts] = useState<{a?: string, b?: string}>({}); const [showExplanation, setShowExplanation] = useState(false); const diffSvgRef = useRef(null); const heatmapARef = useRef(null); const heatmapBRef = useRef(null); const promptServiceRef = useRef(null); // Initialize prompt service connection useEffect(() => { const client = new PromptServiceClient(); promptServiceRef.current = client; // Try to connect to the prompt service client.connect() .then(() => { console.log('✅ Connected to prompt service'); setIsPromptServiceConnected(true); setUseRealModel(true); }) .catch((error) => { // This is expected if the service isn't running - no need to log error console.log('ℹ️ Using demo mode (start prompt_service.py for real model)'); setIsPromptServiceConnected(false); setUseRealModel(false); }); // Set up message handler client.onMessage('prompt-diff', (data) => { handlePromptServiceMessage(data as unknown as PromptServiceMessage); }); return () => { client.offMessage('prompt-diff'); client.disconnect(); }; }, []); // Handle messages from prompt service interface PromptServiceMessage { type: string; comparison_group?: string; layer?: string; weights?: number[][]; max_weight?: number; entropy?: number; timestamp?: number; tokens?: string[]; } const handlePromptServiceMessage = (data: PromptServiceMessage) => { if (data.type === 'attention') { // Store attention traces temporarily // eslint-disable-next-line @typescript-eslint/no-explicit-any if (!(window as any).tempAttentionStorage) { // eslint-disable-next-line @typescript-eslint/no-explicit-any (window as any).tempAttentionStorage = { prompt_a: [], prompt_b: [] }; } if (data.comparison_group === 'prompt_a') { // eslint-disable-next-line @typescript-eslint/no-explicit-any (window as any).tempAttentionStorage.prompt_a.push(data); } else if (data.comparison_group === 'prompt_b') { // eslint-disable-next-line @typescript-eslint/no-explicit-any (window as any).tempAttentionStorage.prompt_b.push(data); } } else if (data.type === 'prompt_comparison') { // Handle comparison summary // eslint-disable-next-line @typescript-eslint/no-explicit-any const storage = (window as any).tempAttentionStorage || { prompt_a: [], prompt_b: [] }; // Create comparison from received data // eslint-disable-next-line @typescript-eslint/no-explicit-any const comparisonData = data as any; // Type assertion for comparison-specific fields const newComparison: PromptComparison = { promptA: comparisonData.prompt_a?.text || prompt1, promptB: comparisonData.prompt_b?.text || prompt2, attentionA: storage.prompt_a.map((t: PromptServiceMessage) => ({ layer: t.layer || '', weights: t.weights || [], max_weight: t.max_weight || 0, entropy: t.entropy || 0, timestamp: t.timestamp || 0, tokens: t.tokens || [] })), attentionB: storage.prompt_b.map((t: PromptServiceMessage) => ({ layer: t.layer || '', weights: t.weights || [], max_weight: t.max_weight || 0, entropy: t.entropy || 0, timestamp: t.timestamp || 0, tokens: t.tokens || [] })), timestamp: Date.now() }; // Store generated texts setGeneratedTexts({ a: comparisonData.prompt_a?.generated || '', b: comparisonData.prompt_b?.generated || '' }); setComparison(newComparison); setSavedComparisons(prev => [...prev, newComparison]); // Auto-select first layer if (storage.prompt_a.length > 0) { const layers = Array.from(new Set(storage.prompt_a.map((a: PromptServiceMessage) => a.layer))); if (layers[0]) { setSelectedLayer(layers[0] as string); } } // Clear temporary storage // eslint-disable-next-line @typescript-eslint/no-explicit-any (window as any).tempAttentionStorage = null; setIsAnalyzing(false); } }; // Window type declarations removed - using (window as any) for tempAttentionStorage instead /** * Analyzes attention patterns for two prompts * * Uses real model service if available, otherwise falls back to demo mode */ const analyzePrompts = async () => { setIsAnalyzing(true); try { // Generate traces for both prompts using the unified backend console.log('Generating traces for Prompt A...'); const responseA = await fetch(`${getApiUrl()}/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ prompt: prompt1, max_tokens: 50, temperature: 0.7, sampling_rate: 0.5 // Higher sampling for better comparison }) }); if (!responseA.ok) throw new Error(`HTTP error! status: ${responseA.status}`); const dataA = await responseA.json(); console.log('Generating traces for Prompt B...'); const responseB = await fetch(`${getApiUrl()}/generate`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ prompt: prompt2, max_tokens: 50, temperature: 0.7, sampling_rate: 0.5 }) }); if (!responseB.ok) throw new Error(`HTTP error! status: ${responseB.status}`); const dataB = await responseB.json(); // Store generated texts setGeneratedTexts({ a: dataA.generated_text, b: dataB.generated_text }); // Extract attention traces from both responses const attentionA = dataA.traces ?.filter((t: PromptServiceMessage) => t.type === 'attention' && t.weights) .map((t: PromptServiceMessage) => ({ layer: t.layer || 'unknown', weights: t.weights || [], max_weight: t.max_weight || 1, entropy: t.entropy, timestamp: t.timestamp || Date.now(), tokens: t.tokens })) || []; const attentionB = dataB.traces ?.filter((t: PromptServiceMessage) => t.type === 'attention' && t.weights) .map((t: PromptServiceMessage) => ({ layer: t.layer || 'unknown', weights: t.weights || [], max_weight: t.max_weight || 1, entropy: t.entropy, timestamp: t.timestamp || Date.now(), tokens: t.tokens })) || []; if (attentionA.length === 0 || attentionB.length === 0) { alert('No attention traces captured. Try adjusting the prompts or sampling rate.'); setIsAnalyzing(false); return; } // Create comparison const newComparison: PromptComparison = { promptA: prompt1, promptB: prompt2, attentionA, attentionB, timestamp: Date.now() }; setComparison(newComparison); // Find common layers between both sets const layersA = new Set(attentionA.map((a: AttentionData) => a.layer)); const layersB = new Set(attentionB.map((a: AttentionData) => a.layer)); const commonLayers = Array.from(layersA).filter(l => layersB.has(l)).sort(); if (commonLayers.length > 0) { setSelectedLayer(commonLayers[0] as string); } else if (attentionA.length > 0) { // No common layers, just use first layer from A setSelectedLayer(attentionA[0].layer); } console.log(`Comparison complete. Found ${commonLayers.length} common layers.`); } catch (error) { console.error('Error comparing prompts:', error); alert(`Failed to compare prompts: ${error}`); } finally { setIsAnalyzing(false); } }; // Auto-analyze when triggered by demo selection useEffect(() => { if (shouldAutoAnalyze && !isAnalyzing) { setShouldAutoAnalyze(false); // Small delay to ensure state is updated setTimeout(() => { analyzePrompts(); }, 100); } }, [shouldAutoAnalyze, isAnalyzing, prompt1, prompt2]); const analyzeDemoMode = () => { // Original demo mode implementation const attentionTraces = traces.filter(t => t.type === 'attention' && t.weights); console.log('[PromptDiff] Available attention traces:', attentionTraces.length); if (attentionTraces.length === 0) { // No traces available - show error message alert('No attention data available. Please run a model first to capture attention patterns.\n\nTry running: python python-sdk/working_demo.py'); setIsAnalyzing(false); return; } if (attentionTraces.length < 2) { // Not enough traces for comparison alert(`Only ${attentionTraces.length} attention trace(s) available. Need at least 2 for comparison.\n\nTry running the demo again to generate more traces.`); setIsAnalyzing(false); return; } // Simulate having different attention patterns for each prompt const halfPoint = Math.floor(attentionTraces.length / 2); const attentionA = attentionTraces.slice(0, halfPoint).map(t => ({ layer: t.layer || 'unknown', weights: t.weights || [], max_weight: t.max_weight || 1, entropy: t.entropy, timestamp: t.timestamp || Date.now(), tokens: t.tokens })); const attentionB = attentionTraces.slice(halfPoint).map(t => ({ layer: t.layer || 'unknown', weights: t.weights || [], max_weight: t.max_weight || 1, entropy: t.entropy, timestamp: t.timestamp || Date.now(), tokens: t.tokens })); // Add some variation to attentionB to simulate different prompt attentionB.forEach(attention => { attention.weights = attention.weights.map(row => row.map(val => Math.min(1, Math.max(0, val + (Math.random() - 0.5) * 0.2))) ); attention.max_weight = Math.max(...attention.weights.flat()); attention.entropy = (attention.entropy || 0) + (Math.random() - 0.5) * 0.1; }); const newComparison: PromptComparison = { promptA: prompt1, promptB: prompt2, attentionA, attentionB, timestamp: Date.now() }; setComparison(newComparison); setSavedComparisons(prev => [...prev, newComparison]); // Auto-select first layer if (attentionA.length > 0) { const layers = Array.from(new Set(attentionA.map(a => a.layer))); setSelectedLayer(layers[0]); } setTimeout(() => setIsAnalyzing(false), 1000); }; /** * Render attention heatmap for Prompt A */ useEffect(() => { if (!comparison || !selectedLayer || !heatmapARef.current) return; const attention = comparison.attentionA.find(a => a.layer === selectedLayer); if (!attention || !attention.weights) return; const margin = { top: 40, right: 40, bottom: 60, left: 60 }; const cellSize = 12; const weights = attention.weights; const numRows = weights.length; const numCols = weights[0]?.length || 0; if (numRows === 0 || numCols === 0) return; const width = numCols * cellSize; const height = numRows * cellSize; // Clear previous d3.select(heatmapARef.current).selectAll("*").remove(); const svg = d3.select(heatmapARef.current) .attr("width", width + margin.left + margin.right) .attr("height", height + margin.top + margin.bottom); const g = svg.append("g") .attr("transform", `translate(${margin.left},${margin.top})`); // Color scale const colorScale = d3.scaleSequential(d3.interpolateBlues) .domain([0, attention.max_weight || 1]); // Draw cells g.selectAll(".cell") .data(weights.flatMap((row, i) => row.map((value, j) => ({ row: i, col: j, value })) )) .enter().append("rect") .attr("class", "cell") .attr("x", d => d.col * cellSize) .attr("y", d => d.row * cellSize) .attr("width", cellSize - 1) .attr("height", cellSize - 1) .attr("fill", d => colorScale(d.value)) .attr("stroke", "#1f2937") .attr("stroke-width", 0.5); // Title svg.append("text") .attr("x", (width + margin.left + margin.right) / 2) .attr("y", 20) .attr("text-anchor", "middle") .style("font-size", "14px") .style("font-weight", "bold") .style("fill", "#fff") .text("Prompt A"); }, [comparison, selectedLayer]); /** * Render attention heatmap for Prompt B */ useEffect(() => { if (!comparison || !selectedLayer || !heatmapBRef.current) return; const attention = comparison.attentionB.find(a => a.layer === selectedLayer); if (!attention || !attention.weights) return; const margin = { top: 40, right: 40, bottom: 60, left: 60 }; const cellSize = 12; const weights = attention.weights; const numRows = weights.length; const numCols = weights[0]?.length || 0; if (numRows === 0 || numCols === 0) return; const width = numCols * cellSize; const height = numRows * cellSize; // Clear previous d3.select(heatmapBRef.current).selectAll("*").remove(); const svg = d3.select(heatmapBRef.current) .attr("width", width + margin.left + margin.right) .attr("height", height + margin.top + margin.bottom); const g = svg.append("g") .attr("transform", `translate(${margin.left},${margin.top})`); // Color scale const colorScale = d3.scaleSequential(d3.interpolateBlues) .domain([0, attention.max_weight || 1]); // Draw cells g.selectAll(".cell") .data(weights.flatMap((row, i) => row.map((value, j) => ({ row: i, col: j, value })) )) .enter().append("rect") .attr("class", "cell") .attr("x", d => d.col * cellSize) .attr("y", d => d.row * cellSize) .attr("width", cellSize - 1) .attr("height", cellSize - 1) .attr("fill", d => colorScale(d.value)) .attr("stroke", "#1f2937") .attr("stroke-width", 0.5); // Title svg.append("text") .attr("x", (width + margin.left + margin.right) / 2) .attr("y", 20) .attr("text-anchor", "middle") .style("font-size", "14px") .style("font-weight", "bold") .style("fill", "#fff") .text("Prompt B"); }, [comparison, selectedLayer]); /** * Creates D3.js visualization of attention differences * Uses diverging color scale (RdBu) to show increases/decreases * KNOWN ISSUE: Title text may be cut off at top of visualization */ useEffect(() => { if (!comparison || !selectedLayer || !diffSvgRef.current) return; const attentionA = comparison.attentionA.find(a => a.layer === selectedLayer); const attentionB = comparison.attentionB.find(a => a.layer === selectedLayer); if (!attentionA || !attentionB) return; const margin = { top: 130, right: 60, bottom: 40, left: 60 }; const cellSize = 18; // Calculate difference matrix const diffMatrix: number[][] = []; const minRows = Math.min(attentionA.weights.length, attentionB.weights.length); const minCols = Math.min(attentionA.weights[0]?.length || 0, attentionB.weights[0]?.length || 0); for (let i = 0; i < minRows; i++) { diffMatrix[i] = []; for (let j = 0; j < minCols; j++) { diffMatrix[i][j] = (attentionB.weights[i]?.[j] || 0) - (attentionA.weights[i]?.[j] || 0); } } const width = minCols * cellSize; const height = minRows * cellSize; // Clear previous visualization d3.select(diffSvgRef.current).selectAll("*").remove(); const totalWidth = width + margin.left + margin.right; const totalHeight = height + margin.top + margin.bottom; const svg = d3.select(diffSvgRef.current) .attr("width", totalWidth) .attr("height", totalHeight) .attr("viewBox", `0 0 ${totalWidth} ${totalHeight}`) .style("display", "block"); const g = svg.append("g") .attr("transform", `translate(${margin.left},${margin.top})`); // Create diverging color scale for differences const maxDiff = Math.max(...diffMatrix.flat().map(Math.abs)); const colorScale = d3.scaleDiverging() .domain([-maxDiff, 0, maxDiff]) // eslint-disable-next-line @typescript-eslint/no-explicit-any .interpolator(d3.interpolateRdBu as any); // Create scales const xScale = d3.scaleBand() .domain(d3.range(minCols).map(String)) .range([0, width]) .padding(0.01); const yScale = d3.scaleBand() .domain(d3.range(minRows).map(String)) .range([0, height]) .padding(0.01); // Create tooltip const tooltip = d3.select("body").append("div") .attr("class", "diff-tooltip") .style("opacity", 0) .style("position", "absolute") .style("background", "rgba(0, 0, 0, 0.9)") .style("color", "white") .style("padding", "10px") .style("border-radius", "6px") .style("font-size", "12px") .style("pointer-events", "none") .style("z-index", "1000"); // Draw cells g.selectAll(".diff-cell") .data(diffMatrix.flatMap((row, i) => row.map((value, j) => ({ row: i, col: j, value })) )) .enter().append("rect") .attr("class", "diff-cell") .attr("x", d => xScale(String(d.col))!) .attr("y", d => yScale(String(d.row))!) .attr("width", xScale.bandwidth()) .attr("height", yScale.bandwidth()) .attr("fill", d => colorScale(d.value)) .attr("stroke", "#1f2937") .attr("stroke-width", 0.5) .style("cursor", "pointer") .on("mouseover", function(event, d) { tooltip.transition().duration(200).style("opacity", .95); const change = d.value > 0 ? '+' : ''; const tokenFrom = attentionA.tokens?.[d.row] || `T${d.row}`; const tokenTo = attentionA.tokens?.[d.col] || `T${d.col}`; tooltip.html(`
Attention Change
From: ${tokenFrom} → To: ${tokenTo}
Prompt A: ${(attentionA.weights[d.row]?.[d.col] || 0).toFixed(4)}
Prompt B: ${(attentionB.weights[d.row]?.[d.col] || 0).toFixed(4)}
Change: ${change}${d.value.toFixed(4)}
`) .style("left", (event.pageX + 10) + "px") .style("top", (event.pageY - 28) + "px"); d3.select(this) .attr("stroke", "#3b82f6") .attr("stroke-width", 2); }) .on("mouseout", function() { tooltip.transition().duration(500).style("opacity", 0); d3.select(this) .attr("stroke", "#1f2937") .attr("stroke-width", 0.5); }); // Add title svg.append("text") .attr("x", totalWidth / 2) .attr("y", 25) .attr("text-anchor", "middle") .style("font-size", "14px") .style("font-weight", "bold") .style("fill", "#fff") .text(`Attention Difference - ${selectedLayer}`); // Add color legend const legendWidth = 150; const legendHeight = 15; const legendScale = d3.scaleLinear() .domain([-maxDiff, maxDiff]) .range([0, legendWidth]); const legendAxis = d3.axisBottom(legendScale) .ticks(5) .tickFormat(d => (d as number).toFixed(2)); const legend = svg.append("g") .attr("transform", `translate(${(totalWidth - legendWidth) / 2}, ${50})`); // Create gradient for legend const gradientId = `diff-gradient-${Date.now()}`; const gradient = svg.append("defs") .append("linearGradient") .attr("id", gradientId) .attr("x1", "0%") .attr("x2", "100%"); for (let i = 0; i <= 20; i++) { const t = i / 20; const value = -maxDiff + (2 * maxDiff * t); gradient.append("stop") .attr("offset", `${t * 100}%`) .attr("stop-color", colorScale(value)); } legend.append("rect") .attr("width", legendWidth) .attr("height", legendHeight) .style("fill", `url(#${gradientId})`); legend.append("g") .attr("transform", `translate(0, ${legendHeight})`) .call(legendAxis) .selectAll("text") .style("fill", "#9ca3af") .style("font-size", "10px"); legend.append("text") .attr("x", legendWidth / 2) .attr("y", -5) .attr("text-anchor", "middle") .style("font-size", "11px") .style("fill", "#9ca3af") .text("← Less Attention | More Attention →"); // Cleanup return () => { tooltip.remove(); }; }, [comparison, selectedLayer]); /** * Calculates statistical metrics for attention comparison * Returns average change, max increase/decrease, and entropy metrics */ const calculateStats = () => { if (!comparison || !selectedLayer) return null; const attentionA = comparison.attentionA.find(a => a.layer === selectedLayer); const attentionB = comparison.attentionB.find(a => a.layer === selectedLayer); if (!attentionA || !attentionB) return null; // Calculate average attention change let totalChange = 0; let count = 0; let maxIncrease = 0; let maxDecrease = 0; const minRows = Math.min(attentionA.weights.length, attentionB.weights.length); const minCols = Math.min(attentionA.weights[0]?.length || 0, attentionB.weights[0]?.length || 0); for (let i = 0; i < minRows; i++) { for (let j = 0; j < minCols; j++) { const diff = (attentionB.weights[i]?.[j] || 0) - (attentionA.weights[i]?.[j] || 0); totalChange += Math.abs(diff); count++; if (diff > maxIncrease) maxIncrease = diff; if (diff < maxDecrease) maxDecrease = diff; } } return { avgChange: count > 0 ? totalChange / count : 0, maxIncrease, maxDecrease: Math.abs(maxDecrease), entropyChangeA: attentionA.entropy || 0, entropyChangeB: attentionB.entropy || 0 }; }; const stats = calculateStats(); const uniqueLayers = comparison ? Array.from(new Set(comparison.attentionA.map(a => a.layer))) : []; // Generate contextual explanation for current visualization const generateExplanation = () => { if (!comparison) { return { title: "No Comparison Data", description: "Enter two different prompts and analyze to see how attention patterns differ.", details: [] }; } const numLayers = uniqueLayers.length; const hasRealModel = useRealModel; const avgChange = stats?.avgChange ? (stats.avgChange * 100).toFixed(1) : "0"; const maxInc = stats?.maxIncrease ? (stats.maxIncrease * 100).toFixed(1) : "0"; const maxDec = stats?.maxDecrease ? (stats.maxDecrease * 100).toFixed(1) : "0"; return { title: `Prompt Comparison: ${numLayers} layers analyzed`, description: `Comparing attention patterns between two prompts to identify behavioral differences.`, details: [ { heading: "What is Prompt Diff?", content: `This tool compares how the model's attention mechanism responds to different prompts. Three visualizations show: raw attention for each prompt (left/right) and the difference between them (center).` }, { heading: "Reading the Three Maps", content: `Left: Attention patterns for Prompt A. Center: Difference map (blue = increased in B, red = decreased in B). Right: Attention patterns for Prompt B. Compare side panels to see the full context of changes.` }, { heading: "Current Analysis", content: `Average attention change: ${avgChange}%. Maximum increase: +${maxInc}%. Maximum decrease: -${maxDec}%. ${hasRealModel ? 'Using real model for authentic patterns.' : 'Demo mode with simulated differences.'}` }, { heading: "Why Compare Prompts?", content: `Different prompts can dramatically change model behavior. Adding words like 'detailed' or 'concise' shifts attention patterns. This visualization reveals these hidden changes.` }, { heading: "Entropy Changes", content: `Entropy measures uncertainty in attention distribution. Higher entropy means more scattered attention, lower means more focused. Compare entropy values to see which prompt creates clearer attention patterns.` }, { heading: "Practical Applications", content: `Use this to: Optimize prompts for better results, understand why certain prompts work better, debug unexpected model behavior, and design more effective prompt templates.` } ] }; }; const explanation = generateExplanation(); // Export comparison const exportComparison = () => { if (!comparison) return; const dataStr = JSON.stringify(comparison, null, 2); const dataUri = 'data:application/json;charset=utf-8,'+ encodeURIComponent(dataStr); const exportFileDefaultName = `prompt_diff_${Date.now()}.json`; const linkElement = document.createElement('a'); linkElement.setAttribute('href', dataUri); linkElement.setAttribute('download', exportFileDefaultName); linkElement.click(); }; return (

Prompt Diff Analyzer

Compare how different prompts affect attention patterns and model behavior

{isConnected ? 'Connected' : 'Disconnected'}
{/* Prompt Input Section */}