File size: 5,334 Bytes
1f39ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python3
"""
Test script to verify TagTransformer embeddings for special characters and tags
"""

import torch
import torch.nn as nn
from transformer import TagTransformer, PAD_IDX, DEVICE

def create_test_vocabulary():
    """Create a test vocabulary that includes the example tokens"""
    
    # Regular characters from your example
    chars = ['t', 'ɾ', 'a', 'd', 'ˈ', 'u', 's', 'e', 'n', 'k', 'm', 'o']
    
    # Special tokens and tags from your example
    special_tokens = [
        '<V;IND;PRS;3;PL>',
        '<V;SBJV;PRS;1;PL>', 
        '<V;SBJV;PRS;3;PL>',
        '#'
    ]
    
    # Create character-to-index mappings
    char2idx = {char: idx for idx, char in enumerate(chars)}
    special2idx = {token: idx + len(chars) for idx, token in enumerate(special_tokens)}
    
    # Combine mappings
    all_tokens = chars + special_tokens
    token2idx = {**char2idx, **special2idx}
    
    return token2idx, all_tokens, len(chars), len(special_tokens)

def tokenize_sequence(sequence, token2idx):
    """Tokenize the example sequence"""
    # Split by spaces and map to indices
    tokens = sequence.split()
    indices = [token2idx[token] for token in tokens]
    return torch.tensor(indices, dtype=torch.long)

def test_embeddings():
    """Test the TagTransformer embeddings"""
    
    # Create vocabulary
    token2idx, all_tokens, num_chars, num_special = create_test_vocabulary()
    vocab_size = len(all_tokens)
    
    print("Vocabulary:")
    for idx, token in enumerate(all_tokens):
        print(f"  {idx:2d}: {token}")
    print()
    
    # Your example sequence
    example = "t ɾ a d ˈ u s e n <V;IND;PRS;3;PL> # t ɾ a d u s k ˈ a m o s <V;SBJV;PRS;1;PL> # <V;SBJV;PRS;3;PL>"
    
    print(f"Example sequence: {example}")
    print()
    
    # Tokenize
    tokenized = tokenize_sequence(example, token2idx)
    print(f"Tokenized sequence: {tokenized.tolist()}")
    print(f"Sequence length: {len(tokenized)}")
    print()
    
    # Create model
    model = TagTransformer(
        src_vocab_size=vocab_size,
        trg_vocab_size=vocab_size,
        embed_dim=64,  # Small for testing
        nb_heads=4,
        src_hid_size=128,
        src_nb_layers=2,
        trg_hid_size=128,
        trg_nb_layers=2,
        dropout_p=0.0,  # No dropout for testing
        tie_trg_embed=True,
        label_smooth=0.0,
        nb_attr=num_special,
        src_c2i=token2idx,
        trg_c2i=token2idx,
        attr_c2i={},  # Not used in this test
    )
    
    model = model.to(DEVICE)
    model.eval()
    
    print(f"Model created with {model.count_nb_params():,} parameters")
    print(f"Number of special tokens/attributes: {num_special}")
    print()
    
    # Prepare input
    src_batch = tokenized.unsqueeze(1).to(DEVICE)  # [seq_len, batch_size=1]
    src_mask = torch.zeros(len(tokenized), 1, dtype=torch.bool).to(DEVICE)
    
    print("Input shape:", src_batch.shape)
    print()
    
    # Test embedding
    with torch.no_grad():
        # Get embeddings
        word_embed = model.embed_scale * model.src_embed(src_batch)
        pos_embed = model.position_embed(src_batch)
        
        # Get special embeddings
        char_mask = (src_batch >= (model.src_vocab_size - model.nb_attr)).long()
        special_embed = model.embed_scale * model.special_embeddings(char_mask)
        
        # Combine embeddings
        combined_embed = word_embed + pos_embed + special_embed
        
        print("Embedding analysis:")
        print(f"  Word embeddings shape: {word_embed.shape}")
        print(f"  Positional embeddings shape: {pos_embed.shape}")
        print(f"  Special embeddings shape: {special_embed.shape}")
        print(f"  Combined embeddings shape: {combined_embed.shape}")
        print()
        
        # Show which tokens get special embeddings
        print("Special embedding analysis:")
        for i, token_idx in enumerate(tokenized):
            token = all_tokens[token_idx]
            is_special = char_mask[i, 0].item()
            special_emb = special_embed[i, 0]
            
            print(f"  Position {i:2d}: {token:20s} | Special: {is_special} | Special emb norm: {special_emb.norm().item():.4f}")
        
        print()
        
        # Test full forward pass
        print("Testing full forward pass...")
        output = model.encode(src_batch, src_mask)
        print(f"  Encoder output shape: {output.shape}")
        print(f"  Encoder output norm: {output.norm().item():.4f}")
        
        # Check if special tokens have different representations
        special_positions = char_mask.squeeze(1).nonzero(as_tuple=True)[0]
        char_positions = (1 - char_mask.squeeze(1)).nonzero(as_tuple=True)[0]
        
        if len(special_positions) > 0 and len(char_positions) > 0:
            special_repr = output[special_positions].mean(dim=0)
            char_repr = output[char_positions].mean(dim=0)
            
            print(f"  Average special token representation norm: {special_repr.norm().item():.4f}")
            print(f"  Average character representation norm: {char_repr.norm().item():.4f}")
            print(f"  Difference: {torch.dist(special_repr, char_repr).item():.4f}")
        
        print("✓ Embedding test completed successfully!")

if __name__ == '__main__':
    test_embeddings()