| import contextlib | |
| import copy | |
| import asyncio | |
| import functools | |
| import inspect | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Sequence | |
| from shared.api import GeneratedArtifact, GenerationError, GenerationResult, PreviewUpdate, SessionJob, WanGPSession, _pushd | |
| _NO_YIELDED_RESULT = object() | |
| _GRADIO_LOG_PATCH_LOCK = threading.Lock() | |
| _ORIGINAL_GRADIO_LOG_MESSAGE = None | |
| _WRAPPED_LOG_LOCAL = threading.local() | |
| def _buffered_gradio_log_message(message: str, title: str, level: str = "info", duration: float | None = 10, visible: bool = True): | |
| from gradio.context import LocalContext | |
| blocks = LocalContext.blocks.get() | |
| event_id = LocalContext.event_id.get() | |
| if blocks is not None and event_id is not None: | |
| return _ORIGINAL_GRADIO_LOG_MESSAGE(message, title=title, level=level, duration=duration, visible=visible) | |
| owner = getattr(_WRAPPED_LOG_LOCAL, "owner", None) | |
| call_id = str(getattr(_WRAPPED_LOG_LOCAL, "call_id", "") or "").strip() | |
| if owner is not None and len(call_id) > 0: | |
| state = owner._get_wrapped_call(call_id) | |
| if state is not None: | |
| state.queue_log_message(message=message, title=title, level=level, duration=duration, visible=visible) | |
| return | |
| return _ORIGINAL_GRADIO_LOG_MESSAGE(message, title=title, level=level, duration=duration, visible=visible) | |
| def _ensure_gradio_log_message_patch() -> None: | |
| global _ORIGINAL_GRADIO_LOG_MESSAGE | |
| with _GRADIO_LOG_PATCH_LOCK: | |
| if _ORIGINAL_GRADIO_LOG_MESSAGE is not None: | |
| return | |
| import gradio.helpers as gr_helpers | |
| _ORIGINAL_GRADIO_LOG_MESSAGE = gr_helpers.log_message | |
| gr_helpers.log_message = _buffered_gradio_log_message | |
| def _normalize_queue_request(request): | |
| if request is None: | |
| return None | |
| try: | |
| from gradio import route_utils | |
| route_utils.get_api_call_path(request) | |
| return request | |
| except Exception: | |
| pass | |
| try: | |
| from fastapi import Request as FastAPIRequest | |
| except Exception: | |
| return request | |
| scope = dict(getattr(request, "scope", {}) or {}) | |
| if scope.get("type") != "http": | |
| return request | |
| queue_path = f"{route_utils.API_PREFIX}/queue/join" | |
| scope["path"] = queue_path | |
| scope["raw_path"] = queue_path.encode("utf-8") | |
| scope["query_string"] = b"" | |
| try: | |
| return FastAPIRequest(scope, request.receive) | |
| except Exception: | |
| return request | |
| class GradioProgressCallbacks: | |
| def __init__(self, progress) -> None: | |
| self._progress = progress | |
| self._ratio = 0.0 | |
| def on_status(self, status) -> None: | |
| status = str(status or "").strip() | |
| if status: | |
| self._progress(self._ratio, desc=status) | |
| def on_progress(self, update) -> None: | |
| self._ratio = max(0.0, min(1.0, float(getattr(update, "progress", 0)) / 100.0)) | |
| self._progress(self._ratio, desc=str(getattr(update, "status", "") or "Generating...")) | |
| class _WrappedCallState: | |
| def __init__(self, output_count: int) -> None: | |
| self.output_count = output_count | |
| self.done = threading.Event() | |
| self.result: Any = None | |
| self.has_result = False | |
| self.error: BaseException | None = None | |
| self.job: SessionJob | None = None | |
| self.abort_client_id = "" | |
| self._abort_client_ids: list[str] = [] | |
| self._abort_client_ids_lock = threading.Lock() | |
| self.callback_context_ready = threading.Event() | |
| self.callback_context: dict[str, Any] | None = None | |
| self._followup_jobs: list[SessionJob] = [] | |
| self._followup_lock = threading.Lock() | |
| self._followup_enabled = False | |
| self._primary_job_forwarded = False | |
| self._yielded_results: list[Any] = [] | |
| self._yielded_results_lock = threading.Lock() | |
| self._log_messages: list[dict[str, Any]] = [] | |
| self._log_messages_lock = threading.Lock() | |
| def set_result(self, result: Any) -> None: | |
| self.result = result | |
| self.has_result = True | |
| self.done.set() | |
| def set_completed(self) -> None: | |
| self.done.set() | |
| def set_error(self, error: BaseException) -> None: | |
| self.error = error | |
| self.done.set() | |
| def set_callback_context(self, context: dict[str, Any]) -> None: | |
| self.callback_context = dict(context) | |
| self.callback_context_ready.set() | |
| def enable_followup_queue_triggers(self) -> None: | |
| self._followup_enabled = True | |
| def add_followup_job(self, job: SessionJob) -> None: | |
| if not self._followup_enabled: | |
| return | |
| with self._followup_lock: | |
| self._followup_jobs.append(job) | |
| def pop_ready_followup_load_queue_token(self) -> str: | |
| with self._followup_lock: | |
| for index, job in enumerate(self._followup_jobs): | |
| if job.webui_submission_ready: | |
| self._followup_jobs.pop(index) | |
| return job.webui_load_queue_token | |
| return "" | |
| def pop_primary_load_queue_token(self) -> str: | |
| if self._primary_job_forwarded or self.job is None or not self.job.webui_submission_ready: | |
| return "" | |
| self._primary_job_forwarded = True | |
| self.enable_followup_queue_triggers() | |
| return self.job.webui_load_queue_token | |
| def queue_abort_client_id(self, client_id: str) -> None: | |
| client_id = str(client_id or "").strip() | |
| if len(client_id) == 0: | |
| return | |
| with self._abort_client_ids_lock: | |
| self._abort_client_ids.append(client_id) | |
| def pop_abort_client_id(self) -> str: | |
| with self._abort_client_ids_lock: | |
| if not self._abort_client_ids: | |
| return "" | |
| client_id = self._abort_client_ids.pop(0) | |
| self.abort_client_id = client_id | |
| return client_id | |
| def push_yielded_result(self, result: Any) -> None: | |
| with self._yielded_results_lock: | |
| self._yielded_results.append(result) | |
| def pop_yielded_result(self) -> Any: | |
| with self._yielded_results_lock: | |
| if not self._yielded_results: | |
| return _NO_YIELDED_RESULT | |
| return self._yielded_results.pop(0) | |
| def queue_log_message(self, *, message: str, title: str, level: str, duration: float | None, visible: bool) -> None: | |
| with self._log_messages_lock: | |
| self._log_messages.append({"log": str(message or ""), "title": str(title or ""), "level": str(level or "info"), "duration": duration, "visible": bool(visible)}) | |
| def pop_log_messages(self) -> list[dict[str, Any]]: | |
| with self._log_messages_lock: | |
| if not self._log_messages: | |
| return [] | |
| messages = list(self._log_messages) | |
| self._log_messages.clear() | |
| return messages | |
| class _BoundGradioCallbacks: | |
| def __init__(self, callbacks: object, state: _WrappedCallState, owner: "GradioWanGPSession") -> None: | |
| self._callbacks = callbacks | |
| self._state = state | |
| self._owner = owner | |
| def __getattr__(self, name: str) -> Any: | |
| target = getattr(self._callbacks, name) | |
| if not callable(target): | |
| return target | |
| def wrapped(*args, **kwargs): | |
| self._state.callback_context_ready.wait(timeout=30.0) | |
| context = self._state.callback_context | |
| if not isinstance(context, dict): | |
| return target(*args, **kwargs) | |
| with self._owner._push_callback_context(context): | |
| return target(*args, **kwargs) | |
| return wrapped | |
| class WebUIQueueProbe: | |
| _POLL_INTERVAL_SECONDS = 0.2 | |
| _MISSING_OUTPUT_TIMEOUT_SECONDS = 5.0 | |
| _QUEUE_ADMISSION_SUSPEND_NOTICE_SECONDS = 10.0 | |
| _INLINE_QUEUE_SLOT_TIMEOUT_SECONDS = 10.0 | |
| _CANCEL_GRACE_SECONDS = 1.0 | |
| def __init__(self, session: WanGPSession, runtime, tasks: list[dict[str, Any]], job: SessionJob) -> None: | |
| self._session = session | |
| self._runtime = runtime | |
| self._tasks = tasks | |
| self._job = job | |
| self._wgp = runtime.module | |
| self._gen = session._state["gen"] | |
| self._manifest = job.webui_manifest or self._build_manifest(tasks) | |
| self._client_ids: list[str] = [] | |
| self._task_index_by_client_id: dict[str, int] = {} | |
| self._task_id_by_client_id: dict[str, Any] = {} | |
| self._outputs_by_client_id: dict[str, str] = {} | |
| self._artifacts_by_client_id: dict[str, GeneratedArtifact] = {} | |
| self._errors_by_client_id: dict[str, GenerationError] = {} | |
| self._admitted_client_ids: set[str] = set() | |
| self._missing_output_since: dict[str, float] = {} | |
| self._last_status_text = "" | |
| self._last_active_client_id = "" | |
| self._last_progress_key: tuple[Any, ...] | None = None | |
| self._last_preview_key: tuple[Any, ...] | None = None | |
| self._active_progress_seed: Any = None | |
| self._cancel_issued = False | |
| self._cancel_requested_at: float | None = None | |
| self._cancel_dispatched_client_ids: set[str] = set() | |
| self._submitted_at = 0.0 | |
| self._queue_wait_suspended = False | |
| self._logged_admitted_client_ids: set[str] = set() | |
| self._logged_missing_output_client_ids: set[str] = set() | |
| self._live_started_client_ids: set[str] = set() | |
| for index, task in enumerate(self._tasks, start=1): | |
| params = self._session._get_task_settings(task) | |
| client_id = str(params.get("client_id", "") or "").strip() | |
| if len(client_id) == 0: | |
| continue | |
| self._client_ids.append(client_id) | |
| self._task_index_by_client_id[client_id] = index | |
| self._task_id_by_client_id[client_id] = task.get("id") | |
| def run(self) -> GenerationResult: | |
| self._submit_inline_manifest() | |
| while not self._all_clients_finished(): | |
| self._poll_once() | |
| if self._all_clients_finished(): | |
| break | |
| time.sleep(self._POLL_INTERVAL_SECONDS) | |
| generated_files = [self._outputs_by_client_id[client_id] for client_id in self._client_ids if client_id in self._outputs_by_client_id] | |
| errors = [self._errors_by_client_id[client_id] for client_id in self._client_ids if client_id in self._errors_by_client_id] | |
| successful_tasks = len(generated_files) | |
| failed_tasks = len(self._client_ids) - successful_tasks | |
| return GenerationResult( | |
| success=len(errors) == 0 and failed_tasks == 0, | |
| generated_files=generated_files, | |
| errors=errors, | |
| total_tasks=len(self._client_ids), | |
| successful_tasks=successful_tasks, | |
| failed_tasks=failed_tasks, | |
| artifacts=tuple(self._artifacts_by_client_id.get(client_id) for client_id in self._client_ids if client_id in self._artifacts_by_client_id), | |
| ) | |
| def _submit_inline_manifest(self) -> None: | |
| self._reset_idle_state() | |
| self._wait_for_inline_queue_slot() | |
| if self._job.cancel_requested: | |
| for client_id in self._client_ids: | |
| self._register_error(client_id, "Generation was cancelled", stage="cancelled") | |
| return | |
| self._gen.setdefault("queue_errors", {}) | |
| self._gen["inline_queue"] = copy.deepcopy(self._manifest) | |
| self._job._mark_webui_submission_ready() | |
| print(f"WanGP API queued client_ids={self._client_ids}") | |
| gradio_context = getattr(self._session, "_gradio_webui_context", None) | |
| if not isinstance(gradio_context, dict) or not gradio_context.get("defer_load_queue_trigger", False): | |
| self._trigger_load_queue_event() | |
| self._submitted_at = time.time() | |
| self._publish("status", "Queued in WanGP...", "on_status") | |
| def _trigger_load_queue_event(self) -> None: | |
| gradio_context = getattr(self._session, "_gradio_webui_context", None) | |
| if not isinstance(gradio_context, dict): | |
| raise RuntimeError("WanGP WebUI queue submission requires an active Gradio session context.") | |
| fn_index = gradio_context.get("load_queue_fn_index") | |
| blocks = gradio_context.get("blocks") | |
| request = gradio_context.get("request") | |
| session_hash = gradio_context.get("session_hash") | |
| if not isinstance(fn_index, int) or blocks is None or request is None or not session_hash: | |
| raise RuntimeError("WanGP WebUI queue trigger is unavailable for the current Gradio session.") | |
| from gradio.data_classes import PredictBodyInternal | |
| request = _normalize_queue_request(request) | |
| if getattr(blocks._queue, "server_app", None) is None and getattr(blocks, "app", None) is not None: | |
| blocks._queue.set_server_app(blocks.app) | |
| body = PredictBodyInternal(session_hash=session_hash, fn_index=fn_index, data=[None, None], request=request) | |
| success, error_or_event_id = asyncio.run(blocks._queue.push(body=body, request=request, username=getattr(request, "username", None))) | |
| if not success: | |
| raise RuntimeError(str(error_or_event_id)) | |
| def _wait_for_inline_queue_slot(self) -> None: | |
| deadline = time.time() + self._INLINE_QUEUE_SLOT_TIMEOUT_SECONDS | |
| while self._gen.get("inline_queue") is not None: | |
| if self._job.cancel_requested: | |
| return | |
| if time.time() >= deadline: | |
| raise RuntimeError("WanGP inline queue bridge is busy") | |
| time.sleep(0.05) | |
| def _reset_idle_state(self) -> None: | |
| if self._gen.get("in_progress", False) or list(self._gen.get("queue", []) or []): | |
| return | |
| self._gen["abort"] = False | |
| self._gen["resume"] = False | |
| self._gen["early_stop"] = False | |
| self._gen["early_stop_forwarded"] = False | |
| self._gen["status"] = "" | |
| self._gen["status_display"] = False | |
| self._gen["last_progress_args"] = None | |
| self._gen["progress_args"] = None | |
| self._gen["preview"] = None | |
| def _poll_once(self) -> None: | |
| queue_client_ids, active_client_id = self._get_queue_snapshot() | |
| for client_id in queue_client_ids: | |
| if client_id in self._client_ids: | |
| self._admitted_client_ids.add(client_id) | |
| if client_id not in self._logged_admitted_client_ids: | |
| print(f"WanGP API admitted client_id={client_id}") | |
| self._logged_admitted_client_ids.add(client_id) | |
| if self._job.cancel_requested: | |
| self._request_cancel() | |
| if self._queue_wait_suspended and any(client_id in self._admitted_client_ids for client_id in self._client_ids): | |
| print("WanGP back in focus API queue resumed") | |
| self._queue_wait_suspended = False | |
| self._check_queue_errors() | |
| self._check_outputs(queue_client_ids) | |
| self._emit_live_updates(queue_client_ids, active_client_id) | |
| self._check_queue_admission_timeout() | |
| self._finalize_cancelled_clients(queue_client_ids) | |
| def _get_queue_snapshot(self) -> tuple[list[str], str]: | |
| queue_client_ids: list[str] = [] | |
| active_client_id = "" | |
| first_queue_task = True | |
| for task in list(self._gen.get("queue", []) or []): | |
| if not isinstance(task, dict): | |
| continue | |
| params = self._session._get_task_settings(task) | |
| client_id = str(params.get("client_id", "") or "").strip() | |
| if first_queue_task: | |
| active_client_id = client_id | |
| first_queue_task = False | |
| if len(client_id) == 0: | |
| continue | |
| queue_client_ids.append(client_id) | |
| return queue_client_ids, active_client_id | |
| def _check_queue_errors(self) -> None: | |
| queue_errors = self._gen.get("queue_errors", {}) or {} | |
| for client_id in self._client_ids: | |
| if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: | |
| continue | |
| error_tuple = queue_errors.get(client_id) | |
| if error_tuple is None: | |
| continue | |
| error_text = str(error_tuple[0] if len(error_tuple) > 0 else "WanGP queue error") | |
| aborted = bool(error_tuple[1]) if len(error_tuple) > 1 else False | |
| print(f"WanGP API queue error client_id={client_id} aborted={aborted} error={error_text}") | |
| if aborted: | |
| self._remove_queue_client_id(client_id) | |
| self._register_error(client_id, "Generation was cancelled", stage="cancelled") | |
| else: | |
| self._register_error(client_id, error_text or "WanGP queue error", stage="generation") | |
| def _check_outputs(self, queue_client_ids: list[str]) -> None: | |
| processed = self._wgp.get_processed_queue(self._gen) | |
| if not isinstance(processed, tuple) or len(processed) != 4: | |
| return | |
| file_list, file_settings_list, audio_file_list, audio_file_settings_list = processed | |
| for client_id in self._client_ids: | |
| if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: | |
| continue | |
| output_path = self._find_output_for_client(client_id, file_list, file_settings_list, audio_file_list, audio_file_settings_list) | |
| pending_artifact = self._session._peek_output_artifact(client_id) | |
| if pending_artifact is not None and client_id not in queue_client_ids: | |
| if not queue_client_ids and self._gen.get("in_progress", False): | |
| if client_id not in self._logged_missing_output_client_ids: | |
| print(f"WanGP API delaying completion for client_id={client_id} until main queue settles") | |
| self._logged_missing_output_client_ids.add(client_id) | |
| continue | |
| artifact = self._session._consume_output_artifact(client_id) | |
| resolved_output_path = str(output_path or (artifact.path if artifact is not None else "") or "").strip() | |
| if len(resolved_output_path) == 0: | |
| self._register_error(client_id, f"Generation produced an API artifact for client_id '{client_id}' without an output path.", stage="generation") | |
| continue | |
| self._outputs_by_client_id[client_id] = resolved_output_path | |
| if artifact is not None: | |
| self._artifacts_by_client_id[client_id] = GeneratedArtifact( | |
| path=resolved_output_path, | |
| media_type=artifact.media_type, | |
| client_id=artifact.client_id, | |
| video_tensor_uint8=artifact.video_tensor_uint8, | |
| video_tensor_hdr=artifact.video_tensor_hdr, | |
| hdr=artifact.hdr, | |
| audio_tensor=artifact.audio_tensor, | |
| audio_sampling_rate=artifact.audio_sampling_rate, | |
| fps=artifact.fps, | |
| flashvsr_continue_cache=artifact.flashvsr_continue_cache, | |
| ) | |
| self._missing_output_since.pop(client_id, None) | |
| self._logged_missing_output_client_ids.discard(client_id) | |
| print(f"WanGP API completed client_id={client_id} via artifact path={resolved_output_path}") | |
| payload = {"client_id": client_id, "path": resolved_output_path} | |
| self._publish("output", payload, "on_output") | |
| continue | |
| if output_path is not None and client_id not in queue_client_ids: | |
| if not queue_client_ids and self._gen.get("in_progress", False): | |
| if client_id not in self._logged_missing_output_client_ids: | |
| print(f"WanGP API delaying gallery completion for client_id={client_id} until main queue settles") | |
| self._logged_missing_output_client_ids.add(client_id) | |
| continue | |
| self._outputs_by_client_id[client_id] = output_path | |
| artifact = self._session._consume_output_artifact(client_id) | |
| if artifact is not None: | |
| self._artifacts_by_client_id[client_id] = GeneratedArtifact( | |
| path=output_path, | |
| media_type=artifact.media_type, | |
| client_id=artifact.client_id, | |
| video_tensor_uint8=artifact.video_tensor_uint8, | |
| video_tensor_hdr=artifact.video_tensor_hdr, | |
| hdr=artifact.hdr, | |
| audio_tensor=artifact.audio_tensor, | |
| audio_sampling_rate=artifact.audio_sampling_rate, | |
| fps=artifact.fps, | |
| flashvsr_continue_cache=artifact.flashvsr_continue_cache, | |
| ) | |
| self._missing_output_since.pop(client_id, None) | |
| self._logged_missing_output_client_ids.discard(client_id) | |
| print(f"WanGP API completed client_id={client_id} via gallery path={output_path}") | |
| payload = {"client_id": client_id, "path": output_path} | |
| self._publish("output", payload, "on_output") | |
| continue | |
| if client_id in queue_client_ids: | |
| self._missing_output_since.pop(client_id, None) | |
| self._logged_missing_output_client_ids.discard(client_id) | |
| continue | |
| if client_id not in self._admitted_client_ids: | |
| continue | |
| started_missing_at = self._missing_output_since.setdefault(client_id, time.time()) | |
| if client_id not in self._logged_missing_output_client_ids: | |
| print(f"WanGP API waiting for output client_id={client_id} queue_empty={not queue_client_ids} artifact_ready={pending_artifact is not None}") | |
| self._logged_missing_output_client_ids.add(client_id) | |
| if time.time() - started_missing_at >= self._MISSING_OUTPUT_TIMEOUT_SECONDS: | |
| self._register_error( | |
| client_id, | |
| f"Generation finished queue processing but no output with client_id '{client_id}' was found in the gallery.", | |
| stage="generation", | |
| ) | |
| def _emit_live_updates(self, queue_client_ids: list[str], active_client_id: str) -> None: | |
| if active_client_id != self._last_active_client_id: | |
| self._last_active_client_id = active_client_id | |
| self._last_progress_key = None | |
| self._last_preview_key = None | |
| self._last_status_text = "" | |
| self._active_progress_seed = copy.deepcopy(self._gen.get("last_progress_args")) | |
| live_generation_running = bool(self._gen.get("in_progress", False)) | |
| active_client_is_live = live_generation_running and active_client_id in self._client_ids and active_client_id not in self._outputs_by_client_id and active_client_id not in self._errors_by_client_id | |
| if active_client_is_live: | |
| self._live_started_client_ids.add(active_client_id) | |
| progress_args = self._gen.get("last_progress_args") | |
| progress_ready = progress_args != self._active_progress_seed | |
| progress_update = self._session._build_progress_update(progress_args, include_state_fallback=False) if progress_ready else None | |
| if progress_update is not None: | |
| if len(progress_update.status) > 0 and progress_update.status != self._last_status_text: | |
| self._last_status_text = progress_update.status | |
| self._publish("status", progress_update.status, "on_status") | |
| progress_key = ( | |
| active_client_id, | |
| progress_update.phase, | |
| progress_update.progress, | |
| progress_update.current_step, | |
| progress_update.total_steps, | |
| progress_update.status, | |
| progress_update.unit, | |
| ) | |
| if progress_key != self._last_progress_key: | |
| self._last_progress_key = progress_key | |
| self._publish("progress", progress_update, "on_progress") | |
| preview_image = self._gen.get("preview") | |
| if preview_image is not None and progress_update is not None: | |
| preview_key = (active_client_id, id(preview_image), getattr(preview_image, "size", None), progress_update.progress) | |
| if preview_key != self._last_preview_key: | |
| self._last_preview_key = preview_key | |
| self._publish( | |
| "preview", | |
| PreviewUpdate( | |
| image=preview_image, | |
| phase=progress_update.phase, | |
| status=progress_update.status, | |
| progress=progress_update.progress, | |
| current_step=progress_update.current_step, | |
| total_steps=progress_update.total_steps, | |
| ), | |
| "on_preview", | |
| ) | |
| return | |
| queued_client_ids = [ | |
| client_id for client_id in queue_client_ids | |
| if client_id in self._client_ids and client_id not in self._outputs_by_client_id and client_id not in self._errors_by_client_id | |
| ] | |
| if queued_client_ids and any(client_id not in self._live_started_client_ids for client_id in queued_client_ids): | |
| status_text = "Waiting in WanGP queue..." | |
| if status_text != self._last_status_text: | |
| self._last_status_text = status_text | |
| self._publish("status", status_text, "on_status") | |
| def _check_queue_admission_timeout(self) -> None: | |
| pending_client_ids = [ | |
| client_id | |
| for client_id in self._client_ids | |
| if client_id not in self._outputs_by_client_id and client_id not in self._errors_by_client_id and client_id not in self._admitted_client_ids | |
| ] | |
| if not pending_client_ids: | |
| self._queue_wait_suspended = False | |
| return | |
| if self._gen.get("in_progress", False) or list(self._gen.get("queue", []) or []): | |
| self._submitted_at = time.time() | |
| return | |
| if self._submitted_at <= 0 or time.time() - self._submitted_at < self._QUEUE_ADMISSION_SUSPEND_NOTICE_SECONDS or self._queue_wait_suspended: | |
| return | |
| print("WanGP API queue suspended while waiting for Video Generator to get browser focus") | |
| self._publish("status", "Waiting for WanGP Video Generator to get browser focus...", "on_status") | |
| self._queue_wait_suspended = True | |
| def _finalize_cancelled_clients(self, queue_client_ids: list[str]) -> None: | |
| if not self._cancel_issued or self._cancel_requested_at is None: | |
| return | |
| if time.time() - self._cancel_requested_at < self._CANCEL_GRACE_SECONDS: | |
| return | |
| for client_id in self._client_ids: | |
| if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: | |
| continue | |
| if client_id in queue_client_ids or self._inline_queue_contains_client_id(client_id): | |
| continue | |
| self._register_error(client_id, "Generation was cancelled", stage="cancelled") | |
| def _request_cancel(self) -> None: | |
| dispatched_any = False | |
| for client_id in self._client_ids: | |
| if client_id in self._outputs_by_client_id or client_id in self._errors_by_client_id: | |
| continue | |
| if client_id in self._cancel_dispatched_client_ids: | |
| continue | |
| if self._remove_inline_queue_client_id(client_id): | |
| self._cancel_dispatched_client_ids.add(client_id) | |
| dispatched_any = True | |
| continue | |
| if client_id in self._admitted_client_ids: | |
| self._queue_abort_client_id(client_id) | |
| self._cancel_dispatched_client_ids.add(client_id) | |
| dispatched_any = True | |
| if dispatched_any: | |
| self._cancel_issued = True | |
| self._cancel_requested_at = time.time() | |
| def _queue_abort_client_id(self, client_id: str) -> None: | |
| owner = getattr(self._session, "_gradio_session_proxy", None) | |
| enqueue = getattr(owner, "_enqueue_abort_client_id", None) | |
| if not callable(enqueue) or not enqueue(self._job, client_id): | |
| self._gen["abort"] = True | |
| print("WanGP API set direct abort flag because the WebUI abort trigger was unavailable.") | |
| def _remove_inline_queue_client_id(self, client_id: str) -> bool: | |
| inline_queue = self._gen.get("inline_queue") | |
| if inline_queue is None: | |
| return False | |
| def _matches(item: Any) -> bool: | |
| if not isinstance(item, dict): | |
| return False | |
| params = item.get("params") | |
| if isinstance(params, dict) and str(params.get("client_id", "") or "").strip() == client_id: | |
| return True | |
| return str(item.get("client_id", "") or "").strip() == client_id | |
| if _matches(inline_queue): | |
| self._gen.pop("inline_queue", None) | |
| return True | |
| if isinstance(inline_queue, list): | |
| remaining = [item for item in inline_queue if not _matches(item)] | |
| if len(remaining) != len(inline_queue): | |
| if remaining: | |
| self._gen["inline_queue"] = remaining | |
| else: | |
| self._gen.pop("inline_queue", None) | |
| return True | |
| return False | |
| def _remove_queue_client_id(self, client_id: str) -> bool: | |
| queue = self._gen.get("queue") | |
| if not isinstance(queue, list): | |
| return False | |
| remaining = [] | |
| removed = False | |
| for item in list(queue): | |
| if self._inline_item_matches_client_id(item, client_id): | |
| removed = True | |
| continue | |
| remaining.append(item) | |
| if removed: | |
| queue[:] = remaining | |
| self._gen["queue"] = queue | |
| return removed | |
| def _inline_queue_contains_client_id(self, client_id: str) -> bool: | |
| inline_queue = self._gen.get("inline_queue") | |
| if inline_queue is None: | |
| return False | |
| if isinstance(inline_queue, list): | |
| return any(self._inline_item_matches_client_id(item, client_id) for item in inline_queue) | |
| return self._inline_item_matches_client_id(inline_queue, client_id) | |
| def _inline_item_matches_client_id(item: Any, client_id: str) -> bool: | |
| if not isinstance(item, dict): | |
| return False | |
| params = item.get("params") | |
| if isinstance(params, dict) and str(params.get("client_id", "") or "").strip() == client_id: | |
| return True | |
| return str(item.get("client_id", "") or "").strip() == client_id | |
| def _find_output_for_client( | |
| client_id: str, | |
| file_list: Sequence[Any], | |
| file_settings_list: Sequence[Any], | |
| audio_file_list: Sequence[Any], | |
| audio_file_settings_list: Sequence[Any], | |
| ) -> str | None: | |
| for paths, settings_list in ((file_list, file_settings_list), (audio_file_list, audio_file_settings_list)): | |
| for path, settings in zip(reversed(list(paths or [])), reversed(list(settings_list or []))): | |
| if not isinstance(settings, dict): | |
| continue | |
| if str(settings.get("client_id", "") or "").strip() == client_id: | |
| return str(Path(path).resolve()) | |
| return None | |
| def _register_error(self, client_id: str, message: str, *, stage: str) -> None: | |
| if client_id in self._errors_by_client_id or client_id in self._outputs_by_client_id: | |
| return | |
| failure = GenerationError( | |
| message=message, | |
| task_index=self._task_index_by_client_id.get(client_id), | |
| task_id=self._task_id_by_client_id.get(client_id), | |
| stage=stage, | |
| ) | |
| self._errors_by_client_id[client_id] = failure | |
| self._publish("error", failure, "on_error") | |
| def _publish(self, kind: str, payload: Any, callback_name: str | None = None) -> None: | |
| self._job.events.put(kind, payload) | |
| if callback_name is not None: | |
| self._session._emit_callback(callback_name, payload, job=self._job) | |
| def _all_clients_finished(self) -> bool: | |
| completed_count = len(self._outputs_by_client_id) + len(self._errors_by_client_id) | |
| return completed_count >= len(self._client_ids) | |
| def _build_manifest(tasks: Sequence[dict[str, Any]]) -> list[dict[str, Any]]: | |
| manifest = [] | |
| for index, task in enumerate(tasks, start=1): | |
| params = copy.deepcopy(WanGPSession._get_task_settings(task)) | |
| manifest.append({"id": task.get("id", index), "params": params, "plugin_data": copy.deepcopy(task.get("plugin_data", {}))}) | |
| return manifest | |
| def run_webui_job(session, job: SessionJob, tasks: list[dict[str, Any]]) -> None: | |
| try: | |
| runtime = session._ensure_runtime() | |
| job.events.put("started", {"tasks": len(tasks), "backend": "webui_queue"}) | |
| result = WebUIQueueProbe(session, runtime, tasks, job).run() | |
| job.events.put("completed", result) | |
| session._emit_callback("on_complete", result, job=job) | |
| job._set_result(result) | |
| except BaseException as exc: | |
| failure = session._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") | |
| result = GenerationResult( | |
| success=False, | |
| generated_files=[], | |
| errors=[failure], | |
| total_tasks=len(tasks), | |
| successful_tasks=0, | |
| failed_tasks=max(1, len(tasks)), | |
| artifacts=(), | |
| ) | |
| job.events.put("error", failure) | |
| session._emit_callback("on_error", failure, job=job) | |
| job.events.put("completed", result) | |
| session._emit_callback("on_complete", result, job=job) | |
| job._set_result(result) | |
| finally: | |
| job.events.close() | |
| with session._job_lock: | |
| if session._active_job is job: | |
| session._active_job = None | |
| class GradioWanGPSession: | |
| def __init__(self, *, init_fn, plugin=None, state_component: Any = None, session_kwargs: dict[str, Any] | None = None) -> None: | |
| self._init_fn = init_fn | |
| self._plugin = plugin | |
| self._state_component = state_component | |
| self._session_kwargs = dict(session_kwargs or {}) | |
| self._session_kwargs.setdefault("console_output", False) | |
| self._session: WanGPSession | None = None | |
| self._defer_load_queue_trigger = False | |
| self._ui_local = threading.local() | |
| self._wrapped_calls: dict[str, _WrappedCallState] = {} | |
| self._wrapped_calls_lock = threading.Lock() | |
| self._ui_call_component = None | |
| def for_plugin(cls, plugin, *, init_fn, session_kwargs: dict[str, Any] | None = None): | |
| plugin.request_component("state") | |
| return cls(init_fn=init_fn, plugin=plugin, session_kwargs=session_kwargs) | |
| def submit(self, source, callbacks: object | None = None) -> SessionJob: | |
| session = self._ensure_session() | |
| self._bind_gradio_context(session) | |
| job = session.submit(source, callbacks=self._wrap_callbacks_for_current_call(callbacks)) | |
| self._capture_job_for_current_call(job) | |
| return job | |
| def submit_task(self, settings: dict[str, Any], callbacks: object | None = None) -> SessionJob: | |
| session = self._ensure_session() | |
| self._bind_gradio_context(session) | |
| job = session.submit_task(settings, callbacks=self._wrap_callbacks_for_current_call(callbacks)) | |
| self._capture_job_for_current_call(job) | |
| return job | |
| def submit_manifest(self, settings_list: list[dict[str, Any]], callbacks: object | None = None) -> SessionJob: | |
| session = self._ensure_session() | |
| self._bind_gradio_context(session) | |
| job = session.submit_manifest(settings_list, callbacks=self._wrap_callbacks_for_current_call(callbacks)) | |
| self._capture_job_for_current_call(job) | |
| return job | |
| def run(self, source, callbacks: object | None = None) -> GenerationResult: | |
| session = self._ensure_session() | |
| self._bind_gradio_context(session) | |
| return session.run(source, callbacks=self._wrap_callbacks_for_current_call(callbacks)) | |
| def run_task(self, settings: dict[str, Any], callbacks: object | None = None) -> GenerationResult: | |
| session = self._ensure_session() | |
| self._bind_gradio_context(session) | |
| return session.run_task(settings, callbacks=self._wrap_callbacks_for_current_call(callbacks)) | |
| def run_manifest(self, settings_list: list[dict[str, Any]], callbacks: object | None = None) -> GenerationResult: | |
| session = self._ensure_session() | |
| self._bind_gradio_context(session) | |
| return session.run_manifest(settings_list, callbacks=self._wrap_callbacks_for_current_call(callbacks)) | |
| def ensure_ready(self): | |
| self._ensure_session().ensure_ready() | |
| return self | |
| def close(self) -> None: | |
| if self._session is None: | |
| return | |
| self._session.close() | |
| self._session = None | |
| def cancel(self) -> None: | |
| if self._session is not None: | |
| self._session.cancel() | |
| def plugin_ui_context(self): | |
| import gradio as gr | |
| original_click = gr.Button.click | |
| if self._ui_call_component is None: | |
| self._ui_call_component = gr.State("") | |
| def patched_click(button, *args, **kwargs): | |
| fn = kwargs.get("fn") | |
| if fn is None and args: | |
| fn = args[0] | |
| if not callable(fn): | |
| return original_click(button, *args, **kwargs) | |
| if not self._callback_uses_api_session(fn): | |
| return original_click(button, *args, **kwargs) | |
| return self._wrap_button_click(original_click, button, *args, **kwargs) | |
| gr.Button.click = patched_click | |
| try: | |
| yield | |
| finally: | |
| gr.Button.click = original_click | |
| def __getattr__(self, name: str) -> Any: | |
| if name.startswith("_"): | |
| raise AttributeError(name) | |
| return getattr(self._ensure_session(), name) | |
| def _wrap_button_click(self, original_click, button, *args, **kwargs): | |
| import gradio as gr | |
| fn = kwargs.get("fn") | |
| if fn is None and args: | |
| fn = args[0] | |
| inputs = kwargs.get("inputs") | |
| if inputs is None and len(args) > 1: | |
| inputs = args[1] | |
| outputs = kwargs.get("outputs") | |
| if outputs is None and len(args) > 2: | |
| outputs = args[2] | |
| original_outputs = self._normalize_outputs(outputs) | |
| explicit_show_progress = kwargs.get("show_progress") if "show_progress" in kwargs else None | |
| explicit_progress_targets = self._normalize_outputs(kwargs.get("show_progress_on")) if "show_progress_on" in kwargs else None | |
| load_queue_trigger = self._resolve_main_bridge_component("wangp_main_load_queue_trigger") | |
| abort_client_id = self._resolve_main_bridge_component("wangp_main_abort_client_id") | |
| call_state = self._ui_call_component | |
| wrapped_start = self._make_wrapped_click_start(fn, len(original_outputs)) | |
| kwargs["fn"] = wrapped_start | |
| kwargs["outputs"] = [*original_outputs, load_queue_trigger, abort_client_id, call_state] | |
| kwargs["show_progress"] = "hidden" | |
| args = () | |
| dependency = original_click(button, *args, **kwargs) | |
| wait_outputs = [*original_outputs, load_queue_trigger, abort_client_id, call_state] | |
| progress_targets = explicit_progress_targets if explicit_progress_targets is not None else [component for component in original_outputs if hasattr(component, "_id")] | |
| def wait_wrapped_call(call_id): | |
| yield from self._wait_wrapped_call(call_id, len(original_outputs)) | |
| then_kwargs = {"fn": wait_wrapped_call, "inputs": [call_state], "outputs": wait_outputs, "show_progress": explicit_show_progress or "full"} | |
| if progress_targets is not None and len(progress_targets) > 0: | |
| then_kwargs["show_progress_on"] = progress_targets | |
| dependency.then( | |
| **then_kwargs, | |
| ) | |
| return dependency | |
| def _make_wrapped_click_start(self, fn, output_count: int): | |
| import gradio as gr | |
| def wrapped(*args, **kwargs): | |
| call_id = str(time.time_ns()) | |
| state = _WrappedCallState(output_count) | |
| self._remember_wrapped_call(call_id, state) | |
| bound_state = self._resolve_state() | |
| bound_context = self._capture_current_gradio_context() | |
| state.set_callback_context(bound_context) | |
| worker = threading.Thread(target=self._run_wrapped_click_worker, args=(call_id, fn, args, kwargs, bound_state, bound_context), daemon=True, name="wangp-plugin-click") | |
| worker.start() | |
| deadline = time.time() + 0.5 | |
| while time.time() < deadline: | |
| yielded_result = state.pop_yielded_result() | |
| if yielded_result is not _NO_YIELDED_RESULT: | |
| return [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] | |
| load_queue_token = state.pop_primary_load_queue_token() | |
| if load_queue_token: | |
| return [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] | |
| if state.job is not None: | |
| state.job._webui_submission_ready.wait(timeout=0.05) | |
| if state.done.wait(timeout=0.05): | |
| if state.error is not None: | |
| self._forget_wrapped_call(call_id) | |
| raise self._as_gradio_error(state.error) | |
| yielded_result = state.pop_yielded_result() | |
| if yielded_result is not _NO_YIELDED_RESULT: | |
| return [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] | |
| if not state.has_result: | |
| self._forget_wrapped_call(call_id) | |
| return [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] | |
| self._forget_wrapped_call(call_id) | |
| return [*self._normalize_callback_result(state.result, output_count), gr.skip(), gr.skip(), ""] | |
| if state.job is not None: | |
| if not state.job._webui_submission_ready.wait(timeout=5): | |
| self._forget_wrapped_call(call_id) | |
| raise gr.Error("WanGP WebUI submission did not become ready in time.") | |
| load_queue_token = state.pop_primary_load_queue_token() | |
| if load_queue_token: | |
| return [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] | |
| state.done.wait() | |
| if state.error is not None: | |
| self._forget_wrapped_call(call_id) | |
| raise self._as_gradio_error(state.error) | |
| yielded_result = state.pop_yielded_result() | |
| if yielded_result is not _NO_YIELDED_RESULT: | |
| return [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] | |
| if not state.has_result: | |
| self._forget_wrapped_call(call_id) | |
| return [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] | |
| self._forget_wrapped_call(call_id) | |
| return [*self._normalize_callback_result(state.result, output_count), gr.skip(), gr.skip(), ""] | |
| wrapped.__signature__ = inspect.signature(fn) | |
| return wrapped | |
| def _wait_wrapped_call(self, call_id: str, output_count: int): | |
| import gradio as gr | |
| call_id = str(call_id or "").strip() | |
| if len(call_id) == 0: | |
| yield [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] | |
| return | |
| state = self._get_wrapped_call(call_id) | |
| if state is None: | |
| yield [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] | |
| return | |
| try: | |
| state.set_callback_context(self._capture_progress_callback_context()) | |
| self._flush_buffered_log_messages(state) | |
| while True: | |
| self._flush_buffered_log_messages(state) | |
| load_queue_token = state.pop_primary_load_queue_token() | |
| if load_queue_token: | |
| yield [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] | |
| continue | |
| load_queue_token = state.pop_ready_followup_load_queue_token() | |
| if load_queue_token: | |
| print(f"WanGP API forwarding follow-up load_queue_trigger token={load_queue_token}") | |
| yield [*self._blank_outputs(output_count), load_queue_token, gr.skip(), call_id] | |
| continue | |
| abort_client_id = state.pop_abort_client_id() | |
| if abort_client_id: | |
| print(f"WanGP API forwarding abort_client_id={abort_client_id}") | |
| yield [*self._blank_outputs(output_count), gr.skip(), abort_client_id, call_id] | |
| continue | |
| yielded_result = state.pop_yielded_result() | |
| if yielded_result is not _NO_YIELDED_RESULT: | |
| yield [*self._normalize_callback_result(yielded_result, output_count), gr.skip(), gr.skip(), call_id] | |
| continue | |
| if state.done.is_set(): | |
| if state.error is not None: | |
| raise self._as_gradio_error(state.error) | |
| if state.has_result: | |
| yield [*self._normalize_callback_result(state.result, output_count), gr.skip(), gr.skip(), ""] | |
| else: | |
| yield [*self._blank_outputs(output_count), gr.skip(), gr.skip(), ""] | |
| break | |
| time.sleep(0.05) | |
| finally: | |
| self._forget_wrapped_call(call_id) | |
| def _flush_buffered_log_messages(state: _WrappedCallState) -> None: | |
| context = state.callback_context | |
| if not isinstance(context, dict): | |
| return | |
| blocks = context.get("blocks") | |
| event_id = context.get("event_id") | |
| if blocks is None or event_id is None: | |
| return | |
| for message in state.pop_log_messages(): | |
| blocks._queue.log_message(event_id=event_id, **message) | |
| def _run_wrapped_click_worker(self, call_id: str, fn, args, kwargs, bound_state: dict[str, Any], bound_context: dict[str, Any]) -> None: | |
| state = self._get_wrapped_call(call_id) | |
| if state is None: | |
| return | |
| _ensure_gradio_log_message_patch() | |
| self._ui_local.call_id = call_id | |
| self._ui_local.defer_load_queue_trigger = True | |
| self._ui_local.bound_state = bound_state | |
| _WRAPPED_LOG_LOCAL.owner = self | |
| _WRAPPED_LOG_LOCAL.call_id = call_id | |
| try: | |
| exec_context = dict(bound_context) | |
| if isinstance(state.callback_context, dict): | |
| exec_context.update(state.callback_context) | |
| self._ui_local.bound_gradio_context = exec_context | |
| with self._push_callback_context(exec_context): | |
| result = fn(*args, **kwargs) | |
| if inspect.isgenerator(result): | |
| iterator = iter(result) | |
| while True: | |
| try: | |
| state.push_yielded_result(next(iterator)) | |
| except StopIteration as stop: | |
| if stop.value is not None: | |
| state.set_result(stop.value) | |
| else: | |
| state.set_completed() | |
| break | |
| else: | |
| state.set_result(result) | |
| except BaseException as exc: | |
| state.set_error(exc) | |
| finally: | |
| _WRAPPED_LOG_LOCAL.owner = None | |
| _WRAPPED_LOG_LOCAL.call_id = "" | |
| self._ui_local.call_id = "" | |
| self._ui_local.defer_load_queue_trigger = False | |
| self._ui_local.bound_state = None | |
| self._ui_local.bound_gradio_context = None | |
| def _capture_job_for_current_call(self, job: SessionJob) -> None: | |
| call_id = str(getattr(self._ui_local, "call_id", "") or "").strip() | |
| if len(call_id) == 0: | |
| return | |
| job._bind_webui_owner_call(call_id) | |
| state = self._get_wrapped_call(call_id) | |
| if state is not None: | |
| if state.job is None: | |
| state.job = job | |
| else: | |
| state.add_followup_job(job) | |
| def _capture_cancelled_job(self, job: SessionJob) -> None: | |
| return | |
| def _enqueue_abort_client_id(self, job: SessionJob, client_id: str) -> bool: | |
| call_id = str(getattr(job, "webui_owner_call_id", "") or getattr(self._ui_local, "call_id", "") or "").strip() | |
| if len(call_id) == 0: | |
| return False | |
| state = self._get_wrapped_call(call_id) | |
| if state is None: | |
| return False | |
| state.queue_abort_client_id(client_id) | |
| return True | |
| def _remember_wrapped_call(self, call_id: str, state: _WrappedCallState) -> None: | |
| with self._wrapped_calls_lock: | |
| self._wrapped_calls[call_id] = state | |
| def _get_wrapped_call(self, call_id: str) -> _WrappedCallState | None: | |
| with self._wrapped_calls_lock: | |
| return self._wrapped_calls.get(call_id) | |
| def _forget_wrapped_call(self, call_id: str) -> None: | |
| with self._wrapped_calls_lock: | |
| self._wrapped_calls.pop(call_id, None) | |
| def _callback_uses_api_session(self, fn) -> bool: | |
| candidates: list[Any] = [] | |
| try: | |
| closure_vars = inspect.getclosurevars(fn) | |
| except Exception: | |
| closure_vars = None | |
| if closure_vars is not None: | |
| candidates.extend(closure_vars.nonlocals.values()) | |
| candidates.extend(closure_vars.globals.values()) | |
| for values in (getattr(fn, "__defaults__", None) or (), (getattr(fn, "__kwdefaults__", None) or {}).values()): | |
| candidates.extend(values) | |
| for cell in getattr(fn, "__closure__", ()) or (): | |
| try: | |
| candidates.append(cell.cell_contents) | |
| except ValueError: | |
| continue | |
| for candidate in candidates: | |
| if candidate is self or candidate is self._session: | |
| return True | |
| if isinstance(candidate, (GradioWanGPSession, WanGPSession)): | |
| return True | |
| if inspect.ismethod(candidate) and candidate.__self__ in (self, self._session): | |
| return True | |
| return False | |
| def _wrap_callbacks_for_current_call(self, callbacks: object | None) -> object | None: | |
| if callbacks is None: | |
| return None | |
| call_id = str(getattr(self._ui_local, "call_id", "") or "").strip() | |
| if len(call_id) == 0: | |
| return callbacks | |
| state = self._get_wrapped_call(call_id) | |
| if state is None: | |
| return callbacks | |
| if isinstance(callbacks, _BoundGradioCallbacks): | |
| return callbacks | |
| return _BoundGradioCallbacks(callbacks, state, self) | |
| def _normalize_outputs(outputs: Any) -> list[Any]: | |
| if outputs is None: | |
| return [] | |
| if isinstance(outputs, (list, tuple)): | |
| return list(outputs) | |
| return [outputs] | |
| def _blank_outputs(output_count: int) -> list[Any]: | |
| import gradio as gr | |
| return [gr.skip()] * output_count | |
| def _normalize_callback_result(result: Any, output_count: int) -> list[Any]: | |
| import gradio as gr | |
| if output_count <= 0: | |
| return [] | |
| if output_count == 1: | |
| return [result] | |
| if isinstance(result, tuple): | |
| normalized = list(result) | |
| elif isinstance(result, list): | |
| normalized = list(result) | |
| else: | |
| normalized = [result] | |
| if len(normalized) < output_count: | |
| normalized.extend([gr.skip()] * (output_count - len(normalized))) | |
| return normalized[:output_count] | |
| def _as_gradio_error(error: BaseException): | |
| import gradio as gr | |
| return error if isinstance(error, gr.Error) else gr.Error(str(error)) | |
| def _ensure_session(self) -> WanGPSession: | |
| state = self._resolve_state() | |
| if self._session is None or self._session._state is not state: | |
| session_kwargs = copy.deepcopy(self._session_kwargs) | |
| session_kwargs["webui_state"] = state | |
| self._session = WanGPSession(**session_kwargs) | |
| self._session._gradio_session_proxy = self | |
| return self._session | |
| def _resolve_state(self) -> dict[str, Any]: | |
| bound_state = getattr(self._ui_local, "bound_state", None) | |
| if isinstance(bound_state, dict): | |
| return bound_state | |
| component = self._state_component | |
| if component is None and self._plugin is not None: | |
| component = getattr(self._plugin, "state", None) | |
| state = self._resolve_live_session_state(component) | |
| if not isinstance(state, dict): | |
| state = getattr(component, "value", None) if component is not None else None | |
| if not isinstance(state, dict): | |
| raise RuntimeError("WanGP WebUI session requires access to the live Gradio state component.") | |
| return state | |
| def _resolve_live_session_state(component: Any) -> dict[str, Any] | None: | |
| component_id = getattr(component, "_id", None) | |
| if component_id is None: | |
| return None | |
| try: | |
| from gradio.context import LocalContext | |
| except Exception: | |
| return None | |
| try: | |
| blocks = LocalContext.blocks.get(None) | |
| request = LocalContext.request.get(None) | |
| except LookupError: | |
| return None | |
| session_hash = getattr(request, "session_hash", None) if request is not None else None | |
| state_holder = getattr(blocks, "state_holder", None) if blocks is not None else None | |
| if not session_hash or state_holder is None: | |
| return None | |
| try: | |
| session_state = state_holder[session_hash] | |
| state = session_state[component_id] | |
| except Exception: | |
| return None | |
| return state if isinstance(state, dict) else None | |
| def _bind_gradio_context(self, session: WanGPSession) -> None: | |
| bound_context = getattr(self._ui_local, "bound_gradio_context", None) | |
| if isinstance(bound_context, dict): | |
| session._gradio_webui_context = dict(bound_context) | |
| session._gradio_webui_context["defer_load_queue_trigger"] = self._defer_load_queue_trigger or bool(getattr(self._ui_local, "defer_load_queue_trigger", False)) | |
| return | |
| try: | |
| from gradio.context import LocalContext | |
| except Exception: | |
| raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") | |
| try: | |
| blocks = LocalContext.blocks.get(None) | |
| request_wrapper = LocalContext.request.get(None) | |
| except LookupError as exc: | |
| raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") from exc | |
| session_hash = getattr(request_wrapper, "session_hash", None) if request_wrapper is not None else None | |
| request = getattr(request_wrapper, "request", request_wrapper) | |
| if blocks is None or request is None or not session_hash: | |
| raise RuntimeError("WanGP WebUI session requires a live Gradio request with a session hash.") | |
| session._gradio_webui_context = { | |
| "blocks": blocks, | |
| "request": request, | |
| "session_hash": session_hash, | |
| "load_queue_fn_index": self._resolve_trigger_fn_index(blocks, session_hash, "load_queue_action", "change"), | |
| "abort_fn_index": self._resolve_abort_fn_index(blocks, session_hash), | |
| "defer_load_queue_trigger": self._defer_load_queue_trigger or bool(getattr(self._ui_local, "defer_load_queue_trigger", False)), | |
| } | |
| def _set_defer_load_queue_trigger(self, value: bool) -> None: | |
| self._defer_load_queue_trigger = bool(value) | |
| def _capture_current_gradio_context(self) -> dict[str, Any]: | |
| try: | |
| from gradio.context import LocalContext | |
| except Exception: | |
| raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") | |
| try: | |
| blocks = LocalContext.blocks.get(None) | |
| blocks_config = LocalContext.blocks_config.get(None) | |
| renderable = LocalContext.renderable.get(None) | |
| render_block = LocalContext.render_block.get(None) | |
| in_event_listener = LocalContext.in_event_listener.get(False) | |
| event_id = LocalContext.event_id.get(None) | |
| request_wrapper = LocalContext.request.get(None) | |
| progress = LocalContext.progress.get(None) | |
| except LookupError as exc: | |
| raise RuntimeError("WanGP WebUI session requires an active Gradio callback context.") from exc | |
| session_hash = getattr(request_wrapper, "session_hash", None) if request_wrapper is not None else None | |
| request = getattr(request_wrapper, "request", request_wrapper) | |
| if blocks is None or request is None or not session_hash: | |
| raise RuntimeError("WanGP WebUI session requires a live Gradio request with a session hash.") | |
| return { | |
| "blocks": blocks, | |
| "blocks_config": blocks_config, | |
| "renderable": renderable, | |
| "render_block": render_block, | |
| "in_event_listener": in_event_listener, | |
| "event_id": event_id, | |
| "request": request, | |
| "progress": progress, | |
| "session_hash": session_hash, | |
| "load_queue_fn_index": self._resolve_trigger_fn_index(blocks, session_hash, "load_queue_action", "change"), | |
| "abort_fn_index": self._resolve_abort_fn_index(blocks, session_hash), | |
| "defer_load_queue_trigger": self._defer_load_queue_trigger or bool(getattr(self._ui_local, "defer_load_queue_trigger", False)), | |
| } | |
| def _capture_progress_callback_context() -> dict[str, Any]: | |
| try: | |
| from gradio.context import LocalContext | |
| except Exception as exc: | |
| raise RuntimeError("WanGP progress callbacks require an active Gradio callback context.") from exc | |
| return { | |
| "blocks": LocalContext.blocks.get(None), | |
| "blocks_config": LocalContext.blocks_config.get(None), | |
| "renderable": LocalContext.renderable.get(None), | |
| "render_block": LocalContext.render_block.get(None), | |
| "in_event_listener": LocalContext.in_event_listener.get(False), | |
| "event_id": LocalContext.event_id.get(None), | |
| "request": LocalContext.request.get(None), | |
| "progress": LocalContext.progress.get(None), | |
| } | |
| def _push_callback_context(context: dict[str, Any]): | |
| try: | |
| from gradio.context import LocalContext | |
| except Exception: | |
| yield | |
| return | |
| tokens = [] | |
| mapping = { | |
| LocalContext.blocks: context.get("blocks"), | |
| LocalContext.blocks_config: context.get("blocks_config"), | |
| LocalContext.renderable: context.get("renderable"), | |
| LocalContext.render_block: context.get("render_block"), | |
| LocalContext.in_event_listener: context.get("in_event_listener", False), | |
| LocalContext.event_id: context.get("event_id"), | |
| LocalContext.request: context.get("request"), | |
| LocalContext.progress: context.get("progress"), | |
| } | |
| try: | |
| for var, value in mapping.items(): | |
| tokens.append((var, var.set(value))) | |
| yield | |
| finally: | |
| for var, token in reversed(tokens): | |
| var.reset(token) | |
| def _resolve_main_bridge_component(elem_id: str): | |
| try: | |
| from gradio.context import Context, get_blocks_context | |
| except Exception as exc: | |
| raise RuntimeError(f"WanGP WebUI bridge component '{elem_id}' is unavailable outside the Gradio build context.") from exc | |
| blocks_context = get_blocks_context() | |
| if blocks_context is None and getattr(Context, "root_block", None) is not None: | |
| blocks_context = Context.root_block.default_config | |
| blocks = getattr(blocks_context, "blocks", None) | |
| if not isinstance(blocks, dict): | |
| raise RuntimeError(f"WanGP WebUI bridge component '{elem_id}' was not found in the current Blocks tree.") | |
| for component in blocks.values(): | |
| if getattr(component, "elem_id", None) == elem_id: | |
| return component | |
| raise RuntimeError(f"WanGP WebUI bridge component '{elem_id}' could not be resolved.") | |
| def _resolve_trigger_fn_index(blocks, session_hash: str, api_name: str, event_name: str) -> int: | |
| session_state = blocks.state_holder[session_hash] | |
| for block_fn in session_state.blocks_config.fns.values(): | |
| targets = getattr(block_fn, "targets", ()) or () | |
| if getattr(block_fn, "api_name", None) == api_name and any(target[1] == event_name for target in targets): | |
| return int(getattr(block_fn, "_id")) | |
| raise RuntimeError(f"WanGP WebUI trigger '{api_name}' was not found.") | |
| def _resolve_abort_fn_index(blocks, session_hash: str) -> int: | |
| session_state = blocks.state_holder[session_hash] | |
| for block_fn in session_state.blocks_config.fns.values(): | |
| targets = getattr(block_fn, "targets", ()) or () | |
| api_name = str(getattr(block_fn, "api_name", "") or "") | |
| if api_name.startswith("abort_generation") and any(target[1] == "change" for target in targets): | |
| return int(getattr(block_fn, "_id")) | |
| raise RuntimeError("WanGP WebUI abort trigger was not found.") | |
| def create_gradio_webui_session(plugin, *, init_fn, session_kwargs: dict[str, Any] | None = None) -> GradioWanGPSession: | |
| return GradioWanGPSession.for_plugin(plugin, init_fn=init_fn, session_kwargs=session_kwargs) | |
| def create_gradio_progress_callbacks(progress) -> GradioProgressCallbacks: | |
| return GradioProgressCallbacks(progress) | |
Xet Storage Details
- Size:
- 63.2 kB
- Xet hash:
- 525ffdf86a092a3f9e5c5e8187328efb491b96578c2a4a632c46a50edf164fba
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.