|
|
import os |
|
|
import json |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from about_content import about_markdown |
|
|
from submission_content import submission_markdown |
|
|
import plotly.express as px |
|
|
import plotly |
|
|
|
|
|
def load_data(data_dir): |
|
|
data = [] |
|
|
for file_name in os.listdir(data_dir): |
|
|
if file_name.endswith(".json"): |
|
|
|
|
|
stage, method, model, dataset = file_name.replace(".json", "").split("_") |
|
|
with open(os.path.join(data_dir, file_name), "r") as f: |
|
|
entry = json.load(f) |
|
|
entry.update({"Method": method, "Model": model, "Dataset": dataset, "Stage": stage}) |
|
|
data.append(entry) |
|
|
return pd.DataFrame(data) |
|
|
|
|
|
def filter_and_display(selected_columns, model_types, datasets, stage): |
|
|
filtered = data.copy() |
|
|
|
|
|
|
|
|
filtered = filtered[filtered["Stage"] == stage] |
|
|
|
|
|
|
|
|
if model_types: |
|
|
filtered = filtered[filtered["Model"].isin(model_types)] |
|
|
|
|
|
|
|
|
if datasets: |
|
|
filtered = filtered[filtered["Dataset"].isin(datasets)] |
|
|
|
|
|
if not filtered.empty: |
|
|
|
|
|
if stage == "decode": |
|
|
filtered = filtered.groupby(["Method", "Model", "Dataset"], as_index=False).agg({ |
|
|
"Throughput (token/s)": "mean", |
|
|
"Quality": "mean", |
|
|
"Link": "first" |
|
|
}) |
|
|
else: |
|
|
filtered = filtered.groupby(["Method", "Model", "Dataset"], as_index=False).agg({ |
|
|
"Quality": "mean", |
|
|
"TTFT (s)": "mean", |
|
|
"Link": "first" |
|
|
}) |
|
|
|
|
|
|
|
|
display_columns = ["Method", "Model", "Dataset"] + [col for col in selected_columns if col in filtered.columns] |
|
|
return filtered[display_columns] if not filtered.empty else pd.DataFrame(columns=display_columns) |
|
|
|
|
|
def create_prefill_visualization(filtered_data): |
|
|
if filtered_data.empty: |
|
|
return None |
|
|
fig = px.scatter( |
|
|
filtered_data, |
|
|
x='TTFT (s)', |
|
|
y='Quality', |
|
|
color='Method', |
|
|
hover_data=['Model', 'Dataset'], |
|
|
title='Prefill Stage: Quality vs TTFT (s) by Method' |
|
|
) |
|
|
fig.update_layout( |
|
|
yaxis=dict(range=[0, 100]), |
|
|
xaxis=dict(range=[0, None]) |
|
|
) |
|
|
return fig |
|
|
|
|
|
def create_decode_visualization(filtered_data): |
|
|
if filtered_data.empty: |
|
|
return None |
|
|
fig = px.scatter( |
|
|
filtered_data, |
|
|
x='Throughput (token/s)', |
|
|
y='Quality', |
|
|
color='Method', |
|
|
hover_data=['Model', 'Dataset'], |
|
|
title='Decode Stage: Quality vs Throughput by Method' |
|
|
) |
|
|
fig.update_layout( |
|
|
yaxis=dict(range=[0, 100]), |
|
|
xaxis=dict(range=[0, None]) |
|
|
) |
|
|
return fig |
|
|
|
|
|
|
|
|
data_dir = "data" |
|
|
data = load_data(data_dir) |
|
|
|
|
|
|
|
|
def create_gradio_app(): |
|
|
|
|
|
with gr.Blocks() as app: |
|
|
with gr.Row(): |
|
|
gr.Markdown( |
|
|
"""# KV Cache Benchmark |
|
|
### Demo leaderboard |
|
|
This demo leaderboard allows users to explore and compare different KV cache implementations across various models and datasets. It provides interactive filtering options and real-time updates of benchmark results, including visualization of Quality and TTFT (s) metrics. |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("KV Cache Benchmark"): |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("## Prefill-Stage KV Cache Compression") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Select Columns to Display") |
|
|
prefill_columns_to_display = gr.CheckboxGroup( |
|
|
choices=["Quality", "TTFT (s)", "Link"], |
|
|
label="Columns", |
|
|
value=["Quality", "TTFT (s)"] |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Model Types") |
|
|
prefill_model_types = gr.CheckboxGroup( |
|
|
choices=list(data["Model"].unique()), |
|
|
label="Model Types", |
|
|
value=list(data["Model"].unique()) |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Datasets") |
|
|
prefill_datasets = gr.CheckboxGroup( |
|
|
choices=list(data[data["Stage"] == "prefill"]["Dataset"].unique()), |
|
|
label="Datasets", |
|
|
value=list(data[data["Stage"] == "prefill"]["Dataset"].unique()) |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("## Results") |
|
|
|
|
|
|
|
|
prefill_default = filter_and_display( |
|
|
["Quality", "TTFT (s)"], |
|
|
list(data["Model"].unique()), |
|
|
list(data[data["Stage"] == "prefill"]["Dataset"].unique()), |
|
|
"prefill" |
|
|
) |
|
|
prefill_results = gr.Dataframe( |
|
|
value=prefill_default |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("### Prefill-stage Visualization") |
|
|
with gr.Row(): |
|
|
prefill_plot = gr.Plot( |
|
|
value=create_prefill_visualization(prefill_default) |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("## Decode-Stage KV Cache Compression") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Select Columns to Display") |
|
|
decode_columns_to_display = gr.CheckboxGroup( |
|
|
choices=["Throughput (token/s)", "Quality", "Link"], |
|
|
label="Columns", |
|
|
value=["Throughput (token/s)", "Quality"] |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Model Types") |
|
|
decode_model_types = gr.CheckboxGroup( |
|
|
choices=list(data["Model"].unique()), |
|
|
label="Model Types", |
|
|
value=list(data["Model"].unique()) |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### Datasets") |
|
|
decode_datasets = gr.CheckboxGroup( |
|
|
choices=list(data[data["Stage"] == "decode"]["Dataset"].unique()), |
|
|
label="Datasets", |
|
|
value=list(data[data["Stage"] == "decode"]["Dataset"].unique()) |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("## Results") |
|
|
|
|
|
|
|
|
decode_default = filter_and_display( |
|
|
["Throughput (token/s)", "Quality"], |
|
|
list(data["Model"].unique()), |
|
|
list(data[data["Stage"] == "decode"]["Dataset"].unique()), |
|
|
"decode" |
|
|
) |
|
|
decode_results = gr.Dataframe( |
|
|
value=decode_default |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("### Decode-Stage Visualization") |
|
|
with gr.Row(): |
|
|
decode_plot = gr.Plot( |
|
|
value=create_decode_visualization(decode_default) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def auto_update_prefill(selected_columns, model_types, datasets): |
|
|
if not model_types or not datasets: |
|
|
|
|
|
return pd.DataFrame(columns=["Method", "Model"] + selected_columns) |
|
|
filtered_data = filter_and_display(selected_columns, model_types, datasets, "prefill") |
|
|
return filtered_data |
|
|
|
|
|
def auto_update_decode(selected_columns, model_types, datasets): |
|
|
if not model_types or not datasets: |
|
|
|
|
|
return pd.DataFrame(columns=["Method", "Model"] + selected_columns) |
|
|
filtered_data = filter_and_display(selected_columns, model_types, datasets, "decode") |
|
|
return filtered_data |
|
|
|
|
|
|
|
|
prefill_columns_to_display.change( |
|
|
auto_update_prefill, |
|
|
inputs=[prefill_columns_to_display, prefill_model_types, prefill_datasets], |
|
|
outputs=[prefill_results] |
|
|
) |
|
|
prefill_model_types.change( |
|
|
auto_update_prefill, |
|
|
inputs=[prefill_columns_to_display, prefill_model_types, prefill_datasets], |
|
|
outputs=[prefill_results] |
|
|
) |
|
|
prefill_datasets.change( |
|
|
auto_update_prefill, |
|
|
inputs=[prefill_columns_to_display, prefill_model_types, prefill_datasets], |
|
|
outputs=[prefill_results] |
|
|
) |
|
|
|
|
|
decode_columns_to_display.change( |
|
|
auto_update_decode, |
|
|
inputs=[decode_columns_to_display, decode_model_types, decode_datasets], |
|
|
outputs=[decode_results] |
|
|
) |
|
|
decode_model_types.change( |
|
|
auto_update_decode, |
|
|
inputs=[decode_columns_to_display, decode_model_types, decode_datasets], |
|
|
outputs=[decode_results] |
|
|
) |
|
|
decode_datasets.change( |
|
|
auto_update_decode, |
|
|
inputs=[decode_columns_to_display, decode_model_types, decode_datasets], |
|
|
outputs=[decode_results] |
|
|
) |
|
|
|
|
|
|
|
|
def reload_website(): |
|
|
|
|
|
return gr.JS("window.location.reload();") |
|
|
|
|
|
reload_button = gr.Button("Reload Data") |
|
|
reload_button.click( |
|
|
reload_website |
|
|
) |
|
|
|
|
|
with gr.TabItem("About"): |
|
|
gr.Markdown(about_markdown) |
|
|
|
|
|
with gr.TabItem("Submission Instructions"): |
|
|
gr.Markdown(submission_markdown) |
|
|
|
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
app = create_gradio_app() |
|
|
app.launch() |
|
|
|