| from typing import List, Union |
|
|
| import numpy as np |
| import torch |
| from diffusers.modular_pipelines import ( |
| ComponentSpec, |
| InputParam, |
| ModularPipelineBlocks, |
| OutputParam, |
| PipelineState, |
| ) |
| from PIL import Image, ImageDraw |
| from transformers import Florence2ForConditionalGeneration, AutoProcessor |
|
|
|
|
| class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
| @property |
| def expected_components(self): |
| return [ |
| ComponentSpec( |
| name="image_annotator", |
| type_hint=Florence2ForConditionalGeneration, |
| repo="florence-community/Florence-2-base-ft", |
| ), |
| ComponentSpec( |
| name="image_annotator_processor", |
| type_hint=AutoProcessor, |
| repo="florence-community/Florence-2-base-ft", |
| ), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=Union[Image.Image, List[Image.Image]], |
| required=True, |
| description="Image(s) to annotate", |
| ), |
| InputParam( |
| "annotation_task", |
| type_hint=Union[str, List[str]], |
| required=True, |
| default="<REFERRING_EXPRESSION_SEGMENTATION>", |
| description="""Annotation Task to perform on the image. |
| Supported Tasks: |
| |
| <OD> |
| <REFERRING_EXPRESSION_SEGMENTATION> |
| <CAPTION> |
| <DETAILED_CAPTION> |
| <MORE_DETAILED_CAPTION> |
| <DENSE_REGION_CAPTION> |
| <CAPTION_TO_PHRASE_GROUNDING> |
| <OPEN_VOCABULARY_DETECTION> |
| |
| """, |
| ), |
| InputParam( |
| "annotation_prompt", |
| type_hint=Union[str, List[str]], |
| required=True, |
| description="""Annotation Prompt to provide more context to the task. |
| Can be used to detect or segment out specific elements in the image |
| """, |
| ), |
| InputParam( |
| "annotation_output_type", |
| type_hint=str, |
| required=True, |
| default="mask_image", |
| description="""Output type from annotation predictions. Availabe options are |
| annotation: |
| - raw annotation predictions from the model based on task type. |
| mask_image: |
| -black and white mask image for the given image based on the task type |
| mask_overlay: |
| - white mask overlayed on the original image |
| bounding_box: |
| - bounding boxes drawn on the original image |
| """, |
| ), |
| InputParam( |
| "annotation_overlay", |
| type_hint=bool, |
| required=True, |
| default=False, |
| description="", |
| ), |
| InputParam( |
| "fill", |
| type_hint=str, |
| default="white", |
| description="", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "mask_image", |
| type_hint=Image, |
| description="Inpainting Mask for input Image(s)", |
| ), |
| OutputParam( |
| "annotations", |
| type_hint=dict, |
| description="Annotations Predictions for input Image(s)", |
| ), |
| OutputParam( |
| "image", |
| type_hint=Image, |
| description="Annotated input Image(s)", |
| ), |
| ] |
|
|
| def get_annotations(self, components, images, prompts, task): |
| task_prompts = [task + prompt for prompt in prompts] |
|
|
| inputs = components.image_annotator_processor( |
| text=task_prompts, images=images, return_tensors="pt" |
| ).to(components.image_annotator.device, components.image_annotator.dtype) |
|
|
| generated_ids = components.image_annotator.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| early_stopping=False, |
| do_sample=False, |
| num_beams=3, |
| ) |
| annotations = components.image_annotator_processor.batch_decode( |
| generated_ids, skip_special_tokens=False |
| ) |
| outputs = [] |
| for image, annotation in zip(images, annotations): |
| outputs.append( |
| components.image_annotator_processor.post_process_generation( |
| annotation, task=task, image_size=(image.width, image.height) |
| ) |
| ) |
| return outputs |
|
|
| def prepare_mask(self, images, annotations, overlay=False, fill="white"): |
| masks = [] |
| for image, annotation in zip(images, annotations): |
| mask_image = image.copy() if overlay else Image.new("L", image.size, 0) |
| draw = ImageDraw.Draw(mask_image) |
|
|
| for _, _annotation in annotation.items(): |
| if "polygons" in _annotation: |
| for polygon in _annotation["polygons"]: |
| polygon = np.array(polygon).reshape(-1, 2) |
| if len(polygon) < 3: |
| continue |
| polygon = polygon.reshape(-1).tolist() |
| draw.polygon(polygon, fill=fill) |
|
|
| elif "bbox" in _annotation: |
| bbox = _annotation["bbox"] |
| draw.rectangle(bbox, fill="white") |
|
|
| masks.append(mask_image) |
|
|
| return masks |
|
|
| def prepare_bounding_boxes(self, images, annotations): |
| outputs = [] |
| for image, annotation in zip(images, annotations): |
| image_copy = image.copy() |
| draw = ImageDraw.Draw(image_copy) |
| for _, _annotation in annotation.items(): |
| bbox = _annotation["bbox"] |
| label = _annotation["label"] |
|
|
| draw.rectangle(bbox, outline="red", width=3) |
| draw.text((bbox[0], bbox[1] - 20), label, fill="red") |
|
|
| outputs.append(image_copy) |
|
|
| return outputs |
|
|
| def prepare_inputs(self, images, prompts): |
| prompts = prompts or "" |
|
|
| if isinstance(images, Image.Image): |
| images = [images] |
| if isinstance(prompts, str): |
| prompts = [prompts] |
|
|
| if len(images) != len(prompts): |
| raise ValueError("Number of images and annotation prompts must match.") |
|
|
| return images, prompts |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
| images, annotation_task_prompt = self.prepare_inputs( |
| block_state.image, block_state.annotation_prompt |
| ) |
| task = block_state.annotation_task |
| fill = block_state.fill |
|
|
| annotations = self.get_annotations( |
| components, images, annotation_task_prompt, task |
| ) |
| block_state.annotations = annotations |
| if block_state.annotation_output_type == "mask_image": |
| block_state.mask_image = self.prepare_mask(images, annotations) |
| else: |
| block_state.mask_image = None |
|
|
| if block_state.annotation_output_type == "mask_overlay": |
| block_state.image = self.prepare_mask( |
| images, annotations, overlay=True, fill=fill |
| ) |
|
|
| elif block_state.annotation_output_type == "bounding_box": |
| block_state.image = self.prepare_bounding_boxes(images, annotations) |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|