Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Detailed analysis of TagTransformer embeddings for different token types | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from transformer import TagTransformer, PAD_IDX, DEVICE | |
| def analyze_token_embeddings(): | |
| """Analyze how different token types are embedded""" | |
| # Create a simple vocabulary for analysis | |
| regular_chars = ['a', 'b', 'c', 'd', 'e'] | |
| special_tokens = ['<TAG1>', '<TAG2>', '<TAG3>'] | |
| # Create mappings | |
| char2idx = {char: idx for idx, char in enumerate(regular_chars)} | |
| special2idx = {token: idx + len(regular_chars) for idx, token in enumerate(special_tokens)} | |
| token2idx = {**char2idx, **special2idx} | |
| vocab_size = len(token2idx) | |
| num_special = len(special_tokens) | |
| print("=== Token Embedding Analysis ===") | |
| print(f"Regular characters: {regular_chars}") | |
| print(f"Special tokens: {special_tokens}") | |
| print(f"Vocabulary size: {vocab_size}") | |
| print(f"Number of special tokens: {num_special}") | |
| print() | |
| # Create model | |
| model = TagTransformer( | |
| src_vocab_size=vocab_size, | |
| trg_vocab_size=vocab_size, | |
| embed_dim=32, # Small for analysis | |
| nb_heads=4, | |
| src_hid_size=64, | |
| src_nb_layers=1, | |
| trg_hid_size=64, | |
| trg_nb_layers=1, | |
| dropout_p=0.0, | |
| tie_trg_embed=True, | |
| label_smooth=0.0, | |
| nb_attr=num_special, | |
| src_c2i=token2idx, | |
| trg_c2i=token2idx, | |
| attr_c2i={}, | |
| ) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| # Test different token types | |
| test_tokens = regular_chars + special_tokens | |
| test_indices = [token2idx[token] for token in test_tokens] | |
| print("=== Individual Token Analysis ===") | |
| for i, (token, idx) in enumerate(zip(test_tokens, test_indices)): | |
| print(f"\nToken: {token:10s} (ID: {idx:2d})") | |
| # Create input tensor | |
| input_tensor = torch.tensor([[idx]], dtype=torch.long).to(DEVICE) # [1, 1] | |
| with torch.no_grad(): | |
| # Get different embedding components | |
| word_embed = model.src_embed(input_tensor) # [1, 1, embed_dim] | |
| pos_embed = model.position_embed(input_tensor) # [1, 1, embed_dim] | |
| # Check if this is a special token | |
| is_special = (idx >= vocab_size - num_special) | |
| char_mask = torch.tensor([[is_special]], dtype=torch.long).to(DEVICE) | |
| special_embed = model.special_embeddings(char_mask) # [1, 1, embed_dim] | |
| # Scale embeddings | |
| word_embed_scaled = model.embed_scale * word_embed | |
| special_embed_scaled = model.embed_scale * special_embed | |
| # Combined embedding | |
| combined = word_embed_scaled + pos_embed + special_embed_scaled | |
| print(f" Is special token: {is_special}") | |
| print(f" Word embedding norm: {word_embed_scaled.norm().item():.4f}") | |
| print(f" Positional embedding norm: {pos_embed.norm().item():.4f}") | |
| print(f" Special embedding norm: {special_embed_scaled.norm().item():.4f}") | |
| print(f" Combined embedding norm: {combined.norm().item():.4f}") | |
| # Show some embedding values | |
| print(f" Word embedding sample: {word_embed_scaled[0, 0, :5].tolist()}") | |
| print(f" Special embedding sample: {special_embed_scaled[0, 0, :5].tolist()}") | |
| print("\n=== Sequence Analysis ===") | |
| # Test with a sequence containing mixed tokens | |
| sequence = "a <TAG1> b <TAG2> c" | |
| sequence_tokens = sequence.split() | |
| sequence_indices = [token2idx[token] for token in sequence_tokens] | |
| print(f"Test sequence: {sequence}") | |
| print(f"Token indices: {sequence_indices}") | |
| # Create sequence tensor | |
| seq_tensor = torch.tensor([sequence_indices], dtype=torch.long).to(DEVICE) # [batch_size=1, seq_len] | |
| seq_tensor = seq_tensor.t() # Transpose to [seq_len, batch_size] | |
| with torch.no_grad(): | |
| # Get embeddings for the sequence | |
| word_embed = model.embed_scale * model.src_embed(seq_tensor) | |
| pos_embed = model.position_embed(seq_tensor) | |
| # Get special embeddings | |
| char_mask = (seq_tensor >= (vocab_size - num_special)).long() | |
| special_embed = model.embed_scale * model.special_embeddings(char_mask) | |
| # Combined embeddings | |
| combined = word_embed + pos_embed + special_embed | |
| print(f"\nSequence embedding shapes:") | |
| print(f" Word embeddings: {word_embed.shape}") | |
| print(f" Positional embeddings: {pos_embed.shape}") | |
| print(f" Special embeddings: {special_embed.shape}") | |
| print(f" Combined embeddings: {combined.shape}") | |
| print(f"\nToken-by-token analysis:") | |
| for i, (token, idx) in enumerate(zip(sequence_tokens, sequence_indices)): | |
| is_special = (idx >= vocab_size - num_special) | |
| print(f" Position {i}: {token:10s} | Special: {is_special}") | |
| print(f" Word emb norm: {word_embed[i, 0].norm().item():.4f}") | |
| print(f" Pos emb norm: {pos_embed[i, 0].norm().item():.4f}") | |
| print(f" Special emb norm: {special_embed[i, 0].norm().item():.4f}") | |
| print(f" Combined norm: {combined[i, 0].norm().item():.4f}") | |
| # Test encoder output | |
| print(f"\nTesting encoder...") | |
| src_mask = torch.zeros(seq_tensor.size(0), seq_tensor.size(1), dtype=torch.bool).to(DEVICE) | |
| encoded = model.encode(seq_tensor, src_mask) | |
| print(f" Encoder output shape: {encoded.shape}") | |
| print(f" Encoder output norms: {encoded.squeeze(1).norm(dim=1).tolist()}") | |
| print("\n=== Summary ===") | |
| print("✓ Regular characters get standard word embeddings + positional embeddings") | |
| print("✓ Special tokens get standard word embeddings + positional embeddings + special embeddings") | |
| print("✓ The special embeddings provide additional feature information for tags and special tokens") | |
| print("✓ This allows the model to distinguish between regular characters and linguistic tags") | |
| if __name__ == '__main__': | |
| analyze_token_embeddings() | |