#!/usr/bin/env python3 """ Test script to verify the training setup works correctly """ import torch from torch.utils.data import DataLoader from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn from transformer import TagTransformer, PAD_IDX, DEVICE def test_training_setup(): """Test the complete training setup""" print("=== Testing Training Setup ===") # Data file paths train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src' train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt' # Build vocabularies print("Building vocabularies...") src_vocab = build_vocabulary([train_src]) tgt_vocab = build_vocabulary([train_tgt]) print(f"Source vocabulary size: {len(src_vocab)}") print(f"Target vocabulary size: {len(tgt_vocab)}") # Count feature tokens feature_tokens = [token for token in src_vocab.keys() if token.startswith('<') and token.endswith('>')] nb_attr = len(feature_tokens) print(f"Number of feature tokens: {nb_attr}") print(f"Feature examples: {feature_tokens[:5]}") # Create dataset print("\nCreating dataset...") dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, max_length=50) print(f"Dataset size: {len(dataset)}") # Test data loading print("\nTesting data loading...") sample_src, sample_tgt = dataset[0] print(f"Sample source: {' '.join(sample_src)}") print(f"Sample target: {' '.join(sample_tgt)}") # Create dataloader print("\nCreating dataloader...") dataloader = DataLoader( dataset, batch_size=4, shuffle=False, collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, 50), num_workers=0 ) # Test batch loading print("\nTesting batch loading...") for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader): print(f"Batch {batch_idx}:") print(f" Source shape: {src.shape}") print(f" Source mask shape: {src_mask.shape}") print(f" Target shape: {tgt.shape}") print(f" Target mask shape: {tgt_mask.shape}") if batch_idx >= 1: # Only test first 2 batches break # Create model print("\nCreating model...") config = { 'embed_dim': 64, # Small for testing 'nb_heads': 4, 'src_hid_size': 128, 'src_nb_layers': 2, 'trg_hid_size': 128, 'trg_nb_layers': 2, 'dropout_p': 0.0, 'tie_trg_embed': True, 'label_smooth': 0.0, } model = TagTransformer( src_vocab_size=len(src_vocab), trg_vocab_size=len(tgt_vocab), embed_dim=config['embed_dim'], nb_heads=config['nb_heads'], src_hid_size=config['src_hid_size'], src_nb_layers=config['src_nb_layers'], trg_hid_size=config['trg_hid_size'], trg_nb_layers=config['trg_nb_layers'], dropout_p=config['dropout_p'], tie_trg_embed=config['tie_trg_embed'], label_smooth=config['label_smooth'], nb_attr=nb_attr, src_c2i=src_vocab, trg_c2i=tgt_vocab, attr_c2i={}, ) model = model.to(DEVICE) print(f"Model created with {model.count_nb_params():,} parameters") # Test forward pass print("\nTesting forward pass...") model.eval() with torch.no_grad(): # Get a batch src, src_mask, tgt, tgt_mask = next(iter(dataloader)) src, src_mask, tgt, tgt_mask = ( src.to(DEVICE), src_mask.to(DEVICE), tgt.to(DEVICE), tgt_mask.to(DEVICE) ) # Forward pass output = model(src, src_mask, tgt, tgt_mask) print(f" Output shape: {output.shape}") # Test loss computation loss = model.loss(output[:-1], tgt[1:]) print(f" Loss: {loss.item():.4f}") print("\n✓ Training setup test completed successfully!") print("\nReady to start training with:") print(f"python scripts/train_morphological.py --output_dir ./models") if __name__ == '__main__': test_training_setup()