|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
import os
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
|
|
|
|
|
|
|
|
|
from src.tokenizer import generate_v1_data, CharacterTokenizer
|
|
|
from src.model import TinyLLM, n_embed, n_head, n_layer, dropout
|
|
|
|
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
WEIGHTS_PATH = 'data/tinyllm_v1_weights1.pt'
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def generate(model, idx, max_new_tokens):
|
|
|
"""
|
|
|
Takes a sequence of indices (idx) and generates max_new_tokens new indices
|
|
|
using the model autoregressively.
|
|
|
"""
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
for _ in range(max_new_tokens):
|
|
|
|
|
|
block_size = model.block_size
|
|
|
idx_cond = idx[:, -block_size:]
|
|
|
|
|
|
|
|
|
logits, _ = model(idx_cond)
|
|
|
|
|
|
|
|
|
logits = logits[:, -1, :]
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
|
|
|
|
|
|
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
idx = torch.cat((idx, idx_next), dim=1)
|
|
|
|
|
|
return idx
|
|
|
|
|
|
|
|
|
def setup_inference():
|
|
|
"""Sets up the model, tokenizer, and loads weights for inference."""
|
|
|
try:
|
|
|
|
|
|
raw_data = generate_v1_data()
|
|
|
tokenizer = CharacterTokenizer(raw_data)
|
|
|
max_len = max(len(s) for s in raw_data)
|
|
|
|
|
|
|
|
|
|
|
|
block_size = max_len
|
|
|
|
|
|
|
|
|
model = TinyLLM(
|
|
|
vocab_size=tokenizer.vocab_size,
|
|
|
n_embed=n_embed,
|
|
|
n_head=n_head,
|
|
|
n_layer=n_layer,
|
|
|
block_size=block_size,
|
|
|
dropout=dropout
|
|
|
).to(DEVICE)
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
|
|
|
print(f"\nSuccessfully loaded model weights from {WEIGHTS_PATH}")
|
|
|
|
|
|
return model, tokenizer, block_size
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
print(f"Error: Weights file not found at {WEIGHTS_PATH}. Please run train.py first.")
|
|
|
return None, None, None
|
|
|
except RuntimeError as e:
|
|
|
print(f"Runtime Error during loading: {e}")
|
|
|
print("Please ensure your src/model.py hyperparameters match the saved weights.")
|
|
|
return None, None, None
|
|
|
|
|
|
|
|
|
def solve_problem(model, tokenizer, question_str, block_size):
|
|
|
"""Encodes a question, generates the answer, and prints the result."""
|
|
|
|
|
|
|
|
|
context_tokens = tokenizer.encode(question_str)
|
|
|
|
|
|
context_tokens.append(tokenizer.encode(' ')[0])
|
|
|
|
|
|
|
|
|
idx = torch.tensor([context_tokens], dtype=torch.long, device=DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
max_new_tokens = block_size - idx.shape[1]
|
|
|
|
|
|
if max_new_tokens <= 0:
|
|
|
print("Error: Input sequence is too long.")
|
|
|
return
|
|
|
|
|
|
|
|
|
generated_idx = generate(model, idx, max_new_tokens=max_new_tokens)
|
|
|
|
|
|
|
|
|
generated_sequence = tokenizer.decode(generated_idx[0].tolist())
|
|
|
|
|
|
print(f"Question: '{question_str}'")
|
|
|
print(f"Model Output: '{generated_sequence}'")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
model, tokenizer, block_size = setup_inference()
|
|
|
|
|
|
if model is not None:
|
|
|
print("\n--- TinyLLM Math Chatbot Initialized ---")
|
|
|
print("Enter a single-digit math problem (e.g., 4 + 5, 8 / 2).")
|
|
|
print("Type 'exit' to quit.")
|
|
|
|
|
|
while True:
|
|
|
|
|
|
question_str = input("Input: ")
|
|
|
|
|
|
if question_str.lower() == 'exit':
|
|
|
break
|
|
|
|
|
|
|
|
|
question_str = question_str.strip()
|
|
|
parts = question_str.split()
|
|
|
|
|
|
|
|
|
is_valid = (
|
|
|
len(parts) == 3 and
|
|
|
parts[0].isdigit() and len(parts[0]) == 1 and
|
|
|
parts[2].isdigit() and len(parts[2]) == 1 and
|
|
|
parts[1] in ['+', '-', '*', '/']
|
|
|
)
|
|
|
|
|
|
if not is_valid:
|
|
|
print("Error: Please enter a problem in the format 'N op N' with single-digit operands (e.g., 2 + 3).\n")
|
|
|
continue
|
|
|
|
|
|
|
|
|
solve_problem(model, tokenizer, question_str, block_size)
|
|
|
print("-" * 30)
|
|
|
|
|
|
print("\n--- Chatbot Shutting Down ---") |