File size: 5,756 Bytes
13c35e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torch.nn.functional as F
import os
import sys

# --- Ensure src folder is in the path for imports ---
# This helps the script find model.py, tokenizer.py, etc.
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))

# --- Import all project components ---
from src.tokenizer import generate_v1_data, CharacterTokenizer
from src.model import TinyLLM, n_embed, n_head, n_layer, dropout # Also import hyperparams

# --- Configuration (CHECK THIS PATH!) ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Use the file name confirmed in your last successful training run
WEIGHTS_PATH = 'data/tinyllm_v1_weights1.pt' 


@torch.no_grad()
def generate(model, idx, max_new_tokens):
    """

    Takes a sequence of indices (idx) and generates max_new_tokens new indices 

    using the model autoregressively.

    """
    model.eval() # Set model to evaluation mode
    
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # Crop context to the model's block size (block_size will be set below)
        block_size = model.block_size
        idx_cond = idx[:, -block_size:] 
        
        # Get predictions
        logits, _ = model(idx_cond)
        
        # Focus only on the last time step (the next token)
        logits = logits[:, -1, :] 
        
        # Apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1)
        
        # Sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) 
        
        # Append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) 
    
    return idx


def setup_inference():
    """Sets up the model, tokenizer, and loads weights for inference."""
    try:
        # 1. Setup Data Pipeline to determine sequence lengths
        raw_data = generate_v1_data()
        tokenizer = CharacterTokenizer(raw_data)
        max_len = max(len(s) for s in raw_data)
        
        # FIX: Ensure block_size matches the model's training size (14)
        # block_size is the maximum sequence length (T) the model can handle
        block_size = max_len # Use max_len directly to get the 14 size for the V1 dataset
        
        # 2. Initialize Model Architecture
        model = TinyLLM(
            vocab_size=tokenizer.vocab_size,
            n_embed=n_embed,
            n_head=n_head,
            n_layer=n_layer,
            block_size=block_size,
            dropout=dropout
        ).to(DEVICE)
        
        # 3. Load Trained Weights
        model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=DEVICE))
        print(f"\nSuccessfully loaded model weights from {WEIGHTS_PATH}")
        
        return model, tokenizer, block_size
    
    except FileNotFoundError:
        print(f"Error: Weights file not found at {WEIGHTS_PATH}. Please run train.py first.")
        return None, None, None
    except RuntimeError as e:
        print(f"Runtime Error during loading: {e}")
        print("Please ensure your src/model.py hyperparameters match the saved weights.")
        return None, None, None


def solve_problem(model, tokenizer, question_str, block_size):
    """Encodes a question, generates the answer, and prints the result."""
    
    # 1. Encode the question string (e.g., "5 + 3")
    context_tokens = tokenizer.encode(question_str)
    # Add an extra space before the = for clean formatting
    context_tokens.append(tokenizer.encode(' ')[0]) 
    
    # Convert list of token IDs to a PyTorch tensor (1, T)
    idx = torch.tensor([context_tokens], dtype=torch.long, device=DEVICE)
    
    # 2. Generate the rest of the sequence (the "= ANS" part)
    # The max_len is the length of the expected output: = 9 (4 characters)
    max_new_tokens = block_size - idx.shape[1] 
    
    if max_new_tokens <= 0:
        print("Error: Input sequence is too long.")
        return

    # Generate tokens
    generated_idx = generate(model, idx, max_new_tokens=max_new_tokens)
    
    # 3. Decode the result and print
    generated_sequence = tokenizer.decode(generated_idx[0].tolist())
    
    print(f"Question: '{question_str}'")
    print(f"Model Output: '{generated_sequence}'")


# --- Main Interactive User Loop ---
if __name__ == '__main__':
    model, tokenizer, block_size = setup_inference()
    
    if model is not None:
        print("\n--- TinyLLM Math Chatbot Initialized ---")
        print("Enter a single-digit math problem (e.g., 4 + 5, 8 / 2).")
        print("Type 'exit' to quit.")
        
        while True:
            # 1. Get user input
            question_str = input("Input: ")
            
            if question_str.lower() == 'exit':
                break
                
            # 2. Basic Input Validation
            question_str = question_str.strip()
            parts = question_str.split()
            
            # Simple check for format N op N and single digits
            is_valid = (
                len(parts) == 3 and 
                parts[0].isdigit() and len(parts[0]) == 1 and
                parts[2].isdigit() and len(parts[2]) == 1 and
                parts[1] in ['+', '-', '*', '/']
            )

            if not is_valid:
                print("Error: Please enter a problem in the format 'N op N' with single-digit operands (e.g., 2 + 3).\n")
                continue

            # 3. Solve the problem using the trained model
            solve_problem(model, tokenizer, question_str, block_size)
            print("-" * 30)
        
        print("\n--- Chatbot Shutting Down ---")