Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 5 |
from langchain_community.vectorstores import Chroma
|
|
@@ -55,16 +55,12 @@ def get_huggingface_token():
|
|
| 55 |
# ---------------------------
|
| 56 |
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
|
| 57 |
"""
|
| 58 |
-
Returns
|
| 59 |
"""
|
| 60 |
token = get_huggingface_token()
|
| 61 |
-
|
| 62 |
-
headers = {
|
| 63 |
-
"Authorization": f"Bearer {token}",
|
| 64 |
-
"Content-Type": "application/json"
|
| 65 |
-
}
|
| 66 |
|
| 67 |
-
return
|
| 68 |
|
| 69 |
|
| 70 |
# ---------------------------
|
|
@@ -147,43 +143,39 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
|
|
| 147 |
|
| 148 |
try:
|
| 149 |
selected_model = model_choice or MODEL_OPTIONS[0]
|
| 150 |
-
|
| 151 |
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
|
| 152 |
|
| 153 |
# Get relevant documents
|
| 154 |
docs = retriever_obj.invoke(query)
|
| 155 |
context = "\n\n".join(doc.page_content for doc in docs)
|
| 156 |
|
| 157 |
-
# Create
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
{context}
|
| 160 |
|
| 161 |
Question: {query}
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
# Call HuggingFace Inference API directly
|
| 166 |
-
payload = {
|
| 167 |
-
"inputs": prompt,
|
| 168 |
-
"parameters": {
|
| 169 |
-
"max_new_tokens": max_tok,
|
| 170 |
-
"temperature": temp,
|
| 171 |
-
"return_full_text": False
|
| 172 |
}
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
response = requests.post(api_url, headers=headers, json=payload)
|
| 176 |
-
response.raise_for_status()
|
| 177 |
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
-
|
| 181 |
-
if isinstance(result, list) and len(result) > 0:
|
| 182 |
-
return result[0].get("generated_text", str(result))
|
| 183 |
-
elif isinstance(result, dict):
|
| 184 |
-
return result.get("generated_text", str(result))
|
| 185 |
-
else:
|
| 186 |
-
return str(result)
|
| 187 |
except Exception as e:
|
| 188 |
import traceback
|
| 189 |
error_details = traceback.format_exc()
|
|
|
|
| 1 |
import os
|
| 2 |
+
from huggingface_hub import InferenceClient
|
| 3 |
from langchain_huggingface import HuggingFaceEmbeddings
|
| 4 |
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 5 |
from langchain_community.vectorstores import Chroma
|
|
|
|
| 55 |
# ---------------------------
|
| 56 |
def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
|
| 57 |
"""
|
| 58 |
+
Returns InferenceClient for HuggingFace models.
|
| 59 |
"""
|
| 60 |
token = get_huggingface_token()
|
| 61 |
+
client = InferenceClient(token=token)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
return client, model_id, max_tokens, temperature
|
| 64 |
|
| 65 |
|
| 66 |
# ---------------------------
|
|
|
|
| 143 |
|
| 144 |
try:
|
| 145 |
selected_model = model_choice or MODEL_OPTIONS[0]
|
| 146 |
+
client, model_id, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
|
| 147 |
retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
|
| 148 |
|
| 149 |
# Get relevant documents
|
| 150 |
docs = retriever_obj.invoke(query)
|
| 151 |
context = "\n\n".join(doc.page_content for doc in docs)
|
| 152 |
|
| 153 |
+
# Create messages for chat completion
|
| 154 |
+
messages = [
|
| 155 |
+
{
|
| 156 |
+
"role": "system",
|
| 157 |
+
"content": "You are a helpful assistant that answers questions based only on the provided context."
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"role": "user",
|
| 161 |
+
"content": f"""Context:
|
| 162 |
{context}
|
| 163 |
|
| 164 |
Question: {query}
|
| 165 |
|
| 166 |
+
Please answer the question based only on the context provided above."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
}
|
| 168 |
+
]
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
# Call chat completion API
|
| 171 |
+
response = client.chat_completion(
|
| 172 |
+
messages=messages,
|
| 173 |
+
model=model_id,
|
| 174 |
+
max_tokens=max_tok,
|
| 175 |
+
temperature=temp
|
| 176 |
+
)
|
| 177 |
|
| 178 |
+
return response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
except Exception as e:
|
| 180 |
import traceback
|
| 181 |
error_details = traceback.format_exc()
|