Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import faiss | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # -------------------- Models -------------------- | |
| def load_embedding_model(): | |
| model_name = "paraphrase-MiniLM-L6-v2" | |
| return SentenceTransformer(model_name) | |
| embedding_model = load_embedding_model() | |
| def load_hf_model(): | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # small-ish, but still heavy | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype="auto" | |
| ) | |
| text_gen = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, # smaller = faster | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.9, | |
| ) | |
| return text_gen | |
| text_gen = load_hf_model() | |
| # -------------------- Index building (cached) -------------------- | |
| def build_index(): | |
| """Load PDF, split, embed, and build FAISS index once.""" | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| pdf_path = os.path.join(BASE_DIR, "About_india.pdf") | |
| if not os.path.exists(pdf_path): | |
| raise FileNotFoundError(f"PDF not found at: {pdf_path}") | |
| loader = PyPDFLoader(pdf_path) | |
| documents = loader.load() | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=150, chunk_overlap=30) | |
| # Split all docs at once | |
| chunks = splitter.split_documents(documents) | |
| embeddings = [] | |
| documents_text = [] | |
| sources = [] | |
| for i, doc in enumerate(chunks): | |
| text = doc.page_content | |
| emb = embedding_model.encode(text) | |
| embeddings.append(emb) | |
| documents_text.append(text) | |
| # could store page/chunk info here; using placeholder | |
| sources.append(f"chunk_{i+1}") | |
| if not embeddings: | |
| raise ValueError("No chunks were created from the PDF.") | |
| embedding_dimension = len(embeddings[0]) | |
| index = faiss.IndexFlatL2(embedding_dimension) | |
| index.add(np.array(embeddings, dtype="float32")) | |
| df = pd.DataFrame({"documents": documents_text, "source": sources}) | |
| return index, df | |
| # -------------------- Streamlit UI -------------------- | |
| st.title("π Ask about the states and capitals of India and USA. Q&A (RAG)") | |
| #st.write("Ask about the states and capitals of India and USA.") | |
| # Build or load index (only first time is heavy) | |
| with st.spinner("Loading document and building index (first time only)..."): | |
| index, df = build_index() | |
| query = st.text_input("Enter your query:") | |
| if query: | |
| with st.spinner("Kaira Agent Fetching Please wait....."): | |
| # Embed query | |
| query_embedding = embedding_model.encode(query).reshape(1, -1) | |
| distances, indices = index.search(query_embedding, k=5) | |
| # For L2 distance: smaller = closer. Let's just always answer. | |
| top_dist = float(distances[0][0]) | |
| #st.write(f"Debug: top distance = {top_dist:.2f}") | |
| combined_similar_documents_content = [] | |
| similar_documents_sources = [] | |
| for i in indices[0]: | |
| similar_document_content = df.loc[i, "documents"] | |
| combined_similar_documents_content.append(similar_document_content) | |
| similar_document_source = df.loc[i, "source"] | |
| similar_documents_sources.append(similar_document_source) | |
| context_text = " ".join(combined_similar_documents_content) | |
| prompt = f""" | |
| You are a helpful assistant answering questions about countries, states, and capital cities and population and exports. | |
| Context: | |
| \"\"\"{context_text}\"\"\" | |
| Question: {query} | |
| Answer in one short sentence for countries, states, and capital cities asked in the above sentense: | |
| """ | |
| # Generate answer | |
| result = text_gen( | |
| prompt, | |
| max_new_tokens=32, # short answer | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.9,)[0]["generated_text"] | |
| # Remove the prompt prefix from the output | |
| answer = result[len(prompt):].strip() | |
| st.subheader("π€ Kaira Agent Response") | |
| st.write(answer) | |
| #st.subheader("π Sources (chunks used)") | |
| #st.write(list(set(similar_documents_sources))) | |