Text-to-Image
Diffusers
English
stable-diffusion-xl
huggingface-inference-endpoints
custom-inference
Instructions to use msgxai/msgxai-hg-api with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use msgxai/msgxai-hg-api with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("msgxai/msgxai-hg-api", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
| import os | |
| import json | |
| import random | |
| import re | |
| import base64 | |
| from io import BytesIO | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from diffusers import ( | |
| AutoencoderKL, | |
| StableDiffusionXLPipeline, | |
| EulerAncestralDiscreteScheduler, | |
| DPMSolverSDEScheduler | |
| ) | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from PIL import Image | |
| # Global constants | |
| MAX_SEED = 12211231 # Maximum seed value for random generator | |
| NUM_IMAGES_PER_PROMPT = 1 # Number of images to generate per prompt | |
| USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" # Flag to enable torch compilation | |
| # --- Child-Content Filtering Functions --- | |
| child_related_regex = re.compile( | |
| r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|' | |
| r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|' | |
| r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))', | |
| re.IGNORECASE | |
| ) | |
| def remove_child_related_content(prompt: str) -> str: | |
| """Remove any child-related references from the prompt.""" | |
| # Filter out child-related words/phrases using regex | |
| cleaned_prompt = re.sub(child_related_regex, '', prompt) | |
| return cleaned_prompt.strip() | |
| def contains_child_related_content(prompt: str) -> bool: | |
| """Check if the prompt contains child-related content.""" | |
| # Use regex to determine if prompt has child-related terms | |
| return bool(child_related_regex.search(prompt)) | |
| # --- Utility Function: Convert PIL Image to Base64 --- | |
| def pil_image_to_base64(img: Image.Image) -> str: | |
| """Convert a PIL Image to base64 encoded string.""" | |
| # Create a BytesIO buffer and save the image to it | |
| buffered = BytesIO() | |
| img.convert("RGB").save(buffered, format="WEBP", quality=90) | |
| # Convert buffer to base64 string | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| class EndpointHandler: | |
| """ | |
| Custom handler for Hugging Face Inference Endpoints. | |
| This class follows the HF Inference Endpoints specification. | |
| For Hugging Face Inference Endpoints, only this class is needed. | |
| It provides both the initialization (__init__) and inference (__call__) methods | |
| required by the Hugging Face Inference API. | |
| """ | |
| def __init__(self, path="", config=None): | |
| """ | |
| Initialize the handler with model path and configurations. | |
| Args: | |
| path (str): Path to the model directory (used by HF Inference Endpoints). | |
| config (dict, optional): Configuration for the handler, passed by HF Inference Endpoints. | |
| """ | |
| # Load configuration from app.conf or use provided config | |
| try: | |
| if config: | |
| # Use config provided by HF Inference Endpoints | |
| self.cfg = config | |
| else: | |
| # Try to load from app.conf as fallback | |
| config_path = os.path.join(path, "app.conf") if path else "app.conf" | |
| with open(config_path, "r") as f: | |
| self.cfg = json.load(f) | |
| print("Configuration loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading configuration: {e}") | |
| self.cfg = {} | |
| # Load the model pipeline | |
| print("Loading the model pipeline...") | |
| self.pipe = self._load_pipeline_and_scheduler() | |
| print("Model loaded successfully!") | |
| def _load_pipeline_and_scheduler(self): | |
| """Load the Stable Diffusion pipeline and scheduler.""" | |
| # Get clip_skip from configuration, default to 0 | |
| clip_skip = self.cfg.get("clip_skip", 0) | |
| # Download model files from Hugging Face Hub | |
| ckpt_dir = snapshot_download(repo_id=self.cfg["model_id"]) | |
| # Load the VAE model (for decoding latents) | |
| vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.float16) | |
| # Load the Stable Diffusion XL pipeline | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| ckpt_dir, | |
| vae=vae, | |
| torch_dtype=torch.float16, | |
| use_safetensors=self.cfg.get("use_safetensors", True) | |
| ) | |
| # Move model to GPU | |
| pipe = pipe.to("cuda") | |
| # Use efficient attention processor | |
| pipe.unet.set_attn_processor(AttnProcessor2_0()) | |
| # Set up samplers/schedulers based on configuration | |
| samplers = { | |
| "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config), | |
| "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) | |
| } | |
| # Default to "DPM++ SDE Karras" if not specified | |
| pipe.scheduler = samplers.get(self.cfg.get("sampler", "DPM++ SDE Karras")) | |
| # Adjust the text encoder layers if needed using clip_skip | |
| if clip_skip > 0: | |
| pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1) | |
| # Compile model if environment variable is set | |
| if USE_TORCH_COMPILE: | |
| pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) | |
| print("Model Compiled!") | |
| return pipe | |
| def __call__(self, data): | |
| """ | |
| Process the inference request. | |
| This is called for each inference request by the Hugging Face Inference API. | |
| Args: | |
| data: The input data for the inference request | |
| For HF Inference Endpoints, this is typically a dict with "inputs" field | |
| Returns: | |
| list: A list containing the generated image as base64 string and seed | |
| This follows the HF Inference Endpoints output format | |
| """ | |
| # Validate that the model is loaded | |
| if not hasattr(self, 'pipe') or self.pipe is None: | |
| return {"error": "Model not loaded. Please check initialization logs."} | |
| # Parse the request payload | |
| try: | |
| if isinstance(data, dict): | |
| payload = data | |
| else: | |
| # Assuming the request is a JSON string | |
| payload = json.loads(data) | |
| except Exception as e: | |
| return {"error": f"Failed to parse request data: {str(e)}"} | |
| # Extract parameters from the payload | |
| parameters = {} | |
| if "parameters" in payload and isinstance(payload["parameters"], dict): | |
| # HF Inference Endpoints format: {"inputs": "prompt", "parameters": {...}} | |
| parameters = payload["parameters"] | |
| # Get the prompt from the payload | |
| prompt_text = payload.get("inputs", "") | |
| if not prompt_text: | |
| # Try to get prompt from different fields for compatibility | |
| prompt_text = payload.get("prompt", "") | |
| if not prompt_text: | |
| return {"error": "No prompt provided. Please include 'inputs' or 'prompt' field."} | |
| # Apply child-content filtering to the prompt | |
| if contains_child_related_content(prompt_text): | |
| prompt_text = remove_child_related_content(prompt_text) | |
| # Replace placeholder in the prompt template from config | |
| combined_prompt = self.cfg.get("prompt", "{prompt}").replace("{prompt}", prompt_text) | |
| # Use negative_prompt from parameters or payload, fall back to config | |
| negative_prompt = parameters.get("negative_prompt", payload.get("negative_prompt", self.cfg.get("negative_prompt", ""))) | |
| # Get dimensions from config (default to 1024x768 if not specified) | |
| width = int(self.cfg.get("width", 1024)) | |
| height = int(self.cfg.get("height", 768)) | |
| # Other generation parameters | |
| inference_steps = int(parameters.get("inference_steps", payload.get("inference_steps", self.cfg.get("inference_steps", 30)))) | |
| guidance_scale = float(parameters.get("guidance_scale", payload.get("guidance_scale", self.cfg.get("guidance_scale", 7)))) | |
| # Use provided seed or generate a random one | |
| seed = int(parameters.get("seed", payload.get("seed", random.randint(0, MAX_SEED)))) | |
| generator = torch.Generator(self.pipe.device).manual_seed(seed) | |
| try: | |
| # Generate the image using the pipeline | |
| outputs = self.pipe( | |
| prompt=combined_prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=inference_steps, | |
| generator=generator, | |
| num_images_per_prompt=NUM_IMAGES_PER_PROMPT, | |
| output_type="pil" | |
| ) | |
| # Convert the first generated image to base64 | |
| img_base64 = pil_image_to_base64(outputs.images[0]) | |
| # Return the response formatted for Hugging Face Inference Endpoints | |
| return [{"generated_image": img_base64, "seed": seed}] | |
| except Exception as e: | |
| # Log the error and return an error response | |
| error_message = f"Image generation failed: {str(e)}" | |
| print(error_message) | |
| return {"error": error_message} | |
| # For local testing without HF Inference Endpoints | |
| if __name__ == "__main__": | |
| import argparse | |
| import uvicorn | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| # Parse command-line arguments | |
| parser = argparse.ArgumentParser(description="Run the text-to-image API locally") | |
| parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") | |
| args = parser.parse_args() | |
| # Create FastAPI app | |
| app = FastAPI(title="Text-to-Image API with Content Filtering") | |
| # Initialize the handler | |
| handler = EndpointHandler() | |
| async def read_root(): | |
| """Health check endpoint.""" | |
| return {"status": "ok", "message": "Text-to-Image API is running"} | |
| async def generate_image(request: Request): | |
| """Main inference endpoint.""" | |
| try: | |
| body = await request.json() | |
| result = handler(body) | |
| if "error" in result: | |
| return JSONResponse(status_code=500, content={"error": result["error"]}) | |
| return result | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Failed to process request: {str(e)}"} | |
| ) | |
| # Run the server | |
| print(f"Starting server on http://{args.host}:{args.port}") | |
| uvicorn.run(app, host=args.host, port=args.port) |