#!/usr/bin/env python3 """ Batch training script for all morphological reinflection datasets on Hugging Face """ import os import subprocess import argparse from pathlib import Path def run_training_command(cmd): """Run a training command and handle errors""" print(f"Running: {' '.join(cmd)}") try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) print(f"✅ Success: {cmd[2]}") # cmd[2] is the model name return True except subprocess.CalledProcessError as e: print(f"❌ Error training {cmd[2]}: {e}") print(f"STDOUT: {e.stdout}") print(f"STDERR: {e.stderr}") return False def main(): parser = argparse.ArgumentParser(description='Train all morphological transformer models on Hugging Face') parser.add_argument('--username', type=str, required=True, help='Your Hugging Face username') parser.add_argument('--wandb_project', type=str, default='morphological-transformer', help='Weights & Biases project name') parser.add_argument('--hf_token', type=str, help='Hugging Face token for model upload') parser.add_argument('--upload_models', action='store_true', help='Upload models to Hugging Face Hub') parser.add_argument('--output_dir', type=str, default='./hf_models', help='Output directory') parser.add_argument('--datasets', nargs='+', default=['10L_90NL', '50L_50NL', '90L_10NL'], help='Datasets to train (default: all)') parser.add_argument('--runs', nargs='+', default=['1', '2', '3'], help='Runs to train (default: all)') parser.add_argument('--dry_run', action='store_true', help='Print commands without executing') args = parser.parse_args() # Base command template base_cmd = [ 'python', 'scripts/train_huggingface.py', '--output_dir', args.output_dir, '--wandb_project', args.wandb_project ] if args.hf_token: base_cmd.extend(['--hf_token', args.hf_token]) if args.upload_models: base_cmd.append('--upload_model') # Dataset configurations datasets = { '10L_90NL': { 'train_pattern': './10L_90NL/train/run{run}/train.10L_90NL_{run}_1.src', 'train_tgt_pattern': './10L_90NL/train/run{run}/train.10L_90NL_{run}_1.tgt', 'dev_pattern': './10L_90NL/dev/run{run}/dev.10L_90NL_{run}_1.src', 'dev_tgt_pattern': './10L_90NL/dev/run{run}/dev.10L_90NL_{run}_1.tgt', 'test_pattern': './10L_90NL/test/run{run}/test.10L_90NL_{run}_1.src', 'test_tgt_pattern': './10L_90NL/test/run{run}/test.10L_90NL_{run}_1.tgt', 'model_name_pattern': '{username}/morphological-transformer-10L90NL-run{run}' }, '50L_50NL': { 'train_pattern': './50L_50NL/train/run{run}/train.50L_50NL_{run}_1.src', 'train_tgt_pattern': './50L_50NL/train/run{run}/train.50L_50NL_{run}_1.tgt', 'dev_pattern': './50L_50NL/dev/run{run}/dev.50L_50NL_{run}_1.src', 'dev_tgt_pattern': './50L_50NL/dev/run{run}/dev.50L_50NL_{run}_1.tgt', 'test_pattern': './50L_50NL/test/run{run}/test.50L_50NL_{run}_1.src', 'test_tgt_pattern': './50L_50NL/test/run{run}/test.50L_50NL_{run}_1.tgt', 'model_name_pattern': '{username}/morphological-transformer-50L50NL-run{run}' }, '90L_10NL': { 'train_pattern': './90L_10NL/train/run{run}/train.90L_10NL_{run}_1.src', 'train_tgt_pattern': './90L_10NL/train/run{run}/train.90L_10NL_{run}_1.tgt', 'dev_pattern': './90L_10NL/dev/run{run}/dev.90L_10NL_{run}_1.src', 'dev_tgt_pattern': './90L_10NL/dev/run{run}/dev.90L_10NL_{run}_1.tgt', 'test_pattern': './90L_10NL/test/run{run}/test.90L_10NL_{run}_1.src', 'test_tgt_pattern': './90L_10NL/test/run{run}/test.90L_10NL_{run}_1.tgt', 'model_name_pattern': '{username}/morphological-transformer-90L10NL-run{run}' } } # Generate training commands commands = [] for dataset in args.datasets: if dataset not in datasets: print(f"⚠️ Unknown dataset: {dataset}") continue config = datasets[dataset] for run in args.runs: # Check if data files exist train_src = config['train_pattern'].format(run=run) train_tgt = config['train_tgt_pattern'].format(run=run) dev_src = config['dev_pattern'].format(run=run) dev_tgt = config['dev_tgt_pattern'].format(run=run) test_src = config['test_pattern'].format(run=run) test_tgt = config['test_tgt_pattern'].format(run=run) # Check if files exist missing_files = [] for file_path in [train_src, train_tgt, dev_src, dev_tgt, test_src, test_tgt]: if not os.path.exists(file_path): missing_files.append(file_path) if missing_files: print(f"⚠️ Skipping {dataset} run {run} - missing files: {missing_files}") continue # Build command model_name = config['model_name_pattern'].format(username=args.username, run=run) cmd = base_cmd + [ '--model_name', model_name, '--train_src', train_src, '--train_tgt', train_tgt, '--dev_src', dev_src, '--dev_tgt', dev_tgt, '--test_src', test_src, '--test_tgt', test_tgt ] commands.append(cmd) print(f"🚀 Found {len(commands)} training jobs to run") if args.dry_run: print("\n📋 Commands that would be executed:") for i, cmd in enumerate(commands, 1): print(f"{i:2d}. {' '.join(cmd)}") return # Execute training commands successful = 0 failed = 0 for i, cmd in enumerate(commands, 1): print(f"\n🔄 Training {i}/{len(commands)}: {cmd[2]}") if run_training_command(cmd): successful += 1 else: failed += 1 # Summary print(f"\n📊 Training Summary:") print(f"✅ Successful: {successful}") print(f"❌ Failed: {failed}") print(f"📈 Success Rate: {successful/(successful+failed)*100:.1f}%") if successful > 0: print(f"\n🎉 Models saved to: {args.output_dir}") if args.upload_models: print(f"🌐 Models uploaded to: https://huggingface.co/{args.username}") if __name__ == '__main__': main()