| from typing import Literal, Optional, TYPE_CHECKING |
|
|
| import numpy as np |
| from fastapi import FastAPI |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from pydantic import BaseModel, Field |
| from platform import system |
|
|
| if TYPE_CHECKING: |
| from flux_pipeline import FluxPipeline |
|
|
| if system() == "Windows": |
| MAX_RAND = 2**16 - 1 |
| else: |
| MAX_RAND = 2**32 - 1 |
|
|
|
|
| class AppState: |
| model: "FluxPipeline" |
|
|
|
|
| class FastAPIApp(FastAPI): |
| state: AppState |
|
|
|
|
| class LoraArgs(BaseModel): |
| scale: Optional[float] = 1.0 |
| path: Optional[str] = None |
| name: Optional[str] = None |
| action: Optional[Literal["load", "unload"]] = "load" |
|
|
|
|
| class LoraLoadResponse(BaseModel): |
| status: Literal["success", "error"] |
| message: Optional[str] = None |
|
|
|
|
| class GenerateArgs(BaseModel): |
| prompt: str |
| width: Optional[int] = Field(default=720) |
| height: Optional[int] = Field(default=1024) |
| num_steps: Optional[int] = Field(default=24) |
| guidance: Optional[float] = Field(default=3.5) |
| seed: Optional[int] = Field( |
| default_factory=lambda: np.random.randint(0, MAX_RAND), gt=0, lt=MAX_RAND |
| ) |
| strength: Optional[float] = 1.0 |
| init_image: Optional[str] = None |
|
|
|
|
| app = FastAPIApp() |
|
|
|
|
| @app.post("/generate") |
| def generate(args: GenerateArgs): |
| """ |
| Generates an image from the Flux flow transformer. |
| |
| Args: |
| args (GenerateArgs): Arguments for image generation: |
| |
| - `prompt`: The prompt used for image generation. |
| |
| - `width`: The width of the image. |
| |
| - `height`: The height of the image. |
| |
| - `num_steps`: The number of steps for the image generation. |
| |
| - `guidance`: The guidance for image generation, represents the |
| influence of the prompt on the image generation. |
| |
| - `seed`: The seed for the image generation. |
| |
| - `strength`: strength for image generation, 0.0 - 1.0. |
| Represents the percent of diffusion steps to run, |
| setting the init_image as the noised latent at the |
| given number of steps. |
| |
| - `init_image`: Base64 encoded image or path to image to use as the init image. |
| |
| Returns: |
| StreamingResponse: The generated image as streaming jpeg bytes. |
| """ |
| result = app.state.model.generate(**args.model_dump()) |
| return StreamingResponse(result, media_type="image/jpeg") |
|
|
|
|
| @app.post("/lora", response_model=LoraLoadResponse) |
| def lora_action(args: LoraArgs): |
| """ |
| Loads or unloads a LoRA checkpoint into / from the Flux flow transformer. |
| |
| Args: |
| args (LoraArgs): Arguments for the LoRA action: |
| |
| - `scale`: The scaling factor for the LoRA weights. |
| - `path`: The path to the LoRA checkpoint. |
| - `name`: The name of the LoRA checkpoint. |
| - `action`: The action to perform, either "load" or "unload". |
| |
| Returns: |
| LoraLoadResponse: The status of the LoRA action. |
| """ |
| try: |
| if args.action == "load": |
| app.state.model.load_lora(args.path, args.scale, args.name) |
| elif args.action == "unload": |
| app.state.model.unload_lora(args.name if args.name else args.path) |
| else: |
| return JSONResponse( |
| content={ |
| "status": "error", |
| "message": f"Invalid action, expected 'load' or 'unload', got {args.action}", |
| }, |
| status_code=400, |
| ) |
| except Exception as e: |
| return JSONResponse( |
| status_code=500, content={"status": "error", "message": str(e)} |
| ) |
| return JSONResponse(status_code=200, content={"status": "success"}) |
|
|