# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import torch import os from module import WeatherForecastModule from datamodule import WeatherForecastDataModule from pytorch_lightning.cli import LightningCLI from arch import Unet, R2Unet, AttUnet,AttR2Unet os.environ["WANDB_API_KEY"] = "f20f0d088ab8481e81a8623dcd59c22d4939fea1" os.environ["WANDB_ENTITY"] = "weatherforecast1024" if __name__ == "__main__": # main() cli = LightningCLI( model_class=WeatherForecastModule, datamodule_class=WeatherForecastDataModule, seed_everything_default=42, run=False, ) cli.model.set_path(cli.datamodule.hparams.dir_data) cli.model.set_size(cli.datamodule.hparams.rad_size,cli.datamodule.hparams.sat_size) cli.model.set_lat() cli.model.set_clim() cli.model.set_normalize() cli.model.set_denormalize() cli.trainer.fit(model = cli.model,datamodule = cli.datamodule) cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="best") # cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="/data/AI102024/checkpoints/last_version_12.ckpt")