nanochat / app.py
axiilay's picture
Update app.py
e2e0202 verified
raw
history blame
2.28 kB
"""Gradio interface for nanochat model."""
from __future__ import annotations
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
import gradio as gr
from huggingface_hub import snapshot_download
from model import NanochatModel
MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/nanochat")
MODEL_DIR = os.environ.get("MODEL_DIR", "./model_cache")
_model: NanochatModel | None = None
def download_model() -> None:
"""Download the model from Hugging Face if needed."""
model_path = Path(MODEL_DIR)
if not model_path.exists() or not any(model_path.iterdir()):
snapshot_download(
repo_id=MODEL_REPO,
local_dir=MODEL_DIR,
)
def load_model() -> None:
"""Load the nanochat model."""
global _model
if _model is None:
download_model()
_model = NanochatModel(model_dir=MODEL_DIR, device="cpu")
load_model()
def respond(
message: str,
history: list[dict[str, str]],
temperature: float,
top_k: int,
) -> Generator[str, Any, None]:
"""Generate a response using the nanochat model.
Args:
message: User's input message
history: Chat history in Gradio messages format
temperature: Sampling temperature
top_k: Top-k sampling parameter
Yields:
Incrementally generated response text
"""
conversation = []
for msg in history:
conversation.append(msg)
conversation.append({"role": "user", "content": message})
response = ""
for token in _model.generate(
history=conversation,
max_tokens=512,
temperature=temperature,
top_k=top_k,
):
response += token
yield response
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=1,
maximum=200,
value=50,
step=1,
label="Top-k sampling",
),
],
)
with gr.Blocks(title="nanochat") as demo:
gr.Markdown("# nanochat")
gr.Markdown("Chat with an AI trained in 4 hours for $100")
chatbot.render()
if __name__ == "__main__":
demo.launch()