File size: 5,423 Bytes
a64e29c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
"""
Zen Oracle - Gradio App for HuggingFace Spaces
Provides both a web UI and JSON API endpoint.
API: POST /api/predict with {"data": ["question", "style"]}
"""
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import os
# Interpretation prompt
INTERPRETATION_PROMPT = "You are an old zen master reading a koan. A student asked: \"{question}\" The master replied: \"{answer}\". Write a short capping verse for this koan, 4-8 lines, and nothing else."
# Global models (loaded once)
oracle_model = None
oracle_tokenizer = None
interpreter_model = None
interpreter_tokenizer = None
def load_models():
"""Load models on startup."""
global oracle_model, oracle_tokenizer, interpreter_model, interpreter_tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load oracle model (fine-tuned Flan-T5-small)
print("Loading oracle model...")
checkpoint_path = os.environ.get("ORACLE_CHECKPOINT", "checkpoints/best_model.pt")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
config = checkpoint.get('config', {})
model_name = config.get('model', {}).get('name', 'google/flan-t5-small')
oracle_tokenizer = AutoTokenizer.from_pretrained(model_name)
oracle_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
oracle_model.load_state_dict(checkpoint['model_state_dict'])
else:
# Fallback to base model if no checkpoint
print("No checkpoint found, using base Flan-T5-small")
oracle_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')
oracle_model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small')
oracle_model.to(device)
oracle_model.eval()
# Load interpreter model (Qwen2.5-1.5B)
print("Loading interpreter model...")
interpreter_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
interpreter_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-1.5B-Instruct",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
interpreter_model.to(device)
interpreter_model.eval()
print("Models loaded!")
def generate_answer(question: str) -> str:
"""Generate zen answer from oracle."""
device = next(oracle_model.parameters()).device
inputs = oracle_tokenizer(question, return_tensors="pt").to(device)
with torch.no_grad():
outputs = oracle_model.generate(
**inputs,
max_new_tokens=150,
temperature=0.8,
top_p=0.85,
top_k=40,
repetition_penalty=1.3,
do_sample=True,
pad_token_id=oracle_tokenizer.pad_token_id,
eos_token_id=oracle_tokenizer.eos_token_id
)
return oracle_tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_interpretation(question: str, answer: str) -> str:
"""Generate interpretation using Qwen2.5."""
device = next(interpreter_model.parameters()).device
prompt = INTERPRETATION_PROMPT.format(question=question, answer=answer)
messages = [{"role": "user", "content": prompt}]
formatted_prompt = interpreter_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = interpreter_tokenizer(
formatted_prompt, return_tensors="pt", max_length=512, truncation=True
).to(device)
with torch.no_grad():
outputs = interpreter_model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=interpreter_tokenizer.pad_token_id,
eos_token_id=interpreter_tokenizer.eos_token_id
)
response = interpreter_tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True
)
return response.strip()
def consult_oracle(question: str) -> dict:
"""
Consult the zen oracle.
Returns a dict with question, answer, and interpretation.
"""
if not question.strip():
return {"error": "Please enter a question"}
answer = generate_answer(question)
interpretation = generate_interpretation(question, answer)
return {
"question": question,
"answer": answer,
"interpretation": interpretation
}
def gradio_consult(question: str) -> tuple:
"""Gradio interface function."""
result = consult_oracle(question)
if "error" in result:
return result["error"], ""
return result["answer"], result["interpretation"]
# Load models on import
load_models()
# Create Gradio interface
demo = gr.Interface(
fn=gradio_consult,
inputs=[
gr.Textbox(
label="Your Question",
placeholder="What is the mind of no mind?",
lines=2
)
],
outputs=[
gr.Textbox(label="Kaku-ora's Answer"),
gr.Textbox(label="Sage Interpretation", lines=6)
],
title="Kaku-ora",
description="Ask the oracle. Receive sage advice.",
examples=[
["What is the meaning of life?"],
["What is Buddha?"],
["How do I find peace?"],
],
api_name="consult" # API endpoint: /api/consult
)
if __name__ == "__main__":
demo.launch()
|