Spaces:
Runtime error
Runtime error
File size: 6,238 Bytes
1f39ae1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
#!/usr/bin/env python3
"""
Detailed analysis of TagTransformer embeddings for different token types
"""
import torch
import torch.nn as nn
import numpy as np
from transformer import TagTransformer, PAD_IDX, DEVICE
def analyze_token_embeddings():
"""Analyze how different token types are embedded"""
# Create a simple vocabulary for analysis
regular_chars = ['a', 'b', 'c', 'd', 'e']
special_tokens = ['<TAG1>', '<TAG2>', '<TAG3>']
# Create mappings
char2idx = {char: idx for idx, char in enumerate(regular_chars)}
special2idx = {token: idx + len(regular_chars) for idx, token in enumerate(special_tokens)}
token2idx = {**char2idx, **special2idx}
vocab_size = len(token2idx)
num_special = len(special_tokens)
print("=== Token Embedding Analysis ===")
print(f"Regular characters: {regular_chars}")
print(f"Special tokens: {special_tokens}")
print(f"Vocabulary size: {vocab_size}")
print(f"Number of special tokens: {num_special}")
print()
# Create model
model = TagTransformer(
src_vocab_size=vocab_size,
trg_vocab_size=vocab_size,
embed_dim=32, # Small for analysis
nb_heads=4,
src_hid_size=64,
src_nb_layers=1,
trg_hid_size=64,
trg_nb_layers=1,
dropout_p=0.0,
tie_trg_embed=True,
label_smooth=0.0,
nb_attr=num_special,
src_c2i=token2idx,
trg_c2i=token2idx,
attr_c2i={},
)
model = model.to(DEVICE)
model.eval()
# Test different token types
test_tokens = regular_chars + special_tokens
test_indices = [token2idx[token] for token in test_tokens]
print("=== Individual Token Analysis ===")
for i, (token, idx) in enumerate(zip(test_tokens, test_indices)):
print(f"\nToken: {token:10s} (ID: {idx:2d})")
# Create input tensor
input_tensor = torch.tensor([[idx]], dtype=torch.long).to(DEVICE) # [1, 1]
with torch.no_grad():
# Get different embedding components
word_embed = model.src_embed(input_tensor) # [1, 1, embed_dim]
pos_embed = model.position_embed(input_tensor) # [1, 1, embed_dim]
# Check if this is a special token
is_special = (idx >= vocab_size - num_special)
char_mask = torch.tensor([[is_special]], dtype=torch.long).to(DEVICE)
special_embed = model.special_embeddings(char_mask) # [1, 1, embed_dim]
# Scale embeddings
word_embed_scaled = model.embed_scale * word_embed
special_embed_scaled = model.embed_scale * special_embed
# Combined embedding
combined = word_embed_scaled + pos_embed + special_embed_scaled
print(f" Is special token: {is_special}")
print(f" Word embedding norm: {word_embed_scaled.norm().item():.4f}")
print(f" Positional embedding norm: {pos_embed.norm().item():.4f}")
print(f" Special embedding norm: {special_embed_scaled.norm().item():.4f}")
print(f" Combined embedding norm: {combined.norm().item():.4f}")
# Show some embedding values
print(f" Word embedding sample: {word_embed_scaled[0, 0, :5].tolist()}")
print(f" Special embedding sample: {special_embed_scaled[0, 0, :5].tolist()}")
print("\n=== Sequence Analysis ===")
# Test with a sequence containing mixed tokens
sequence = "a <TAG1> b <TAG2> c"
sequence_tokens = sequence.split()
sequence_indices = [token2idx[token] for token in sequence_tokens]
print(f"Test sequence: {sequence}")
print(f"Token indices: {sequence_indices}")
# Create sequence tensor
seq_tensor = torch.tensor([sequence_indices], dtype=torch.long).to(DEVICE) # [batch_size=1, seq_len]
seq_tensor = seq_tensor.t() # Transpose to [seq_len, batch_size]
with torch.no_grad():
# Get embeddings for the sequence
word_embed = model.embed_scale * model.src_embed(seq_tensor)
pos_embed = model.position_embed(seq_tensor)
# Get special embeddings
char_mask = (seq_tensor >= (vocab_size - num_special)).long()
special_embed = model.embed_scale * model.special_embeddings(char_mask)
# Combined embeddings
combined = word_embed + pos_embed + special_embed
print(f"\nSequence embedding shapes:")
print(f" Word embeddings: {word_embed.shape}")
print(f" Positional embeddings: {pos_embed.shape}")
print(f" Special embeddings: {special_embed.shape}")
print(f" Combined embeddings: {combined.shape}")
print(f"\nToken-by-token analysis:")
for i, (token, idx) in enumerate(zip(sequence_tokens, sequence_indices)):
is_special = (idx >= vocab_size - num_special)
print(f" Position {i}: {token:10s} | Special: {is_special}")
print(f" Word emb norm: {word_embed[i, 0].norm().item():.4f}")
print(f" Pos emb norm: {pos_embed[i, 0].norm().item():.4f}")
print(f" Special emb norm: {special_embed[i, 0].norm().item():.4f}")
print(f" Combined norm: {combined[i, 0].norm().item():.4f}")
# Test encoder output
print(f"\nTesting encoder...")
src_mask = torch.zeros(seq_tensor.size(0), seq_tensor.size(1), dtype=torch.bool).to(DEVICE)
encoded = model.encode(seq_tensor, src_mask)
print(f" Encoder output shape: {encoded.shape}")
print(f" Encoder output norms: {encoded.squeeze(1).norm(dim=1).tolist()}")
print("\n=== Summary ===")
print("✓ Regular characters get standard word embeddings + positional embeddings")
print("✓ Special tokens get standard word embeddings + positional embeddings + special embeddings")
print("✓ The special embeddings provide additional feature information for tags and special tokens")
print("✓ This allows the model to distinguish between regular characters and linguistic tags")
if __name__ == '__main__':
analyze_token_embeddings()
|