Spaces:
Sleeping
Sleeping
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| import streamlit as st | |
| from streamlit.column_config import LinkColumn | |
| import os | |
| os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
| def load_model(): | |
| model = SentenceTransformer("sbintuitions/sarashina-embedding-v1-1b") | |
| return model | |
| def load_title_data(): | |
| title_df = pd.read_csv('anlp2025.tsv', names=["pid", "title"], sep="\t") | |
| return title_df | |
| def load_title_embeddings(): | |
| npz_comp = np.load("anlp2025.npz") | |
| title_embeddings = npz_comp["arr_0"] | |
| return title_embeddings | |
| def get_retrieval_results(index, input_text, top_k, model, title_df): | |
| query_embeddings = model.encode([input_text]) | |
| _, ids = index.search(x=query_embeddings, k=top_k) | |
| retrieved_titles = [] | |
| retrieved_pids = [] | |
| for id in ids[0]: | |
| retrieved_titles.append(title_df.loc[id, "title"]) | |
| retrieved_pids.append(title_df.loc[id, "pid"]) | |
| df = pd.DataFrame({ | |
| "pid": retrieved_pids, | |
| "paper": retrieved_titles, | |
| "pdf": [f'https://www.anlp.jp/proceedings/annual_meeting/2025/pdf_dir/{pid}.pdf' for pid in retrieved_pids] | |
| }) | |
| return df | |
| if __name__ == "__main__": | |
| model = load_model() | |
| title_df = load_title_data() | |
| title_embeddings = load_title_embeddings() | |
| index = faiss.IndexFlatL2(1792) | |
| index.add(title_embeddings) | |
| st.markdown("## NLP2025 論文検索") | |
| st.html(f"大会公式ページは<a href='https://www.anlp.jp/proceedings/annual_meeting/2025/' target='_blank'>こちら</a>") | |
| input_text = st.text_input('query', '', placeholder='') | |
| top_k = st.number_input('top_k', min_value=1, value=10, step=1) | |
| column_config = { | |
| "pdf": LinkColumn( | |
| display_text="🔗" | |
| ) | |
| } | |
| if st.button('検索'): | |
| stripped_input_text = input_text.strip() | |
| df = get_retrieval_results(index, stripped_input_text, top_k, model, title_df) | |
| st.dataframe(df, column_config=column_config, width=720) | |