odia-qa-generator-2 / answer_generation.py
Piyushdash94's picture
Update answer_generation.py
6c0182b verified
raw
history blame
5.63 kB
import os
import json
from dotenv import load_dotenv
from pydantic import BaseModel
import google.generativeai as genai
from fastapi import FastAPI, HTTPException
import uvicorn
from prompts import PROMPTS
from llm_pipeline import example_odia_answer_json, example_odia_question_json
# Setup
load_dotenv()
# Check for required environment variables
google_api_key = os.getenv("GOOGLE_API_KEY")
if not google_api_key:
raise ValueError("GOOGLE_API_KEY not found in environment variables")
genai.configure(api_key=google_api_key)
model = genai.GenerativeModel(os.getenv("LLM_MODEL", "gemini-pro"))
LANGUAGE = "Odia"
# Models
class QuestionRequest(BaseModel):
question: str
class LLMResponseModel(BaseModel):
question_content: str
answer_language: str = LANGUAGE
reasoning_content: str
answer_content: str
def create_prompt(user_odia_question: str) -> str:
SIMPLE_PROMPT = PROMPTS["odia_reasoning_generation_prompt"]
prompt = SIMPLE_PROMPT.format(
user_odia_question=user_odia_question,
example_odia_question_json=example_odia_question_json,
example_answer_json=example_odia_answer_json
)
return prompt
# Functions
def chat_with_model(prompt: str) -> str:
try:
response = model.generate_content(prompt)
return response.text if response.text else "Error: Empty response"
except Exception as e:
return f"Error: {str(e)}"
def clean_json_text(text: str) -> str:
if text.startswith("Error:"):
return text
# Remove markdown code blocks
text = text.strip()
if text.startswith("```"):
lines = text.split('\n')
if len(lines) > 2:
text = '\n'.join(lines[1:-1])
else:
text = text.strip("`").replace("json", "", 1).strip()
# Extract JSON content
first = text.find("{")
last = text.rfind("}")
if first != -1 and last != -1:
return text[first:last+1]
return text
def validate_output(raw_output: str, original_question: str):
cleaned = clean_json_text(raw_output)
if cleaned.startswith("Error:"):
return {
"question_content": original_question,
"answer_language": LANGUAGE,
"reasoning_content": f"Error occurred: {cleaned}",
"answer_content": "Unable to generate answer due to error",
"error": cleaned
}
try:
# Try to parse and validate JSON
parsed_data = json.loads(cleaned)
validated = LLMResponseModel(**parsed_data)
return validated.model_dump()
except json.JSONDecodeError as je:
return {
"question_content": original_question,
"answer_language": LANGUAGE,
"reasoning_content": f"JSON parsing failed: {str(je)}",
"answer_content": "Unable to parse model response",
"error": f"JSON Error: {str(je)}"
}
except Exception as e:
return {
"question_content": original_question,
"answer_language": LANGUAGE,
"reasoning_content": f"Validation failed: {str(e)}",
"answer_content": "Unable to validate model response",
"error": f"Validation Error: {str(e)}"
}
def run_pipeline(question: str):
try:
# Use simple prompt if PROMPTS not available
prompt =create_prompt(user_odia_question=question)
raw_output = chat_with_model(prompt)
return validate_output(raw_output, question)
except Exception as e:
return {
"question_content": question,
"answer_language": LANGUAGE,
"reasoning_content": f"Pipeline error: {str(e)}",
"answer_content": "Unable to process question",
"error": f"Pipeline Error: {str(e)}"
}
# API
app = FastAPI(title="Odia Question Answering API", version="0.1.0")
@app.get("/")
async def root():
return {"message": "Odia Question Answering API is running", "status": "healthy"}
@app.get("/health")
async def health_check():
try:
# Test model connectivity
test_response = model.generate_content("Test")
return {
"status": "healthy",
"model": os.getenv("LLM_MODEL", "gemini-pro"),
"api_configured": bool(google_api_key)
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"api_configured": bool(google_api_key)
}
@app.post("/generate")
async def generate_answer(request: QuestionRequest):
try:
if not request.question.strip():
raise HTTPException(status_code=400, detail="Question cannot be empty")
result = run_pipeline(request.question.strip())
# Check for critical errors that should return 500
if "error" in result and any(err_type in result["error"] for err_type in ["Error: ", "Pipeline Error:"]):
raise HTTPException(status_code=500, detail=f"Processing failed: {result['error']}")
return {"success": True, "data": result}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
if __name__ == "__main__":
print("Starting Odia Question Answering API...")
print(f"Google API Key configured: {'Yes' if google_api_key else 'No'}")
print(f"Model: {os.getenv('LLM_MODEL', 'gemini-pro')}")
host = os.getenv("ANSWER_SERVICE_HOST", "0.0.0.0")
port = int(os.getenv("ANSWER_SERVICE_PORT", "9000"))
uvicorn.run(app, host=0.0.0.0, port=9000)