MedicalRAG-Demo / src /streamlit_app.py
zypchn's picture
Update src/streamlit_app.py
e02a155 verified
import os
import streamlit as st
from dotenv import load_dotenv
from pinecone import ServerlessSpec, Pinecone as PineconeClient
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain import hub
from langgraph.graph import START, END, StateGraph
from typing_extensions import List, TypedDict
from langchain.chat_models import init_chat_model
load_dotenv()
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
os.environ["LANGSMITH_API_KEY"] = os.getenv("LANGSMITH_API_KEY")
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["PINECONE_API_KEY"] = os.getenv("PINECONE_API_KEY")
os.environ["LANGSMITH_PROJECT"] = os.getenv("LANGSMITH_PROJECT")
os.environ["PINECONE_INDEX_NAME"] = os.getenv("PINECONE_INDEX_NAME")
os.environ["LANGSMITH_TRACING_V2"] = "true"
pinecone_api_key = os.environ.get("PINECONE_API_KEY")
index_name = os.environ.get("PINECONE_INDEX_NAME")
pc = PineconeClient(api_key=pinecone_api_key)
index = pc.Index(index_name)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vector_store = PineconeVectorStore(index=index, embedding=embeddings)
prompt = hub.pull("rlm/rag-prompt")
llm = init_chat_model("gpt-4o-mini", model_provider="openai")
class State(TypedDict):
is_medical: bool
query: str
context: List[Document]
response: str
from typing import Literal
def is_medical_query(state: State):
"""
Returns 'True' if the query is relevant to medicine, False otherwise.
"""
prompt = f"""Determine if the following query is related to medicine:
Query: "{state["query"]}"
Answer only with 'True' or 'False'"""
response = llm.invoke(prompt).content.strip().upper()
if response == "TRUE":
return {"is_medical": True}
elif response == "FALSE":
return {"is_medical": False, "response": "Sorduğunuz konu medikal kapsamda değerlendirilmiyor. Lütfen başka bir soru yazınız."}
def retrieve(state: State):
retrieved_docs = vector_store.similarity_search(state["query"], k=7)
return {"context": retrieved_docs}
def generate(state: State):
docs_content = "\n\n".join(doc.page_content for doc in state["context"])
messages = prompt.invoke({
"question": state["query"],
"context": docs_content,
})
response = llm.invoke(messages)
return {"response": response.content}
# Generate workflow
workflow = StateGraph(State)
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.add_node("is_medical_query", is_medical_query)
workflow.add_edge(START, "is_medical_query")
workflow.add_conditional_edges(
"is_medical_query",
lambda state: "retrieve" if state["is_medical"] else END,
{"retrieve": "retrieve", END: END}
)
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
# Streamlit app
st.set_page_config(page_title="MedicalRAG")
st.header("🩺 MedicalRAG (Demo)")
input = st.text_input("", key=input)
submit = st.button("🤔 Sor")
if submit:
with st.spinner("Cevap oluşturuluyor...", show_time=True):
response = app.invoke({
"query": input
})
st.subheader("Cevap :")
st.write(response["response"])