|
|
from huggingface_hub import PyTorchModelHubMixin
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
n_embed = 64
|
|
|
n_head = 4
|
|
|
n_layer = 4
|
|
|
dropout = 0.1
|
|
|
|
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
|
"""A multi-head masked self-attention module."""
|
|
|
|
|
|
def __init__(self, n_embed, n_head, block_size, dropout):
|
|
|
super().__init__()
|
|
|
|
|
|
self.n_embed = n_embed
|
|
|
self.n_head = n_head
|
|
|
self.head_size = n_embed // n_head
|
|
|
|
|
|
|
|
|
self.c_attn = nn.Linear(n_embed, 3 * n_embed, bias=False)
|
|
|
|
|
|
self.c_proj = nn.Linear(n_embed, n_embed, bias=False)
|
|
|
self.attn_dropout = nn.Dropout(dropout)
|
|
|
self.resid_dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))
|
|
|
.view(1, 1, block_size, block_size))
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, T, C = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
qkv = self.c_attn(x)
|
|
|
q, k, v = qkv.split(self.n_embed, dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
wei = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_size))
|
|
|
|
|
|
|
|
|
|
|
|
wei = wei.masked_fill(self.tril[:,:,:T,:T] == 0, float('-inf'))
|
|
|
|
|
|
|
|
|
wei = F.softmax(wei, dim=-1)
|
|
|
wei = self.attn_dropout(wei)
|
|
|
|
|
|
|
|
|
out = wei @ v
|
|
|
|
|
|
|
|
|
out = out.transpose(1, 2).contiguous().view(B, T, C)
|
|
|
|
|
|
|
|
|
out = self.resid_dropout(self.c_proj(out))
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
"""A two-layer MLP for processing attention output."""
|
|
|
def __init__(self, n_embed, dropout):
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
|
nn.Linear(n_embed, 4 * n_embed),
|
|
|
nn.GELU(),
|
|
|
nn.Linear(4 * n_embed, n_embed),
|
|
|
nn.Dropout(dropout),
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
"""A standard Transformer decoder block with Attention and FFN."""
|
|
|
|
|
|
def __init__(self, n_embed, n_head, block_size, dropout):
|
|
|
super().__init__()
|
|
|
|
|
|
self.ln_1 = nn.LayerNorm(n_embed)
|
|
|
self.attn = CausalSelfAttention(n_embed, n_head, block_size, dropout)
|
|
|
self.ln_2 = nn.LayerNorm(n_embed)
|
|
|
self.ffn = FeedForward(n_embed, dropout)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
x = x + self.attn(self.ln_1(x))
|
|
|
|
|
|
x = x + self.ffn(self.ln_2(x))
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class TinyLLM(nn.Module, PyTorchModelHubMixin):
|
|
|
"""The complete Decoder-Only Transformer model."""
|
|
|
|
|
|
def __init__(self, vocab_size, n_embed, n_head, n_layer, block_size, dropout):
|
|
|
super().__init__()
|
|
|
|
|
|
self.block_size = block_size
|
|
|
|
|
|
self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
|
|
|
|
|
|
self.position_embedding_table = nn.Embedding(block_size, n_embed)
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
TransformerBlock(n_embed, n_head, block_size, dropout)
|
|
|
for _ in range(n_layer)
|
|
|
])
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(n_embed)
|
|
|
|
|
|
self.lm_head = nn.Linear(n_embed, vocab_size)
|
|
|
|
|
|
def forward(self, idx, targets=None):
|
|
|
|
|
|
B, T = idx.shape
|
|
|
|
|
|
|
|
|
|
|
|
tok_emb = self.token_embedding_table(idx)
|
|
|
|
|
|
pos = torch.arange(T, device=idx.device)
|
|
|
pos_emb = self.position_embedding_table(pos)
|
|
|
|
|
|
|
|
|
x = tok_emb + pos_emb
|
|
|
|
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
|
|
|
x = self.ln_f(x)
|
|
|
logits = self.lm_head(x)
|
|
|
|
|
|
loss = None
|
|
|
if targets is not None:
|
|
|
|
|
|
B, T, C = logits.shape
|
|
|
logits = logits.view(B*T, C)
|
|
|
targets = targets.view(B*T)
|
|
|
|
|
|
|
|
|
loss = F.cross_entropy(logits, targets)
|
|
|
|
|
|
return logits, loss |