VLAlert / training /pretrain /model_loader.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
5.43 kB
"""
VLM模型加载和LoRA配置
支持多种VLM架构
"""
import torch
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer
)
from peft import LoraConfig, get_peft_model, TaskType
from config import ModelConfig
def load_qwen25_vl_model(config: ModelConfig):
"""加载Qwen2.5-VL模型"""
print(f"加载模型: {config.model_path}")
min_pixels = 256 * 28 * 28
max_pixels = 768 * 28 * 28
# 加载processor
processor = AutoProcessor.from_pretrained(
config.model_path,
trust_remote_code=True,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
# 加载模型 - 使用AutoModelForVision2Seq而不是特定类
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": torch.bfloat16,
}
if config.load_in_4bit:
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
elif config.load_in_8bit:
model_kwargs["load_in_8bit"] = True
# 使用AutoModelForVision2Seq自动识别模型类型
model = AutoModelForVision2Seq.from_pretrained(
config.model_path,
**model_kwargs
)
try:
model.config.use_cache = False
except Exception:
pass
if hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
# 应用LoRA
if config.use_lora:
print("应用LoRA配置...")
lora_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
target_modules=config.lora_target_modules,
lora_dropout=config.lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, processor
def prepare_qwen25_vl_inputs(processor, images, text_prompts, device):
"""
准备Qwen2.5-VL的输入
Args:
processor: Qwen2VL processor
images: List of PIL Images or List of List of PIL Images (for sequences)
text_prompts: List of text prompts
device: torch device
Returns:
inputs: 模型输入字典
"""
messages_batch = []
for i, (img, prompt) in enumerate(zip(images, text_prompts)):
if isinstance(img, list):
# 序列输入(任务3)
content = []
for frame in img:
content.append({"type": "image", "image": frame})
content.append({"type": "text", "text": prompt})
else:
# 单帧输入(任务1和2)
content = [
{"type": "image", "image": img},
{"type": "text", "text": prompt}
]
messages = [{"role": "user", "content": content}]
messages_batch.append(messages)
# 1) 只做“提示”(不包含答案),用于训练时对齐 labels
texts = [
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages_batch
]
# 2) 图像必须是“按样本的列表”,多帧用 list-of-images
images_nested = []
for img in images:
images_nested.append(img if isinstance(img, list) else [img])
# 3) 构造模型输入
inputs = processor(
text=texts,
images=images_nested,
return_tensors="pt",
padding=True,
truncation=True,
)
# 保证有 pad_token_id
tok = processor.tokenizer
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token
inputs = {k: v.to(device) for k, v in inputs.items()}
# 同时把“提示文本”返回,后面构造对齐的 labels 要用
inputs["__prompt_texts__"] = texts # 仅供上层用,不会传给 model.forward
return inputs
# # 使用processor处理
# texts = [
# processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
# for msg in messages_batch
# ]
# # 准备所有图像
# all_images = []
# for img in images:
# if isinstance(img, list):
# all_images.extend(img)
# else:
# all_images.append(img)
# # 处理输入
# inputs = processor(
# text=texts,
# images=all_images if all_images else None,
# return_tensors="pt",
# padding=True
# )
# return {k: v.to(device) for k, v in inputs.items()}
def load_model_and_processor(config: ModelConfig):
"""
根据模型类型加载模型和processor
"""
if config.model_type == "qwen2.5-vl":
return load_qwen25_vl_model(config)
else:
raise ValueError(f"不支持的模型类型: {config.model_type}")
def prepare_model_inputs(processor, model_type, images, text_prompts, device):
"""
根据模型类型准备输入
"""
if model_type == "qwen2.5-vl":
return prepare_qwen25_vl_inputs(processor, images, text_prompts, device)
else:
raise ValueError(f"不支持的模型类型: {model_type}")