| """ |
| VLM模型加载和LoRA配置 |
| 支持多种VLM架构 |
| """ |
|
|
| import torch |
| from transformers import ( |
| AutoModelForVision2Seq, |
| AutoProcessor, |
| AutoTokenizer |
| ) |
| from peft import LoraConfig, get_peft_model, TaskType |
| from config import ModelConfig |
|
|
|
|
| def load_qwen25_vl_model(config: ModelConfig): |
| """加载Qwen2.5-VL模型""" |
| print(f"加载模型: {config.model_path}") |
| |
| min_pixels = 256 * 28 * 28 |
| max_pixels = 768 * 28 * 28 |
| |
| |
| processor = AutoProcessor.from_pretrained( |
| config.model_path, |
| trust_remote_code=True, |
| min_pixels=min_pixels, |
| max_pixels=max_pixels, |
| ) |
| |
| |
| model_kwargs = { |
| "trust_remote_code": True, |
| "torch_dtype": torch.bfloat16, |
| } |
| |
| if config.load_in_4bit: |
| from transformers import BitsAndBytesConfig |
| model_kwargs["quantization_config"] = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4" |
| ) |
| elif config.load_in_8bit: |
| model_kwargs["load_in_8bit"] = True |
| |
| |
| model = AutoModelForVision2Seq.from_pretrained( |
| config.model_path, |
| **model_kwargs |
| ) |
| |
| try: |
| model.config.use_cache = False |
| except Exception: |
| pass |
| if hasattr(model, "gradient_checkpointing_enable"): |
| model.gradient_checkpointing_enable() |
| if hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
| |
| |
| if config.use_lora: |
| print("应用LoRA配置...") |
| lora_config = LoraConfig( |
| r=config.lora_r, |
| lora_alpha=config.lora_alpha, |
| target_modules=config.lora_target_modules, |
| lora_dropout=config.lora_dropout, |
| bias="none", |
| task_type=TaskType.CAUSAL_LM |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| |
| return model, processor |
|
|
|
|
| def prepare_qwen25_vl_inputs(processor, images, text_prompts, device): |
| """ |
| 准备Qwen2.5-VL的输入 |
| |
| Args: |
| processor: Qwen2VL processor |
| images: List of PIL Images or List of List of PIL Images (for sequences) |
| text_prompts: List of text prompts |
| device: torch device |
| |
| Returns: |
| inputs: 模型输入字典 |
| """ |
| messages_batch = [] |
| |
| for i, (img, prompt) in enumerate(zip(images, text_prompts)): |
| if isinstance(img, list): |
| |
| content = [] |
| for frame in img: |
| content.append({"type": "image", "image": frame}) |
| content.append({"type": "text", "text": prompt}) |
| else: |
| |
| content = [ |
| {"type": "image", "image": img}, |
| {"type": "text", "text": prompt} |
| ] |
| |
| messages = [{"role": "user", "content": content}] |
| messages_batch.append(messages) |
| |
| |
| texts = [ |
| processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) |
| for msg in messages_batch |
| ] |
|
|
| |
| images_nested = [] |
| for img in images: |
| images_nested.append(img if isinstance(img, list) else [img]) |
|
|
| |
| inputs = processor( |
| text=texts, |
| images=images_nested, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| ) |
|
|
| |
| tok = processor.tokenizer |
| if tok.pad_token_id is None: |
| tok.pad_token = tok.eos_token |
|
|
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| inputs["__prompt_texts__"] = texts |
| return inputs |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def load_model_and_processor(config: ModelConfig): |
| """ |
| 根据模型类型加载模型和processor |
| """ |
| if config.model_type == "qwen2.5-vl": |
| return load_qwen25_vl_model(config) |
| else: |
| raise ValueError(f"不支持的模型类型: {config.model_type}") |
|
|
|
|
| def prepare_model_inputs(processor, model_type, images, text_prompts, device): |
| """ |
| 根据模型类型准备输入 |
| """ |
| if model_type == "qwen2.5-vl": |
| return prepare_qwen25_vl_inputs(processor, images, text_prompts, device) |
| else: |
| raise ValueError(f"不支持的模型类型: {model_type}") |