Elea Zhong
update demo
6d99887
import os
from pathlib import Path
import warnings
from PIL import Image
from diffusers.pipelines.qwenimage.pipeline_qwenimage import QwenImagePipeline
import lpips
import torch
from safetensors.torch import load_file, save_model
import torch.nn.functional as F
import torchvision.transforms.v2.functional as TF
from einops import rearrange
from qwenimage.datamodels import QwenConfig, QwenInputs
from qwenimage.debug import clear_cuda_memory, ctimed, ftimed, print_gpu_memory, texam
from qwenimage.experiments.quantize_text_encoder_experiments import quantize_text_encoder_int4wo_linear
from qwenimage.experiments.quantize_experiments import quantize_transformer_fp8darow_nolast
from qwenimage.loss import LossAccumulator
from qwenimage.models.pipeline_qwenimage_edit_plus import CONDITION_IMAGE_SIZE, QwenImageEditPlusPipeline, calculate_dimensions
from qwenimage.models.pipeline_qwenimage_edit_save_interm import QwenImageEditSaveIntermPipeline
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.optimization import simple_quantize_model
from qwenimage.sampling import TimestepDistUtils
from wandml import WandModel
from wandml.core.logger import wand_logger
from wandml.trainers.experiment_trainer import ExperimentTrainer
class QwenImageFoundation(WandModel):
SOURCE = "Qwen/Qwen-Image-Edit-2509"
INPUT_MODEL = QwenInputs
CACHE_DIR = "qwen_image_edit_2509"
PIPELINE = QwenImageEditPlusPipeline
serialize_modules = ["transformer"]
def __init__(self, config:QwenConfig, device=None):
super().__init__()
self.config:QwenConfig = config
self.dtype = torch.bfloat16
if device is None:
default_device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = default_device
else:
self.device = device
print(f"{self.device=}")
pipe = self.PIPELINE.from_pretrained(
"Qwen/Qwen-Image-Edit-2509",
transformer=QwenImageTransformer2DModel.from_pretrained(
"Qwen/Qwen-Image-Edit-2509",
subfolder='transformer',
torch_dtype=self.dtype,
device_map=self.device
),
torch_dtype=self.dtype,
)
pipe = pipe.to(device=self.device, dtype=self.dtype)
if config.load_multi_view_lora:
pipe.load_lora_weights(
"dx8152/Qwen-Edit-2509-Multiple-angles",
weight_name="镜头转换.safetensors", adapter_name="angles"
)
pipe.set_adapters(["angles"], adapter_weights=[1.])
pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.25)
pipe.unload_lora_weights()
self.pipe = pipe
self.vae = self.pipe.vae
self.transformer = self.pipe.transformer
self.text_encoder = self.pipe.text_encoder
self.scheduler = self.pipe.scheduler
self.vae.to(self.dtype)
self.vae.eval()
self.vae.requires_grad_(False)
self.text_encoder.eval()
self.text_encoder.requires_grad_(False)
self.text_encoder_device = None
self.transformer.eval()
self.transformer.requires_grad_(False)
if self.config.gradient_checkpointing:
self.transformer.enable_gradient_checkpointing()
if self.config.vae_tiling:
self.vae.enable_tiling(
576,
576,
512,
512
)
self.timestep_dist_utils = TimestepDistUtils(
min_seq_len=self.scheduler.config.base_image_seq_len,
max_seq_len=self.scheduler.config.max_image_seq_len,
min_mu=self.scheduler.config.base_shift,
max_mu=self.scheduler.config.max_shift,
train_dist=self.config.train_dist,
train_shift=self.config.train_shift,
inference_dist=self.config.inference_dist,
inference_shift=self.config.inference_shift,
static_mu=self.config.static_mu,
loss_weight_dist=self.config.loss_weight_dist,
)
if self.config.quantize_text_encoder:
quantize_text_encoder_int4wo_linear(self.text_encoder)
if self.config.quantize_transformer:
quantize_transformer_fp8darow_nolast(self.transformer)
def load(self, load_path):
if not isinstance(load_path, Path):
load_path = Path(load_path)
if not load_path.is_dir():
raise ValueError(f"Expected {load_path=} to be a directory")
for module_name in self.serialize_modules:
model_state_dict = load_file(load_path / f"{module_name}.safetensors")
missing, unexpected = getattr(self, module_name).load_state_dict(model_state_dict, strict=False, assign=True)
if missing:
warnings.warn(f"{module_name} missing {missing}")
if unexpected:
warnings.warn(f"{module_name} unexpected {unexpected}")
def save(self, save_path, skip=False):
if skip: return
if not isinstance(save_path, Path):
save_path = Path(save_path)
if not save_path.is_dir():
raise ValueError(f"Expected {save_path=} to be a directory")
save_path.mkdir(parents=True, exist_ok=True)
for module_name in self.serialize_modules:
save_model(getattr(self, module_name), save_path / f"{module_name}.safetensors")
print(f"Saved {module_name} to {save_path}")
def get_train_params(self):
return [{"params": [p for p in self.transformer.parameters() if p.requires_grad]}]
def pil_to_latents(self, images):
image = self.pipe.image_processor.preprocess(images)
h,w = image.shape[-2:]
h_r, w_r = calculate_dimensions(self.config.vae_image_size, h/w)
image = TF.resize(image, (h_r, w_r))
print("pil_to_latents.image")
texam(image)
image = image.unsqueeze(2) # N, C, F=1, H, W
image = image.to(device=self.device, dtype=self.dtype)
latents = self.pipe.vae.encode(image).latent_dist.mode() # argmax
latents_mean = (
torch.tensor(self.pipe.vae.config.latents_mean)
.view(1, self.pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.pipe.vae.config.latents_std)
.view(1, self.pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents = (latents - latents_mean) / latents_std
latents = latents.squeeze(2)
print("pil_to_latents.latents")
texam(latents)
return latents.to(dtype=self.dtype)
def latents_to_pil(self, latents, h=None, w=None, with_grad=False):
if not with_grad:
latents = latents.clone().detach()
if latents.dim() == 3: # 1d latent
if h is None or w is None:
raise ValueError(f"auto unpack needs h,w, got {h=}, {w=}")
latents = self.unpack_latents(latents, h=h, w=w)
latents = latents.unsqueeze(2)
latents = latents.to(self.dtype)
latents_mean = (
torch.tensor(self.pipe.vae.config.latents_mean)
.view(1, self.pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.pipe.vae.config.latents_std)
.view(1, self.pipe.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std + latents_mean
latents = latents.to(device=self.device, dtype=self.dtype)
image = self.pipe.vae.decode(latents, return_dict=False)[0][:, :, 0] # F = 1
if with_grad:
texam(image, "latents_to_pil.image")
return image
image = self.pipe.image_processor.postprocess(image)
return image
@staticmethod
def pack_latents(latents):
packed = rearrange(latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return packed
@staticmethod
def unpack_latents(packed, h, w):
latents = rearrange(packed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", ph=2, pw=2, h=h, w=w)
return latents
@ftimed
def offload_text_encoder(self, device=str|torch.device):
if self.text_encoder_device == device:
return
print(f"Moving text encoder to {device}")
self.text_encoder_device = device
self.text_encoder.to(device)
if device == "cpu" or device == torch.device("cpu"):
print_gpu_memory(clear_mem="pre")
@ftimed
def preprocess_batch(self, batch):
prompts = batch["text"]
references = batch["reference"]
h,w = references.shape[-2:]
h_r, w_r = calculate_dimensions(CONDITION_IMAGE_SIZE, h/w)
references = TF.resize(references, (h_r, w_r))
print("preprocess_batch.references")
texam(references)
self.offload_text_encoder("cuda")
with torch.no_grad():
prompt_embeds, prompt_embeds_mask = self.pipe.encode_prompt(
prompts,
references.mul(255), # scaled to RGB
device="cuda",
max_sequence_length = self.config.train_max_sequence_length,
)
prompt_embeds = prompt_embeds.cpu().clone().detach()
prompt_embeds_mask = prompt_embeds_mask.cpu().clone().detach()
batch["prompt_embeds"] = (prompt_embeds, prompt_embeds_mask)
batch["reference"] = batch["reference"].cpu()
batch["image"] = batch["image"].cpu()
return batch
@ftimed
def single_step(self, batch) -> torch.Tensor:
self.offload_text_encoder("cpu")
if "prompt_embeds" not in batch:
batch = self.preprocess_batch(batch)
prompt_embeds, prompt_embeds_mask = batch["prompt_embeds"]
prompt_embeds = prompt_embeds.to(device=self.device)
prompt_embeds_mask = prompt_embeds_mask.to(device=self.device)
images = batch["image"].to(device=self.device, dtype=self.dtype)
x_0 = self.pil_to_latents(images).to(device=self.device, dtype=self.dtype)
x_1 = torch.randn_like(x_0).to(device=self.device, dtype=self.dtype)
seq_len = self.timestep_dist_utils.get_seq_len(x_0)
batch_size = x_0.shape[0]
t = self.timestep_dist_utils.get_train_t([batch_size], seq_len=seq_len).to(device=self.device, dtype=self.dtype)
x_t = (1.0 - t) * x_0 + t * x_1
x_t_1d = self.pack_latents(x_t)
references = batch["reference"].to(device=self.device, dtype=self.dtype)
print("references")
texam(references)
assert references.shape[0] == 1
refs = self.pil_to_latents(references).to(device=self.device, dtype=self.dtype)
refs_1d = self.pack_latents(refs)
print("refs refs_1d")
texam(refs)
texam(refs_1d)
inp_1d = torch.cat([x_t_1d, refs_1d], dim=1)
print("inp_1d")
texam(inp_1d)
l_height, l_width = x_0.shape[-2:]
ref_height, ref_width = refs.shape[-2:]
img_shapes = [
[
(1, l_height // 2, l_width // 2),
(1, ref_height // 2, ref_width // 2),
]
] * batch_size
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist()
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=x_0.device)
v_pred_1d = self.transformer(
hidden_states=inp_1d,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=t,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)]
v_pred_2d = self.unpack_latents(v_pred_1d, h=l_height//2, w=l_width//2)
v_gt_2d = x_1 - x_0
if self.config.loss_weight_dist is not None:
loss = F.mse_loss(v_pred_2d, v_gt_2d, reduction="none").mean(dim=[1,2,3])
weights = self.timestep_dist_utils.get_loss_weighting(t)
loss = torch.mean(loss * weights)
else:
loss = F.mse_loss(v_pred_2d, v_gt_2d, reduction="mean")
return loss
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
print(inputs)
self.offload_text_encoder("cuda")
if inputs.vae_image_override is None:
inputs.vae_image_override = self.config.vae_image_size
if inputs.latent_size_override is None:
inputs.latent_size_override = self.config.vae_image_size
return self.pipe(**inputs.model_dump()).images
class QwenImageFoundationSaveInterm(QwenImageFoundation):
PIPELINE = QwenImageEditSaveIntermPipeline
class QwenImageRegressionFoundation(QwenImageFoundation):
def __init__(self, config:QwenConfig, device=None):
super().__init__(config, device=device)
self.lpips_fn = lpips.LPIPS(net='vgg').to(device=self.device)
def preprocess_batch(self, batch):
return batch
@ftimed
def single_step(self, batch) -> torch.Tensor:
self.offload_text_encoder("cpu")
out_dict = batch["data"]
assert len(out_dict) == 1
out_dict = out_dict[0]
prompt_embeds = out_dict["prompt_embeds"]
prompt_embeds_mask = out_dict["prompt_embeds_mask"]
prompt_embeds = prompt_embeds.to(device=self.device, dtype=self.dtype)
prompt_embeds_mask = prompt_embeds_mask.to(device=self.device, dtype=self.dtype)
h_f16 = out_dict["height"] // 16
w_f16 = out_dict["width"] // 16
refs_1d = out_dict["image_latents"].to(device=self.device, dtype=self.dtype)
t = out_dict["t"].to(device=self.device, dtype=self.dtype)
x_0_1d = out_dict["output"].to(device=self.device, dtype=self.dtype)
x_t_1d = out_dict["latents_start"].to(device=self.device, dtype=self.dtype)
v_neg_1d = out_dict["noise_pred"].to(device=self.device, dtype=self.dtype)
v_gt_1d = (x_t_1d - x_0_1d) / t
inp_1d = torch.cat([x_t_1d, refs_1d], dim=1)
print("inp_1d")
texam(inp_1d)
img_shapes = out_dict["img_shapes"]
txt_seq_lens = out_dict["txt_seq_lens"]
image_rotary_emb = self.transformer.pos_embed(img_shapes, txt_seq_lens, device=self.device)
v_pred_1d = self.transformer(
hidden_states=inp_1d,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=t,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
v_pred_1d = v_pred_1d[:, : x_t_1d.size(1)]
split = batch["split"]
step = batch["step"]
if split == "train":
loss_terms = self.config.train_loss_terms
elif split == "validation":
loss_terms = self.config.validation_loss_terms
loss_accumulator = LossAccumulator(
terms=loss_terms.model_dump(),
step=step,
split=split,
term_groups={"pixel":loss_terms.pixel_terms}
)
if loss_accumulator.has("mse"):
if self.config.loss_weight_dist is not None:
mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="none").mean(dim=[1,2,3])
weights = self.timestep_dist_utils.get_loss_weighting(t)
mse_loss = torch.mean(mse_loss * weights)
else:
mse_loss = F.mse_loss(v_pred_1d, v_gt_1d, reduction="mean")
loss_accumulator.accum("mse", mse_loss)
if loss_accumulator.has("triplet"):
# 1d, B,L,C
margin = loss_terms.triplet_margin
triplet_min_abs_diff = loss_terms.triplet_min_abs_diff
print(f"{triplet_min_abs_diff=}")
v_gt_neg_diff = (v_gt_1d - v_neg_1d).abs().mean(dim=2)
zero_weight = torch.zeros_like(v_gt_neg_diff)
v_weight = torch.where(v_gt_neg_diff > triplet_min_abs_diff, v_gt_neg_diff, zero_weight)
ones = torch.ones_like(v_gt_neg_diff)
filtered_nums = torch.sum(torch.where(v_gt_neg_diff > triplet_min_abs_diff, ones, zero_weight))
wand_logger.log({
"filtered_nums": filtered_nums,
}, commit=False)
diffv_gt_pred = (v_gt_1d - v_pred_1d).pow(2)
diffv_neg_pred = (v_neg_1d - v_pred_1d).pow(2)
per_tok_diff = (diffv_gt_pred - diffv_neg_pred).sum(dim=2)
triplet_loss = torch.mean(F.relu((per_tok_diff + margin) * v_weight))
ones = torch.ones_like(per_tok_diff)
zeros = torch.zeros_like(per_tok_diff)
loss_nonzero_nums = torch.sum(torch.where(((per_tok_diff + margin) * v_weight)>0, ones, zeros))
wand_logger.log({
"loss_nonzero_nums": loss_nonzero_nums,
}, commit=False)
loss_accumulator.accum("triplet", triplet_loss)
texam(v_gt_neg_diff, "v_gt_neg_diff")
texam(v_weight, "v_weight")
texam(diffv_gt_pred, "diffv_gt_pred")
texam(diffv_neg_pred, "diffv_neg_pred")
texam(per_tok_diff, "per_tok_diff")
if loss_accumulator.has("negative_mse"):
neg_mse_loss = -F.mse_loss(v_pred_1d, v_neg_1d, reduction="mean")
loss_accumulator.accum("negative_mse", neg_mse_loss)
if loss_accumulator.has("distribution_matching"):
dm_v = (v_pred_1d - v_neg_1d + v_gt_1d).detach()
dm_mse = F.mse_loss(v_pred_1d, dm_v, reduction="mean")
loss_accumulator.accum("distribution_matching", dm_mse)
if loss_accumulator.has("negative_exponential"):
raise NotImplementedError()
if loss_accumulator.has_group("pixel"):
x_0_pred = x_t_1d - t * v_pred_1d
with torch.no_grad():
pixel_values_x0_gt = self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16, with_grad=True).detach()
pixel_values_x0_pred = self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16, with_grad=True)
if loss_accumulator.has("pixel_lpips"):
lpips_loss = self.lpips_fn(pixel_values_x0_gt, pixel_values_x0_pred)
texam(lpips_loss, "lpips_loss")
lpips_loss = lpips_loss.mean()
texam(lpips_loss, "lpips_loss")
loss_accumulator.accum("pixel_lpips", lpips_loss)
if loss_accumulator.has("pixel_mse"):
pixel_mse_loss = F.mse_loss(pixel_values_x0_pred, pixel_values_x0_gt, reduction="mean")
loss_accumulator.accum("pixel_mse", pixel_mse_loss)
if loss_accumulator.has("pixel_triplet"):
raise NotImplementedError()
loss_accumulator.accum("pixel_triplet", pixel_triplet_loss)
if loss_accumulator.has("pixel_distribution_matching"):
raise NotImplementedError()
loss_accumulator.accum("pixel_distribution_matching", pixel_distribution_matching_loss)
if loss_accumulator.has("adversarial"):
raise NotImplementedError()
loss = loss_accumulator.total
logs = loss_accumulator.logs()
wand_logger.log(logs, step=step, commit=False)
wand_logger.log({
"t": t.float().cpu().item()
}, step=step, commit=False)
if self.should_log_training(step):
self.log_single_step_images(
h_f16,
w_f16,
t,
x_0_1d,
x_t_1d,
v_gt_1d,
v_neg_1d,
v_pred_1d,
visualize_velocities=False,
)
return loss
def should_log_training(self, step) -> bool:
return (
self.training # don't log when validating
and ExperimentTrainer._is_step_trigger(step, self.config.log_batch_steps)
)
def log_single_step_images(
self,
h_f16,
w_f16,
t,
x_0_1d,
x_t_1d,
v_gt_1d,
v_neg_1d,
v_pred_1d,
visualize_velocities=False,
):
t_float = t.float().cpu().item()
x_0_pred = x_t_1d - t * v_pred_1d
x_0_neg = x_t_1d - t * v_neg_1d
x_0_recon = x_t_1d - t * v_gt_1d
log_pils = {
f"x_{t_float}_1d": self.latents_to_pil(x_t_1d, h=h_f16, w=w_f16),
"x_0": self.latents_to_pil(x_0_1d, h=h_f16, w=w_f16),
"x_0_recon": self.latents_to_pil(x_0_recon, h=h_f16, w=w_f16),
"x_0_pred": self.latents_to_pil(x_0_pred, h=h_f16, w=w_f16),
"x_0_neg": self.latents_to_pil(x_0_neg, h=h_f16, w=w_f16),
}
if visualize_velocities: # naively visualizing through vae (works with flux)
log_pils.update({
"v_gt_1d": self.latents_to_pil(v_gt_1d, h=h_f16, w=w_f16),
"v_pred_1d": self.latents_to_pil(v_pred_1d, h=h_f16, w=w_f16),
"v_neg_1d": self.latents_to_pil(v_neg_1d, h=h_f16, w=w_f16),
})
# create gt-neg difference maps
v_pred_2d = self.unpack_latents(v_pred_1d, h_f16, w_f16)
v_gt_2d = self.unpack_latents(v_gt_1d, h_f16, w_f16)
v_neg_2d = self.unpack_latents(v_neg_1d, h_f16, w_f16)
gt_neg_diff_map_2d = (v_gt_2d - v_neg_2d).pow(2).mean(dim=1, keepdim=True)
gt_pred_diff_map_2d = (v_gt_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True)
neg_pred_diff_map_2d = (v_neg_2d - v_pred_2d).pow(2).mean(dim=1, keepdim=True)
diff_max = torch.max(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d]))
diff_min = torch.min(torch.stack([gt_neg_diff_map_2d, gt_pred_diff_map_2d, neg_pred_diff_map_2d]))
print(f"{diff_min}, {diff_max}")
# norms to 0-1
diff_span = diff_max - diff_min
gt_neg_diff_map_2d = (gt_neg_diff_map_2d - diff_min) / diff_span
gt_pred_diff_map_2d = (gt_pred_diff_map_2d - diff_min) / diff_span
neg_pred_diff_map_2d = (neg_pred_diff_map_2d - diff_min) / diff_span
log_pils.update({
"gt-neg":gt_neg_diff_map_2d.float().cpu(),
"gt-pred":gt_pred_diff_map_2d.float().cpu(),
"neg-pred":neg_pred_diff_map_2d.float().cpu(),
})
wand_logger.log({
"train_images": log_pils,
}, commit=False)
def base_pipe(self, inputs: QwenInputs) -> list[Image]:
inputs.num_inference_steps = self.config.regression_base_pipe_steps
return super().base_pipe(inputs)