|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import torch |
|
|
import yaml |
|
|
from yacs.config import CfgNode as CN |
|
|
|
|
|
|
|
|
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) |
|
|
|
|
|
_C = CN() |
|
|
|
|
|
|
|
|
_C.BASE = [''] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.DATA = CN() |
|
|
|
|
|
_C.DATA.BATCH_SIZE = 128 |
|
|
|
|
|
_C.DATA.DATA_PATH = '' |
|
|
|
|
|
_C.DATA.DATASET = 'imagenet' |
|
|
|
|
|
_C.DATA.IMG_SIZE = 224 |
|
|
|
|
|
_C.DATA.INTERPOLATION = 'bicubic' |
|
|
|
|
|
|
|
|
_C.DATA.ZIP_MODE = False |
|
|
|
|
|
_C.DATA.CACHE_MODE = 'part' |
|
|
|
|
|
_C.DATA.PIN_MEMORY = True |
|
|
|
|
|
_C.DATA.NUM_WORKERS = 8 |
|
|
|
|
|
|
|
|
_C.DATA.MASK_PATCH_SIZE = 32 |
|
|
|
|
|
_C.DATA.MASK_RATIO = 0.6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.MODEL = CN() |
|
|
|
|
|
_C.MODEL.TYPE = 'swin' |
|
|
|
|
|
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224' |
|
|
|
|
|
|
|
|
_C.MODEL.PRETRAINED = '' |
|
|
|
|
|
_C.MODEL.RESUME = '' |
|
|
|
|
|
_C.MODEL.NUM_CLASSES = 1000 |
|
|
|
|
|
_C.MODEL.DROP_RATE = 0.0 |
|
|
|
|
|
_C.MODEL.DROP_PATH_RATE = 0.1 |
|
|
|
|
|
_C.MODEL.LABEL_SMOOTHING = 0.1 |
|
|
|
|
|
|
|
|
_C.MODEL.SWIN = CN() |
|
|
_C.MODEL.SWIN.PATCH_SIZE = 4 |
|
|
_C.MODEL.SWIN.IN_CHANS = 3 |
|
|
_C.MODEL.SWIN.EMBED_DIM = 96 |
|
|
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] |
|
|
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] |
|
|
_C.MODEL.SWIN.WINDOW_SIZE = 7 |
|
|
_C.MODEL.SWIN.MLP_RATIO = 4. |
|
|
_C.MODEL.SWIN.QKV_BIAS = True |
|
|
_C.MODEL.SWIN.QK_SCALE = None |
|
|
_C.MODEL.SWIN.APE = False |
|
|
_C.MODEL.SWIN.PATCH_NORM = True |
|
|
|
|
|
|
|
|
_C.MODEL.SWINV2 = CN() |
|
|
_C.MODEL.SWINV2.PATCH_SIZE = 4 |
|
|
_C.MODEL.SWINV2.IN_CHANS = 3 |
|
|
_C.MODEL.SWINV2.EMBED_DIM = 96 |
|
|
_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] |
|
|
_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] |
|
|
_C.MODEL.SWINV2.WINDOW_SIZE = 7 |
|
|
_C.MODEL.SWINV2.MLP_RATIO = 4. |
|
|
_C.MODEL.SWINV2.QKV_BIAS = True |
|
|
_C.MODEL.SWINV2.APE = False |
|
|
_C.MODEL.SWINV2.PATCH_NORM = True |
|
|
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] |
|
|
|
|
|
|
|
|
_C.MODEL.SWIN_MOE = CN() |
|
|
_C.MODEL.SWIN_MOE.PATCH_SIZE = 4 |
|
|
_C.MODEL.SWIN_MOE.IN_CHANS = 3 |
|
|
_C.MODEL.SWIN_MOE.EMBED_DIM = 96 |
|
|
_C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2] |
|
|
_C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24] |
|
|
_C.MODEL.SWIN_MOE.WINDOW_SIZE = 7 |
|
|
_C.MODEL.SWIN_MOE.MLP_RATIO = 4. |
|
|
_C.MODEL.SWIN_MOE.QKV_BIAS = True |
|
|
_C.MODEL.SWIN_MOE.QK_SCALE = None |
|
|
_C.MODEL.SWIN_MOE.APE = False |
|
|
_C.MODEL.SWIN_MOE.PATCH_NORM = True |
|
|
_C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True |
|
|
_C.MODEL.SWIN_MOE.INIT_STD = 0.02 |
|
|
_C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] |
|
|
_C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]] |
|
|
_C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1 |
|
|
_C.MODEL.SWIN_MOE.TOP_VALUE = 1 |
|
|
_C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25 |
|
|
_C.MODEL.SWIN_MOE.COSINE_ROUTER = False |
|
|
_C.MODEL.SWIN_MOE.NORMALIZE_GATE = False |
|
|
_C.MODEL.SWIN_MOE.USE_BPR = True |
|
|
_C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False |
|
|
_C.MODEL.SWIN_MOE.GATE_NOISE = 1.0 |
|
|
_C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256 |
|
|
_C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5 |
|
|
_C.MODEL.SWIN_MOE.MOE_DROP = 0.0 |
|
|
_C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01 |
|
|
|
|
|
|
|
|
_C.MODEL.SWIN_MLP = CN() |
|
|
_C.MODEL.SWIN_MLP.PATCH_SIZE = 4 |
|
|
_C.MODEL.SWIN_MLP.IN_CHANS = 3 |
|
|
_C.MODEL.SWIN_MLP.EMBED_DIM = 96 |
|
|
_C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] |
|
|
_C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] |
|
|
_C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 |
|
|
_C.MODEL.SWIN_MLP.MLP_RATIO = 4. |
|
|
_C.MODEL.SWIN_MLP.APE = False |
|
|
_C.MODEL.SWIN_MLP.PATCH_NORM = True |
|
|
|
|
|
|
|
|
_C.MODEL.SIMMIM = CN() |
|
|
_C.MODEL.SIMMIM.NORM_TARGET = CN() |
|
|
_C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False |
|
|
_C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.TRAIN = CN() |
|
|
_C.TRAIN.START_EPOCH = 0 |
|
|
_C.TRAIN.EPOCHS = 300 |
|
|
_C.TRAIN.WARMUP_EPOCHS = 20 |
|
|
_C.TRAIN.WEIGHT_DECAY = 0.05 |
|
|
_C.TRAIN.BASE_LR = 5e-4 |
|
|
_C.TRAIN.WARMUP_LR = 5e-7 |
|
|
_C.TRAIN.MIN_LR = 5e-6 |
|
|
|
|
|
_C.TRAIN.CLIP_GRAD = 5.0 |
|
|
|
|
|
_C.TRAIN.AUTO_RESUME = True |
|
|
|
|
|
|
|
|
_C.TRAIN.ACCUMULATION_STEPS = 1 |
|
|
|
|
|
|
|
|
_C.TRAIN.USE_CHECKPOINT = False |
|
|
|
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER = CN() |
|
|
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True |
|
|
|
|
|
_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 |
|
|
_C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] |
|
|
|
|
|
|
|
|
_C.TRAIN.OPTIMIZER = CN() |
|
|
_C.TRAIN.OPTIMIZER.NAME = 'adamw' |
|
|
|
|
|
_C.TRAIN.OPTIMIZER.EPS = 1e-8 |
|
|
|
|
|
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) |
|
|
|
|
|
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 |
|
|
|
|
|
|
|
|
_C.TRAIN.LAYER_DECAY = 1.0 |
|
|
|
|
|
|
|
|
_C.TRAIN.MOE = CN() |
|
|
|
|
|
_C.TRAIN.MOE.SAVE_MASTER = False |
|
|
|
|
|
|
|
|
|
|
|
_C.AUG = CN() |
|
|
|
|
|
_C.AUG.COLOR_JITTER = 0.4 |
|
|
|
|
|
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' |
|
|
|
|
|
_C.AUG.REPROB = 0.25 |
|
|
|
|
|
_C.AUG.REMODE = 'pixel' |
|
|
|
|
|
_C.AUG.RECOUNT = 1 |
|
|
|
|
|
_C.AUG.MIXUP = 0.8 |
|
|
|
|
|
_C.AUG.CUTMIX = 1.0 |
|
|
|
|
|
_C.AUG.CUTMIX_MINMAX = None |
|
|
|
|
|
_C.AUG.MIXUP_PROB = 1.0 |
|
|
|
|
|
_C.AUG.MIXUP_SWITCH_PROB = 0.5 |
|
|
|
|
|
_C.AUG.MIXUP_MODE = 'batch' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.TEST = CN() |
|
|
|
|
|
_C.TEST.CROP = True |
|
|
|
|
|
_C.TEST.SEQUENTIAL = False |
|
|
_C.TEST.SHUFFLE = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_C.ENABLE_AMP = False |
|
|
|
|
|
|
|
|
_C.AMP_ENABLE = True |
|
|
|
|
|
_C.AMP_OPT_LEVEL = '' |
|
|
|
|
|
_C.OUTPUT = '' |
|
|
|
|
|
_C.TAG = 'default' |
|
|
|
|
|
_C.SAVE_FREQ = 1 |
|
|
|
|
|
_C.PRINT_FREQ = 10 |
|
|
|
|
|
_C.SEED = 0 |
|
|
|
|
|
_C.EVAL_MODE = False |
|
|
|
|
|
_C.THROUGHPUT_MODE = False |
|
|
|
|
|
_C.LOCAL_RANK = 0 |
|
|
|
|
|
_C.FUSED_WINDOW_PROCESS = False |
|
|
_C.FUSED_LAYERNORM = False |
|
|
|
|
|
|
|
|
def _update_config_from_file(config, cfg_file): |
|
|
config.defrost() |
|
|
with open(cfg_file, 'r') as f: |
|
|
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
|
|
for cfg in yaml_cfg.setdefault('BASE', ['']): |
|
|
if cfg: |
|
|
_update_config_from_file( |
|
|
config, os.path.join(os.path.dirname(cfg_file), cfg) |
|
|
) |
|
|
print('=> merge config from {}'.format(cfg_file)) |
|
|
config.merge_from_file(cfg_file) |
|
|
config.freeze() |
|
|
|
|
|
|
|
|
def update_config(config, args): |
|
|
_update_config_from_file(config, args.cfg) |
|
|
|
|
|
config.defrost() |
|
|
if args.opts: |
|
|
config.merge_from_list(args.opts) |
|
|
|
|
|
def _check_args(name): |
|
|
if hasattr(args, name) and eval(f'args.{name}'): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
if _check_args('batch_size'): |
|
|
config.DATA.BATCH_SIZE = args.batch_size |
|
|
if _check_args('data_path'): |
|
|
config.DATA.DATA_PATH = args.data_path |
|
|
if _check_args('zip'): |
|
|
config.DATA.ZIP_MODE = True |
|
|
if _check_args('cache_mode'): |
|
|
config.DATA.CACHE_MODE = args.cache_mode |
|
|
if _check_args('pretrained'): |
|
|
config.MODEL.PRETRAINED = args.pretrained |
|
|
if _check_args('resume'): |
|
|
config.MODEL.RESUME = args.resume |
|
|
if _check_args('accumulation_steps'): |
|
|
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps |
|
|
if _check_args('use_checkpoint'): |
|
|
config.TRAIN.USE_CHECKPOINT = True |
|
|
if _check_args('amp_opt_level'): |
|
|
print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") |
|
|
if args.amp_opt_level == 'O0': |
|
|
config.AMP_ENABLE = False |
|
|
if _check_args('disable_amp'): |
|
|
config.AMP_ENABLE = False |
|
|
if _check_args('output'): |
|
|
config.OUTPUT = args.output |
|
|
if _check_args('tag'): |
|
|
config.TAG = args.tag |
|
|
if _check_args('eval'): |
|
|
config.EVAL_MODE = True |
|
|
if _check_args('throughput'): |
|
|
config.THROUGHPUT_MODE = True |
|
|
|
|
|
|
|
|
if _check_args('enable_amp'): |
|
|
config.ENABLE_AMP = args.enable_amp |
|
|
|
|
|
|
|
|
if _check_args('fused_window_process'): |
|
|
config.FUSED_WINDOW_PROCESS = True |
|
|
if _check_args('fused_layernorm'): |
|
|
config.FUSED_LAYERNORM = True |
|
|
|
|
|
if _check_args('optim'): |
|
|
config.TRAIN.OPTIMIZER.NAME = args.optim |
|
|
|
|
|
|
|
|
if PYTORCH_MAJOR_VERSION == 1: |
|
|
config.LOCAL_RANK = args.local_rank |
|
|
else: |
|
|
config.LOCAL_RANK = int(os.environ['LOCAL_RANK']) |
|
|
|
|
|
|
|
|
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) |
|
|
|
|
|
config.freeze() |
|
|
|
|
|
|
|
|
def get_config(args): |
|
|
"""Get a yacs CfgNode object with default values.""" |
|
|
|
|
|
|
|
|
config = _C.clone() |
|
|
update_config(config, args) |
|
|
|
|
|
return config |
|
|
|