gworley3 commited on
Commit
a64e29c
·
verified ·
1 Parent(s): e4ccc65

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +172 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Zen Oracle - Gradio App for HuggingFace Spaces
3
+
4
+ Provides both a web UI and JSON API endpoint.
5
+ API: POST /api/predict with {"data": ["question", "style"]}
6
+ """
7
+
8
+ import torch
9
+ import gradio as gr
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
11
+ import os
12
+
13
+ # Interpretation prompt
14
+ INTERPRETATION_PROMPT = "You are an old zen master reading a koan. A student asked: \"{question}\" The master replied: \"{answer}\". Write a short capping verse for this koan, 4-8 lines, and nothing else."
15
+
16
+ # Global models (loaded once)
17
+ oracle_model = None
18
+ oracle_tokenizer = None
19
+ interpreter_model = None
20
+ interpreter_tokenizer = None
21
+
22
+
23
+ def load_models():
24
+ """Load models on startup."""
25
+ global oracle_model, oracle_tokenizer, interpreter_model, interpreter_tokenizer
26
+
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"Using device: {device}")
29
+
30
+ # Load oracle model (fine-tuned Flan-T5-small)
31
+ print("Loading oracle model...")
32
+ checkpoint_path = os.environ.get("ORACLE_CHECKPOINT", "checkpoints/best_model.pt")
33
+
34
+ if os.path.exists(checkpoint_path):
35
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
36
+ config = checkpoint.get('config', {})
37
+ model_name = config.get('model', {}).get('name', 'google/flan-t5-small')
38
+
39
+ oracle_tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+ oracle_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
41
+ oracle_model.load_state_dict(checkpoint['model_state_dict'])
42
+ else:
43
+ # Fallback to base model if no checkpoint
44
+ print("No checkpoint found, using base Flan-T5-small")
45
+ oracle_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')
46
+ oracle_model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small')
47
+
48
+ oracle_model.to(device)
49
+ oracle_model.eval()
50
+
51
+ # Load interpreter model (Qwen2.5-1.5B)
52
+ print("Loading interpreter model...")
53
+ interpreter_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
54
+ interpreter_model = AutoModelForCausalLM.from_pretrained(
55
+ "Qwen/Qwen2.5-1.5B-Instruct",
56
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
57
+ )
58
+ interpreter_model.to(device)
59
+ interpreter_model.eval()
60
+
61
+ print("Models loaded!")
62
+
63
+
64
+ def generate_answer(question: str) -> str:
65
+ """Generate zen answer from oracle."""
66
+ device = next(oracle_model.parameters()).device
67
+ inputs = oracle_tokenizer(question, return_tensors="pt").to(device)
68
+
69
+ with torch.no_grad():
70
+ outputs = oracle_model.generate(
71
+ **inputs,
72
+ max_new_tokens=150,
73
+ temperature=0.8,
74
+ top_p=0.85,
75
+ top_k=40,
76
+ repetition_penalty=1.3,
77
+ do_sample=True,
78
+ pad_token_id=oracle_tokenizer.pad_token_id,
79
+ eos_token_id=oracle_tokenizer.eos_token_id
80
+ )
81
+
82
+ return oracle_tokenizer.decode(outputs[0], skip_special_tokens=True)
83
+
84
+
85
+ def generate_interpretation(question: str, answer: str) -> str:
86
+ """Generate interpretation using Qwen2.5."""
87
+ device = next(interpreter_model.parameters()).device
88
+
89
+ prompt = INTERPRETATION_PROMPT.format(question=question, answer=answer)
90
+
91
+ messages = [{"role": "user", "content": prompt}]
92
+ formatted_prompt = interpreter_tokenizer.apply_chat_template(
93
+ messages, tokenize=False, add_generation_prompt=True
94
+ )
95
+
96
+ inputs = interpreter_tokenizer(
97
+ formatted_prompt, return_tensors="pt", max_length=512, truncation=True
98
+ ).to(device)
99
+
100
+ with torch.no_grad():
101
+ outputs = interpreter_model.generate(
102
+ **inputs,
103
+ max_new_tokens=256,
104
+ temperature=0.7,
105
+ top_p=0.9,
106
+ do_sample=True,
107
+ pad_token_id=interpreter_tokenizer.pad_token_id,
108
+ eos_token_id=interpreter_tokenizer.eos_token_id
109
+ )
110
+
111
+ response = interpreter_tokenizer.decode(
112
+ outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True
113
+ )
114
+ return response.strip()
115
+
116
+
117
+ def consult_oracle(question: str) -> dict:
118
+ """
119
+ Consult the zen oracle.
120
+
121
+ Returns a dict with question, answer, and interpretation.
122
+ """
123
+ if not question.strip():
124
+ return {"error": "Please enter a question"}
125
+
126
+ answer = generate_answer(question)
127
+ interpretation = generate_interpretation(question, answer)
128
+
129
+ return {
130
+ "question": question,
131
+ "answer": answer,
132
+ "interpretation": interpretation
133
+ }
134
+
135
+
136
+ def gradio_consult(question: str) -> tuple:
137
+ """Gradio interface function."""
138
+ result = consult_oracle(question)
139
+ if "error" in result:
140
+ return result["error"], ""
141
+ return result["answer"], result["interpretation"]
142
+
143
+
144
+ # Load models on import
145
+ load_models()
146
+
147
+ # Create Gradio interface
148
+ demo = gr.Interface(
149
+ fn=gradio_consult,
150
+ inputs=[
151
+ gr.Textbox(
152
+ label="Your Question",
153
+ placeholder="What is the mind of no mind?",
154
+ lines=2
155
+ )
156
+ ],
157
+ outputs=[
158
+ gr.Textbox(label="Kaku-ora's Answer"),
159
+ gr.Textbox(label="Sage Interpretation", lines=6)
160
+ ],
161
+ title="Kaku-ora",
162
+ description="Ask the oracle. Receive sage advice.",
163
+ examples=[
164
+ ["What is the meaning of life?"],
165
+ ["What is Buddha?"],
166
+ ["How do I find peace?"],
167
+ ],
168
+ api_name="consult" # API endpoint: /api/consult
169
+ )
170
+
171
+ if __name__ == "__main__":
172
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.45.0
3
+ gradio>=4.0.0
4
+ tqdm>=4.65.0
5
+ PyYAML>=6.0