Spaces:
Runtime error
Runtime error
| import torch | |
| import spaces | |
| import gradio as gr | |
| import time | |
| import re | |
| import random | |
| import os | |
| from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler | |
| import warnings | |
| # 忽略警告 | |
| warnings.filterwarnings("ignore") | |
| # ==================== 1. 分辨率配置 ==================== | |
| RES_CHOICES = { | |
| "1024": [ | |
| "720x1280 (9:16)", | |
| "1024x1024 (1:1)", | |
| "1152x896 (9:7)", | |
| "896x1152 (7:9)", | |
| "1152x864 (4:3)", | |
| "864x1152 (3:4)", | |
| "1248x832 (3:2)", | |
| "832x1248 (2:3)", | |
| "1280x720 (16:9)", | |
| "1344x576 (21:9)", | |
| "576x1344 (9:21)", | |
| ], | |
| "1280": [ | |
| "864x1536 (9:16)", | |
| "1280x1280 (1:1)", | |
| "1440x1120 (9:7)", | |
| "1120x1440 (7:9)", | |
| "1472x1104 (4:3)", | |
| "1104x1472 (3:4)", | |
| "1536x1024 (3:2)", | |
| "1024x1536 (2:3)", | |
| "1536x864 (16:9)", | |
| "1680x720 (21:9)", | |
| "720x1680 (9:21)", | |
| ], | |
| "1536": [ | |
| "1152x2048 (9:16)", | |
| "1536x1536 (1:1)", | |
| "1728x1344 (9:7)", | |
| "1344x1728 (7:9)", | |
| "1728x1296 (4:3)", | |
| "1296x1728 (3:4)", | |
| "1872x1248 (3:2)", | |
| "1248x1872 (2:3)", | |
| "2048x1152 (16:9)", | |
| "2016x864 (21:9)", | |
| "864x2016 (9:21)", | |
| ], | |
| } | |
| def get_resolution(resolution_str): | |
| """从分辨率字符串提取宽高,确保是8的倍数""" | |
| if not resolution_str: | |
| return 1024, 1024 | |
| match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution_str) | |
| if match: | |
| width = int(match.group(1)) | |
| height = int(match.group(2)) | |
| return width - width % 8, height - height % 8 | |
| return 1024, 1024 | |
| # ==================== 2. 模型加载与核心优化 ==================== | |
| print("🚀 Loading Z-Image-Turbo pipeline...") | |
| # 必须设置为 True,才能加载 Z-Image 自定义的 Pipeline 和 Transformer 类 | |
| # 否则无法调用 set_attention_backend | |
| pipe = DiffusionPipeline.from_pretrained( | |
| "Tongyi-MAI/Z-Image-Turbo", | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True, | |
| trust_remote_code=True, | |
| ) | |
| # 使用 FlowMatchEulerDiscreteScheduler 并设置 shift=3.0 | |
| try: | |
| scheduler_config = dict(pipe.scheduler.config) | |
| scheduler_config.pop("algorithm_type", None) | |
| pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config( | |
| scheduler_config, | |
| shift=3.0 | |
| ) | |
| print("✅ Scheduler optimized with shift=3.0") | |
| except Exception as e: | |
| print(f"⚠️ Scheduler config warning: {e}") | |
| # 移动到 GPU | |
| pipe.to("cuda") | |
| print("Enabling torch.compile optimizations...") | |
| torch._inductor.config.conv_1x1_as_mm = True | |
| torch._inductor.config.coordinate_descent_tuning = True | |
| torch._inductor.config.epilogue_fusion = False | |
| torch._inductor.config.coordinate_descent_check_all_directions = True | |
| torch._inductor.config.max_autotune_gemm = True | |
| torch._inductor.config.max_autotune_gemm_backends = "TRITON,ATEN" | |
| torch._inductor.config.triton.cudagraphs = False | |
| # 尝试按顺序启用最快的后端 | |
| def enable_best_attention_backend(pipeline): | |
| backends = [ | |
| # ===== S Tier:当前最优 ===== | |
| "flash_varlen", # FA v2 varlen,稳定 + 高性能 | |
| "_flash_3_varlen_hub", # FA v3 varlen(hub),SM90 上非常强 | |
| "_flash_varlen_3", # FA v3 varlen(本地) | |
| "_flash_3", # FA v3 非 varlen | |
| "flash", # FA v2 非 varlen | |
| # ===== A Tier:可接受 / 备用高性能 ===== | |
| "flash_varlen_hub", | |
| "flash_hub", | |
| "xformers", # 成熟但性能略逊于 FA | |
| "_native_flash", | |
| # ===== B Tier:框架原生 / 兼容优先 ===== | |
| "native", | |
| "_native_efficient", | |
| "_native_cudnn", | |
| # ===== C Tier:特定后端 / 场景受限 ===== | |
| "flex", | |
| "_native_xla", | |
| "_native_npu", | |
| "aiter", | |
| # ===== D Tier:Sage / 实验性量化实现 ===== | |
| "sage", | |
| "sage_hub", | |
| "sage_varlen", | |
| "_sage_qk_int8_pv_fp16_cuda", | |
| "_sage_qk_int8_pv_fp16_triton", | |
| "_sage_qk_int8_pv_fp8_cuda", | |
| "_sage_qk_int8_pv_fp8_cuda_sm90", | |
| # ===== Fallback ===== | |
| "_native_math", | |
| ] | |
| # 检查 pipeline.transformer 是否有 set_attention_backend 方法 | |
| # 这是 Z-Image 自定义类特有的 | |
| enabled = False | |
| for backend in backends: | |
| try: | |
| pipeline.transformer.set_attention_backend(backend) | |
| print(f"✅ Attention backend set to: {backend}") | |
| enabled = True | |
| break | |
| except Exception as e: | |
| pass | |
| if not enabled: | |
| print("⚠️ Warning: Transformer model does not support 'set_attention_backend'. Custom code might not be loaded.") | |
| # 如果加载失败,尝试标准的 xformers | |
| try: | |
| pipeline.enable_xformers_memory_efficient_attention() | |
| print("✅ Standard xFormers enabled as fallback") | |
| except: | |
| pass | |
| # 执行后端设置 | |
| enable_best_attention_backend(pipe) | |
| # VAE 内存优化 | |
| try: | |
| pipe.vae.enable_slicing() | |
| except: | |
| pass | |
| # print("Compiling transformer...") | |
| # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False) | |
| # ==================== 3. 生成逻辑 ==================== | |
| def generate_image( | |
| prompt, | |
| resolution_choice, | |
| use_custom_res, | |
| custom_width, | |
| custom_height, | |
| num_inference_steps, | |
| seed, | |
| randomize_seed, | |
| negative_prompt, | |
| gallery_history, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| if gallery_history is None: | |
| gallery_history = [] | |
| try: | |
| if not prompt or len(prompt.strip()) < 2: | |
| raise gr.Error("请输提示词 (Prompt)") | |
| prompt = prompt.strip() | |
| neg_prompt = negative_prompt.strip() if negative_prompt else None | |
| if use_custom_res: | |
| width = int(custom_width) - int(custom_width) % 8 | |
| height = int(custom_height) - int(custom_height) % 8 | |
| else: | |
| width, height = get_resolution(resolution_choice) | |
| if randomize_seed: | |
| seed = random.randint(0, 2**32 - 1) | |
| seed = int(seed) | |
| start_time = time.time() | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| # 清理显存确保最大空间 | |
| torch.cuda.empty_cache() | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| image = pipe( | |
| prompt=prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=int(num_inference_steps), | |
| guidance_scale=0.0, | |
| generator=generator, | |
| negative_prompt=neg_prompt, | |
| max_sequence_length=512, | |
| ).images[0] | |
| gen_time = time.time() - start_time | |
| # 格式化历史记录 | |
| info_label = f"{width}x{height} | Steps: {num_inference_steps} | Seed: {seed} | {gen_time:.2f}s" | |
| gallery_history.insert(0, (image, info_label)) | |
| return gallery_history, seed | |
| except Exception as e: | |
| raise gr.Error(f"生成错误: {str(e)}") | |
| # ==================== 4. UI 样式 ==================== | |
| css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600;800&display=swap'); | |
| body, .gradio-container { font-family: 'Inter', sans-serif !important; } | |
| .header-container { text-align: center; margin-bottom: 20px; } | |
| .header-title { | |
| font-size: 2.5rem; font-weight: 800; margin: 0; | |
| background: linear-gradient(135deg, #f59e0b, #ea580c); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
| } | |
| .header-subtitle { font-size: 1rem; color: #6b7280; font-weight: 500; } | |
| .primary-btn { | |
| background: linear-gradient(90deg, #f59e0b 0%, #d97706 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| font-size: 1.1rem !important; | |
| box-shadow: 0 4px 6px -1px rgba(245, 158, 11, 0.2) !important; | |
| } | |
| .primary-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 15px -3px rgba(245, 158, 11, 0.3) !important; } | |
| .panel-container { | |
| background: #ffffff; border: 1px solid #e5e7eb; border-radius: 12px; padding: 15px; | |
| } | |
| .dark .panel-container { background: #1f2937; border-color: #374151; } | |
| """ | |
| # ==================== 5. Gradio 界面 ==================== | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange"), css=css, title="Z-Image-Turbo") as demo: | |
| gr.HTML(""" | |
| <div class="header-container"> | |
| <h1 class="header-title">⚡ Z-Image-Turbo</h1> | |
| <p class="header-subtitle">Optimized Backend • 8 Steps • Gallery History</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # --- 控制面板 --- | |
| with gr.Column(scale=4, min_width=320): | |
| with gr.Group(elem_classes="panel-container"): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3 | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="Low quality, blurry...", | |
| lines=1 | |
| ) | |
| generate_btn = gr.Button("🚀 Generate", elem_classes="primary-btn") | |
| with gr.Group(elem_classes="panel-container"): | |
| gr.Markdown("### 📐 Resolution") | |
| res_category = gr.Radio( | |
| choices=["1024", "1280", "1536"], | |
| value="1024", | |
| label="Resolution Base", | |
| container=False | |
| ) | |
| resolution_dropdown = gr.Dropdown( | |
| choices=RES_CHOICES["1024"], | |
| value=RES_CHOICES["1024"][0], | |
| label="Select Ratio", | |
| show_label=False | |
| ) | |
| with gr.Accordion("Custom Size", open=False): | |
| use_custom_res = gr.Checkbox(label="Enable Custom", value=False) | |
| with gr.Row(visible=False) as custom_res_row: | |
| width_slider = gr.Slider(512, 1536, value=1024, step=64, label="W") | |
| height_slider = gr.Slider(512, 1536, value=1024, step=64, label="H") | |
| with gr.Accordion("⚙️ Settings", open=False): | |
| with gr.Group(elem_classes="panel-container"): | |
| steps_slider = gr.Slider(4, 20, value=8, step=1, label="Steps") | |
| with gr.Row(): | |
| random_seed = gr.Checkbox(label="Random Seed", value=True) | |
| seed_input = gr.Number(label="Seed", value=42, visible=False, precision=0) | |
| # --- 画廊 --- | |
| with gr.Column(scale=6, min_width=500): | |
| output_gallery = gr.Gallery( | |
| label="History", | |
| value=[], | |
| columns=[2], | |
| rows=[2], | |
| object_fit="contain", | |
| height="auto", | |
| show_share_button=True, | |
| show_download_button=True, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| last_seed_display = gr.Textbox(label="Last Seed", interactive=False, scale=3) | |
| clear_btn = gr.Button("🗑️ Clear", scale=1, variant="secondary") | |
| # 交互逻辑 | |
| def update_resolution_list(category): | |
| return gr.Dropdown(choices=RES_CHOICES[category], value=RES_CHOICES[category][0]) | |
| res_category.change(update_resolution_list, inputs=res_category, outputs=resolution_dropdown) | |
| use_custom_res.change( | |
| lambda x: (gr.Row(visible=x), gr.Dropdown(interactive=not x)), | |
| inputs=use_custom_res, outputs=[custom_res_row, resolution_dropdown] | |
| ) | |
| random_seed.change(lambda x: gr.Number(visible=not x), inputs=random_seed, outputs=seed_input) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=[prompt, resolution_dropdown, use_custom_res, width_slider, height_slider, steps_slider, seed_input, random_seed, negative_prompt, output_gallery], | |
| outputs=[output_gallery, last_seed_display] | |
| ) | |
| clear_btn.click(lambda: ([], ""), outputs=[output_gallery, last_seed_display]) | |
| if __name__ == "__main__": | |
| demo.launch() |