policy123's picture
Update app.py
da51ce6 verified
# app.py
# FINAL CPU VERSION using a quantized model for maximum reliability on free hardware.
# 1. Import necessary libraries
import gradio as gr
# **FIXED:** Import AutoModelForCausalLM from the main transformers library
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
# 2. Load the Quantized Language Model
# This model is optimized to use less memory, making it stable on free CPUs.
try:
model_name_or_path = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
# Load the quantized model using the standard transformers class.
# The installed 'optimum' and 'auto-gptq' libraries will handle the GPTQ format automatically.
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
use_safetensors=True,
trust_remote_code=False,
device_map="auto" # Will automatically use CPU
)
# Create the text generation pipeline
generator = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer
)
print("Quantized model loaded successfully on CPU.")
MODEL_LOADED = True
except Exception as e:
print(f"Error loading quantized model: {e}")
generator = None
MODEL_LOADED = False
# 3. Define the core analysis function
def analyze_document(document_text, query_text):
"""
Analyzes the document based on the query using the loaded LLM.
"""
if not MODEL_LOADED or generator is None:
return {"error": "Model is not available. Please check the Space logs for errors."}
# The chat-based prompt format for TinyLlama
messages = [
{
"role": "system",
"content": """You are an expert AI assistant for a claims processing department. Your task is to analyze an insurance policy document and a user's query to make a decision. Based ONLY on the information in the Policy Document, determine if the request should be approved or rejected. Provide your final answer in a strict JSON format. The JSON object must contain three keys: "decision" (string, "Approved" or "Rejected"), "amount" (number, 0 if not applicable), and "justification" (string, explaining your reasoning and citing the policy). Do not use any information outside of the provided Policy Document."""
},
{
"role": "user",
"content": f"""
**Policy Document (Source of Truth):**
---
{document_text}
---
**User Query:**
---
{query_text}
---
**JSON Response:**
"""
}
]
prompt = generator.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
try:
# Generate the response from the LLM
outputs = generator(
prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.95
)
generated_text = outputs[0]["generated_text"]
# Extract the JSON part from the model's full output
json_start = generated_text.find('{')
json_end = generated_text.rfind('}') + 1
if json_start != -1 and json_end > json_start:
cleaned_json_str = generated_text[json_start:json_end]
import json
return json.loads(cleaned_json_str)
else:
return {"error": "Failed to generate valid JSON.", "raw_output": generated_text}
except Exception as e:
print(f"Error during analysis: {e}")
return {"error": f"An error occurred during analysis: {str(e)}"}
# 4. Create and launch the Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Policy Analysis API (CPU Version)")
gr.Markdown("This Gradio app serves the backend for the RAG policy analysis system, optimized for CPU.")
with gr.Row():
doc_input = gr.Textbox(lines=5, label="Document Text", placeholder="Paste the document text here...")
query_input = gr.Textbox(label="Query Text", placeholder="Enter your query here...")
output_json = gr.JSON(label="Analysis Result")
analyze_btn = gr.Button("Analyze")
analyze_btn.click(
fn=analyze_document,
inputs=[doc_input, query_input],
outputs=output_json,
api_name="analyze"
)
demo.launch()