morphological-transformer / scripts /analyze_embeddings.py
akki2825
Initial deployment of Morphological Transformer with ZeroGPU
1f39ae1
#!/usr/bin/env python3
"""
Detailed analysis of TagTransformer embeddings for different token types
"""
import torch
import torch.nn as nn
import numpy as np
from transformer import TagTransformer, PAD_IDX, DEVICE
def analyze_token_embeddings():
"""Analyze how different token types are embedded"""
# Create a simple vocabulary for analysis
regular_chars = ['a', 'b', 'c', 'd', 'e']
special_tokens = ['<TAG1>', '<TAG2>', '<TAG3>']
# Create mappings
char2idx = {char: idx for idx, char in enumerate(regular_chars)}
special2idx = {token: idx + len(regular_chars) for idx, token in enumerate(special_tokens)}
token2idx = {**char2idx, **special2idx}
vocab_size = len(token2idx)
num_special = len(special_tokens)
print("=== Token Embedding Analysis ===")
print(f"Regular characters: {regular_chars}")
print(f"Special tokens: {special_tokens}")
print(f"Vocabulary size: {vocab_size}")
print(f"Number of special tokens: {num_special}")
print()
# Create model
model = TagTransformer(
src_vocab_size=vocab_size,
trg_vocab_size=vocab_size,
embed_dim=32, # Small for analysis
nb_heads=4,
src_hid_size=64,
src_nb_layers=1,
trg_hid_size=64,
trg_nb_layers=1,
dropout_p=0.0,
tie_trg_embed=True,
label_smooth=0.0,
nb_attr=num_special,
src_c2i=token2idx,
trg_c2i=token2idx,
attr_c2i={},
)
model = model.to(DEVICE)
model.eval()
# Test different token types
test_tokens = regular_chars + special_tokens
test_indices = [token2idx[token] for token in test_tokens]
print("=== Individual Token Analysis ===")
for i, (token, idx) in enumerate(zip(test_tokens, test_indices)):
print(f"\nToken: {token:10s} (ID: {idx:2d})")
# Create input tensor
input_tensor = torch.tensor([[idx]], dtype=torch.long).to(DEVICE) # [1, 1]
with torch.no_grad():
# Get different embedding components
word_embed = model.src_embed(input_tensor) # [1, 1, embed_dim]
pos_embed = model.position_embed(input_tensor) # [1, 1, embed_dim]
# Check if this is a special token
is_special = (idx >= vocab_size - num_special)
char_mask = torch.tensor([[is_special]], dtype=torch.long).to(DEVICE)
special_embed = model.special_embeddings(char_mask) # [1, 1, embed_dim]
# Scale embeddings
word_embed_scaled = model.embed_scale * word_embed
special_embed_scaled = model.embed_scale * special_embed
# Combined embedding
combined = word_embed_scaled + pos_embed + special_embed_scaled
print(f" Is special token: {is_special}")
print(f" Word embedding norm: {word_embed_scaled.norm().item():.4f}")
print(f" Positional embedding norm: {pos_embed.norm().item():.4f}")
print(f" Special embedding norm: {special_embed_scaled.norm().item():.4f}")
print(f" Combined embedding norm: {combined.norm().item():.4f}")
# Show some embedding values
print(f" Word embedding sample: {word_embed_scaled[0, 0, :5].tolist()}")
print(f" Special embedding sample: {special_embed_scaled[0, 0, :5].tolist()}")
print("\n=== Sequence Analysis ===")
# Test with a sequence containing mixed tokens
sequence = "a <TAG1> b <TAG2> c"
sequence_tokens = sequence.split()
sequence_indices = [token2idx[token] for token in sequence_tokens]
print(f"Test sequence: {sequence}")
print(f"Token indices: {sequence_indices}")
# Create sequence tensor
seq_tensor = torch.tensor([sequence_indices], dtype=torch.long).to(DEVICE) # [batch_size=1, seq_len]
seq_tensor = seq_tensor.t() # Transpose to [seq_len, batch_size]
with torch.no_grad():
# Get embeddings for the sequence
word_embed = model.embed_scale * model.src_embed(seq_tensor)
pos_embed = model.position_embed(seq_tensor)
# Get special embeddings
char_mask = (seq_tensor >= (vocab_size - num_special)).long()
special_embed = model.embed_scale * model.special_embeddings(char_mask)
# Combined embeddings
combined = word_embed + pos_embed + special_embed
print(f"\nSequence embedding shapes:")
print(f" Word embeddings: {word_embed.shape}")
print(f" Positional embeddings: {pos_embed.shape}")
print(f" Special embeddings: {special_embed.shape}")
print(f" Combined embeddings: {combined.shape}")
print(f"\nToken-by-token analysis:")
for i, (token, idx) in enumerate(zip(sequence_tokens, sequence_indices)):
is_special = (idx >= vocab_size - num_special)
print(f" Position {i}: {token:10s} | Special: {is_special}")
print(f" Word emb norm: {word_embed[i, 0].norm().item():.4f}")
print(f" Pos emb norm: {pos_embed[i, 0].norm().item():.4f}")
print(f" Special emb norm: {special_embed[i, 0].norm().item():.4f}")
print(f" Combined norm: {combined[i, 0].norm().item():.4f}")
# Test encoder output
print(f"\nTesting encoder...")
src_mask = torch.zeros(seq_tensor.size(0), seq_tensor.size(1), dtype=torch.bool).to(DEVICE)
encoded = model.encode(seq_tensor, src_mask)
print(f" Encoder output shape: {encoded.shape}")
print(f" Encoder output norms: {encoded.squeeze(1).norm(dim=1).tolist()}")
print("\n=== Summary ===")
print("✓ Regular characters get standard word embeddings + positional embeddings")
print("✓ Special tokens get standard word embeddings + positional embeddings + special embeddings")
print("✓ The special embeddings provide additional feature information for tags and special tokens")
print("✓ This allows the model to distinguish between regular characters and linguistic tags")
if __name__ == '__main__':
analyze_token_embeddings()