File size: 5,022 Bytes
f36fe6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eae8d97
 
 
f36fe6f
eae8d97
 
 
f36fe6f
 
eae8d97
 
 
f36fe6f
 
 
 
 
 
 
 
 
 
 
eae8d97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f36fe6f
eae8d97
 
f36fe6f
eae8d97
f36fe6f
 
eae8d97
f36fe6f
 
 
 
 
 
eae8d97
f36fe6f
 
eae8d97
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import spaces
import torch
import numpy as np
from typing import Generator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE

# Global variables to store the model and tokenizer
# These are loaded under the GPU context to minimize overhead on subsequent calls.
tokenizer = None
model = None

def initialize_model():
    """Initializes and loads the model and tokenizer once onto the GPU."""
    global tokenizer, model
    if model is None:
        try:
            print(f"Loading model {MODEL_NAME}...")
            
            # Use bfloat16 for efficiency on modern GPUs (e.g., H100, A100)
            dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
            
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_NAME,
                torch_dtype=dtype,
                device_map="auto" # Automatically handles device placement (GPU)
            )
            model.eval()
            
            # Set padding token if not defined (common for Causal LMs)
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token_id = tokenizer.eos_token_id
                
            print("Model loaded successfully.")
        except Exception as e:
            print(f"Failed to load model: {e}")
            raise
    return tokenizer, model

# Call initialization immediately to ensure the model is ready when the worker starts up
# Note: This runs in the global scope, relying on the worker environment managing the GPU context.
try:
    initialize_model()
except Exception as e:
    print(f"Warning: Global model initialization failed: {e}. It will be re-attempted during the first inference call.")


@spaces.GPU(duration=120)
def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
    """
    Generates a response from the KAT model, streaming output token by token.
    
    Args:
        prompt: The current user input.
        history: The accumulated chat history (list of [user_msg, bot_msg] tuples).
    
    Yields:
        str: Accumulated text response chunk.
    """
    global tokenizer, model
    
    # Fallback initialization in case global loading failed
    if model is None or tokenizer is None:
        initialize_model()

    # Convert Gradio history format to the model's chat template format
    messages = []
    for human, bot in history:
        # Add past exchanges
        if human:
            messages.append({
"role": "user", "content": human
})
        if bot:
            messages.append({
"role": "assistant", "content": bot
})

    # Add the current prompt
    messages.append({
"role": "user", "content": prompt
})

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    
    # Prepare inputs and move to model device
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # Create a custom streamer that works with Gradio
    class GradioStreamer:
        def __init__(self, tokenizer):
            self.tokenizer = tokenizer
            self.text_queue = []
            self.generated_text = ""
            
        def put(self, value):
            # Decode the new tokens and add to queue
            if isinstance(value, torch.Tensor):
                new_text = self.tokenizer.decode(value, skip_special_tokens=True)
                # Only yield the new part
                if new_text.startswith(self.generated_text):
                    new_part = new_text[len(self.generated_text):]
                    if new_part:
                        self.text_queue.append(new_part)
                        self.generated_text = new_text
                else:
                    # Sometimes the decoding might not align perfectly
                    self.text_queue.append(new_text)
                    self.generated_text = new_text
                    
        def end(self):
            pass
            
        def __iter__(self):
            return iter(self.text_queue)
    
    # Create our custom streamer
    gradio_streamer = GradioStreamer(tokenizer)
    
    # Generate with streaming
    input_ids = model_inputs.input_ids
    
    # Generate tokens one by one for true streaming
    generated_ids = model.generate(
        input_ids=input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=DO_SAMPLE,
        temperature=TEMPERATURE,
        pad_token_id=tokenizer.eos_token_id,
        streamer=gradio_streamer,
        repetition_penalty=1.1,
    )
    
    # Yield the text as it's generated
    accumulated_text = ""
    for new_chunk in gradio_streamer.text_queue:
        accumulated_text += new_chunk
        yield accumulated_text
    
    # Final yield to ensure complete text is sent
    if accumulated_text:
        yield accumulated_text.strip()