|
|
import os |
|
|
import json |
|
|
from dotenv import load_dotenv |
|
|
from pydantic import BaseModel, ValidationError |
|
|
from typing import List |
|
|
from prompts import PROMPTS |
|
|
import google.generativeai as genai |
|
|
from fastapi import FastAPI, HTTPException |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
google_api_key = os.getenv("GOOGLE_API_KEY") |
|
|
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
|
|
model = genai.GenerativeModel(os.getenv("LLM_MODEL", "gemini-pro")) |
|
|
|
|
|
|
|
|
class TopicRequest(BaseModel): |
|
|
topic: str |
|
|
num_questions: int = 10 |
|
|
|
|
|
class GeneratedQuestionModel(BaseModel): |
|
|
question_language: str |
|
|
question_list: List[str] |
|
|
|
|
|
|
|
|
def chat_with_model(prompt: str) -> str: |
|
|
try: |
|
|
response = model.generate_content(prompt) |
|
|
return response.text if response.text else "Error: Empty response" |
|
|
except Exception as e: |
|
|
return f"Error: {e}" |
|
|
|
|
|
def clean_json_text(text: str) -> str: |
|
|
if text.startswith("Error:"): |
|
|
return text |
|
|
if text.startswith("```"): |
|
|
lines = text.split('\n') |
|
|
text = '\n'.join(lines[1:-1]) if len(lines) > 2 else text.strip("`").replace("json", "", 1).strip() |
|
|
first, last = text.find("{"), text.rfind("}") |
|
|
return text[first:last+1] if first != -1 and last != -1 else text |
|
|
|
|
|
def validate_answer(raw_output: str): |
|
|
cleaned = clean_json_text(raw_output) |
|
|
if cleaned.startswith("Error:"): |
|
|
return {"error": cleaned, "question_language": "Odia", "question_list": []} |
|
|
try: |
|
|
return GeneratedQuestionModel.model_validate_json(cleaned).model_dump() |
|
|
except ValidationError: |
|
|
try: |
|
|
return GeneratedQuestionModel(**json.loads(cleaned)).model_dump() |
|
|
except: |
|
|
return {"error": "Invalid JSON", "question_language": "Odia", "question_list": []} |
|
|
|
|
|
def final_pipeline(user_input: str, num_questions: int = 10): |
|
|
prompt = PROMPTS["questions_only"].format(language="Odia", topic=user_input, num_questions=num_questions) |
|
|
return validate_answer(chat_with_model(prompt)) |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
try: |
|
|
|
|
|
test_response = model.generate_content("Test") |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": os.getenv("LLM_MODEL", "gemini-pro"), |
|
|
"api_configured": bool(google_api_key) |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"status": "unhealthy", |
|
|
"error": str(e), |
|
|
"api_configured": bool(google_api_key) |
|
|
} |
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"message": "Odia Question Generating API is running", "status": "healthy"} |
|
|
|
|
|
|
|
|
@app.post("/generate-questions") |
|
|
async def generate_questions(request: TopicRequest): |
|
|
if not request.topic.strip(): |
|
|
raise HTTPException(status_code=400, detail="Topic cannot be empty") |
|
|
if not 1 <= request.num_questions <= 50: |
|
|
raise HTTPException(status_code=400, detail="Questions must be between 1-50") |
|
|
|
|
|
result = final_pipeline(request.topic.strip(), request.num_questions) |
|
|
|
|
|
if "error" in result and "Error:" in result["error"]: |
|
|
raise HTTPException(status_code=500, detail=result["error"]) |
|
|
|
|
|
return {"success": True, "data": result} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
host = os.getenv("QUESTION_SERVICE_HOST", "0.0.0.0") |
|
|
port = int(os.getenv("QUESTION_SERVICE_PORT", "8000")) |
|
|
uvicorn.run(app, host=0.0.0.0, port=8000) |