Spaces:
Runtime error
Runtime error
File size: 5,334 Bytes
1f39ae1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
#!/usr/bin/env python3
"""
Test script to verify TagTransformer embeddings for special characters and tags
"""
import torch
import torch.nn as nn
from transformer import TagTransformer, PAD_IDX, DEVICE
def create_test_vocabulary():
"""Create a test vocabulary that includes the example tokens"""
# Regular characters from your example
chars = ['t', 'ɾ', 'a', 'd', 'ˈ', 'u', 's', 'e', 'n', 'k', 'm', 'o']
# Special tokens and tags from your example
special_tokens = [
'<V;IND;PRS;3;PL>',
'<V;SBJV;PRS;1;PL>',
'<V;SBJV;PRS;3;PL>',
'#'
]
# Create character-to-index mappings
char2idx = {char: idx for idx, char in enumerate(chars)}
special2idx = {token: idx + len(chars) for idx, token in enumerate(special_tokens)}
# Combine mappings
all_tokens = chars + special_tokens
token2idx = {**char2idx, **special2idx}
return token2idx, all_tokens, len(chars), len(special_tokens)
def tokenize_sequence(sequence, token2idx):
"""Tokenize the example sequence"""
# Split by spaces and map to indices
tokens = sequence.split()
indices = [token2idx[token] for token in tokens]
return torch.tensor(indices, dtype=torch.long)
def test_embeddings():
"""Test the TagTransformer embeddings"""
# Create vocabulary
token2idx, all_tokens, num_chars, num_special = create_test_vocabulary()
vocab_size = len(all_tokens)
print("Vocabulary:")
for idx, token in enumerate(all_tokens):
print(f" {idx:2d}: {token}")
print()
# Your example sequence
example = "t ɾ a d ˈ u s e n <V;IND;PRS;3;PL> # t ɾ a d u s k ˈ a m o s <V;SBJV;PRS;1;PL> # <V;SBJV;PRS;3;PL>"
print(f"Example sequence: {example}")
print()
# Tokenize
tokenized = tokenize_sequence(example, token2idx)
print(f"Tokenized sequence: {tokenized.tolist()}")
print(f"Sequence length: {len(tokenized)}")
print()
# Create model
model = TagTransformer(
src_vocab_size=vocab_size,
trg_vocab_size=vocab_size,
embed_dim=64, # Small for testing
nb_heads=4,
src_hid_size=128,
src_nb_layers=2,
trg_hid_size=128,
trg_nb_layers=2,
dropout_p=0.0, # No dropout for testing
tie_trg_embed=True,
label_smooth=0.0,
nb_attr=num_special,
src_c2i=token2idx,
trg_c2i=token2idx,
attr_c2i={}, # Not used in this test
)
model = model.to(DEVICE)
model.eval()
print(f"Model created with {model.count_nb_params():,} parameters")
print(f"Number of special tokens/attributes: {num_special}")
print()
# Prepare input
src_batch = tokenized.unsqueeze(1).to(DEVICE) # [seq_len, batch_size=1]
src_mask = torch.zeros(len(tokenized), 1, dtype=torch.bool).to(DEVICE)
print("Input shape:", src_batch.shape)
print()
# Test embedding
with torch.no_grad():
# Get embeddings
word_embed = model.embed_scale * model.src_embed(src_batch)
pos_embed = model.position_embed(src_batch)
# Get special embeddings
char_mask = (src_batch >= (model.src_vocab_size - model.nb_attr)).long()
special_embed = model.embed_scale * model.special_embeddings(char_mask)
# Combine embeddings
combined_embed = word_embed + pos_embed + special_embed
print("Embedding analysis:")
print(f" Word embeddings shape: {word_embed.shape}")
print(f" Positional embeddings shape: {pos_embed.shape}")
print(f" Special embeddings shape: {special_embed.shape}")
print(f" Combined embeddings shape: {combined_embed.shape}")
print()
# Show which tokens get special embeddings
print("Special embedding analysis:")
for i, token_idx in enumerate(tokenized):
token = all_tokens[token_idx]
is_special = char_mask[i, 0].item()
special_emb = special_embed[i, 0]
print(f" Position {i:2d}: {token:20s} | Special: {is_special} | Special emb norm: {special_emb.norm().item():.4f}")
print()
# Test full forward pass
print("Testing full forward pass...")
output = model.encode(src_batch, src_mask)
print(f" Encoder output shape: {output.shape}")
print(f" Encoder output norm: {output.norm().item():.4f}")
# Check if special tokens have different representations
special_positions = char_mask.squeeze(1).nonzero(as_tuple=True)[0]
char_positions = (1 - char_mask.squeeze(1)).nonzero(as_tuple=True)[0]
if len(special_positions) > 0 and len(char_positions) > 0:
special_repr = output[special_positions].mean(dim=0)
char_repr = output[char_positions].mean(dim=0)
print(f" Average special token representation norm: {special_repr.norm().item():.4f}")
print(f" Average character representation norm: {char_repr.norm().item():.4f}")
print(f" Difference: {torch.dist(special_repr, char_repr).item():.4f}")
print("✓ Embedding test completed successfully!")
if __name__ == '__main__':
test_embeddings()
|