| import torch |
| import base64 |
| import io |
| import os |
| from typing import Optional |
| from PIL import Image |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
|
|
| app = FastAPI() |
|
|
| |
| pipe = None |
| export_to_video = None |
|
|
| class InferenceRequest(BaseModel): |
| image: str |
| prompt: str |
| negative_prompt: str = "ugly, static, blurry, low quality" |
| num_frames: int = 93 |
| num_inference_steps: int = 35 |
| guidance_scale: float = 7.0 |
| seed: Optional[int] = None |
|
|
| @app.on_event("startup") |
| async def load_model(): |
| global pipe, export_to_video |
| from diffusers import Cosmos2VideoToWorldPipeline |
| from diffusers.utils import export_to_video as etv |
| |
| export_to_video = etv |
| model_id = "nvidia/Cosmos-Predict2-2B-Video2World" |
| |
| print("Loading model...") |
| pipe = Cosmos2VideoToWorldPipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16, |
| token=os.environ.get("HF_TOKEN"), |
| ) |
| pipe.to("cuda") |
| print("Model loaded successfully!") |
|
|
| @app.post("/predict") |
| @app.post("/") |
| async def predict(request: dict): |
| global pipe, export_to_video |
| |
| |
| inputs = request.get("inputs", request) |
| |
| image_data = inputs.get("image") |
| if not image_data: |
| raise HTTPException(status_code=400, detail="No image provided") |
| |
| prompt = inputs.get("prompt", "") |
| if not prompt: |
| raise HTTPException(status_code=400, detail="No prompt provided") |
| |
| |
| try: |
| if image_data.startswith("http"): |
| from diffusers.utils import load_image |
| image = load_image(image_data) |
| else: |
| image_bytes = base64.b64decode(image_data) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
| |
| image = image.resize((1280, 704)) |
| |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Failed to load image: {str(e)}") |
|
|
| negative_prompt = inputs.get("negative_prompt", "ugly, static, blurry, low quality") |
| num_frames = inputs.get("num_frames", 93) |
| num_inference_steps = inputs.get("num_inference_steps", 35) |
| guidance_scale = inputs.get("guidance_scale", 7.0) |
| seed = inputs.get("seed") |
|
|
| |
| generator = None |
| if seed is not None: |
| generator = torch.Generator(device="cuda").manual_seed(int(seed)) |
|
|
| try: |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| output = pipe( |
| image=image, |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| num_frames=num_frames, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| generator=generator, |
| ) |
|
|
| video_path = "/tmp/output.mp4" |
| export_to_video(output.frames[0], video_path, fps=16) |
|
|
| with open(video_path, "rb") as f: |
| video_b64 = base64.b64encode(f.read()).decode("utf-8") |
|
|
| return {"video": video_b64, "content_type": "video/mp4"} |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}") |
|
|
| @app.get("/health") |
| @app.get("/") |
| async def health(): |
| return {"status": "healthy", "message": "Cosmos-Predict2 Video2World API"} |
|
|