```python import spaces import torch import numpy as np from typing import Generator from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE # Global variables to store the model and tokenizer # These are loaded under the GPU context to minimize overhead on subsequent calls. tokenizer = None model = None def initialize_model(): """Initializes and loads the model and tokenizer once onto the GPU.""" global tokenizer, model if model is None: try: print(f"Loading model {MODEL_NAME}...") # Use bfloat16 for efficiency on modern GPUs (e.g., H100, A100) dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=dtype, device_map="auto" # Automatically handles device placement (GPU) ) model.eval() # Set padding token if not defined (common for Causal LMs) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id print("Model loaded successfully.") except Exception as e: print(f"Failed to load model: {e}") raise return tokenizer, model # Call initialization immediately to ensure the model is ready when the worker starts up # Note: This runs in the global scope, relying on the worker environment managing the GPU context. try: initialize_model() except Exception as e: print(f"Warning: Global model initialization failed: {e}. It will be re-attempted during the first inference call.") @spaces.GPU(duration=120) def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]: """ Generates a response from the KAT model, streaming output token by token. Args: prompt: The current user input. history: The accumulated chat history (list of [user_msg, bot_msg] tuples). Yields: str: Accumulated text response chunk. """ global tokenizer, model # Fallback initialization in case global loading failed if model is None or tokenizer is None: initialize_model() # Convert Gradio history format to the model's chat template format messages = [] for human, bot in history: # Add past exchanges if human: messages.append({"role": "user", "content": human}) if bot: messages.append({"role": "assistant", "content": bot}) # Add the current prompt messages.append({"role": "user", "content": prompt}) # Apply chat template text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Prepare inputs and move to model device model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # Use TextStreamer for efficient token streaming streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Start generation in a separate thread (TextStreamer uses an internal blocking mechanism) # Since Gradio's generator interface expects synchronous yields from the main thread # within the @spaces.GPU context, we need to adapt the TextStreamer output. # A cleaner approach for Gradio streaming is direct model generation without TextStreamer: input_ids = model_inputs.input_ids generated_ids = model.generate( input_ids=input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=DO_SAMPLE, temperature=TEMPERATURE, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=True, min_new_tokens=1, # Enable iterative decoding repetition_penalty=1.1, ) full_response = "" # Process output sequence token by token for seq in generated_ids.sequences: # Get the new tokens generated after the prompt new_tokens = seq[input_ids.shape[-1]:] # Decode only the newly generated part of the sequence so far current_response = tokenizer.decode(new_tokens, skip_special_tokens=True) # Yield only the difference from the previous chunk if len(current_response) > len(full_response): new_text = current_response[len(full_response):] full_response = current_response yield new_text # Final cleanup (sometimes the model output is slightly messy) if full_response: yield full_response.strip() ```