Elea Zhong
training quantization
f9abc90
from collections import OrderedDict
from PIL import Image
from torchao import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, ModuleFqnToConfig, PerRow
from torchao.utils import get_model_size_in_bytes
from qwenimage.debug import ftimed, print_first_param
from qwenimage.experiments.experiments_qwen import ExperimentRegistry, QwenBaseExperiment
from qwenimage.models.attention_processors import QwenDoubleStreamAttnProcessorFA3
from qwenimage.optimization import optimize_pipeline_
# ModuleFqnToConfig
# @ExperimentRegistry.register(name="qwen_fa3_aot")
# class Qwen_FA3_AoT(QwenBaseExperiment):
# @ftimed
# def optimize(self):
# self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
# optimize_pipeline_(
# self.pipe,
# cache_compiled=self.config.cache_compiled,
# quantize=False,
# suffix="_fa3",
# pipe_kwargs={
# "image": [Image.new("RGB", (1024, 1024))],
# "prompt":"prompt",
# "num_inference_steps":4
# }
# )
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8wo")
class Qwen_FA3_AoT_fp8wo(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Float8WeightOnlyConfig(),
suffix="_fp8wo_fa3",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_int8wo")
class Qwen_FA3_AoT_int8wo(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Int8WeightOnlyConfig(),
suffix="_int8wo_fa3",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8da")
class Qwen_FA3_AoT_fp8da(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
quantize_config=Float8DynamicActivationFloat8WeightConfig(),
suffix="_fp8da_fa3",
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_int8da")
class Qwen_FA3_AoT_int8da(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_int8da_fa3",
quantize_config=Int8DynamicActivationInt8WeightConfig(),
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow")
class Qwen_FA3_AoT_fp8darow(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fp8dqrow_fa3",
quantize_config=Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
ATTENTION_QKV_REGEX = "re:^transformer_blocks\\.\\d+\\.attn\\.(to_q|to_k|to_v|to_qkv|to_added_qkv|add_q_proj|add_k_proj|add_v_proj)$"
ATTENTION_QKV_REGEX = r"re:^transformer_blocks\.\d+\.attn\.(to_q|to_k|to_v|to_qkv|to_added_qkv|add_q_proj|add_k_proj|add_v_proj)$"
# Attention QKV projections (all Linear)
# Attention output projections (Linear)
ATTENTION_OUT_REGEX = r"re:^transformer_blocks\.\d+\.attn\.to_out\.0$"
ATTENTION_ADD_OUT_REGEX = r"re:^transformer_blocks\.\d+\.attn\.to_add_out$"
# Image modulation Linear layer
IMG_MOD_LINEAR_REGEX = r"re:^transformer_blocks\.\d+\.img_mod\.1$"
# Image MLP Linear layers
IMG_MLP_LINEAR1_REGEX = r"re:^transformer_blocks\.\d+\.img_mlp\.net\.0\.proj$"
IMG_MLP_LINEAR2_REGEX = r"re:^transformer_blocks\.\d+\.img_mlp\.net\.2$"
# Text modulation Linear layer
TXT_MOD_LINEAR_REGEX = r"re:^transformer_blocks\.\d+\.txt_mod\.1$"
# Text MLP Linear layers
TXT_MLP_LINEAR1_REGEX = r"re:^transformer_blocks\.\d+\.txt_mlp\.net\.0\.proj$"
TXT_MLP_LINEAR2_REGEX = r"re:^transformer_blocks\.\d+\.txt_mlp\.net\.2$"
# Top-level Linear layers (these were already fine)
IMG_IN_REGEX = r"re:^img_in$"
TXT_IN_REGEX = r"re:^txt_in$"
PROJ_OUT_REGEX = r"re:^proj_out$"
ATTN_LAST_LAYER = r"re:^transformer_blocks\.59\..*$"
ATTN_FIRST_LAYER = r"re:^transformer_blocks\.0\..*$"
@ExperimentRegistry.register(name="qwen_fa3_aot_qkvint4oint8")
class Qwen_FA3_AoT_qkvint4oint8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
module_fqn_to_config = ModuleFqnToConfig(
OrderedDict([
(ATTENTION_QKV_REGEX,Int4WeightOnlyConfig(),),
("_default",Int8WeightOnlyConfig(),),
])
)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_qkvint4oint8_fa3",
quantize_config=module_fqn_to_config,
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_qkvfp8oint8")
class Qwen_FA3_AoT_qkvfp8oint8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
module_fqn_to_config = ModuleFqnToConfig(
OrderedDict([
(ATTENTION_QKV_REGEX,Float8DynamicActivationFloat8WeightConfig(),),
("_default",Int8WeightOnlyConfig(),),
])
)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_qkvfp8oint8_fa3",
quantize_config=module_fqn_to_config,
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nolast")
class Qwen_FA3_AoT_fp8darow_nolast(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
module_fqn_to_config = ModuleFqnToConfig(
OrderedDict([
(ATTN_LAST_LAYER, None),
("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
])
)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fp8darow_nolast_fa3",
quantize_config=module_fqn_to_config,
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
def quantize_transformer_fp8darow_nolast(model):
module_fqn_to_config = ModuleFqnToConfig(
OrderedDict([
(ATTN_LAST_LAYER, None),
# ("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
("_default",Float8DynamicActivationFloat8WeightConfig(),),
])
)
print(f"original model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
quantize_(model, module_fqn_to_config)
print_first_param(model)
print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB")
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nofirstlast")
class Qwen_FA3_AoT_fp8darow_nofirstlast(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
module_fqn_to_config = ModuleFqnToConfig(
OrderedDict([
(ATTN_LAST_LAYER, None),
(ATTN_FIRST_LAYER, None),
("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
])
)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fp8darow_nofirstlast_fa3",
quantize_config=module_fqn_to_config,
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)
@ExperimentRegistry.register(name="qwen_fa3_aot_fp8darow_nolast_cint8")
class Qwen_FA3_AoT_fp8darow_nolast_cint8(QwenBaseExperiment):
@ftimed
def optimize(self):
self.pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
module_fqn_to_config = ModuleFqnToConfig(
OrderedDict([
(ATTN_LAST_LAYER, None),
(IMG_IN_REGEX, Int8WeightOnlyConfig()),
(TXT_IN_REGEX, Int8WeightOnlyConfig()),
(PROJ_OUT_REGEX, Int8WeightOnlyConfig()),
("_default",Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),),
])
)
optimize_pipeline_(
self.pipe,
cache_compiled=self.config.cache_compiled,
quantize=True,
suffix="_fp8darow_nolast_cint8_fa3",
quantize_config=module_fqn_to_config,
pipe_kwargs={
"image": [Image.new("RGB", (1024, 1024))],
"prompt":"prompt",
"num_inference_steps":4
}
)