Spaces:
Runtime error
Runtime error
File size: 4,197 Bytes
1f39ae1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
#!/usr/bin/env python3
"""
Test script to verify the training setup works correctly
"""
import torch
from torch.utils.data import DataLoader
from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn
from transformer import TagTransformer, PAD_IDX, DEVICE
def test_training_setup():
"""Test the complete training setup"""
print("=== Testing Training Setup ===")
# Data file paths
train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src'
train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt'
# Build vocabularies
print("Building vocabularies...")
src_vocab = build_vocabulary([train_src])
tgt_vocab = build_vocabulary([train_tgt])
print(f"Source vocabulary size: {len(src_vocab)}")
print(f"Target vocabulary size: {len(tgt_vocab)}")
# Count feature tokens
feature_tokens = [token for token in src_vocab.keys()
if token.startswith('<') and token.endswith('>')]
nb_attr = len(feature_tokens)
print(f"Number of feature tokens: {nb_attr}")
print(f"Feature examples: {feature_tokens[:5]}")
# Create dataset
print("\nCreating dataset...")
dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, max_length=50)
print(f"Dataset size: {len(dataset)}")
# Test data loading
print("\nTesting data loading...")
sample_src, sample_tgt = dataset[0]
print(f"Sample source: {' '.join(sample_src)}")
print(f"Sample target: {' '.join(sample_tgt)}")
# Create dataloader
print("\nCreating dataloader...")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=False,
collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, 50),
num_workers=0
)
# Test batch loading
print("\nTesting batch loading...")
for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print(f" Source shape: {src.shape}")
print(f" Source mask shape: {src_mask.shape}")
print(f" Target shape: {tgt.shape}")
print(f" Target mask shape: {tgt_mask.shape}")
if batch_idx >= 1: # Only test first 2 batches
break
# Create model
print("\nCreating model...")
config = {
'embed_dim': 64, # Small for testing
'nb_heads': 4,
'src_hid_size': 128,
'src_nb_layers': 2,
'trg_hid_size': 128,
'trg_nb_layers': 2,
'dropout_p': 0.0,
'tie_trg_embed': True,
'label_smooth': 0.0,
}
model = TagTransformer(
src_vocab_size=len(src_vocab),
trg_vocab_size=len(tgt_vocab),
embed_dim=config['embed_dim'],
nb_heads=config['nb_heads'],
src_hid_size=config['src_hid_size'],
src_nb_layers=config['src_nb_layers'],
trg_hid_size=config['trg_hid_size'],
trg_nb_layers=config['trg_nb_layers'],
dropout_p=config['dropout_p'],
tie_trg_embed=config['tie_trg_embed'],
label_smooth=config['label_smooth'],
nb_attr=nb_attr,
src_c2i=src_vocab,
trg_c2i=tgt_vocab,
attr_c2i={},
)
model = model.to(DEVICE)
print(f"Model created with {model.count_nb_params():,} parameters")
# Test forward pass
print("\nTesting forward pass...")
model.eval()
with torch.no_grad():
# Get a batch
src, src_mask, tgt, tgt_mask = next(iter(dataloader))
src, src_mask, tgt, tgt_mask = (
src.to(DEVICE), src_mask.to(DEVICE),
tgt.to(DEVICE), tgt_mask.to(DEVICE)
)
# Forward pass
output = model(src, src_mask, tgt, tgt_mask)
print(f" Output shape: {output.shape}")
# Test loss computation
loss = model.loss(output[:-1], tgt[1:])
print(f" Loss: {loss.item():.4f}")
print("\n✓ Training setup test completed successfully!")
print("\nReady to start training with:")
print(f"python scripts/train_morphological.py --output_dir ./models")
if __name__ == '__main__':
test_training_setup()
|