Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Test script to verify TagTransformer embeddings for special characters and tags | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformer import TagTransformer, PAD_IDX, DEVICE | |
| def create_test_vocabulary(): | |
| """Create a test vocabulary that includes the example tokens""" | |
| # Regular characters from your example | |
| chars = ['t', 'ɾ', 'a', 'd', 'ˈ', 'u', 's', 'e', 'n', 'k', 'm', 'o'] | |
| # Special tokens and tags from your example | |
| special_tokens = [ | |
| '<V;IND;PRS;3;PL>', | |
| '<V;SBJV;PRS;1;PL>', | |
| '<V;SBJV;PRS;3;PL>', | |
| '#' | |
| ] | |
| # Create character-to-index mappings | |
| char2idx = {char: idx for idx, char in enumerate(chars)} | |
| special2idx = {token: idx + len(chars) for idx, token in enumerate(special_tokens)} | |
| # Combine mappings | |
| all_tokens = chars + special_tokens | |
| token2idx = {**char2idx, **special2idx} | |
| return token2idx, all_tokens, len(chars), len(special_tokens) | |
| def tokenize_sequence(sequence, token2idx): | |
| """Tokenize the example sequence""" | |
| # Split by spaces and map to indices | |
| tokens = sequence.split() | |
| indices = [token2idx[token] for token in tokens] | |
| return torch.tensor(indices, dtype=torch.long) | |
| def test_embeddings(): | |
| """Test the TagTransformer embeddings""" | |
| # Create vocabulary | |
| token2idx, all_tokens, num_chars, num_special = create_test_vocabulary() | |
| vocab_size = len(all_tokens) | |
| print("Vocabulary:") | |
| for idx, token in enumerate(all_tokens): | |
| print(f" {idx:2d}: {token}") | |
| print() | |
| # Your example sequence | |
| example = "t ɾ a d ˈ u s e n <V;IND;PRS;3;PL> # t ɾ a d u s k ˈ a m o s <V;SBJV;PRS;1;PL> # <V;SBJV;PRS;3;PL>" | |
| print(f"Example sequence: {example}") | |
| print() | |
| # Tokenize | |
| tokenized = tokenize_sequence(example, token2idx) | |
| print(f"Tokenized sequence: {tokenized.tolist()}") | |
| print(f"Sequence length: {len(tokenized)}") | |
| print() | |
| # Create model | |
| model = TagTransformer( | |
| src_vocab_size=vocab_size, | |
| trg_vocab_size=vocab_size, | |
| embed_dim=64, # Small for testing | |
| nb_heads=4, | |
| src_hid_size=128, | |
| src_nb_layers=2, | |
| trg_hid_size=128, | |
| trg_nb_layers=2, | |
| dropout_p=0.0, # No dropout for testing | |
| tie_trg_embed=True, | |
| label_smooth=0.0, | |
| nb_attr=num_special, | |
| src_c2i=token2idx, | |
| trg_c2i=token2idx, | |
| attr_c2i={}, # Not used in this test | |
| ) | |
| model = model.to(DEVICE) | |
| model.eval() | |
| print(f"Model created with {model.count_nb_params():,} parameters") | |
| print(f"Number of special tokens/attributes: {num_special}") | |
| print() | |
| # Prepare input | |
| src_batch = tokenized.unsqueeze(1).to(DEVICE) # [seq_len, batch_size=1] | |
| src_mask = torch.zeros(len(tokenized), 1, dtype=torch.bool).to(DEVICE) | |
| print("Input shape:", src_batch.shape) | |
| print() | |
| # Test embedding | |
| with torch.no_grad(): | |
| # Get embeddings | |
| word_embed = model.embed_scale * model.src_embed(src_batch) | |
| pos_embed = model.position_embed(src_batch) | |
| # Get special embeddings | |
| char_mask = (src_batch >= (model.src_vocab_size - model.nb_attr)).long() | |
| special_embed = model.embed_scale * model.special_embeddings(char_mask) | |
| # Combine embeddings | |
| combined_embed = word_embed + pos_embed + special_embed | |
| print("Embedding analysis:") | |
| print(f" Word embeddings shape: {word_embed.shape}") | |
| print(f" Positional embeddings shape: {pos_embed.shape}") | |
| print(f" Special embeddings shape: {special_embed.shape}") | |
| print(f" Combined embeddings shape: {combined_embed.shape}") | |
| print() | |
| # Show which tokens get special embeddings | |
| print("Special embedding analysis:") | |
| for i, token_idx in enumerate(tokenized): | |
| token = all_tokens[token_idx] | |
| is_special = char_mask[i, 0].item() | |
| special_emb = special_embed[i, 0] | |
| print(f" Position {i:2d}: {token:20s} | Special: {is_special} | Special emb norm: {special_emb.norm().item():.4f}") | |
| print() | |
| # Test full forward pass | |
| print("Testing full forward pass...") | |
| output = model.encode(src_batch, src_mask) | |
| print(f" Encoder output shape: {output.shape}") | |
| print(f" Encoder output norm: {output.norm().item():.4f}") | |
| # Check if special tokens have different representations | |
| special_positions = char_mask.squeeze(1).nonzero(as_tuple=True)[0] | |
| char_positions = (1 - char_mask.squeeze(1)).nonzero(as_tuple=True)[0] | |
| if len(special_positions) > 0 and len(char_positions) > 0: | |
| special_repr = output[special_positions].mean(dim=0) | |
| char_repr = output[char_positions].mean(dim=0) | |
| print(f" Average special token representation norm: {special_repr.norm().item():.4f}") | |
| print(f" Average character representation norm: {char_repr.norm().item():.4f}") | |
| print(f" Difference: {torch.dist(special_repr, char_repr).item():.4f}") | |
| print("✓ Embedding test completed successfully!") | |
| if __name__ == '__main__': | |
| test_embeddings() | |