Spaces:
Running
Running
| import logging | |
| import sqlite3 | |
| import requests | |
| import streamlit as st | |
| from langgraph.errors import GraphRecursionError | |
| from langchain_groq import ChatGroq | |
| from agent import SQLAgentRAG | |
| from tools import retriever | |
| from constant import GROQ_API_KEY, CONFIG | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def download_sqlite_db(url, local_filename='travel.sqlite'): | |
| # Download the file | |
| with requests.get(url, stream=True) as r: | |
| r.raise_for_status() | |
| with open(local_filename, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| return local_filename | |
| # Initialize the language model | |
| llm = ChatGroq( | |
| model="openai/gpt-oss-20b", | |
| api_key=GROQ_API_KEY, | |
| temperature=0.1, | |
| verbose=True | |
| ) | |
| # Initialize SQL Agent | |
| agent = SQLAgentRAG(llm=llm, tools=retriever) | |
| def query_rag_agent(query: str): | |
| """ | |
| Handle a query through the RAG Agent, producing an SQL response if applicable. | |
| Parameters: | |
| - query (str): The input query to process. | |
| Returns: | |
| - Tuple[str, List[str]]: The response content and SQL query if applicable. | |
| Raises: | |
| - GraphRecursionError: If there's a recursion limit reached within the agent's graph. | |
| """ | |
| try: | |
| output = agent.graph.invoke({"messages": query}, CONFIG) | |
| response = output["messages"][-1].content | |
| sql_query = output.get("sql_query", ["No SQL query generated"])[-1] | |
| logger.info(f"Query processed successfully: {query}") | |
| return response, sql_query | |
| except GraphRecursionError: | |
| logger.error("Graph recursion limit reached; query processing failed.") | |
| return "Graph recursion limit reached. No SQL result generated.", "" | |
| def main(): | |
| with st.sidebar: | |
| st.header("About Project") | |
| st.markdown( | |
| """ | |
| RAG (Retrieval-Augmented Generation) Agent SQL is an approach that combines retrieval techniques with text generation to create more relevant and contextualised answers from data, | |
| particularly in SQL databases. RAG-Agent SQL uses two main components: | |
| - Retrieval: Retrieving relevant information from the database based on a given question or input. | |
| - Augmented Generation: Using natural language models (e.g., LLMs such as GPT or LLaMA) to generate more detailed answers, using information from the retrieval results. | |
| to see the architecture can be seen here [Github](https://github.com/fahmiaziz98/sql_agent/tree/main/002sql-agent-ra) | |
| """ | |
| ) | |
| st.header("Example Question") | |
| st.markdown( | |
| """ | |
| - How many different aircraft models are there? And what are the models? | |
| - What is the aircraft model with the longest range? | |
| - Which airports are located in the city of Basel? | |
| - Can you please provide information on what I asked before? | |
| - What are the fare conditions available on Boeing 777-300? | |
| - What is the total amount of bookings made in April 2024? | |
| - What is the scheduled arrival time of flight number QR0051? | |
| - Which car rental services are available in Basel? | |
| - Which seat was assigned to the boarding pass with ticket number 0060005435212351? | |
| - Which trip recommendations are related to history in Basel? | |
| - How many tickets were sold for Business class on flight 30625? | |
| - Which hotels are located in Zurich? | |
| """ | |
| ) | |
| # Main Application Title | |
| st.title("RAG SQL-Agent") | |
| # Initialize session state for storing chat messages | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Display conversation history from session state | |
| for message in st.session_state.messages: | |
| role = message.get("role", "assistant") | |
| with st.chat_message(role): | |
| if "output" in message: | |
| st.markdown(message["output"]) | |
| if "sql_query" in message and message["sql_query"]: | |
| with st.expander("SQL Query", expanded=True): | |
| st.code(message["sql_query"]) | |
| # Input form for user prompt | |
| if prompt := st.chat_input("What do you want to know?"): | |
| st.chat_message("user").markdown(prompt) | |
| st.session_state.messages.append({"role": "user", "output": prompt}) | |
| # Fetch response from RAG agent function directly | |
| with st.spinner("Searching for an answer..."): | |
| output_text, sql_query = query_rag_agent(prompt) | |
| # Display assistant response and SQL query | |
| st.chat_message("assistant").markdown(output_text) | |
| if sql_query: | |
| with st.expander("SQL Query", expanded=True): | |
| st.code(sql_query) | |
| # Append assistant response to session state | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "output": output_text, | |
| "sql_query": sql_query, | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| URL = "https://storage.googleapis.com/benchmarks-artifacts/travel-db/travel2.sqlite" | |
| download_sqlite_db(URL) | |
| main() | |