UWColorKit / app.py
budzynski's picture
new color rendering
568914f
Raw
History Blame Contribute Delete
9.3 kB
"""UWColorKit demo: zero-shot underwater color recovery from location + depth."""
import json
import os
import gradio as gr
import client
import color
import exif_gps
DEFAULT_LAT, DEFAULT_LNG, DEFAULT_DEPTH = 20.21, -87.43, 12
RAW_EXTENSIONS = [
".dng", ".cr2", ".cr3", ".arw", ".arq", ".nef", ".nrw", ".raf", ".rw2",
".orf", ".pef", ".srw", ".sr2", ".raw", ".rwl", ".dcr", ".kdc", ".3fr",
".mef", ".mos", ".iiq", ".erf", ".x3f",
]
MAP_HTML = (
'<div id="scubai-map" style="height:360px;border-radius:8px;'
'overflow:hidden;margin:6px 0;border:1px solid #ddd;"></div>'
)
LEAFLET_JS = """
() => {
const CSS = "https://unpkg.com/leaflet@1.9.4/dist/leaflet.css";
const JS = "https://unpkg.com/leaflet@1.9.4/dist/leaflet.js";
function setNumber(id, v) {
const el = document.getElementById(id);
const inp = el ? el.querySelector("input") : null;
if (!inp) return;
const set = Object.getOwnPropertyDescriptor(window.HTMLInputElement.prototype, "value").set;
set.call(inp, v);
inp.dispatchEvent(new Event("input", { bubbles: true }));
}
function readNumber(id, fb) {
const el = document.getElementById(id);
const inp = el ? el.querySelector("input") : null;
const v = inp ? parseFloat(inp.value) : NaN;
return isNaN(v) ? fb : v;
}
function initMap() {
const el = document.getElementById("scubai-map");
if (!el || !window.L) { setTimeout(initMap, 200); return; }
if (el.dataset.ready === "1") return;
el.dataset.ready = "1";
const map = L.map(el).setView([readNumber("scubai_lat", 20.21), readNumber("scubai_lng", -87.43)], 5);
L.tileLayer("https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png",
{ maxZoom: 19, attribution: "© OpenStreetMap" }).addTo(map);
const marker = L.marker(map.getCenter(), { draggable: true }).addTo(map);
function publish(ll) {
setNumber("scubai_lat", Math.round(ll.lat * 1e6) / 1e6);
setNumber("scubai_lng", Math.round(ll.lng * 1e6) / 1e6);
}
map.on("click", (e) => { marker.setLatLng(e.latlng); publish(e.latlng); });
marker.on("dragend", () => publish(marker.getLatLng()));
window.scubaiSetPin = (la, ln) => {
la = parseFloat(la); ln = parseFloat(ln);
if (!isNaN(la) && !isNaN(ln)) { marker.setLatLng([la, ln]); map.panTo([la, ln]); }
};
setTimeout(() => map.invalidateSize(), 300);
}
function ensure() {
if (!document.querySelector('link[data-scubai]')) {
const l = document.createElement("link");
l.rel = "stylesheet"; l.href = CSS; l.setAttribute("data-scubai", "1");
document.head.appendChild(l);
}
if (window.L) return initMap();
let s = document.querySelector('script[data-scubai]');
if (!s) {
s = document.createElement("script");
s.src = JS; s.async = true; s.setAttribute("data-scubai", "1");
s.onload = initMap; document.head.appendChild(s);
} else { s.addEventListener("load", initMap); }
}
ensure();
}
"""
MOVE_PIN_JS = "(la, ln) => { if (window.scubaiSetPin) window.scubaiSetPin(la, ln); }"
APP_STORE_URL = "https://apps.apple.com/us/app/scubai/id6741768984"
SITE_URL = "https://scub.ai"
HEADER_MD = """
# 🤿 UWColorKit — zero-shot underwater white balance
The color engine inside **ScubAI** — it brings back the reds and oranges the water strips away.
"""
APP_CTA_MD = (
f"🐠 Shooting underwater on an iPhone? **[ScubAI]({APP_STORE_URL})** "
"does this automatically, every dive."
)
FOOTER_MD = (
"*Your RAW is decoded for this preview and discarded right after.*\n\n"
f"*UWColorKit · patent-pending underwater color engine · [scub.ai]({SITE_URL})*"
)
EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples")
def load_examples():
"""Read examples/examples.json -> ([[file, lat, lng, depth], ...], [label, ...]).
GPS lives in each RAW's EXIF, not the manifest, so the sample dives stay in
sync with the photo and the JSON only carries file + depth (+ optional label).
Coordinates come from the same exif_gps path the upload handler uses; a file
without GPS EXIF falls back to the default location (logged).
"""
manifest = os.path.join(EXAMPLES_DIR, "examples.json")
if not os.path.exists(manifest):
return [], []
try:
with open(manifest) as fh:
entries = json.load(fh)
except (OSError, ValueError) as exc:
print(f"[uwcolorkit] examples.json could not be parsed: {exc}")
return [], []
rows, labels, missing = [], [], 0
for e in entries:
filename = e.get("file", "")
path = os.path.join(EXAMPLES_DIR, filename)
if not os.path.exists(path):
missing += 1
continue
gps = exif_gps.extract_gps(path)
if gps is None:
print(f"[uwcolorkit] {filename!r}: no GPS in EXIF, using default location")
lat, lng = DEFAULT_LAT, DEFAULT_LNG
else:
lat, lng = gps
rows.append([path, round(lat, 6), round(lng, 6), e.get("depth", DEFAULT_DEPTH)])
labels.append(e.get("label") or os.path.splitext(filename)[0])
if missing:
print(f"[uwcolorkit] {missing} example(s) listed in examples.json have no file yet")
return rows, labels
def _discard(path):
try:
if path and os.path.exists(path):
os.remove(path)
except OSError:
pass
# Raw object (+ month) per upload, RAM only. Lets the user re-run with a different
# depth without re-reading the file (which we discard right after opening).
_DECODE_CACHE = {}
_DECODE_CACHE_MAX = 6
def _cache_put(path, value):
_DECODE_CACHE[path] = value
while len(_DECODE_CACHE) > _DECODE_CACHE_MAX:
old = _DECODE_CACHE.pop(next(iter(_DECODE_CACHE)))
try:
old[0].close() # free the evicted RawPy's buffers
except Exception:
pass
def on_upload(path):
gps = exif_gps.extract_gps(path) if path else None
if gps:
return gr.update(value=round(gps[0], 6)), gr.update(value=round(gps[1], 6))
return gr.update(), gr.update()
def correct(path, lat, lng, depth):
if not path:
raise gr.Error("Upload a RAW file first.")
if lat is None or lng is None:
raise gr.Error("Set a location on the map.")
# imread once per upload, cache the raw object, then discard the file. Re-runs
# (e.g. a new depth) re-render from the in-memory raw — no second file read.
cached = _DECODE_CACHE.get(path)
if cached is None:
if not os.path.exists(path):
raise gr.Error("Please re-upload the RAW — the previous file was cleared.")
month = exif_gps.extract_month(path) # read metadata before discarding
try:
raw = color.open_raw(path)
except Exception as exc:
raise gr.Error(f"Could not decode RAW ({type(exc).__name__}): {exc}")
finally:
_discard(path)
cached = (raw, month)
_cache_put(path, cached)
raw, month = cached # month is None when no EXIF date -> engine uses yearly data
try:
illum = client.fetch_illuminant(lat, lng, depth, month)
except Exception as exc:
raise gr.Error(f"Parameter API error: {exc}")
print(f"[uwcolorkit] month={month} stub={client.using_stub()}")
before = color.render(raw, color.baseline_wb(raw))
after = color.render(raw, color.wb_from_chromaticity(raw, illum.as_tuple()))
return (before, after)
def build():
with gr.Blocks(title="UWColorKit — underwater color engine", theme=gr.themes.Soft()) as demo:
gr.Markdown(HEADER_MD)
with gr.Row():
with gr.Column():
raw_file = gr.File(
label="RAW photo",
file_count="single",
type="filepath",
file_types=RAW_EXTENSIONS,
)
gr.HTML(MAP_HTML)
with gr.Row():
lat = gr.Number(label="Latitude", value=DEFAULT_LAT, elem_id="scubai_lat", precision=6)
lng = gr.Number(label="Longitude", value=DEFAULT_LNG, elem_id="scubai_lng", precision=6)
depth = gr.Slider(0, 40, value=DEFAULT_DEPTH, step=1, label="Depth (m)")
correct_btn = gr.Button("Correct color", variant="primary")
examples, example_labels = load_examples()
if examples:
gr.Examples(
examples=examples,
inputs=[raw_file, lat, lng, depth],
label="Or try a sample dive",
example_labels=example_labels,
)
with gr.Column():
slider = gr.ImageSlider(label="RAW → Corrected", type="numpy", slider_position=50)
gr.Markdown(APP_CTA_MD)
gr.Markdown(FOOTER_MD)
raw_file.upload(on_upload, inputs=raw_file, outputs=[lat, lng])
correct_btn.click(correct, inputs=[raw_file, lat, lng, depth], outputs=slider)
lat.change(None, [lat, lng], None, js=MOVE_PIN_JS)
lng.change(None, [lat, lng], None, js=MOVE_PIN_JS)
demo.load(None, None, None, js=LEAFLET_JS)
return demo
demo = build()
if __name__ == "__main__":
demo.queue().launch()