from typing import Optional, List import torch from torch import nn from torch.nn import functional as F import transformers from transformers import PreTrainedModel, AutoModel, AutoModelForCausalLM, GenerationConfig from transformers import AutoConfig from transformers.feature_extraction_utils import BatchFeature from .configuration_dmllm import DMLLMConfig from .modeling_abstractor import PerceiverProjection from .modeling_llada import LLaDAModelLM from .cache import * from .configuration_llada import LLaDAConfig def build_vision_model(config, model=None): assert hasattr(config, "name_or_path") if model is None: model = AutoModel.from_pretrained( config.name_or_path, config=config, trust_remote_code=True) return model def forward_process(bsz,seq_len,device, eps=1e-3): b, l = bsz,seq_len t = torch.rand(b, device=device) # t = torch.sigmoid(t) p_mask = (1 - eps) * t + eps p_mask = p_mask[:, None]#.repeat(1, l) masked_indices = torch.rand((b, l), device=device) mask_cutoff = torch.max(p_mask,masked_indices.min(-1,keepdim=True).values) masked_indices = masked_indices <= mask_cutoff # mask at least one token # 126336 is used for [MASK] token #noisy_batch = torch.where(masked_indices, 126336, input_ids) return masked_indices, p_mask def forward_process_blocks(bsz, seq_len, device, block_length=8, eps=1e-3): """ Block-level forward diffusion process for SDAR-v2 Args: bsz: batch size seq_len: sequence length device: torch device block_length: length of each block eps: minimum masking probability Returns: masked_indices: boolean tensor indicating which tokens to mask p_mask: masking probabilities """ b, l = bsz, seq_len t = torch.rand(b, device=device) p_mask = (1 - eps) * t + eps # Calculate number of blocks num_blocks = (l + block_length - 1) // block_length # Block-level masking probability block_p_mask = p_mask[:, None].expand(b, num_blocks) # [batch, num_blocks] # Decide which blocks to mask block_mask_decisions = torch.rand(b, num_blocks, device=device) < block_p_mask # Expand block decisions to token level masked_indices = torch.zeros(b, l, device=device, dtype=torch.bool) for i in range(num_blocks): start_idx = i * block_length end_idx = min((i + 1) * block_length, l) # If block is selected for masking, mask all tokens in the block masked_indices[:, start_idx:end_idx] = block_mask_decisions[:, i:i+1] # Add some randomness within blocks (optional) within_block_randomness = 0.2 # 20% chance to flip individual tokens random_flip = torch.rand(b, l, device=device) < within_block_randomness masked_indices = masked_indices ^ (random_flip & masked_indices) return masked_indices, p_mask def create_block_attention_mask(seq_len, block_length, device, batch_size=1): """ Create block diagonal attention mask for SDAR-v2 """ num_blocks = (seq_len + block_length - 1) // block_length # Create block-level lower triangular mask block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device, dtype=torch.bool)) # Expand to token level token_mask = block_mask.repeat_interleave(block_length, dim=0)\ .repeat_interleave(block_length, dim=1) # Crop to actual sequence length token_mask = token_mask[:seq_len, :seq_len] # Convert to 4D format [batch, 1, seq_len, seq_len] attention_mask = token_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) # Convert to additive mask format (0 for attend, -inf for mask) attention_mask = torch.where( attention_mask, torch.zeros_like(attention_mask, dtype=torch.float), torch.full_like(attention_mask, float('-inf')) ) return attention_mask class DMLLM(PreTrainedModel): config_class = DMLLMConfig supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_cache_class = False _supports_flash_attn_2 = True _supports_sdpa = True accepts_loss_kwargs=False def __init__(self, config: DMLLMConfig, language_model=None, vision_model=None, processor=None): super().__init__(config) self.image_size = config.image_size self.patch_size = config.patch_size self.downsample_ratio = config.downsample_ratio self.num_image_token = config.num_image_token self.vision_select_layer = config.vision_select_layer self.replacement_noise_mode = config.replacement_noise_mode try: vision_hidden_states = self.config.vision_model_config.hidden_size except: vision_hidden_states = self.config.vision_model_config.vision_config.hidden_size self.config.vision_model_config.hidden_size = vision_hidden_states vision_model = build_vision_model(config.vision_model_config, vision_model) vision_abstractor = PerceiverProjection(**config.vision_abstractor_config, in_dim=self.config.vision_model_config.hidden_size * (int(1 / self.downsample_ratio) ** 2), out_dim=self.config.language_model_config.hidden_size) if language_model is None: kwargs_ = {} if config._attn_implementation_internal is not None: kwargs_['attn_implementation'] = config._attn_implementation_internal if 'llada' in config.language_model_config.name_or_path.lower(): with transformers.modeling_utils.no_init_weights(): language_model = LLaDAModelLM(config.language_model_config) else: raise ValueError(f"Unsupported language model: {config.language_model_config.name_or_path}") self.vision_model = vision_model self.vision_abstractor = vision_abstractor self.language_model = language_model def forward_vision(self, pixel_values): # pixel_values: (n, c, h, w) or (b, tiles, c, h, w) # Handle BatchFeature input if isinstance(pixel_values, BatchFeature): pixel_values = pixel_values["pixel_values"] # Handle 5D input: (b, tiles, c, h, w) -> (b*tiles, c, h, w) if pixel_values.dim() == 5: b, tiles, c, h, w = pixel_values.shape pixel_values = pixel_values.view(b * tiles, c, h, w) # flags for dummy images (all-zero images) image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0 image_flags = image_flags.long() if image_flags.dim() > 1: image_flags = image_flags.squeeze(-1) # extract vision features if self.vision_select_layer == -1: image_embeddings = self.vision_model.vision_model( pixel_values=pixel_values, ).last_hidden_state else: image_embeddings = self.vision_model.vision_model( pixel_values=pixel_values, output_hidden_states=True ).hidden_states[self.vision_select_layer] # (B, N, C) vit_embeds = image_embeddings[image_flags == 1] if self.downsample_ratio != 1: patch_num = self.image_size // self.patch_size vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], patch_num, patch_num, vit_embeds.shape[-1]) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.flatten(1, 2) vit_embeds = self.vision_abstractor(vit_embeds) return vit_embeds def prepare_for_lm(self, input_ids, vision_embeds): inputs_embeds = self.get_input_embeddings()(input_ids) vision_embeds_ = vision_embeds if vision_embeds is not None: try: vision_mask = input_ids == self.config.image_token_id if torch.count_nonzero(vision_mask).item() != vision_embeds.shape[:-1].numel(): info = "vision embeddings mismatch input embeddings: " \ f"vision_mask shape={vision_mask.shape}; " \ f"vision_mask count={torch.count_nonzero(vision_mask)}; " \ f"vision_embeds shape={vision_embeds.shape}" #print(info) num_vision_1 = torch.count_nonzero(vision_mask).item() num_vision_2 = vision_embeds.shape[:-1].numel() vision_embeds = vision_embeds.contiguous() if num_vision_1 <= num_vision_2: vision_embeds = vision_embeds.view(-1, vision_embeds.size(-1))[:num_vision_1] else: vision_embeds = vision_embeds.view(-1, vision_embeds.size(-1)) less_nums = num_vision_1 - num_vision_2 vision_embeds = torch.cat([vision_embeds, vision_embeds[-less_nums:]], dim=0) vision_embeds = vision_embeds.contiguous() # assert torch.count_nonzero(vision_mask).item() == vision_embeds.shape[:-1].numel(), \ # "vision embeddings mismatch input embeddings: " \ # f"vision_mask shape={vision_mask.shape}; " \ # f"vision_mask count={torch.count_nonzero(vision_mask)}; " \ # f"vision_embeds shape={vision_embeds.shape}" inputs_embeds = torch.masked_scatter(inputs_embeds, vision_mask.unsqueeze(-1), vision_embeds.to(inputs_embeds.dtype).view(-1, vision_embeds.size(-1))) except: inputs_embeds = inputs_embeds + torch.sum(vision_embeds_[0, 0, :]) * 0.0 return inputs_embeds def pixel_shuffle(self, x, scale_factor=0.5): x = x.contiguous() n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) x = x.permute(0, 2, 1, 3).contiguous() return x def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.BoolTensor] = None, position_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, return_dict: bool = True, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # ========Get visual embedding======== if pixel_values is not None: vision_embeds = self.forward_vision(pixel_values) else: vision_embeds = None # print(f"input_ids.shape: {input_ids.shape}", {vision_embeds.shape}) inputs_embeds = self.prepare_for_lm(input_ids, vision_embeds) # print(f"inputs_embeds.shape: {inputs_embeds.shape}") p_mask = None answer_length = None if self.is_gradient_checkpointing and torch.is_grad_enabled(): inputs_embeds.requires_grad_(True) # ========Forward into LM======== outputs = self.language_model( input_ids=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, return_dict=return_dict, labels=labels, use_cache=False, conversation_ids=None, replacement_noise_mode=self.replacement_noise_mode, p_mask = p_mask, answer_length = answer_length, **kwargs, ) return outputs def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): super().gradient_checkpointing_enable(gradient_checkpointing_kwargs) self.language_model.enable_input_require_grads() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() def tie_weights(self): return self.language_model.tie_weights() @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, **generate_kwargs, ) -> torch.LongTensor: if pixel_values is not None: vision_embeds = self.forward_vision(pixel_values) else: vision_embeds = None inputs_embeds = self.prepare_for_lm(input_ids, vision_embeds) if 'llada' in self.config.language_model_config.name_or_path.lower(): outputs = self.language_model.generate_with_embeds( inputs_embeds=inputs_embeds, **generate_kwargs ) else: raise NotImplementedError(f"Generation not implemented for model: {self.config.language_model_config.name_or_path}") return outputs @torch.no_grad() def generate_replace_noise( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, **generate_kwargs, ) -> torch.LongTensor: if pixel_values is not None: vision_embeds = self.forward_vision(pixel_values) else: vision_embeds = None inputs_embeds = self.prepare_for_lm(input_ids, vision_embeds) outputs, all_steps_response = self.language_model.generate_with_embeds_replace_noise( inputs_embeds=inputs_embeds, **generate_kwargs ) return outputs, all_steps_response def get_template(self): template = dict( SYSTEM=("<|start_header_id|>system<|end_header_id|>\n{system}<|eot_id|>\n"), INSTRUCTION=("<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"), SUFFIX="<|eot_id|>", SUFFIX_AS_EOS=True, SEP="\n", STOP_WORDS=["<|eot_id|>"], ) return template @torch.no_grad() def chat( self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False ): if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = self.get_template() eos_token_id = tokenizer.convert_tokens_to_ids(template["SUFFIX"].strip()) history = "" if history is None else history prompt = history prompt = prompt + template["INSTRUCTION"].format(input=question) if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') prompt = prompt[::-1] for num_patches in num_patches_list[::-1]: image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN prompt = prompt.replace(''[::-1], image_tokens[::-1], 1) prompt = prompt[::-1] model_inputs = tokenizer(prompt, return_tensors='pt') device = torch.device(self.language_model.device if torch.cuda.is_available() else 'cpu') input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, input_ids=input_ids, **generation_config ) # response = [ # tokenizer.decode(g[len(p) :].tolist()) # for p, g in zip(input_ids, generation_output) # ][0] #print("generation_output:", tokenizer.batch_decode(generation_output, skip_special_tokens=False)[0]) response = tokenizer.batch_decode(generation_output, skip_special_tokens=False)[0] history = history + prompt + response response = response.split(template["SUFFIX"].strip())[0].strip() if return_history: return response, history else: if verbose: print(response) return response return @torch.no_grad() def chat_replace_noise( self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', verbose=False ): if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = self.get_template() eos_token_id = tokenizer.convert_tokens_to_ids(template["SUFFIX"].strip()) history = "" if history is None else history prompt = history prompt = prompt + template["INSTRUCTION"].format(input=question) if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') prompt = prompt[::-1] for num_patches in num_patches_list[::-1]: image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN prompt = prompt.replace(''[::-1], image_tokens[::-1], 1) prompt = prompt[::-1] model_inputs = tokenizer(prompt, return_tensors='pt') device = torch.device(self.language_model.device if torch.cuda.is_available() else 'cpu') input_ids = model_inputs['input_ids'].to(device) attention_mask = model_inputs['attention_mask'].to(device) generation_config['eos_token_id'] = eos_token_id generation_output, all_steps_response = self.generate_replace_noise( pixel_values=pixel_values, input_ids=input_ids, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=False)[0] all_steps_response_ = [] for step_response in all_steps_response: step_response = tokenizer.batch_decode(step_response, skip_special_tokens=False)[0] all_steps_response_.append(step_response) all_steps_response = all_steps_response_ for i, step_response in enumerate(all_steps_response): print(f"Step {i}: {step_response}\n") history = history + prompt + response response = response.split(template["SUFFIX"].strip())[0].strip() if return_history: return response, history else: if verbose: print(response) return response return AutoConfig.register("dmllm", DMLLMConfig) AutoModel.register(DMLLMConfig, DMLLM)