#!/usr/bin/env python3 """ Test script to verify TagTransformer embeddings for special characters and tags """ import torch import torch.nn as nn from transformer import TagTransformer, PAD_IDX, DEVICE def create_test_vocabulary(): """Create a test vocabulary that includes the example tokens""" # Regular characters from your example chars = ['t', 'ɾ', 'a', 'd', 'ˈ', 'u', 's', 'e', 'n', 'k', 'm', 'o'] # Special tokens and tags from your example special_tokens = [ '', '', '', '#' ] # Create character-to-index mappings char2idx = {char: idx for idx, char in enumerate(chars)} special2idx = {token: idx + len(chars) for idx, token in enumerate(special_tokens)} # Combine mappings all_tokens = chars + special_tokens token2idx = {**char2idx, **special2idx} return token2idx, all_tokens, len(chars), len(special_tokens) def tokenize_sequence(sequence, token2idx): """Tokenize the example sequence""" # Split by spaces and map to indices tokens = sequence.split() indices = [token2idx[token] for token in tokens] return torch.tensor(indices, dtype=torch.long) def test_embeddings(): """Test the TagTransformer embeddings""" # Create vocabulary token2idx, all_tokens, num_chars, num_special = create_test_vocabulary() vocab_size = len(all_tokens) print("Vocabulary:") for idx, token in enumerate(all_tokens): print(f" {idx:2d}: {token}") print() # Your example sequence example = "t ɾ a d ˈ u s e n # t ɾ a d u s k ˈ a m o s # " print(f"Example sequence: {example}") print() # Tokenize tokenized = tokenize_sequence(example, token2idx) print(f"Tokenized sequence: {tokenized.tolist()}") print(f"Sequence length: {len(tokenized)}") print() # Create model model = TagTransformer( src_vocab_size=vocab_size, trg_vocab_size=vocab_size, embed_dim=64, # Small for testing nb_heads=4, src_hid_size=128, src_nb_layers=2, trg_hid_size=128, trg_nb_layers=2, dropout_p=0.0, # No dropout for testing tie_trg_embed=True, label_smooth=0.0, nb_attr=num_special, src_c2i=token2idx, trg_c2i=token2idx, attr_c2i={}, # Not used in this test ) model = model.to(DEVICE) model.eval() print(f"Model created with {model.count_nb_params():,} parameters") print(f"Number of special tokens/attributes: {num_special}") print() # Prepare input src_batch = tokenized.unsqueeze(1).to(DEVICE) # [seq_len, batch_size=1] src_mask = torch.zeros(len(tokenized), 1, dtype=torch.bool).to(DEVICE) print("Input shape:", src_batch.shape) print() # Test embedding with torch.no_grad(): # Get embeddings word_embed = model.embed_scale * model.src_embed(src_batch) pos_embed = model.position_embed(src_batch) # Get special embeddings char_mask = (src_batch >= (model.src_vocab_size - model.nb_attr)).long() special_embed = model.embed_scale * model.special_embeddings(char_mask) # Combine embeddings combined_embed = word_embed + pos_embed + special_embed print("Embedding analysis:") print(f" Word embeddings shape: {word_embed.shape}") print(f" Positional embeddings shape: {pos_embed.shape}") print(f" Special embeddings shape: {special_embed.shape}") print(f" Combined embeddings shape: {combined_embed.shape}") print() # Show which tokens get special embeddings print("Special embedding analysis:") for i, token_idx in enumerate(tokenized): token = all_tokens[token_idx] is_special = char_mask[i, 0].item() special_emb = special_embed[i, 0] print(f" Position {i:2d}: {token:20s} | Special: {is_special} | Special emb norm: {special_emb.norm().item():.4f}") print() # Test full forward pass print("Testing full forward pass...") output = model.encode(src_batch, src_mask) print(f" Encoder output shape: {output.shape}") print(f" Encoder output norm: {output.norm().item():.4f}") # Check if special tokens have different representations special_positions = char_mask.squeeze(1).nonzero(as_tuple=True)[0] char_positions = (1 - char_mask.squeeze(1)).nonzero(as_tuple=True)[0] if len(special_positions) > 0 and len(char_positions) > 0: special_repr = output[special_positions].mean(dim=0) char_repr = output[char_positions].mean(dim=0) print(f" Average special token representation norm: {special_repr.norm().item():.4f}") print(f" Average character representation norm: {char_repr.norm().item():.4f}") print(f" Difference: {torch.dist(special_repr, char_repr).item():.4f}") print("✓ Embedding test completed successfully!") if __name__ == '__main__': test_embeddings()