| """ |
| Custom Inference Handler for SigLIP2-base-patch16-512 |
| Supports: zero_shot, image_embedding, text_embedding, similarity |
| Returns 768D embeddings. |
| """ |
| from typing import Any, Dict, List, Union |
| import torch |
| from PIL import Image |
| import requests |
| from io import BytesIO |
| import base64 |
| from transformers import AutoProcessor, AutoModel |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device) |
| self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| self.model.eval() |
|
|
| def _load_image(self, image_data: Any) -> Image.Image: |
| if isinstance(image_data, str): |
| if image_data.startswith(("http://", "https://")): |
| response = requests.get(image_data, timeout=10) |
| response.raise_for_status() |
| return Image.open(BytesIO(response.content)).convert("RGB") |
| else: |
| if "," in image_data: |
| image_data = image_data.split(",")[1] |
| image_bytes = base64.b64decode(image_data) |
| return Image.open(BytesIO(image_bytes)).convert("RGB") |
| elif isinstance(image_data, bytes): |
| return Image.open(BytesIO(image_data)).convert("RGB") |
| raise ValueError(f"Unsupported image format: {type(image_data)}") |
|
|
| def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor: |
| inputs = self.processor(images=images, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| features = self.model.get_image_features(**inputs) |
| return features / features.norm(dim=-1, keepdim=True) |
|
|
| def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor: |
| inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| features = self.model.get_text_features(**inputs) |
| return features / features.norm(dim=-1, keepdim=True) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Any: |
| inputs = data.get("inputs", data) |
| parameters = data.get("parameters", {}) |
| mode = parameters.get("mode", "auto") |
|
|
| |
| if mode == "auto": |
| if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs): |
| mode = "similarity" |
| elif "candidate_labels" in parameters: |
| mode = "zero_shot" |
| elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500: |
| mode = "text_embedding" |
| elif isinstance(inputs, list) and all( |
| isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs |
| ): |
| mode = "text_embedding" |
| else: |
| mode = "image_embedding" |
|
|
| if mode == "zero_shot": |
| return self._zero_shot(inputs, parameters) |
| elif mode == "image_embedding": |
| return self._image_embedding(inputs) |
| elif mode == "text_embedding": |
| return self._text_embedding(inputs) |
| elif mode == "similarity": |
| return self._similarity(inputs) |
| else: |
| raise ValueError(f"Unknown mode: {mode}") |
|
|
| def _zero_shot(self, inputs, parameters): |
| candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"]) |
| if isinstance(candidate_labels, str): |
| candidate_labels = [l.strip() for l in candidate_labels.split(",")] |
|
|
| images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] |
| image_embeds = self._get_image_embeddings(images) |
| text_embeds = self._get_text_embeddings(candidate_labels) |
|
|
| logits = image_embeds @ text_embeds.T |
| probs = torch.softmax(logits, dim=-1) |
|
|
| results = [] |
| for i, prob in enumerate(probs): |
| scores = prob.cpu().tolist() |
| result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])] |
| results.append(result) |
|
|
| return results[0] if len(results) == 1 else results |
|
|
| def _image_embedding(self, inputs): |
| images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] |
| embeddings = self._get_image_embeddings(images) |
| return [{"embedding": emb.cpu().tolist()} for emb in embeddings] |
|
|
| def _text_embedding(self, inputs): |
| texts = [inputs] if isinstance(inputs, str) else inputs |
| embeddings = self._get_text_embeddings(texts) |
| return [{"embedding": emb.cpu().tolist()} for emb in embeddings] |
|
|
| def _similarity(self, inputs): |
| image_input = inputs.get("image") or inputs.get("images") |
| text_input = inputs.get("text") or inputs.get("texts") |
|
|
| images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input] |
| texts = [text_input] if isinstance(text_input, str) else text_input |
|
|
| image_embeds = self._get_image_embeddings(images) |
| text_embeds = self._get_text_embeddings(texts) |
|
|
| similarity = (image_embeds @ text_embeds.T).cpu().tolist() |
| return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)} |