Daankular's picture
download
raw
5.76 kB
from __future__ import annotations
import os
from pathlib import Path
from typing import Iterable
import numpy as np
import torch
HDR_REFERENCE_WHITE_NITS = 203.0
HDR10_MASTER_DISPLAY = "G(13250,34500)B(7500,3000)R(34000,16000)WP(15635,16450)L(10000000,1)"
HDR10_MAX_CLL = "10000,400"
VIDEO_PROMPT_HDR_OUTPUT_FLAG = "&"
def hdr10_zscale_filter(*, reference_white_nits: float = HDR_REFERENCE_WHITE_NITS) -> str:
return (
"zscale=pin=709:tin=linear:min=gbr:rin=full:"
f"p=2020:t=smpte2084:m=2020_ncl:r=limited:npl={float(reference_white_nits):.12g},"
"format=yuv420p10le"
)
def hdr10_x265_params() -> str:
return f"hdr10=1:repeat-headers=1:master-display={HDR10_MASTER_DISPLAY}:max-cll={HDR10_MAX_CLL}:log-level=none"
class LogC3:
A = 5.555556
B = 0.052272
C = 0.247190
D = 0.385537
E = 5.367655
F = 0.092809
CUT = 0.010591
def compress(self, hdr: torch.Tensor) -> torch.Tensor:
x = torch.clamp(hdr, min=0.0)
log_part = self.C * torch.log10(self.A * x + self.B) + self.D
lin_part = self.E * x + self.F
return torch.where(x >= self.CUT, log_part, lin_part).clamp_(0.0, 1.0)
def compress_ldr(self, ldr: torch.Tensor) -> torch.Tensor:
return torch.clamp(ldr, 0.0, 1.0)
def decompress(self, logc: torch.Tensor) -> torch.Tensor:
logc = torch.clamp(logc, 0.0, 1.0)
cut_log = self.E * self.CUT + self.F
lin_from_log = (torch.pow(10.0, (logc - self.D) / self.C) - self.B) / self.A
lin_from_lin = (logc - self.F) / self.E
return torch.where(logc >= cut_log, lin_from_log, lin_from_lin).clamp_(min=0.0)
def hdr_linear_to_vae_range(frames: torch.Tensor, *, transform: str = "logc3") -> torch.Tensor:
frames = frames.to(dtype=torch.float32)
if transform != "logc3":
raise ValueError(f"Unsupported HDR transform: {transform}")
return LogC3().compress(frames).mul_(2.0).sub_(1.0)
def vae_range_to_hdr_linear(frames: torch.Tensor, *, transform: str = "logc3") -> torch.Tensor:
frames = frames.to(dtype=torch.float32).add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
if transform != "logc3":
raise ValueError(f"Unsupported HDR transform: {transform}")
return LogC3().decompress(frames)
def linear_to_srgb(linear: torch.Tensor) -> torch.Tensor:
linear = torch.clamp(linear, 0.0, 1.0)
low = linear * 12.92
high = 1.055 * torch.pow(linear, 1.0 / 2.4) - 0.055
return torch.where(linear <= 0.0031308, low, high).clamp_(0.0, 1.0)
def tonemap_hdr_tensor_to_uint8(video: torch.Tensor, *, exposure: float = 0.0) -> torch.Tensor:
if video.ndim == 5 and video.shape[0] == 1:
video = video[0]
if video.ndim != 4:
raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(video.shape)}.")
scale = float(2.0 ** float(exposure))
srgb = linear_to_srgb(video.to(dtype=torch.float32).mul(scale))
return srgb.mul(255.0).round_().clamp_(0.0, 255.0).to(torch.uint8)
def iter_video_chunks(video: torch.Tensor | Iterable[torch.Tensor]):
if torch.is_tensor(video):
yield video
return
for chunk in video:
if chunk is not None:
yield chunk
def iter_hdr_gbrpf32_frames(video: torch.Tensor | Iterable[torch.Tensor]):
for chunk in iter_video_chunks(video):
if chunk is None:
continue
if chunk.ndim == 5 and chunk.shape[0] == 1:
chunk = chunk[0]
if chunk.ndim != 4:
raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(chunk.shape)}.")
frames = chunk.detach().cpu().to(dtype=torch.float32)
for frame in frames.permute(1, 0, 2, 3):
yield frame[[1, 2, 0]].contiguous().numpy().astype(np.float32, copy=False).tobytes()
def write_hdr_exr_frames(
video: torch.Tensor,
output_dir: str | os.PathLike[str],
*,
start_index: int = 0,
exr_half: bool = True,
) -> int:
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1")
import cv2
if video.ndim == 5 and video.shape[0] == 1:
video = video[0]
if video.ndim != 4:
raise ValueError(f"Expected [C,F,H,W] HDR tensor, got {tuple(video.shape)}.")
Path(output_dir).mkdir(parents=True, exist_ok=True)
frame_count = int(video.shape[1])
params: list[int] = []
if exr_half and hasattr(cv2, "IMWRITE_EXR_TYPE") and hasattr(cv2, "IMWRITE_EXR_TYPE_HALF"):
params = [int(cv2.IMWRITE_EXR_TYPE), int(cv2.IMWRITE_EXR_TYPE_HALF)]
frames = video.detach().cpu().to(dtype=torch.float32).permute(1, 2, 3, 0).contiguous()
for idx, frame in enumerate(frames, start=int(start_index)):
rgb = frame.numpy().astype(np.float32, copy=False)
bgr = np.ascontiguousarray(rgb[..., ::-1])
path = os.path.join(os.fspath(output_dir), f"frame_{idx:06d}.exr")
if not cv2.imwrite(path, bgr, params):
raise RuntimeError(f"Failed to write HDR EXR frame: {path}")
return frame_count
def read_hdr_exr_frames(
output_dir: str | os.PathLike[str],
*,
start_index: int,
frame_count: int,
) -> torch.Tensor | None:
os.environ.setdefault("OPENCV_IO_ENABLE_OPENEXR", "1")
import cv2
frames = []
for idx in range(int(start_index), int(start_index) + int(frame_count)):
path = os.path.join(os.fspath(output_dir), f"frame_{idx:06d}.exr")
if not os.path.isfile(path):
return None
bgr = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if bgr is None:
return None
rgb = np.ascontiguousarray(bgr[..., ::-1]).astype(np.float32, copy=False)
frames.append(torch.from_numpy(rgb))
if not frames:
return None
return torch.stack(frames, dim=0).permute(3, 0, 1, 2).contiguous()

Xet Storage Details

Size:
5.76 kB
·
Xet hash:
0edfeec9da1edb62d4d8cea763e892f61a3b58e378afbd30b74d366755c21cdc

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.