kakuora / app.py
gworley3's picture
Upload 2 files
a64e29c verified
raw
history blame
5.42 kB
"""
Zen Oracle - Gradio App for HuggingFace Spaces
Provides both a web UI and JSON API endpoint.
API: POST /api/predict with {"data": ["question", "style"]}
"""
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import os
# Interpretation prompt
INTERPRETATION_PROMPT = "You are an old zen master reading a koan. A student asked: \"{question}\" The master replied: \"{answer}\". Write a short capping verse for this koan, 4-8 lines, and nothing else."
# Global models (loaded once)
oracle_model = None
oracle_tokenizer = None
interpreter_model = None
interpreter_tokenizer = None
def load_models():
"""Load models on startup."""
global oracle_model, oracle_tokenizer, interpreter_model, interpreter_tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load oracle model (fine-tuned Flan-T5-small)
print("Loading oracle model...")
checkpoint_path = os.environ.get("ORACLE_CHECKPOINT", "checkpoints/best_model.pt")
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
config = checkpoint.get('config', {})
model_name = config.get('model', {}).get('name', 'google/flan-t5-small')
oracle_tokenizer = AutoTokenizer.from_pretrained(model_name)
oracle_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
oracle_model.load_state_dict(checkpoint['model_state_dict'])
else:
# Fallback to base model if no checkpoint
print("No checkpoint found, using base Flan-T5-small")
oracle_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small')
oracle_model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small')
oracle_model.to(device)
oracle_model.eval()
# Load interpreter model (Qwen2.5-1.5B)
print("Loading interpreter model...")
interpreter_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
interpreter_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-1.5B-Instruct",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
interpreter_model.to(device)
interpreter_model.eval()
print("Models loaded!")
def generate_answer(question: str) -> str:
"""Generate zen answer from oracle."""
device = next(oracle_model.parameters()).device
inputs = oracle_tokenizer(question, return_tensors="pt").to(device)
with torch.no_grad():
outputs = oracle_model.generate(
**inputs,
max_new_tokens=150,
temperature=0.8,
top_p=0.85,
top_k=40,
repetition_penalty=1.3,
do_sample=True,
pad_token_id=oracle_tokenizer.pad_token_id,
eos_token_id=oracle_tokenizer.eos_token_id
)
return oracle_tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_interpretation(question: str, answer: str) -> str:
"""Generate interpretation using Qwen2.5."""
device = next(interpreter_model.parameters()).device
prompt = INTERPRETATION_PROMPT.format(question=question, answer=answer)
messages = [{"role": "user", "content": prompt}]
formatted_prompt = interpreter_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = interpreter_tokenizer(
formatted_prompt, return_tensors="pt", max_length=512, truncation=True
).to(device)
with torch.no_grad():
outputs = interpreter_model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=interpreter_tokenizer.pad_token_id,
eos_token_id=interpreter_tokenizer.eos_token_id
)
response = interpreter_tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True
)
return response.strip()
def consult_oracle(question: str) -> dict:
"""
Consult the zen oracle.
Returns a dict with question, answer, and interpretation.
"""
if not question.strip():
return {"error": "Please enter a question"}
answer = generate_answer(question)
interpretation = generate_interpretation(question, answer)
return {
"question": question,
"answer": answer,
"interpretation": interpretation
}
def gradio_consult(question: str) -> tuple:
"""Gradio interface function."""
result = consult_oracle(question)
if "error" in result:
return result["error"], ""
return result["answer"], result["interpretation"]
# Load models on import
load_models()
# Create Gradio interface
demo = gr.Interface(
fn=gradio_consult,
inputs=[
gr.Textbox(
label="Your Question",
placeholder="What is the mind of no mind?",
lines=2
)
],
outputs=[
gr.Textbox(label="Kaku-ora's Answer"),
gr.Textbox(label="Sage Interpretation", lines=6)
],
title="Kaku-ora",
description="Ask the oracle. Receive sage advice.",
examples=[
["What is the meaning of life?"],
["What is Buddha?"],
["How do I find peace?"],
],
api_name="consult" # API endpoint: /api/consult
)
if __name__ == "__main__":
demo.launch()