|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
from src.tokenizer import generate_v1_data, CharacterTokenizer
|
|
|
from src.dataset import MathDataset
|
|
|
from src.model import TinyLLM, n_embed, n_head, n_layer, dropout
|
|
|
|
|
|
|
|
|
BATCH_SIZE = 32
|
|
|
LEARNING_RATE = 1e-3
|
|
|
EPOCHS = 100
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
def setup_data_pipeline(batch_size=BATCH_SIZE):
|
|
|
"""Sets up the data generation, tokenization, and PyTorch DataLoaders."""
|
|
|
|
|
|
|
|
|
raw_data = generate_v1_data()
|
|
|
tokenizer = CharacterTokenizer(raw_data)
|
|
|
max_len = max(len(s) for s in raw_data)
|
|
|
|
|
|
|
|
|
train_dataset = MathDataset(raw_data, tokenizer, max_len)
|
|
|
train_dataloader = DataLoader(
|
|
|
train_dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=True,
|
|
|
drop_last=True
|
|
|
)
|
|
|
|
|
|
print(f"Total problems: {len(raw_data)}")
|
|
|
print(f"Vocabulary Size: {tokenizer.vocab_size}")
|
|
|
print(f"Max Sequence Length (T): {max_len}")
|
|
|
print(f"Device: {DEVICE}")
|
|
|
|
|
|
return train_dataloader, tokenizer.vocab_size, max_len
|
|
|
|
|
|
|
|
|
def train(dataloader, vocab_size, block_size):
|
|
|
"""Initializes the model and runs the full training loop."""
|
|
|
|
|
|
|
|
|
model = TinyLLM(
|
|
|
vocab_size=vocab_size,
|
|
|
n_embed=n_embed,
|
|
|
n_head=n_head,
|
|
|
n_layer=n_layer,
|
|
|
block_size=block_size,
|
|
|
dropout=dropout
|
|
|
).to(DEVICE)
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
|
|
print(f"TinyLLM Parameters: {sum(p.numel() for p in model.parameters())/1e3:.1f}K")
|
|
|
print(f"Starting training for {EPOCHS} epochs...")
|
|
|
|
|
|
|
|
|
for epoch in range(EPOCHS):
|
|
|
model.train()
|
|
|
total_loss = 0
|
|
|
|
|
|
for batch_idx, (X, Y) in enumerate(dataloader):
|
|
|
|
|
|
X, Y = X.to(DEVICE), Y.to(DEVICE)
|
|
|
|
|
|
|
|
|
logits, loss = model(X, targets=Y)
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
if batch_idx % 100 == 0 and batch_idx > 0:
|
|
|
print(f" Epoch {epoch}/{EPOCHS} | Batch {batch_idx}/{len(dataloader)} | Loss: {loss.item():.4f}")
|
|
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
|
print(f"--- Epoch {epoch+1} Complete --- Average Loss: {avg_loss:.4f}")
|
|
|
|
|
|
|
|
|
if avg_loss < 0.01:
|
|
|
print("Loss is very low. Stopping training early.")
|
|
|
break
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), 'data/tinyllm_v1_weights1.pt')
|
|
|
print("\nTraining complete! Model weights saved to data/tinyllm_v1_weights1.pt")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
dataloader, vocab_size, max_len = setup_data_pipeline()
|
|
|
|
|
|
|
|
|
train(dataloader, vocab_size, max_len) |