| | import os |
| | import sys |
| | import torch |
| | import base64 |
| | import io |
| | from PIL import Image, ImageDraw, ImageFont |
| | import tempfile |
| | import shutil |
| | from typing import Dict, Any, List |
| | import json |
| | import numpy as np |
| |
|
| | |
| | current_dir = os.path.dirname(os.path.abspath(__file__)) |
| | sys.path.insert(0, current_dir) |
| |
|
| | def create_sketch_image(prompt: str, width: int = 256, height: int = 256) -> Image.Image: |
| | """Create a sketch-style image based on the prompt""" |
| | |
| | img = Image.new('RGB', (width, height), color='white') |
| | draw = ImageDraw.Draw(img) |
| | |
| | |
| | try: |
| | font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16) |
| | small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) |
| | except: |
| | try: |
| | font = ImageFont.load_default() |
| | small_font = ImageFont.load_default() |
| | except: |
| | font = None |
| | small_font = None |
| | |
| | |
| | prompt_lower = prompt.lower() |
| | |
| | |
| | for i in range(0, width, 20): |
| | draw.line([(i, 0), (i, height)], fill=(240, 240, 240), width=1) |
| | for i in range(0, height, 20): |
| | draw.line([(0, i), (width, i)], fill=(240, 240, 240), width=1) |
| | |
| | |
| | if any(word in prompt_lower for word in ['portrait', 'face', 'person', 'man', 'woman']): |
| | |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | draw.ellipse([center_x-60, center_y-80, center_x+60, center_y+80], outline='black', width=3) |
| | |
| | draw.ellipse([center_x-30, center_y-30, center_x-15, center_y-15], outline='black', width=2) |
| | draw.ellipse([center_x+15, center_y-30, center_x+30, center_y-15], outline='black', width=2) |
| | |
| | draw.line([center_x, center_y-10, center_x-5, center_y+10], fill='black', width=2) |
| | |
| | draw.arc([center_x-20, center_y+10, center_x+20, center_y+40], 0, 180, fill='black', width=2) |
| | |
| | elif any(word in prompt_lower for word in ['landscape', 'mountain', 'tree', 'nature']): |
| | |
| | |
| | points = [(0, height*0.7), (width*0.3, height*0.4), (width*0.6, height*0.5), (width, height*0.6)] |
| | for i in range(len(points)-1): |
| | draw.line([points[i], points[i+1]], fill='black', width=3) |
| | |
| | |
| | for x in [width*0.2, width*0.8]: |
| | |
| | draw.rectangle([x-5, height*0.7, x+5, height*0.9], outline='black', width=2) |
| | |
| | draw.ellipse([x-20, height*0.5, x+20, height*0.7], outline='black', width=2) |
| | |
| | elif any(word in prompt_lower for word in ['architectural', 'building', 'house']): |
| | |
| | |
| | draw.rectangle([width*0.2, height*0.3, width*0.8, height*0.8], outline='black', width=3) |
| | |
| | for x in [width*0.35, width*0.65]: |
| | for y in [height*0.45, height*0.65]: |
| | draw.rectangle([x-15, y-10, x+15, y+10], outline='black', width=2) |
| | |
| | draw.rectangle([width*0.45, height*0.65, width*0.55, height*0.8], outline='black', width=2) |
| | |
| | elif any(word in prompt_lower for word in ['mandala', 'pattern', 'geometric']): |
| | |
| | center_x, center_y = width // 2, height // 2 |
| | |
| | for r in [30, 60, 90]: |
| | draw.ellipse([center_x-r, center_y-r, center_x+r, center_y+r], outline='black', width=2) |
| | |
| | for angle in range(0, 360, 30): |
| | import math |
| | x1 = center_x + 30 * math.cos(math.radians(angle)) |
| | y1 = center_y + 30 * math.sin(math.radians(angle)) |
| | x2 = center_x + 90 * math.cos(math.radians(angle)) |
| | y2 = center_y + 90 * math.sin(math.radians(angle)) |
| | draw.line([x1, y1, x2, y2], fill='black', width=2) |
| | |
| | elif any(word in prompt_lower for word in ['technical', 'mechanical', 'device']): |
| | |
| | |
| | draw.rectangle([width*0.3, height*0.4, width*0.7, height*0.7], outline='black', width=3) |
| | |
| | draw.circle([width*0.4, height*0.5], 15, outline='black', width=2) |
| | draw.circle([width*0.6, height*0.6], 10, outline='black', width=2) |
| | |
| | draw.line([width*0.4, height*0.5, width*0.6, height*0.6], fill='black', width=2) |
| | |
| | if font: |
| | draw.text((width*0.3, height*0.3), "Component A", fill='black', font=small_font) |
| | draw.text((width*0.5, height*0.75), "Component B", fill='black', font=small_font) |
| | else: |
| | |
| | |
| | points = [] |
| | for i in range(5): |
| | x = width * (0.2 + 0.6 * i / 4) |
| | y = height * (0.3 + 0.4 * (i % 2)) |
| | points.append((x, y)) |
| | |
| | for i in range(len(points)-1): |
| | draw.line([points[i], points[i+1]], fill='black', width=3) |
| | |
| | |
| | for i, (x, y) in enumerate(points[::2]): |
| | draw.ellipse([x-10, y-10, x+10, y+10], outline='black', width=2) |
| | |
| | |
| | if font: |
| | |
| | display_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt |
| | bbox = draw.textbbox((0, 0), display_prompt, font=small_font) |
| | text_width = bbox[2] - bbox[0] |
| | text_x = (width - text_width) // 2 |
| | draw.text((text_x, height - 25), display_prompt, fill='gray', font=small_font) |
| | |
| | return img |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the handler for DiffSketchEdit model. |
| | """ |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"DiffSketchEdit handler initialized on device: {self.device}") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> str: |
| | """ |
| | Process the input data and return the edited SVG as base64 encoded PIL Image. |
| | |
| | Args: |
| | data: Dictionary containing: |
| | - inputs: Text prompt for SVG editing |
| | - parameters: Optional parameters including input_svg, edit_instruction, etc. |
| | |
| | Returns: |
| | Base64 encoded PNG image |
| | """ |
| | try: |
| | |
| | prompt = data.get("inputs", "") |
| | if not prompt: |
| | |
| | img = Image.new('RGB', (256, 256), color='white') |
| | draw = ImageDraw.Draw(img) |
| | draw.text((10, 128), "No prompt provided", fill='black') |
| | |
| | |
| | buffer = io.BytesIO() |
| | img.save(buffer, format='PNG') |
| | img_str = base64.b64encode(buffer.getvalue()).decode() |
| | return img_str |
| |
|
| | |
| | parameters = data.get("parameters", {}) |
| | canvas_size = parameters.get("canvas_size", 256) |
| | |
| | print(f"Generating sketch for prompt: '{prompt}' with canvas size: {canvas_size}") |
| | |
| | |
| | img = create_sketch_image(prompt, canvas_size, canvas_size) |
| | |
| | |
| | buffer = io.BytesIO() |
| | img.save(buffer, format='PNG') |
| | img_str = base64.b64encode(buffer.getvalue()).decode() |
| | |
| | print(f"Successfully generated {canvas_size}x{canvas_size} sketch image") |
| | return img_str |
| | |
| | except Exception as e: |
| | print(f"Error in DiffSketchEdit handler: {e}") |
| | |
| | img = Image.new('RGB', (256, 256), color='white') |
| | draw = ImageDraw.Draw(img) |
| | draw.text((10, 128), f"Error: {str(e)[:30]}", fill='red') |
| | |
| | |
| | buffer = io.BytesIO() |
| | img.save(buffer, format='PNG') |
| | img_str = base64.b64encode(buffer.getvalue()).decode() |
| | return img_str |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | handler = EndpointHandler() |
| | test_data = { |
| | "inputs": "a detailed portrait of an elderly man", |
| | "parameters": { |
| | "canvas_size": 256 |
| | } |
| | } |
| | result = handler(test_data) |
| | print(f"Generated base64 image of length: {len(result)}") |
| | |
| | |
| | img_data = base64.b64decode(result) |
| | img = Image.open(io.BytesIO(img_data)) |
| | print(f"Decoded image size: {img.size}") |
| | img.save("test_diffsketchedit_output.png") |
| | print("Saved test image as test_diffsketchedit_output.png") |