""" """ import os from typing import Any from typing import Callable from typing import ParamSpec from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights from torchao.core.config import AOBaseConfig from torchao.quantization import ( quantize_, Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig, PerRow, ) import spaces import torch from torch.utils._pytree import tree_map from torchao.utils import get_model_size_in_bytes from qwenimage.datamodels import QuantOptions from qwenimage.debug import ftimed, print_first_param P = ParamSpec('P') TRANSFORMER_IMAGE_SEQ_LENGTH_DIM = torch.export.Dim('image_seq_length') TRANSFORMER_TEXT_SEQ_LENGTH_DIM = torch.export.Dim('text_seq_length') TRANSFORMER_DYNAMIC_SHAPES = { 'hidden_states': { 1: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, }, 'encoder_hidden_states': { 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, }, 'encoder_hidden_states_mask': { 1: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, }, 'image_rotary_emb': ({ 0: TRANSFORMER_IMAGE_SEQ_LENGTH_DIM, }, { 0: TRANSFORMER_TEXT_SEQ_LENGTH_DIM, }), } INDUCTOR_CONFIGS = { 'conv_1x1_as_mm': True, 'epilogue_fusion': False, 'coordinate_descent_tuning': True, 'coordinate_descent_check_all_directions': True, 'max_autotune': True, 'triton.cudagraphs': True, } def aoti_apply( compiled: ZeroGPUCompiledModel, module: torch.nn.Module, call_method: str = 'forward', ): setattr(module, call_method, compiled) drain_module_parameters(module) def drain_module_parameters(module: torch.nn.Module): state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()} state_dict = {name: torch.nn.Parameter(torch.empty(tensor.size(), device='cpu')) for name, tensor in module.state_dict().items()} module.load_state_dict(state_dict, assign=True) for name, param in state_dict.items(): meta = state_dict_meta[name] param.data = torch.Tensor([]).to(**meta) # @ftimed # @spaces.GPU(duration=1500) def optimize_pipeline_( pipeline: Callable[P, Any], cache_compiled=True, quantize=True, quantize_config:AOBaseConfig=None, inductor_config=None, suffix="", pipe_kwargs={} ): if quantize_config is not None: transformer_pt2_cache_path = f"checkpoints/transformer{suffix}_archive.pt2" transformer_weights_cache_path = f"checkpoints/transformer{suffix}_weights.pt" print(f"original model size: {get_model_size_in_bytes(pipeline.transformer) / 1024 / 1024} MB") quantize_(pipeline.transformer, quantize_config) print_first_param(pipeline.transformer) print(f"quantized model size: {get_model_size_in_bytes(pipeline.transformer) / 1024 / 1024} MB") elif quantize: transformer_pt2_cache_path = f"checkpoints/transformer_int8{suffix}_archive.pt2" transformer_weights_cache_path = f"checkpoints/transformer_int8{suffix}_weights.pt" print(f"original model size: {get_model_size_in_bytes(pipeline.transformer) / 1024 / 1024} MB") quantize_(pipeline.transformer, Int8WeightOnlyConfig()) print_first_param(pipeline.transformer) print(f"quantized model size: {get_model_size_in_bytes(pipeline.transformer) / 1024 / 1024} MB") else: transformer_pt2_cache_path = f"checkpoints/transformer{suffix}_archive.pt2" transformer_weights_cache_path = f"checkpoints/transformer{suffix}_weights.pt" if inductor_config is None: inductor_config = INDUCTOR_CONFIGS if os.path.isfile(transformer_pt2_cache_path) and cache_compiled: drain_module_parameters(pipeline.transformer) zerogpu_weights = torch.load(transformer_weights_cache_path, weights_only=False) compiled_transformer = ZeroGPUCompiledModel(transformer_pt2_cache_path, zerogpu_weights) else: with spaces.aoti_capture(pipeline.transformer) as call: pipeline(**pipe_kwargs) dynamic_shapes = tree_map(lambda t: None, call.kwargs) dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES exported = torch.export.export( mod=pipeline.transformer, args=call.args, kwargs=call.kwargs, dynamic_shapes=dynamic_shapes, ) compiled_transformer = spaces.aoti_compile(exported, inductor_config) with open(transformer_pt2_cache_path, "wb") as f: f.write(compiled_transformer.archive_file.getvalue()) torch.save(compiled_transformer.weights, transformer_weights_cache_path) aoti_apply(compiled_transformer, pipeline.transformer) def simple_quantize_model(model, quant_option: QuantOptions): if quant_option == QuantOptions.INT8WO: aoconfig = Int8WeightOnlyConfig() elif quant_option == QuantOptions.INT4WO: aoconfig = Int4WeightOnlyConfig() elif quant_option == QuantOptions.FP8ROW: aoconfig = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) else: raise ValueError() print(f"original model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB") quantize_(model, aoconfig) print_first_param(model) print(f"quantized model size: {get_model_size_in_bytes(model) / 1024 / 1024} MB") return model