Hch Li
fixed visualization bug
c91973c
import os
import json
import gradio as gr
import pandas as pd
from about_content import about_markdown # Import the about page content
from submission_content import submission_markdown # Import the submission page content
import plotly.express as px
import plotly
# Helper function to load data from JSON files
def load_data(data_dir):
data = []
for file_name in os.listdir(data_dir):
if file_name.endswith(".json"):
# Extract stage, method, model, and dataset from the file name
stage, method, model, dataset = file_name.replace(".json", "").split("_")
with open(os.path.join(data_dir, file_name), "r") as f:
entry = json.load(f)
entry.update({"Method": method, "Model": model, "Dataset": dataset, "Stage": stage})
data.append(entry)
return pd.DataFrame(data)
def filter_and_display(selected_columns, model_types, datasets, stage):
filtered = data.copy()
# Filter by stage
filtered = filtered[filtered["Stage"] == stage]
# Filter by model types
if model_types:
filtered = filtered[filtered["Model"].isin(model_types)]
# Filter by datasets
if datasets:
filtered = filtered[filtered["Dataset"].isin(datasets)]
if not filtered.empty:
# Adjust aggregation based on stage
if stage == "decode":
filtered = filtered.groupby(["Method", "Model", "Dataset"], as_index=False).agg({
"Throughput (token/s)": "mean",
"Quality": "mean",
"Link": "first"
})
else:
filtered = filtered.groupby(["Method", "Model", "Dataset"], as_index=False).agg({
"Quality": "mean",
"TTFT (s)": "mean",
"Link": "first"
})
# Select columns to display
display_columns = ["Method", "Model", "Dataset"] + [col for col in selected_columns if col in filtered.columns]
return filtered[display_columns] if not filtered.empty else pd.DataFrame(columns=display_columns)
def create_prefill_visualization(filtered_data):
if filtered_data.empty:
return None
fig = px.scatter(
filtered_data,
x='TTFT (s)',
y='Quality',
color='Method',
hover_data=['Model', 'Dataset'],
title='Prefill Stage: Quality vs TTFT (s) by Method'
)
fig.update_layout(
yaxis=dict(range=[0, 100]), # Set y-axis (Quality) range from 0 to 1
xaxis=dict(range=[0, None]) # Set x-axis (TTFT (s)) to start from 0
)
return fig
def create_decode_visualization(filtered_data):
if filtered_data.empty:
return None
fig = px.scatter(
filtered_data,
x='Throughput (token/s)',
y='Quality',
color='Method',
hover_data=['Model', 'Dataset'],
title='Decode Stage: Quality vs Throughput by Method'
)
fig.update_layout(
yaxis=dict(range=[0, 100]), # Set y-axis (Quality) range from 0 to 1
xaxis=dict(range=[0, None]) # Set x-axis (Throughput (token/s)) to start from 0
)
return fig
# Load the data from the /data folder
data_dir = "data"
data = load_data(data_dir)
# Gradio app UI and functionality
def create_gradio_app():
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""# KV Cache Benchmark
### Demo leaderboard
This demo leaderboard allows users to explore and compare different KV cache implementations across various models and datasets. It provides interactive filtering options and real-time updates of benchmark results, including visualization of Quality and TTFT (s) metrics.
""")
with gr.Tabs():
with gr.TabItem("KV Cache Benchmark"):
# Prefill-stage selection
with gr.Row():
gr.Markdown("## Prefill-Stage KV Cache Compression")
with gr.Row():
with gr.Column():
gr.Markdown("#### Select Columns to Display")
prefill_columns_to_display = gr.CheckboxGroup(
choices=["Quality", "TTFT (s)", "Link"],
label="Columns",
value=["Quality", "TTFT (s)"]
)
with gr.Column():
gr.Markdown("#### Model Types")
prefill_model_types = gr.CheckboxGroup(
choices=list(data["Model"].unique()),
label="Model Types",
value=list(data["Model"].unique()) # Default to all models
)
with gr.Column():
gr.Markdown("#### Datasets")
prefill_datasets = gr.CheckboxGroup(
choices=list(data[data["Stage"] == "prefill"]["Dataset"].unique()),
label="Datasets",
value=list(data[data["Stage"] == "prefill"]["Dataset"].unique()) # Default to all datasets for prefill
)
# Prefill-stage compression results
with gr.Row():
gr.Markdown("## Results")
# Initialize the Prefill Dataframe with default data
prefill_default = filter_and_display(
["Quality", "TTFT (s)"],
list(data["Model"].unique()),
list(data[data["Stage"] == "prefill"]["Dataset"].unique()),
"prefill"
)
prefill_results = gr.Dataframe(
value=prefill_default
)
# Prefill-stage visualization (Static initially)
with gr.Row():
gr.Markdown("### Prefill-stage Visualization")
with gr.Row():
prefill_plot = gr.Plot(
value=create_prefill_visualization(prefill_default)
)
# Decode-stage selection
with gr.Row():
gr.Markdown("## Decode-Stage KV Cache Compression")
with gr.Row():
with gr.Column():
gr.Markdown("#### Select Columns to Display")
decode_columns_to_display = gr.CheckboxGroup(
choices=["Throughput (token/s)", "Quality", "Link"],
label="Columns",
value=["Throughput (token/s)", "Quality"]
)
with gr.Column():
gr.Markdown("#### Model Types")
decode_model_types = gr.CheckboxGroup(
choices=list(data["Model"].unique()),
label="Model Types",
value=list(data["Model"].unique()) # Default to all models
)
with gr.Column():
gr.Markdown("#### Datasets")
decode_datasets = gr.CheckboxGroup(
choices=list(data[data["Stage"] == "decode"]["Dataset"].unique()),
label="Datasets",
value=list(data[data["Stage"] == "decode"]["Dataset"].unique()) # Default to all datasets for decode
)
# Decode-stage compression results
with gr.Row():
gr.Markdown("## Results")
# Initialize the Decode Dataframe with default data
decode_default = filter_and_display(
["Throughput (token/s)", "Quality"],
list(data["Model"].unique()),
list(data[data["Stage"] == "decode"]["Dataset"].unique()),
"decode"
)
decode_results = gr.Dataframe(
value=decode_default
)
# Decode-stage visualization (Static initially)
with gr.Row():
gr.Markdown("### Decode-Stage Visualization")
with gr.Row():
decode_plot = gr.Plot(
value=create_decode_visualization(decode_default)
)
# AUTO-UPDATE FUNCTIONS:
# (We only update the DataFrame, NOT the Plot)
def auto_update_prefill(selected_columns, model_types, datasets):
if not model_types or not datasets:
# Return an empty DataFrame if no selection is made
return pd.DataFrame(columns=["Method", "Model"] + selected_columns)
filtered_data = filter_and_display(selected_columns, model_types, datasets, "prefill")
return filtered_data
def auto_update_decode(selected_columns, model_types, datasets):
if not model_types or not datasets:
# Return an empty DataFrame if no selection is made
return pd.DataFrame(columns=["Method", "Model"] + selected_columns)
filtered_data = filter_and_display(selected_columns, model_types, datasets, "decode")
return filtered_data
# Only update the tables when filters change
prefill_columns_to_display.change(
auto_update_prefill,
inputs=[prefill_columns_to_display, prefill_model_types, prefill_datasets],
outputs=[prefill_results]
)
prefill_model_types.change(
auto_update_prefill,
inputs=[prefill_columns_to_display, prefill_model_types, prefill_datasets],
outputs=[prefill_results]
)
prefill_datasets.change(
auto_update_prefill,
inputs=[prefill_columns_to_display, prefill_model_types, prefill_datasets],
outputs=[prefill_results]
)
decode_columns_to_display.change(
auto_update_decode,
inputs=[decode_columns_to_display, decode_model_types, decode_datasets],
outputs=[decode_results]
)
decode_model_types.change(
auto_update_decode,
inputs=[decode_columns_to_display, decode_model_types, decode_datasets],
outputs=[decode_results]
)
decode_datasets.change(
auto_update_decode,
inputs=[decode_columns_to_display, decode_model_types, decode_datasets],
outputs=[decode_results]
)
# Reload button to restart the whole website
def reload_website():
# This function will trigger a page reload using JavaScript
return gr.JS("window.location.reload();")
reload_button = gr.Button("Reload Data")
reload_button.click(
reload_website
)
with gr.TabItem("About"):
gr.Markdown(about_markdown) # Use the imported about page content
with gr.TabItem("Submission Instructions"):
gr.Markdown(submission_markdown) # Use the imported submission page content
return app
if __name__ == "__main__":
app = create_gradio_app()
app.launch()