File size: 19,262 Bytes
dfafaa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
import datetime as dt
import random
from pathlib import Path
import os
import hashlib
import requests
import json

import numpy as np
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvm
import torchvision.transforms as T
from PIL import Image
from torchcam.methods import GradCAM, GradCAMpp
from torchcam.utils import overlay_mask
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST

# Persist selected checkpoint across reruns
if "ckpt_path" not in st.session_state:
    st.session_state["ckpt_path"] = None


@st.cache_data(show_spinner=True)
def download_release_asset(url: str, dest_dir: str = "saved_checkpoints") -> str:
    """Download a remote checkpoint to dest_dir and return its local path.
    Cached so subsequent reruns won't redownload.
    """
    Path(dest_dir).mkdir(parents=True, exist_ok=True)
    url_hash = hashlib.sha256(url.encode("utf-8")).hexdigest()[:16]
    fname = Path(url).name or f"asset_{url_hash}.ckpt"
    if not fname.endswith(".ckpt"):
        fname = f"{fname}.ckpt"
    local_path = Path(dest_dir) / f"{url_hash}_{fname}"
    if local_path.exists() and local_path.stat().st_size > 0:
        return str(local_path)
    with requests.get(url, stream=True, timeout=120) as r:
        r.raise_for_status()
        with open(local_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                if chunk:
                    f.write(chunk)
    return str(local_path)


def load_release_presets() -> dict:
    """Load release preset URLs from multiple sources.
    Order: Streamlit secrets β†’ .streamlit/presets.json β†’ presets.json β†’ env var RELEASE_CKPTS_JSON.
    Returns a dict name -> url. Safe if nothing is configured.
    """
    # 1) Streamlit secrets
    try:
        if hasattr(st, "secrets") and "release_checkpoints" in st.secrets:
            # Convert to plain dict in case it's a Secrets object
            return dict(st.secrets["release_checkpoints"])  # type: ignore[index]
    except Exception:
        pass

    # 2) Local JSON files for dev
    for rel in (".streamlit/presets.json", "presets.json"):
        p = Path(rel)
        if p.exists():
            try:
                with open(p, "r", encoding="utf-8") as f:
                    data = json.load(f)
                # Either the file is a mapping directly, or has a top-level key
                if isinstance(data, dict) and data:
                    if "release_checkpoints" in data and isinstance(data["release_checkpoints"], dict):
                        return dict(data["release_checkpoints"])  # nested
                    return dict(data)  # flat mapping
            except Exception:
                pass

    # 3) Environment variable containing JSON mapping
    env_json = os.environ.get("RELEASE_CKPTS_JSON", "").strip()
    if env_json:
        try:
            data = json.loads(env_json)
            if isinstance(data, dict):
                return dict(data)
        except Exception:
            pass

    return {}


# ---------- Small utilities ----------
def get_device(choice="auto"):
    if choice == "cpu":
        return "cpu"
    if choice == "cuda":
        return "cuda"
    return "cuda" if torch.cuda.is_available() else "cpu"


def find_latest_best_ckpt():
    ckpts = sorted(
        Path("checkpoints").rglob("best.ckpt"), key=lambda p: p.stat().st_mtime
    )
    return ckpts[-1] if ckpts else None


def denorm_to_pil(x, mean, std):
    """
    x: torch.Tensor CxHxW (normalized), mean/std lists
    returns PIL.Image (RGB)
    """
    x = x.detach().cpu().clone()
    if len(mean) == 1:
        # grayscale
        m, s = float(mean[0]), float(std[0])
        x = x * s + m  # de-normalize
        x = x.clamp(0, 1)
        # convert to RGB for overlay convenience
        pil = T.ToPILImage()(x)
        pil = pil.convert("RGB")
        return pil
    else:
        mean = torch.tensor(mean)[:, None, None]
        std = torch.tensor(std)[:, None, None]
        x = x * std + mean
        x = x.clamp(0, 1)
        return T.ToPILImage()(x)


DATASET_CLASSES = {
    "fashion-mnist": [
        "T-shirt/top",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ],
    "mnist": [str(i) for i in range(10)],
    "cifar10": [
        "airplane",
        "automobile",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    ],
}


@st.cache_resource
def load_raw_dataset(name: str, root="data"):
    """Load the test split with ToTensor() only (for preview)."""
    tt = T.ToTensor()
    if name == "fashion-mnist":
        ds = FashionMNIST(root=root, train=False, download=True, transform=tt)
    elif name == "mnist":
        ds = MNIST(root=root, train=False, download=True, transform=tt)
    elif name == "cifar10":
        ds = CIFAR10(root=root, train=False, download=True, transform=tt)
    else:
        raise ValueError(f"Unknown dataset: {name}")
    classes = getattr(ds, "classes", None) or [str(i) for i in range(10)]
    return ds, classes


def pil_from_tensor(img_tensor, grayscale_to_rgb=True):
    pil = T.ToPILImage()(img_tensor)
    if grayscale_to_rgb and img_tensor.ndim == 3 and img_tensor.shape[0] == 1:
        pil = pil.convert("RGB")
    return pil


@st.cache_data(ttl=5, show_spinner=False)
def list_ckpts(root_dir: str, recursive: bool = True, filter: str = ""):
    """Return (labels, paths) sorted by mtime desc."""
    root = Path(root_dir)
    if not root.exists():
        return [], []
    files = sorted(
        (root.rglob("*.ckpt") if recursive else root.glob("*.ckpt")),
        key=lambda p: p.stat().st_mtime,
        reverse=True,
    )
    files = [p for p in files if filter in str(p)]
    labels = []
    for p in files:
        rel = p.relative_to(root)
        mtime = dt.datetime.fromtimestamp(p.stat().st_mtime).strftime("%Y-%m-%d %H:%M")
        labels.append(f"{rel}  β€’  {mtime}")
    return labels, [str(p) for p in files]


# ---------- Your SmallCNN (for FMNIST) ----------
class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


# ---------- Load model + meta from checkpoint ----------
def load_model_from_ckpt(ckpt_path: Path, device: str):
    ckpt = torch.load(str(ckpt_path), map_location=device)
    classes = ckpt.get("classes", None)
    meta = ckpt.get("meta", {})
    num_classes = len(classes) if classes else 10
    model_name = meta.get("model_name", "smallcnn")

    if model_name == "smallcnn":
        model = SmallCNN(num_classes=num_classes).to(device)
        default_target_layer = "conv2"
    elif model_name == "resnet18_cifar":
        m = tvm.resnet18(weights=None)
        m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        m.maxpool = nn.Identity()
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        model = m.to(device)
        default_target_layer = "layer4"
    elif model_name == "resnet18_imagenet":
        try:
            w = tvm.ResNet18_Weights.IMAGENET1K_V1
        except Exception:
            w = None
        m = tvm.resnet18(weights=w)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        model = m.to(device)
        default_target_layer = "layer4"
    else:
        raise ValueError(f"Unknown model_name in ckpt: {model_name}")

    model.load_state_dict(ckpt["model_state"])
    model.eval()
    # ensure meta has defaults
    meta.setdefault("default_target_layer", default_target_layer)
    return model, classes, meta


def build_transform_from_meta(meta):
    img_size = int(meta.get("img_size", 28))
    mean = meta.get("mean", [0.2860])  # FMNIST fallback
    std = meta.get("std", [0.3530])
    if len(mean) == 1:
        return T.Compose(
            [
                T.Grayscale(num_output_channels=1),
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize(mean, std),
            ]
        )
    else:
        return T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize(mean, std),
            ]
        )


def predict_and_cam(model, x, device, target_layer, topk=3, method="Grad-CAM"):
    """
    x: Tensor [1,C,H,W] normalized
    returns: list of dicts: {rank, class_index, prob, cam_tensor(H,W)}
    """
    cam_cls = GradCAM if method == "Grad-CAM" else GradCAMpp
    cam_extractor = cam_cls(model, target_layer=target_layer)

    logits = model(x.to(device))
    probs = torch.softmax(logits, dim=1)[0].detach().cpu()
    top_vals, top_idxs = probs.topk(topk)

    results = []
    for rank, (p, idx) in enumerate(zip(top_vals.tolist(), top_idxs.tolist())):
        retain = rank < topk - 1
        cams = cam_extractor(idx, logits, retain_graph=retain)  # list
        cam = cams[0].detach().cpu()  # [H,W] at feature-map resolution
        results.append(
            {"rank": rank + 1, "class_index": int(idx), "prob": float(p), "cam": cam}
        )
    return results, probs


def overlay_pil(base_pil_rgb: Image.Image, cam_tensor, alpha=0.5):
    # cam_tensor: torch.Tensor HxW in [0,1] (we'll min-max it)
    cam = cam_tensor.clone()
    cam -= cam.min()
    cam = cam / (cam.max() + 1e-8)
    heat = T.ToPILImage()(cam)  # single-channel PIL
    return overlay_mask(base_pil_rgb, heat, alpha=alpha)


# ---------- UI ----------
st.set_page_config(page_title="Grad-CAM Demo", page_icon="πŸ”", layout="wide")
st.title("πŸ” Grad-CAM Demo β€” upload an image, get top-k + heatmaps")

# Sidebar: checkpoint + options
with st.sidebar:
    st.header("Settings")

    ckpt_path = st.session_state.get("ckpt_path")

    st.subheader("Checkpoints")
    # Remote download (presets or URL), saved automatically to saved_checkpoints/
    presets = load_release_presets()
    preset_names = list(presets.keys())
    preset_sel = st.selectbox("Preset (GitHub Releases)", options=["(none)"] + preset_names, index=0) if preset_names else "(none)"
    url_input = st.text_input("Or paste asset URL", value="")
    if st.button("Download checkpoint", use_container_width=True):
        url = presets.get(preset_sel, "") if preset_sel != "(none)" else url_input.strip()
        if not url:
            st.warning("Provide a preset or paste a URL")
        else:
            try:
                path_dl = download_release_asset(url, dest_dir="saved_checkpoints")
                st.success(f"Downloaded to: {path_dl}")
                ckpt_path = path_dl
                st.session_state["ckpt_path"] = ckpt_path
                st.cache_data.clear()
            except Exception as e:
                st.error(f"Download failed: {e}")

    # Upload a user-provided .ckpt directly in the online app
    uploaded_ckpt = st.file_uploader("Upload checkpoint (.ckpt)", type=["ckpt"], accept_multiple_files=False)
    if uploaded_ckpt is not None and st.button("Use uploaded checkpoint", use_container_width=True):
        try:
            Path("saved_checkpoints").mkdir(parents=True, exist_ok=True)
            raw = uploaded_ckpt.read()
            content_hash = hashlib.sha256(raw).hexdigest()[:16]
            base_name = Path(uploaded_ckpt.name).name
            if not base_name.endswith(".ckpt"):
                base_name = f"{base_name}.ckpt"
            local_path = Path("saved_checkpoints") / f"{content_hash}_{base_name}"
            with open(local_path, "wb") as f:
                f.write(raw)
            ckpt_path = str(local_path)
            st.session_state["ckpt_path"] = ckpt_path
            st.success(f"Uploaded to: {ckpt_path}")
            st.cache_data.clear()
        except Exception as e:
            st.error(f"Upload failed: {e}")

    st.caption(f"Selected: {ckpt_path}")

    with st.expander("Checkpoint meta preview", expanded=False):
        try:
            if ckpt_path:
                m, c, meta_preview = load_model_from_ckpt(Path(ckpt_path), device="cpu")
                st.json(
                    {
                        "dataset": meta_preview.get("dataset"),
                        "model_name": meta_preview.get("model_name"),
                        "img_size": meta_preview.get("img_size"),
                        "target_layer": meta_preview.get("default_target_layer"),
                    }
                )
            else:
                st.info("No checkpoint selected yet.")
        except Exception as e:
            st.info(f"Could not read meta: {e}")

    method = st.selectbox("CAM method", ["Grad-CAM", "Grad-CAM++"], index=0)
    topk = st.slider("Top-k classes", min_value=1, max_value=10, value=3, step=1)
    alpha = st.slider(
        "Overlay alpha", min_value=0.1, max_value=0.9, value=0.5, step=0.05
    )

# Load model/meta
if not ckpt_path or not Path(ckpt_path).exists():
    st.info(
        "First choose a checkpoint:\n"
        "- Preset: pick from the list and click 'Download checkpoint'\n"
        "- URL: paste a direct .ckpt URL and click 'Download checkpoint'\n"
        "- Upload: select a .ckpt and click 'Use uploaded checkpoint'\n\n"
        "After a checkpoint is selected, upload an image or use the sample picker to see predictions and Grad-CAM overlays."
    )
    st.stop()

device = "cpu"
model, classes, meta = load_model_from_ckpt(Path(ckpt_path), device)
tf = build_transform_from_meta(meta)
target_layer = meta.get("default_target_layer", "conv2")

# Main: uploader
# Main: uploader OR dataset sample
st.subheader("1) Provide an image")
uploaded = st.file_uploader(
    "Upload PNG/JPG (or pick a sample below)", type=["png", "jpg", "jpeg"]
)

with st.expander("…or pick a sample from this model's dataset", expanded=False):
    ds_default = meta.get("dataset", "fashion-mnist")
    ds, ds_classes = load_raw_dataset(ds_default, root="data")
    targets = np.array(getattr(ds, "targets", [ds[i][1] for i in range(len(ds))]))

    # --- class filter (persisted) ---
    class_opts = ["(any)"] + list(ds_classes)
    class_sel = st.selectbox("Class filter", options=class_opts, index=0, key="class_sel")

    if class_sel == "(any)":
        filtered_idx = np.arange(len(ds))
    else:
        class_id = ds_classes.index(class_sel)
        filtered_idx = np.nonzero(targets == class_id)[0]

    # --- ensure we have a session index and keep it valid ---
    if "sample_idx" not in st.session_state:
        st.session_state["sample_idx"] = 0

    # clamp when filter changes or dataset length is small
    if len(filtered_idx) > 0:
        st.session_state["sample_idx"] = int(
            np.clip(st.session_state["sample_idx"], 0, len(filtered_idx) - 1)
        )

    if len(filtered_idx) == 0:
        st.info("No samples found for this class.")
        sample_img = None
    else:
        col_l, col_r = st.columns([2, 1])

        with col_r:
            picked = st.button("Pick random", use_container_width=True, key="btn_pick_random")
            if picked:
                # IMPORTANT: update session_state BEFORE creating the slider
                cur = st.session_state["sample_idx"]
                if len(filtered_idx) > 1:
                    new_idx = random.randrange(len(filtered_idx) - 1)
                    if new_idx >= cur:
                        new_idx += 1
                else:
                    new_idx = 0
                st.session_state["sample_idx"] = new_idx
                # no st.rerun() needed; the app will rerun after the button

        with col_l:
            # Now instantiate the slider (AFTER any state changes above)
            st.slider(
                "Pick index (within filtered samples)",
                0, max(0, len(filtered_idx) - 1),
                key="sample_idx",  # same key as the state we set above
            )

        raw_idx = int(filtered_idx[st.session_state["sample_idx"]])
        img_tensor, label = ds[raw_idx]
        sample_img = pil_from_tensor(img_tensor, grayscale_to_rgb=True)

        st.image(
            sample_img,
            caption=f"Sample β€’ {ds_default} β€’ class={ds_classes[label]} β€’ idx={raw_idx}",
            width=160,
            use_container_width=False,
        )

# Decide the input image used downstream
if uploaded is not None:
    pil = Image.open(uploaded).convert("RGB")
elif "sample_img" in locals() and sample_img is not None:
    pil = sample_img
else:
    st.info("Upload an image or open the sample picker above.")
    st.stop()

col_in, col_cfg = st.columns([2, 1])

with col_in:
    if uploaded:
        pil = Image.open(uploaded).convert("RGB")
    elif sample_img is not None:
        pil = sample_img
    else:
        st.info("Upload an image or check 'Use a sample image'.")
        st.stop()

    st.image(pil, caption="Input", use_container_width=True)

with col_cfg:
    st.markdown("**Model meta**")
    st.json(
        {
            "dataset": meta.get("dataset"),
            "model_name": meta.get("model_name"),
            "img_size": meta.get("img_size"),
            "target_layer": target_layer,
            "mean": meta.get("mean"),
            "std": meta.get("std"),
            "classes": (
                classes
                if classes and len(classes) <= 10
                else f"{len(classes) if classes else 'N/A'} classes"
            ),
        }
    )

# Prepare tensor + denormalized PIL base for overlay
x = tf(pil)  # CxHxW normalized
x_batched = x.unsqueeze(0)  # 1xCxHxW
base_pil = denorm_to_pil(x, meta.get("mean", [0.2860]), meta.get("std", [0.3530]))

# Predict + CAM
with st.spinner("Running inference + Grad-CAM..."):
    try:
        cam_results, probs = predict_and_cam(
            model, x_batched, device, target_layer, topk=topk, method=method
        )
    except Exception as e:
        st.error(
            f"Grad-CAM failed. Target layer likely incorrect."
            f"\nLayer: {target_layer}\nError: {e}"
        )
        st.stop()

# Top-k table
st.subheader("2) Top-k predictions")
rows = []
for r in cam_results:
    name = classes[r["class_index"]] if classes else str(r["class_index"])
    rows.append(
        {
            "rank": r["rank"],
            "class": name,
            "index": r["class_index"],
            "prob": round(r["prob"], 4),
        }
    )
st.dataframe(rows, use_container_width=True)

# Overlays
st.subheader("3) Grad-CAM overlays")
cols = st.columns(len(cam_results))
for c, r in zip(cols, cam_results):
    name = classes[r["class_index"]] if classes else str(r["class_index"])
    ov = overlay_pil(base_pil, r["cam"], alpha=alpha)
    with c:
        st.image(
            ov,
            caption=f"Top{r['rank']}: {name} ({r['prob']:.3f})",
            use_container_width=True,
        )