| from datamodule import WeatherForecastDataModuleOld | |
| import yaml | |
| if __name__ == "__main__": | |
| with open('MambaUnet_servir.yaml','r') as f: | |
| config = yaml.safe_load(f) | |
| dm = WeatherForecastDataModuleOld( | |
| **config['data'] | |
| ) | |
| dm.prepare_data() | |
| dm.setup(stage=None) | |
| train_loader = dm.train_dataloader() | |
| for item in train_loader: | |
| # N,C,T,H,W | |
| # print(len(item)) | |
| print(item.shape) | |
| break |