File size: 3,613 Bytes
6b3973e
 
7272a1f
6b3973e
 
 
 
 
 
 
 
 
 
 
 
 
903da37
6b3973e
 
 
 
 
8262b50
 
6b3973e
 
 
 
 
8262b50
 
 
7272a1f
6b3973e
 
7272a1f
6b3973e
 
 
 
 
 
 
 
 
 
 
 
 
 
7272a1f
 
6b3973e
7272a1f
 
 
 
6b3973e
7272a1f
 
6b3973e
 
7272a1f
6b3973e
7272a1f
6b3973e
 
 
7272a1f
6b3973e
 
7272a1f
6b3973e
 
 
 
 
 
 
 
7272a1f
6b3973e
 
 
7272a1f
6b3973e
 
 
 
 
7272a1f
6b3973e
7272a1f
 
 
 
 
 
 
6b3973e
7272a1f
 
 
 
 
6b3973e
7272a1f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# save as app.py
import threading
import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)

MODEL_ID = "EpistemeAI/VibeCoder-20B-alpha-0.001"

# --------- Model load (do this once at startup) ----------
# Adjust dtype / device_map to your environment.
# If you have limited GPU memory, consider: device_map="auto", load_in_8bit=True (requires bitsandbytes)
print("Loading tokenizer and model (this may take a while)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Recommended: try device_map="auto" with accelerate installed; fallback to cpu if not available.
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype="auto",
        device_map="cuda",
    )
except Exception as e:
    print("Automatic device_map load failed, falling back to cpu. Error:", e)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype="auto",
        device_map="auto",
        )

model.eval()
print("Model loaded. Device:", next(model.parameters()).device)

# --------- Helper: build prompt ----------
def build_prompt(system_message: str, history: list[dict], user_message: str) -> str:
    # Keep your conversation structure — adapt to model's preferred format if needed.
    pieces = []
    if system_message:
        pieces.append(f"<|system|>\n{system_message}\n")
    for turn in history:
        role = turn.get("role", "user")
        content = turn.get("content", "")
        pieces.append(f"<|{role}|>\n{content}\n")
    pieces.append(f"<|user|>\n{user_message}\n<|assistant|>\n")
    return "\n".join(pieces)

# --------- Gradio respond function (streams tokens) ----------
def respond(
    message,
    history: list[dict],
    system_message,
    max_tokens,
    temperature,
    top_p,
    hf_token=None,  # kept for compatibility with UI; not used for local pipeline
):
    """
    Streams tokens as they are generated using TextIteratorStreamer.
    Gradio will accept a generator yielding partial response strings.
    """
    prompt = build_prompt(system_message, history or [], message)

    # Prepare inputs
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)

    # Create streamer to yield token-chunks as they are generated
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=int(max_tokens),
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        streamer=streamer,
    )

    # Start generation in background thread
    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    partial = ""
    # Iterate streamer yields token chunks (strings)
    for token_str in streamer:
        partial += token_str
        yield partial

# --------- Build Gradio UI ----------
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        gr.Markdown("Model: " + MODEL_ID)
        gr.LoginButton()
    chatbot.render()

if __name__ == "__main__":
    demo.launch()