KAT-Dev / models.py
akhaliq's picture
akhaliq HF Staff
Deploy Gradio app with multiple files
f36fe6f verified
raw
history blame
4.79 kB
```python
import spaces
import torch
import numpy as np
from typing import Generator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
# Global variables to store the model and tokenizer
# These are loaded under the GPU context to minimize overhead on subsequent calls.
tokenizer = None
model = None
def initialize_model():
"""Initializes and loads the model and tokenizer once onto the GPU."""
global tokenizer, model
if model is None:
try:
print(f"Loading model {MODEL_NAME}...")
# Use bfloat16 for efficiency on modern GPUs (e.g., H100, A100)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto" # Automatically handles device placement (GPU)
)
model.eval()
# Set padding token if not defined (common for Causal LMs)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
raise
return tokenizer, model
# Call initialization immediately to ensure the model is ready when the worker starts up
# Note: This runs in the global scope, relying on the worker environment managing the GPU context.
try:
initialize_model()
except Exception as e:
print(f"Warning: Global model initialization failed: {e}. It will be re-attempted during the first inference call.")
@spaces.GPU(duration=120)
def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
"""
Generates a response from the KAT model, streaming output token by token.
Args:
prompt: The current user input.
history: The accumulated chat history (list of [user_msg, bot_msg] tuples).
Yields:
str: Accumulated text response chunk.
"""
global tokenizer, model
# Fallback initialization in case global loading failed
if model is None or tokenizer is None:
initialize_model()
# Convert Gradio history format to the model's chat template format
messages = []
for human, bot in history:
# Add past exchanges
if human:
messages.append({"role": "user", "content": human})
if bot:
messages.append({"role": "assistant", "content": bot})
# Add the current prompt
messages.append({"role": "user", "content": prompt})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Prepare inputs and move to model device
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Use TextStreamer for efficient token streaming
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Start generation in a separate thread (TextStreamer uses an internal blocking mechanism)
# Since Gradio's generator interface expects synchronous yields from the main thread
# within the @spaces.GPU context, we need to adapt the TextStreamer output.
# A cleaner approach for Gradio streaming is direct model generation without TextStreamer:
input_ids = model_inputs.input_ids
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=DO_SAMPLE,
temperature=TEMPERATURE,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
min_new_tokens=1,
# Enable iterative decoding
repetition_penalty=1.1,
)
full_response = ""
# Process output sequence token by token
for seq in generated_ids.sequences:
# Get the new tokens generated after the prompt
new_tokens = seq[input_ids.shape[-1]:]
# Decode only the newly generated part of the sequence so far
current_response = tokenizer.decode(new_tokens, skip_special_tokens=True)
# Yield only the difference from the previous chunk
if len(current_response) > len(full_response):
new_text = current_response[len(full_response):]
full_response = current_response
yield new_text
# Final cleanup (sometimes the model output is slightly messy)
if full_response:
yield full_response.strip()
```