| | import sys |
| | import logging |
| | import shutil |
| | import tempfile |
| | import zipfile |
| | import io as python_io |
| | import base64 |
| | from pathlib import Path |
| |
|
| | from fastapi import FastAPI, UploadFile, File |
| | from fastapi.responses import StreamingResponse, JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import torch |
| |
|
| | |
| | |
| | |
| | sys.path.append(str(Path(__file__).parents[2])) |
| |
|
| | from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor |
| | from sharp.utils import io as sharp_io |
| | from sharp.utils.gaussians import save_ply |
| | from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | LOGGER = logging.getLogger("sharp.api") |
| |
|
| | app = FastAPI() |
| |
|
| | |
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | predictor: RGBGaussianPredictor | None = None |
| | device: torch.device | None = None |
| |
|
| |
|
| | @app.on_event("startup") |
| | async def startup_event(): |
| | global predictor, device |
| | try: |
| | device_str = ( |
| | "cuda" |
| | if torch.cuda.is_available() |
| | else ("mps" if torch.backends.mps.is_available() else "cpu") |
| | ) |
| | device = torch.device(device_str) |
| | LOGGER.info(f"Using device: {device}") |
| |
|
| | LOGGER.info("Loading SHARP model state dict...") |
| | state_dict = torch.hub.load_state_dict_from_url( |
| | DEFAULT_MODEL_URL, progress=True, map_location=device |
| | ) |
| |
|
| | predictor = create_predictor(PredictorParams()) |
| | predictor.load_state_dict(state_dict) |
| | predictor.eval() |
| | predictor.to(device) |
| | LOGGER.info("Model loaded and ready.") |
| | except Exception as e: |
| | LOGGER.exception("Failed during startup/model init: %s", e) |
| | |
| |
|
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return { |
| | "status": "ok", |
| | "device": str(device) if device else None, |
| | "model_loaded": predictor is not None, |
| | } |
| |
|
| |
|
| | @app.post("/predict") |
| | async def predict(files: list[UploadFile] = File(...)): |
| | """Accept images and return JSON with per-image metadata and PLY as base64.""" |
| | if not predictor: |
| | return JSONResponse({"error": "Model not loaded"}, status_code=500) |
| |
|
| | results = [] |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | temp_path = Path(temp_dir) |
| |
|
| | for file in files: |
| | try: |
| | |
| | file_path = temp_path / file.filename |
| | with open(file_path, "wb") as buffer: |
| | shutil.copyfileobj(file.file, buffer) |
| |
|
| | |
| | image, _, f_px = sharp_io.load_rgb(file_path) |
| | gaussians = predict_image(predictor, image, f_px, device) |
| |
|
| | |
| | ply_filename = f"{file_path.stem}.ply" |
| | ply_path = temp_path / ply_filename |
| | height, width = image.shape[:2] |
| | save_ply(gaussians, f_px, (height, width), ply_path) |
| |
|
| | |
| | with open(ply_path, "rb") as f: |
| | ply_data = base64.b64encode(f.read()).decode("utf-8") |
| |
|
| | results.append( |
| | { |
| | "filename": file.filename, |
| | "ply_filename": ply_filename, |
| | "ply_data": ply_data, |
| | "width": width, |
| | "height": height, |
| | "focal_length": f_px, |
| | } |
| | ) |
| | except Exception as e: |
| | LOGGER.exception("Error processing %s: %s", file.filename, e) |
| | results.append({"filename": file.filename, "error": str(e)}) |
| |
|
| | return {"results": results} |
| |
|
| |
|
| | @app.post("/predict/download") |
| | async def predict_download(files: list[UploadFile] = File(...)): |
| | """Accept images and return a ZIP of generated PLY files.""" |
| | if not predictor: |
| | return JSONResponse({"error": "Model not loaded"}, status_code=500) |
| |
|
| | output_zip = python_io.BytesIO() |
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | temp_path = Path(temp_dir) |
| | with zipfile.ZipFile(output_zip, "w") as zf: |
| | for file in files: |
| | try: |
| | file_path = temp_path / file.filename |
| | with open(file_path, "wb") as buffer: |
| | shutil.copyfileobj(file.file, buffer) |
| |
|
| | image, _, f_px = sharp_io.load_rgb(file_path) |
| | gaussians = predict_image(predictor, image, f_px, device) |
| |
|
| | ply_filename = f"{file_path.stem}.ply" |
| | ply_path = temp_path / ply_filename |
| | height, width = image.shape[:2] |
| | save_ply(gaussians, f_px, (height, width), ply_path) |
| |
|
| | zf.write(ply_path, ply_filename) |
| | except Exception as e: |
| | LOGGER.exception("Error processing %s: %s", file.filename, e) |
| | continue |
| |
|
| | output_zip.seek(0) |
| | return StreamingResponse( |
| | output_zip, |
| | media_type="application/zip", |
| | headers={"Content-Disposition": "attachment; filename=gaussians.zip"}, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|