Elea Zhong
eval activation, add module-level quantization, use sage attention 2++
3de1ebe
import itertools
import json
import math
import os
from pathlib import Path
import random
import statistics
import os
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
import torch
from PIL import Image
import pandas as pd
from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
from torchao import autoquant
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, PerRow, quantize_
from torchao.quantization import Int8WeightOnlyConfig
import spaces
import torch
from torch.utils._pytree import tree_map
from torchao.utils import get_model_size_in_bytes
from qwenimage.debug import ctimed, ftimed, print_first_param
from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessorFA3, QwenDoubleStreamAttnProcessorSageAttn2, sageattn_qk_int8_pv_fp16_cuda_wrapper, sageattn_qk_int8_pv_fp16_triton_wrapper, sageattn_qk_int8_pv_fp8_cuda_sm90_wrapper, sageattn_qk_int8_pv_fp8_cuda_wrapper, sageattn_wrapper
from qwenimage.models.first_block_cache import apply_cache_on_pipe
from qwenimage.models.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline, calculate_dimensions
from qwenimage.models.transformer_qwenimage import QwenImageTransformer2DModel
from qwenimage.experiment import AbstractExperiment, ExperimentConfig
from qwenimage.debug import ProfileSession, ftimed
from qwenimage.optimization import INDUCTOR_CONFIGS, TRANSFORMER_DYNAMIC_SHAPES, aoti_apply, drain_module_parameters, optimize_pipeline_
from qwenimage.prompt import build_camera_prompt
class ExperimentRegistry:
registry = {}
@classmethod
def register(cls, name: str):
def decorator(experiment_class):
if name in cls.registry:
raise ValueError(f"Experiment '{name}' is already registered")
cls.registry[name] = experiment_class
experiment_class.registry_name = name
return experiment_class
return decorator
@classmethod
def get(cls, name: str):
if name not in cls.registry:
raise KeyError(f"{name} not in {list(cls.registry.keys())}")
return cls.registry[name]
@classmethod
def filter(cls, startswith=None, endswith=None, contains=None, not_contain=None):
keys = list(cls.registry.keys())
if startswith is not None:
keys = [k for k in keys if k.startswith(startswith)]
if endswith is not None:
keys = [k for k in keys if k.endswith(endswith)]
if contains is not None:
keys = [k for k in keys if contains in k]
if not_contain is not None:
keys = [k for k in keys if not_contain not in k]
return keys
@classmethod
def keys(cls):
return list(cls.registry.keys())
@classmethod
def dict(cls):
return dict(cls.registry)
class PipeInputs:
images = list(Path("scripts/assets/test_images_v1/").glob("*.jpg"))
camera_params = {
"rotate_deg": [-90,-45,0,45,90],
"move_forward": [0,5,10],
"vertical_tilt": [-1,0,1],
"wideangle": [True, False]
}
def __init__(self, seed=42):
self.seed=seed
param_keys = list(self.camera_params.keys())
param_values = list(self.camera_params.values())
param_keys.append("image")
param_values.append(self.images)
self.total_inputs = []
for comb in itertools.product(*param_values):
inp = {key:val for key,val in zip(param_keys, comb)}
self.total_inputs.append(inp)
print(f"{len(self.total_inputs)} input combinations")
random.seed(seed)
random.shuffle(self.total_inputs)
self.generator = None
def __len__(self):
return len(self.total_inputs)
def __getitem__(self, ind):
if self.generator is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.generator = torch.Generator(device=device).manual_seed(self.seed)
inputs = self.total_inputs[ind]
cam_prompt_params = {k:v for k,v in inputs.items() if k in self.camera_params}
prompt = build_camera_prompt(**cam_prompt_params)
image = [Image.open(inputs["image"]).convert("RGB")]
return {
"image": image,
"prompt": prompt,
"generator": self.generator,
"num_inference_steps": 4,
"true_cfg_scale": 1.0,
"num_images_per_prompt": 1,
}
@ExperimentRegistry.register(name="qwen_base")
class QwenBaseExperiment(AbstractExperiment):
def __init__(self, config: ExperimentConfig | None = None, pipe_inputs: PipeInputs | None = None):
self.config = config if config is not None else ExperimentConfig()
self.config.report_dir.mkdir(parents=True, exist_ok=True)
self.profile_report = ProfileSession().start()
self.pipe_inputs = pipe_inputs if pipe_inputs is not None else PipeInputs()
@ftimed
def load(self):
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"experiment load cuda: {torch.cuda.is_available()=}")
pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509",
transformer=QwenImageTransformer2DModel.from_pretrained(
"linoyts/Qwen-Image-Edit-Rapid-AIO",
subfolder='transformer',
torch_dtype=dtype,
device_map=device),
torch_dtype=dtype,
).to(device)
pipe.load_lora_weights(
"dx8152/Qwen-Edit-2509-Multiple-angles",
weight_name="镜头转换.safetensors", adapter_name="angles"
)
pipe.set_adapters(["angles"], adapter_weights=[1.])
pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.25)
pipe.unload_lora_weights()
self.pipe = pipe
@ftimed
def optimize(self):
pass
def run_once(self, *args, **kwargs):
return self.pipe(*args, **kwargs).images[0]
def run(self):
output_save_dir = self.config.report_dir / f"{self.config.name}_outputs"
output_save_dir.mkdir(parents=True, exist_ok=True)
for i in range(self.config.iterations):
inputs = self.pipe_inputs[i]
with ctimed("run_once"):
output = self.run_once(**inputs)
output.save(output_save_dir / f"{i:03d}.jpg")
def report(self):
print(self.profile_report)
raw_data = dict(self.profile_report.recorded_times)
with open(self.config.report_dir/ f"{self.config.name}_raw.json", "w") as f:
json.dump(raw_data, f, indent=2)
data = []
for name, times in self.profile_report.recorded_times.items():
mean = statistics.mean(times)
std = statistics.stdev(times) if len(times) > 1 else 0
size = len(times)
data.append({
'name': name,
'mean': mean,
'std': std,
'len': size
})
df = pd.DataFrame(data)
df.to_csv(self.config.report_dir/f"{self.config.name}.csv")
return df, raw_data
def cleanup(self):
del self.pipe.transformer
del self.pipe
@ExperimentRegistry.register(name="qwen_50step")
class Qwen_50Step(QwenBaseExperiment):
@ftimed
def load(self):
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"experiment load cuda: {torch.cuda.is_available()=}")
pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509",
transformer=QwenImageTransformer2DModel.from_pretrained( # use our own model
"Qwen/Qwen-Image-Edit-2509",
subfolder='transformer',
torch_dtype=dtype,
device_map=device
),
torch_dtype=dtype,
).to(device)
pipe.load_lora_weights(
"dx8152/Qwen-Edit-2509-Multiple-angles",
weight_name="镜头转换.safetensors", adapter_name="angles"
)
pipe.set_adapters(["angles"], adapter_weights=[1.])
pipe.fuse_lora(adapter_names=["angles"], lora_scale=1.25)
pipe.unload_lora_weights()
self.pipe = pipe
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 50
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_lightning_lora")
class Qwen_Lightning_Lora(QwenBaseExperiment):
@ftimed
def load(self):
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Scheduler configuration for Lightning
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3), # We use shift=3 in distillation
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3), # We use shift=3 in distillation
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None, # set shift_terminal to None
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) # TODO: check scheduler sync issue mentioned by https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2509",
transformer=QwenImageTransformer2DModel.from_pretrained( # use our own model
"Qwen/Qwen-Image-Edit-2509",
subfolder='transformer',
torch_dtype=dtype,
device_map=device
),
scheduler=scheduler,
torch_dtype=dtype,
).to(device)
pipe.load_lora_weights(
"dx8152/Qwen-Edit-2509-Multiple-angles",
weight_name="镜头转换.safetensors",
adapter_name="angles"
)
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Edit-2509/Qwen-Image-Edit-2509-Lightning-4steps-V1.0-bf16.safetensors",
adapter_name="lightning",
)
pipe.set_adapters(["angles", "lightning"], adapter_weights=[1.25, 1.])
pipe.fuse_lora(adapter_names=["angles", "lightning"], lora_scale=1.0)
pipe.unload_lora_weights()
self.pipe = pipe
@ExperimentRegistry.register(name="qwen_lightning_lora_3step")
class Qwen_Lightning_Lora_3step(Qwen_Lightning_Lora):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 3
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_base_3step")
class Qwen_Base_3step(QwenBaseExperiment):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 3
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_lightning_lora_2step")
class Qwen_Lightning_Lora_2step(Qwen_Lightning_Lora):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 2
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_base_2step")
class Qwen_Base_2step(QwenBaseExperiment):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 2
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_fuse")
class Qwen_Fuse(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
@ExperimentRegistry.register(name="qwen_fuse_aot")
class Qwen_Fuse_AoT(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=False,
suffix="_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_fuse")
class Qwen_FA3_Fuse(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
@ExperimentRegistry.register(name="qwen_fa3")
class Qwen_FA3(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
@ExperimentRegistry.register(name="qwen_sageattn")
class Qwen_Sageattn(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_wrapper))
@ExperimentRegistry.register(name="qwen_sageattn_qk_int8_pv_fp16_cuda")
class Qwen_Sageattn_qk_int8_pv_fp16_cuda(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp16_cuda_wrapper))
@ExperimentRegistry.register(name="qwen_sageattn_qk_int8_pv_fp16_triton")
class Qwen_Sageattn_qk_int8_pv_fp16_triton(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp16_triton_wrapper))
@ExperimentRegistry.register(name="qwen_sageattn_qk_int8_pv_fp8_cuda")
class Qwen_Sageattn_qk_int8_pv_fp8_cuda(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp8_cuda_wrapper))
@ExperimentRegistry.register(name="qwen_sageattn_qk_int8_pv_fp8_cuda_sm90")
class Qwen_Sageattn_qk_int8_pv_fp8_cuda_sm90(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp8_cuda_sm90_wrapper))
@ExperimentRegistry.register(name="qwen_aot")
class Qwen_AoT(QwenBaseExperiment):
@ftimed
def optimize(self):
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=False,
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot")
class Qwen_FA3_AoT(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=False,
suffix="_fa3",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_sage_aot")
class Qwen_Sage_AoT(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=False,
suffix="_sage",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_int8")
class Qwen_FA3_AoT_int8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_sage_aot_int8")
class Qwen_Sage_AoT_int8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_sage",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_sage_aot_int8da")
class Qwen_Sage_AoT_int8da(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp8_cuda_sm90_wrapper))
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Int8DynamicActivationInt8WeightConfig(),
suffix="_int8da_sage",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_sagefp8cuda_aot_int8da")
class Qwen_Sagefp8cuda_AoT_int8da(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp8_cuda_wrapper))
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Int8DynamicActivationInt8WeightConfig(),
suffix="_int8da_sage",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_sagefp8cuda_aot_int8wo")
class Qwen_Sagefp8cuda_AoT_int8wo(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2(sageattn_qk_int8_pv_fp8_cuda_wrapper))
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Int8WeightOnlyConfig(),
suffix="_int8wo_sage",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fp8_weightonly")
class Qwen_fp8_Weightonly(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
quantize_(self.pipe.transformer, Float8WeightOnlyConfig())
@ExperimentRegistry.register(name="qwen_int8_weightonly")
class Qwen_int8_Weightonly(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
quantize_(self.pipe.transformer, Int8WeightOnlyConfig())
@ExperimentRegistry.register(name="qwen_fp8")
class Qwen_fp8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
quantize_(self.pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
@ExperimentRegistry.register(name="qwen_int8")
class Qwen_int8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
quantize_(self.pipe.transformer, Int8DynamicActivationInt8WeightConfig())
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8")
class Qwen_FA3_AoT_fp8(QwenBaseExperiment):
@ftimed
# @spaces.GPU()
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
suffix="_fa3"
cache_compiled=self.config.cache_compiled
transformer_pt2_cache_path = f"checkpoints/transformer_fp8{suffix}_archive.pt2"
transformer_weights_cache_path = f"checkpoints/transformer_fp8{suffix}_weights.pt"
print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
quantize_(self.pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
print_first_param(self.pipe.transformer)
print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
inductor_config = INDUCTOR_CONFIGS
if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
drain_module_parameters(self.pipe.transformer)
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
else:
with spaces.aoti_capture(self.pipe.transformer) as call:
self.pipe(**pipe_kwargs)
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=self.pipe.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled_transformer = spaces.aoti_compile(exported, inductor_config)
with open(transformer_pt2_cache_path, "wb") as f:
f.write(compiled_transformer.archive_file.getvalue())
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
aoti_apply(compiled_transformer, self.pipe.transformer)
@ExperimentRegistry.register(name="qwen_sage_aot_fp8")
class Qwen_Sage_AoT_fp8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorSageAttn2())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Float8DynamicActivationFloat8WeightConfig(),
suffix="_fp8_sage",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
# FA3_AoT_fp8_fuse
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8_fuse")
class Qwen_FA3_AoT_fp8_fuse(QwenBaseExperiment):
@ftimed
# @spaces.GPU()
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
suffix="_fa3_fuse"
cache_compiled=self.config.cache_compiled
transformer_pt2_cache_path = f"checkpoints/transformer_fp8{suffix}_archive.pt2"
transformer_weights_cache_path = f"checkpoints/transformer_fp8{suffix}_weights.pt"
print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
quantize_(self.pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
print_first_param(self.pipe.transformer)
print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
inductor_config = INDUCTOR_CONFIGS
if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
drain_module_parameters(self.pipe.transformer)
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
else:
with spaces.aoti_capture(self.pipe.transformer) as call:
self.pipe(**pipe_kwargs)
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=self.pipe.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled_transformer = spaces.aoti_compile(exported, inductor_config)
with open(transformer_pt2_cache_path, "wb") as f:
f.write(compiled_transformer.archive_file.getvalue())
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
aoti_apply(compiled_transformer, self.pipe.transformer)
# FA3_AoT_int8_fuse
@ExperimentRegistry.register(name="qwen_fa3_aot_int8_fuse")
class Qwen_FA3_AoT_int8_fuse(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
# lightning_FA3_AoT_fp8_fuse
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_fp8_fuse")
class Qwen_lightning_FA3_AoT_fp8_fuse(Qwen_Lightning_Lora):
@ftimed
# @spaces.GPU()
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
suffix="_fa3_fuse"
cache_compiled=self.config.cache_compiled
transformer_pt2_cache_path = f"checkpoints/transformer_fp8{suffix}_archive.pt2"
transformer_weights_cache_path = f"checkpoints/transformer_fp8{suffix}_weights.pt"
print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
quantize_(self.pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
print_first_param(self.pipe.transformer)
print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
inductor_config = INDUCTOR_CONFIGS
if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
drain_module_parameters(self.pipe.transformer)
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
else:
with spaces.aoti_capture(self.pipe.transformer) as call:
self.pipe(**pipe_kwargs)
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=self.pipe.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled_transformer = spaces.aoti_compile(exported, inductor_config)
with open(transformer_pt2_cache_path, "wb") as f:
f.write(compiled_transformer.archive_file.getvalue())
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
aoti_apply(compiled_transformer, self.pipe.transformer)
# lightning_FA3_AoT_int8_fuse
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse")
class Qwen_Lightning_FA3_AoT_int8_fuse(Qwen_Lightning_Lora):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_2step")
class Qwen_Lightning_FA3_AoT_int8_fuse_2step(Qwen_Lightning_FA3_AoT_int8_fuse):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 2
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_3step")
class Qwen_Lightning_FA3_AoT_int8_fuse_3step(Qwen_Lightning_FA3_AoT_int8_fuse):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 3
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_fa3_aot_int8_fuse_2step")
class Qwen_FA3_AoT_int8_fuse_2step(Qwen_FA3_AoT_int8_fuse):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 2
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_fa3_aot_int8_fuse_3step")
class Qwen_FA3_AoT_int8_fuse_3step(Qwen_FA3_AoT_int8_fuse):
def run_once(self, *args, **kwargs):
kwargs["num_inference_steps"] = 3
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_channels_last")
class Qwen_Channels_Last(QwenBaseExperiment):
"""
This experiment may be useless: channels last format only works with NCHW tensors,
i.e. 2D CNNs, transformer is 1D and vae is 3D, plus, for it to work the inputs need to
be converted in-pipe as well. left for reference.
"""
@ftimed
def optimize(self):
# self.pipe.vae = self.pipe.vae.to(memory_format=torch.channels_last_3d)
self.pipe.transformer = self.pipe.transformer.to(memory_format=torch.channels_last)
@ExperimentRegistry.register(name="qwen_fbcache_05")
class Qwen_FBCache_05(QwenBaseExperiment):
@ftimed
def optimize(self):
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.5,)
@ExperimentRegistry.register(name="qwen_fbcache_055")
class Qwen_FBCache_055(QwenBaseExperiment):
@ftimed
def optimize(self):
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.55,)
@ExperimentRegistry.register(name="qwen_fbcache_054")
class Qwen_FBCache_054(QwenBaseExperiment):
@ftimed
def optimize(self):
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.54,)
@ExperimentRegistry.register(name="qwen_fbcache_053")
class Qwen_FBCache_053(QwenBaseExperiment):
@ftimed
def optimize(self):
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.53,)
@ExperimentRegistry.register(name="qwen_fbcache_052")
class Qwen_FBCache_052(QwenBaseExperiment):
@ftimed
def optimize(self):
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.52,)
@ExperimentRegistry.register(name="qwen_fbcache_051")
class Qwen_FBCache_051(QwenBaseExperiment):
@ftimed
def optimize(self):
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.51,)
# @ExperimentRegistry.register(name="qwen_lightning_fa3_aot_autoquant_fuse")
class Qwen_lightning_FA3_AoT_autoquant_fuse(Qwen_Lightning_Lora):
"""
Seemingly not working with AoT export
"""
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
suffix="_autoquant_fa3_fuse"
cache_compiled=self.config.cache_compiled
transformer_pt2_cache_path = f"checkpoints/transformer_{suffix}_archive.pt2"
transformer_weights_cache_path = f"checkpoints/transformer_{suffix}_weights.pt"
print(f"original model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
autoquant(self.pipe.transformer, error_on_unseen=False)
print_first_param(self.pipe.transformer)
print(f"quantized model size: {get_model_size_in_bytes(self.pipe.transformer) / 1024 / 1024} MB")
inductor_config = INDUCTOR_CONFIGS
if os.path.isfile(transformer_pt2_cache_path) and cache_compiled:
drain_module_parameters(self.pipe.transformer)
zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False)
compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights)
else:
with spaces.aoti_capture(self.pipe.transformer) as call:
self.pipe(**pipe_kwargs)
dynamic_shapes = tree_map(lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=self.pipe.transformer,
args=call.args,
kwargs=call.kwargs,
dynamic_shapes=dynamic_shapes,
)
compiled_transformer = spaces.aoti_compile(exported, inductor_config)
with open(transformer_pt2_cache_path, "wb") as f:
f.write(compiled_transformer.archive_file.getvalue())
torch.save(compiled_transformer.weights, transformer_weights_cache_path)
aoti_apply(compiled_transformer, self.pipe.transformer)
class Qwen_Downsize_Mixin:
def run_once(self, *args, **kwargs):
image = kwargs["image"][0]
w,h = image.size
real_w, real_h = calculate_dimensions(1024 * 1024, w / h)
down_w, down_h = calculate_dimensions(self.size * self.size, w / h)
image = image.resize((down_w, down_h))
kwargs["image"] = [image]
kwargs["vae_image_override"] = self.size * self.size
kwargs["width"] = real_w
kwargs["height"] = real_h
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_downsize768")
class Qwen_Downsize768(Qwen_Downsize_Mixin, QwenBaseExperiment):
size = 768
@ExperimentRegistry.register(name="qwen_downsize512")
class Qwen_Downsize512(Qwen_Downsize_Mixin, QwenBaseExperiment):
size = 512
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_2step_fbcache_055_downsize512")
class Qwen_Lightning_FA3_AoT_int8_fuse_2step_FBCache055_Downsize512(Qwen_Lightning_Lora):
size = 512
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.55,)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
def run_once(self, *args, **kwargs):
image = kwargs["image"][0]
w,h = image.size
real_w, real_h = calculate_dimensions(1024 * 1024, w / h)
down_w, down_h = calculate_dimensions(self.size * self.size, w / h)
image = image.resize((down_w, down_h))
kwargs["image"] = [image]
kwargs["vae_image_override"] = self.size * self.size
kwargs["width"] = real_w
kwargs["height"] = real_h
kwargs["num_inference_steps"] = 2
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_downsize512")
class Qwen_Lightning_FA3_AoT_int8_fuse_Downsize512(Qwen_Lightning_Lora):
size = 512
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
def run_once(self, *args, **kwargs):
image = kwargs["image"][0]
w,h = image.size
real_w, real_h = calculate_dimensions(1024 * 1024, w / h)
down_w, down_h = calculate_dimensions(self.size * self.size, w / h)
image = image.resize((down_w, down_h))
kwargs["image"] = [image]
kwargs["vae_image_override"] = self.size * self.size
kwargs["width"] = real_w
kwargs["height"] = real_h
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_1step_fbcache_055_downsize512")
class Qwen_Lightning_FA3_AoT_int8_fuse_1step_FBCache055_Downsize512(Qwen_Lightning_Lora):
size = 512
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.55,)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
def run_once(self, *args, **kwargs):
image = kwargs["image"][0]
w,h = image.size
real_w, real_h = calculate_dimensions(1024 * 1024, w / h)
down_w, down_h = calculate_dimensions(self.size * self.size, w / h)
image = image.resize((down_w, down_h))
kwargs["image"] = [image]
kwargs["vae_image_override"] = self.size * self.size
kwargs["width"] = real_w
kwargs["height"] = real_h
kwargs["num_inference_steps"] = 1
return self.pipe(*args, **kwargs).images[0]
@ExperimentRegistry.register(name="qwen_lightning_fa3_aot_int8_fuse_4step_fbcache_055_downsize512")
class Qwen_Lightning_FA3_AoT_int8_fuse_4step_FBCache055_Downsize512(Qwen_Lightning_Lora):
size = 512
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
self.pipe.transformer.fuse_qkv_projections()
self.pipe.transformer.check_fused_qkv()
apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.55,)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fa3_fuse",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
def run_once(self, *args, **kwargs):
image = kwargs["image"][0]
w,h = image.size
real_w, real_h = calculate_dimensions(1024 * 1024, w / h)
down_w, down_h = calculate_dimensions(self.size * self.size, w / h)
image = image.resize((down_w, down_h))
kwargs["image"] = [image]
kwargs["vae_image_override"] = self.size * self.size
kwargs["width"] = real_w
kwargs["height"] = real_h
kwargs["num_inference_steps"] = 4
return self.pipe(*args, **kwargs).images[0]