Train PreDiff on SEVIR-LR dataset
Configurations for training and inference
Change the configurations in corresponding cfg.yaml You might consider modifying the following configurations according to your specific requirements:
trainer.check_val_every_n_epoch: Run validation everyntraining epochs. Set a larger value for it if you want to alleviate the time costs in validation.vis.eval_example_only: Iftrue, only data with indices in theexample_data_idx_listwill be evaluated. Set it tofalseif you want to evaluate the whole val/test set.vis.eval_aligned: Iftrue, PreDiff-KA will be evaluated.vis.eval_unaligned: Iftrue, PreDiff without knowledge alignment will be evaluated.vis.num_samples_per_context: Generatensamples for each context sequence.model.align.alignment_type:nullby default means not to load the knowledge alignment module. Setting it toavg_xfor knowledge alignment with anticipated future average intensity.model.align.model_ckpt_path: Point it to your own pretrained checkpoint if you want a custom knowledge alignment network.model.vae.pretrained_ckpt_path: Point it to your own pretrained checkpoint if you want a custom vae.
Commands for training and inference
Run the following command to train PreDiff on SEVIR-LR dataset.
cd ROOT_DIR/PreDiff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/prediff/sevirlr/train_sevirlr_prediff.py --gpus 2 --cfg ./scripts/prediff/sevirlr/cfg.yaml --save tmp_sevirlr_prediff
Or run the following command to directly load pretrained checkpoint for test.
cd ROOT_DIR/PreDiff
MASTER_ADDR=localhost MASTER_PORT=10001 python ./scripts/prediff/sevirlr/train_sevirlr_prediff.py --gpus 2 --pretrained --save tmp_sevirlr_prediff
Run the following command to train PreDiff using multi-node DDP.
# On the master node
MASTER_ADDR=localhost MASTER_PORT=10001 WORLD_SIZE=16 NODE_RANK=0 python ./scripts/prediff/sevirlr/train_sevirlr_prediff.py --nodes 2 --gpus 2 --cfg ./scripts/prediff/sevirlr/train_sevirlr_prediff/cfg.yaml --save tmp_sevirlr_prediff
# On the 1st node
MASTER_ADDR=$master_ip MASTER_PORT=10001 WORLD_SIZE=16 NODE_RANK=1 python ./scripts/prediff/sevirlr/train_sevirlr_prediff.py --nodes 2 --gpus 2 --cfg ./scripts/prediff/sevirlr/train_sevirlr_prediff/cfg.yaml --save tmp_sevirlr_prediff
Run the tensorboard command to visualize the experiment records
cd ROOT_DIR/PreDiff
tensorboard --logdir ./experiments/tmp_sevirlr_prediff/lightning_logs