MNIST conditional GAN (cGAN)
Class-conditional synthesis of 28×28 grayscale MNIST-style digits. The generator maps noise z and digit label y to an image; the discriminator uses a projection discriminator (Miyato & Koyama, ICLR 2018) with spectral normalization.
Files in this model repo
| File | Description |
|---|---|
mnist_cgan_generator.pth |
Generator state_dict for inference (matches submission_digit_cgan.py). |
training_<source>.pt |
Original checkpoint file (full training state when applicable). |
cgan_architecture.py |
Copy of digit_cgan/model.py (Generator + Discriminator definitions). |
generator_config.json |
Inferred constructor kwargs and metadata. |
Weights
Checkpoint: generator-only export (epoch not in file).
Inferred architecture (from tensors):
latent_dim=100,embed_dim=100,base_channels=384,num_classes=10- Output shape:
(B, 1, 28, 28), values in[-1, 1](tanh).
Source file: mnist_cgan_generator.pth.
Load the generator (example)
import torch
from huggingface_hub import hf_hub_download
import sys
sys.path.insert(0, "/path/to/week-06")
from digit_cgan.model import Generator
repo_id = "<YOUR_REPO_ID>"
weights = hf_hub_download(repo_id, "mnist_cgan_generator.pth")
G = Generator(
latent_dim=100,
embed_dim=100,
base_channels=384,
num_classes=10,
)
G.load_state_dict(torch.load(weights, map_location="cpu", weights_only=True))
G.eval()
with torch.no_grad():
z = torch.randn(4, 100)
y = torch.tensor([0, 1, 2, 3])
fake = G(z, y)
Architecture (cgan_architecture.py)
- Generator: class embedding concatenated with z, linear reshape to 7×7 features, two ConvTranspose2d stages to 28×28, conv to 1 channel + tanh.
- Discriminator: convolutional backbone with spectral norm, global pool, linear map to a feature vector; score is unconditional linear term plus inner product between features and a class embedding (projection term).
See T. Miyato & M. Koyama, cGANs with Projection Discriminator, ICLR 2018.
Training (typical)
python -m digit_cgan.train — hinge loss, Adam, optional EMA on the
generator for sampling; best FID checkpoints use the EMA weights in
best_generator.pth.
CLI defaults in train.py include latent_dim=100, embed_dim=100;
base_channels_g / base_channels_d / feature_dim may differ per run —
always use generator_config.json or infer from weights as above.
Limitations
MNIST is a simple benchmark; generalization to out-of-distribution digit styles is not guaranteed.
References
- Takeru Miyato, Masanori Koyama, cGANs with Projection Discriminator, https://arxiv.org/abs/1802.05637