File size: 51,624 Bytes
7667a87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 |
from omegaconf import OmegaConf
import os
from shutil import copyfile
import warnings
from typing import Dict,Sequence,Union
import inspect
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
import torchmetrics
from lightning.pytorch import Trainer, loggers as pl_loggers
from lightning.pytorch.profilers import PyTorchProfiler
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import (
Callback, LearningRateMonitor, DeviceStatsMonitor,
EarlyStopping, ModelCheckpoint
)
from lightning.pytorch.utilities import grad_norm
from einops import rearrange
from models.vae import AutoencoderKL
from models.knowledge_alignment import SEVIRAvgIntensityAlignment,get_alignment_kwargs_avg_x
from models.diffusion import LatentDiffusion
from models.core_model.cuboid_transformer import CuboidTransformerUNet
from datamodule import SEVIRLightningDataModule,vis_sevir_seq
from utils.path import (
default_exps_dir,
default_pretrained_vae_dir,default_pretrained_alignment_dir
)
from utils.optim import disable_train,warmup_lambda
from utils.layout import step_layout_to_in_out_slice
from evaluation import FrechetVideoDistance,SEVIRSkillScore
class PreDiffSEVIRPLModule(LatentDiffusion):
def __init__(self,
total_num_steps: int,
oc_file: str = None,
save_dir: str = None):
self.total_num_steps = total_num_steps
if oc_file is not None:
oc_from_file = OmegaConf.load(open(oc_file, "r"))
else:
oc_from_file = None
oc = self.get_base_config(oc_from_file=oc_from_file)
self.save_hyperparameters(oc)
self.oc = oc
latent_model_cfg = OmegaConf.to_object(oc.model.latent_model)
num_blocks = len(latent_model_cfg["depth"])
if isinstance(latent_model_cfg["self_pattern"], str):
block_attn_patterns = [latent_model_cfg["self_pattern"]] * num_blocks
else:
block_attn_patterns = OmegaConf.to_container(latent_model_cfg["self_pattern"])
latent_model = CuboidTransformerUNet(
input_shape=latent_model_cfg["input_shape"],
target_shape=latent_model_cfg["target_shape"],
base_units=latent_model_cfg["base_units"],
scale_alpha=latent_model_cfg["scale_alpha"],
num_heads=latent_model_cfg["num_heads"],
attn_drop=latent_model_cfg["attn_drop"],
proj_drop=latent_model_cfg["proj_drop"],
ffn_drop=latent_model_cfg["ffn_drop"],
# inter-attn downsample/upsample
downsample=latent_model_cfg["downsample"],
downsample_type=latent_model_cfg["downsample_type"],
upsample_type=latent_model_cfg["upsample_type"],
upsample_kernel_size=latent_model_cfg["upsample_kernel_size"],
# attention
depth=latent_model_cfg["depth"],
block_attn_patterns=block_attn_patterns,
# global vectors
num_global_vectors=latent_model_cfg["num_global_vectors"],
use_global_vector_ffn=latent_model_cfg["use_global_vector_ffn"],
use_global_self_attn=latent_model_cfg["use_global_self_attn"],
separate_global_qkv=latent_model_cfg["separate_global_qkv"],
global_dim_ratio=latent_model_cfg["global_dim_ratio"],
# misc
ffn_activation=latent_model_cfg["ffn_activation"],
gated_ffn=latent_model_cfg["gated_ffn"],
norm_layer=latent_model_cfg["norm_layer"],
padding_type=latent_model_cfg["padding_type"],
checkpoint_level=latent_model_cfg["checkpoint_level"],
pos_embed_type=latent_model_cfg["pos_embed_type"],
use_relative_pos=latent_model_cfg["use_relative_pos"],
self_attn_use_final_proj=latent_model_cfg["self_attn_use_final_proj"],
# initialization
attn_linear_init_mode=latent_model_cfg["attn_linear_init_mode"],
ffn_linear_init_mode=latent_model_cfg["ffn_linear_init_mode"],
ffn2_linear_init_mode=latent_model_cfg["ffn2_linear_init_mode"],
attn_proj_linear_init_mode=latent_model_cfg["attn_proj_linear_init_mode"],
conv_init_mode=latent_model_cfg["conv_init_mode"],
down_linear_init_mode=latent_model_cfg["down_up_linear_init_mode"],
up_linear_init_mode=latent_model_cfg["down_up_linear_init_mode"],
global_proj_linear_init_mode=latent_model_cfg["global_proj_linear_init_mode"],
norm_init_mode=latent_model_cfg["norm_init_mode"],
# timestep embedding for diffusion
time_embed_channels_mult=latent_model_cfg["time_embed_channels_mult"],
time_embed_use_scale_shift_norm=latent_model_cfg["time_embed_use_scale_shift_norm"],
time_embed_dropout=latent_model_cfg["time_embed_dropout"],
unet_res_connect=latent_model_cfg["unet_res_connect"]
)
vae_cfg = OmegaConf.to_object(oc.model.vae)
first_stage_model = AutoencoderKL(
down_block_types=vae_cfg["down_block_types"],
in_channels=vae_cfg["in_channels"],
block_out_channels=vae_cfg["block_out_channels"],
act_fn=vae_cfg["act_fn"],
latent_channels=vae_cfg["latent_channels"],
up_block_types=vae_cfg["up_block_types"],
norm_num_groups=vae_cfg["norm_num_groups"],
layers_per_block=vae_cfg["layers_per_block"],
out_channels=vae_cfg["out_channels"], )
pretrained_ckpt_path = vae_cfg["pretrained_ckpt_path"]
if pretrained_ckpt_path is not None:
state_dict = torch.load(os.path.join(default_pretrained_vae_dir, vae_cfg["pretrained_ckpt_path"]),
map_location=torch.device("cpu"))
first_stage_model.load_state_dict(state_dict=state_dict)
else:
warnings.warn(f"Pretrained weights for `AutoencoderKL` not set. Run for sanity check only.")
diffusion_cfg = OmegaConf.to_object(oc.model.diffusion)
super(PreDiffSEVIRPLModule, self).__init__(
torch_nn_module=latent_model,
layout=oc.layout.layout,
data_shape=diffusion_cfg["data_shape"],
timesteps=diffusion_cfg["timesteps"],
beta_schedule=diffusion_cfg["beta_schedule"],
loss_type=self.oc.optim.loss_type,
monitor=self.oc.optim.monitor,
use_ema=diffusion_cfg["use_ema"],
log_every_t=diffusion_cfg["log_every_t"],
clip_denoised=diffusion_cfg["clip_denoised"],
linear_start=diffusion_cfg["linear_start"],
linear_end=diffusion_cfg["linear_end"],
cosine_s=diffusion_cfg["cosine_s"],
given_betas=diffusion_cfg["given_betas"],
original_elbo_weight=diffusion_cfg["original_elbo_weight"],
v_posterior=diffusion_cfg["v_posterior"],
l_simple_weight=diffusion_cfg["l_simple_weight"],
parameterization=diffusion_cfg["parameterization"],
learn_logvar=diffusion_cfg["learn_logvar"],
logvar_init=diffusion_cfg["logvar_init"],
# latent diffusion
latent_shape=diffusion_cfg["latent_shape"],
first_stage_model=first_stage_model,
cond_stage_model=diffusion_cfg["cond_stage_model"],
num_timesteps_cond=diffusion_cfg["num_timesteps_cond"],
cond_stage_trainable=diffusion_cfg["cond_stage_trainable"],
cond_stage_forward=diffusion_cfg["cond_stage_forward"],
scale_by_std=diffusion_cfg["scale_by_std"],
scale_factor=diffusion_cfg["scale_factor"], )
# knowledge alignment
knowledge_alignment_cfg = OmegaConf.to_object(oc.model.align)
self.alignment_type = knowledge_alignment_cfg["alignment_type"]
self.use_alignment = self.alignment_type is not None
if self.use_alignment:
alignment_ckpt_path = os.path.join(default_pretrained_alignment_dir, knowledge_alignment_cfg["model_ckpt_path"])
self.alignment_obj = SEVIRAvgIntensityAlignment(
alignment_type=knowledge_alignment_cfg["alignment_type"],
guide_scale=knowledge_alignment_cfg["guide_scale"],
model_type=knowledge_alignment_cfg["model_type"],
model_args=knowledge_alignment_cfg["model_args"],
model_ckpt_path=alignment_ckpt_path
)
disable_train(self.alignment_obj.model)
self.alignment_model = self.alignment_obj.model
alignment_fn = self.alignment_obj.get_mean_shift
else:
alignment_fn = None
self.set_alignment(alignment_fn=alignment_fn)
# lr_scheduler
self.total_num_steps = total_num_steps
# logging
self.save_dir = save_dir
self.logging_prefix = oc.logging.logging_prefix
# visualization
self.train_example_data_idx_list = list(oc.eval.train_example_data_idx_list)
self.val_example_data_idx_list = list(oc.eval.val_example_data_idx_list)
self.test_example_data_idx_list = list(oc.eval.test_example_data_idx_list)
self.eval_example_only = oc.eval.eval_example_only
if self.oc.eval.eval_unaligned:
self.valid_mse = torchmetrics.MeanSquaredError()
self.valid_mae = torchmetrics.MeanAbsoluteError()
self.valid_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4
)
self.test_mse = torchmetrics.MeanSquaredError()
self.test_mae = torchmetrics.MeanAbsoluteError()
self.test_ssim = torchmetrics.image.StructuralSimilarityIndexMeasure()
self.test_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4
)
self.test_fvd = FrechetVideoDistance(
feature=self.oc.eval.fvd_features,
layout=self.layout,
reset_real_features=False,
normalize=False,
auto_t=True, )
if self.oc.eval.eval_aligned:
self.valid_aligned_mse = torchmetrics.MeanSquaredError()
self.valid_aligned_mae = torchmetrics.MeanAbsoluteError()
self.valid_aligned_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4, )
self.test_aligned_mse = torchmetrics.MeanSquaredError()
self.test_aligned_mae = torchmetrics.MeanAbsoluteError()
self.test_aligned_ssim = torchmetrics.image.StructuralSimilarityIndexMeasure()
self.test_aligned_score = SEVIRSkillScore(
mode=self.oc.dataset.metrics_mode,
seq_len=self.oc.layout.out_len,
layout=self.layout,
threshold_list=self.oc.dataset.threshold_list,
metrics_list=self.oc.dataset.metrics_list,
eps=1e-4, )
self.test_aligned_fvd = FrechetVideoDistance(
feature=self.oc.eval.fvd_features,
layout=self.layout,
reset_real_features=False,
normalize=False,
auto_t=True, )
self.configure_save(cfg_file_path=oc_file)
def configure_save(self, cfg_file_path=None):
self.save_dir = os.path.join(default_exps_dir, self.save_dir)
os.makedirs(self.save_dir, exist_ok=True)
if cfg_file_path is not None:
cfg_file_target_path = os.path.join(self.save_dir, "cfg.yaml")
if (not os.path.exists(cfg_file_target_path)) or \
(not os.path.samefile(cfg_file_path, cfg_file_target_path)):
copyfile(cfg_file_path, cfg_file_target_path)
self.example_save_dir = os.path.join(self.save_dir, "examples")
os.makedirs(self.example_save_dir, exist_ok=True)
self.npy_save_dir = os.path.join(self.save_dir, "npy")
os.makedirs(self.npy_save_dir, exist_ok=True)
# region Get Default Config
def get_base_config(self, oc_from_file=None):
oc = OmegaConf.create()
oc.layout = self.get_layout_config()
oc.optim = self.get_optim_config()
oc.logging = self.get_logging_config()
oc.trainer = self.get_trainer_config()
oc.eval = self.get_eval_config()
oc.model = self.get_model_config()
oc.dataset = self.get_dataset_config()
if oc_from_file is not None:
# oc = apply_omegaconf_overrides(oc, oc_from_file)
oc = OmegaConf.merge(oc, oc_from_file)
return oc
@staticmethod
def get_layout_config():
cfg = OmegaConf.create()
cfg.in_len = 7
cfg.out_len = 6
cfg.in_step=1
cfg.out_step=1
cfg.in_out_diff=1
cfg.img_height = 128
cfg.img_width = 128
cfg.data_channels = 4
cfg.layout = "NTHWC"
return cfg
@staticmethod
def get_model_config():
cfg = OmegaConf.create()
layout_cfg = PreDiffSEVIRPLModule.get_layout_config()
cfg.diffusion = OmegaConf.create()
cfg.diffusion.data_shape = (layout_cfg.out_len,
layout_cfg.img_height,
layout_cfg.img_width,
layout_cfg.data_channels)
cfg.diffusion.timesteps = 1000
cfg.diffusion.beta_schedule = "linear"
cfg.diffusion.use_ema = True
cfg.diffusion.log_every_t = 100 # log every `log_every_t` timesteps. Must be smaller than `timesteps`.
cfg.diffusion.clip_denoised = False
cfg.diffusion.linear_start = 1e-4
cfg.diffusion.linear_end = 2e-2
cfg.diffusion.cosine_s = 8e-3
cfg.diffusion.given_betas = None
cfg.diffusion.original_elbo_weight = 0.
cfg.diffusion.v_posterior = 0.
cfg.diffusion.l_simple_weight = 1.
cfg.diffusion.parameterization = "eps"
cfg.diffusion.learn_logvar = None
cfg.diffusion.logvar_init = 0.
# latent diffusion
cfg.diffusion.latent_shape = [10, 16, 16, 4]
cfg.diffusion.cond_stage_model = "__is_first_stage__"
cfg.diffusion.num_timesteps_cond = None
cfg.diffusion.cond_stage_trainable = False
cfg.diffusion.cond_stage_forward = None
cfg.diffusion.scale_by_std = False
cfg.diffusion.scale_factor = 1.0
cfg.diffusion.latent_cond_shape = [10, 16, 16, 4]
# knowledge alignment
cfg.align = OmegaConf.create()
cfg.align.alignment_type = None
cfg.align.guide_scale = 1.0
cfg.align.model_type = "cuboid"
cfg.align.model_ckpt_path = "tmp.pt"
cfg.align.model_args = OmegaConf.create()
# Earthformer
cfg.align.model_args.input_shape = [6, 16, 16, 4]
cfg.align.model_args.out_channels = 2
cfg.align.model_args.base_units = 16
cfg.align.model_args.block_units = None
cfg.align.model_args.scale_alpha = 1.0
cfg.align.model_args.depth = [1, 1]
cfg.align.model_args.downsample = 2
cfg.align.model_args.downsample_type = "patch_merge"
cfg.align.model_args.block_attn_patterns = "axial"
cfg.align.model_args.num_heads = 4
cfg.align.model_args.attn_drop = 0.0
cfg.align.model_args.proj_drop = 0.0
cfg.align.model_args.ffn_drop = 0.0
cfg.align.model_args.ffn_activation = "gelu"
cfg.align.model_args.gated_ffn = False
cfg.align.model_args.norm_layer = "layer_norm"
cfg.align.model_args.use_inter_ffn = True
cfg.align.model_args.hierarchical_pos_embed = False
cfg.align.model_args.pos_embed_type = 't+h+w'
cfg.align.model_args.padding_type = "zero"
cfg.align.model_args.checkpoint_level = 0
cfg.align.model_args.use_relative_pos = True
cfg.align.model_args.self_attn_use_final_proj = True
# global vectors
cfg.align.model_args.num_global_vectors = 0
cfg.align.model_args.use_global_vector_ffn = True
cfg.align.model_args.use_global_self_attn = False
cfg.align.model_args.separate_global_qkv = False
cfg.align.model_args.global_dim_ratio = 1
# initialization
cfg.align.model_args.attn_linear_init_mode = "0"
cfg.align.model_args.ffn_linear_init_mode = "0"
cfg.align.model_args.ffn2_linear_init_mode = "2"
cfg.align.model_args.attn_proj_linear_init_mode = "2"
cfg.align.model_args.conv_init_mode = "0"
cfg.align.model_args.down_linear_init_mode = "0"
cfg.align.model_args.global_proj_linear_init_mode = "2"
cfg.align.model_args.norm_init_mode = "0"
# timestep embedding for diffusion
cfg.align.model_args.time_embed_channels_mult = 4
cfg.align.model_args.time_embed_use_scale_shift_norm = False
cfg.align.model_args.time_embed_dropout = 0.0
# readout
cfg.align.model_args.pool = "attention"
cfg.align.model_args.readout_seq = True
cfg.align.model_args.out_len = 6
cfg.latent_model = OmegaConf.create()
cfg.latent_model.input_shape = [10, 16, 16, 4]
cfg.latent_model.target_shape = [10, 16, 16, 4]
cfg.latent_model.base_units = 4
# block_units = null
cfg.latent_model.scale_alpha = 1.0
cfg.latent_model.num_heads = 4
cfg.latent_model.attn_drop = 0.1
cfg.latent_model.proj_drop = 0.1
cfg.latent_model.ffn_drop = 0.1
# inter-attn downsample/upsample
cfg.latent_model.downsample = 2
cfg.latent_model.downsample_type = "patch_merge"
cfg.latent_model.upsample_type = "upsample"
cfg.latent_model.upsample_kernel_size = 3
# cuboid attention
cfg.latent_model.depth = [1, 1]
cfg.latent_model.self_pattern = "axial"
# global vectors
cfg.latent_model.num_global_vectors = 0
cfg.latent_model.use_dec_self_global = False
cfg.latent_model.dec_self_update_global = True
cfg.latent_model.use_dec_cross_global = False
cfg.latent_model.use_global_vector_ffn = False
cfg.latent_model.use_global_self_attn = True
cfg.latent_model.separate_global_qkv = True
cfg.latent_model.global_dim_ratio = 1
# mise
cfg.latent_model.ffn_activation = "gelu"
cfg.latent_model.gated_ffn = False
cfg.latent_model.norm_layer = "layer_norm"
cfg.latent_model.padding_type = "zeros"
cfg.latent_model.pos_embed_type = "t+h+w"
cfg.latent_model.checkpoint_level = 0
cfg.latent_model.use_relative_pos = True
cfg.latent_model.self_attn_use_final_proj = True
# initialization
cfg.latent_model.attn_linear_init_mode = "0"
cfg.latent_model.ffn_linear_init_mode = "0"
cfg.latent_model.ffn2_linear_init_mode = "2"
cfg.latent_model.attn_proj_linear_init_mode = "2"
cfg.latent_model.conv_init_mode = "0"
cfg.latent_model.down_up_linear_init_mode = "0"
cfg.latent_model.global_proj_linear_init_mode = "2"
cfg.latent_model.norm_init_mode = "0"
# timestep embedding for diffusion
cfg.latent_model.time_embed_channels_mult = 4
cfg.latent_model.time_embed_use_scale_shift_norm = False
cfg.latent_model.time_embed_dropout = 0.0
cfg.latent_model.unet_res_connect = True
cfg.vae = OmegaConf.create()
cfg.vae.data_channels = layout_cfg.data_channels
# from stable-diffusion-v1-5
cfg.vae.down_block_types = ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D']
cfg.vae.in_channels = cfg.vae.data_channels
cfg.vae.block_out_channels = [128, 256, 512, 512]
cfg.vae.act_fn = 'silu'
cfg.vae.latent_channels = 4
cfg.vae.up_block_types = ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']
cfg.vae.norm_num_groups = 32
cfg.vae.layers_per_block = 2
cfg.vae.out_channels = cfg.vae.data_channels
return cfg
@staticmethod
def get_dataset_config():
cfg = OmegaConf.create()
cfg.dataset_name = "sevir_lr"
cfg.img_height = 128
cfg.img_width = 128
cfg.in_len = 7
cfg.out_len = 6
cfg.seq_len = 13
cfg.plot_stride = 1
cfg.interval_real_time = 10
cfg.sample_mode = "sequent"
cfg.stride = cfg.out_len
cfg.layout = "NTHWC"
cfg.start_date = None
cfg.train_val_split_date = (2019, 1, 1)
cfg.train_test_split_date = (2019, 6, 1)
cfg.end_date = None
cfg.metrics_mode = "0"
cfg.metrics_list = ('csi', 'pod', 'sucr', 'bias')
cfg.threshold_list = (16, 74, 133, 160, 181, 219)
cfg.aug_mode = "1"
return cfg
@staticmethod
def get_optim_config():
cfg = OmegaConf.create()
cfg.seed = None
cfg.total_batch_size = 32
cfg.micro_batch_size = 8
cfg.float32_matmul_precision = "high"
cfg.method = "adamw"
cfg.lr = 1.0E-6
cfg.wd = 1.0E-2
cfg.betas = (0.9, 0.999)
cfg.gradient_clip_val = 1.0
cfg.max_epochs = 50
cfg.loss_type = "l2"
# scheduler
cfg.warmup_percentage = 0.2
cfg.lr_scheduler_mode = "cosine" # Can be strings like 'linear', 'cosine', 'platue'
cfg.min_lr_ratio = 1.0E-3
cfg.warmup_min_lr_ratio = 0.0
# early stopping
cfg.monitor = "valid_loss_epoch"
cfg.early_stop = False
cfg.early_stop_mode = "min"
cfg.early_stop_patience = 5
cfg.save_top_k = 1
return cfg
@staticmethod
def get_logging_config():
cfg = OmegaConf.create()
cfg.logging_prefix = "PreDiff"
cfg.monitor_lr = True
cfg.monitor_device = False
cfg.track_grad_norm = -1
cfg.use_wandb = False
cfg.profiler = None
cfg.save_npy = False
return cfg
@staticmethod
def get_trainer_config():
cfg = OmegaConf.create()
cfg.check_val_every_n_epoch = 1
cfg.log_step_ratio = 0.001 # Logging every 1% of the total training steps per epoch
cfg.precision = 32
cfg.find_unused_parameters = True
cfg.num_sanity_val_steps = 2
return cfg
@staticmethod
def get_eval_config():
cfg = OmegaConf.create()
cfg.train_example_data_idx_list = [0, ]
cfg.val_example_data_idx_list = [0, ]
cfg.test_example_data_idx_list = [0, ]
cfg.eval_example_only = False
cfg.eval_aligned = True
cfg.eval_unaligned = True
cfg.num_samples_per_context = 1
cfg.font_size = 20
cfg.label_offset = (-0.5, 0.5)
cfg.label_avg_int = False
cfg.fvd_features = 400
return cfg
# endregion
# region Trainer and Optimizer Config
def configure_optimizers(self):
optim_cfg = self.oc.optim
params = list(self.torch_nn_module.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
if optim_cfg.method == "adamw":
optimizer = torch.optim.AdamW(params, lr=optim_cfg.lr, betas=optim_cfg.betas)
else:
raise NotImplementedError(f"opimization method {optim_cfg.method} not supported.")
warmup_iter = int(np.round(self.oc.optim.warmup_percentage * self.total_num_steps))
if optim_cfg.lr_scheduler_mode == 'none':
return {'optimizer': optimizer}
else:
if optim_cfg.lr_scheduler_mode == 'cosine':
warmup_scheduler = LambdaLR(optimizer,
lr_lambda=warmup_lambda(warmup_steps=warmup_iter,
min_lr_ratio=optim_cfg.warmup_min_lr_ratio))
cosine_scheduler = CosineAnnealingLR(optimizer,
T_max=(self.total_num_steps - warmup_iter),
eta_min=optim_cfg.min_lr_ratio * optim_cfg.lr)
lr_scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_iter])
lr_scheduler_config = {
'scheduler': lr_scheduler,
'interval': 'step',
'frequency': 1,
}
else:
raise NotImplementedError
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler_config}
def set_trainer_kwargs(self, **kwargs):
r"""
Default kwargs used when initializing pl.Trainer
"""
if self.oc.logging.profiler is None:
profiler = None
elif self.oc.logging.profiler == "pytorch":
profiler = PyTorchProfiler(filename=f"{self.oc.logging.logging_prefix}_PyTorchProfiler.log")
else:
raise NotImplementedError
checkpoint_callback = ModelCheckpoint(
monitor=self.oc.optim.monitor,
dirpath=os.path.join(self.save_dir, "checkpoints"),
filename="{epoch:03d}_{val/loss:.4f}",
auto_insert_metric_name=False,
save_top_k=self.oc.optim.save_top_k,
save_last=True,
mode="min",
)
callbacks = kwargs.pop("callbacks", [])
assert isinstance(callbacks, list)
for ele in callbacks:
assert isinstance(ele, Callback)
callbacks += [checkpoint_callback, ]
if self.oc.logging.monitor_lr:
callbacks += [LearningRateMonitor(logging_interval='step'), ]
if self.oc.logging.monitor_device:
callbacks += [DeviceStatsMonitor(), ]
if self.oc.optim.early_stop:
callbacks += [EarlyStopping(monitor=self.oc.optim.monitor,
min_delta=0.0,
patience=self.oc.optim.early_stop_patience,
verbose=False,
mode=self.oc.optim.early_stop_mode), ]
logger = kwargs.pop("logger", [])
tb_logger = pl_loggers.TensorBoardLogger(save_dir=self.save_dir)
csv_logger = pl_loggers.CSVLogger(save_dir=self.save_dir)
logger += [tb_logger, csv_logger]
if self.oc.logging.use_wandb:
wandb_logger = pl_loggers.WandbLogger(
name = self.oc.logging.logging_name,
id = self.oc.logging.run_id,
project=self.oc.logging.logging_prefix,
save_dir=self.save_dir
)
logger += [wandb_logger, ]
log_every_n_steps = max(1, int(self.oc.trainer.log_step_ratio * self.total_num_steps))
trainer_init_keys = inspect.signature(Trainer).parameters.keys()
ret = dict(
callbacks=callbacks,
# log
logger=logger,
log_every_n_steps=log_every_n_steps,
profiler=profiler,
# save
default_root_dir=self.save_dir,
# ddp
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=self.oc.trainer.find_unused_parameters),
# strategy=ApexDDPStrategy(find_unused_parameters=False, delay_allreduce=True),
# optimization
max_epochs=self.oc.optim.max_epochs,
check_val_every_n_epoch=self.oc.trainer.check_val_every_n_epoch,
gradient_clip_val=self.oc.optim.gradient_clip_val,
# NVIDIA amp
precision=self.oc.trainer.precision,
# misc
num_sanity_val_steps=self.oc.trainer.num_sanity_val_steps,
inference_mode=False,
)
oc_trainer_kwargs = OmegaConf.to_object(self.oc.trainer)
oc_trainer_kwargs = {key: val for key, val in oc_trainer_kwargs.items() if key in trainer_init_keys}
ret.update(oc_trainer_kwargs)
ret.update(kwargs)
return ret
# endregion
# region Properties Extraction and Misc Calc
@classmethod
def get_total_num_steps(
cls,
num_samples: int,
total_batch_size: int,
epoch: int = None):
r"""
Parameters
----------
num_samples: int
The number of samples of the datasets. `num_samples / micro_batch_size` is the number of steps per epoch.
total_batch_size: int
`total_batch_size == micro_batch_size * world_size * grad_accum`
epoch: int
"""
if epoch is None:
epoch = cls.get_optim_config().max_epochs
return int(epoch * num_samples / total_batch_size)
@staticmethod
def get_sevir_datamodule(dataset_cfg,
micro_batch_size: int = 1,
num_workers: int = 4):
dm = SEVIRLightningDataModule(
seq_len=dataset_cfg["seq_len"],
sample_mode=dataset_cfg["sample_mode"],
stride=dataset_cfg["stride"],
batch_size=micro_batch_size,
layout=dataset_cfg["layout"],
output_type=np.float32,
preprocess=True,
rescale_method="01",
verbose=False,
aug_mode=dataset_cfg["aug_mode"],
ret_contiguous=False,
# datamodule_only
dataset_name=dataset_cfg["dataset_name"],
start_date=dataset_cfg["start_date"],
train_test_split_date=dataset_cfg["train_test_split_date"],
end_date=dataset_cfg["end_date"],
val_ratio=dataset_cfg["val_ratio"],
num_workers=num_workers, )
return dm
@property
def in_slice(self):
if not hasattr(self, "_in_slice"):
in_slice, out_slice = step_layout_to_in_out_slice(
layout=self.oc.layout.layout,
in_len=self.oc.layout.in_len, in_step= self.oc.layout.in_step,
out_len=self.oc.layout.out_len, out_step = self.oc.layout.out_step,
in_out_diff= self.oc.layout.in_out_diff
)
self._in_slice = in_slice
self._out_slice = out_slice
return self._in_slice
@property
def out_slice(self):
if not hasattr(self, "_out_slice"):
in_slice, out_slice = step_layout_to_in_out_slice(
layout=self.oc.layout.layout,
in_len=self.oc.layout.in_len, in_step= self.oc.layout.in_step,
out_len=self.oc.layout.out_len, out_step = self.oc.layout.out_step,
in_out_diff= self.oc.layout.in_out_diff
)
self._in_slice = in_slice
self._out_slice = out_slice
return self._out_slice
@torch.no_grad()
def get_input(self, batch, **kwargs):
r"""
dataset dependent
re-implement it for each specific dataset
Parameters
----------
batch: Any
raw data batch from specific dataloader
Returns
-------
out: Sequence[torch.Tensor, Dict[str, Any]]
out[0] should be a torch.Tensor which is the target to generate
out[1] should be a dict consists of several key-value pairs for conditioning
"""
return self._get_input_sevirlr(batch=batch, return_verbose=kwargs.get("return_verbose", False))
@torch.no_grad()
def _get_input_sevirlr(self, batch, return_verbose=False):
seq = batch
in_seq = seq[self.in_slice]
out_seq = seq[self.out_slice].contiguous()
if return_verbose:
return out_seq, {"y": in_seq}, in_seq
else:
return out_seq, {"y": in_seq}
# endregion
# region Operation Step
def training_step(self, batch, batch_idx):
loss, loss_dict = self(batch)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=False)
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)
if self.current_epoch % self.oc.trainer.check_val_every_n_epoch == 0 \
and self.local_rank == 0:
if data_idx in self.train_example_data_idx_list:
target_seq, cond, context_seq = \
self.get_input(batch, return_verbose=True)
aligned_pred_seq_list = []
aligned_pred_label_list = []
pred_seq_list = []
pred_label_list = []
for i in range(self.oc.eval.num_samples_per_context):
# aligned sampling
if self.use_alignment and self.oc.eval.eval_aligned:
if self.alignment_type == "avg_x":
alignment_kwargs = get_alignment_kwargs_avg_x(context_seq=context_seq,
target_seq=target_seq)
else:
raise NotImplementedError
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False, ).contiguous()
aligned_pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
aligned_pred_label_list.append(f"{self.oc.logging.logging_prefix}_aligned_pred_{i}")
# no alignment
if self.oc.eval.eval_unaligned:
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
verbose=False, ).contiguous()
pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
pred_label_list.append(f"{self.oc.logging.logging_prefix}_pred_{i}")
pred_seq_list = aligned_pred_seq_list + pred_seq_list
pred_label_list = aligned_pred_label_list + pred_label_list
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].detach().float().cpu().numpy(),
target_seq=target_seq[0].detach().float().cpu().numpy(),
pred_seq=pred_seq_list,
pred_label=pred_label_list,
mode="train", )
return loss
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self(batch)
with self.ema_scope():
_, loss_dict_ema = self(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)
if not self.eval_example_only or data_idx in self.val_example_data_idx_list:
target_seq, cond, context_seq = \
self.get_input(batch, return_verbose=True)
aligned_pred_seq_list = []
aligned_pred_label_list = []
pred_seq_list = []
pred_label_list = []
for i in range(self.oc.eval.num_samples_per_context):
# aligned sampling
if self.use_alignment and self.oc.eval.eval_aligned:
if self.alignment_type == "avg_x":
alignment_kwargs = get_alignment_kwargs_avg_x(context_seq=context_seq,
target_seq=target_seq)
else:
raise NotImplementedError
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False, ).contiguous()
aligned_pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
aligned_pred_label_list.append(f"{self.oc.logging.logging_prefix}_aligned_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.valid_aligned_mse(pred_seq, target_seq)
self.valid_aligned_mae(pred_seq, target_seq)
self.valid_aligned_score.update(pred_seq, target_seq)
# no alignment
if self.oc.eval.eval_unaligned:
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
verbose=False, ).contiguous()
pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
pred_label_list.append(f"{self.oc.logging.logging_prefix}_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.valid_mse(pred_seq, target_seq)
self.valid_mae(pred_seq, target_seq)
self.valid_score.update(pred_seq, target_seq)
pred_seq_list = aligned_pred_seq_list + pred_seq_list
pred_label_list = aligned_pred_label_list + pred_label_list
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].detach().float().cpu().numpy(),
target_seq=target_seq[0].detach().float().cpu().numpy(),
pred_seq=pred_seq_list,
pred_label=pred_label_list,
mode="val",
suffix=f"_rank{self.local_rank}", )
def on_validation_epoch_end(self):
if self.oc.eval.eval_unaligned:
valid_mse = self.valid_mse.compute()
valid_mae = self.valid_mae.compute()
valid_score = self.valid_score.compute()
valid_loss = -valid_score["avg"]["csi"]
self.log('valid_loss_epoch', valid_loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('valid_mse_epoch', valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('valid_mae_epoch', valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=valid_score, prefix="valid")
self.valid_mse.reset()
self.valid_mae.reset()
self.valid_score.reset()
if self.oc.eval.eval_aligned:
valid_mse = self.valid_aligned_mse.compute()
valid_mae = self.valid_aligned_mae.compute()
valid_score = self.valid_aligned_score.compute()
valid_loss = -valid_score["avg"]["csi"]
self.log('valid_aligned_loss_epoch', valid_loss, prog_bar=True, on_step=False, on_epoch=True,
sync_dist=True)
self.log('valid_aligned_mse_epoch', valid_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('valid_aligned_mae_epoch', valid_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=valid_score, prefix="valid_aligned")
self.valid_aligned_mse.reset()
self.valid_aligned_mae.reset()
self.valid_aligned_score.reset()
def test_step(self, batch, batch_idx):
micro_batch_size = batch.shape[self.batch_axis]
data_idx = int(batch_idx * micro_batch_size)
if not self.eval_example_only or data_idx in self.val_example_data_idx_list:
target_seq, cond, context_seq = \
self.get_input(batch, return_verbose=True)
target_seq_bchw = rearrange(target_seq, "b t h w c -> (b t) c h w")
aligned_pred_seq_list = []
aligned_pred_label_list = []
pred_seq_list = []
pred_label_list = []
for i in range(self.oc.eval.num_samples_per_context):
# aligned sampling
if self.use_alignment and self.oc.eval.eval_aligned:
if self.alignment_type == "avg_x":
alignment_kwargs = get_alignment_kwargs_avg_x(context_seq=context_seq,
target_seq=target_seq)
else:
raise NotImplementedError
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False, ).contiguous()
if self.oc.logging.save_npy:
npy_path = os.path.join(self.npy_save_dir,
f"batch{batch_idx}_rank{self.local_rank}_sample{i}_aligned.npy")
np.save(npy_path, pred_seq.detach().float().cpu().numpy())
aligned_pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
aligned_pred_label_list.append(f"{self.oc.logging.logging_prefix}_aligned_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.test_aligned_mse(pred_seq, target_seq)
self.test_aligned_mae(pred_seq, target_seq)
self.test_aligned_score.update(pred_seq, target_seq)
# self.test_aligned_fvd.update(pred_seq, real=False)
pred_seq_bchw = rearrange(pred_seq, "b t h w c -> (b t) c h w")
self.test_aligned_ssim(pred_seq_bchw, target_seq_bchw)
# no alignment
if self.oc.eval.eval_unaligned:
pred_seq = self.sample(
cond=cond,
batch_size=micro_batch_size,
return_intermediates=False,
verbose=False, ).contiguous()
if self.oc.logging.save_npy:
npy_path = os.path.join(self.npy_save_dir,
f"batch{batch_idx}_rank{self.local_rank}_sample{i}.npy")
np.save(npy_path, pred_seq.detach().float().cpu().numpy())
pred_seq_list.append(pred_seq[0].detach().float().cpu().numpy())
pred_label_list.append(f"{self.oc.logging.logging_prefix}_pred_{i}")
if pred_seq.dtype is not torch.float:
pred_seq = pred_seq.float()
self.test_mse(pred_seq, target_seq)
self.test_mae(pred_seq, target_seq)
self.test_score.update(pred_seq, target_seq)
# self.test_fvd.update(pred_seq, real=False)
pred_seq_bchw = rearrange(pred_seq, "b t h w c -> (b t) c h w")
self.test_ssim(pred_seq_bchw, target_seq_bchw)
# if self.use_alignment and self.oc.eval.eval_aligned:
# self.test_aligned_fvd.update(target_seq, real=True)
# if self.oc.eval.eval_unaligned:
# self.test_fvd.update(target_seq, real=True)
pred_seq_list = aligned_pred_seq_list + pred_seq_list
pred_label_list = aligned_pred_label_list + pred_label_list
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].detach().float().cpu().numpy(),
target_seq=target_seq[0].detach().float().cpu().numpy(),
pred_seq=pred_seq_list,
pred_label=pred_label_list,
mode="test",
suffix=f"_rank{self.local_rank}", )
def on_test_epoch_end(self):
if self.oc.eval.eval_unaligned:
test_mse = self.test_mse.compute()
test_mae = self.test_mae.compute()
test_ssim = self.test_ssim.compute()
test_score = self.test_score.compute()
# test_fvd = self.test_fvd.compute()
self.log('test_mse_epoch', test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_mae_epoch', test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_ssim_epoch', test_ssim, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=test_score, prefix="test")
# self.log('test_fvd_epoch', test_fvd, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.test_mse.reset()
self.test_mae.reset()
self.test_ssim.reset()
self.test_score.reset()
# self.test_fvd.reset()
if self.oc.eval.eval_aligned:
test_mse = self.test_aligned_mse.compute()
test_mae = self.test_aligned_mae.compute()
test_ssim = self.test_aligned_ssim.compute()
test_score = self.test_aligned_score.compute()
# test_fvd = self.test_aligned_fvd.compute()
self.log('test_aligned_mse_epoch', test_mse, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_aligned_mae_epoch', test_mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log('test_aligned_ssim_epoch', test_ssim, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_score_epoch_end(score_dict=test_score, prefix="test_aligned")
# self.log('test_aligned_fvd_epoch', test_fvd, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
self.test_aligned_mse.reset()
self.test_aligned_mae.reset()
self.test_aligned_ssim.reset()
self.test_aligned_score.reset()
# self.test_aligned_fvd.reset()
# endregion
def save_vis_step_end(
self,
data_idx: int,
context_seq: np.ndarray,
target_seq: np.ndarray,
pred_seq: Union[np.ndarray, Sequence[np.ndarray]],
pred_label: Union[str, Sequence[str]] = None,
label_mode: str = "name",
mode: str = "train",
prefix: str = "",
suffix: str = "", ):
r"""
Parameters
----------
data_idx
context_seq, target_seq, pred_seq: np.ndarray
layout should not include batch
mode: str
"""
if mode == "train":
example_data_idx_list = self.train_example_data_idx_list
elif mode == "val":
example_data_idx_list = self.val_example_data_idx_list
elif mode == "test":
example_data_idx_list = self.test_example_data_idx_list
else:
raise ValueError(f"Wrong mode {mode}! Must be in ['train', 'val', 'test'].")
if label_mode == "name":
# use the given label
context_label = "context"
target_label = "target"
elif label_mode == "avg_int":
context_label = f"context\navg_int={np.mean(context_seq):.4f}"
target_label = f"target\navg_int={np.mean(target_seq):.4f}"
if isinstance(pred_label, Sequence):
pred_label = [f"{label}\navg_int={np.mean(seq):.4f}" for label, seq in zip(pred_label, pred_seq)]
elif isinstance(pred_label, str):
pred_label = f"{pred_label}\navg_int={np.mean(pred_seq):.4f}"
else:
raise TypeError(f"Wrong pred_label type {type(pred_label)}! must be in [str, Sequence[str]].")
else:
raise NotImplementedError
if isinstance(pred_seq, Sequence):
seq_list = [context_seq, target_seq] + list(pred_seq)
label_list = [context_label, target_label] + pred_label
else:
seq_list = [context_seq, target_seq, pred_seq]
label_list = [context_label, target_label, pred_label]
if data_idx in example_data_idx_list:
png_save_name = f"{prefix}{mode}_epoch_{self.current_epoch}_data_{data_idx}{suffix}.png"
vis_sevir_seq(
save_path=os.path.join(self.example_save_dir, png_save_name),
seq=seq_list,
label=label_list,
interval_real_time=10,
plot_stride=1, fs=self.oc.eval.fs,
label_offset=self.oc.eval.label_offset,
label_avg_int=self.oc.eval.label_avg_int, )
def log_score_epoch_end(self, score_dict: Dict, prefix: str = "valid"):
for metrics in self.oc.dataset.metrics_list:
for thresh in self.oc.dataset.threshold_list:
score_mean = np.mean(score_dict[thresh][metrics]).item()
self.log(f"{prefix}_{metrics}_{thresh}_epoch", score_mean,
prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
score_avg_mean = score_dict.get("avg", None)
if score_avg_mean is not None:
score_avg_mean = np.mean(score_avg_mean[metrics]).item()
self.log(f"{prefix}_{metrics}_avg_epoch", score_avg_mean,
prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
def on_before_optimizer_step(self, optimizer):
# Compute the 2-norm for each layer
# If using mixed precision, the gradients are already unscaled here
# reference: https://lightning.ai/docs/pytorch/2.0.9/debug/debugging_intermediate.html#look-out-for-exploding-gradients
if self.oc.logging.track_grad_norm != -1:
norms = grad_norm(self.torch_nn_module, norm_type=self.oc.logging.track_grad_norm)
self.log_dict(norms)
|