File size: 6,238 Bytes
1f39ae1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
"""
Detailed analysis of TagTransformer embeddings for different token types
"""

import torch
import torch.nn as nn
import numpy as np
from transformer import TagTransformer, PAD_IDX, DEVICE

def analyze_token_embeddings():
    """Analyze how different token types are embedded"""
    
    # Create a simple vocabulary for analysis
    regular_chars = ['a', 'b', 'c', 'd', 'e']
    special_tokens = ['<TAG1>', '<TAG2>', '<TAG3>']
    
    # Create mappings
    char2idx = {char: idx for idx, char in enumerate(regular_chars)}
    special2idx = {token: idx + len(regular_chars) for idx, token in enumerate(special_tokens)}
    token2idx = {**char2idx, **special2idx}
    
    vocab_size = len(token2idx)
    num_special = len(special_tokens)
    
    print("=== Token Embedding Analysis ===")
    print(f"Regular characters: {regular_chars}")
    print(f"Special tokens: {special_tokens}")
    print(f"Vocabulary size: {vocab_size}")
    print(f"Number of special tokens: {num_special}")
    print()
    
    # Create model
    model = TagTransformer(
        src_vocab_size=vocab_size,
        trg_vocab_size=vocab_size,
        embed_dim=32,  # Small for analysis
        nb_heads=4,
        src_hid_size=64,
        src_nb_layers=1,
        trg_hid_size=64,
        trg_nb_layers=1,
        dropout_p=0.0,
        tie_trg_embed=True,
        label_smooth=0.0,
        nb_attr=num_special,
        src_c2i=token2idx,
        trg_c2i=token2idx,
        attr_c2i={},
    )
    
    model = model.to(DEVICE)
    model.eval()
    
    # Test different token types
    test_tokens = regular_chars + special_tokens
    test_indices = [token2idx[token] for token in test_tokens]
    
    print("=== Individual Token Analysis ===")
    for i, (token, idx) in enumerate(zip(test_tokens, test_indices)):
        print(f"\nToken: {token:10s} (ID: {idx:2d})")
        
        # Create input tensor
        input_tensor = torch.tensor([[idx]], dtype=torch.long).to(DEVICE)  # [1, 1]
        
        with torch.no_grad():
            # Get different embedding components
            word_embed = model.src_embed(input_tensor)  # [1, 1, embed_dim]
            pos_embed = model.position_embed(input_tensor)  # [1, 1, embed_dim]
            
            # Check if this is a special token
            is_special = (idx >= vocab_size - num_special)
            char_mask = torch.tensor([[is_special]], dtype=torch.long).to(DEVICE)
            special_embed = model.special_embeddings(char_mask)  # [1, 1, embed_dim]
            
            # Scale embeddings
            word_embed_scaled = model.embed_scale * word_embed
            special_embed_scaled = model.embed_scale * special_embed
            
            # Combined embedding
            combined = word_embed_scaled + pos_embed + special_embed_scaled
            
            print(f"  Is special token: {is_special}")
            print(f"  Word embedding norm: {word_embed_scaled.norm().item():.4f}")
            print(f"  Positional embedding norm: {pos_embed.norm().item():.4f}")
            print(f"  Special embedding norm: {special_embed_scaled.norm().item():.4f}")
            print(f"  Combined embedding norm: {combined.norm().item():.4f}")
            
            # Show some embedding values
            print(f"  Word embedding sample: {word_embed_scaled[0, 0, :5].tolist()}")
            print(f"  Special embedding sample: {special_embed_scaled[0, 0, :5].tolist()}")
    
    print("\n=== Sequence Analysis ===")
    
    # Test with a sequence containing mixed tokens
    sequence = "a <TAG1> b <TAG2> c"
    sequence_tokens = sequence.split()
    sequence_indices = [token2idx[token] for token in sequence_tokens]
    
    print(f"Test sequence: {sequence}")
    print(f"Token indices: {sequence_indices}")
    
    # Create sequence tensor
    seq_tensor = torch.tensor([sequence_indices], dtype=torch.long).to(DEVICE)  # [batch_size=1, seq_len]
    seq_tensor = seq_tensor.t()  # Transpose to [seq_len, batch_size]
    
    with torch.no_grad():
        # Get embeddings for the sequence
        word_embed = model.embed_scale * model.src_embed(seq_tensor)
        pos_embed = model.position_embed(seq_tensor)
        
        # Get special embeddings
        char_mask = (seq_tensor >= (vocab_size - num_special)).long()
        special_embed = model.embed_scale * model.special_embeddings(char_mask)
        
        # Combined embeddings
        combined = word_embed + pos_embed + special_embed
        
        print(f"\nSequence embedding shapes:")
        print(f"  Word embeddings: {word_embed.shape}")
        print(f"  Positional embeddings: {pos_embed.shape}")
        print(f"  Special embeddings: {special_embed.shape}")
        print(f"  Combined embeddings: {combined.shape}")
        
        print(f"\nToken-by-token analysis:")
        for i, (token, idx) in enumerate(zip(sequence_tokens, sequence_indices)):
            is_special = (idx >= vocab_size - num_special)
            print(f"  Position {i}: {token:10s} | Special: {is_special}")
            print(f"    Word emb norm: {word_embed[i, 0].norm().item():.4f}")
            print(f"    Pos emb norm: {pos_embed[i, 0].norm().item():.4f}")
            print(f"    Special emb norm: {special_embed[i, 0].norm().item():.4f}")
            print(f"    Combined norm: {combined[i, 0].norm().item():.4f}")
        
        # Test encoder output
        print(f"\nTesting encoder...")
        src_mask = torch.zeros(seq_tensor.size(0), seq_tensor.size(1), dtype=torch.bool).to(DEVICE)
        encoded = model.encode(seq_tensor, src_mask)
        print(f"  Encoder output shape: {encoded.shape}")
        print(f"  Encoder output norms: {encoded.squeeze(1).norm(dim=1).tolist()}")
    
    print("\n=== Summary ===")
    print("✓ Regular characters get standard word embeddings + positional embeddings")
    print("✓ Special tokens get standard word embeddings + positional embeddings + special embeddings")
    print("✓ The special embeddings provide additional feature information for tags and special tokens")
    print("✓ This allows the model to distinguish between regular characters and linguistic tags")

if __name__ == '__main__':
    analyze_token_embeddings()