""" Zen Oracle - Gradio App for HuggingFace Spaces Provides both a web UI and JSON API endpoint. API: POST /api/predict with {"data": ["question", "style"]} """ import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM import os # Interpretation prompt INTERPRETATION_PROMPT = "You are an old zen master reading a koan. A student asked: \"{question}\" The master replied: \"{answer}\". Write a short capping verse for this koan, 4-8 lines, and nothing else." # Global models (loaded once) oracle_model = None oracle_tokenizer = None interpreter_model = None interpreter_tokenizer = None def load_models(): """Load models on startup.""" global oracle_model, oracle_tokenizer, interpreter_model, interpreter_tokenizer device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load oracle model (fine-tuned Flan-T5-small) print("Loading oracle model...") checkpoint_path = os.environ.get("ORACLE_CHECKPOINT", "checkpoints/best_model.pt") if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') config = checkpoint.get('config', {}) model_name = config.get('model', {}).get('name', 'google/flan-t5-small') oracle_tokenizer = AutoTokenizer.from_pretrained(model_name) oracle_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) oracle_model.load_state_dict(checkpoint['model_state_dict']) else: # Fallback to base model if no checkpoint print("No checkpoint found, using base Flan-T5-small") oracle_tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small') oracle_model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small') oracle_model.to(device) oracle_model.eval() # Load interpreter model (Qwen2.5-1.5B) print("Loading interpreter model...") interpreter_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") interpreter_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2.5-1.5B-Instruct", torch_dtype=torch.float16 if device == "cuda" else torch.float32, ) interpreter_model.to(device) interpreter_model.eval() print("Models loaded!") def generate_answer(question: str) -> str: """Generate zen answer from oracle.""" device = next(oracle_model.parameters()).device inputs = oracle_tokenizer(question, return_tensors="pt").to(device) with torch.no_grad(): outputs = oracle_model.generate( **inputs, max_new_tokens=150, temperature=0.8, top_p=0.85, top_k=40, repetition_penalty=1.3, do_sample=True, pad_token_id=oracle_tokenizer.pad_token_id, eos_token_id=oracle_tokenizer.eos_token_id ) return oracle_tokenizer.decode(outputs[0], skip_special_tokens=True) def generate_interpretation(question: str, answer: str) -> str: """Generate interpretation using Qwen2.5.""" device = next(interpreter_model.parameters()).device prompt = INTERPRETATION_PROMPT.format(question=question, answer=answer) messages = [{"role": "user", "content": prompt}] formatted_prompt = interpreter_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = interpreter_tokenizer( formatted_prompt, return_tensors="pt", max_length=512, truncation=True ).to(device) with torch.no_grad(): outputs = interpreter_model.generate( **inputs, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=interpreter_tokenizer.pad_token_id, eos_token_id=interpreter_tokenizer.eos_token_id ) response = interpreter_tokenizer.decode( outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ) return response.strip() def consult_oracle(question: str) -> dict: """ Consult the zen oracle. Returns a dict with question, answer, and interpretation. """ if not question.strip(): return {"error": "Please enter a question"} answer = generate_answer(question) interpretation = generate_interpretation(question, answer) return { "question": question, "answer": answer, "interpretation": interpretation } def gradio_consult(question: str) -> tuple: """Gradio interface function.""" result = consult_oracle(question) if "error" in result: return result["error"], "" return result["answer"], result["interpretation"] # Load models on import load_models() # Create Gradio interface demo = gr.Interface( fn=gradio_consult, inputs=[ gr.Textbox( label="Your Question", placeholder="What is the mind of no mind?", lines=2 ) ], outputs=[ gr.Textbox(label="Kaku-ora's Answer"), gr.Textbox(label="Sage Interpretation", lines=6) ], title="Kaku-ora", description="Ask the oracle. Receive sage advice.", examples=[ ["What is the meaning of life?"], ["What is Buddha?"], ["How do I find peace?"], ], api_name="consult" # API endpoint: /api/consult ) if __name__ == "__main__": demo.launch()