tiny-math-llm / src /model.py
anujbhatt4ai's picture
Initial upload of TinyLLM
13c35e3 verified
from huggingface_hub import PyTorchModelHubMixin
# ... (rest of your model code)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# --- Hyperparameters (You can adjust these later) ---
# For a "Tiny" LLM, we keep the size very small.
n_embed = 64 # C: Embedding dimension (size of the vector representing a character)
n_head = 4 # H: Number of attention heads
n_layer = 4 # Number of repeating Transformer blocks
dropout = 0.1 # Dropout rate
# --- 1. Causal Self-Attention (The "Attention is All You Need" Component) ---
class CausalSelfAttention(nn.Module):
"""A multi-head masked self-attention module."""
def __init__(self, n_embed, n_head, block_size, dropout):
super().__init__()
self.n_embed = n_embed
self.n_head = n_head
self.head_size = n_embed // n_head
# Combined projection for Q, K, and V (more efficient)
self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=False)
# Output projection
self.c_proj = nn.Linear(n_embed, n_embed, bias=False)
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Causal Mask (tril = lower triangular matrix)
# This mask prevents a token from attending to future tokens (autoregressive)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x):
B, T, C = x.shape # Batch size, Sequence length (Time), Embedding dimension (Channel)
# 1. Compute Q, K, V and split (efficiently)
# q, k, v are (B, T, C)
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embed, dim=2)
# 2. Reshape for Multi-Head Attention (B, T, C) -> (B, H, T, Head_size)
# We prepare the tensors so that each head processes a smaller chunk of the dimension C
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
# 3. Scaled Dot-Product Attention: (B, H, T, T)
# wei = (q @ k.transpose(-2, -1)) / sqrt(Head_size)
wei = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))
# 4. Apply Causal Mask
# Set attention scores to -inf for future tokens (where tril == 0)
wei = wei.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf'))
# 5. Softmax and Dropout
wei = F.softmax(wei, dim=-1)
wei = self.attn_dropout(wei)
# 6. Compute Weighted Sum of Values: (B, H, T, Head_size)
out = wei @ v
# 7. Re-assemble heads: (B, H, T, Head_size) -> (B, T, C)
out = out.transpose(1, 2).contiguous().view(B, T, C)
# 8. Final Linear Projection
out = self.resid_dropout(self.c_proj(out))
return out
# --- 2. Feed Forward Network (FFN) ---
class FeedForward(nn.Module):
"""A two-layer MLP for processing attention output."""
def __init__(self, n_embed, dropout):
super().__init__()
self.net = nn.Sequential(
# Standard ratio is 4x the embedding size
nn.Linear(n_embed, 4 * n_embed),
nn.GELU(), # Modern activation function (smoother than ReLU)
nn.Linear(4 * n_embed, n_embed),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# --- 3. Transformer Block (The Repeating Unit) ---
class TransformerBlock(nn.Module):
"""A standard Transformer decoder block with Attention and FFN."""
def __init__(self, n_embed, n_head, block_size, dropout):
super().__init__()
# LayerNorm applied BEFORE the sub-layer (Pre-Norm style)
self.ln_1 = nn.LayerNorm(n_embed)
self.attn = CausalSelfAttention(n_embed, n_head, block_size, dropout)
self.ln_2 = nn.LayerNorm(n_embed)
self.ffn = FeedForward(n_embed, dropout)
def forward(self, x):
# 1. Attention with Residual Connection and LayerNorm
x = x + self.attn(self.ln_1(x))
# 2. FFN with Residual Connection and LayerNorm
x = x + self.ffn(self.ln_2(x))
return x
# --- 4. The Final TinyLLM Model ---
class TinyLLM(nn.Module, PyTorchModelHubMixin):
"""The complete Decoder-Only Transformer model."""
def __init__(self, vocab_size, n_embed, n_head, n_layer, block_size, dropout):
super().__init__()
self.block_size = block_size
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
# Positional Encoding: A fixed table for position information
self.position_embedding_table = nn.Embedding(block_size, n_embed)
# Stack of Transformer Blocks
self.blocks = nn.Sequential(*[
TransformerBlock(n_embed, n_head, block_size, dropout)
for _ in range(n_layer)
])
self.ln_f = nn.LayerNorm(n_embed) # Final LayerNorm
# Linear layer to map the embedding vector back to the vocabulary space
self.lm_head = nn.Linear(n_embed, vocab_size)
def forward(self, idx, targets=None):
# idx is the input tensor X of shape (B, T)
B, T = idx.shape
# 1. Token and Positional Embeddings
# Token embedding: (B, T, C)
tok_emb = self.token_embedding_table(idx)
# Position embedding: (T, C) -> expanded to (B, T, C)
pos = torch.arange(T, device=idx.device)
pos_emb = self.position_embedding_table(pos)
# 2. Combine (Add) Embeddings
x = tok_emb + pos_emb # (B, T, C)
# 3. Pass through Transformer Blocks
x = self.blocks(x) # (B, T, C)
# 4. Final LayerNorm and Linear Head
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
loss = None
if targets is not None:
# Reshape for CrossEntropyLoss: (B*T, vocab_size) and (B*T)
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
# Compute the negative log-likelihood loss
loss = F.cross_entropy(logits, targets)
return logits, loss