kaira-agent / src /kaira-agent.py
arjuturu's picture
Update src/kaira-agent.py
c5176a8 verified
import streamlit as st
import os
import numpy as np
import pandas as pd
import faiss
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# -------------------- Models --------------------
@st.cache_resource(show_spinner=False)
def load_embedding_model():
model_name = "paraphrase-MiniLM-L6-v2"
return SentenceTransformer(model_name)
embedding_model = load_embedding_model()
@st.cache_resource(show_spinner=True)
def load_hf_model():
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # small-ish, but still heavy
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype="auto"
)
text_gen = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=256, # smaller = faster
do_sample=True,
temperature=0.3,
top_p=0.9,
)
return text_gen
text_gen = load_hf_model()
# -------------------- Index building (cached) --------------------
@st.cache_resource(show_spinner=True)
def build_index():
"""Load PDF, split, embed, and build FAISS index once."""
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
pdf_path = os.path.join(BASE_DIR, "About_india.pdf")
if not os.path.exists(pdf_path):
raise FileNotFoundError(f"PDF not found at: {pdf_path}")
loader = PyPDFLoader(pdf_path)
documents = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=150, chunk_overlap=30)
# Split all docs at once
chunks = splitter.split_documents(documents)
embeddings = []
documents_text = []
sources = []
for i, doc in enumerate(chunks):
text = doc.page_content
emb = embedding_model.encode(text)
embeddings.append(emb)
documents_text.append(text)
# could store page/chunk info here; using placeholder
sources.append(f"chunk_{i+1}")
if not embeddings:
raise ValueError("No chunks were created from the PDF.")
embedding_dimension = len(embeddings[0])
index = faiss.IndexFlatL2(embedding_dimension)
index.add(np.array(embeddings, dtype="float32"))
df = pd.DataFrame({"documents": documents_text, "source": sources})
return index, df
# -------------------- Streamlit UI --------------------
st.title("πŸ“š Ask about the states and capitals of India and USA. Q&A (RAG)")
#st.write("Ask about the states and capitals of India and USA.")
# Build or load index (only first time is heavy)
with st.spinner("Loading document and building index (first time only)..."):
index, df = build_index()
query = st.text_input("Enter your query:")
if query:
with st.spinner("Kaira Agent Fetching Please wait....."):
# Embed query
query_embedding = embedding_model.encode(query).reshape(1, -1)
distances, indices = index.search(query_embedding, k=5)
# For L2 distance: smaller = closer. Let's just always answer.
top_dist = float(distances[0][0])
#st.write(f"Debug: top distance = {top_dist:.2f}")
combined_similar_documents_content = []
similar_documents_sources = []
for i in indices[0]:
similar_document_content = df.loc[i, "documents"]
combined_similar_documents_content.append(similar_document_content)
similar_document_source = df.loc[i, "source"]
similar_documents_sources.append(similar_document_source)
context_text = " ".join(combined_similar_documents_content)
prompt = f"""
You are a helpful assistant answering questions about countries, states, and capital cities and population and exports.
Context:
\"\"\"{context_text}\"\"\"
Question: {query}
Answer in one short sentence for countries, states, and capital cities asked in the above sentense:
"""
# Generate answer
result = text_gen(
prompt,
max_new_tokens=32, # short answer
do_sample=True,
temperature=0.3,
top_p=0.9,)[0]["generated_text"]
# Remove the prompt prefix from the output
answer = result[len(prompt):].strip()
st.subheader("πŸ€– Kaira Agent Response")
st.write(answer)
#st.subheader("πŸ“š Sources (chunks used)")
#st.write(list(set(similar_documents_sources)))