Daankular/models / Wan2GP /shared /utils /virtual_media.py
Daankular's picture
download
raw
8.33 kB
from __future__ import annotations
from dataclasses import dataclass
from threading import RLock
from typing import Any
_VIRTUAL_MEDIA_SOURCES: dict[str, dict[str, dict[str, Any]]] = {}
_VIRTUAL_MEDIA_LOCK = RLock()
@dataclass(frozen=True)
class VirtualMediaSpec:
source_path: str
start_frame: int = 0
end_frame: int | None = None
audio_track_no: int | None = None
extras: tuple[tuple[str, str], ...] = ()
def as_suffix_items(self) -> list[tuple[str, str]]:
items: list[tuple[str, str]] = []
if self.start_frame != 0:
items.append(("start_frame", str(int(self.start_frame))))
if self.end_frame is not None:
items.append(("end_frame", str(int(self.end_frame))))
if self.audio_track_no is not None:
items.append(("audio_track_no", str(int(self.audio_track_no))))
items.extend(list(self.extras))
return items
def parse_virtual_media_path(value: Any) -> VirtualMediaSpec | None:
if not isinstance(value, str):
return None
text = value.strip()
if "|" not in text:
return None
source_path, suffix = text.split("|", 1)
source_path = source_path.strip()
if len(source_path) == 0:
return None
values: dict[str, str] = {}
extras: list[tuple[str, str]] = []
for raw_item in suffix.split(","):
item = raw_item.strip()
if len(item) == 0:
continue
key, sep, raw_value = item.partition("=")
key = key.strip().lower()
if sep == "":
extras.append((item, ""))
continue
value_text = raw_value.strip()
if key in ("start_frame", "end_frame", "audio_track_no"):
values[key] = value_text
else:
extras.append((key, value_text))
return VirtualMediaSpec(
source_path=source_path,
start_frame=_parse_int(values.get("start_frame"), 0),
end_frame=_parse_optional_int(values.get("end_frame")),
audio_track_no=_parse_optional_int(values.get("audio_track_no")),
extras=tuple(extras),
)
def strip_virtual_media_suffix(value: Any) -> Any:
spec = parse_virtual_media_path(value)
return spec.source_path if spec is not None else value
def build_virtual_media_path(
source_path: str,
*,
start_frame: int | None = None,
end_frame: int | None = None,
audio_track_no: int | None = None,
extras: dict[str, Any] | None = None,
) -> str:
base_path = str(source_path or "").strip()
if len(base_path) == 0:
return base_path
parts: list[str] = []
if start_frame is not None:
parts.append(f"start_frame={int(start_frame)}")
if end_frame is not None:
parts.append(f"end_frame={int(end_frame)}")
if audio_track_no is not None:
parts.append(f"audio_track_no={int(audio_track_no)}")
for key, value in (extras or {}).items():
key_text = str(key or "").strip()
value_text = str(value or "").strip()
if len(key_text) == 0 or len(value_text) == 0:
continue
parts.append(f"{key_text}={value_text}")
return base_path if len(parts) == 0 else f"{base_path}|{','.join(parts)}"
def replace_virtual_media_source(value: Any, source_path: str) -> Any:
spec = parse_virtual_media_path(value)
if spec is None:
return source_path
extras = dict(spec.extras)
return build_virtual_media_path(
source_path,
start_frame=spec.start_frame if spec.start_frame != 0 else None,
end_frame=spec.end_frame,
audio_track_no=spec.audio_track_no,
extras=extras,
)
def clamp_virtual_frame_range(spec: VirtualMediaSpec | None, total_frames: int) -> tuple[int, int | None]:
if spec is None:
return 0, None
explicit_frame_count = _parse_int(_extras_dict(spec).get("frame_count"), 0)
total_frames = explicit_frame_count if explicit_frame_count > 0 else int(total_frames or 0)
if total_frames <= 0:
return 0, None
start_frame = _resolve_relative_frame_index(spec.start_frame, total_frames, default_to_end=False)
end_frame = _resolve_relative_frame_index(spec.end_frame, total_frames, default_to_end=True)
if end_frame is None:
return start_frame, total_frames - 1
end_frame = max(start_frame, end_frame)
return start_frame, end_frame
def get_virtual_media_vsource(value: Any) -> str | None:
spec = value if isinstance(value, VirtualMediaSpec) else parse_virtual_media_path(value)
if spec is None:
return None
vsource = _extras_dict(spec).get("vsource", "").strip()
return vsource or None
def get_virtual_media_entry(value: Any) -> dict[str, Any] | None:
spec = value if isinstance(value, VirtualMediaSpec) else parse_virtual_media_path(value)
if spec is None:
return None
vsource = get_virtual_media_vsource(spec)
if vsource is None:
return None
with _VIRTUAL_MEDIA_LOCK:
entry = _VIRTUAL_MEDIA_SOURCES.get(vsource, {}).get(str(spec.source_path or "").strip())
return None if entry is None else dict(entry)
def store_virtual_video(vsource: str, name: str, tensor: Any, fps: float, *, hdr: bool = False) -> None:
import torch
tensor = tensor.detach().cpu().to(dtype=torch.float32).contiguous().clone()
with _VIRTUAL_MEDIA_LOCK:
_VIRTUAL_MEDIA_SOURCES.setdefault(str(vsource).strip(), {})[str(name).strip()] = {
"kind": "video",
"tensor": tensor,
"fps": max(float(fps or 0.0), 1.0),
"hdr": bool(hdr),
}
_clear_virtual_media_caches()
def get_virtual_video(value: Any, name: str | None = None) -> Any:
entry = _get_virtual_media_entry_from_key(value, name)
if entry is None or entry.get("kind") != "video":
return None
tensor = entry.get("tensor")
return None if tensor is None else tensor.clone()
def store_virtual_image(vsource: str, name: str, image: Any) -> None:
with _VIRTUAL_MEDIA_LOCK:
_VIRTUAL_MEDIA_SOURCES.setdefault(str(vsource).strip(), {})[str(name).strip()] = {"kind": "image", "image": image.copy()}
_clear_virtual_media_caches()
def get_virtual_image(value: Any, name: str | None = None) -> Any:
entry = _get_virtual_media_entry_from_key(value, name)
if entry is None or entry.get("kind") != "image":
return None
image = entry.get("image")
return None if image is None else image.copy()
def clear_virtual_media_source(vsource: str) -> None:
with _VIRTUAL_MEDIA_LOCK:
_VIRTUAL_MEDIA_SOURCES.pop(str(vsource).strip(), None)
_clear_virtual_media_caches()
def _parse_int(value: Any, default: int) -> int:
try:
return int(str(value or "").strip())
except (TypeError, ValueError):
return default
def _parse_optional_int(value: Any) -> int | None:
text = str(value or "").strip()
if len(text) == 0:
return None
try:
return int(text)
except (TypeError, ValueError):
return None
def _resolve_relative_frame_index(value: int | None, total_frames: int, *, default_to_end: bool) -> int | None:
if total_frames <= 0:
return None if value is None else 0
if value is None:
return total_frames - 1 if default_to_end else 0
index = int(value)
if index < 0:
index = total_frames + index
return max(0, min(index, total_frames - 1))
def _extras_dict(spec: VirtualMediaSpec) -> dict[str, str]:
return {str(key or "").strip().lower(): str(value or "").strip() for key, value in spec.extras}
def _get_virtual_media_entry_from_key(value: Any, name: str | None = None) -> dict[str, Any] | None:
if name is None:
return get_virtual_media_entry(value)
with _VIRTUAL_MEDIA_LOCK:
entry = _VIRTUAL_MEDIA_SOURCES.get(str(value).strip(), {}).get(str(name).strip())
return None if entry is None else dict(entry)
def _clear_virtual_media_caches() -> None:
try:
from . import video_decode as _video_decode
_video_decode.probe_video_stream_metadata.cache_clear()
except Exception:
pass
try:
from . import utils as _utils
_utils._get_video_info_cached.cache_clear()
_utils._get_video_info_details_cached.cache_clear()
except Exception:
pass

Xet Storage Details

Size:
8.33 kB
·
Xet hash:
1052f1ef984c011645c56d9ea381bd240770a57999a632465aac7c3cfdabbdcb

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.