|
|
| from datetime import datetime |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from diffusers import FluxPipeline |
| from fa3 import FlashFluxAttnProcessor3_0 |
| from aoti import aoti_load_ |
|
|
| |
| dtype = torch.bfloat16 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| pipeline = FluxPipeline.from_pretrained( |
| "black-forest-labs/Flux.1-Dev", torch_dtype=torch.bfloat16 |
| ).to(device) |
| pipeline.transformer.set_attn_processor(FlashFluxAttnProcessor3_0()) |
| aoti_load_(pipeline.transformer, "zerogpu-aoti/flux-dev-aot", "flux-dev-aot.pt2") |
|
|
|
|
| @spaces.GPU |
| def generate_image(prompt: str, progress=gr.Progress(track_tqdm=True)): |
| generator = torch.Generator(device='cuda').manual_seed(42) |
| t0 = datetime.now() |
| output = pipeline( |
| prompt=prompt, |
| num_inference_steps=28, |
| generator=generator, |
| ) |
| return [(output.images[0], f'{(datetime.now() - t0).total_seconds():.2f}s')] |
|
|
|
|
| gr.Interface( |
| fn=generate_image, |
| inputs=gr.Text(label="Prompt"), |
| outputs=gr.Gallery(), |
| examples=["A cat playing with a ball of yarn"], |
| cache_examples=False, |
| ).launch() |