akhaliq HF Staff commited on
Commit
a5660ec
·
verified ·
1 Parent(s): 0763b5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -138
app.py CHANGED
@@ -1,141 +1,41 @@
1
- I see the issues! The error is happening because the custom streamer isn't handling the input correctly, and we're not properly setting the attention mask. Let me fix the streaming implementation:
2
-
3
- === models.py ===
4
- import spaces
5
- import torch
6
- import numpy as np
7
- from typing import Generator
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
- from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
10
-
11
- # Global variables to store the model and tokenizer
12
- tokenizer = None
13
- model = None
14
-
15
- def initialize_model():
16
- """Initializes and loads the model and tokenizer once onto the GPU."""
17
- global tokenizer, model
18
- if model is None:
19
- try:
20
- print(f"Loading model {MODEL_NAME}...")
21
-
22
- # Use bfloat16 for efficiency on modern GPUs
23
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
24
-
25
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- MODEL_NAME,
28
- torch_dtype=dtype,
29
- device_map="auto"
30
- )
31
- model.eval()
32
-
33
- # Set padding token if not defined
34
- if tokenizer.pad_token_id is None:
35
- tokenizer.pad_token_id = tokenizer.eos_token_id
36
-
37
- print("Model loaded successfully.")
38
- except Exception as e:
39
- print(f"Failed to load model: {e}")
40
- raise
41
- return tokenizer, model
42
-
43
- # Call initialization
44
- try:
45
- initialize_model()
46
- except Exception as e:
47
- print(f"Warning: Global model initialization failed: {e}")
48
-
49
- @spaces.GPU(duration=120)
50
- def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
51
- """
52
- Generates a response from the KAT model with proper streaming.
53
- """
54
- global tokenizer, model
55
-
56
- # Fallback initialization
57
- if model is None or tokenizer is None:
58
- initialize_model()
59
-
60
- # Convert Gradio history format to the model's chat template format
61
- messages = []
62
- for human, bot in history:
63
- if human:
64
- messages.append({"role": "user", "content": human})
65
- if bot:
66
- messages.append({"role": "assistant", "content": bot})
67
-
68
- # Add the current prompt
69
- messages.append({"role": "user", "content": prompt})
70
-
71
- # Apply chat template
72
- text = tokenizer.apply_chat_template(
73
- messages,
74
- tokenize=False,
75
- add_generation_prompt=True,
76
  )
77
-
78
- # Tokenize with attention mask
79
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
80
- input_ids = inputs.input_ids.to(model.device)
81
- attention_mask = inputs.attention_mask.to(model.device)
82
-
83
- # Generate with streaming using yield-based approach
84
- accumulated_text = ""
85
-
86
- # Generate tokens incrementally
87
- for _ in range(MAX_NEW_TOKENS):
88
- with torch.no_grad():
89
- outputs = model(
90
- input_ids=input_ids,
91
- attention_mask=attention_mask,
92
- return_dict=True
93
- )
94
-
95
- # Get next token probabilities
96
- next_token_logits = outputs.logits[:, -1, :]
97
-
98
- # Apply temperature
99
- if TEMPERATURE > 0:
100
- next_token_logits = next_token_logits / TEMPERATURE
101
-
102
- # Apply softmax and sample
103
- probs = torch.softmax(next_token_logits, dim=-1)
104
- if DO_SAMPLE:
105
- next_token = torch.multinomial(probs, num_samples=1)
106
- else:
107
- next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
108
-
109
- # Check for EOS token
110
- if next_token.item() == tokenizer.eos_token_id:
111
- break
112
-
113
- # Decode the new token
114
- new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
115
-
116
- # Update accumulated text
117
- accumulated_text += new_token_text
118
-
119
- # Yield the current accumulated text
120
- yield accumulated_text
121
-
122
- # Prepare for next iteration
123
- input_ids = torch.cat([input_ids, next_token], dim=-1)
124
- attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
125
-
126
- # Stop if we've reached max tokens
127
- if input_ids.shape[-1] >= input_ids.shape[-1] + MAX_NEW_TOKENS:
128
- break
129
 
130
- # Final yield to ensure complete text
131
- yield accumulated_text.strip()
132
-
133
- The key changes I made:
134
-
135
- 1. **Fixed attention mask**: Now properly sets `attention_mask` when tokenizing to avoid the warning
136
- 2. **Simplified streaming**: Using a manual token-by-token generation loop instead of the complex custom streamer
137
- 3. **Proper tensor handling**: Correctly handles token tensors and decoding
138
- 4. **EOS handling**: Properly stops generation when end-of-sequence token is encountered
139
- 5. **Memory efficiency**: Uses `torch.no_grad()` for inference to save memory
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- This implementation should now properly stream tokens one by one and yield the accumulated text to the Gradio interface for real-time display.
 
 
1
+ import gradio as gr
2
+ from models import stream_generate_response
3
+
4
+ # Header Link
5
+ ANYCODER_LINK = "<a href='https://huggingface.co/spaces/akhaliq/anycoder' target='_blank'>Built with anycoder</a>"
6
+
7
+ with gr.Blocks(title="KAT-Dev Chat", theme=gr.themes.Soft()) as demo:
8
+ gr.HTML(
9
+ f"""
10
+ <div style="text-align: center; max-width: 800px; margin: 0 auto;">
11
+ <h1>💬 KAT-Dev LLM Chat</h1>
12
+ <p>Powered by Kwaipilot/KAT-Dev, a large language model. This application uses Hugging Face ZeroGPU for highly efficient inference.</p>
13
+ {ANYCODER_LINK}
14
+ </div>
15
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # ChatInterface handles the full conversational UI, streaming, and history management
19
+ chat_interface = gr.ChatInterface(
20
+ fn=stream_generate_response,
21
+ title="", # Title moved to HTML block
22
+ chatbot=gr.Chatbot(
23
+ height=500,
24
+ show_copy_button=True,
25
+ layout="bubble"
26
+ ),
27
+ textbox=gr.Textbox(
28
+ placeholder="Ask the KAT model anything...",
29
+ container=False,
30
+ scale=7
31
+ ),
32
+ # Disable the default submit button text since we have an icon
33
+ submit_btn=True,
34
+ stop_btn=True,
35
+
36
+ # Concurrency limit handled by @spaces.GPU
37
+ concurrency_limit=10,
38
+ )
39
 
40
+ demo.queue()
41
+ demo.launch()