tiny-math-llm / train.py
anujbhatt4ai's picture
Initial upload of TinyLLM
13c35e3 verified
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
# --- Import all project components ---
from src.tokenizer import generate_v1_data, CharacterTokenizer
from src.dataset import MathDataset
from src.model import TinyLLM, n_embed, n_head, n_layer, dropout # Also import hyperparams
# --- Hyperparameters for Training ---
BATCH_SIZE = 32
LEARNING_RATE = 1e-3 # Standard starting learning rate for Adam
EPOCHS = 100 # Number of full passes over the dataset (Adjust as needed)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def setup_data_pipeline(batch_size=BATCH_SIZE):
"""Sets up the data generation, tokenization, and PyTorch DataLoaders."""
# 1. Generate data and initialize tokenizer
raw_data = generate_v1_data()
tokenizer = CharacterTokenizer(raw_data)
max_len = max(len(s) for s in raw_data)
# 2. Create Dataset and DataLoader
train_dataset = MathDataset(raw_data, tokenizer, max_len)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
print(f"Total problems: {len(raw_data)}")
print(f"Vocabulary Size: {tokenizer.vocab_size}")
print(f"Max Sequence Length (T): {max_len}")
print(f"Device: {DEVICE}")
return train_dataloader, tokenizer.vocab_size, max_len
def train(dataloader, vocab_size, block_size):
"""Initializes the model and runs the full training loop."""
# 1. Initialize Model, Optimizer, and move to Device
model = TinyLLM(
vocab_size=vocab_size,
n_embed=n_embed,
n_head=n_head,
n_layer=n_layer,
block_size=block_size,
dropout=dropout
).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
print(f"TinyLLM Parameters: {sum(p.numel() for p in model.parameters())/1e3:.1f}K")
print(f"Starting training for {EPOCHS} epochs...")
# 2. Training Loop
for epoch in range(EPOCHS):
model.train() # Set model to training mode
total_loss = 0
for batch_idx, (X, Y) in enumerate(dataloader):
# Move data to the selected device (CPU or CUDA)
X, Y = X.to(DEVICE), Y.to(DEVICE)
# Forward pass: calculate logits and loss
logits, loss = model(X, targets=Y)
total_loss += loss.item()
# Backward pass: calculate gradients and update weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log progress every 100 batches (adjust frequency if needed)
if batch_idx % 100 == 0 and batch_idx > 0:
print(f" Epoch {epoch}/{EPOCHS} | Batch {batch_idx}/{len(dataloader)} | Loss: {loss.item():.4f}")
avg_loss = total_loss / len(dataloader)
print(f"--- Epoch {epoch+1} Complete --- Average Loss: {avg_loss:.4f}")
# If the loss is very low, the model has likely memorized the math.
if avg_loss < 0.01:
print("Loss is very low. Stopping training early.")
break
# 3. Save the trained model
torch.save(model.state_dict(), 'data/tinyllm_v1_weights1.pt')
print("\nTraining complete! Model weights saved to data/tinyllm_v1_weights1.pt")
if __name__ == '__main__':
# 1. Setup the data
dataloader, vocab_size, max_len = setup_data_pipeline()
# 2. Start the training process
train(dataloader, vocab_size, max_len)