| """Gradio interface for nanochat model.""" | |
| from __future__ import annotations | |
| import os | |
| from collections.abc import Generator | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from model import NanochatModel | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/nanochat") | |
| MODEL_DIR = os.environ.get("MODEL_DIR", "./model_cache") | |
| _model: NanochatModel | None = None | |
| def download_model() -> None: | |
| """Download the model from Hugging Face if needed.""" | |
| model_path = Path(MODEL_DIR) | |
| if not model_path.exists() or not any(model_path.iterdir()): | |
| snapshot_download( | |
| repo_id=MODEL_REPO, | |
| local_dir=MODEL_DIR, | |
| ) | |
| def load_model() -> None: | |
| """Load the nanochat model.""" | |
| global _model | |
| if _model is None: | |
| download_model() | |
| _model = NanochatModel(model_dir=MODEL_DIR, device="cpu") | |
| load_model() | |
| def respond( | |
| message: str, | |
| history: list[dict[str, str]], | |
| temperature: float, | |
| top_k: int, | |
| ) -> Generator[str, Any, None]: | |
| """Generate a response using the nanochat model. | |
| Args: | |
| message: User's input message | |
| history: Chat history in Gradio messages format | |
| temperature: Sampling temperature | |
| top_k: Top-k sampling parameter | |
| Yields: | |
| Incrementally generated response text | |
| """ | |
| conversation = [] | |
| for msg in history: | |
| conversation.append(msg) | |
| conversation.append({"role": "user", "content": message}) | |
| response = "" | |
| for token in _model.generate( | |
| history=conversation, | |
| max_tokens=512, | |
| temperature=temperature, | |
| top_k=top_k, | |
| ): | |
| response += token | |
| yield response | |
| chatbot = gr.ChatInterface( | |
| respond, | |
| type="messages", | |
| additional_inputs=[ | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=200, | |
| value=50, | |
| step=1, | |
| label="Top-k sampling", | |
| ), | |
| ], | |
| ) | |
| with gr.Blocks(title="nanochat") as demo: | |
| gr.Markdown("# nanochat") | |
| gr.Markdown("Chat with an AI trained in 4 hours for $100") | |
| chatbot.render() | |
| if __name__ == "__main__": | |
| demo.launch() |