File size: 3,922 Bytes
07026aa
4594d95
 
07026aa
4594d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2274601
4594d95
 
 
 
 
 
 
 
 
 
 
 
 
 
07026aa
4594d95
 
07026aa
4594d95
 
 
 
 
 
 
 
 
 
07026aa
4594d95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07026aa
 
4594d95
07026aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class VibeThinkerChat:
    def __init__(self, model_path="WeiboAI/VibeThinker-1.5B"):
        print("Loading model and tokenizer...")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True
        )
        print("Model loaded successfully!")
    
    def generate_response(self, prompt, temperature=0.6, max_tokens=40960, top_p=0.95):
        messages = [
            {"role": "user", "content": prompt}
        ]
        
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
        
        generation_config = dict(
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            top_k=1
        )
        
        generated_ids = self.model.generate(
            model_inputs.input_ids,
            **generation_config
        )
        
        generated_ids = [
            output_ids[len(input_ids):] 
            for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        
        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response

# Initialize model
chat_model = VibeThinkerChat()

def chat_interface(message, history, temperature, max_tokens):
    try:
        response = chat_model.generate_response(
            message, 
            temperature=temperature,
            max_tokens=max_tokens
        )
        return response
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="VibeThinker-1.5B Chat") as demo:
    gr.Markdown("# 🧠 VibeThinker-1.5B Chat Interface")
    gr.Markdown("A 1.5B parameter reasoning model optimized for math and coding problems.")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(height=500)
            msg = gr.Textbox(
                label="Your Message",
                placeholder="Ask a math or coding question...",
                lines=3
            )
            with gr.Row():
                submit = gr.Button("Submit", variant="primary")
                clear = gr.Button("Clear")
        
        with gr.Column(scale=1):
            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.6,
                step=0.1,
                label="Temperature",
                info="Recommended: 0.6 or 1.0"
            )
            max_tokens = gr.Slider(
                minimum=512,
                maximum=40960,
                value=4096,
                step=512,
                label="Max Tokens",
                info="Maximum response length"
            )
    
    def user_message(user_msg, history):
        return "", history + [[user_msg, None]]
    
    def bot_response(history, temp, max_tok):
        user_msg = history[-1][0]
        bot_msg = chat_interface(user_msg, history, temp, max_tok)
        history[-1][1] = bot_msg
        return history
    
    msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_response, [chatbot, temperature, max_tokens], chatbot
    )
    submit.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot_response, [chatbot, temperature, max_tokens], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue()
    demo.launch()