| """Processor class for MarkupDM.""" |
|
|
| import math |
| import re |
| import shutil |
| import subprocess |
| import tempfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from .fonts import FontManager |
| from PIL import Image, ImageDraw |
| from transformers import ( |
| ImageProcessingMixin, |
| PreTrainedModel, |
| PreTrainedTokenizerBase, |
| ProcessorMixin, |
| ) |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
| MAXIMUM_DECODE_IMAGE_SIZE = 4096 |
| IMG_FORMAT = "{:03d}.png" |
| FONT_FORMAT = "{:03d}.ttf" |
|
|
|
|
| class MarkupDMProcessor(ProcessorMixin): |
| attributes = ["tokenizer", "image_processor"] |
|
|
| |
| tokenizer_class = "AutoTokenizer" |
| tokenizer: PreTrainedTokenizerBase |
|
|
| |
| image_processor_class = "AutoImageProcessor" |
| image_processor: ImageProcessingMixin |
|
|
| def __init__( |
| self, |
| tokenizer: PreTrainedTokenizerBase, |
| image_processor: ImageProcessingMixin, |
| ): |
| super().__init__(tokenizer, image_processor) |
|
|
| |
| if "<begin_of_image>" not in tokenizer.additional_special_tokens: |
| self.extend_base_tokenizer(self.tokenizer) |
|
|
| |
| boi = "<begin_of_image>" |
| img_sep = "<image_sep>" |
| self.re_img_size = re.compile(rf"{boi}(\d+){img_sep}(\d+){img_sep}") |
| self.re_svg_width = re.compile(r'<svg[^>]*\bwidth="(\d+)"[^>]*>') |
| self.re_svg_height = re.compile(r'<svg[^>]*\bheight="(\d+)"[^>]*>') |
|
|
| |
| self.font_manager = None |
|
|
| def extend_base_tokenizer(self, tokenizer: PreTrainedTokenizerBase) -> None: |
| logger.info("Extending tokenizer...") |
| tokenizer.clean_up_tokenization_spaces = False |
|
|
| |
| additional_special_tokens = [ |
| "<begin_of_image>", |
| "<end_of_image>", |
| "<image_sep>", |
| "<image_token>", |
| ] |
| logger.info(f"Add special tokens: {additional_special_tokens}") |
| tokenizer.add_special_tokens( |
| {"additional_special_tokens": additional_special_tokens}, |
| replace_additional_special_tokens=False, |
| ) |
|
|
| def __call__( |
| self, |
| svg: str | None = None, |
| images: list[Image.Image] | None = None, |
| filenames: list[str] | None = None, |
| vision_model: PreTrainedModel | None = None, |
| ) -> dict: |
| |
| if not isinstance(images, list): |
| images = [images] |
|
|
| if len(images) > 0 and images[0] is not None: |
| output = self.preprocess_images(images) |
| output = self.encode_images(output, vision_model) |
| else: |
| output = {"width": [], "height": [], "image_ids": []} |
|
|
| |
| output.update({"svg": svg, "filenames": filenames}) |
| output = self.tokenize_example(output) |
|
|
| return output |
|
|
| def preprocess_images(self, images: list[Image.Image]) -> dict: |
| assert images is not None, "Images must be provided." |
| output: dict = {"image": [], "width": [], "height": []} |
|
|
| for image in images: |
| processed = self.image_processor(image) |
| for key, value in processed.items(): |
| output[key].append(value) |
|
|
| |
| output["image"] = torch.stack(output["image"]) |
|
|
| return output |
|
|
| def encode_images(self, example: dict, vision_model: PreTrainedModel) -> dict: |
| if "images" in example and "width" not in example: |
| example = self.preprocess_images(example["images"]) |
|
|
| assert vision_model is not None, "Vision model must be provided." |
| image = example.pop("image") |
| image = image.to(dtype=vision_model.dtype, device=vision_model.device) |
| with torch.inference_mode(): |
| _, _, (_, _, image_ids) = vision_model.model.encode(image) |
| example["image_ids"] = list(image_ids.view(image.size(0), -1).cpu()) |
|
|
| return example |
|
|
| def tokenize_example(self, example: dict) -> dict: |
| |
| for key in ["svg", "filenames", "width", "height", "image_ids"]: |
| msg = f"Missing key: {key}." |
| if key in ["width", "height", "image_ids"]: |
| msg += " Images must be encoded first using `encode_images`." |
| assert example.get(key, None) is not None, msg |
|
|
| tokenizer = self.tokenizer |
| bos_id = tokenizer.bos_token_id |
| eos_id = tokenizer.eos_token_id |
| bos_id = bos_id if bos_id is not None else eos_id |
| boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>") |
| eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>") |
| img_sep_id = tokenizer.convert_tokens_to_ids("<image_sep>") |
|
|
| |
| name2token = {} |
| for filename, image_ids, width, height in zip( |
| example["filenames"], |
| example["image_ids"], |
| example["width"], |
| example["height"], |
| ): |
| _image_ids = (image_ids + len(tokenizer)).tolist() |
| W_tokens = tokenizer.encode(str(width)) |
| H_tokens = tokenizer.encode(str(height)) |
|
|
| |
| image_tokens = [ |
| boi_id, |
| *W_tokens, |
| img_sep_id, |
| *H_tokens, |
| img_sep_id, |
| *_image_ids, |
| eoi_id, |
| ] |
|
|
| name2token[filename] = image_tokens |
|
|
| |
| |
| tokens = [bos_id] |
| svg = example["svg"] |
| while svg: |
| |
| start, end = len(svg), len(svg) |
| for name in name2token.keys(): |
| _start = svg.find(name) |
| if -1 < _start and _start < start: |
| start = _start |
| end = start + len(name) |
|
|
| |
| tokens += tokenizer.encode(svg[:start]) |
|
|
| |
| if start < end: |
| tokens += name2token[svg[start:end]] |
|
|
| |
| svg = svg[end:] |
|
|
| tokens.append(eos_id) |
|
|
| |
| input_ids = torch.tensor(tokens) |
| image_mask = input_ids >= len(tokenizer) |
|
|
| |
| image_pos_ids = torch.zeros_like(input_ids) |
| if len(example["image_ids"]) > 0: |
| length = example["image_ids"][0].size(0) |
| num_images = sum(image_mask) // length |
| image_pos_ids[image_mask] = torch.arange(length).repeat(num_images) |
|
|
| return { |
| "input_ids": input_ids, |
| "image_mask": image_mask, |
| "image_pos_ids": image_pos_ids, |
| } |
|
|
| def decode( |
| self, |
| tokens: torch.Tensor | np.ndarray, |
| vision_model: PreTrainedModel | None = None, |
| ) -> dict: |
| tokenizer = self.tokenizer |
| bos = tokenizer.bos_token |
| eos = tokenizer.eos_token |
| bos = bos if bos is not None else eos |
|
|
| |
| msg = "Should be reverted from FIM format before decoding." |
| for fim_type in ["prefix", "middle", "suffix"]: |
| token_id = tokenizer.convert_tokens_to_ids(f"<fim_{fim_type}>") |
| if token_id is None: |
| token_id = tokenizer.convert_tokens_to_ids(f"<|fim_{fim_type}|>") |
| assert token_id is not None, f"{fim_type} token not found" |
| assert token_id not in tokens, msg |
|
|
| tokens = torch.asarray(tokens).detach().cpu() |
| assert tokens.ndim == 1, "Tokens must be 1D." |
| boi_id = tokenizer.convert_tokens_to_ids("<begin_of_image>") |
| eoi_id = tokenizer.convert_tokens_to_ids("<end_of_image>") |
|
|
| |
| svg = "" |
| images: list = [] |
| filenames: list = [] |
| while len(tokens) > 0: |
| |
| boi_idx = torch.where(tokens == boi_id)[0] |
| eoi_idx = torch.where(tokens == eoi_id)[0] |
| if boi_idx.size(0) > 0: |
| start = int(boi_idx[0].item()) |
| end = int(eoi_idx[0].item()) + 1 if eoi_idx.size(0) > 0 else len(tokens) |
| assert start < end, "Invalid image tokens." |
| else: |
| start, end = len(tokens), len(tokens) |
|
|
| |
| svg += tokenizer.decode(tokens[:start]) |
|
|
| |
| if start < end: |
| |
| image_tokens = tokens[start:end] |
| image_text = tokenizer.decode(image_tokens) |
| matched = self.re_img_size.match(image_text) |
| if matched is not None: |
| width, height = map(int, matched.groups()) |
| else: |
| width = self.image_processor.size |
| height = self.image_processor.size |
|
|
| |
| image_mask = image_tokens >= len(tokenizer) |
| image_ids = image_tokens[image_mask] - len(tokenizer) |
| image = self.decode_image(vision_model, image_ids, width, height) |
| filename = IMG_FORMAT.format(len(images)) |
| svg += filename |
|
|
| images.append(image) |
| filenames.append(filename) |
|
|
| |
| tokens = tokens[end:] |
|
|
| |
| svg = re.sub(rf"({re.escape(bos)})+", bos, svg) |
| svg = re.sub(rf"({re.escape(eos)})+", eos, svg) |
|
|
| |
| i_bos = svg.find(bos) |
| svg = svg[i_bos + len(bos) :] if i_bos > -1 else svg |
| i_eos = svg.find(eos, i_bos + 1) |
| svg = svg[:i_eos] if i_eos > -1 else svg |
|
|
| return {"svg": svg, "images": images, "filenames": filenames} |
|
|
| def decode_image( |
| self, |
| vision_model: PreTrainedModel | None = None, |
| image_ids: torch.Tensor | np.ndarray | None = None, |
| width: int | None = None, |
| height: int | None = None, |
| dummy_color: tuple[int, int, int, int] = (200,) * 4, |
| pad_value: int = 0, |
| ) -> Image.Image: |
| |
| width = width or self.image_processor.size |
| height = height or self.image_processor.size |
| width, height = self.compute_safe_image_size(width, height) |
|
|
| if vision_model is None and image_ids is None: |
| |
| return Image.new("RGBA", (width, height), dummy_color) |
|
|
| |
| assert vision_model is not None, "Vision model must be provided." |
| scale_factor = 2 ** (vision_model.model.encoder.num_resolutions - 1) |
| latent_size = self.image_processor.size // scale_factor |
| required_length = latent_size**2 |
|
|
| |
| image_ids = torch.asarray(image_ids, device=vision_model.device) |
| code_length = image_ids.shape[0] |
| if code_length < required_length: |
| pad_size = required_length - code_length |
| pad = torch.full((pad_size,), pad_value).to(image_ids) |
| image_ids = torch.cat([image_ids, pad]) |
|
|
| |
| with torch.inference_mode(): |
| codebook_entry = vision_model.model.quantize.get_codebook_entry( |
| image_ids, (1, latent_size, latent_size, -1) |
| ) |
| recon = vision_model.model.decode(codebook_entry)[0].float() |
|
|
| |
| img = self.image_processor.postprocess( |
| recon, self.image_processor.size, self.image_processor.size |
| ) |
|
|
| |
| if code_length < required_length: |
| img = self.mask_padded_area(img, code_length, scale_factor) |
|
|
| |
| img = img.resize((width, height), resample=self.image_processor.resample) |
|
|
| return img |
|
|
| def compute_safe_image_size(self, width: int, height: int) -> tuple[int, int]: |
| long_edge = max(width, height) |
| if MAXIMUM_DECODE_IMAGE_SIZE < long_edge: |
| scale = MAXIMUM_DECODE_IMAGE_SIZE / long_edge |
| width = min(max(int(width * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) |
| height = min(max(int(height * scale), 1), MAXIMUM_DECODE_IMAGE_SIZE) |
| return width, height |
|
|
| def mask_padded_area( |
| self, |
| img: Image.Image, |
| code_length: int, |
| scale_factor: int, |
| fill: tuple[int, int, int, int] = (200, 200, 200, 255), |
| ) -> Image.Image: |
| draw = ImageDraw.Draw(img, mode="RGBA") |
| width, height = img.size |
| zw = math.ceil(width / scale_factor) |
| cw = code_length % zw |
| ch = code_length // zw |
| draw.polygon( |
| [ |
| (cw * scale_factor, ch * scale_factor), |
| (width, ch * scale_factor), |
| (width, height), |
| (0, height), |
| (0, (ch + 1) * scale_factor), |
| (cw * scale_factor, (ch + 1) * scale_factor), |
| ], |
| fill=fill, |
| ) |
| return img |
|
|
| def set_font_manager(self, fonts_path: str | None = None) -> None: |
| self.font_manager = FontManager(fonts_path) |
|
|
| def render_preprocess(self, example: dict, out_dir: str | Path) -> None: |
| msg = "Font manager is not set. Call `set_font_manager` first." |
| assert self.font_manager is not None, msg |
|
|
| out_dir = Path(out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| svg = example["svg"] |
|
|
| |
| found = set() |
| style_text = "text{dominant-baseline:text-before-edge}" |
| for i, text_str in enumerate(re.findall("<text[^>]*>", svg)): |
| matched = re.search('font-family="([^"]*)"', text_str) |
| if matched is None: |
| logger.warning(f"Font family not found in {text_str}") |
| continue |
|
|
| |
| font_family = matched.group(1) |
| is_bold = 'font-weight="bold"' in text_str |
| is_italic = 'font-style="italic"' in text_str |
| font_weight = "bold" if is_bold else "regular" |
| if is_italic: |
| font_style = "bolditalic" if is_bold else "italic" |
| else: |
| font_style = font_weight |
| key = (font_family, font_weight, font_style) |
| if key in found: |
| continue |
|
|
| font_bytes = self.font_manager.lookup( |
| font_family=font_family, |
| font_weight=font_weight, |
| font_style=font_style, |
| ) |
|
|
| |
| font_path = FONT_FORMAT.format(i) |
| font_face = "@font-face{" |
| font_face += f"font-family:'{font_family}';" |
| font_face += f"font-weight:{font_weight};" |
| font_face += f"font-style:{font_style};" |
| font_face += f"src:url('{font_path}');" |
| font_face += "}" |
| style_text += font_face |
|
|
| |
| Path(f"{out_dir}/{font_path}").write_bytes(font_bytes) |
| found.add(key) |
|
|
| |
| matched = re.search("<svg[^>]*>", svg) |
| assert matched is not None, "SVG tag not found" |
| i = matched.span()[1] |
| style = f"<style>{style_text}</style>" |
| example["svg"] = svg[:i] + style + svg[i:] |
|
|
| def render(self, example: dict, save_dir: str | Path | None = None) -> Image.Image: |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| self.render_preprocess(example, tmp_dir) |
|
|
| |
| matched = self.re_svg_width.search(example["svg"]) |
| assert matched is not None, "Width not found in SVG." |
| width = int(matched.group(1)) |
| matched = self.re_svg_height.search(example["svg"]) |
| assert matched is not None, "Height not found in SVG." |
| height = int(matched.group(1)) |
|
|
| |
| html = '<!DOCTYPE html><html><body style="margin: 0px">' |
| html += f"{example['svg']}</body></html>" |
|
|
| |
| Path(f"{tmp_dir}/index.html").write_text(html, encoding="utf-8") |
|
|
| |
| for img, filename in zip(example["images"], example["filenames"]): |
| Path(f"{tmp_dir}/{filename}").parent.mkdir(parents=True, exist_ok=True) |
| img.save(f"{tmp_dir}/{filename}") |
|
|
| |
| command = [ |
| "google-chrome", |
| "--headless", |
| "--disable-web-security", |
| "--allow-running-insecure-content", |
| "--no-sandbox", |
| "--disable-infobars", |
| "--hide-scrollbars", |
| "--disable-dev-shm-usage", |
| "--no-zygote", |
| f"--window-size={width},{height}", |
| f"--screenshot={tmp_dir}/screenshot.png", |
| f"{tmp_dir}/index.html", |
| ] |
| subprocess.run(command, check=True, stderr=subprocess.DEVNULL) |
|
|
| |
| out = Image.open(f"{tmp_dir}/screenshot.png") |
| size = (width, height) |
| out = out.resize(size, resample=Image.Resampling.LANCZOS) |
|
|
| |
| if save_dir is not None: |
| shutil.copytree(tmp_dir, save_dir, dirs_exist_ok=True) |
|
|
| return out |
|
|