| | import os |
| | import io |
| | import logging |
| | from typing import Optional, Dict, Any, Union |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | from PIL import Image |
| | import base64 |
| |
|
| | |
| | from handler import EndpointHandler |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | app = FastAPI(title="diffsketcher_edit API", description="API for diffsketcher_edit text-to-SVG generation") |
| |
|
| | |
| | model_dir = os.environ.get("MODEL_DIR", "/code/model_weights") |
| | handler = EndpointHandler(model_dir) |
| | logger.info(f"Initialized handler with model_dir: {model_dir}") |
| |
|
| | class TextToImageRequest(BaseModel): |
| | inputs: Union[str, Dict[str, Any]] |
| |
|
| | @app.post("/") |
| | async def generate_image(request: TextToImageRequest): |
| | |
| | try: |
| | logger.info(f"Received request: {request}") |
| | |
| | |
| | image = handler(request.dict()) |
| | |
| | |
| | img_byte_arr = io.BytesIO() |
| | image.save(img_byte_arr, format='PNG') |
| | img_byte_arr = img_byte_arr.getvalue() |
| | |
| | |
| | return {"image": base64.b64encode(img_byte_arr).decode('utf-8')} |
| | except Exception as e: |
| | logger.error(f"Error processing request: {e}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.get("/health") |
| | async def health_check(): |
| | |
| | return {"status": "ok"} |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|