Spaces:
Runtime error
Runtime error
| import altair as alt | |
| import gradio as gr | |
| import pandas as pd | |
| from functools import partial | |
| from datasets import load_dataset | |
| def get_data(): | |
| model_id = "ybelkada/model_cards_correct_tag" | |
| dataset = load_dataset(model_id, split="train").to_pandas() | |
| # Convert dataset to a pandas DataFrame and sort by commit_dates | |
| df = pd.DataFrame(dataset) | |
| df["commit_dates"] = pd.to_datetime(df["commit_dates"]) # Convert commit_dates to datetime format | |
| df = df.sort_values(by="commit_dates") | |
| melted_df = pd.melt(df, id_vars=['commit_dates'], value_vars=['total_transformers_model', 'missing_library_name'], var_name='type') | |
| df['ratio'] = (1 - df['missing_library_name'] / df['total_transformers_model']) * 100 | |
| ratio_df = df[['commit_dates', 'ratio']].copy() | |
| return ratio_df, melted_df | |
| ratio_df, melted_df = get_data() | |
| def make_plot(plot_type, refresh=False): | |
| global ratio_df, melted_df | |
| if refresh: | |
| ratio_df, melted_df = get_data() | |
| if plot_type == "Total models with missing 'transformers' tag": | |
| highlight = alt.selection(type='single', on='mouseover', | |
| fields=['type'], nearest=True) | |
| base = alt.Chart(melted_df).encode( | |
| x=alt.X('commit_dates:T', title='Date'), | |
| y=alt.Y('value:Q', scale=alt.Scale(domain=(melted_df['value'].min(), melted_df['value'].max())), title="Count"), | |
| color='type:N', | |
| ) | |
| points = base.mark_circle().encode( | |
| opacity=alt.value(1), | |
| ).add_selection( | |
| highlight | |
| ).properties( | |
| width=1200, | |
| height=800, | |
| ) | |
| lines = base.mark_line().encode( | |
| size=alt.condition(~highlight, alt.value(1), alt.value(3)) | |
| ) | |
| return points + lines | |
| else: | |
| highlight = alt.selection(type='single', on='mouseover', | |
| fields=['ratio'], nearest=True) | |
| base = alt.Chart(ratio_df).encode( | |
| x=alt.X('commit_dates:T', title='Date'), | |
| y=alt.Y('ratio:Q', scale=alt.Scale(domain=(ratio_df['ratio'].min(), ratio_df['ratio'].max())), title="(1 - missing_library_name / total_transformers_model) * 100 - Higher is better"), | |
| ) | |
| points = base.mark_circle().encode( | |
| opacity=alt.value(1) | |
| ).add_selection( | |
| highlight | |
| ).properties( | |
| width=1200, | |
| height=800, | |
| ) | |
| lines = base.mark_line().encode( | |
| size=alt.condition(~highlight, alt.value(1), alt.value(3)) | |
| ) | |
| return points + lines | |
| with gr.Blocks() as demo: | |
| button = gr.Radio( | |
| label="Plot type", | |
| choices=["Total models with missing 'transformers' tag", "Proportion of models correctly tagged with 'transformers' tag"], | |
| value="Total models with missing 'transformers' tag" | |
| ) | |
| refresh_button = gr.Button(value="Fetch latest data") | |
| plot = gr.Plot(label="Plot") | |
| button.change(make_plot, inputs=[button], outputs=[plot]) | |
| refresh_button.click(partial(make_plot, refresh=True), inputs=[button], outputs=[plot]) | |
| demo.load(make_plot, inputs=[button], outputs=[plot]) | |
| if __name__ == "__main__": | |
| demo.launch() |