policy123 commited on
Commit
e8f876f
·
verified ·
1 Parent(s): 96970be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -46
app.py CHANGED
@@ -7,39 +7,35 @@ from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
8
  from transformers import pipeline
9
  import torch
10
- import uvicorn # Add this import
11
 
12
  # 2. Initialize the FastAPI application
13
  app = FastAPI()
14
 
15
  # 3. Add CORS middleware
16
- # This allows our frontend (running on a different domain) to communicate with this backend.
17
  app.add_middleware(
18
  CORSMiddleware,
19
- allow_origins=["*"], # Allows all origins
20
  allow_credentials=True,
21
- allow_methods=["*"], # Allows all methods
22
- allow_headers=["*"], # Allows all headers
23
  )
24
 
25
  # 4. Load the Language Model
26
- # We use a small, efficient model that runs well on free CPU hardware.
27
- # The pipeline will be created only once when the application starts.
28
  try:
29
- # Using a distilled, smaller, but capable model from the community.
30
- # It's specifically designed for summarization and Q&A tasks.
31
- summarizer = pipeline(
32
- "summarization",
33
- model="Xenova/LaMini-Flan-T5-783M",
34
- torch_dtype=torch.bfloat16 # Use a memory-efficient data type
35
  )
36
  print("Model loaded successfully.")
37
  except Exception as e:
38
  print(f"Error loading model: {e}")
39
- summarizer = None
40
 
41
  # 5. Define the data model for the incoming request
42
- # This ensures the data we receive is in the correct format.
43
  class QueryRequest(BaseModel):
44
  document_text: str
45
  query_text: str
@@ -47,55 +43,66 @@ class QueryRequest(BaseModel):
47
  # 6. Define the API endpoint
48
  @app.post("/analyze")
49
  async def analyze_document(request: QueryRequest):
50
- """
51
- This endpoint receives a document and a query, constructs a prompt,
52
- and uses the LLM to generate a structured JSON response.
53
- """
54
- if summarizer is None:
55
- return {"error": "Model is not available."}
56
 
57
- # This is the same prompt engineering we did before.
58
- # We are asking the model to perform a specific task and return a JSON.
59
- prompt = f"""
60
- **CONTEXT:**
61
- You are an expert AI assistant for a claims processing department. Your task is to analyze an insurance policy document and a user's query to make a decision. Based ONLY on the information in the Policy Document, determine if the request should be approved or rejected. Provide your final answer in a strict JSON format. The JSON object must contain three keys: "decision" (string, "Approved" or "Rejected"), "amount" (number, 0 if not applicable), and "justification" (string, explaining your reasoning and citing the policy). Do not use any information outside of the provided Policy Document.
 
 
 
 
 
 
 
 
 
62
 
63
- **Policy Document (Source of Truth):**
64
- ---
65
- {request.document_text}
66
- ---
67
 
68
- **User Query:**
69
- ---
70
- {request.query_text}
71
- ---
72
 
73
- **JSON Response:**
74
- """
 
 
 
 
75
 
76
  try:
77
  # Generate the response from the LLM
78
- # We set max_length to get a reasonably sized response.
79
- result = summarizer(prompt, max_length=512, clean_up_tokenization_spaces=True)
80
- generated_text = result[0]['summary_text']
 
 
 
 
 
 
81
 
82
- # The model might not return perfect JSON, so we clean it up.
83
- # Find the start and end of the JSON object.
84
  json_start = generated_text.find('{')
85
  json_end = generated_text.rfind('}') + 1
86
 
87
- if json_start != -1 and json_end != 0:
88
  cleaned_json = generated_text[json_start:json_end]
89
- # The backend should return the JSON string directly, not a Python dict
90
- # The frontend will parse it.
91
  return cleaned_json
92
  else:
93
- # If no JSON is found, return the raw text with an error flag.
94
  return {"error": "Failed to generate valid JSON.", "raw_output": generated_text}
95
 
96
  except Exception as e:
97
  print(f"Error during analysis: {e}")
98
- return {"error": f"An error occurred: {str(e)}"}
99
 
100
  # A simple root endpoint to confirm the server is running.
101
  @app.get("/")
 
7
  from pydantic import BaseModel
8
  from transformers import pipeline
9
  import torch
10
+ import uvicorn
11
 
12
  # 2. Initialize the FastAPI application
13
  app = FastAPI()
14
 
15
  # 3. Add CORS middleware
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
19
  allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
  )
23
 
24
  # 4. Load the Language Model
25
+ # **UPDATED:** Using a smaller, more efficient model to ensure it loads on free hardware.
 
26
  try:
27
+ generator = pipeline(
28
+ "text-generation",
29
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Switched to TinyLlama
30
+ torch_dtype=torch.bfloat16,
31
+ device_map="auto" # Automatically select device (CPU in this case)
 
32
  )
33
  print("Model loaded successfully.")
34
  except Exception as e:
35
  print(f"Error loading model: {e}")
36
+ generator = None
37
 
38
  # 5. Define the data model for the incoming request
 
39
  class QueryRequest(BaseModel):
40
  document_text: str
41
  query_text: str
 
43
  # 6. Define the API endpoint
44
  @app.post("/analyze")
45
  async def analyze_document(request: QueryRequest):
46
+ if generator is None:
47
+ return {"error": "Model is not available. It may have failed to load due to resource constraints."}
 
 
 
 
48
 
49
+ # **UPDATED:** Using a chat-based prompt format suitable for TinyLlama.
50
+ # This structure helps the model understand its role and the task better.
51
+ messages = [
52
+ {
53
+ "role": "system",
54
+ "content": """You are an expert AI assistant for a claims processing department. Your task is to analyze an insurance policy document and a user's query to make a decision. Based ONLY on the information in the Policy Document, determine if the request should be approved or rejected. Provide your final answer in a strict JSON format. The JSON object must contain three keys: "decision" (string, "Approved" or "Rejected"), "amount" (number, 0 if not applicable), and "justification" (string, explaining your reasoning and citing the policy). Do not use any information outside of the provided Policy Document."""
55
+ },
56
+ {
57
+ "role": "user",
58
+ "content": f"""
59
+ **Policy Document (Source of Truth):**
60
+ ---
61
+ {request.document_text}
62
+ ---
63
 
64
+ **User Query:**
65
+ ---
66
+ {request.query_text}
67
+ ---
68
 
69
+ **JSON Response:**
70
+ """
71
+ }
72
+ ]
73
 
74
+ # The prompt template for the model
75
+ prompt = generator.tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize=False,
78
+ add_generation_prompt=True
79
+ )
80
 
81
  try:
82
  # Generate the response from the LLM
83
+ outputs = generator(
84
+ prompt,
85
+ max_new_tokens=256, # Max tokens for the generated response
86
+ do_sample=True,
87
+ temperature=0.7,
88
+ top_k=50,
89
+ top_p=0.95
90
+ )
91
+ generated_text = outputs[0]["generated_text"]
92
 
93
+ # The model's output will include our prompt, so we find the JSON part.
 
94
  json_start = generated_text.find('{')
95
  json_end = generated_text.rfind('}') + 1
96
 
97
+ if json_start != -1 and json_end > json_start:
98
  cleaned_json = generated_text[json_start:json_end]
 
 
99
  return cleaned_json
100
  else:
 
101
  return {"error": "Failed to generate valid JSON.", "raw_output": generated_text}
102
 
103
  except Exception as e:
104
  print(f"Error during analysis: {e}")
105
+ return {"error": f"An error occurred during analysis: {str(e)}"}
106
 
107
  # A simple root endpoint to confirm the server is running.
108
  @app.get("/")