| import contextlib | |
| import inspect | |
| import sys | |
| import threading | |
| import time | |
| from typing import Any | |
| from shared.api import GenerationError, GenerationResult, SessionJob, _GENERATION_LOCK, _OutputCapture, _pushd | |
| from shared.utils.thread_utils import AsyncStream | |
| def run_cli_job(session, job: SessionJob, tasks: list[dict[str, Any]]) -> None: | |
| stream = AsyncStream() | |
| gen = session._state["gen"] | |
| worker_done = threading.Event() | |
| base_file_count = len(gen["file_list"]) | |
| base_audio_count = len(gen["audio_file_list"]) | |
| total_tasks = len(tasks) | |
| runtime = None | |
| task_summary: dict[str, Any] = { | |
| "errors": [], | |
| "successful_tasks": 0, | |
| "failed_tasks": 0, | |
| "total_tasks": total_tasks, | |
| } | |
| try: | |
| runtime = session._ensure_runtime() | |
| with _GENERATION_LOCK, _pushd(runtime.root): | |
| session._configure_runtime(runtime) | |
| session._prepare_state_for_run(tasks) | |
| job.events.put("started", {"tasks": len(tasks)}) | |
| def worker() -> None: | |
| stdout_capture = _OutputCapture( | |
| "stdout", | |
| lambda stream_name, line: session._emit_stream(job, stream_name, line), | |
| console=sys.__stdout__ if session._console_output else None, | |
| console_isatty=session._console_isatty, | |
| ) | |
| stderr_capture = _OutputCapture( | |
| "stderr", | |
| lambda stream_name, line: session._emit_stream(job, stream_name, line), | |
| console=sys.__stderr__ if session._console_output else None, | |
| console_isatty=session._console_isatty, | |
| ) | |
| try: | |
| with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture): | |
| _run_tasks_worker(session, runtime.module, tasks, stream, job, task_summary) | |
| except BaseException as exc: | |
| failure = session._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") | |
| task_summary["errors"].append(failure) | |
| stream.output_queue.push("error", failure) | |
| finally: | |
| stdout_capture.flush() | |
| stderr_capture.flush() | |
| stream.output_queue.push("worker_exit", None) | |
| worker_done.set() | |
| worker_thread = threading.Thread(target=worker, daemon=True, name="wangp-session-worker") | |
| worker_thread.start() | |
| while True: | |
| if job.cancel_requested: | |
| session._request_cancel_unlocked(runtime.module) | |
| item = stream.output_queue.pop() | |
| if item is None: | |
| if worker_done.is_set() and not worker_thread.is_alive(): | |
| break | |
| time.sleep(0.01) | |
| continue | |
| command, data = item | |
| if command == "worker_exit": | |
| break | |
| _handle_command(session, job, runtime.module, tasks, command, data) | |
| worker_thread.join(timeout=0.1) | |
| outputs = session._collect_outputs(base_file_count, base_audio_count) | |
| artifacts = session._consume_output_artifacts(tasks) | |
| if job.cancel_requested and not task_summary["errors"]: | |
| task_summary["errors"].append(GenerationError(message="Generation was cancelled", stage="cancelled")) | |
| task_summary["failed_tasks"] = max(task_summary["failed_tasks"], 1) | |
| result = GenerationResult( | |
| success=not task_summary["errors"], | |
| generated_files=outputs, | |
| errors=list(task_summary["errors"]), | |
| total_tasks=task_summary["total_tasks"], | |
| successful_tasks=task_summary["successful_tasks"], | |
| failed_tasks=task_summary["failed_tasks"], | |
| artifacts=artifacts, | |
| ) | |
| job.events.put("completed", result) | |
| session._emit_callback("on_complete", result, job=job) | |
| job._set_result(result) | |
| except BaseException as exc: | |
| failure = session._make_generation_error(exc, task_index=None, task_id=None, stage="runtime") | |
| result = GenerationResult( | |
| success=False, | |
| generated_files=[], | |
| errors=[failure], | |
| total_tasks=total_tasks, | |
| successful_tasks=task_summary["successful_tasks"], | |
| failed_tasks=max(task_summary["failed_tasks"], 1 if total_tasks > 0 else 0), | |
| artifacts=(), | |
| ) | |
| job.events.put("error", failure) | |
| session._emit_callback("on_error", failure, job=job) | |
| job.events.put("completed", result) | |
| session._emit_callback("on_complete", result, job=job) | |
| job._set_result(result) | |
| finally: | |
| job.events.close() | |
| if runtime is not None: | |
| session._reset_state_after_run() | |
| with session._job_lock: | |
| if session._active_job is job: | |
| session._active_job = None | |
| def _run_tasks_worker(session, wgp, tasks: list[dict[str, Any]], stream: AsyncStream, job: SessionJob, task_summary: dict[str, Any]) -> None: | |
| expected_args = set(inspect.signature(wgp.generate_video).parameters.keys()) | |
| total_tasks = len(tasks) | |
| for task_index, task in enumerate(tasks, start=1): | |
| if job.cancel_requested: | |
| break | |
| session._state["gen"]["prompt_no"] = task_index | |
| session._state["gen"]["prompts_max"] = total_tasks | |
| session._state["gen"]["queue"] = tasks | |
| task_id = task.get("id") | |
| task_errors: list[GenerationError] = [] | |
| def send_cmd(command: str, data: Any = None) -> None: | |
| if command == "error": | |
| failure = session._make_generation_error(data, task_index=task_index, task_id=task_id, stage="generation") | |
| task_errors.append(failure) | |
| stream.output_queue.push("error", failure) | |
| return | |
| stream.output_queue.push(command, data) | |
| validated_settings, validation_error = wgp.validate_task(task, session._state) | |
| if validated_settings is None: | |
| failure = GenerationError( | |
| message=validation_error or f"Task {task_index} failed validation", | |
| task_index=task_index, | |
| task_id=task_id, | |
| stage="validation", | |
| ) | |
| task_summary["errors"].append(failure) | |
| task_summary["failed_tasks"] += 1 | |
| stream.output_queue.push("error", failure) | |
| continue | |
| task_settings = validated_settings.copy() | |
| task_settings["state"] = session._state | |
| filtered_params = {key: value for key, value in task_settings.items() if key in expected_args} | |
| plugin_data = task.get("plugin_data", {}) | |
| try: | |
| success = wgp.generate_video(task, send_cmd, plugin_data=plugin_data, **filtered_params) | |
| except BaseException as exc: | |
| if not task_errors: | |
| task_errors.append(session._make_generation_error(exc, task_index=task_index, task_id=task_id, stage="generation")) | |
| stream.output_queue.push("error", task_errors[-1]) | |
| success = False | |
| if session._state["gen"].get("abort", False) or job.cancel_requested: | |
| task_errors.append(GenerationError(message="Generation was cancelled", task_index=task_index, task_id=task_id, stage="cancelled")) | |
| stream.output_queue.push("error", task_errors[-1]) | |
| task_summary["errors"].extend(task_errors) | |
| task_summary["failed_tasks"] += 1 | |
| break | |
| if task_errors: | |
| task_summary["errors"].extend(task_errors) | |
| task_summary["failed_tasks"] += 1 | |
| continue | |
| if not success: | |
| failure = GenerationError( | |
| message=f"Task {task_index} did not complete successfully", | |
| task_index=task_index, | |
| task_id=task_id, | |
| stage="generation", | |
| ) | |
| task_summary["errors"].append(failure) | |
| task_summary["failed_tasks"] += 1 | |
| stream.output_queue.push("error", failure) | |
| continue | |
| task_summary["successful_tasks"] += 1 | |
| def _handle_command(session, job: SessionJob, wgp, tasks: list[dict[str, Any]], command: str, data: Any) -> None: | |
| if command == "progress": | |
| progress = session._build_progress_update(data) | |
| job.events.put("progress", progress) | |
| session._emit_callback("on_progress", progress, job=job) | |
| return | |
| if command == "preview": | |
| preview = session._build_preview_update(wgp, tasks, data) | |
| if preview is not None: | |
| job.events.put("preview", preview) | |
| session._emit_callback("on_preview", preview, job=job) | |
| return | |
| if command == "status": | |
| text = str(data or "") | |
| job.events.put("status", text) | |
| session._emit_callback("on_status", text, job=job) | |
| return | |
| if command == "info": | |
| text = str(data or "") | |
| job.events.put("info", text) | |
| session._emit_callback("on_info", text, job=job) | |
| return | |
| if command == "output": | |
| job.events.put("output", data) | |
| session._emit_callback("on_output", data, job=job) | |
| return | |
| if command == "refresh_models": | |
| job.events.put("refresh_models", data) | |
| return | |
| if command == "error": | |
| error = data if isinstance(data, GenerationError) else session._make_generation_error(data) | |
| job.events.put("error", error) | |
| session._emit_callback("on_error", error, job=job) | |
| return | |
Xet Storage Details
- Size:
- 9.78 kB
- Xet hash:
- e9be678058d71d46e61a7049b2b78b001f9d2eebe31c095882e7fa54fabc3e29
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.