File size: 3,144 Bytes
aa30c3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# pdf_airavata_qa.py

import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import requests
import os

# ----------------------------
# 1. Load and split PDF
# ----------------------------
def load_and_chunk(pdf_path):
    loader = PyPDFLoader(pdf_path)
    docs = loader.load()
    splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    chunks = splitter.split_documents(docs)
    texts = [c.page_content for c in chunks]
    return texts

# ----------------------------
# 2. Build embedding index
# ----------------------------
def build_index(texts):
    embed_model = SentenceTransformer('all-mpnet-base-v2')  # You can choose another model
    embeddings = embed_model.encode(texts, convert_to_numpy=True)
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    return index, embeddings

# ----------------------------
# 3. Airavata API call
# ----------------------------
def call_airavata(prompt):
    # Example: Using HuggingFace Inference API (replace with your key or local endpoint)
    API_URL = "https://api-inference.huggingface.co/models/ai4bharat/airavata"  # Check actual endpoint
    API_TOKEN = os.environ.get("HF_API_TOKEN")  # Set your token in environment
    headers = {"Authorization": f"Bearer {API_TOKEN}"}
    
    payload = {"inputs": prompt}
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code == 200:
        result = response.json()
        return result[0]['generated_text']
    else:
        return f"Error: {response.status_code} - {response.text}"

# ----------------------------
# 4. PDF Q&A function
# ----------------------------
texts, index = [], None

def qa(pdf_file, question):
    global texts, index
    if pdf_file is not None:
        # Load PDF and build index
        texts = load_and_chunk(pdf_file.name)
        index, _ = build_index(texts)
    
    if not texts:
        return "Please upload a PDF first."
    
    # Embed the question
    embed_model = SentenceTransformer('all-mpnet-base-v2')
    q_emb = embed_model.encode([question], convert_to_numpy=True)
    
    # Retrieve top 5 relevant chunks
    D, I = index.search(q_emb, k=5)
    context = "\n\n".join([texts[i] for i in I[0]])
    
    # Build prompt for Airavata
    prompt = f"""
You are an AI assistant. Use the following document context to answer the question.

Context:
{context}

Question: {question}

Answer:
"""
    # Get answer from Airavata
    answer = call_airavata(prompt)
    return answer

# ----------------------------
# 5. Gradio Interface
# ----------------------------
demo = gr.Interface(
    fn=qa,
    inputs=[
        gr.File(label="Upload PDF"),
        gr.Textbox(label="Ask your question", placeholder="Type a question about the PDF...")
    ],
    outputs=gr.Textbox(label="Answer"),
    title="PDF Q&A with Airavata",
    description="Upload a PDF and ask questions. Airavata will answer based on the document."
)

demo.launch()