Spaces:
Running
Running
| """UWColorKit demo: zero-shot underwater color recovery from location + depth.""" | |
| import json | |
| import os | |
| import gradio as gr | |
| import client | |
| import color | |
| import exif_gps | |
| DEFAULT_LAT, DEFAULT_LNG, DEFAULT_DEPTH = 20.21, -87.43, 12 | |
| RAW_EXTENSIONS = [ | |
| ".dng", ".cr2", ".cr3", ".arw", ".arq", ".nef", ".nrw", ".raf", ".rw2", | |
| ".orf", ".pef", ".srw", ".sr2", ".raw", ".rwl", ".dcr", ".kdc", ".3fr", | |
| ".mef", ".mos", ".iiq", ".erf", ".x3f", | |
| ] | |
| MAP_HTML = ( | |
| '<div id="scubai-map" style="height:360px;border-radius:8px;' | |
| 'overflow:hidden;margin:6px 0;border:1px solid #ddd;"></div>' | |
| ) | |
| LEAFLET_JS = """ | |
| () => { | |
| const CSS = "https://unpkg.com/leaflet@1.9.4/dist/leaflet.css"; | |
| const JS = "https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"; | |
| function setNumber(id, v) { | |
| const el = document.getElementById(id); | |
| const inp = el ? el.querySelector("input") : null; | |
| if (!inp) return; | |
| const set = Object.getOwnPropertyDescriptor(window.HTMLInputElement.prototype, "value").set; | |
| set.call(inp, v); | |
| inp.dispatchEvent(new Event("input", { bubbles: true })); | |
| } | |
| function readNumber(id, fb) { | |
| const el = document.getElementById(id); | |
| const inp = el ? el.querySelector("input") : null; | |
| const v = inp ? parseFloat(inp.value) : NaN; | |
| return isNaN(v) ? fb : v; | |
| } | |
| function initMap() { | |
| const el = document.getElementById("scubai-map"); | |
| if (!el || !window.L) { setTimeout(initMap, 200); return; } | |
| if (el.dataset.ready === "1") return; | |
| el.dataset.ready = "1"; | |
| const map = L.map(el).setView([readNumber("scubai_lat", 20.21), readNumber("scubai_lng", -87.43)], 5); | |
| L.tileLayer("https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png", | |
| { maxZoom: 19, attribution: "© OpenStreetMap" }).addTo(map); | |
| const marker = L.marker(map.getCenter(), { draggable: true }).addTo(map); | |
| function publish(ll) { | |
| setNumber("scubai_lat", Math.round(ll.lat * 1e6) / 1e6); | |
| setNumber("scubai_lng", Math.round(ll.lng * 1e6) / 1e6); | |
| } | |
| map.on("click", (e) => { marker.setLatLng(e.latlng); publish(e.latlng); }); | |
| marker.on("dragend", () => publish(marker.getLatLng())); | |
| window.scubaiSetPin = (la, ln) => { | |
| la = parseFloat(la); ln = parseFloat(ln); | |
| if (!isNaN(la) && !isNaN(ln)) { marker.setLatLng([la, ln]); map.panTo([la, ln]); } | |
| }; | |
| setTimeout(() => map.invalidateSize(), 300); | |
| } | |
| function ensure() { | |
| if (!document.querySelector('link[data-scubai]')) { | |
| const l = document.createElement("link"); | |
| l.rel = "stylesheet"; l.href = CSS; l.setAttribute("data-scubai", "1"); | |
| document.head.appendChild(l); | |
| } | |
| if (window.L) return initMap(); | |
| let s = document.querySelector('script[data-scubai]'); | |
| if (!s) { | |
| s = document.createElement("script"); | |
| s.src = JS; s.async = true; s.setAttribute("data-scubai", "1"); | |
| s.onload = initMap; document.head.appendChild(s); | |
| } else { s.addEventListener("load", initMap); } | |
| } | |
| ensure(); | |
| } | |
| """ | |
| MOVE_PIN_JS = "(la, ln) => { if (window.scubaiSetPin) window.scubaiSetPin(la, ln); }" | |
| APP_STORE_URL = "https://apps.apple.com/us/app/scubai/id6741768984" | |
| SITE_URL = "https://scub.ai" | |
| HEADER_MD = """ | |
| # 🤿 UWColorKit — zero-shot underwater white balance | |
| The color engine inside **ScubAI** — it brings back the reds and oranges the water strips away. | |
| """ | |
| APP_CTA_MD = ( | |
| f"🐠 Shooting underwater on an iPhone? **[ScubAI]({APP_STORE_URL})** " | |
| "does this automatically, every dive." | |
| ) | |
| FOOTER_MD = ( | |
| "*Your RAW is decoded for this preview and discarded right after.*\n\n" | |
| f"*UWColorKit · patent-pending underwater color engine · [scub.ai]({SITE_URL})*" | |
| ) | |
| EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples") | |
| def load_examples(): | |
| """Read examples/examples.json -> ([[file, lat, lng, depth], ...], [label, ...]). | |
| GPS lives in each RAW's EXIF, not the manifest, so the sample dives stay in | |
| sync with the photo and the JSON only carries file + depth (+ optional label). | |
| Coordinates come from the same exif_gps path the upload handler uses; a file | |
| without GPS EXIF falls back to the default location (logged). | |
| """ | |
| manifest = os.path.join(EXAMPLES_DIR, "examples.json") | |
| if not os.path.exists(manifest): | |
| return [], [] | |
| try: | |
| with open(manifest) as fh: | |
| entries = json.load(fh) | |
| except (OSError, ValueError) as exc: | |
| print(f"[uwcolorkit] examples.json could not be parsed: {exc}") | |
| return [], [] | |
| rows, labels, missing = [], [], 0 | |
| for e in entries: | |
| filename = e.get("file", "") | |
| path = os.path.join(EXAMPLES_DIR, filename) | |
| if not os.path.exists(path): | |
| missing += 1 | |
| continue | |
| gps = exif_gps.extract_gps(path) | |
| if gps is None: | |
| print(f"[uwcolorkit] {filename!r}: no GPS in EXIF, using default location") | |
| lat, lng = DEFAULT_LAT, DEFAULT_LNG | |
| else: | |
| lat, lng = gps | |
| rows.append([path, round(lat, 6), round(lng, 6), e.get("depth", DEFAULT_DEPTH)]) | |
| labels.append(e.get("label") or os.path.splitext(filename)[0]) | |
| if missing: | |
| print(f"[uwcolorkit] {missing} example(s) listed in examples.json have no file yet") | |
| return rows, labels | |
| def _discard(path): | |
| try: | |
| if path and os.path.exists(path): | |
| os.remove(path) | |
| except OSError: | |
| pass | |
| # Raw object (+ month) per upload, RAM only. Lets the user re-run with a different | |
| # depth without re-reading the file (which we discard right after opening). | |
| _DECODE_CACHE = {} | |
| _DECODE_CACHE_MAX = 6 | |
| def _cache_put(path, value): | |
| _DECODE_CACHE[path] = value | |
| while len(_DECODE_CACHE) > _DECODE_CACHE_MAX: | |
| old = _DECODE_CACHE.pop(next(iter(_DECODE_CACHE))) | |
| try: | |
| old[0].close() # free the evicted RawPy's buffers | |
| except Exception: | |
| pass | |
| def on_upload(path): | |
| gps = exif_gps.extract_gps(path) if path else None | |
| if gps: | |
| return gr.update(value=round(gps[0], 6)), gr.update(value=round(gps[1], 6)) | |
| return gr.update(), gr.update() | |
| def correct(path, lat, lng, depth): | |
| if not path: | |
| raise gr.Error("Upload a RAW file first.") | |
| if lat is None or lng is None: | |
| raise gr.Error("Set a location on the map.") | |
| # imread once per upload, cache the raw object, then discard the file. Re-runs | |
| # (e.g. a new depth) re-render from the in-memory raw — no second file read. | |
| cached = _DECODE_CACHE.get(path) | |
| if cached is None: | |
| if not os.path.exists(path): | |
| raise gr.Error("Please re-upload the RAW — the previous file was cleared.") | |
| month = exif_gps.extract_month(path) # read metadata before discarding | |
| try: | |
| raw = color.open_raw(path) | |
| except Exception as exc: | |
| raise gr.Error(f"Could not decode RAW ({type(exc).__name__}): {exc}") | |
| finally: | |
| _discard(path) | |
| cached = (raw, month) | |
| _cache_put(path, cached) | |
| raw, month = cached # month is None when no EXIF date -> engine uses yearly data | |
| try: | |
| illum = client.fetch_illuminant(lat, lng, depth, month) | |
| except Exception as exc: | |
| raise gr.Error(f"Parameter API error: {exc}") | |
| print(f"[uwcolorkit] month={month} stub={client.using_stub()}") | |
| before = color.render(raw, color.baseline_wb(raw)) | |
| after = color.render(raw, color.wb_from_chromaticity(raw, illum.as_tuple())) | |
| return (before, after) | |
| def build(): | |
| with gr.Blocks(title="UWColorKit — underwater color engine", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(HEADER_MD) | |
| with gr.Row(): | |
| with gr.Column(): | |
| raw_file = gr.File( | |
| label="RAW photo", | |
| file_count="single", | |
| type="filepath", | |
| file_types=RAW_EXTENSIONS, | |
| ) | |
| gr.HTML(MAP_HTML) | |
| with gr.Row(): | |
| lat = gr.Number(label="Latitude", value=DEFAULT_LAT, elem_id="scubai_lat", precision=6) | |
| lng = gr.Number(label="Longitude", value=DEFAULT_LNG, elem_id="scubai_lng", precision=6) | |
| depth = gr.Slider(0, 40, value=DEFAULT_DEPTH, step=1, label="Depth (m)") | |
| correct_btn = gr.Button("Correct color", variant="primary") | |
| examples, example_labels = load_examples() | |
| if examples: | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[raw_file, lat, lng, depth], | |
| label="Or try a sample dive", | |
| example_labels=example_labels, | |
| ) | |
| with gr.Column(): | |
| slider = gr.ImageSlider(label="RAW → Corrected", type="numpy", slider_position=50) | |
| gr.Markdown(APP_CTA_MD) | |
| gr.Markdown(FOOTER_MD) | |
| raw_file.upload(on_upload, inputs=raw_file, outputs=[lat, lng]) | |
| correct_btn.click(correct, inputs=[raw_file, lat, lng, depth], outputs=slider) | |
| lat.change(None, [lat, lng], None, js=MOVE_PIN_JS) | |
| lng.change(None, [lat, lng], None, js=MOVE_PIN_JS) | |
| demo.load(None, None, None, js=LEAFLET_JS) | |
| return demo | |
| demo = build() | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |