YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

sid-klein-lora-gan-patch-lpips-sid-anchor-20x-v4-step-0004500

Student UNet checkpoint with a full copy-paste inference example.

Files

Full inference code

import requests
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from diffusers import UNet2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import AutoencoderTiny
from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput
from diffusers.models.modeling_utils import ModelMixin

REPO_ID = "dim/sid-klein-lora-gan-patch-lpips-sid-anchor-20x-v4-step-0004500"
EXAMPLE_IMAGE_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/assets/example_input.jpg"
OUTPUT_PATH = "result.png"
RESOLUTION = 512
SEED = 0
MINIMAL_NOISE_R = 100.0
NUM_STEPS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16 if DEVICE.type == "cuda" else torch.float32

torch.manual_seed(SEED)
np.random.seed(SEED)


class DotDict(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class Flux2TinyAutoEncoderModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        latent_channels=128,
        encoder_block_out_channels=(64, 64, 64, 64),
        decoder_block_out_channels=(64, 64, 64, 64),
        act_fn="silu",
        upsampling_scaling_factor=2,
        num_encoder_blocks=(1, 3, 3, 3),
        num_decoder_blocks=(3, 3, 3, 1),
        latent_magnitude=3.0,
        latent_shift=0.5,
        force_upcast=False,
        scaling_factor=0.13025,
    ):
        super().__init__()
        self.tiny_vae = AutoencoderTiny(
            in_channels=in_channels,
            out_channels=out_channels,
            encoder_block_out_channels=list(encoder_block_out_channels),
            decoder_block_out_channels=list(decoder_block_out_channels),
            act_fn=act_fn,
            latent_channels=latent_channels // 4,
            upsampling_scaling_factor=upsampling_scaling_factor,
            num_encoder_blocks=list(num_encoder_blocks),
            num_decoder_blocks=list(num_decoder_blocks),
            latent_magnitude=latent_magnitude,
            latent_shift=latent_shift,
            force_upcast=force_upcast,
            scaling_factor=scaling_factor,
        )
        self.extra_encoder = nn.Conv2d(latent_channels // 4, latent_channels, kernel_size=4, stride=2, padding=1)
        self.extra_decoder = nn.ConvTranspose2d(latent_channels, latent_channels // 4, kernel_size=4, stride=2, padding=1)
        self.residual_encoder = nn.Sequential(
            nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1),
            nn.GroupNorm(8, latent_channels),
            nn.SiLU(),
            nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1),
        )
        self.residual_decoder = nn.Sequential(
            nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1),
            nn.GroupNorm(8, latent_channels // 4),
            nn.SiLU(),
            nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1),
        )

    def encode(self, x, return_dict=True):
        encoded = self.tiny_vae.encode(x, return_dict=False)[0]
        compressed = self.extra_encoder(encoded)
        enhanced = self.residual_encoder(compressed) + compressed
        if return_dict:
            return EncoderOutput(latent=enhanced)
        return enhanced

    def decode(self, z, return_dict=True):
        decompressed = self.extra_decoder(z)
        enhanced = self.residual_decoder(decompressed) + decompressed
        decoded = self.tiny_vae.decode(enhanced, return_dict=False)[0]
        if return_dict:
            return DecoderOutput(sample=decoded)
        return decoded


class Flux2TinyAutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.vae = Flux2TinyAutoEncoderModel.from_pretrained(
            "dim/fal_FLUX.2-Tiny-AutoEncoder_v6_2x_flux_klein_4B_lora_v2"
        ).to(device=DEVICE, dtype=DTYPE)
        self.config = DotDict(scaling_factor=1.0)

    def encode(self, x, return_dict=True):
        latent = torch.nn.functional.pixel_shuffle(self.vae.encode(x).latent, 2)
        if return_dict:
            return DotDict(latent_dist=DotDict(mode=lambda: latent))
        return [DotDict(sample=lambda generator: latent)]

    def decode(self, x, return_dict=True):
        decoded = self.vae.decode(torch.nn.functional.pixel_unshuffle(x, 2)).sample
        if return_dict:
            return DotDict(sample=decoded.unsqueeze(0))
        return decoded.unsqueeze(0)


def create_frequency_soft_cutoff_mask(height, width, cutoff_radius, transition_width=2.0, device=None):
    if device is None:
        device = torch.device("cpu")
    u = torch.arange(height, device=device)
    v = torch.arange(width, device=device)
    u, v = torch.meshgrid(u, v, indexing="ij")
    center_u, center_v = height // 2, width // 2
    radius = torch.sqrt((u - center_u) ** 2 + (v - center_v) ** 2)
    mask = torch.exp(-((radius - cutoff_radius) ** 2) / (2 * transition_width**2))
    return torch.where(radius <= cutoff_radius, torch.ones_like(mask), mask)


def generate_structured_noise_batch(image_batch, cutoff_radius, input_noise, transition_width=2.0, pad_factor=1.5):
    _, _, height, width = image_batch.shape
    image_batch = image_batch.float()
    input_noise = input_noise.float()
    device = image_batch.device

    pad_h = (int(height * (pad_factor - 1)) // 2) * 2
    pad_w = (int(width * (pad_factor - 1)) // 2) * 2

    padded_images = torch.nn.functional.pad(image_batch, (pad_w // 2, pad_w // 2, pad_h // 2, pad_h // 2), mode="reflect")
    padded_noise = torch.nn.functional.pad(input_noise, (pad_w // 2, pad_w // 2, pad_h // 2, pad_h // 2), mode="reflect")

    padded_height = height + pad_h
    padded_width = width + pad_w
    cutoff_radius = min(min(padded_height / 2, padded_width / 2), cutoff_radius)
    freq_mask = create_frequency_soft_cutoff_mask(padded_height, padded_width, cutoff_radius, transition_width=transition_width, device=device)

    image_fft = torch.fft.fftshift(torch.fft.fft2(padded_images, dim=(-2, -1)), dim=(-2, -1))
    noise_fft = torch.fft.fftshift(torch.fft.fft2(padded_noise, dim=(-2, -1)), dim=(-2, -1))

    image_phase = torch.angle(image_fft)
    noise_phase = torch.angle(noise_fft)
    noise_magnitude = torch.abs(noise_fft)

    mixed_phase = freq_mask.unsqueeze(0).unsqueeze(0) * image_phase + (1 - freq_mask.unsqueeze(0).unsqueeze(0)) * noise_phase
    mixed_fft = noise_magnitude * torch.exp(1j * mixed_phase)
    structured_noise = torch.fft.ifft2(torch.fft.ifftshift(mixed_fft, dim=(-2, -1)), dim=(-2, -1)).real

    return structured_noise[:, :, pad_h // 2 : pad_h // 2 + height, pad_w // 2 : pad_w // 2 + width].to(dtype=image_batch.dtype)


preprocess = transforms.Compose([
    transforms.Resize(RESOLUTION, interpolation=transforms.InterpolationMode.LANCZOS),
    transforms.CenterCrop(RESOLUTION),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

to_pil = transforms.ToPILImage()
vae = Flux2TinyAutoEncoder().requires_grad_(False).eval()
unet = UNet2DModel.from_pretrained(REPO_ID, subfolder="unet", torch_dtype=DTYPE).to(DEVICE).eval()
noise_scheduler = FlowMatchEulerDiscreteScheduler()


@torch.no_grad()
def generate_image(source_pil, num_steps=NUM_STEPS, minimal_noise_r=MINIMAL_NOISE_R, seed=SEED):
    source_tensor = preprocess(source_pil).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
    generator = torch.Generator(device="cuda") if DEVICE.type == "cuda" else torch.Generator()
    generator.manual_seed(seed)

    latent_dist = vae.encode(source_tensor, return_dict=False)[0]
    z_source = latent_dist.sample(generator=generator) * vae.config.scaling_factor

    input_noise = torch.randn(z_source.shape, generator=generator, device=z_source.device, dtype=torch.float32)
    sample = generate_structured_noise_batch(z_source.float(), cutoff_radius=minimal_noise_r, input_noise=input_noise).to(dtype=z_source.dtype, device=z_source.device)

    sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps)
    noise_scheduler.set_timesteps(sigmas=sigmas, device=DEVICE)

    for timestep in noise_scheduler.timesteps:
        if hasattr(noise_scheduler, "scale_model_input"):
            model_input = noise_scheduler.scale_model_input(sample, timestep)
        else:
            model_input = sample
        model_input = torch.cat([model_input, z_source], dim=1)

        pred = unet(model_input, timestep.to(z_source.device).repeat(model_input.shape[0]), return_dict=False)[0]
        sample = noise_scheduler.step(pred, timestep, sample, return_dict=False)[0]

    decoded = vae.decode(sample / vae.config.scaling_factor, return_dict=False)[0]
    decoded = decoded.clamp(-1, 1)
    return to_pil(decoded[0].cpu().float() * 0.5 + 0.5)


image = Image.open(requests.get(EXAMPLE_IMAGE_URL, stream=True).raw).convert("RGB")
result = generate_image(image)
result.save(OUTPUT_PATH)
print(OUTPUT_PATH)
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support