Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import plotly.express as px | |
| import pandas as pd | |
| import random | |
| import logging | |
| from umap import UMAP | |
| from sentence_transformers import SentenceTransformer, util | |
| from datasets import load_dataset | |
| def load_model(): | |
| return SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
| def load_words_dataset(): | |
| dataset = load_dataset("marksverdhei/wordnet-definitions-en-2021", split="train") | |
| return dataset["Word"] | |
| # @st.cache_resource | |
| # def prepare_umap(): | |
| # all_enc = model.encode(all_words) | |
| # umap_3d = UMAP(n_components=3, init='random', random_state=0) | |
| # proj_3d = umap_3d.fit_transform(random.sample(all_enc.tolist(), k=2000)) | |
| # return umap_3d | |
| all_words = load_words_dataset() | |
| model = load_model() | |
| #umap_3d = prepare_umap() | |
| secret_word = random.choice(all_words) | |
| secret_embedding = model.encode(secret_word) | |
| print("Secret word ", secret_word) | |
| if 'words' not in st.session_state: | |
| st.session_state['words'] = [] | |
| if 'words_umap_df' not in st.session_state: | |
| words_umap_df = pd.DataFrame({ | |
| "x": [], | |
| "y": [], | |
| "z": [], | |
| "similarity": [], | |
| "s": [], | |
| "l": [], | |
| }) | |
| #secret_embedding_3d = umap_3d.transform([secret_embedding])[0] | |
| secret_embedding_3d = [0, 1, 2] | |
| words_umap_df.loc[len(words_umap_df)] = { | |
| "x": secret_embedding_3d[0], | |
| "y": secret_embedding_3d[1], | |
| "z": secret_embedding_3d[2], | |
| "similarity": 1, | |
| "s": 10, | |
| "l": "Secret word" | |
| } | |
| st.session_state['words_umap_df'] = words_umap_df | |
| st.write('Try to guess a secret word by semantic similarity') | |
| word = st.text_input("Input a word") | |
| used_words = [w for w, s in st.session_state['words']] | |
| if st.button("Guess") or word: | |
| if word not in used_words: | |
| word_embedding = model.encode(word) | |
| similarity = util.pytorch_cos_sim( | |
| secret_embedding, | |
| word_embedding | |
| ).cpu().numpy()[0][0] | |
| st.session_state['words'].append((str(word), similarity)) | |
| #pt = umap_3d.transform([word_embedding])[0] | |
| pt = [0, 1, 2] | |
| words_umap_df = st.session_state['words_umap_df'] | |
| words_umap_df.loc[len(words_umap_df)] = { | |
| "x": pt[0], | |
| "y": pt[1], | |
| "z": pt[2], | |
| "similarity": similarity, | |
| "s": 3, | |
| "l": str(word) | |
| } | |
| st.session_state['words_umap_df'] = words_umap_df | |
| words_df = pd.DataFrame( | |
| st.session_state['words'], | |
| columns=["word", "similarity"] | |
| ).sort_values(by=["similarity"], ascending=False) | |
| st.dataframe(words_df, use_container_width=True) | |
| words_umap_df = st.session_state['words_umap_df'] | |
| fig_3d = px.scatter_3d(words_umap_df, x="x", y="y", z="z", color="similarity", hover_name="l", hover_data={"x": False, "y": False, "z": False, "s": False}, size="s", size_max=10, range_color=(0,1)) | |
| st.plotly_chart(fig_3d, theme="streamlit", use_container_width=True) | |