""" VLM模型加载和LoRA配置 支持多种VLM架构 """ import torch from transformers import ( AutoModelForVision2Seq, AutoProcessor, AutoTokenizer ) from peft import LoraConfig, get_peft_model, TaskType from config import ModelConfig def load_qwen25_vl_model(config: ModelConfig): """加载Qwen2.5-VL模型""" print(f"加载模型: {config.model_path}") min_pixels = 256 * 28 * 28 max_pixels = 768 * 28 * 28 # 加载processor processor = AutoProcessor.from_pretrained( config.model_path, trust_remote_code=True, min_pixels=min_pixels, max_pixels=max_pixels, ) # 加载模型 - 使用AutoModelForVision2Seq而不是特定类 model_kwargs = { "trust_remote_code": True, "torch_dtype": torch.bfloat16, } if config.load_in_4bit: from transformers import BitsAndBytesConfig model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) elif config.load_in_8bit: model_kwargs["load_in_8bit"] = True # 使用AutoModelForVision2Seq自动识别模型类型 model = AutoModelForVision2Seq.from_pretrained( config.model_path, **model_kwargs ) try: model.config.use_cache = False except Exception: pass if hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable() if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() # 应用LoRA if config.use_lora: print("应用LoRA配置...") lora_config = LoraConfig( r=config.lora_r, lora_alpha=config.lora_alpha, target_modules=config.lora_target_modules, lora_dropout=config.lora_dropout, bias="none", task_type=TaskType.CAUSAL_LM ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() return model, processor def prepare_qwen25_vl_inputs(processor, images, text_prompts, device): """ 准备Qwen2.5-VL的输入 Args: processor: Qwen2VL processor images: List of PIL Images or List of List of PIL Images (for sequences) text_prompts: List of text prompts device: torch device Returns: inputs: 模型输入字典 """ messages_batch = [] for i, (img, prompt) in enumerate(zip(images, text_prompts)): if isinstance(img, list): # 序列输入(任务3) content = [] for frame in img: content.append({"type": "image", "image": frame}) content.append({"type": "text", "text": prompt}) else: # 单帧输入(任务1和2) content = [ {"type": "image", "image": img}, {"type": "text", "text": prompt} ] messages = [{"role": "user", "content": content}] messages_batch.append(messages) # 1) 只做“提示”(不包含答案),用于训练时对齐 labels texts = [ processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages_batch ] # 2) 图像必须是“按样本的列表”,多帧用 list-of-images images_nested = [] for img in images: images_nested.append(img if isinstance(img, list) else [img]) # 3) 构造模型输入 inputs = processor( text=texts, images=images_nested, return_tensors="pt", padding=True, truncation=True, ) # 保证有 pad_token_id tok = processor.tokenizer if tok.pad_token_id is None: tok.pad_token = tok.eos_token inputs = {k: v.to(device) for k, v in inputs.items()} # 同时把“提示文本”返回,后面构造对齐的 labels 要用 inputs["__prompt_texts__"] = texts # 仅供上层用,不会传给 model.forward return inputs # # 使用processor处理 # texts = [ # processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) # for msg in messages_batch # ] # # 准备所有图像 # all_images = [] # for img in images: # if isinstance(img, list): # all_images.extend(img) # else: # all_images.append(img) # # 处理输入 # inputs = processor( # text=texts, # images=all_images if all_images else None, # return_tensors="pt", # padding=True # ) # return {k: v.to(device) for k, v in inputs.items()} def load_model_and_processor(config: ModelConfig): """ 根据模型类型加载模型和processor """ if config.model_type == "qwen2.5-vl": return load_qwen25_vl_model(config) else: raise ValueError(f"不支持的模型类型: {config.model_type}") def prepare_model_inputs(processor, model_type, images, text_prompts, device): """ 根据模型类型准备输入 """ if model_type == "qwen2.5-vl": return prepare_qwen25_vl_inputs(processor, images, text_prompts, device) else: raise ValueError(f"不支持的模型类型: {model_type}")