morphological-transformer / scripts /test_training_setup.py
akki2825
Initial deployment of Morphological Transformer with ZeroGPU
1f39ae1
#!/usr/bin/env python3
"""
Test script to verify the training setup works correctly
"""
import torch
from torch.utils.data import DataLoader
from morphological_dataset import MorphologicalDataset, build_vocabulary, collate_fn
from transformer import TagTransformer, PAD_IDX, DEVICE
def test_training_setup():
"""Test the complete training setup"""
print("=== Testing Training Setup ===")
# Data file paths
train_src = '10L_90NL/train/run1/train.10L_90NL_1_1.src'
train_tgt = '10L_90NL/train/run1/train.10L_90NL_1_1.tgt'
# Build vocabularies
print("Building vocabularies...")
src_vocab = build_vocabulary([train_src])
tgt_vocab = build_vocabulary([train_tgt])
print(f"Source vocabulary size: {len(src_vocab)}")
print(f"Target vocabulary size: {len(tgt_vocab)}")
# Count feature tokens
feature_tokens = [token for token in src_vocab.keys()
if token.startswith('<') and token.endswith('>')]
nb_attr = len(feature_tokens)
print(f"Number of feature tokens: {nb_attr}")
print(f"Feature examples: {feature_tokens[:5]}")
# Create dataset
print("\nCreating dataset...")
dataset = MorphologicalDataset(train_src, train_tgt, src_vocab, tgt_vocab, max_length=50)
print(f"Dataset size: {len(dataset)}")
# Test data loading
print("\nTesting data loading...")
sample_src, sample_tgt = dataset[0]
print(f"Sample source: {' '.join(sample_src)}")
print(f"Sample target: {' '.join(sample_tgt)}")
# Create dataloader
print("\nCreating dataloader...")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=False,
collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, 50),
num_workers=0
)
# Test batch loading
print("\nTesting batch loading...")
for batch_idx, (src, src_mask, tgt, tgt_mask) in enumerate(dataloader):
print(f"Batch {batch_idx}:")
print(f" Source shape: {src.shape}")
print(f" Source mask shape: {src_mask.shape}")
print(f" Target shape: {tgt.shape}")
print(f" Target mask shape: {tgt_mask.shape}")
if batch_idx >= 1: # Only test first 2 batches
break
# Create model
print("\nCreating model...")
config = {
'embed_dim': 64, # Small for testing
'nb_heads': 4,
'src_hid_size': 128,
'src_nb_layers': 2,
'trg_hid_size': 128,
'trg_nb_layers': 2,
'dropout_p': 0.0,
'tie_trg_embed': True,
'label_smooth': 0.0,
}
model = TagTransformer(
src_vocab_size=len(src_vocab),
trg_vocab_size=len(tgt_vocab),
embed_dim=config['embed_dim'],
nb_heads=config['nb_heads'],
src_hid_size=config['src_hid_size'],
src_nb_layers=config['src_nb_layers'],
trg_hid_size=config['trg_hid_size'],
trg_nb_layers=config['trg_nb_layers'],
dropout_p=config['dropout_p'],
tie_trg_embed=config['tie_trg_embed'],
label_smooth=config['label_smooth'],
nb_attr=nb_attr,
src_c2i=src_vocab,
trg_c2i=tgt_vocab,
attr_c2i={},
)
model = model.to(DEVICE)
print(f"Model created with {model.count_nb_params():,} parameters")
# Test forward pass
print("\nTesting forward pass...")
model.eval()
with torch.no_grad():
# Get a batch
src, src_mask, tgt, tgt_mask = next(iter(dataloader))
src, src_mask, tgt, tgt_mask = (
src.to(DEVICE), src_mask.to(DEVICE),
tgt.to(DEVICE), tgt_mask.to(DEVICE)
)
# Forward pass
output = model(src, src_mask, tgt, tgt_mask)
print(f" Output shape: {output.shape}")
# Test loss computation
loss = model.loss(output[:-1], tgt[1:])
print(f" Loss: {loss.item():.4f}")
print("\n✓ Training setup test completed successfully!")
print("\nReady to start training with:")
print(f"python scripts/train_morphological.py --output_dir ./models")
if __name__ == '__main__':
test_training_setup()