| import torch |
| from diffusers import ( |
| StableDiffusionControlNetPipeline, |
| ControlNetModel, |
| EulerAncestralDiscreteScheduler, |
| ) |
| from typing import Dict, List, Any |
| from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
| import qrcode |
| import os |
| import base64 |
| from io import BytesIO |
| import json |
| from PIL import Image |
|
|
| MODEL_ID = "simdi/colorful_qr" |
| WIDTH = 768 |
| HEIGHT = 768 |
|
|
| WEIGHT_PAIRS = [ |
| (0.25, 0.20), |
| (0.25, 0.25), |
| (0.35, 0.20), |
| (0.35, 0.25), |
| (0.45, 0.20), |
| (0.45, 0.25), |
| ] |
|
|
|
|
| def float_to_pair_index(f: float): |
| length = len(WEIGHT_PAIRS) |
| |
| if False: |
| return int(f) |
| |
| else: |
| |
| f = max(0.0, min(f, 1.0)) |
| |
| index = int(f * length) |
| |
| index = min(index, length - 1) |
| return index |
|
|
|
|
| def select_weight_pair(f: float): |
| return WEIGHT_PAIRS[float_to_pair_index(f)] |
|
|
|
|
| def load_models(): |
| controlnet_tile = ControlNetModel.from_pretrained( |
| "lllyasviel/control_v11f1e_sd15_tile", |
| torch_dtype=torch.float16, |
| ) |
|
|
| controlnet_brightness = ControlNetModel.from_pretrained( |
| "ioclab/control_v1p_sd15_brightness", |
| torch_dtype=torch.float16, |
| ) |
|
|
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| MODEL_ID, |
| controlnet=[ |
| controlnet_tile, |
| controlnet_brightness, |
| ], |
| torch_dtype=torch.float16, |
| cache_dir="cache", |
| |
| ) |
|
|
| pipe.to("cuda") |
|
|
| pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) |
| pipe.enable_vae_slicing() |
|
|
| return pipe |
|
|
|
|
| def resize_for_condition_image(input_image, resolution: int): |
| input_image = input_image.convert("RGB") |
| W, H = input_image.size |
| k = float(resolution) / min(H, W) |
| H *= k |
| W *= k |
| H = int(round(H / 64.0)) * 64 |
| W = int(round(W / 64.0)) * 64 |
| img = input_image.resize((W, H), resample=Image.LANCZOS) |
| return img |
|
|
|
|
| def generate_qr_code(content: str): |
| qrcode_generator = qrcode.QRCode( |
| version=1, |
| error_correction=qrcode.ERROR_CORRECT_H, |
| box_size=10, |
| border=2, |
| ) |
| qrcode_generator.clear() |
| qrcode_generator.add_data(content) |
| qrcode_generator.make(fit=True) |
| img = qrcode_generator.make_image(fill_color="black", back_color="white") |
| img = resize_for_condition_image(img, 768) |
| return img |
|
|
|
|
| def image_to_base64(image): |
| buffered = BytesIO() |
| image.save(buffered, format="PNG") |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
| def generate_image_with_conditioning_scale(**inputs): |
| styles = inputs["styles"] |
| pair = inputs["pair"] |
| pipe = inputs["pipe"] |
| qr_image = inputs["qr_image"] |
| generator = inputs["generator"] |
|
|
| images = pipe( |
| prompt=styles, |
| negative_prompt=[""] * len(styles), |
| width=WIDTH, |
| height=HEIGHT, |
| guidance_scale=7.0, |
| generator=generator, |
| num_inference_steps=25, |
| num_images_per_prompt=2, |
| controlnet_conditioning_scale=pair, |
| image=[qr_image] * 2, |
| ).images |
|
|
| |
|
|
| return { |
| "fields": [ |
| { |
| "name": "output", |
| "type": "Image", |
| "value": [ |
| f"data:image/png;base64,{image_to_base64(image)}" |
| for image in images |
| ], |
| } |
| ] |
| } |
|
|
|
|
| def generate_image(pipe, inputs): |
| styles = inputs["styles"] |
| if isinstance(styles, str): |
| styles = [styles] |
| if len(styles) == 1: |
| styles = styles * 5 |
| content = inputs["content"] |
| art_scale = inputs["art_scale"] |
|
|
| with torch.inference_mode(): |
| with torch.autocast("cuda"): |
|
|
| qr_image = generate_qr_code(content) |
| generator = torch.Generator() |
| pair = select_weight_pair(art_scale) |
| return generate_image_with_conditioning_scale( |
| styles=styles, |
| pair=pair, |
| pipe=pipe, |
| qr_image=qr_image, |
| generator=generator, |
| ) |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self._model = load_models() |
|
|
| def __call__(self, model_input: Dict[str, Any]): |
| images = generate_image(self._model, model_input) |
| return images |
|
|