KAT-Dev / app.py
akhaliq's picture
akhaliq HF Staff
Upload app.py with huggingface_hub
0763b5e verified
raw
history blame
5.06 kB
I see the issues! The error is happening because the custom streamer isn't handling the input correctly, and we're not properly setting the attention mask. Let me fix the streaming implementation:
=== models.py ===
import spaces
import torch
import numpy as np
from typing import Generator
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
# Global variables to store the model and tokenizer
tokenizer = None
model = None
def initialize_model():
"""Initializes and loads the model and tokenizer once onto the GPU."""
global tokenizer, model
if model is None:
try:
print(f"Loading model {MODEL_NAME}...")
# Use bfloat16 for efficiency on modern GPUs
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto"
)
model.eval()
# Set padding token if not defined
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
raise
return tokenizer, model
# Call initialization
try:
initialize_model()
except Exception as e:
print(f"Warning: Global model initialization failed: {e}")
@spaces.GPU(duration=120)
def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
"""
Generates a response from the KAT model with proper streaming.
"""
global tokenizer, model
# Fallback initialization
if model is None or tokenizer is None:
initialize_model()
# Convert Gradio history format to the model's chat template format
messages = []
for human, bot in history:
if human:
messages.append({"role": "user", "content": human})
if bot:
messages.append({"role": "assistant", "content": bot})
# Add the current prompt
messages.append({"role": "user", "content": prompt})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize with attention mask
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
# Generate with streaming using yield-based approach
accumulated_text = ""
# Generate tokens incrementally
for _ in range(MAX_NEW_TOKENS):
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
)
# Get next token probabilities
next_token_logits = outputs.logits[:, -1, :]
# Apply temperature
if TEMPERATURE > 0:
next_token_logits = next_token_logits / TEMPERATURE
# Apply softmax and sample
probs = torch.softmax(next_token_logits, dim=-1)
if DO_SAMPLE:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Check for EOS token
if next_token.item() == tokenizer.eos_token_id:
break
# Decode the new token
new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
# Update accumulated text
accumulated_text += new_token_text
# Yield the current accumulated text
yield accumulated_text
# Prepare for next iteration
input_ids = torch.cat([input_ids, next_token], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
# Stop if we've reached max tokens
if input_ids.shape[-1] >= input_ids.shape[-1] + MAX_NEW_TOKENS:
break
# Final yield to ensure complete text
yield accumulated_text.strip()
The key changes I made:
1. **Fixed attention mask**: Now properly sets `attention_mask` when tokenizing to avoid the warning
2. **Simplified streaming**: Using a manual token-by-token generation loop instead of the complex custom streamer
3. **Proper tensor handling**: Correctly handles token tensors and decoding
4. **EOS handling**: Properly stops generation when end-of-sequence token is encountered
5. **Memory efficiency**: Uses `torch.no_grad()` for inference to save memory
This implementation should now properly stream tokens one by one and yield the accumulated text to the Gradio interface for real-time display.