|
|
""" |
|
|
Zen Oracle - Gradio App for HuggingFace Spaces |
|
|
|
|
|
Provides both a web UI and JSON API endpoint. |
|
|
API: POST /api/predict with {"data": ["question", "style"]} |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM |
|
|
import os |
|
|
|
|
|
|
|
|
INTERPRETATION_PROMPT = "You are an old zen master reading a koan. A student asked: \"{question}\" The master replied: \"{answer}\". Write a short capping verse for this koan, 4-8 lines, and nothing else." |
|
|
|
|
|
|
|
|
oracle_model = None |
|
|
oracle_tokenizer = None |
|
|
interpreter_model = None |
|
|
interpreter_tokenizer = None |
|
|
|
|
|
|
|
|
def load_models(): |
|
|
"""Load models on startup.""" |
|
|
global oracle_model, oracle_tokenizer, interpreter_model, interpreter_tokenizer |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print("Loading oracle model...") |
|
|
checkpoint_path = os.environ.get("ORACLE_CHECKPOINT", "checkpoints/best_model.pt") |
|
|
|
|
|
if os.path.exists(checkpoint_path): |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
config = checkpoint.get('config', {}) |
|
|
model_name = config.get('model', {}).get('name', 'google/flan-t5-small') |
|
|
|
|
|
oracle_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
oracle_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
oracle_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
|
|
|
print("No checkpoint found, using base Flan-T5-small") |
|
|
oracle_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small') |
|
|
oracle_model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small') |
|
|
|
|
|
oracle_model.to(device) |
|
|
oracle_model.eval() |
|
|
|
|
|
|
|
|
print("Loading interpreter model...") |
|
|
interpreter_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") |
|
|
interpreter_model = AutoModelForCausalLM.from_pretrained( |
|
|
"Qwen/Qwen2.5-1.5B-Instruct", |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
) |
|
|
interpreter_model.to(device) |
|
|
interpreter_model.eval() |
|
|
|
|
|
print("Models loaded!") |
|
|
|
|
|
|
|
|
def generate_answer(question: str) -> str: |
|
|
"""Generate zen answer from oracle.""" |
|
|
device = next(oracle_model.parameters()).device |
|
|
inputs = oracle_tokenizer(question, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = oracle_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=150, |
|
|
temperature=0.8, |
|
|
top_p=0.85, |
|
|
top_k=40, |
|
|
repetition_penalty=1.3, |
|
|
do_sample=True, |
|
|
pad_token_id=oracle_tokenizer.pad_token_id, |
|
|
eos_token_id=oracle_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
return oracle_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
def generate_interpretation(question: str, answer: str) -> str: |
|
|
"""Generate interpretation using Qwen2.5.""" |
|
|
device = next(interpreter_model.parameters()).device |
|
|
|
|
|
prompt = INTERPRETATION_PROMPT.format(question=question, answer=answer) |
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
formatted_prompt = interpreter_tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
inputs = interpreter_tokenizer( |
|
|
formatted_prompt, return_tensors="pt", max_length=512, truncation=True |
|
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = interpreter_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=interpreter_tokenizer.pad_token_id, |
|
|
eos_token_id=interpreter_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = interpreter_tokenizer.decode( |
|
|
outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True |
|
|
) |
|
|
return response.strip() |
|
|
|
|
|
|
|
|
def consult_oracle(question: str) -> dict: |
|
|
""" |
|
|
Consult the zen oracle. |
|
|
|
|
|
Returns a dict with question, answer, and interpretation. |
|
|
""" |
|
|
if not question.strip(): |
|
|
return {"error": "Please enter a question"} |
|
|
|
|
|
answer = generate_answer(question) |
|
|
interpretation = generate_interpretation(question, answer) |
|
|
|
|
|
return { |
|
|
"question": question, |
|
|
"answer": answer, |
|
|
"interpretation": interpretation |
|
|
} |
|
|
|
|
|
|
|
|
def gradio_consult(question: str) -> tuple: |
|
|
"""Gradio interface function.""" |
|
|
result = consult_oracle(question) |
|
|
if "error" in result: |
|
|
return result["error"], "" |
|
|
return result["answer"], result["interpretation"] |
|
|
|
|
|
|
|
|
|
|
|
load_models() |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_consult, |
|
|
inputs=[ |
|
|
gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="What is the mind of no mind?", |
|
|
lines=2 |
|
|
) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Textbox(label="Kaku-ora's Answer"), |
|
|
gr.Textbox(label="Sage Interpretation", lines=6) |
|
|
], |
|
|
title="Kaku-ora", |
|
|
description="Ask the oracle. Receive sage advice.", |
|
|
examples=[ |
|
|
["What is the meaning of life?"], |
|
|
["What is Buddha?"], |
|
|
["How do I find peace?"], |
|
|
], |
|
|
api_name="consult" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|