tiny-math-llm / custom_run.py
anujbhatt4ai's picture
Initial upload of TinyLLM
13c35e3 verified
import torch
import torch.nn.functional as F
import os
import sys
# --- Ensure src folder is in the path for imports ---
# This helps the script find model.py, tokenizer.py, etc.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
# --- Import all project components ---
from src.tokenizer import generate_v1_data, CharacterTokenizer
from src.model import TinyLLM, n_embed, n_head, n_layer, dropout # Also import hyperparams
# --- Configuration (CHECK THIS PATH!) ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Use the file name confirmed in your last successful training run
WEIGHTS_PATH = 'data/tinyllm_v1_weights1.pt'
@torch.no_grad()
def generate(model, idx, max_new_tokens):
"""
Takes a sequence of indices (idx) and generates max_new_tokens new indices
using the model autoregressively.
"""
model.eval() # Set model to evaluation mode
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop context to the model's block size (block_size will be set below)
block_size = model.block_size
idx_cond = idx[:, -block_size:]
# Get predictions
logits, _ = model(idx_cond)
# Focus only on the last time step (the next token)
logits = logits[:, -1, :]
# Apply softmax to get probabilities
probs = F.softmax(logits, dim=-1)
# Sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
def setup_inference():
"""Sets up the model, tokenizer, and loads weights for inference."""
try:
# 1. Setup Data Pipeline to determine sequence lengths
raw_data = generate_v1_data()
tokenizer = CharacterTokenizer(raw_data)
max_len = max(len(s) for s in raw_data)
# FIX: Ensure block_size matches the model's training size (14)
# block_size is the maximum sequence length (T) the model can handle
block_size = max_len # Use max_len directly to get the 14 size for the V1 dataset
# 2. Initialize Model Architecture
model = TinyLLM(
vocab_size=tokenizer.vocab_size,
n_embed=n_embed,
n_head=n_head,
n_layer=n_layer,
block_size=block_size,
dropout=dropout
).to(DEVICE)
# 3. Load Trained Weights
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
print(f"\nSuccessfully loaded model weights from {WEIGHTS_PATH}")
return model, tokenizer, block_size
except FileNotFoundError:
print(f"Error: Weights file not found at {WEIGHTS_PATH}. Please run train.py first.")
return None, None, None
except RuntimeError as e:
print(f"Runtime Error during loading: {e}")
print("Please ensure your src/model.py hyperparameters match the saved weights.")
return None, None, None
def solve_problem(model, tokenizer, question_str, block_size):
"""Encodes a question, generates the answer, and prints the result."""
# 1. Encode the question string (e.g., "5 + 3")
context_tokens = tokenizer.encode(question_str)
# Add an extra space before the = for clean formatting
context_tokens.append(tokenizer.encode(' ')[0])
# Convert list of token IDs to a PyTorch tensor (1, T)
idx = torch.tensor([context_tokens], dtype=torch.long, device=DEVICE)
# 2. Generate the rest of the sequence (the "= ANS" part)
# The max_len is the length of the expected output: = 9 (4 characters)
max_new_tokens = block_size - idx.shape[1]
if max_new_tokens <= 0:
print("Error: Input sequence is too long.")
return
# Generate tokens
generated_idx = generate(model, idx, max_new_tokens=max_new_tokens)
# 3. Decode the result and print
generated_sequence = tokenizer.decode(generated_idx[0].tolist())
print(f"Question: '{question_str}'")
print(f"Model Output: '{generated_sequence}'")
# --- Main Interactive User Loop ---
if __name__ == '__main__':
model, tokenizer, block_size = setup_inference()
if model is not None:
print("\n--- TinyLLM Math Chatbot Initialized ---")
print("Enter a single-digit math problem (e.g., 4 + 5, 8 / 2).")
print("Type 'exit' to quit.")
while True:
# 1. Get user input
question_str = input("Input: ")
if question_str.lower() == 'exit':
break
# 2. Basic Input Validation
question_str = question_str.strip()
parts = question_str.split()
# Simple check for format N op N and single digits
is_valid = (
len(parts) == 3 and
parts[0].isdigit() and len(parts[0]) == 1 and
parts[2].isdigit() and len(parts[2]) == 1 and
parts[1] in ['+', '-', '*', '/']
)
if not is_valid:
print("Error: Please enter a problem in the format 'N op N' with single-digit operands (e.g., 2 + 3).\n")
continue
# 3. Solve the problem using the trained model
solve_problem(model, tokenizer, question_str, block_size)
print("-" * 30)
print("\n--- Chatbot Shutting Down ---")