| from typing import List, Union |
|
|
| import numpy as np |
| import torch |
| from diffusers.modular_pipelines import ( |
| ComponentSpec, |
| InputParam, |
| ModularPipelineBlocks, |
| OutputParam, |
| PipelineState, |
| ) |
| from PIL import Image, ImageDraw |
| from transformers import AutoProcessor, Florence2ForConditionalGeneration |
|
|
|
|
| class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
| @property |
| def expected_components(self): |
| return [ |
| ComponentSpec( |
| name="image_annotator", |
| type_hint=Florence2ForConditionalGeneration, |
| repo="florence-community/Florence-2-base-ft", |
| ), |
| ComponentSpec( |
| name="image_annotator_processor", |
| type_hint=AutoProcessor, |
| repo="florence-community/Florence-2-base-ft", |
| ), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=Union[Image.Image, List[Image.Image]], |
| required=True, |
| description="Image(s) to annotate", |
| metadata={"mellon":"image"}, |
| ), |
| InputParam( |
| "annotation_task", |
| type_hint=Union[str, List[str]], |
| default="<REFERRING_EXPRESSION_SEGMENTATION>", |
| metadata={"mellon":"dropdown"}, |
| description="""Annotation Task to perform on the image. |
| Supported Tasks: |
| |
| <OD> |
| <REFERRING_EXPRESSION_SEGMENTATION> |
| <CAPTION> |
| <DETAILED_CAPTION> |
| <MORE_DETAILED_CAPTION> |
| <DENSE_REGION_CAPTION> |
| <REGION_PROPOSAL> |
| <CAPTION_TO_PHRASE_GROUNDING> |
| <OPEN_VOCABULARY_DETECTION> |
| <OCR> |
| <OCR_WITH_REGION> |
| |
| """, |
| ), |
| InputParam( |
| "annotation_prompt", |
| type_hint=Union[str, List[str]], |
| required=True, |
| metadata={"mellon":"textbox"}, |
| description="""Annotation Prompt to provide more context to the task. |
| Can be used to detect or segment out specific elements in the image |
| """, |
| ), |
| InputParam( |
| "annotation_output_type", |
| type_hint=str, |
| default="mask_image", |
| metadata={"mellon":"dropdown"}, |
| description="""Output type from annotation predictions. Availabe options are |
| annotation: |
| - raw annotation predictions from the model based on task type. |
| mask_image: |
| -black and white mask image for the given image based on the task type |
| mask_overlay: |
| - white mask overlayed on the original image |
| bounding_box: |
| - bounding boxes drawn on the original image |
| """, |
| ), |
| InputParam( |
| "annotation_overlay", |
| type_hint=bool, |
| required=True, |
| default=False, |
| description="", |
| metadata={"mellon":"checkbox"}, |
| ), |
| InputParam( |
| "fill", |
| type_hint=str, |
| default="white", |
| description="", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "annotations", |
| type_hint=dict, |
| description="Annotations Predictions for input Image(s)", |
| ), |
| OutputParam( |
| "images", |
| type_hint=Image, |
| description="Annotated input Image(s)", |
| metadata={"mellon":"image"}, |
| ), |
| ] |
|
|
| def get_annotations(self, components, images, prompts, task): |
| task_prompts = [task + prompt for prompt in prompts] |
|
|
| inputs = components.image_annotator_processor( |
| text=task_prompts, images=images, return_tensors="pt" |
| ).to(components.image_annotator.device, components.image_annotator.dtype) |
|
|
| generated_ids = components.image_annotator.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| early_stopping=False, |
| do_sample=False, |
| num_beams=3, |
| ) |
| annotations = components.image_annotator_processor.batch_decode( |
| generated_ids, skip_special_tokens=False |
| ) |
|
|
| outputs = [] |
| for image, annotation in zip(images, annotations): |
| outputs.append( |
| components.image_annotator_processor.post_process_generation( |
| annotation, task=task, image_size=(image.width, image.height) |
| ) |
| ) |
|
|
| return outputs |
|
|
| def _iter_polygon_point_sets(self, poly): |
| """ |
| Yields lists of (x, y) points for all simple polygons found in `poly`. |
| Supports formats: |
| - [x1, y1, x2, y2, ...] |
| - [[x, y], [x, y], ...] |
| - [xs, ys] |
| - dict {'x': xs, 'y': ys} |
| - nested lists containing any of the above |
| """ |
| if poly is None: |
| return |
|
|
| def is_num(v): |
| return isinstance(v, (int, float, np.number)) |
|
|
| |
| if isinstance(poly, dict) and "x" in poly and "y" in poly: |
| xs, ys = poly["x"], poly["y"] |
| if ( |
| isinstance(xs, (list, tuple)) |
| and isinstance(ys, (list, tuple)) |
| and len(xs) == len(ys) |
| ): |
| pts = list(zip(xs, ys)) |
| if len(pts) >= 3: |
| yield pts |
| return |
|
|
| if isinstance(poly, (list, tuple)): |
| |
| if all(is_num(v) for v in poly): |
| coords = list(poly) |
| if len(coords) >= 6 and len(coords) % 2 == 0: |
| yield list(zip(coords[0::2], coords[1::2])) |
| return |
|
|
| |
| if all( |
| isinstance(v, (list, tuple)) |
| and len(v) == 2 |
| and all(is_num(n) for n in v) |
| for v in poly |
| ): |
| if len(poly) >= 3: |
| yield [tuple(v) for v in poly] |
| return |
|
|
| |
| if len(poly) == 2 and all(isinstance(v, (list, tuple)) for v in poly): |
| xs, ys = poly |
| try: |
| if len(xs) == len(ys) and len(xs) >= 3: |
| yield list(zip(xs, ys)) |
| return |
| except TypeError: |
| pass |
|
|
| |
| for part in poly: |
| yield from self._iter_polygon_point_sets(part) |
| |
|
|
| def prepare_mask(self, images, annotations, overlay=False, fill="white"): |
| masks = [] |
| for image, annotation in zip(images, annotations): |
| mask_image = image.copy() if overlay else Image.new("L", image.size, 0) |
| draw = ImageDraw.Draw(mask_image) |
|
|
| |
| mask_fill = fill |
| if not overlay and isinstance(fill, str): |
| |
| mask_fill = 255 |
|
|
| for _, _annotation in annotation.items(): |
| if "polygons" in _annotation: |
| for poly in _annotation["polygons"]: |
| for pts in self._iter_polygon_point_sets(poly): |
| if len(pts) < 3: |
| continue |
| |
| flat = [] |
| for x, y in pts: |
| xi = int(round(max(0, min(image.width - 1, x)))) |
| yi = int(round(max(0, min(image.height - 1, y)))) |
| flat.extend([xi, yi]) |
| draw.polygon(flat, fill=mask_fill) |
|
|
| elif "bboxes" in _annotation: |
| for bbox in _annotation["bboxes"]: |
| flat = np.array(bbox).flatten().tolist() |
| if len(flat) == 4: |
| x0, y0, x1, y1 = flat |
| draw.rectangle( |
| ( |
| int(round(x0)), |
| int(round(y0)), |
| int(round(x1)), |
| int(round(y1)), |
| ), |
| fill=mask_fill, |
| ) |
|
|
| elif "quad_boxes" in _annotation: |
| for quad in _annotation["quad_boxes"]: |
| for pts in self._iter_polygon_point_sets(quad): |
| if len(pts) < 3: |
| continue |
| flat = [] |
| for x, y in pts: |
| xi = int(round(max(0, min(image.width - 1, x)))) |
| yi = int(round(max(0, min(image.height - 1, y)))) |
| flat.extend([xi, yi]) |
| draw.polygon(flat, fill=mask_fill) |
|
|
| masks.append(mask_image) |
|
|
| return masks |
|
|
| def prepare_bounding_boxes(self, images, annotations): |
| outputs = [] |
| for image, annotation in zip(images, annotations): |
| image_copy = image.copy() |
| draw = ImageDraw.Draw(image_copy) |
| for _, _annotation in annotation.items(): |
| |
| bboxes = _annotation.get("bboxes", []) |
| labels = _annotation.get("labels", []) |
|
|
| if len(labels) == 0: |
| labels = _annotation.get("bboxes_labels", []) |
|
|
| for i, bbox in enumerate(bboxes): |
| flat = np.array(bbox).flatten().tolist() |
|
|
| if len(flat) != 4: |
| continue |
|
|
| x0, y0, x1, y1 = flat |
| draw.rectangle( |
| ( |
| int(round(x0)), |
| int(round(y0)), |
| int(round(x1)), |
| int(round(y1)), |
| ), |
| outline="red", |
| width=3, |
| ) |
| label = labels[i] if i < len(labels) else "" |
| if label: |
| text_y = max(0, int(y0) - 20) |
| draw.text((int(x0), text_y), label, fill="red") |
|
|
| |
| quad_boxes = _annotation.get("quad_boxes", []) |
| qlabels = _annotation.get("labels", []) |
| for i, quad in enumerate(quad_boxes): |
| for pts in self._iter_polygon_point_sets(quad): |
| if len(pts) < 3: |
| continue |
| flat = [] |
| xs, ys = [], [] |
| for x, y in pts: |
| xi = int(round(max(0, min(image.width - 1, x)))) |
| yi = int(round(max(0, min(image.height - 1, y)))) |
| flat.extend([xi, yi]) |
| xs.append(xi) |
| ys.append(yi) |
|
|
| |
| try: |
| draw.polygon(flat, outline="red", width=3) |
| except TypeError: |
| |
| draw.polygon(flat, outline="red") |
|
|
| |
| label = qlabels[i] if i < len(qlabels) else "" |
| if label: |
| cx = int(round(sum(xs) / len(xs))) |
| cy = int(round(sum(ys) / len(ys))) |
| cx = max(0, min(image.width - 1, cx)) |
| cy = max(0, min(image.height - 1, cy)) |
| draw.text((cx, cy), label, fill="red") |
|
|
| outputs.append(image_copy) |
|
|
| return outputs |
|
|
| def prepare_inputs(self, images, prompts): |
| prompts = prompts or "" |
|
|
| if isinstance(images, Image.Image): |
| images = [images] |
| if isinstance(prompts, str): |
| prompts = [prompts] |
|
|
| if len(images) != len(prompts): |
| raise ValueError("Number of images and annotation prompts must match.") |
|
|
| return images, prompts |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| skip_image = False |
|
|
| |
| if ( |
| block_state.annotation_task == "<OD>" |
| or block_state.annotation_task == "<DENSE_REGION_CAPTION>" |
| or block_state.annotation_task == "<REGION_PROPOSAL>" |
| or block_state.annotation_task == "<OCR_WITH_REGION>" |
| ): |
| block_state.annotation_prompt = "" |
| block_state.annotation_output_type = "bounding_box" |
| |
| elif ( |
| block_state.annotation_task == "<CAPTION>" |
| or block_state.annotation_task == "<DETAILED_CAPTION>" |
| or block_state.annotation_task == "<MORE_DETAILED_CAPTION>" |
| or block_state.annotation_task == "<OCR>" |
| ): |
| block_state.annotation_prompt = "" |
| skip_image = True |
|
|
| images, annotation_task_prompt = self.prepare_inputs( |
| block_state.image, block_state.annotation_prompt |
| ) |
| task = block_state.annotation_task |
| fill = block_state.fill |
|
|
| annotations = self.get_annotations( |
| components, images, annotation_task_prompt, task |
| ) |
|
|
| block_state.annotations = annotations |
| block_state.images = None |
|
|
| if not skip_image: |
| if block_state.annotation_output_type == "mask_image": |
| block_state.images = self.prepare_mask(images, annotations) |
|
|
| if block_state.annotation_output_type == "mask_overlay": |
| block_state.images = self.prepare_mask( |
| images, annotations, overlay=True, fill=fill |
| ) |
| elif block_state.annotation_output_type == "bounding_box": |
| block_state.images = self.prepare_bounding_boxes(images, annotations) |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|