hexware's picture
Update app.py
578fc7f verified
import os
import io
import json
import time
import uuid
import random
import tempfile
import zipfile
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import gradio as gr
import spaces
from PIL import Image
from pptx import Presentation
from diffusers import QwenImageLayeredPipeline
from huggingface_hub import HfApi, login
from huggingface_hub.utils import HfHubHTTPError
LOG_DIR = "/tmp/local"
MAX_SEED = np.iinfo(np.int32).max
# -------------------------
# HF auth (Spaces secrets)
# -------------------------
def _get_hf_token() -> Optional[str]:
# priority: HF_TOKEN -> hf -> HUGGINGFACEHUB_API_TOKEN
return (
os.environ.get("HF_TOKEN")
or os.environ.get("hf")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
def _get_dataset_repo() -> Optional[str]:
# priority: DATASET_REPO -> HF_DATASET_REPO
return os.environ.get("DATASET_REPO") or os.environ.get("HF_DATASET_REPO")
HF_TOKEN = _get_hf_token()
DATASET_REPO = _get_dataset_repo()
if HF_TOKEN:
try:
login(token=HF_TOKEN)
except Exception as e:
print("HF login failed:", repr(e))
# -------------------------
# Pipeline init
# -------------------------
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = QwenImageLayeredPipeline.from_pretrained(
"Qwen/Qwen-Image-Layered", torch_dtype=dtype
).to(device)
def ensure_dirname(path: str):
if path and not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def px_to_emu(px, dpi=96):
inch = px / dpi
emu = inch * 914400
return int(emu)
def imagelist_to_pptx_from_pils(images: List[Image.Image]) -> str:
if not images:
raise ValueError("No images to export")
w, h = images[0].size
prs = Presentation()
prs.slide_width = px_to_emu(w)
prs.slide_height = px_to_emu(h)
slide = prs.slides.add_slide(prs.slide_layouts[6])
left = top = 0
# each layer as a picture on the same slide (stacked)
for img in images:
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp.name)
slide.shapes.add_picture(
tmp.name,
left,
top,
width=px_to_emu(w),
height=px_to_emu(h),
)
out = tempfile.NamedTemporaryFile(suffix=".pptx", delete=False)
prs.save(out.name)
return out.name
def imagelist_to_zip_from_pils(images: List[Image.Image], prefix: str = "layer") -> str:
outzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False)
with zipfile.ZipFile(outzip.name, "w", zipfile.ZIP_DEFLATED) as zipf:
for i, img in enumerate(images):
buf = io.BytesIO()
img.save(buf, format="PNG")
zipf.writestr(f"{prefix}_{i+1}.png", buf.getvalue())
return outzip.name
def _clamp_int(x, default: int, lo: int, hi: int) -> int:
try:
v = int(x)
except Exception:
v = default
return max(lo, min(hi, v))
def _normalize_resolution(resolution: Any) -> int:
resolution = _clamp_int(resolution, default=640, lo=640, hi=1024)
if resolution not in (640, 1024):
resolution = 640
return resolution
def _normalize_input_image(input_image: Any) -> Image.Image:
# Normalize image input
if isinstance(input_image, list):
input_image = input_image[0]
if isinstance(input_image, str):
pil_image = Image.open(input_image).convert("RGB").convert("RGBA")
elif isinstance(input_image, Image.Image):
pil_image = input_image.convert("RGB").convert("RGBA")
elif isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA")
else:
raise ValueError(f"Unsupported input_image type: {type(input_image)}")
return pil_image
# -------------------------
# Dataset persistence helpers
# -------------------------
def ds_enabled() -> bool:
return bool(_get_hf_token()) and bool(_get_dataset_repo())
def ds_api() -> HfApi:
token = _get_hf_token()
if not token:
raise RuntimeError("HF token missing")
return HfApi(token=token)
def ds_repo_id() -> str:
repo = _get_dataset_repo()
if not repo:
raise RuntimeError("DATASET_REPO missing")
return repo
def ds_ensure_repo() -> Tuple[bool, str]:
"""
Try to create dataset repo if missing.
Returns (ok, message)
"""
if not ds_enabled():
return False, "Dataset persistence disabled: missing HF_TOKEN/hf or DATASET_REPO"
api = ds_api()
repo_id = ds_repo_id()
try:
api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
return True, f"Dataset repo ready: {repo_id}"
except HfHubHTTPError as e:
return False, f"Failed to create/ensure dataset repo: {e}"
except Exception as e:
return False, f"Failed to create/ensure dataset repo: {repr(e)}"
def ds_upload_bytes(path_in_repo: str, data: bytes, commit_message: str) -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled: missing HF_TOKEN/hf or DATASET_REPO"
api = ds_api()
repo_id = ds_repo_id()
try:
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(data)
tmp.flush()
api.upload_file(
path_or_fileobj=tmp.name,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type="dataset",
commit_message=commit_message,
)
return True, f"Uploaded: {path_in_repo}"
except HfHubHTTPError as e:
return False, f"Upload failed (HTTP): {e}"
except Exception as e:
return False, f"Upload failed: {repr(e)}"
def ds_download_bytes(path_in_repo: str) -> Tuple[Optional[bytes], str]:
if not ds_enabled():
return None, "Dataset persistence disabled"
api = ds_api()
repo_id = ds_repo_id()
try:
# Download to temp file
tmpdir = tempfile.mkdtemp()
local_path = api.hf_hub_download(
repo_id=repo_id,
repo_type="dataset",
filename=path_in_repo,
local_dir=tmpdir,
)
with open(local_path, "rb") as f:
return f.read(), "OK"
except HfHubHTTPError as e:
return None, f"Download failed (HTTP): {e}"
except Exception as e:
return None, f"Download failed: {repr(e)}"
def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]:
if not ds_enabled():
return [], "Dataset persistence disabled"
api = ds_api()
repo_id = ds_repo_id()
try:
files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
sess = set()
for p in files:
if p.startswith("sessions/") and p.endswith("session.json"):
parts = p.split("/")
if len(parts) >= 3:
sess.add(parts[1])
out = sorted(sess, reverse=True)[:max_sessions]
return out, f"Found {len(out)} session(s)"
except Exception as e:
return [], f"List sessions failed: {repr(e)}"
def ds_get_root_index() -> Tuple[Optional[Dict[str, Any]], str]:
"""
Read dataset root index.json (used for "last session").
Expected keys: id or last_session_id or session_id.
"""
b, msg = ds_download_bytes("index.json")
if b is None:
return None, msg
try:
obj = json.loads(b.decode("utf-8"))
if not isinstance(obj, dict):
return None, "index.json is not an object"
return obj, "OK"
except Exception as e:
return None, f"Failed to parse index.json: {repr(e)}"
def ds_get_last_session_id() -> Tuple[Optional[str], str]:
idx, msg = ds_get_root_index()
if not idx:
return None, msg
for k in ("id", "last_session_id", "session_id"):
v = idx.get(k)
if isinstance(v, str) and v.strip():
return v.strip(), "OK"
return None, "index.json missing id/last_session_id/session_id"
def ds_set_root_index(session_id: str) -> Tuple[bool, str]:
"""
Update dataset root index.json so we can auto-load the last session after refresh/restart.
"""
payload = {
"id": session_id,
"last_session_id": session_id,
"updated_at": time.time(),
}
b = json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8")
return ds_upload_bytes("index.json", b, f"update index.json {session_id}")
# -------------------------
# Node / History model
# -------------------------
@dataclass
class NodeMeta:
node_id: str
name: str
parent_id: Optional[str]
children: List[str]
op: str # "root" | "decompose" | "refine" | "duplicate"
created_at: float
# for redo refine
source_node_id: Optional[str] = None
source_layer_idx: Optional[int] = None
sub_layers: Optional[int] = None
# settings snapshot (store as dict for export/debug)
settings: Optional[Dict[str, Any]] = None
def _new_id(prefix: str) -> str:
return f"{prefix}_{uuid.uuid4().hex[:10]}"
def _make_chips(state: Dict[str, Any]) -> str:
node_id = state.get("selected_node_id")
nodes: Dict[str, Any] = state.get("nodes", {})
if not node_id or node_id not in nodes:
return "[root] [parent:-] [children:0]"
meta = nodes[node_id]["meta"]
parent = meta.get("parent_id") or "-"
children = meta.get("children") or []
root = state.get("root_node_id") or "-"
parts = []
parts.append(f"[root:{root}]")
parts.append(f"[parent:{parent}]")
parts.append(f"[children:{len(children)}]")
return " ".join(parts)
def _history_choices(state: Dict[str, Any]) -> List[Tuple[str, str]]:
# return list of (label, value)
nodes: Dict[str, Any] = state.get("nodes", {})
# stable-ish order by created_at
items = []
for nid, obj in nodes.items():
meta = obj["meta"]
items.append((meta["created_at"], nid, meta["name"]))
items.sort(key=lambda x: x[0])
return [(f"{name}{nid}", nid) for _, nid, name in items]
def _get_node_images(state: Dict[str, Any], node_id: str) -> List[Image.Image]:
nodes: Dict[str, Any] = state.get("nodes", {})
if node_id not in nodes:
return []
return nodes[node_id].get("images", []) or []
def _set_selected_node(state: Dict[str, Any], node_id: str):
state["selected_node_id"] = node_id
return state
def _add_node(
state: Dict[str, Any],
*,
name: str,
parent_id: Optional[str],
op: str,
images: List[Image.Image],
settings: Optional[Dict[str, Any]] = None,
source_node_id: Optional[str] = None,
source_layer_idx: Optional[int] = None,
sub_layers: Optional[int] = None,
) -> str:
node_id = _new_id("node")
meta = NodeMeta(
node_id=node_id,
name=name,
parent_id=parent_id,
children=[],
op=op,
created_at=time.time(),
source_node_id=source_node_id,
source_layer_idx=source_layer_idx,
sub_layers=sub_layers,
settings=settings or {},
)
if "nodes" not in state:
state["nodes"] = {}
state["nodes"][node_id] = {
"meta": asdict(meta),
"images": images,
}
if parent_id:
if parent_id in state["nodes"]:
state["nodes"][parent_id]["meta"]["children"].append(node_id)
return node_id
def _rename_node(state: Dict[str, Any], node_id: str, new_name: str):
if not new_name:
return
if "nodes" in state and node_id in state["nodes"]:
state["nodes"][node_id]["meta"]["name"] = new_name
def _duplicate_node(state: Dict[str, Any], node_id: str) -> Optional[str]:
if "nodes" not in state or node_id not in state["nodes"]:
return None
src = state["nodes"][node_id]
meta = src["meta"]
parent_id = meta.get("parent_id")
images = src.get("images", [])
name = f"{meta.get('name','node')} (copy)"
return _add_node(
state,
name=name,
parent_id=parent_id,
op="duplicate",
images=images,
settings=meta.get("settings") or {},
)
# -------------------------
# GPU duration + pipeline runner
# -------------------------
def get_duration(*args, **kwargs):
# Any kwargs may come from gradio/spaces wrapper; we only care about gpu_duration
gpu_duration = kwargs.get("gpu_duration", 1000)
return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500)
@spaces.GPU(duration=get_duration)
def gpu_run_pipeline(
pil_image_rgba: Image.Image,
seed=777,
randomize_seed=False,
prompt=None,
neg_prompt=" ",
true_guidance_scale=4.0,
num_inference_steps=50,
layer=4,
cfg_norm=True,
use_en_prompt=True,
resolution=640,
gpu_duration=1000,
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
resolution = _normalize_resolution(resolution)
gen_device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device=gen_device).manual_seed(int(seed))
inputs = {
"image": pil_image_rgba,
"generator": generator,
"true_cfg_scale": true_guidance_scale,
"prompt": prompt,
"negative_prompt": neg_prompt,
"num_inference_steps": int(num_inference_steps),
"num_images_per_prompt": 1,
"layers": int(layer),
"resolution": int(resolution), # 640 or 1024
"cfg_normalize": bool(cfg_norm),
"use_en_prompt": bool(use_en_prompt),
}
# best-effort to reduce weird allocator hiccups on ZeroGPU
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
with torch.inference_mode():
out = pipeline(**inputs)
output_images = out.images[0] # list of PIL images (layers)
return output_images, int(seed), inputs
# -------------------------
# Dataset persistence: save/load nodes + session
# -------------------------
def _pil_to_png_bytes(img: Image.Image) -> bytes:
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def _png_bytes_to_pil(b: bytes) -> Image.Image:
return Image.open(io.BytesIO(b)).convert("RGBA")
def _session_base(session_id: str) -> str:
return f"sessions/{session_id}"
def _node_base(session_id: str, node_id: str) -> str:
return f"{_session_base(session_id)}/nodes/{node_id}"
def _persist_node_to_dataset(state: Dict[str, Any], node_id: str) -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled. Set DATASET_REPO env var and provide HF_TOKEN/hf."
ok, msg = ds_ensure_repo()
if not ok:
return False, msg
session_id = state.get("session_id")
if not session_id:
return False, "No session_id in state (run Decompose first)"
nodes = state.get("nodes", {})
if node_id not in nodes:
return False, "Unknown node_id"
node = nodes[node_id]
meta = node["meta"]
imgs: List[Image.Image] = node.get("images", []) or []
# upload node.json
node_json = json.dumps(meta, ensure_ascii=False, indent=2).encode("utf-8")
path_node_json = f"{_node_base(session_id, node_id)}/node.json"
ok1, msg1 = ds_upload_bytes(path_node_json, node_json, f"save node {node_id}")
if not ok1:
return False, msg1
# upload images
for i, img in enumerate(imgs):
b = _pil_to_png_bytes(img)
path_img = f"{_node_base(session_id, node_id)}/layer_{i+1}.png"
ok2, msg2 = ds_upload_bytes(path_img, b, f"save node {node_id} layer {i+1}")
if not ok2:
return False, msg2
return True, f"Saved node {node_id} to dataset"
def _persist_session_manifest(state: Dict[str, Any]) -> Tuple[bool, str]:
if not ds_enabled():
return False, "Dataset persistence disabled"
ok, msg = ds_ensure_repo()
if not ok:
return False, msg
session_id = state.get("session_id")
if not session_id:
return False, "No session_id"
# Only manifest (no images)
manifest = {
"session_id": session_id,
"created_at": state.get("created_at"),
"root_node_id": state.get("root_node_id"),
"selected_node_id": state.get("selected_node_id"),
"nodes": {
nid: {
"meta": obj["meta"],
"num_layers": len(obj.get("images", []) or []),
}
for nid, obj in (state.get("nodes", {}) or {}).items()
},
}
b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8")
path = f"{_session_base(session_id)}/session.json"
ok_m, msg_m = ds_upload_bytes(path, b, f"save session manifest {session_id}")
if not ok_m:
return False, msg_m
# Update root index.json (best-effort, do not fail save if it can't be updated)
ok_i, msg_i = ds_set_root_index(session_id)
if not ok_i:
return True, f"{msg_m} (warning: index.json update failed: {msg_i})"
return True, f"{msg_m} + updated index.json"
def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]:
b, msg = ds_download_bytes(f"{_session_base(session_id)}/session.json")
if b is None:
return None, msg
try:
return json.loads(b.decode("utf-8")), "OK"
except Exception as e:
return None, f"Failed to parse manifest: {repr(e)}"
def _load_node_images(session_id: str, node_id: str, num_layers: int) -> Tuple[List[Image.Image], str]:
imgs: List[Image.Image] = []
for i in range(num_layers):
b, msg = ds_download_bytes(f"{_node_base(session_id, node_id)}/layer_{i+1}.png")
if b is None:
return [], msg
imgs.append(_png_bytes_to_pil(b))
return imgs, "OK"
# -------------------------
# UI callbacks
# -------------------------
def _init_state() -> Dict[str, Any]:
return {
"session_id": None,
"created_at": None,
"root_node_id": None,
"selected_node_id": None,
"nodes": {}, # node_id -> {meta, images}
"last_refined_node_id": None,
}
def _persistence_status_text() -> str:
tok = _get_hf_token()
repo = _get_dataset_repo()
if tok and repo:
return f"✅ Dataset persistence enabled: `{repo}`"
if repo and not tok:
return "⚠️ Dataset repo set, but HF_TOKEN/hf missing"
if tok and not repo:
return "⚠️ HF_TOKEN/hf set, but DATASET_REPO missing"
return "⚠️ Dataset persistence disabled (set HF_TOKEN + DATASET_REPO secrets to enable)"
def on_refresh_sessions():
sessions, msg = ds_list_sessions()
return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg
def on_init_dataset():
ok, msg = ds_ensure_repo()
return msg
def _current_node_export(state: Dict[str, Any], node_id: str) -> Tuple[Optional[str], Optional[str], str]:
imgs = _get_node_images(state, node_id)
if not imgs:
return None, None, "No images to export"
pptx_path = imagelist_to_pptx_from_pils(imgs)
zip_path = imagelist_to_zip_from_pils(imgs, prefix=f"{node_id}_layer")
return pptx_path, zip_path, "OK"
def _build_layer_dropdown(n: int) -> Tuple[List[str], str]:
if n <= 0:
return [], None
choices = [f"Layer {i+1}" for i in range(n)]
return choices, choices[0]
def _layer_label(idx: int, n: int) -> str:
if n <= 0:
return "Selected: -"
idx = max(0, min(n - 1, idx))
return f"Selected: Layer {idx+1} / {n}"
def on_decompose_click(
state: Dict[str, Any],
input_image,
seed,
randomize_seed,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
layer,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
):
if state is None or not isinstance(state, dict):
state = _init_state()
pil_image = _normalize_input_image(input_image)
# Create session on first run
if not state.get("session_id"):
state["session_id"] = _new_id("sess")
state["created_at"] = time.time()
# Run pipeline
layers_out, used_seed, used_inputs = gpu_run_pipeline(
pil_image_rgba=pil_image,
seed=seed,
randomize_seed=randomize_seed,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
layer=layer,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
)
settings_snapshot = {
"seed": used_seed,
"randomize_seed": bool(randomize_seed),
"prompt": prompt,
"neg_prompt": neg_prompt,
"true_guidance_scale": float(true_guidance_scale),
"num_inference_steps": int(num_inference_steps),
"layers": int(layer),
"resolution": int(_normalize_resolution(resolution)),
"cfg_norm": bool(cfg_norm),
"use_en_prompt": bool(use_en_prompt),
"gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)),
}
# Reset history for new decompose
state["nodes"] = {}
state["last_refined_node_id"] = None
root_id = _add_node(
state,
name="root (decompose)",
parent_id=None,
op="decompose",
images=layers_out,
settings=settings_snapshot,
)
state["root_node_id"] = root_id
state["selected_node_id"] = root_id
# UI outputs
n_layers = len(layers_out)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
# refined block hidden on fresh decompose
refined_visible = gr.update(visible=False)
refined_gallery = []
# exports for current node
pptx_path, zip_path, exp_msg = _current_node_export(state, root_id)
status = f"Decomposed into {n_layers} layer(s). Seed={used_seed}. {exp_msg}"
return (
state,
layers_out, # main layers gallery
layers_out, # mini picker gallery
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0), # selected_layer_idx state (as Number)
selected_label, # label
gr.update(choices=[c[1] for c in hist_choices], value=root_id),
chips,
refined_visible,
refined_gallery,
pptx_path,
zip_path,
status,
str(used_seed),
)
def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str):
node_id = state.get("selected_node_id")
imgs = _get_node_images(state, node_id) if node_id else []
n = len(imgs)
if not layer_name or not layer_name.startswith("Layer "):
idx = 0
else:
try:
idx = int(layer_name.replace("Layer ", "").strip()) - 1
except Exception:
idx = 0
idx = max(0, min(n - 1, idx)) if n > 0 else 0
return gr.update(value=idx), _layer_label(idx, n)
def on_layer_pick_from_gallery(state: Dict[str, Any], evt: gr.SelectData):
node_id = state.get("selected_node_id")
imgs = _get_node_images(state, node_id) if node_id else []
n = len(imgs)
idx = int(evt.index) if evt and evt.index is not None else 0
idx = max(0, min(n - 1, idx)) if n > 0 else 0
dd_value = f"Layer {idx+1}" if n > 0 else None
return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n)
def _refine_from_source(
state: Dict[str, Any],
source_node_id: str,
source_layer_idx: int,
sub_layers: int,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
):
src_imgs = _get_node_images(state, source_node_id)
if not src_imgs:
raise ValueError("Source node has no images")
if source_layer_idx < 0 or source_layer_idx >= len(src_imgs):
raise ValueError("Invalid layer index")
selected_layer_img = src_imgs[source_layer_idx]
layers_out, used_seed, used_inputs = gpu_run_pipeline(
pil_image_rgba=selected_layer_img,
seed=seed,
randomize_seed=randomize_seed,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
layer=sub_layers,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
)
settings_snapshot = {
"seed": used_seed,
"randomize_seed": bool(randomize_seed),
"prompt": prompt,
"neg_prompt": neg_prompt,
"true_guidance_scale": float(true_guidance_scale),
"num_inference_steps": int(num_inference_steps),
"layers": int(sub_layers),
"resolution": int(_normalize_resolution(resolution)),
"cfg_norm": bool(cfg_norm),
"use_en_prompt": bool(use_en_prompt),
"gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)),
"refined_from": {
"source_node_id": source_node_id,
"source_layer_idx": int(source_layer_idx),
},
}
return layers_out, used_seed, settings_snapshot
def on_refine_click(
state: Dict[str, Any],
selected_layer_idx: int,
sub_layers: int,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
):
if not state.get("selected_node_id"):
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"No selected node. Run Decompose first.",
gr.update(),
)
source_node_id = state["selected_node_id"]
src_imgs = _get_node_images(state, source_node_id)
if not src_imgs:
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"Selected node has no images.",
gr.update(),
)
n = len(src_imgs)
idx = int(selected_layer_idx) if selected_layer_idx is not None else 0
idx = max(0, min(n - 1, idx))
sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10)
layers_out, used_seed, settings_snapshot = _refine_from_source(
state,
source_node_id=source_node_id,
source_layer_idx=idx,
sub_layers=sub_layers,
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
seed=seed,
randomize_seed=randomize_seed,
)
child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}"
child_id = _add_node(
state,
name=child_name,
parent_id=source_node_id,
op="refine",
images=layers_out,
settings=settings_snapshot,
source_node_id=source_node_id,
source_layer_idx=idx,
sub_layers=sub_layers,
)
state["selected_node_id"] = child_id
state["last_refined_node_id"] = child_id
n_layers = len(layers_out)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(state, child_id)
status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}"
refined_visible = gr.update(visible=True)
return (
state,
layers_out,
layers_out,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=child_id),
chips,
refined_visible,
layers_out,
pptx_path,
zip_path,
status,
gr.update(),
)
def on_history_select(state: Dict[str, Any], node_id: str):
if not node_id or not state.get("nodes") or node_id not in state["nodes"]:
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"Unknown node.",
)
state["selected_node_id"] = node_id
imgs = _get_node_images(state, node_id)
n_layers = len(imgs)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(state, node_id)
return (
state,
imgs,
imgs,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=node_id),
chips,
gr.update(visible=False),
[],
pptx_path,
zip_path,
f"Selected node: {node_id}. {exp_msg}",
)
def on_back_to_parent(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id or node_id not in state.get("nodes", {}):
return state, gr.update(), "No selected node."
parent = state["nodes"][node_id]["meta"].get("parent_id")
if not parent:
return state, gr.update(), "Already at root."
return on_history_select(state, parent)
def on_duplicate_node(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id:
return state, gr.update(), "No selected node."
new_id = _duplicate_node(state, node_id)
if not new_id:
return state, gr.update(), "Duplicate failed."
return on_history_select(state, new_id)
def on_rename_node(state: Dict[str, Any], new_name: str):
node_id = state.get("selected_node_id")
if not node_id:
return state, gr.update(), "No selected node."
_rename_node(state, node_id, new_name)
hist_choices = _history_choices(state)
chips = _make_chips(state)
return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed."
def on_redo_refine(
state: Dict[str, Any],
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
):
node_id = state.get("selected_node_id")
if not node_id or node_id not in state.get("nodes", {}):
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"No selected node.",
)
meta = state["nodes"][node_id]["meta"]
if meta.get("op") != "refine":
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"Redo works only for refine nodes.",
)
source_node_id = meta.get("source_node_id")
source_layer_idx = meta.get("source_layer_idx")
sub_layers = meta.get("sub_layers")
if not source_node_id or source_layer_idx is None or not sub_layers:
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"Missing refine metadata.",
)
layers_out, used_seed, settings_snapshot = _refine_from_source(
state,
source_node_id=source_node_id,
source_layer_idx=int(source_layer_idx),
sub_layers=int(sub_layers),
prompt=prompt,
neg_prompt=neg_prompt,
true_guidance_scale=true_guidance_scale,
num_inference_steps=num_inference_steps,
cfg_norm=cfg_norm,
use_en_prompt=use_en_prompt,
resolution=resolution,
gpu_duration=gpu_duration,
seed=seed,
randomize_seed=randomize_seed,
)
child_name = f"redo refine ({state['nodes'][source_node_id]['meta']['name']}) L{int(source_layer_idx)+1}"
child_id = _add_node(
state,
name=child_name,
parent_id=source_node_id,
op="refine",
images=layers_out,
settings=settings_snapshot,
source_node_id=source_node_id,
source_layer_idx=int(source_layer_idx),
sub_layers=int(sub_layers),
)
state["selected_node_id"] = child_id
state["last_refined_node_id"] = child_id
n_layers = len(layers_out)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(state)
chips = _make_chips(state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(state, child_id)
return (
state,
layers_out,
layers_out,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=child_id),
chips,
gr.update(visible=True),
layers_out,
pptx_path,
zip_path,
f"Redo refine done. Seed={used_seed}. {exp_msg}",
)
def on_export_selected(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id:
return None, None, "No selected node."
pptx_path, zip_path, msg = _current_node_export(state, node_id)
return pptx_path, zip_path, msg
def on_save_current(state: Dict[str, Any]):
node_id = state.get("selected_node_id")
if not node_id:
return "Nothing to save."
ok1, msg1 = _persist_node_to_dataset(state, node_id)
if not ok1:
return msg1
ok2, msg2 = _persist_session_manifest(state)
if not ok2:
return msg2
return f"✅ Saved node + session manifest. {msg1} | {msg2}"
def on_load_session(state: Dict[str, Any], session_id: str):
if not session_id:
return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, "Pick a session id."
manifest, msg = _load_session_manifest(session_id)
if manifest is None:
return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, msg
new_state = _init_state()
new_state["session_id"] = manifest.get("session_id") or session_id
new_state["created_at"] = manifest.get("created_at")
new_state["root_node_id"] = manifest.get("root_node_id")
new_state["selected_node_id"] = manifest.get("selected_node_id") or manifest.get("root_node_id")
nodes_meta = manifest.get("nodes", {}) or {}
for nid, obj in nodes_meta.items():
meta = obj.get("meta") or {}
new_state["nodes"][nid] = {"meta": meta, "images": []}
sel = new_state["selected_node_id"]
if not sel or sel not in nodes_meta:
sel = new_state["root_node_id"]
new_state["selected_node_id"] = sel
if sel and sel in nodes_meta:
num_layers = int(nodes_meta[sel].get("num_layers", 0))
imgs, msg2 = _load_node_images(session_id, sel, num_layers)
if imgs:
new_state["nodes"][sel]["images"] = imgs
else:
return new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load node images: {msg2}"
root = new_state.get("root_node_id")
if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]:
rl = int(nodes_meta[root].get("num_layers", 0))
rimgs, _ = _load_node_images(session_id, root, rl)
if rimgs:
new_state["nodes"][root]["images"] = rimgs
imgs = _get_node_images(new_state, sel) if sel else []
n_layers = len(imgs)
layer_choices, layer_value = _build_layer_dropdown(n_layers)
hist_choices = _history_choices(new_state)
chips = _make_chips(new_state)
selected_label = _layer_label(0, n_layers)
pptx_path, zip_path, exp_msg = _current_node_export(new_state, sel) if sel else (None, None, "No node")
return (
new_state,
imgs,
imgs,
gr.update(choices=layer_choices, value=layer_value),
gr.update(value=0),
selected_label,
gr.update(choices=[c[1] for c in hist_choices], value=sel),
chips,
gr.update(visible=False),
[],
pptx_path,
zip_path,
f"Loaded session {session_id}. {exp_msg}",
)
def on_history_need_images(state: Dict[str, Any], node_id: str):
if not node_id or node_id not in state.get("nodes", {}):
return state, "Unknown node."
imgs = state["nodes"][node_id].get("images", [])
if imgs:
return state, "OK"
session_id = state.get("session_id")
if not session_id:
return state, "No session_id."
manifest, msg = _load_session_manifest(session_id)
if not manifest:
return state, f"Cannot load manifest: {msg}"
node_obj = (manifest.get("nodes", {}) or {}).get(node_id, {})
num_layers = int(node_obj.get("num_layers", 0))
if num_layers <= 0:
return state, "No layers in manifest for this node."
imgs2, msg2 = _load_node_images(session_id, node_id, num_layers)
if not imgs2:
return state, f"Failed to load images: {msg2}"
state["nodes"][node_id]["images"] = imgs2
return state, "Loaded images."
def on_autoload_last_session(state: Dict[str, Any]):
if not isinstance(state, dict):
state = _init_state()
if not ds_enabled():
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
"Dataset persistence disabled (no autoload).",
)
sid, msg = ds_get_last_session_id()
if not sid:
return (
state,
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(visible=False),
[],
None,
None,
f"No last session to autoload: {msg}",
)
return on_load_session(state, sid)
# -------------------------
# Build UI
# -------------------------
ensure_dirname(LOG_DIR)
examples = [
"assets/test_images/1.png",
"assets/test_images/2.png",
"assets/test_images/3.png",
"assets/test_images/4.png",
"assets/test_images/5.png",
"assets/test_images/6.png",
"assets/test_images/7.png",
"assets/test_images/8.png",
"assets/test_images/9.png",
"assets/test_images/10.png",
"assets/test_images/11.png",
"assets/test_images/12.png",
"assets/test_images/13.png",
]
with gr.Blocks() as demo:
state = gr.State(_init_state())
gr.HTML(
'<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" '
'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">'
)
persistence_banner = gr.Markdown(_persistence_status_text())
with gr.Row():
btn_init_ds = gr.Button("Init dataset repo", variant="secondary")
btn_refresh_sessions = gr.Button("Refresh sessions", variant="secondary")
ds_status = gr.Markdown("")
with gr.Row():
load_session_dd = gr.Dropdown(
label="Load session (from dataset)",
choices=[],
value=None,
allow_custom_value=True,
)
btn_load_session = gr.Button("Load session", variant="primary")
gr.Markdown(
"""
The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded.
It is not designed to control the semantic content of individual layers explicitly.
"""
)
with gr.Row():
# LEFT
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", image_mode="RGBA")
with gr.Accordion("Advanced Settings", open=False):
prompt = gr.Textbox(
label="Prompt (Optional)",
placeholder="Please enter the prompt to descibe the image. (Optional)",
value="",
lines=2,
)
neg_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Please enter the negative prompt",
value=" ",
lines=2,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
true_guidance_scale = gr.Slider(
label="True guidance scale",
minimum=1.0,
maximum=10.0,
step=0.1,
value=4.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50, # DO NOT CHANGE
)
layer = gr.Slider(
label="Layers",
minimum=2,
maximum=10,
step=1,
value=7, # DO NOT CHANGE
)
resolution = gr.Radio(
label="Processing resolution",
choices=[640, 1024],
value=640, # DO NOT CHANGE
)
cfg_norm = gr.Checkbox(
label="Whether enable CFG normalization", value=True
)
use_en_prompt = gr.Checkbox(
label="Automatic caption language if no prompt provided, True for EN, False for ZH",
value=True,
)
gpu_duration = gr.Textbox(
label="GPU duration override (seconds, 20..1500)",
value="1000",
lines=1,
placeholder="e.g. 60, 120, 300, 1000, 1500",
)
btn_decompose = gr.Button("Decompose!", variant="primary")
with gr.Group():
gr.Markdown("### Refine (Recursive Decomposition)")
sub_layers = gr.Slider(
label="Sub-layers (Refine)",
minimum=2,
maximum=10,
step=1,
value=3,
)
btn_refine = gr.Button("Refine selected layer", variant="primary")
# RIGHT
with gr.Column(scale=2):
gr.Markdown("### Current node layers")
gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png")
gr.Markdown("### Layer picker (Photoshop-style)")
layer_picker = gr.Gallery(label="Pick a layer", columns=8, rows=1, format="png")
with gr.Row():
layer_dropdown = gr.Dropdown(label="Refine layer", choices=[], value=None)
selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, precision=0, interactive=False)
selected_layer_label = gr.Markdown("Selected: -")
with gr.Accordion("Refined layers (last refine)", open=True, visible=False) as refined_block:
refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png")
gr.Markdown("### History (nodes)")
with gr.Row():
history_dd = gr.Dropdown(label="Node id", choices=[], value=None)
chips_md = gr.Markdown("[root] [parent:-] [children:0]")
with gr.Row():
btn_back_parent = gr.Button("← back to parent", variant="secondary")
btn_redo = gr.Button("↺ redo refine", variant="secondary")
btn_duplicate = gr.Button("Duplicate node (branch)", variant="secondary")
with gr.Row():
rename_text = gr.Textbox(label="Branch name", value="", lines=1, placeholder="Type new name and click Rename")
btn_rename = gr.Button("Rename", variant="secondary")
with gr.Row():
btn_export = gr.Button("Export selected node (ZIP/PPTX)", variant="primary")
btn_save = gr.Button("Save selected node to dataset", variant="primary")
with gr.Row():
export_pptx = gr.File(label="Download PPTX")
export_zip = gr.File(label="Download ZIP")
status = gr.Markdown("")
seed_used = gr.Textbox(label="Seed used", value="", interactive=False)
# Examples
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[gallery, export_pptx, export_zip],
fn=lambda img: ([], None, None),
cache_examples=False,
run_on_click=False,
)
# Dataset buttons
btn_init_ds.click(fn=on_init_dataset, outputs=[ds_status])
btn_refresh_sessions.click(fn=on_refresh_sessions, outputs=[load_session_dd, ds_status])
# Load session
btn_load_session.click(
fn=on_load_session,
inputs=[state, load_session_dd],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
],
)
# Auto-load last session (reads root index.json; supports id/last_session_id/session_id)
demo.load(
fn=on_autoload_last_session,
inputs=[state],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
],
)
# Decompose
btn_decompose.click(
fn=on_decompose_click,
inputs=[
state,
input_image,
seed,
randomize_seed,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
layer,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
seed_used,
],
)
# Pick layer by clicking mini gallery
layer_picker.select(
fn=on_layer_pick_from_gallery,
inputs=[state],
outputs=[selected_layer_idx, layer_dropdown, selected_layer_label],
)
# Pick layer by dropdown
layer_dropdown.change(
fn=on_layer_pick_from_dropdown,
inputs=[state, layer_dropdown],
outputs=[selected_layer_idx, selected_layer_label],
)
# Refine
btn_refine.click(
fn=on_refine_click,
inputs=[
state,
selected_layer_idx,
sub_layers,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
seed_used,
],
)
# History select (lazy load images first)
def _history_select_with_lazy(state, node_id):
state, _ = on_history_need_images(state, node_id)
return on_history_select(state, node_id)
history_dd.change(
fn=_history_select_with_lazy,
inputs=[state, history_dd],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
],
)
# Back to parent
btn_back_parent.click(
fn=on_back_to_parent,
inputs=[state],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
],
)
# Duplicate
btn_duplicate.click(
fn=on_duplicate_node,
inputs=[state],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
],
)
# Redo refine
btn_redo.click(
fn=on_redo_refine,
inputs=[
state,
prompt,
neg_prompt,
true_guidance_scale,
num_inference_steps,
cfg_norm,
use_en_prompt,
resolution,
gpu_duration,
seed,
randomize_seed,
],
outputs=[
state,
gallery,
layer_picker,
layer_dropdown,
selected_layer_idx,
selected_layer_label,
history_dd,
chips_md,
refined_block,
refined_gallery,
export_pptx,
export_zip,
status,
],
)
# Rename
btn_rename.click(
fn=on_rename_node,
inputs=[state, rename_text],
outputs=[state, history_dd, chips_md, status],
)
# Export
btn_export.click(fn=on_export_selected, inputs=[state], outputs=[export_pptx, export_zip, status])
# Save to dataset
btn_save.click(fn=on_save_current, inputs=[state], outputs=[status])
# Queue (no unsupported args)
demo.queue()
if __name__ == "__main__":
demo.launch()