Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Test script to verify the training setup works correctly | |
| """ | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn | |
| from transformer import TagTransformer, PAD_IDX, DEVICE | |
| def test_training_setup(): | |
| """Test the complete training setup""" | |
| print("=== Testing Training Setup ===") | |
| # Data file paths | |
| train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src' | |
| train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt' | |
| # Build vocabularies | |
| print("Building vocabularies...") | |
| src_vocab = build_vocabulary([train_src]) | |
| tgt_vocab = build_vocabulary([train_tgt]) | |
| print(f"Source vocabulary size: {len(src_vocab)}") | |
| print(f"Target vocabulary size: {len(tgt_vocab)}") | |
| # Count feature tokens | |
| feature_tokens = [token for token in src_vocab.keys() | |
| if token.startswith('<') and token.endswith('>')] | |
| nb_attr = len(feature_tokens) | |
| print(f"Number of feature tokens: {nb_attr}") | |
| print(f"Feature examples: {feature_tokens[:5]}") | |
| # Create dataset | |
| print("\nCreating dataset...") | |
| dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, max_length=50) | |
| print(f"Dataset size: {len(dataset)}") | |
| # Test data loading | |
| print("\nTesting data loading...") | |
| sample_src, sample_tgt = dataset[0] | |
| print(f"Sample source: {' '.join(sample_src)}") | |
| print(f"Sample target: {' '.join(sample_tgt)}") | |
| # Create dataloader | |
| print("\nCreating dataloader...") | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, 50), | |
| num_workers=0 | |
| ) | |
| # Test batch loading | |
| print("\nTesting batch loading...") | |
| for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader): | |
| print(f"Batch {batch_idx}:") | |
| print(f" Source shape: {src.shape}") | |
| print(f" Source mask shape: {src_mask.shape}") | |
| print(f" Target shape: {tgt.shape}") | |
| print(f" Target mask shape: {tgt_mask.shape}") | |
| if batch_idx >= 1: # Only test first 2 batches | |
| break | |
| # Create model | |
| print("\nCreating model...") | |
| config = { | |
| 'embed_dim': 64, # Small for testing | |
| 'nb_heads': 4, | |
| 'src_hid_size': 128, | |
| 'src_nb_layers': 2, | |
| 'trg_hid_size': 128, | |
| 'trg_nb_layers': 2, | |
| 'dropout_p': 0.0, | |
| 'tie_trg_embed': True, | |
| 'label_smooth': 0.0, | |
| } | |
| model = TagTransformer( | |
| src_vocab_size=len(src_vocab), | |
| trg_vocab_size=len(tgt_vocab), | |
| embed_dim=config['embed_dim'], | |
| nb_heads=config['nb_heads'], | |
| src_hid_size=config['src_hid_size'], | |
| src_nb_layers=config['src_nb_layers'], | |
| trg_hid_size=config['trg_hid_size'], | |
| trg_nb_layers=config['trg_nb_layers'], | |
| dropout_p=config['dropout_p'], | |
| tie_trg_embed=config['tie_trg_embed'], | |
| label_smooth=config['label_smooth'], | |
| nb_attr=nb_attr, | |
| src_c2i=src_vocab, | |
| trg_c2i=tgt_vocab, | |
| attr_c2i={}, | |
| ) | |
| model = model.to(DEVICE) | |
| print(f"Model created with {model.count_nb_params():,} parameters") | |
| # Test forward pass | |
| print("\nTesting forward pass...") | |
| model.eval() | |
| with torch.no_grad(): | |
| # Get a batch | |
| src, src_mask, tgt, tgt_mask = next(iter(dataloader)) | |
| src, src_mask, tgt, tgt_mask = ( | |
| src.to(DEVICE), src_mask.to(DEVICE), | |
| tgt.to(DEVICE), tgt_mask.to(DEVICE) | |
| ) | |
| # Forward pass | |
| output = model(src, src_mask, tgt, tgt_mask) | |
| print(f" Output shape: {output.shape}") | |
| # Test loss computation | |
| loss = model.loss(output[:-1], tgt[1:]) | |
| print(f" Loss: {loss.item():.4f}") | |
| print("\n✓ Training setup test completed successfully!") | |
| print("\nReady to start training with:") | |
| print(f"python scripts/train_morphological.py --output_dir ./models") | |
| if __name__ == '__main__': | |
| test_training_setup() | |