widget-RAG / RAG-demo.py
willsh1997's picture
:sparkles: initial commit
af7c8a8
raw
history blame
3.44 kB
import spaces
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM, GenerationConfig
import torch
from transformers import pipeline
import pandas as pd
import gradio as gr
from googlesearch import search
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
#Llama 3.2 3b setup
llama3_model_id = "meta-llama/Llama-3.2-3B-Instruct"
llama3_pipe = pipeline(
"text-generation",
model=llama3_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
model_kwargs={"quantization_config": quantization_config},
)
#google search setup
def google_search_results(input_question: str):
outputGenerator = search(input_question, num_results=3, advanced=True)
outputs = []
for result in outputGenerator:
outputs.append(result.description)
return outputs
# adding RAG
def RAG_enrichment(input_question: str):
enrichment = google_search_results(input_question)
new_output = input_question + "\n\n Use the following information to help you respond: \n\n"
for info in enrichment:
new_output = new_output + info + "\n\n"
return new_output
@spaces.GPU
def llama_QA(input_question, pipe):
"""
stupid func for asking llama a question and then getting an answer
inputs:
- input_question [str]: question for llama to answer
outputs:
- response [str]: llama's response
"""
messages = [
{"role": "system", "content": "You are a helpful chatbot assistant. Answer all questions in the language they are asked in. Exclude any answer that you do not have real time information, just provide the information you have to answer this question."},
{"role": "user", "content": input_question},
]
outputs = pipe(
messages,
max_new_tokens=512
)
response = outputs[0]["generated_text"][-1]['content']
return response
@spaces.GPU
def gradio_func(input_question, left_lang, right_lang):
"""
silly wrapper function for gradio that turns all inputs into a single func. runs both the LHS and RHS of teh 'app' in order to let gradio work correctly.
"""
input_1 = input_question
input_2 = RAG_enrichment(input_question)
output1 = llama_QA(input_question, llama3_pipe) #future2.result()
output2 = llama_QA(input_2, llama3_pipe) #future3.result()
return input_1, input_2, output1, output2
# Create the Gradio interface
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
question_input = gr.Textbox(label="Enter your question", interactive=True, value = """Who is the current president of the United States?""")
with gr.Row():
submit_btn = gr.Button("Ask")
with gr.Row():
input1 = gr.Textbox(label="Qwen 3 output", interactive=False)
input2 = gr.Textbox(label="Gemma 3 output", interactive=False)
with gr.Row():
output1 = gr.Textbox(label="Qwen 3 output", interactive=False)
output2 = gr.Textbox(label="Gemma 3 output", interactive=False)
submit_btn.click(
fn=gradio_func,
inputs=[question_input],
outputs=[
input1,
input2,
output1,
output2,
]
)
return demo
# Launch the app
demo = create_interface()
demo.launch()