| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """VAE model for WorldEngine frame encoding/decoding.""" |
| |
|
| | from dataclasses import dataclass |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.models.modeling_utils import ModelMixin |
| | from .dcae import Encoder, Decoder |
| |
|
| |
|
| | @dataclass |
| | class EncoderDecoderConfig: |
| | """Config object for Encoder/Decoder initialization.""" |
| |
|
| | sample_size: Tuple[int, int] |
| | channels: int |
| | latent_channels: int |
| | ch_0: int |
| | ch_max: int |
| | encoder_blocks_per_stage: List[int] |
| | decoder_blocks_per_stage: List[int] |
| | use_middle_block: bool |
| | skip_logvar: bool = False |
| | skip_residuals: bool = False |
| | normalize_mu: bool = False |
| |
|
| |
|
| | class WorldEngineVAE(ModelMixin, ConfigMixin): |
| | """ |
| | VAE for encoding/decoding video frames using DCAE architecture. |
| | |
| | Encodes RGB uint8 images to latent space and decodes latents back to RGB. |
| | """ |
| |
|
| | _supports_gradient_checkpointing = False |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | |
| | sample_size: Tuple[int, int] = (360, 640), |
| | channels: int = 3, |
| | latent_channels: int = 16, |
| | |
| | encoder_ch_0: int = 64, |
| | encoder_ch_max: int = 256, |
| | encoder_blocks_per_stage: List[int] = None, |
| | |
| | decoder_ch_0: int = 128, |
| | decoder_ch_max: int = 1024, |
| | decoder_blocks_per_stage: List[int] = None, |
| | |
| | use_middle_block: bool = False, |
| | skip_logvar: bool = False, |
| | |
| | scale_factor: float = 1.0, |
| | shift_factor: float = 0.0, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | if encoder_blocks_per_stage is None: |
| | encoder_blocks_per_stage = [1, 1, 1, 1] |
| | if decoder_blocks_per_stage is None: |
| | decoder_blocks_per_stage = [1, 1, 1, 1] |
| |
|
| | |
| | encoder_config = EncoderDecoderConfig( |
| | sample_size=tuple(sample_size), |
| | channels=channels, |
| | latent_channels=latent_channels, |
| | ch_0=encoder_ch_0, |
| | ch_max=encoder_ch_max, |
| | encoder_blocks_per_stage=list(encoder_blocks_per_stage), |
| | decoder_blocks_per_stage=list(decoder_blocks_per_stage), |
| | use_middle_block=use_middle_block, |
| | skip_logvar=skip_logvar, |
| | ) |
| |
|
| | |
| | decoder_config = EncoderDecoderConfig( |
| | sample_size=tuple(sample_size), |
| | channels=channels, |
| | latent_channels=latent_channels, |
| | ch_0=decoder_ch_0, |
| | ch_max=decoder_ch_max, |
| | encoder_blocks_per_stage=list(encoder_blocks_per_stage), |
| | decoder_blocks_per_stage=list(decoder_blocks_per_stage), |
| | use_middle_block=use_middle_block, |
| | skip_logvar=skip_logvar, |
| | ) |
| |
|
| | self.encoder = Encoder(encoder_config) |
| | self.decoder = Decoder(decoder_config) |
| |
|
| | def encode(self, img: Tensor): |
| | """RGB -> RGB+D -> latent""" |
| | assert img.dim() == 3, "Expected [H, W, C] image tensor" |
| | img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype) |
| | rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1) |
| | return self.encoder(rgb) |
| |
|
| | @torch.compile |
| | def decode(self, latent: Tensor): |
| | decoded = self.decoder(latent) |
| | decoded = (decoded / 2 + 0.5).clamp(0, 1) |
| | decoded = (decoded * 255).round().to(torch.uint8) |
| | return decoded.squeeze(0).permute(1, 2, 0)[..., :3] |
| |
|
| | def forward(self, x: Tensor, encode: bool = True) -> Tensor: |
| | """ |
| | Forward pass - encode or decode based on flag. |
| | |
| | Args: |
| | x: Input tensor (image for encode, latent for decode) |
| | encode: If True, encode; if False, decode |
| | |
| | Returns: |
| | Encoded latent or decoded image |
| | """ |
| | if encode: |
| | return self.encode(x) |
| | else: |
| | return self.decode(x) |
| |
|