Spaces:
Runtime error
Runtime error
| # An official reimplemented version of Marigold training script | |
| # Last modified: 2024-05-17 | |
| # | |
| # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
| # More information about the method can be found at https://marigoldmonodepth.github.io | |
| # -------------------------------------------------------------------------- | |
| import argparse | |
| import logging | |
| import os | |
| import shutil | |
| from datetime import datetime, timedelta | |
| from typing import List | |
| import torch | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import ConcatDataset, DataLoader | |
| from tqdm import tqdm | |
| from marigold.marigold_pipeline import MarigoldPipeline | |
| from src.dataset import BaseDepthDataset, DatasetMode, get_dataset | |
| from src.dataset.mixed_sampler import MixedBatchSampler | |
| from src.trainer import get_trainer_cls | |
| from src.util.config_util import ( | |
| find_value_in_omegaconf, | |
| recursive_load_config, | |
| ) | |
| from src.util.depth_transform import ( | |
| DepthNormalizerBase, | |
| get_depth_normalizer, | |
| ) | |
| from src.util.logging_util import ( | |
| config_logging, | |
| init_wandb, | |
| load_wandb_job_id, | |
| log_slurm_job_id, | |
| save_wandb_job_id, | |
| tb_logger, | |
| ) | |
| from src.util.slurm_util import get_local_scratch_dir, is_on_slurm | |
| if "__main__" == __name__: | |
| t_start = datetime.now() | |
| print(f"start at {t_start}") | |
| # -------------------- Arguments -------------------- | |
| parser = argparse.ArgumentParser(description="Train your cute model!") | |
| parser.add_argument( | |
| "--config", | |
| type=str, | |
| default="config/train_marigold.yaml", | |
| help="Path to config file.", | |
| ) | |
| parser.add_argument( | |
| "--resume_run", | |
| action="store", | |
| default=None, | |
| help="Path of checkpoint to be resumed. If given, will ignore --config, and checkpoint in the config", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", type=str, default=None, help="directory to save checkpoints" | |
| ) | |
| parser.add_argument("--no_cuda", action="store_true", help="Do not use cuda.") | |
| parser.add_argument( | |
| "--exit_after", | |
| type=int, | |
| default=-1, | |
| help="Save checkpoint and exit after X minutes.", | |
| ) | |
| parser.add_argument("--no_wandb", action="store_true", help="run without wandb") | |
| parser.add_argument( | |
| "--do_not_copy_data", | |
| action="store_true", | |
| help="On Slurm cluster, do not copy data to local scratch", | |
| ) | |
| parser.add_argument( | |
| "--base_data_dir", type=str, default=None, help="directory of training data" | |
| ) | |
| parser.add_argument( | |
| "--base_ckpt_dir", | |
| type=str, | |
| default=None, | |
| help="directory of pretrained checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--add_datetime_prefix", | |
| action="store_true", | |
| help="Add datetime to the output folder name", | |
| ) | |
| args = parser.parse_args() | |
| resume_run = args.resume_run | |
| output_dir = args.output_dir | |
| base_data_dir = ( | |
| args.base_data_dir | |
| if args.base_data_dir is not None | |
| else os.environ["BASE_DATA_DIR"] | |
| ) | |
| base_ckpt_dir = ( | |
| args.base_ckpt_dir | |
| if args.base_ckpt_dir is not None | |
| else os.environ["BASE_CKPT_DIR"] | |
| ) | |
| # -------------------- Initialization -------------------- | |
| # Resume previous run | |
| if resume_run is not None: | |
| print(f"Resume run: {resume_run}") | |
| out_dir_run = os.path.dirname(os.path.dirname(resume_run)) | |
| job_name = os.path.basename(out_dir_run) | |
| # Resume config file | |
| cfg = OmegaConf.load(os.path.join(out_dir_run, "config.yaml")) | |
| else: | |
| # Run from start | |
| cfg = recursive_load_config(args.config) | |
| # Full job name | |
| pure_job_name = os.path.basename(args.config).split(".")[0] | |
| # Add time prefix | |
| if args.add_datetime_prefix: | |
| job_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_job_name}" | |
| else: | |
| job_name = pure_job_name | |
| # Output dir | |
| if output_dir is not None: | |
| out_dir_run = os.path.join(output_dir, job_name) | |
| else: | |
| out_dir_run = os.path.join("./output", job_name) | |
| os.makedirs(out_dir_run, exist_ok=False) | |
| cfg_data = cfg.dataset | |
| # Other directories | |
| out_dir_ckpt = os.path.join(out_dir_run, "checkpoint") | |
| if not os.path.exists(out_dir_ckpt): | |
| os.makedirs(out_dir_ckpt) | |
| out_dir_tb = os.path.join(out_dir_run, "tensorboard") | |
| if not os.path.exists(out_dir_tb): | |
| os.makedirs(out_dir_tb) | |
| out_dir_eval = os.path.join(out_dir_run, "evaluation") | |
| if not os.path.exists(out_dir_eval): | |
| os.makedirs(out_dir_eval) | |
| out_dir_vis = os.path.join(out_dir_run, "visualization") | |
| if not os.path.exists(out_dir_vis): | |
| os.makedirs(out_dir_vis) | |
| # -------------------- Logging settings -------------------- | |
| config_logging(cfg.logging, out_dir=out_dir_run) | |
| logging.debug(f"config: {cfg}") | |
| # Initialize wandb | |
| if not args.no_wandb: | |
| if resume_run is not None: | |
| wandb_id = load_wandb_job_id(out_dir_run) | |
| wandb_cfg_dic = { | |
| "id": wandb_id, | |
| "resume": "must", | |
| **cfg.wandb, | |
| } | |
| else: | |
| wandb_cfg_dic = { | |
| "config": dict(cfg), | |
| "name": job_name, | |
| "mode": "online", | |
| **cfg.wandb, | |
| } | |
| wandb_cfg_dic.update({"dir": out_dir_run}) | |
| wandb_run = init_wandb(enable=True, **wandb_cfg_dic) | |
| save_wandb_job_id(wandb_run, out_dir_run) | |
| else: | |
| init_wandb(enable=False) | |
| # Tensorboard (should be initialized after wandb) | |
| tb_logger.set_dir(out_dir_tb) | |
| log_slurm_job_id(step=0) | |
| # -------------------- Device -------------------- | |
| cuda_avail = torch.cuda.is_available() and not args.no_cuda | |
| device = torch.device("cuda" if cuda_avail else "cpu") | |
| logging.info(f"device = {device}") | |
| # -------------------- Snapshot of code and config -------------------- | |
| if resume_run is None: | |
| _output_path = os.path.join(out_dir_run, "config.yaml") | |
| with open(_output_path, "w+") as f: | |
| OmegaConf.save(config=cfg, f=f) | |
| logging.info(f"Config saved to {_output_path}") | |
| # Copy and tar code on the first run | |
| _temp_code_dir = os.path.join(out_dir_run, "code_tar") | |
| _code_snapshot_path = os.path.join(out_dir_run, "code_snapshot.tar") | |
| os.system( | |
| f"rsync --relative -arhvz --quiet --filter=':- .gitignore' --exclude '.git' . '{_temp_code_dir}'" | |
| ) | |
| os.system(f"tar -cf {_code_snapshot_path} {_temp_code_dir}") | |
| os.system(f"rm -rf {_temp_code_dir}") | |
| logging.info(f"Code snapshot saved to: {_code_snapshot_path}") | |
| # -------------------- Copy data to local scratch (Slurm) -------------------- | |
| if is_on_slurm() and (not args.do_not_copy_data): | |
| # local scratch dir | |
| original_data_dir = base_data_dir | |
| base_data_dir = os.path.join(get_local_scratch_dir(), "Marigold_data") | |
| # copy data | |
| required_data_list = find_value_in_omegaconf("dir", cfg_data) | |
| # if cfg_train.visualize.init_latent_path is not None: | |
| # required_data_list.append(cfg_train.visualize.init_latent_path) | |
| required_data_list = list(set(required_data_list)) | |
| logging.info(f"Required_data_list: {required_data_list}") | |
| for d in tqdm(required_data_list, desc="Copy data to local scratch"): | |
| ori_dir = os.path.join(original_data_dir, d) | |
| dst_dir = os.path.join(base_data_dir, d) | |
| os.makedirs(os.path.dirname(dst_dir), exist_ok=True) | |
| if os.path.isfile(ori_dir): | |
| shutil.copyfile(ori_dir, dst_dir) | |
| elif os.path.isdir(ori_dir): | |
| shutil.copytree(ori_dir, dst_dir) | |
| logging.info(f"Data copied to: {base_data_dir}") | |
| # -------------------- Gradient accumulation steps -------------------- | |
| eff_bs = cfg.dataloader.effective_batch_size | |
| accumulation_steps = eff_bs / cfg.dataloader.max_train_batch_size | |
| assert int(accumulation_steps) == accumulation_steps | |
| accumulation_steps = int(accumulation_steps) | |
| logging.info( | |
| f"Effective batch size: {eff_bs}, accumulation steps: {accumulation_steps}" | |
| ) | |
| # -------------------- Data -------------------- | |
| loader_seed = cfg.dataloader.seed | |
| if loader_seed is None: | |
| loader_generator = None | |
| else: | |
| loader_generator = torch.Generator().manual_seed(loader_seed) | |
| # Training dataset | |
| depth_transform: DepthNormalizerBase = get_depth_normalizer( | |
| cfg_normalizer=cfg.depth_normalization | |
| ) | |
| train_dataset: BaseDepthDataset = get_dataset( | |
| cfg_data.train, | |
| base_data_dir=base_data_dir, | |
| mode=DatasetMode.TRAIN, | |
| augmentation_args=cfg.augmentation, | |
| depth_transform=depth_transform, | |
| ) | |
| logging.debug("Augmentation: ", cfg.augmentation) | |
| if "mixed" == cfg_data.train.name: | |
| dataset_ls = train_dataset | |
| assert len(cfg_data.train.prob_ls) == len( | |
| dataset_ls | |
| ), "Lengths don't match: `prob_ls` and `dataset_list`" | |
| concat_dataset = ConcatDataset(dataset_ls) | |
| mixed_sampler = MixedBatchSampler( | |
| src_dataset_ls=dataset_ls, | |
| batch_size=cfg.dataloader.max_train_batch_size, | |
| drop_last=True, | |
| prob=cfg_data.train.prob_ls, | |
| shuffle=True, | |
| generator=loader_generator, | |
| ) | |
| train_loader = DataLoader( | |
| concat_dataset, | |
| batch_sampler=mixed_sampler, | |
| num_workers=cfg.dataloader.num_workers, | |
| ) | |
| else: | |
| train_loader = DataLoader( | |
| dataset=train_dataset, | |
| batch_size=cfg.dataloader.max_train_batch_size, | |
| num_workers=cfg.dataloader.num_workers, | |
| shuffle=True, | |
| generator=loader_generator, | |
| ) | |
| # Validation dataset | |
| val_loaders: List[DataLoader] = [] | |
| for _val_dic in cfg_data.val: | |
| _val_dataset = get_dataset( | |
| _val_dic, | |
| base_data_dir=base_data_dir, | |
| mode=DatasetMode.EVAL, | |
| ) | |
| _val_loader = DataLoader( | |
| dataset=_val_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=cfg.dataloader.num_workers, | |
| ) | |
| val_loaders.append(_val_loader) | |
| # Visualization dataset | |
| vis_loaders: List[DataLoader] = [] | |
| for _vis_dic in cfg_data.vis: | |
| _vis_dataset = get_dataset( | |
| _vis_dic, | |
| base_data_dir=base_data_dir, | |
| mode=DatasetMode.EVAL, | |
| ) | |
| _vis_loader = DataLoader( | |
| dataset=_vis_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=cfg.dataloader.num_workers, | |
| ) | |
| vis_loaders.append(_vis_loader) | |
| # -------------------- Model -------------------- | |
| _pipeline_kwargs = cfg.pipeline.kwargs if cfg.pipeline.kwargs is not None else {} | |
| model = MarigoldPipeline.from_pretrained( | |
| os.path.join(base_ckpt_dir, cfg.model.pretrained_path), **_pipeline_kwargs | |
| ) | |
| # -------------------- Trainer -------------------- | |
| # Exit time | |
| if args.exit_after > 0: | |
| t_end = t_start + timedelta(minutes=args.exit_after) | |
| logging.info(f"Will exit at {t_end}") | |
| else: | |
| t_end = None | |
| trainer_cls = get_trainer_cls(cfg.trainer.name) | |
| logging.debug(f"Trainer: {trainer_cls}") | |
| trainer = trainer_cls( | |
| cfg=cfg, | |
| model=model, | |
| train_dataloader=train_loader, | |
| device=device, | |
| base_ckpt_dir=base_ckpt_dir, | |
| out_dir_ckpt=out_dir_ckpt, | |
| out_dir_eval=out_dir_eval, | |
| out_dir_vis=out_dir_vis, | |
| accumulation_steps=accumulation_steps, | |
| val_dataloaders=val_loaders, | |
| vis_dataloaders=vis_loaders, | |
| ) | |
| # -------------------- Checkpoint -------------------- | |
| if resume_run is not None: | |
| trainer.load_checkpoint( | |
| resume_run, load_trainer_state=True, resume_lr_scheduler=True | |
| ) | |
| # -------------------- Training & Evaluation Loop -------------------- | |
| try: | |
| trainer.train(t_end=t_end) | |
| except Exception as e: | |
| logging.exception(e) | |