CatoG commited on
Commit
3462e1b
·
verified ·
1 Parent(s): caaf0e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -32
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- import requests
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_text_splitters import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
@@ -55,16 +55,12 @@ def get_huggingface_token():
55
  # ---------------------------
56
  def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
57
  """
58
- Returns API URL, headers, and parameters for HuggingFace Inference API.
59
  """
60
  token = get_huggingface_token()
61
- api_url = f"https://api-inference.huggingface.co/models/{model_id}"
62
- headers = {
63
- "Authorization": f"Bearer {token}",
64
- "Content-Type": "application/json"
65
- }
66
 
67
- return api_url, headers, max_tokens, temperature
68
 
69
 
70
  # ---------------------------
@@ -147,43 +143,39 @@ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_m
147
 
148
  try:
149
  selected_model = model_choice or MODEL_OPTIONS[0]
150
- api_url, headers, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
151
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
152
 
153
  # Get relevant documents
154
  docs = retriever_obj.invoke(query)
155
  context = "\n\n".join(doc.page_content for doc in docs)
156
 
157
- # Create prompt
158
- prompt = f"""Answer the question based only on the following context:
 
 
 
 
 
 
 
159
  {context}
160
 
161
  Question: {query}
162
 
163
- Answer:"""
164
-
165
- # Call HuggingFace Inference API directly
166
- payload = {
167
- "inputs": prompt,
168
- "parameters": {
169
- "max_new_tokens": max_tok,
170
- "temperature": temp,
171
- "return_full_text": False
172
  }
173
- }
174
-
175
- response = requests.post(api_url, headers=headers, json=payload)
176
- response.raise_for_status()
177
 
178
- result = response.json()
 
 
 
 
 
 
179
 
180
- # Handle different response formats
181
- if isinstance(result, list) and len(result) > 0:
182
- return result[0].get("generated_text", str(result))
183
- elif isinstance(result, dict):
184
- return result.get("generated_text", str(result))
185
- else:
186
- return str(result)
187
  except Exception as e:
188
  import traceback
189
  error_details = traceback.format_exc()
 
1
  import os
2
+ from huggingface_hub import InferenceClient
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_text_splitters import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
 
55
  # ---------------------------
56
  def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
57
  """
58
+ Returns InferenceClient for HuggingFace models.
59
  """
60
  token = get_huggingface_token()
61
+ client = InferenceClient(token=token)
 
 
 
 
62
 
63
+ return client, model_id, max_tokens, temperature
64
 
65
 
66
  # ---------------------------
 
143
 
144
  try:
145
  selected_model = model_choice or MODEL_OPTIONS[0]
146
+ client, model_id, max_tok, temp = get_llm(selected_model, int(max_tokens), float(temperature))
147
  retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
148
 
149
  # Get relevant documents
150
  docs = retriever_obj.invoke(query)
151
  context = "\n\n".join(doc.page_content for doc in docs)
152
 
153
+ # Create messages for chat completion
154
+ messages = [
155
+ {
156
+ "role": "system",
157
+ "content": "You are a helpful assistant that answers questions based only on the provided context."
158
+ },
159
+ {
160
+ "role": "user",
161
+ "content": f"""Context:
162
  {context}
163
 
164
  Question: {query}
165
 
166
+ Please answer the question based only on the context provided above."""
 
 
 
 
 
 
 
 
167
  }
168
+ ]
 
 
 
169
 
170
+ # Call chat completion API
171
+ response = client.chat_completion(
172
+ messages=messages,
173
+ model=model_id,
174
+ max_tokens=max_tok,
175
+ temperature=temp
176
+ )
177
 
178
+ return response.choices[0].message.content
 
 
 
 
 
 
179
  except Exception as e:
180
  import traceback
181
  error_details = traceback.format_exc()