Z-Image-Turbo / app.py
lenML's picture
Update app.py
babc3bc verified
import torch
import spaces
import gradio as gr
import time
import re
import random
import os
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
import warnings
# 忽略警告
warnings.filterwarnings("ignore")
# ==================== 1. 分辨率配置 ====================
RES_CHOICES = {
"1024": [
"720x1280 (9:16)",
"1024x1024 (1:1)",
"1152x896 (9:7)",
"896x1152 (7:9)",
"1152x864 (4:3)",
"864x1152 (3:4)",
"1248x832 (3:2)",
"832x1248 (2:3)",
"1280x720 (16:9)",
"1344x576 (21:9)",
"576x1344 (9:21)",
],
"1280": [
"864x1536 (9:16)",
"1280x1280 (1:1)",
"1440x1120 (9:7)",
"1120x1440 (7:9)",
"1472x1104 (4:3)",
"1104x1472 (3:4)",
"1536x1024 (3:2)",
"1024x1536 (2:3)",
"1536x864 (16:9)",
"1680x720 (21:9)",
"720x1680 (9:21)",
],
"1536": [
"1152x2048 (9:16)",
"1536x1536 (1:1)",
"1728x1344 (9:7)",
"1344x1728 (7:9)",
"1728x1296 (4:3)",
"1296x1728 (3:4)",
"1872x1248 (3:2)",
"1248x1872 (2:3)",
"2048x1152 (16:9)",
"2016x864 (21:9)",
"864x2016 (9:21)",
],
}
def get_resolution(resolution_str):
"""从分辨率字符串提取宽高,确保是8的倍数"""
if not resolution_str:
return 1024, 1024
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution_str)
if match:
width = int(match.group(1))
height = int(match.group(2))
return width - width % 8, height - height % 8
return 1024, 1024
# ==================== 2. 模型加载与核心优化 ====================
print("🚀 Loading Z-Image-Turbo pipeline...")
# 必须设置为 True,才能加载 Z-Image 自定义的 Pipeline 和 Transformer 类
# 否则无法调用 set_attention_backend
pipe = DiffusionPipeline.from_pretrained(
"Tongyi-MAI/Z-Image-Turbo",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_safetensors=True,
trust_remote_code=True,
)
# 使用 FlowMatchEulerDiscreteScheduler 并设置 shift=3.0
try:
scheduler_config = dict(pipe.scheduler.config)
scheduler_config.pop("algorithm_type", None)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
scheduler_config,
shift=3.0
)
print("✅ Scheduler optimized with shift=3.0")
except Exception as e:
print(f"⚠️ Scheduler config warning: {e}")
# 移动到 GPU
pipe.to("cuda")
print("Enabling torch.compile optimizations...")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.max_autotune_gemm = True
torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN"
torch._inductor.config.triton.cudagraphs = False
# 尝试按顺序启用最快的后端
def enable_best_attention_backend(pipeline):
backends = [
# ===== S Tier:当前最优 =====
"flash_varlen", # FA v2 varlen,稳定 + 高性能
"_flash_3_varlen_hub", # FA v3 varlen(hub),SM90 上非常强
"_flash_varlen_3", # FA v3 varlen(本地)
"_flash_3", # FA v3 非 varlen
"flash", # FA v2 非 varlen
# ===== A Tier:可接受 / 备用高性能 =====
"flash_varlen_hub",
"flash_hub",
"xformers", # 成熟但性能略逊于 FA
"_native_flash",
# ===== B Tier:框架原生 / 兼容优先 =====
"native",
"_native_efficient",
"_native_cudnn",
# ===== C Tier:特定后端 / 场景受限 =====
"flex",
"_native_xla",
"_native_npu",
"aiter",
# ===== D Tier:Sage / 实验性量化实现 =====
"sage",
"sage_hub",
"sage_varlen",
"_sage_qk_int8_pv_fp16_cuda",
"_sage_qk_int8_pv_fp16_triton",
"_sage_qk_int8_pv_fp8_cuda",
"_sage_qk_int8_pv_fp8_cuda_sm90",
# ===== Fallback =====
"_native_math",
]
# 检查 pipeline.transformer 是否有 set_attention_backend 方法
# 这是 Z-Image 自定义类特有的
enabled = False
for backend in backends:
try:
pipeline.transformer.set_attention_backend(backend)
print(f"✅ Attention backend set to: {backend}")
enabled = True
break
except Exception as e:
pass
if not enabled:
print("⚠️ Warning: Transformer model does not support 'set_attention_backend'. Custom code might not be loaded.")
# 如果加载失败,尝试标准的 xformers
try:
pipeline.enable_xformers_memory_efficient_attention()
print("✅ Standard xFormers enabled as fallback")
except:
pass
# 执行后端设置
enable_best_attention_backend(pipe)
# VAE 内存优化
try:
pipe.vae.enable_slicing()
except:
pass
# print("Compiling transformer...")
# pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
# ==================== 3. 生成逻辑 ====================
@spaces.GPU
def generate_image(
prompt,
resolution_choice,
use_custom_res,
custom_width,
custom_height,
num_inference_steps,
seed,
randomize_seed,
negative_prompt,
gallery_history,
progress=gr.Progress(track_tqdm=True)
):
if gallery_history is None:
gallery_history = []
try:
if not prompt or len(prompt.strip()) < 2:
raise gr.Error("请输提示词 (Prompt)")
prompt = prompt.strip()
neg_prompt = negative_prompt.strip() if negative_prompt else None
if use_custom_res:
width = int(custom_width) - int(custom_width) % 8
height = int(custom_height) - int(custom_height) % 8
else:
width, height = get_resolution(resolution_choice)
if randomize_seed:
seed = random.randint(0, 2**32 - 1)
seed = int(seed)
start_time = time.time()
generator = torch.Generator("cuda").manual_seed(seed)
# 清理显存确保最大空间
torch.cuda.empty_cache()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
image = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=int(num_inference_steps),
guidance_scale=0.0,
generator=generator,
negative_prompt=neg_prompt,
max_sequence_length=512,
).images[0]
gen_time = time.time() - start_time
# 格式化历史记录
info_label = f"{width}x{height} | Steps: {num_inference_steps} | Seed: {seed} | {gen_time:.2f}s"
gallery_history.insert(0, (image, info_label))
return gallery_history, seed
except Exception as e:
raise gr.Error(f"生成错误: {str(e)}")
# ==================== 4. UI 样式 ====================
css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap');
body, .gradio-container { font-family: 'Inter', sans-serif !important; }
.header-container { text-align: center; margin-bottom: 20px; }
.header-title {
font-size: 2.5rem; font-weight: 800; margin: 0;
background: linear-gradient(135deg, #f59e0b, #ea580c);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
}
.header-subtitle { font-size: 1rem; color: #6b7280; font-weight: 500; }
.primary-btn {
background: linear-gradient(90deg, #f59e0b 0%, #d97706 100%) !important;
border: none !important;
color: white !important;
font-weight: 600 !important;
font-size: 1.1rem !important;
box-shadow: 0 4px 6px -1px rgba(245, 158, 11, 0.2) !important;
}
.primary-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(245, 158, 11, 0.3) !important; }
.panel-container {
background: #ffffff; border: 1px solid #e5e7eb; border-radius: 12px; padding: 15px;
}
.dark .panel-container { background: #1f2937; border-color: #374151; }
"""
# ==================== 5. Gradio 界面 ====================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css=css, title="Z-Image-Turbo") as demo:
gr.HTML("""
<div class="header-container">
<h1 class="header-title">⚡ Z-Image-Turbo</h1>
<p class="header-subtitle">Optimized Backend • 8 Steps • Gallery History</p>
</div>
""")
with gr.Row():
# --- 控制面板 ---
with gr.Column(scale=4, min_width=320):
with gr.Group(elem_classes="panel-container"):
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Low quality, blurry...",
lines=1
)
generate_btn = gr.Button("🚀 Generate", elem_classes="primary-btn")
with gr.Group(elem_classes="panel-container"):
gr.Markdown("### 📐 Resolution")
res_category = gr.Radio(
choices=["1024", "1280", "1536"],
value="1024",
label="Resolution Base",
container=False
)
resolution_dropdown = gr.Dropdown(
choices=RES_CHOICES["1024"],
value=RES_CHOICES["1024"][0],
label="Select Ratio",
show_label=False
)
with gr.Accordion("Custom Size", open=False):
use_custom_res = gr.Checkbox(label="Enable Custom", value=False)
with gr.Row(visible=False) as custom_res_row:
width_slider = gr.Slider(512, 1536, value=1024, step=64, label="W")
height_slider = gr.Slider(512, 1536, value=1024, step=64, label="H")
with gr.Accordion("⚙️ Settings", open=False):
with gr.Group(elem_classes="panel-container"):
steps_slider = gr.Slider(4, 20, value=8, step=1, label="Steps")
with gr.Row():
random_seed = gr.Checkbox(label="Random Seed", value=True)
seed_input = gr.Number(label="Seed", value=42, visible=False, precision=0)
# --- 画廊 ---
with gr.Column(scale=6, min_width=500):
output_gallery = gr.Gallery(
label="History",
value=[],
columns=[2],
rows=[2],
object_fit="contain",
height="auto",
show_share_button=True,
show_download_button=True,
interactive=False
)
with gr.Row():
last_seed_display = gr.Textbox(label="Last Seed", interactive=False, scale=3)
clear_btn = gr.Button("🗑️ Clear", scale=1, variant="secondary")
# 交互逻辑
def update_resolution_list(category):
return gr.Dropdown(choices=RES_CHOICES[category], value=RES_CHOICES[category][0])
res_category.change(update_resolution_list, inputs=res_category, outputs=resolution_dropdown)
use_custom_res.change(
lambda x: (gr.Row(visible=x), gr.Dropdown(interactive=not x)),
inputs=use_custom_res, outputs=[custom_res_row, resolution_dropdown]
)
random_seed.change(lambda x: gr.Number(visible=not x), inputs=random_seed, outputs=seed_input)
generate_btn.click(
fn=generate_image,
inputs=[prompt, resolution_dropdown, use_custom_res, width_slider, height_slider, steps_slider, seed_input, random_seed, negative_prompt, output_gallery],
outputs=[output_gallery, last_seed_display]
)
clear_btn.click(lambda: ([], ""), outputs=[output_gallery, last_seed_display])
if __name__ == "__main__":
demo.launch()