KAT-Dev / models.py
akhaliq's picture
akhaliq HF Staff
Update Gradio app with multiple files
eae8d97 verified
raw
history blame
5.02 kB
import spaces
import torch
import numpy as np
from typing import Generator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
# Global variables to store the model and tokenizer
# These are loaded under the GPU context to minimize overhead on subsequent calls.
tokenizer = None
model = None
def initialize_model():
"""Initializes and loads the model and tokenizer once onto the GPU."""
global tokenizer, model
if model is None:
try:
print(f"Loading model {MODEL_NAME}...")
# Use bfloat16 for efficiency on modern GPUs (e.g., H100, A100)
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto" # Automatically handles device placement (GPU)
)
model.eval()
# Set padding token if not defined (common for Causal LMs)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
raise
return tokenizer, model
# Call initialization immediately to ensure the model is ready when the worker starts up
# Note: This runs in the global scope, relying on the worker environment managing the GPU context.
try:
initialize_model()
except Exception as e:
print(f"Warning: Global model initialization failed: {e}. It will be re-attempted during the first inference call.")
@spaces.GPU(duration=120)
def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
"""
Generates a response from the KAT model, streaming output token by token.
Args:
prompt: The current user input.
history: The accumulated chat history (list of [user_msg, bot_msg] tuples).
Yields:
str: Accumulated text response chunk.
"""
global tokenizer, model
# Fallback initialization in case global loading failed
if model is None or tokenizer is None:
initialize_model()
# Convert Gradio history format to the model's chat template format
messages = []
for human, bot in history:
# Add past exchanges
if human:
messages.append({
"role": "user", "content": human
})
if bot:
messages.append({
"role": "assistant", "content": bot
})
# Add the current prompt
messages.append({
"role": "user", "content": prompt
})
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Prepare inputs and move to model device
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Create a custom streamer that works with Gradio
class GradioStreamer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.text_queue = []
self.generated_text = ""
def put(self, value):
# Decode the new tokens and add to queue
if isinstance(value, torch.Tensor):
new_text = self.tokenizer.decode(value, skip_special_tokens=True)
# Only yield the new part
if new_text.startswith(self.generated_text):
new_part = new_text[len(self.generated_text):]
if new_part:
self.text_queue.append(new_part)
self.generated_text = new_text
else:
# Sometimes the decoding might not align perfectly
self.text_queue.append(new_text)
self.generated_text = new_text
def end(self):
pass
def __iter__(self):
return iter(self.text_queue)
# Create our custom streamer
gradio_streamer = GradioStreamer(tokenizer)
# Generate with streaming
input_ids = model_inputs.input_ids
# Generate tokens one by one for true streaming
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=DO_SAMPLE,
temperature=TEMPERATURE,
pad_token_id=tokenizer.eos_token_id,
streamer=gradio_streamer,
repetition_penalty=1.1,
)
# Yield the text as it's generated
accumulated_text = ""
for new_chunk in gradio_streamer.text_queue:
accumulated_text += new_chunk
yield accumulated_text
# Final yield to ensure complete text is sent
if accumulated_text:
yield accumulated_text.strip()