unet_code / src /train.py
weatherforecast1024's picture
Upload folder using huggingface_hub
f3b050a verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import os
from module import WeatherForecastModule
from datamodule import WeatherForecastDataModule
from pytorch_lightning.cli import LightningCLI
from arch import Unet, R2Unet, AttUnet,AttR2Unet
os.environ["WANDB_API_KEY"] = "f20f0d088ab8481e81a8623dcd59c22d4939fea1"
os.environ["WANDB_ENTITY"] = "weatherforecast1024"
if __name__ == "__main__":
# main()
cli = LightningCLI(
model_class=WeatherForecastModule,
datamodule_class=WeatherForecastDataModule,
seed_everything_default=42,
run=False,
)
cli.model.set_path(cli.datamodule.hparams.dir_data)
cli.model.set_size(cli.datamodule.hparams.rad_size,cli.datamodule.hparams.sat_size)
cli.model.set_lat()
cli.model.set_clim()
cli.model.set_normalize()
cli.model.set_denormalize()
cli.trainer.fit(model = cli.model,datamodule = cli.datamodule)
cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="best")
# cli.trainer.test(cli.model, datamodule=cli.datamodule, ckpt_path="/data/AI102024/checkpoints/last_version_12.ckpt")