| | import os.path as osp |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| | from typing import Optional |
| | from pathlib import Path |
| | from models.maplocnet import MapLocNet |
| | import hydra |
| | import pytorch_lightning as pl |
| | import torch |
| | from omegaconf import DictConfig, OmegaConf |
| | from pytorch_lightning.utilities import rank_zero_only |
| | from module import GenericModule |
| | from logger import logger, pl_logger, EXPERIMENTS_PATH |
| | from module import GenericModule |
| | from dataset import UavMapDatasetModule |
| | from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
| | |
| |
|
| |
|
| | class CleanProgressBar(pl.callbacks.TQDMProgressBar): |
| | def get_metrics(self, trainer, model): |
| | items = super().get_metrics(trainer, model) |
| | items.pop("v_num", None) |
| | items.pop("loss", None) |
| | return items |
| |
|
| |
|
| | class SeedingCallback(pl.callbacks.Callback): |
| | def on_epoch_start_(self, trainer, module): |
| | seed = module.cfg.experiment.seed |
| | is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0 |
| | if trainer.training and not is_overfit: |
| | seed = seed + trainer.current_epoch |
| |
|
| | |
| | pl_logger.disabled = True |
| | try: |
| | pl.seed_everything(seed, workers=True) |
| | finally: |
| | pl_logger.disabled = False |
| |
|
| | def on_train_epoch_start(self, *args, **kwargs): |
| | self.on_epoch_start_(*args, **kwargs) |
| |
|
| | def on_validation_epoch_start(self, *args, **kwargs): |
| | self.on_epoch_start_(*args, **kwargs) |
| |
|
| | def on_test_epoch_start(self, *args, **kwargs): |
| | self.on_epoch_start_(*args, **kwargs) |
| |
|
| |
|
| | class ConsoleLogger(pl.callbacks.Callback): |
| | @rank_zero_only |
| | def on_train_epoch_start(self, trainer, module): |
| | logger.info( |
| | "New training epoch %d for experiment '%s'.", |
| | module.current_epoch, |
| | module.cfg.experiment.name, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def find_last_checkpoint_path(experiment_dir): |
| | cls = pl.callbacks.ModelCheckpoint |
| | path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION) |
| | if osp.exists(path): |
| | return path |
| | else: |
| | return None |
| |
|
| |
|
| | def prepare_experiment_dir(experiment_dir, cfg, rank): |
| | config_path = osp.join(experiment_dir, "config.yaml") |
| | last_checkpoint_path = find_last_checkpoint_path(experiment_dir) |
| | if last_checkpoint_path is not None: |
| | if rank == 0: |
| | logger.info( |
| | "Resuming the training from checkpoint %s", last_checkpoint_path |
| | ) |
| | if osp.exists(config_path): |
| | with open(config_path, "r") as fp: |
| | cfg_prev = OmegaConf.create(fp.read()) |
| | compare_keys = ["experiment", "data", "model", "training"] |
| | if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy( |
| | cfg_prev, compare_keys |
| | ): |
| | raise ValueError( |
| | "Attempting to resume training with a different config: " |
| | f"{OmegaConf.masked_copy(cfg, compare_keys)} vs " |
| | f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}" |
| | ) |
| | if rank == 0: |
| | Path(experiment_dir).mkdir(exist_ok=True, parents=True) |
| | with open(config_path, "w") as fp: |
| | OmegaConf.save(cfg, fp) |
| | return last_checkpoint_path |
| |
|
| |
|
| | def train(cfg: DictConfig) -> None: |
| | torch.set_float32_matmul_precision("medium") |
| | OmegaConf.resolve(cfg) |
| | rank = rank_zero_only.rank |
| |
|
| | if rank == 0: |
| | logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg)) |
| | if cfg.experiment.gpus in (None, 0): |
| | logger.warning("Will train on CPU...") |
| | cfg.experiment.gpus = 0 |
| | elif not torch.cuda.is_available(): |
| | raise ValueError("Requested GPU but no NVIDIA drivers found.") |
| | pl.seed_everything(cfg.experiment.seed, workers=True) |
| |
|
| | init_checkpoint_path = cfg.training.get("finetune_from_checkpoint") |
| | if init_checkpoint_path is not None: |
| | logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path) |
| | model = GenericModule.load_from_checkpoint( |
| | init_checkpoint_path, strict=True, find_best=False, cfg=cfg |
| | ) |
| | else: |
| | model = GenericModule(cfg) |
| | if rank == 0: |
| | logger.info("Network:\n%s", model.model) |
| |
|
| | experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name) |
| | last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank) |
| | checkpointing_epoch = pl.callbacks.ModelCheckpoint( |
| | dirpath=experiment_dir, |
| | filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}", |
| | auto_insert_metric_name=False, |
| | save_last=True, |
| | every_n_epochs=1, |
| | save_on_train_epoch_end=True, |
| | verbose=True, |
| | **cfg.training.checkpointing, |
| | ) |
| | checkpointing_step = pl.callbacks.ModelCheckpoint( |
| | dirpath=experiment_dir, |
| | filename="checkpoint-step-{step}-{loss/total/val:02f}", |
| | auto_insert_metric_name=False, |
| | save_last=True, |
| | every_n_train_steps=1000, |
| | verbose=True, |
| | **cfg.training.checkpointing, |
| | ) |
| | checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing" |
| |
|
| | |
| | early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5) |
| |
|
| | strategy = None |
| | if cfg.experiment.gpus > 1: |
| | strategy = pl.strategies.DDPStrategy(find_unused_parameters=False) |
| | for split in ["train", "val"]: |
| | cfg.data[split].batch_size = ( |
| | cfg.data[split].batch_size // cfg.experiment.gpus |
| | ) |
| | cfg.data[split].num_workers = int( |
| | (cfg.data[split].num_workers + cfg.experiment.gpus - 1) |
| | / cfg.experiment.gpus |
| | ) |
| |
|
| | |
| |
|
| | datamodule =UavMapDatasetModule(cfg.data) |
| |
|
| | tb_args = {"name": cfg.experiment.name, "version": ""} |
| | tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args) |
| |
|
| | callbacks = [ |
| | checkpointing_epoch, |
| | checkpointing_step, |
| | |
| | pl.callbacks.LearningRateMonitor(), |
| | SeedingCallback(), |
| | CleanProgressBar(), |
| | ConsoleLogger(), |
| | ] |
| | if cfg.experiment.gpus > 0: |
| | callbacks.append(pl.callbacks.DeviceStatsMonitor()) |
| |
|
| | trainer = pl.Trainer( |
| | default_root_dir=experiment_dir, |
| | detect_anomaly=False, |
| | |
| | enable_model_summary=True, |
| | sync_batchnorm=True, |
| | enable_checkpointing=True, |
| | logger=tb, |
| | callbacks=callbacks, |
| | strategy=strategy, |
| | check_val_every_n_epoch=1, |
| | accelerator="gpu", |
| | num_nodes=1, |
| | **cfg.training.trainer, |
| | ) |
| | trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path) |
| |
|
| |
|
| | @hydra.main( |
| | config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml" |
| | ) |
| | def main(cfg: DictConfig) -> None: |
| | OmegaConf.save(config=cfg, f='maplocnet.yaml') |
| | train(cfg) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|