| import torch |
| import gradio as gr |
| from torchvision.utils import save_image |
| from torchvision.transforms import ToPILImage |
| from model import Generator |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| NOISE_DIM = 256 |
|
|
| G = Generator().to(DEVICE) |
| G.load_state_dict(torch.load("generator.pth", map_location=DEVICE)) |
| G.eval() |
|
|
| to_pil = ToPILImage() |
|
|
| def generate_image(): |
| noise = torch.randn(1, NOISE_DIM).to(DEVICE) |
|
|
| with torch.no_grad(): |
| image = G(noise) |
|
|
| image = (image + 1) / 2 |
| return to_pil(image.squeeze(0)) |
|
|
| demo = gr.Interface( |
| fn=generate_image, |
| inputs=None, |
| outputs=gr.Image(), |
| title="GAN Image Generator API" |
| ) |
|
|
| demo.launch() |