Spaces:
Runtime error
Runtime error
| import modal | |
| import torch | |
| from smolagents import AgentImage, Tool | |
| from diffusers import StableDiffusionUpscalePipeline | |
| from .app import app | |
| from .image import image | |
| class RemoteUpscalerModalApp: | |
| def setup(self): | |
| model_id = "stabilityai/stable-diffusion-x4-upscaler" | |
| self.pipeline = StableDiffusionUpscalePipeline.from_pretrained( | |
| model_id, torch_dtype=torch.float16 | |
| ) | |
| self.pipeline = self.pipeline.to("cuda") | |
| def forward(self, low_res_imgs, prompts: list[str]): | |
| print(len(low_res_imgs)) | |
| print(low_res_imgs) | |
| print(prompts) | |
| low_res_imgs = [ | |
| img.resize( | |
| (min(512, img.width), min(512, img.height)) | |
| ) for img in low_res_imgs | |
| ] | |
| upscaled_images = self.pipeline(prompt=prompts, image=low_res_imgs).images | |
| return upscaled_images | |
| class RemoteUpscalerTool(Tool): | |
| name = "upscaler" | |
| description = """ | |
| Perform upscaling on images. | |
| The "low_res_imgs" are PIL images. | |
| The "prompts" are strings. | |
| The output is a list of PIL images. | |
| You can upscale multiple images at once. | |
| """ | |
| inputs = { | |
| "low_res_imgs": { | |
| "type": "array", | |
| "description": "The low resolution images to upscale", | |
| }, | |
| "prompts": { | |
| "type": "array", | |
| "description": "The prompts to upscale the images", | |
| }, | |
| } | |
| output_type = "object" | |
| def __init__(self): | |
| super().__init__() | |
| tool_class = modal.Cls.from_name(app.name, RemoteUpscalerModalApp.__name__) | |
| self.tool = tool_class() | |
| def forward(self, low_res_imgs: list[AgentImage], prompts: list[str]): | |
| # Modal's forward.map() handles batching internally | |
| # We can use it synchronously since Modal manages the async execution | |
| upscaled_images = self.tool.forward.map(low_res_imgs, prompts) | |
| # Convert the generator to a list to get all results | |
| return list(upscaled_images) | |