Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import onnxruntime as ort | |
| from inference.onnx_inference import generate_text, sequence_breaker_strings | |
| from inference.model import ByteTokenizer | |
| # --- Globals --- | |
| MODEL_OPTIONS = [ | |
| ("DAT-Byte Small (200M)", "small", True), | |
| ("DAT-Byte Medium", "medium", False), | |
| ("DAT-Byte Large", "large", False), | |
| ] | |
| ONNX_PATH = "models/small.onnx" # Assumes model.onnx is in the root directory | |
| # Cache for the ONNX session | |
| SESSION_CACHE = {} | |
| TOKENIZER = ByteTokenizer() | |
| # Prepare sequence breakers | |
| SEQUENCE_BREAKER_IDS = {TOKENIZER.im_start_id, TOKENIZER.im_end_id} | |
| for s in sequence_breaker_strings: | |
| # These are single-byte tokens, so encode will return a list with one ID | |
| try: | |
| SEQUENCE_BREAKER_IDS.add(TOKENIZER.encode(s.encode("utf-8"))[0]) | |
| except IndexError: | |
| print(f"Warning: Could not encode sequence breaker string: {s}") | |
| # --- Model Loading --- | |
| def get_session(model_key): | |
| if model_key != "small": | |
| raise ValueError("Only DAT-Byte Small is available.") | |
| if model_key not in SESSION_CACHE: | |
| if not os.path.exists(ONNX_PATH): | |
| raise FileNotFoundError(f"ONNX model not found at {ONNX_PATH}") | |
| # Using CPUExecutionProvider as per the project's goal | |
| SESSION_CACHE[model_key] = ort.InferenceSession( | |
| ONNX_PATH, providers=["CPUExecutionProvider"] | |
| ) | |
| return SESSION_CACHE[model_key] | |
| # --- Gradio Callbacks --- | |
| def chat_respond( | |
| message, | |
| history, | |
| model_name, | |
| max_tokens, | |
| temperature, | |
| top_k, | |
| dry_range, | |
| dry_allowed_length, | |
| dry_base, | |
| dry_multiplier, | |
| user_role="user", | |
| assistant_role="assistant", | |
| ): | |
| model_key = next( | |
| (key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled), | |
| None, | |
| ) | |
| if not model_key: | |
| history.append({"role": "user", "content": message}) | |
| history.append( | |
| {"role": "assistant", "content": f"Model '{model_name}' is not available."} | |
| ) | |
| return history | |
| history = history or [] | |
| try: | |
| session = get_session(model_key) | |
| except Exception as e: | |
| history.append({"role": "user", "content": message}) | |
| history.append( | |
| {"role": "assistant", "content": f"[Model loading error: {str(e)}]"} | |
| ) | |
| return history | |
| prompt = "" | |
| for turn in history: | |
| prompt += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" | |
| prompt += ( | |
| f"<|im_start|>{user_role}\n{message}<|im_end|>\n<|im_start|>{assistant_role}\n" | |
| ) | |
| generated_text, _ = generate_text( | |
| session=session, | |
| tokenizer=TOKENIZER, | |
| prompt=prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| stop_sequences=["<|im_end|>".encode("utf-8")], | |
| dry_sequence_breakers=SEQUENCE_BREAKER_IDS, | |
| dry_range=dry_range, | |
| dry_allowed_length=dry_allowed_length, | |
| dry_base=dry_base, | |
| dry_multiplier=dry_multiplier, | |
| ) | |
| generated_text = generated_text.decode("utf-8", "ignore") | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": generated_text}) | |
| return history | |
| def completion_respond( | |
| prompt, | |
| model_name, | |
| max_tokens, | |
| temperature, | |
| top_k, | |
| dry_range, | |
| dry_allowed_length, | |
| dry_base, | |
| dry_multiplier, | |
| ): | |
| model_key = next( | |
| (key for name, key, enabled in MODEL_OPTIONS if name == model_name and enabled), | |
| None, | |
| ) | |
| if not model_key: | |
| return f"[Model '{model_name}' is not available or unknown.]" | |
| try: | |
| session = get_session(model_key) | |
| except Exception as e: | |
| return f"[Model loading error: {str(e)}]" | |
| generated_text, _ = generate_text( | |
| session=session, | |
| tokenizer=TOKENIZER, | |
| prompt=prompt, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| dry_sequence_breakers=SEQUENCE_BREAKER_IDS, | |
| dry_range=dry_range, | |
| dry_allowed_length=dry_allowed_length, | |
| dry_base=dry_base, | |
| dry_multiplier=dry_multiplier, | |
| ) | |
| return generated_text | |
| # --- Gradio UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# DAT-Byte Playground (ONNX Accelerated)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Radio( | |
| [opt[0] for opt in MODEL_OPTIONS], | |
| value=MODEL_OPTIONS[0][0], | |
| label="Model", | |
| interactive=True, | |
| ) | |
| gr.Markdown("**Note:** Only DAT-Byte Small is currently available.") | |
| mode_selector = gr.Radio( | |
| ["Chat", "Raw Completion"], value="Chat", label="Mode" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.05, maximum=2.0, value=0.5, step=0.05, label="Temperature" | |
| ) | |
| top_k = gr.Slider(minimum=0, maximum=256, value=15, step=1, label="Top-k") | |
| with gr.Accordion("DRY Sampling (Don't Repeat Yourself)", open=False): | |
| dry_range = gr.Slider( | |
| minimum=0, maximum=2048, value=1024, step=32, label="Range" | |
| ) | |
| dry_allowed_length = gr.Slider( | |
| minimum=1, maximum=64, value=20, step=1, label="Allowed Length" | |
| ) | |
| dry_base = gr.Slider( | |
| minimum=1.0, maximum=5.0, value=2.0, step=0.1, label="Base" | |
| ) | |
| dry_multiplier = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.0, step=0.05, label="Multiplier" | |
| ) | |
| user_role_box = gr.Textbox("user", label="User Role", visible=True) | |
| assistant_role_box = gr.Textbox( | |
| "assistant", label="Assistant Role", visible=True | |
| ) | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(label="Chat", type="messages", height=600) | |
| with gr.Row(): | |
| chat_input = gr.Textbox( | |
| label="Message", placeholder="Type a message...", scale=4 | |
| ) | |
| send_button = gr.Button("Send", scale=1) | |
| completion_input = gr.Textbox(label="Prompt", visible=False) | |
| completion_output = gr.Textbox(label="Completion", visible=False) | |
| # UI Logic | |
| def update_mode(mode): | |
| is_chat = mode == "Chat" | |
| return ( | |
| gr.update(visible=is_chat), # chatbot | |
| gr.update(), # chat_input row - removed visible parameter | |
| gr.update(visible=not is_chat), # completion_input | |
| gr.update(visible=not is_chat), # completion_output | |
| gr.update(visible=is_chat), # user_role_box | |
| gr.update(visible=is_chat), # assistant_role_box | |
| ) | |
| # Create a dummy component to replace chat_input.parent which is causing the Form visibility issue | |
| chat_input_row_visibility = gr.Checkbox( | |
| visible=False, value=True, label="Chat Input Row Visibility" | |
| ) | |
| mode_selector.change( | |
| update_mode, | |
| [mode_selector], | |
| [ | |
| chatbot, | |
| chat_input_row_visibility, # Replaced chat_input.parent with dummy component | |
| completion_input, | |
| completion_output, | |
| user_role_box, | |
| assistant_role_box, | |
| ], | |
| ) | |
| # Add a separate event handler to show/hide the chat input row | |
| def toggle_chat_input_visibility(mode): | |
| is_chat = mode == "Chat" | |
| return gr.update(visible=is_chat) | |
| mode_selector.change( | |
| toggle_chat_input_visibility, | |
| [mode_selector], | |
| [chat_input.parent], | |
| ) | |
| # Event Handlers | |
| chat_inputs = [ | |
| chat_input, | |
| chatbot, | |
| model_selector, | |
| max_tokens, | |
| temperature, | |
| top_k, | |
| dry_range, | |
| dry_allowed_length, | |
| dry_base, | |
| dry_multiplier, | |
| user_role_box, | |
| assistant_role_box, | |
| ] | |
| chat_args = {"fn": chat_respond, "inputs": chat_inputs, "outputs": [chatbot]} | |
| def clear_input(): | |
| return "" | |
| clear_args = {"fn": clear_input, "inputs": [], "outputs": [chat_input]} | |
| send_button.click(**chat_args).then(**clear_args) | |
| chat_input.submit(**chat_args).then(**clear_args) | |
| completion_inputs = [ | |
| completion_input, | |
| model_selector, | |
| max_tokens, | |
| temperature, | |
| top_k, | |
| dry_range, | |
| dry_allowed_length, | |
| dry_base, | |
| dry_multiplier, | |
| ] | |
| completion_input.submit( | |
| completion_respond, | |
| completion_inputs, | |
| [completion_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |