morphological-transformer / scripts /test_embeddings.py
akki2825
Initial deployment of Morphological Transformer with ZeroGPU
1f39ae1
#!/usr/bin/env python3
"""
Test script to verify TagTransformer embeddings for special characters and tags
"""
import torch
import torch.nn as nn
from transformer import TagTransformer, PAD_IDX, DEVICE
def create_test_vocabulary():
"""Create a test vocabulary that includes the example tokens"""
# Regular characters from your example
chars = ['t', 'ɾ', 'a', 'd', 'ˈ', 'u', 's', 'e', 'n', 'k', 'm', 'o']
# Special tokens and tags from your example
special_tokens = [
'<V;IND;PRS;3;PL>',
'<V;SBJV;PRS;1;PL>',
'<V;SBJV;PRS;3;PL>',
'#'
]
# Create character-to-index mappings
char2idx = {char: idx for idx, char in enumerate(chars)}
special2idx = {token: idx + len(chars) for idx, token in enumerate(special_tokens)}
# Combine mappings
all_tokens = chars + special_tokens
token2idx = {**char2idx, **special2idx}
return token2idx, all_tokens, len(chars), len(special_tokens)
def tokenize_sequence(sequence, token2idx):
"""Tokenize the example sequence"""
# Split by spaces and map to indices
tokens = sequence.split()
indices = [token2idx[token] for token in tokens]
return torch.tensor(indices, dtype=torch.long)
def test_embeddings():
"""Test the TagTransformer embeddings"""
# Create vocabulary
token2idx, all_tokens, num_chars, num_special = create_test_vocabulary()
vocab_size = len(all_tokens)
print("Vocabulary:")
for idx, token in enumerate(all_tokens):
print(f" {idx:2d}: {token}")
print()
# Your example sequence
example = "t ɾ a d ˈ u s e n <V;IND;PRS;3;PL> # t ɾ a d u s k ˈ a m o s <V;SBJV;PRS;1;PL> # <V;SBJV;PRS;3;PL>"
print(f"Example sequence: {example}")
print()
# Tokenize
tokenized = tokenize_sequence(example, token2idx)
print(f"Tokenized sequence: {tokenized.tolist()}")
print(f"Sequence length: {len(tokenized)}")
print()
# Create model
model = TagTransformer(
src_vocab_size=vocab_size,
trg_vocab_size=vocab_size,
embed_dim=64, # Small for testing
nb_heads=4,
src_hid_size=128,
src_nb_layers=2,
trg_hid_size=128,
trg_nb_layers=2,
dropout_p=0.0, # No dropout for testing
tie_trg_embed=True,
label_smooth=0.0,
nb_attr=num_special,
src_c2i=token2idx,
trg_c2i=token2idx,
attr_c2i={}, # Not used in this test
)
model = model.to(DEVICE)
model.eval()
print(f"Model created with {model.count_nb_params():,} parameters")
print(f"Number of special tokens/attributes: {num_special}")
print()
# Prepare input
src_batch = tokenized.unsqueeze(1).to(DEVICE) # [seq_len, batch_size=1]
src_mask = torch.zeros(len(tokenized), 1, dtype=torch.bool).to(DEVICE)
print("Input shape:", src_batch.shape)
print()
# Test embedding
with torch.no_grad():
# Get embeddings
word_embed = model.embed_scale * model.src_embed(src_batch)
pos_embed = model.position_embed(src_batch)
# Get special embeddings
char_mask = (src_batch >= (model.src_vocab_size - model.nb_attr)).long()
special_embed = model.embed_scale * model.special_embeddings(char_mask)
# Combine embeddings
combined_embed = word_embed + pos_embed + special_embed
print("Embedding analysis:")
print(f" Word embeddings shape: {word_embed.shape}")
print(f" Positional embeddings shape: {pos_embed.shape}")
print(f" Special embeddings shape: {special_embed.shape}")
print(f" Combined embeddings shape: {combined_embed.shape}")
print()
# Show which tokens get special embeddings
print("Special embedding analysis:")
for i, token_idx in enumerate(tokenized):
token = all_tokens[token_idx]
is_special = char_mask[i, 0].item()
special_emb = special_embed[i, 0]
print(f" Position {i:2d}: {token:20s} | Special: {is_special} | Special emb norm: {special_emb.norm().item():.4f}")
print()
# Test full forward pass
print("Testing full forward pass...")
output = model.encode(src_batch, src_mask)
print(f" Encoder output shape: {output.shape}")
print(f" Encoder output norm: {output.norm().item():.4f}")
# Check if special tokens have different representations
special_positions = char_mask.squeeze(1).nonzero(as_tuple=True)[0]
char_positions = (1 - char_mask.squeeze(1)).nonzero(as_tuple=True)[0]
if len(special_positions) > 0 and len(char_positions) > 0:
special_repr = output[special_positions].mean(dim=0)
char_repr = output[char_positions].mean(dim=0)
print(f" Average special token representation norm: {special_repr.norm().item():.4f}")
print(f" Average character representation norm: {char_repr.norm().item():.4f}")
print(f" Difference: {torch.dist(special_repr, char_repr).item():.4f}")
print("✓ Embedding test completed successfully!")
if __name__ == '__main__':
test_embeddings()