Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import io | |
| import json | |
| import time | |
| import uuid | |
| import random | |
| import tempfile | |
| import zipfile | |
| from dataclasses import dataclass, asdict | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from PIL import Image | |
| from pptx import Presentation | |
| from diffusers import QwenImageLayeredPipeline | |
| from huggingface_hub import HfApi, login | |
| from huggingface_hub.utils import HfHubHTTPError | |
| LOG_DIR = "/tmp/local" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # ------------------------- | |
| # HF auth (Spaces secrets) | |
| # ------------------------- | |
| def _get_hf_token() -> Optional[str]: | |
| # priority: HF_TOKEN -> hf -> HUGGINGFACEHUB_API_TOKEN | |
| return ( | |
| os.environ.get("HF_TOKEN") | |
| or os.environ.get("hf") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| def _get_dataset_repo() -> Optional[str]: | |
| # priority: DATASET_REPO -> HF_DATASET_REPO | |
| return os.environ.get("DATASET_REPO") or os.environ.get("HF_DATASET_REPO") | |
| HF_TOKEN = _get_hf_token() | |
| DATASET_REPO = _get_dataset_repo() | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| except Exception as e: | |
| print("HF login failed:", repr(e)) | |
| # ------------------------- | |
| # Pipeline init | |
| # ------------------------- | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipeline = QwenImageLayeredPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Layered", torch_dtype=dtype | |
| ).to(device) | |
| def ensure_dirname(path: str): | |
| if path and not os.path.exists(path): | |
| os.makedirs(path, exist_ok=True) | |
| def px_to_emu(px, dpi=96): | |
| inch = px / dpi | |
| emu = inch * 914400 | |
| return int(emu) | |
| def imagelist_to_pptx_from_pils(images: List[Image.Image]) -> str: | |
| if not images: | |
| raise ValueError("No images to export") | |
| w, h = images[0].size | |
| prs = Presentation() | |
| prs.slide_width = px_to_emu(w) | |
| prs.slide_height = px_to_emu(h) | |
| slide = prs.slides.add_slide(prs.slide_layouts[6]) | |
| left = top = 0 | |
| # each layer as a picture on the same slide (stacked) | |
| for img in images: | |
| tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| img.save(tmp.name) | |
| slide.shapes.add_picture( | |
| tmp.name, | |
| left, | |
| top, | |
| width=px_to_emu(w), | |
| height=px_to_emu(h), | |
| ) | |
| out = tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) | |
| prs.save(out.name) | |
| return out.name | |
| def imagelist_to_zip_from_pils(images: List[Image.Image], prefix: str = "layer") -> str: | |
| outzip = tempfile.NamedTemporaryFile(suffix=".zip", delete=False) | |
| with zipfile.ZipFile(outzip.name, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| for i, img in enumerate(images): | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| zipf.writestr(f"{prefix}_{i+1}.png", buf.getvalue()) | |
| return outzip.name | |
| def _clamp_int(x, default: int, lo: int, hi: int) -> int: | |
| try: | |
| v = int(x) | |
| except Exception: | |
| v = default | |
| return max(lo, min(hi, v)) | |
| def _normalize_resolution(resolution: Any) -> int: | |
| resolution = _clamp_int(resolution, default=640, lo=640, hi=1024) | |
| if resolution not in (640, 1024): | |
| resolution = 640 | |
| return resolution | |
| def _normalize_input_image(input_image: Any) -> Image.Image: | |
| # Normalize image input | |
| if isinstance(input_image, list): | |
| input_image = input_image[0] | |
| if isinstance(input_image, str): | |
| pil_image = Image.open(input_image).convert("RGB").convert("RGBA") | |
| elif isinstance(input_image, Image.Image): | |
| pil_image = input_image.convert("RGB").convert("RGBA") | |
| elif isinstance(input_image, np.ndarray): | |
| pil_image = Image.fromarray(input_image).convert("RGB").convert("RGBA") | |
| else: | |
| raise ValueError(f"Unsupported input_image type: {type(input_image)}") | |
| return pil_image | |
| # ------------------------- | |
| # Dataset persistence helpers | |
| # ------------------------- | |
| def ds_enabled() -> bool: | |
| return bool(_get_hf_token()) and bool(_get_dataset_repo()) | |
| def ds_api() -> HfApi: | |
| token = _get_hf_token() | |
| if not token: | |
| raise RuntimeError("HF token missing") | |
| return HfApi(token=token) | |
| def ds_repo_id() -> str: | |
| repo = _get_dataset_repo() | |
| if not repo: | |
| raise RuntimeError("DATASET_REPO missing") | |
| return repo | |
| def ds_ensure_repo() -> Tuple[bool, str]: | |
| """ | |
| Try to create dataset repo if missing. | |
| Returns (ok, message) | |
| """ | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled: missing HF_TOKEN/hf or DATASET_REPO" | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True) | |
| return True, f"Dataset repo ready: {repo_id}" | |
| except HfHubHTTPError as e: | |
| return False, f"Failed to create/ensure dataset repo: {e}" | |
| except Exception as e: | |
| return False, f"Failed to create/ensure dataset repo: {repr(e)}" | |
| def ds_upload_bytes(path_in_repo: str, data: bytes, commit_message: str) -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled: missing HF_TOKEN/hf or DATASET_REPO" | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp.write(data) | |
| tmp.flush() | |
| api.upload_file( | |
| path_or_fileobj=tmp.name, | |
| path_in_repo=path_in_repo, | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| commit_message=commit_message, | |
| ) | |
| return True, f"Uploaded: {path_in_repo}" | |
| except HfHubHTTPError as e: | |
| return False, f"Upload failed (HTTP): {e}" | |
| except Exception as e: | |
| return False, f"Upload failed: {repr(e)}" | |
| def ds_download_bytes(path_in_repo: str) -> Tuple[Optional[bytes], str]: | |
| if not ds_enabled(): | |
| return None, "Dataset persistence disabled" | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| # Download to temp file | |
| tmpdir = tempfile.mkdtemp() | |
| local_path = api.hf_hub_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| filename=path_in_repo, | |
| local_dir=tmpdir, | |
| ) | |
| with open(local_path, "rb") as f: | |
| return f.read(), "OK" | |
| except HfHubHTTPError as e: | |
| return None, f"Download failed (HTTP): {e}" | |
| except Exception as e: | |
| return None, f"Download failed: {repr(e)}" | |
| def ds_list_sessions(max_sessions: int = 50) -> Tuple[List[str], str]: | |
| if not ds_enabled(): | |
| return [], "Dataset persistence disabled" | |
| api = ds_api() | |
| repo_id = ds_repo_id() | |
| try: | |
| files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") | |
| sess = set() | |
| for p in files: | |
| if p.startswith("sessions/") and p.endswith("session.json"): | |
| parts = p.split("/") | |
| if len(parts) >= 3: | |
| sess.add(parts[1]) | |
| out = sorted(sess, reverse=True)[:max_sessions] | |
| return out, f"Found {len(out)} session(s)" | |
| except Exception as e: | |
| return [], f"List sessions failed: {repr(e)}" | |
| def ds_get_root_index() -> Tuple[Optional[Dict[str, Any]], str]: | |
| """ | |
| Read dataset root index.json (used for "last session"). | |
| Expected keys: id or last_session_id or session_id. | |
| """ | |
| b, msg = ds_download_bytes("index.json") | |
| if b is None: | |
| return None, msg | |
| try: | |
| obj = json.loads(b.decode("utf-8")) | |
| if not isinstance(obj, dict): | |
| return None, "index.json is not an object" | |
| return obj, "OK" | |
| except Exception as e: | |
| return None, f"Failed to parse index.json: {repr(e)}" | |
| def ds_get_last_session_id() -> Tuple[Optional[str], str]: | |
| idx, msg = ds_get_root_index() | |
| if not idx: | |
| return None, msg | |
| for k in ("id", "last_session_id", "session_id"): | |
| v = idx.get(k) | |
| if isinstance(v, str) and v.strip(): | |
| return v.strip(), "OK" | |
| return None, "index.json missing id/last_session_id/session_id" | |
| def ds_set_root_index(session_id: str) -> Tuple[bool, str]: | |
| """ | |
| Update dataset root index.json so we can auto-load the last session after refresh/restart. | |
| """ | |
| payload = { | |
| "id": session_id, | |
| "last_session_id": session_id, | |
| "updated_at": time.time(), | |
| } | |
| b = json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8") | |
| return ds_upload_bytes("index.json", b, f"update index.json {session_id}") | |
| # ------------------------- | |
| # Node / History model | |
| # ------------------------- | |
| class NodeMeta: | |
| node_id: str | |
| name: str | |
| parent_id: Optional[str] | |
| children: List[str] | |
| op: str # "root" | "decompose" | "refine" | "duplicate" | |
| created_at: float | |
| # for redo refine | |
| source_node_id: Optional[str] = None | |
| source_layer_idx: Optional[int] = None | |
| sub_layers: Optional[int] = None | |
| # settings snapshot (store as dict for export/debug) | |
| settings: Optional[Dict[str, Any]] = None | |
| def _new_id(prefix: str) -> str: | |
| return f"{prefix}_{uuid.uuid4().hex[:10]}" | |
| def _make_chips(state: Dict[str, Any]) -> str: | |
| node_id = state.get("selected_node_id") | |
| nodes: Dict[str, Any] = state.get("nodes", {}) | |
| if not node_id or node_id not in nodes: | |
| return "[root] [parent:-] [children:0]" | |
| meta = nodes[node_id]["meta"] | |
| parent = meta.get("parent_id") or "-" | |
| children = meta.get("children") or [] | |
| root = state.get("root_node_id") or "-" | |
| parts = [] | |
| parts.append(f"[root:{root}]") | |
| parts.append(f"[parent:{parent}]") | |
| parts.append(f"[children:{len(children)}]") | |
| return " ".join(parts) | |
| def _history_choices(state: Dict[str, Any]) -> List[Tuple[str, str]]: | |
| # return list of (label, value) | |
| nodes: Dict[str, Any] = state.get("nodes", {}) | |
| # stable-ish order by created_at | |
| items = [] | |
| for nid, obj in nodes.items(): | |
| meta = obj["meta"] | |
| items.append((meta["created_at"], nid, meta["name"])) | |
| items.sort(key=lambda x: x[0]) | |
| return [(f"{name} — {nid}", nid) for _, nid, name in items] | |
| def _get_node_images(state: Dict[str, Any], node_id: str) -> List[Image.Image]: | |
| nodes: Dict[str, Any] = state.get("nodes", {}) | |
| if node_id not in nodes: | |
| return [] | |
| return nodes[node_id].get("images", []) or [] | |
| def _set_selected_node(state: Dict[str, Any], node_id: str): | |
| state["selected_node_id"] = node_id | |
| return state | |
| def _add_node( | |
| state: Dict[str, Any], | |
| *, | |
| name: str, | |
| parent_id: Optional[str], | |
| op: str, | |
| images: List[Image.Image], | |
| settings: Optional[Dict[str, Any]] = None, | |
| source_node_id: Optional[str] = None, | |
| source_layer_idx: Optional[int] = None, | |
| sub_layers: Optional[int] = None, | |
| ) -> str: | |
| node_id = _new_id("node") | |
| meta = NodeMeta( | |
| node_id=node_id, | |
| name=name, | |
| parent_id=parent_id, | |
| children=[], | |
| op=op, | |
| created_at=time.time(), | |
| source_node_id=source_node_id, | |
| source_layer_idx=source_layer_idx, | |
| sub_layers=sub_layers, | |
| settings=settings or {}, | |
| ) | |
| if "nodes" not in state: | |
| state["nodes"] = {} | |
| state["nodes"][node_id] = { | |
| "meta": asdict(meta), | |
| "images": images, | |
| } | |
| if parent_id: | |
| if parent_id in state["nodes"]: | |
| state["nodes"][parent_id]["meta"]["children"].append(node_id) | |
| return node_id | |
| def _rename_node(state: Dict[str, Any], node_id: str, new_name: str): | |
| if not new_name: | |
| return | |
| if "nodes" in state and node_id in state["nodes"]: | |
| state["nodes"][node_id]["meta"]["name"] = new_name | |
| def _duplicate_node(state: Dict[str, Any], node_id: str) -> Optional[str]: | |
| if "nodes" not in state or node_id not in state["nodes"]: | |
| return None | |
| src = state["nodes"][node_id] | |
| meta = src["meta"] | |
| parent_id = meta.get("parent_id") | |
| images = src.get("images", []) | |
| name = f"{meta.get('name','node')} (copy)" | |
| return _add_node( | |
| state, | |
| name=name, | |
| parent_id=parent_id, | |
| op="duplicate", | |
| images=images, | |
| settings=meta.get("settings") or {}, | |
| ) | |
| # ------------------------- | |
| # GPU duration + pipeline runner | |
| # ------------------------- | |
| def get_duration(*args, **kwargs): | |
| # Any kwargs may come from gradio/spaces wrapper; we only care about gpu_duration | |
| gpu_duration = kwargs.get("gpu_duration", 1000) | |
| return _clamp_int(gpu_duration, default=1000, lo=20, hi=1500) | |
| def gpu_run_pipeline( | |
| pil_image_rgba: Image.Image, | |
| seed=777, | |
| randomize_seed=False, | |
| prompt=None, | |
| neg_prompt=" ", | |
| true_guidance_scale=4.0, | |
| num_inference_steps=50, | |
| layer=4, | |
| cfg_norm=True, | |
| use_en_prompt=True, | |
| resolution=640, | |
| gpu_duration=1000, | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| resolution = _normalize_resolution(resolution) | |
| gen_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| generator = torch.Generator(device=gen_device).manual_seed(int(seed)) | |
| inputs = { | |
| "image": pil_image_rgba, | |
| "generator": generator, | |
| "true_cfg_scale": true_guidance_scale, | |
| "prompt": prompt, | |
| "negative_prompt": neg_prompt, | |
| "num_inference_steps": int(num_inference_steps), | |
| "num_images_per_prompt": 1, | |
| "layers": int(layer), | |
| "resolution": int(resolution), # 640 or 1024 | |
| "cfg_normalize": bool(cfg_norm), | |
| "use_en_prompt": bool(use_en_prompt), | |
| } | |
| # best-effort to reduce weird allocator hiccups on ZeroGPU | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| with torch.inference_mode(): | |
| out = pipeline(**inputs) | |
| output_images = out.images[0] # list of PIL images (layers) | |
| return output_images, int(seed), inputs | |
| # ------------------------- | |
| # Dataset persistence: save/load nodes + session | |
| # ------------------------- | |
| def _pil_to_png_bytes(img: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| return buf.getvalue() | |
| def _png_bytes_to_pil(b: bytes) -> Image.Image: | |
| return Image.open(io.BytesIO(b)).convert("RGBA") | |
| def _session_base(session_id: str) -> str: | |
| return f"sessions/{session_id}" | |
| def _node_base(session_id: str, node_id: str) -> str: | |
| return f"{_session_base(session_id)}/nodes/{node_id}" | |
| def _persist_node_to_dataset(state: Dict[str, Any], node_id: str) -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled. Set DATASET_REPO env var and provide HF_TOKEN/hf." | |
| ok, msg = ds_ensure_repo() | |
| if not ok: | |
| return False, msg | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return False, "No session_id in state (run Decompose first)" | |
| nodes = state.get("nodes", {}) | |
| if node_id not in nodes: | |
| return False, "Unknown node_id" | |
| node = nodes[node_id] | |
| meta = node["meta"] | |
| imgs: List[Image.Image] = node.get("images", []) or [] | |
| # upload node.json | |
| node_json = json.dumps(meta, ensure_ascii=False, indent=2).encode("utf-8") | |
| path_node_json = f"{_node_base(session_id, node_id)}/node.json" | |
| ok1, msg1 = ds_upload_bytes(path_node_json, node_json, f"save node {node_id}") | |
| if not ok1: | |
| return False, msg1 | |
| # upload images | |
| for i, img in enumerate(imgs): | |
| b = _pil_to_png_bytes(img) | |
| path_img = f"{_node_base(session_id, node_id)}/layer_{i+1}.png" | |
| ok2, msg2 = ds_upload_bytes(path_img, b, f"save node {node_id} layer {i+1}") | |
| if not ok2: | |
| return False, msg2 | |
| return True, f"Saved node {node_id} to dataset" | |
| def _persist_session_manifest(state: Dict[str, Any]) -> Tuple[bool, str]: | |
| if not ds_enabled(): | |
| return False, "Dataset persistence disabled" | |
| ok, msg = ds_ensure_repo() | |
| if not ok: | |
| return False, msg | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return False, "No session_id" | |
| # Only manifest (no images) | |
| manifest = { | |
| "session_id": session_id, | |
| "created_at": state.get("created_at"), | |
| "root_node_id": state.get("root_node_id"), | |
| "selected_node_id": state.get("selected_node_id"), | |
| "nodes": { | |
| nid: { | |
| "meta": obj["meta"], | |
| "num_layers": len(obj.get("images", []) or []), | |
| } | |
| for nid, obj in (state.get("nodes", {}) or {}).items() | |
| }, | |
| } | |
| b = json.dumps(manifest, ensure_ascii=False, indent=2).encode("utf-8") | |
| path = f"{_session_base(session_id)}/session.json" | |
| ok_m, msg_m = ds_upload_bytes(path, b, f"save session manifest {session_id}") | |
| if not ok_m: | |
| return False, msg_m | |
| # Update root index.json (best-effort, do not fail save if it can't be updated) | |
| ok_i, msg_i = ds_set_root_index(session_id) | |
| if not ok_i: | |
| return True, f"{msg_m} (warning: index.json update failed: {msg_i})" | |
| return True, f"{msg_m} + updated index.json" | |
| def _load_session_manifest(session_id: str) -> Tuple[Optional[Dict[str, Any]], str]: | |
| b, msg = ds_download_bytes(f"{_session_base(session_id)}/session.json") | |
| if b is None: | |
| return None, msg | |
| try: | |
| return json.loads(b.decode("utf-8")), "OK" | |
| except Exception as e: | |
| return None, f"Failed to parse manifest: {repr(e)}" | |
| def _load_node_images(session_id: str, node_id: str, num_layers: int) -> Tuple[List[Image.Image], str]: | |
| imgs: List[Image.Image] = [] | |
| for i in range(num_layers): | |
| b, msg = ds_download_bytes(f"{_node_base(session_id, node_id)}/layer_{i+1}.png") | |
| if b is None: | |
| return [], msg | |
| imgs.append(_png_bytes_to_pil(b)) | |
| return imgs, "OK" | |
| # ------------------------- | |
| # UI callbacks | |
| # ------------------------- | |
| def _init_state() -> Dict[str, Any]: | |
| return { | |
| "session_id": None, | |
| "created_at": None, | |
| "root_node_id": None, | |
| "selected_node_id": None, | |
| "nodes": {}, # node_id -> {meta, images} | |
| "last_refined_node_id": None, | |
| } | |
| def _persistence_status_text() -> str: | |
| tok = _get_hf_token() | |
| repo = _get_dataset_repo() | |
| if tok and repo: | |
| return f"✅ Dataset persistence enabled: `{repo}`" | |
| if repo and not tok: | |
| return "⚠️ Dataset repo set, but HF_TOKEN/hf missing" | |
| if tok and not repo: | |
| return "⚠️ HF_TOKEN/hf set, but DATASET_REPO missing" | |
| return "⚠️ Dataset persistence disabled (set HF_TOKEN + DATASET_REPO secrets to enable)" | |
| def on_refresh_sessions(): | |
| sessions, msg = ds_list_sessions() | |
| return gr.update(choices=sessions, value=(sessions[0] if sessions else None)), msg | |
| def on_init_dataset(): | |
| ok, msg = ds_ensure_repo() | |
| return msg | |
| def _current_node_export(state: Dict[str, Any], node_id: str) -> Tuple[Optional[str], Optional[str], str]: | |
| imgs = _get_node_images(state, node_id) | |
| if not imgs: | |
| return None, None, "No images to export" | |
| pptx_path = imagelist_to_pptx_from_pils(imgs) | |
| zip_path = imagelist_to_zip_from_pils(imgs, prefix=f"{node_id}_layer") | |
| return pptx_path, zip_path, "OK" | |
| def _build_layer_dropdown(n: int) -> Tuple[List[str], str]: | |
| if n <= 0: | |
| return [], None | |
| choices = [f"Layer {i+1}" for i in range(n)] | |
| return choices, choices[0] | |
| def _layer_label(idx: int, n: int) -> str: | |
| if n <= 0: | |
| return "Selected: -" | |
| idx = max(0, min(n - 1, idx)) | |
| return f"Selected: Layer {idx+1} / {n}" | |
| def on_decompose_click( | |
| state: Dict[str, Any], | |
| input_image, | |
| seed, | |
| randomize_seed, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| layer, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| ): | |
| if state is None or not isinstance(state, dict): | |
| state = _init_state() | |
| pil_image = _normalize_input_image(input_image) | |
| # Create session on first run | |
| if not state.get("session_id"): | |
| state["session_id"] = _new_id("sess") | |
| state["created_at"] = time.time() | |
| # Run pipeline | |
| layers_out, used_seed, used_inputs = gpu_run_pipeline( | |
| pil_image_rgba=pil_image, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| layer=layer, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| ) | |
| settings_snapshot = { | |
| "seed": used_seed, | |
| "randomize_seed": bool(randomize_seed), | |
| "prompt": prompt, | |
| "neg_prompt": neg_prompt, | |
| "true_guidance_scale": float(true_guidance_scale), | |
| "num_inference_steps": int(num_inference_steps), | |
| "layers": int(layer), | |
| "resolution": int(_normalize_resolution(resolution)), | |
| "cfg_norm": bool(cfg_norm), | |
| "use_en_prompt": bool(use_en_prompt), | |
| "gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)), | |
| } | |
| # Reset history for new decompose | |
| state["nodes"] = {} | |
| state["last_refined_node_id"] = None | |
| root_id = _add_node( | |
| state, | |
| name="root (decompose)", | |
| parent_id=None, | |
| op="decompose", | |
| images=layers_out, | |
| settings=settings_snapshot, | |
| ) | |
| state["root_node_id"] = root_id | |
| state["selected_node_id"] = root_id | |
| # UI outputs | |
| n_layers = len(layers_out) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| # refined block hidden on fresh decompose | |
| refined_visible = gr.update(visible=False) | |
| refined_gallery = [] | |
| # exports for current node | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, root_id) | |
| status = f"Decomposed into {n_layers} layer(s). Seed={used_seed}. {exp_msg}" | |
| return ( | |
| state, | |
| layers_out, # main layers gallery | |
| layers_out, # mini picker gallery | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), # selected_layer_idx state (as Number) | |
| selected_label, # label | |
| gr.update(choices=[c[1] for c in hist_choices], value=root_id), | |
| chips, | |
| refined_visible, | |
| refined_gallery, | |
| pptx_path, | |
| zip_path, | |
| status, | |
| str(used_seed), | |
| ) | |
| def on_layer_pick_from_dropdown(state: Dict[str, Any], layer_name: str): | |
| node_id = state.get("selected_node_id") | |
| imgs = _get_node_images(state, node_id) if node_id else [] | |
| n = len(imgs) | |
| if not layer_name or not layer_name.startswith("Layer "): | |
| idx = 0 | |
| else: | |
| try: | |
| idx = int(layer_name.replace("Layer ", "").strip()) - 1 | |
| except Exception: | |
| idx = 0 | |
| idx = max(0, min(n - 1, idx)) if n > 0 else 0 | |
| return gr.update(value=idx), _layer_label(idx, n) | |
| def on_layer_pick_from_gallery(state: Dict[str, Any], evt: gr.SelectData): | |
| node_id = state.get("selected_node_id") | |
| imgs = _get_node_images(state, node_id) if node_id else [] | |
| n = len(imgs) | |
| idx = int(evt.index) if evt and evt.index is not None else 0 | |
| idx = max(0, min(n - 1, idx)) if n > 0 else 0 | |
| dd_value = f"Layer {idx+1}" if n > 0 else None | |
| return gr.update(value=idx), gr.update(value=dd_value), _layer_label(idx, n) | |
| def _refine_from_source( | |
| state: Dict[str, Any], | |
| source_node_id: str, | |
| source_layer_idx: int, | |
| sub_layers: int, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ): | |
| src_imgs = _get_node_images(state, source_node_id) | |
| if not src_imgs: | |
| raise ValueError("Source node has no images") | |
| if source_layer_idx < 0 or source_layer_idx >= len(src_imgs): | |
| raise ValueError("Invalid layer index") | |
| selected_layer_img = src_imgs[source_layer_idx] | |
| layers_out, used_seed, used_inputs = gpu_run_pipeline( | |
| pil_image_rgba=selected_layer_img, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| layer=sub_layers, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| ) | |
| settings_snapshot = { | |
| "seed": used_seed, | |
| "randomize_seed": bool(randomize_seed), | |
| "prompt": prompt, | |
| "neg_prompt": neg_prompt, | |
| "true_guidance_scale": float(true_guidance_scale), | |
| "num_inference_steps": int(num_inference_steps), | |
| "layers": int(sub_layers), | |
| "resolution": int(_normalize_resolution(resolution)), | |
| "cfg_norm": bool(cfg_norm), | |
| "use_en_prompt": bool(use_en_prompt), | |
| "gpu_duration": int(_clamp_int(gpu_duration, 1000, 20, 1500)), | |
| "refined_from": { | |
| "source_node_id": source_node_id, | |
| "source_layer_idx": int(source_layer_idx), | |
| }, | |
| } | |
| return layers_out, used_seed, settings_snapshot | |
| def on_refine_click( | |
| state: Dict[str, Any], | |
| selected_layer_idx: int, | |
| sub_layers: int, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ): | |
| if not state.get("selected_node_id"): | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "No selected node. Run Decompose first.", | |
| gr.update(), | |
| ) | |
| source_node_id = state["selected_node_id"] | |
| src_imgs = _get_node_images(state, source_node_id) | |
| if not src_imgs: | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "Selected node has no images.", | |
| gr.update(), | |
| ) | |
| n = len(src_imgs) | |
| idx = int(selected_layer_idx) if selected_layer_idx is not None else 0 | |
| idx = max(0, min(n - 1, idx)) | |
| sub_layers = _clamp_int(sub_layers, default=3, lo=2, hi=10) | |
| layers_out, used_seed, settings_snapshot = _refine_from_source( | |
| state, | |
| source_node_id=source_node_id, | |
| source_layer_idx=idx, | |
| sub_layers=sub_layers, | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| ) | |
| child_name = f"refine ({state['nodes'][source_node_id]['meta']['name']}) L{idx+1}" | |
| child_id = _add_node( | |
| state, | |
| name=child_name, | |
| parent_id=source_node_id, | |
| op="refine", | |
| images=layers_out, | |
| settings=settings_snapshot, | |
| source_node_id=source_node_id, | |
| source_layer_idx=idx, | |
| sub_layers=sub_layers, | |
| ) | |
| state["selected_node_id"] = child_id | |
| state["last_refined_node_id"] = child_id | |
| n_layers = len(layers_out) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, child_id) | |
| status = f"Refined into {n_layers} sub-layer(s). Seed={used_seed}. {exp_msg}" | |
| refined_visible = gr.update(visible=True) | |
| return ( | |
| state, | |
| layers_out, | |
| layers_out, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=child_id), | |
| chips, | |
| refined_visible, | |
| layers_out, | |
| pptx_path, | |
| zip_path, | |
| status, | |
| gr.update(), | |
| ) | |
| def on_history_select(state: Dict[str, Any], node_id: str): | |
| if not node_id or not state.get("nodes") or node_id not in state["nodes"]: | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "Unknown node.", | |
| ) | |
| state["selected_node_id"] = node_id | |
| imgs = _get_node_images(state, node_id) | |
| n_layers = len(imgs) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, node_id) | |
| return ( | |
| state, | |
| imgs, | |
| imgs, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=node_id), | |
| chips, | |
| gr.update(visible=False), | |
| [], | |
| pptx_path, | |
| zip_path, | |
| f"Selected node: {node_id}. {exp_msg}", | |
| ) | |
| def on_back_to_parent(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id or node_id not in state.get("nodes", {}): | |
| return state, gr.update(), "No selected node." | |
| parent = state["nodes"][node_id]["meta"].get("parent_id") | |
| if not parent: | |
| return state, gr.update(), "Already at root." | |
| return on_history_select(state, parent) | |
| def on_duplicate_node(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return state, gr.update(), "No selected node." | |
| new_id = _duplicate_node(state, node_id) | |
| if not new_id: | |
| return state, gr.update(), "Duplicate failed." | |
| return on_history_select(state, new_id) | |
| def on_rename_node(state: Dict[str, Any], new_name: str): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return state, gr.update(), "No selected node." | |
| _rename_node(state, node_id, new_name) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| return state, gr.update(choices=[c[1] for c in hist_choices], value=node_id), chips, "Renamed." | |
| def on_redo_refine( | |
| state: Dict[str, Any], | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ): | |
| node_id = state.get("selected_node_id") | |
| if not node_id or node_id not in state.get("nodes", {}): | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "No selected node.", | |
| ) | |
| meta = state["nodes"][node_id]["meta"] | |
| if meta.get("op") != "refine": | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "Redo works only for refine nodes.", | |
| ) | |
| source_node_id = meta.get("source_node_id") | |
| source_layer_idx = meta.get("source_layer_idx") | |
| sub_layers = meta.get("sub_layers") | |
| if not source_node_id or source_layer_idx is None or not sub_layers: | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "Missing refine metadata.", | |
| ) | |
| layers_out, used_seed, settings_snapshot = _refine_from_source( | |
| state, | |
| source_node_id=source_node_id, | |
| source_layer_idx=int(source_layer_idx), | |
| sub_layers=int(sub_layers), | |
| prompt=prompt, | |
| neg_prompt=neg_prompt, | |
| true_guidance_scale=true_guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| cfg_norm=cfg_norm, | |
| use_en_prompt=use_en_prompt, | |
| resolution=resolution, | |
| gpu_duration=gpu_duration, | |
| seed=seed, | |
| randomize_seed=randomize_seed, | |
| ) | |
| child_name = f"redo refine ({state['nodes'][source_node_id]['meta']['name']}) L{int(source_layer_idx)+1}" | |
| child_id = _add_node( | |
| state, | |
| name=child_name, | |
| parent_id=source_node_id, | |
| op="refine", | |
| images=layers_out, | |
| settings=settings_snapshot, | |
| source_node_id=source_node_id, | |
| source_layer_idx=int(source_layer_idx), | |
| sub_layers=int(sub_layers), | |
| ) | |
| state["selected_node_id"] = child_id | |
| state["last_refined_node_id"] = child_id | |
| n_layers = len(layers_out) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(state) | |
| chips = _make_chips(state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(state, child_id) | |
| return ( | |
| state, | |
| layers_out, | |
| layers_out, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=child_id), | |
| chips, | |
| gr.update(visible=True), | |
| layers_out, | |
| pptx_path, | |
| zip_path, | |
| f"Redo refine done. Seed={used_seed}. {exp_msg}", | |
| ) | |
| def on_export_selected(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return None, None, "No selected node." | |
| pptx_path, zip_path, msg = _current_node_export(state, node_id) | |
| return pptx_path, zip_path, msg | |
| def on_save_current(state: Dict[str, Any]): | |
| node_id = state.get("selected_node_id") | |
| if not node_id: | |
| return "Nothing to save." | |
| ok1, msg1 = _persist_node_to_dataset(state, node_id) | |
| if not ok1: | |
| return msg1 | |
| ok2, msg2 = _persist_session_manifest(state) | |
| if not ok2: | |
| return msg2 | |
| return f"✅ Saved node + session manifest. {msg1} | {msg2}" | |
| def on_load_session(state: Dict[str, Any], session_id: str): | |
| if not session_id: | |
| return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, "Pick a session id." | |
| manifest, msg = _load_session_manifest(session_id) | |
| if manifest is None: | |
| return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, msg | |
| new_state = _init_state() | |
| new_state["session_id"] = manifest.get("session_id") or session_id | |
| new_state["created_at"] = manifest.get("created_at") | |
| new_state["root_node_id"] = manifest.get("root_node_id") | |
| new_state["selected_node_id"] = manifest.get("selected_node_id") or manifest.get("root_node_id") | |
| nodes_meta = manifest.get("nodes", {}) or {} | |
| for nid, obj in nodes_meta.items(): | |
| meta = obj.get("meta") or {} | |
| new_state["nodes"][nid] = {"meta": meta, "images": []} | |
| sel = new_state["selected_node_id"] | |
| if not sel or sel not in nodes_meta: | |
| sel = new_state["root_node_id"] | |
| new_state["selected_node_id"] = sel | |
| if sel and sel in nodes_meta: | |
| num_layers = int(nodes_meta[sel].get("num_layers", 0)) | |
| imgs, msg2 = _load_node_images(session_id, sel, num_layers) | |
| if imgs: | |
| new_state["nodes"][sel]["images"] = imgs | |
| else: | |
| return new_state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(visible=False), [], None, None, f"Loaded manifest but failed to load node images: {msg2}" | |
| root = new_state.get("root_node_id") | |
| if root and root != sel and root in nodes_meta and not new_state["nodes"][root]["images"]: | |
| rl = int(nodes_meta[root].get("num_layers", 0)) | |
| rimgs, _ = _load_node_images(session_id, root, rl) | |
| if rimgs: | |
| new_state["nodes"][root]["images"] = rimgs | |
| imgs = _get_node_images(new_state, sel) if sel else [] | |
| n_layers = len(imgs) | |
| layer_choices, layer_value = _build_layer_dropdown(n_layers) | |
| hist_choices = _history_choices(new_state) | |
| chips = _make_chips(new_state) | |
| selected_label = _layer_label(0, n_layers) | |
| pptx_path, zip_path, exp_msg = _current_node_export(new_state, sel) if sel else (None, None, "No node") | |
| return ( | |
| new_state, | |
| imgs, | |
| imgs, | |
| gr.update(choices=layer_choices, value=layer_value), | |
| gr.update(value=0), | |
| selected_label, | |
| gr.update(choices=[c[1] for c in hist_choices], value=sel), | |
| chips, | |
| gr.update(visible=False), | |
| [], | |
| pptx_path, | |
| zip_path, | |
| f"Loaded session {session_id}. {exp_msg}", | |
| ) | |
| def on_history_need_images(state: Dict[str, Any], node_id: str): | |
| if not node_id or node_id not in state.get("nodes", {}): | |
| return state, "Unknown node." | |
| imgs = state["nodes"][node_id].get("images", []) | |
| if imgs: | |
| return state, "OK" | |
| session_id = state.get("session_id") | |
| if not session_id: | |
| return state, "No session_id." | |
| manifest, msg = _load_session_manifest(session_id) | |
| if not manifest: | |
| return state, f"Cannot load manifest: {msg}" | |
| node_obj = (manifest.get("nodes", {}) or {}).get(node_id, {}) | |
| num_layers = int(node_obj.get("num_layers", 0)) | |
| if num_layers <= 0: | |
| return state, "No layers in manifest for this node." | |
| imgs2, msg2 = _load_node_images(session_id, node_id, num_layers) | |
| if not imgs2: | |
| return state, f"Failed to load images: {msg2}" | |
| state["nodes"][node_id]["images"] = imgs2 | |
| return state, "Loaded images." | |
| def on_autoload_last_session(state: Dict[str, Any]): | |
| if not isinstance(state, dict): | |
| state = _init_state() | |
| if not ds_enabled(): | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| "Dataset persistence disabled (no autoload).", | |
| ) | |
| sid, msg = ds_get_last_session_id() | |
| if not sid: | |
| return ( | |
| state, | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(visible=False), | |
| [], | |
| None, | |
| None, | |
| f"No last session to autoload: {msg}", | |
| ) | |
| return on_load_session(state, sid) | |
| # ------------------------- | |
| # Build UI | |
| # ------------------------- | |
| ensure_dirname(LOG_DIR) | |
| examples = [ | |
| "assets/test_images/1.png", | |
| "assets/test_images/2.png", | |
| "assets/test_images/3.png", | |
| "assets/test_images/4.png", | |
| "assets/test_images/5.png", | |
| "assets/test_images/6.png", | |
| "assets/test_images/7.png", | |
| "assets/test_images/8.png", | |
| "assets/test_images/9.png", | |
| "assets/test_images/10.png", | |
| "assets/test_images/11.png", | |
| "assets/test_images/12.png", | |
| "assets/test_images/13.png", | |
| ] | |
| with gr.Blocks() as demo: | |
| state = gr.State(_init_state()) | |
| gr.HTML( | |
| '<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/layered/qwen-image-layered-logo.png" ' | |
| 'alt="Qwen-Image-Layered Logo" width="600" style="display: block; margin: 0 auto;">' | |
| ) | |
| persistence_banner = gr.Markdown(_persistence_status_text()) | |
| with gr.Row(): | |
| btn_init_ds = gr.Button("Init dataset repo", variant="secondary") | |
| btn_refresh_sessions = gr.Button("Refresh sessions", variant="secondary") | |
| ds_status = gr.Markdown("") | |
| with gr.Row(): | |
| load_session_dd = gr.Dropdown( | |
| label="Load session (from dataset)", | |
| choices=[], | |
| value=None, | |
| allow_custom_value=True, | |
| ) | |
| btn_load_session = gr.Button("Load session", variant="primary") | |
| gr.Markdown( | |
| """ | |
| The text prompt is intended to describe the overall content of the input image—including elements that may be partially occluded. | |
| It is not designed to control the semantic content of individual layers explicitly. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # LEFT | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Input Image", image_mode="RGBA") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| prompt = gr.Textbox( | |
| label="Prompt (Optional)", | |
| placeholder="Please enter the prompt to descibe the image. (Optional)", | |
| value="", | |
| lines=2, | |
| ) | |
| neg_prompt = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="Please enter the negative prompt", | |
| value=" ", | |
| lines=2, | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| true_guidance_scale = gr.Slider( | |
| label="True guidance scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=4.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=100, | |
| step=1, | |
| value=50, # DO NOT CHANGE | |
| ) | |
| layer = gr.Slider( | |
| label="Layers", | |
| minimum=2, | |
| maximum=10, | |
| step=1, | |
| value=7, # DO NOT CHANGE | |
| ) | |
| resolution = gr.Radio( | |
| label="Processing resolution", | |
| choices=[640, 1024], | |
| value=640, # DO NOT CHANGE | |
| ) | |
| cfg_norm = gr.Checkbox( | |
| label="Whether enable CFG normalization", value=True | |
| ) | |
| use_en_prompt = gr.Checkbox( | |
| label="Automatic caption language if no prompt provided, True for EN, False for ZH", | |
| value=True, | |
| ) | |
| gpu_duration = gr.Textbox( | |
| label="GPU duration override (seconds, 20..1500)", | |
| value="1000", | |
| lines=1, | |
| placeholder="e.g. 60, 120, 300, 1000, 1500", | |
| ) | |
| btn_decompose = gr.Button("Decompose!", variant="primary") | |
| with gr.Group(): | |
| gr.Markdown("### Refine (Recursive Decomposition)") | |
| sub_layers = gr.Slider( | |
| label="Sub-layers (Refine)", | |
| minimum=2, | |
| maximum=10, | |
| step=1, | |
| value=3, | |
| ) | |
| btn_refine = gr.Button("Refine selected layer", variant="primary") | |
| # RIGHT | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Current node layers") | |
| gallery = gr.Gallery(label="Layers", columns=4, rows=1, format="png") | |
| gr.Markdown("### Layer picker (Photoshop-style)") | |
| layer_picker = gr.Gallery(label="Pick a layer", columns=8, rows=1, format="png") | |
| with gr.Row(): | |
| layer_dropdown = gr.Dropdown(label="Refine layer", choices=[], value=None) | |
| selected_layer_idx = gr.Number(label="Selected layer index (0-based)", value=0, precision=0, interactive=False) | |
| selected_layer_label = gr.Markdown("Selected: -") | |
| with gr.Accordion("Refined layers (last refine)", open=True, visible=False) as refined_block: | |
| refined_gallery = gr.Gallery(label="Refined layers", columns=4, rows=1, format="png") | |
| gr.Markdown("### History (nodes)") | |
| with gr.Row(): | |
| history_dd = gr.Dropdown(label="Node id", choices=[], value=None) | |
| chips_md = gr.Markdown("[root] [parent:-] [children:0]") | |
| with gr.Row(): | |
| btn_back_parent = gr.Button("← back to parent", variant="secondary") | |
| btn_redo = gr.Button("↺ redo refine", variant="secondary") | |
| btn_duplicate = gr.Button("Duplicate node (branch)", variant="secondary") | |
| with gr.Row(): | |
| rename_text = gr.Textbox(label="Branch name", value="", lines=1, placeholder="Type new name and click Rename") | |
| btn_rename = gr.Button("Rename", variant="secondary") | |
| with gr.Row(): | |
| btn_export = gr.Button("Export selected node (ZIP/PPTX)", variant="primary") | |
| btn_save = gr.Button("Save selected node to dataset", variant="primary") | |
| with gr.Row(): | |
| export_pptx = gr.File(label="Download PPTX") | |
| export_zip = gr.File(label="Download ZIP") | |
| status = gr.Markdown("") | |
| seed_used = gr.Textbox(label="Seed used", value="", interactive=False) | |
| # Examples | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image], | |
| outputs=[gallery, export_pptx, export_zip], | |
| fn=lambda img: ([], None, None), | |
| cache_examples=False, | |
| run_on_click=False, | |
| ) | |
| # Dataset buttons | |
| btn_init_ds.click(fn=on_init_dataset, outputs=[ds_status]) | |
| btn_refresh_sessions.click(fn=on_refresh_sessions, outputs=[load_session_dd, ds_status]) | |
| # Load session | |
| btn_load_session.click( | |
| fn=on_load_session, | |
| inputs=[state, load_session_dd], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| ], | |
| ) | |
| # Auto-load last session (reads root index.json; supports id/last_session_id/session_id) | |
| demo.load( | |
| fn=on_autoload_last_session, | |
| inputs=[state], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| ], | |
| ) | |
| # Decompose | |
| btn_decompose.click( | |
| fn=on_decompose_click, | |
| inputs=[ | |
| state, | |
| input_image, | |
| seed, | |
| randomize_seed, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| layer, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| ], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| seed_used, | |
| ], | |
| ) | |
| # Pick layer by clicking mini gallery | |
| layer_picker.select( | |
| fn=on_layer_pick_from_gallery, | |
| inputs=[state], | |
| outputs=[selected_layer_idx, layer_dropdown, selected_layer_label], | |
| ) | |
| # Pick layer by dropdown | |
| layer_dropdown.change( | |
| fn=on_layer_pick_from_dropdown, | |
| inputs=[state, layer_dropdown], | |
| outputs=[selected_layer_idx, selected_layer_label], | |
| ) | |
| # Refine | |
| btn_refine.click( | |
| fn=on_refine_click, | |
| inputs=[ | |
| state, | |
| selected_layer_idx, | |
| sub_layers, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| seed_used, | |
| ], | |
| ) | |
| # History select (lazy load images first) | |
| def _history_select_with_lazy(state, node_id): | |
| state, _ = on_history_need_images(state, node_id) | |
| return on_history_select(state, node_id) | |
| history_dd.change( | |
| fn=_history_select_with_lazy, | |
| inputs=[state, history_dd], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| ], | |
| ) | |
| # Back to parent | |
| btn_back_parent.click( | |
| fn=on_back_to_parent, | |
| inputs=[state], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| ], | |
| ) | |
| # Duplicate | |
| btn_duplicate.click( | |
| fn=on_duplicate_node, | |
| inputs=[state], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| ], | |
| ) | |
| # Redo refine | |
| btn_redo.click( | |
| fn=on_redo_refine, | |
| inputs=[ | |
| state, | |
| prompt, | |
| neg_prompt, | |
| true_guidance_scale, | |
| num_inference_steps, | |
| cfg_norm, | |
| use_en_prompt, | |
| resolution, | |
| gpu_duration, | |
| seed, | |
| randomize_seed, | |
| ], | |
| outputs=[ | |
| state, | |
| gallery, | |
| layer_picker, | |
| layer_dropdown, | |
| selected_layer_idx, | |
| selected_layer_label, | |
| history_dd, | |
| chips_md, | |
| refined_block, | |
| refined_gallery, | |
| export_pptx, | |
| export_zip, | |
| status, | |
| ], | |
| ) | |
| # Rename | |
| btn_rename.click( | |
| fn=on_rename_node, | |
| inputs=[state, rename_text], | |
| outputs=[state, history_dd, chips_md, status], | |
| ) | |
| # Export | |
| btn_export.click(fn=on_export_selected, inputs=[state], outputs=[export_pptx, export_zip, status]) | |
| # Save to dataset | |
| btn_save.click(fn=on_save_current, inputs=[state], outputs=[status]) | |
| # Queue (no unsupported args) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() |