| """ |
| Model Loader Utilities |
| 模型加载工具 - 用于加载各种模型组件 |
| """ |
|
|
| import torch |
| import json |
| import safetensors.torch |
| from typing import Optional |
|
|
|
|
| def load_unet_from_safetensors(unet_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16): |
| """ |
| Load UNet from safetensors file |
| 从 safetensors 文件加载 UNet |
| |
| Args: |
| unet_path: Path to UNet safetensors file |
| config_path: Path to UNet config JSON file |
| device: Device to load model on |
| dtype: Data type for model weights |
| |
| Returns: |
| UNet2DConditionModel or None if loading fails |
| """ |
| try: |
| from diffusers import UNet2DConditionModel |
| |
| |
| with open(config_path, 'r') as f: |
| unet_config = json.load(f) |
| |
| |
| unet = UNet2DConditionModel.from_config(unet_config) |
| |
| |
| state_dict = safetensors.torch.load_file(unet_path) |
| unet.load_state_dict(state_dict) |
| unet.to(device, dtype) |
| |
| return unet |
| except Exception as e: |
| print(f"Error loading UNet: {e}") |
| return None |
|
|
|
|
| def load_vae_from_safetensors(vae_path: str, config_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16): |
| """ |
| Load VAE from safetensors file |
| 从 safetensors 文件加载 VAE |
| |
| Args: |
| vae_path: Path to VAE safetensors file |
| config_path: Path to VAE config JSON file |
| device: Device to load model on |
| dtype: Data type for model weights |
| |
| Returns: |
| AutoencoderKL or None if loading fails |
| """ |
| try: |
| from diffusers import AutoencoderKL |
| |
| |
| with open(config_path, 'r') as f: |
| vae_config = json.load(f) |
| |
| |
| vae = AutoencoderKL.from_config(vae_config) |
| |
| |
| state_dict = safetensors.torch.load_file(vae_path) |
| vae.load_state_dict(state_dict) |
| vae.to(device, dtype) |
| |
| return vae |
| except Exception as e: |
| print(f"Error loading VAE: {e}") |
| return None |
|
|
|
|
| def create_scheduler(scheduler_type: str = "EulerAncestral", model_id: str = "stabilityai/stable-diffusion-xl-base-1.0"): |
| """ |
| Create scheduler for diffusion process |
| 创建扩散过程调度器 |
| |
| Args: |
| scheduler_type: Type of scheduler to create |
| model_id: Model ID to load scheduler config from |
| |
| Returns: |
| Scheduler object or None if creation fails |
| """ |
| try: |
| if scheduler_type == "DDPM": |
| from diffusers import DDPMScheduler |
| scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") |
| elif scheduler_type == "DDIM": |
| from diffusers import DDIMScheduler |
| scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") |
| elif scheduler_type == "DPMSolverMultistep": |
| from diffusers import DPMSolverMultistepScheduler |
| scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") |
| elif scheduler_type == "EulerAncestral": |
| from diffusers import EulerAncestralDiscreteScheduler |
| scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") |
| else: |
| print(f"Unsupported scheduler type: {scheduler_type}, using DDPM") |
| from diffusers import DDPMScheduler |
| scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") |
| |
| return scheduler |
| except Exception as e: |
| print(f"Error creating scheduler: {e}") |
| return None |
|
|
|
|
| def load_qwen_model(model_path: str, device: str = "cuda"): |
| """ |
| Load Qwen3 embedding model |
| 加载 Qwen3 嵌入模型 |
| |
| Args: |
| model_path: Path to Qwen model |
| device: Device to load model on |
| |
| Returns: |
| SentenceTransformer model or None if loading fails |
| """ |
| try: |
| from sentence_transformers import SentenceTransformer |
| model = SentenceTransformer(model_path) |
| model.to(device) |
| return model |
| except ImportError: |
| print("Warning: sentence-transformers not available. Using mock embeddings.") |
| return None |
| except Exception as e: |
| print(f"Error loading Qwen model: {e}") |
| return None |
|
|
|
|
| def save_model_components( |
| unet, |
| vae, |
| adapter, |
| text_encoder, |
| save_dir: str, |
| save_format: str = "safetensors" |
| ): |
| """ |
| Save model components for training checkpoints |
| 保存模型组件用于训练检查点 |
| |
| Args: |
| unet: UNet model |
| vae: VAE model |
| adapter: Qwen embedding adapter |
| text_encoder: Qwen text encoder |
| save_dir: Directory to save components |
| save_format: Format to save in (safetensors or pt) |
| """ |
| import os |
| os.makedirs(save_dir, exist_ok=True) |
| |
| try: |
| if save_format == "safetensors": |
| |
| if unet is not None: |
| safetensors.torch.save_file( |
| unet.state_dict(), |
| os.path.join(save_dir, "unet.safetensors") |
| ) |
| |
| |
| if vae is not None: |
| safetensors.torch.save_file( |
| vae.state_dict(), |
| os.path.join(save_dir, "vae.safetensors") |
| ) |
| |
| |
| if adapter is not None: |
| safetensors.torch.save_file( |
| adapter.state_dict(), |
| os.path.join(save_dir, "adapter.safetensors") |
| ) |
| |
| else: |
| if unet is not None: |
| torch.save(unet.state_dict(), os.path.join(save_dir, "unet.pt")) |
| if vae is not None: |
| torch.save(vae.state_dict(), os.path.join(save_dir, "vae.pt")) |
| if adapter is not None: |
| torch.save(adapter.state_dict(), os.path.join(save_dir, "adapter.pt")) |
| |
| print(f"Model components saved to {save_dir}") |
| |
| except Exception as e: |
| print(f"Error saving model components: {e}") |
|
|
|
|
| def load_unet_with_lora( |
| unet_path: str, |
| unet_config_path: str, |
| lora_weights_path: Optional[str] = None, |
| lora_config_path: Optional[str] = None, |
| device: str = "cuda", |
| dtype: torch.dtype = torch.bfloat16 |
| ): |
| """ |
| Load UNet with optional LoRA weights |
| 加载带有可选LoRA权重的UNet |
| |
| Args: |
| base_unet_path: Path to base UNet (can be safetensors file or HF model path) |
| lora_weights_path: Optional path to LoRA weights (safetensors file) |
| lora_config_path: Optional path to LoRA config directory |
| device: Device to load model on |
| dtype: Data type for model weights |
| |
| Returns: |
| UNet model with LoRA applied if specified |
| """ |
| try: |
| from diffusers import UNet2DConditionModel |
| from peft import PeftModel, LoraConfig |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| unet = load_unet_from_safetensors(unet_path, unet_config_path, device, dtype) |
| |
| |
| if lora_weights_path and lora_config_path: |
| print(f"Loading LoRA weights from {lora_weights_path}") |
| |
| |
| if lora_weights_path.endswith(".safetensors"): |
| import safetensors.torch |
| lora_state_dict = safetensors.torch.load_file(lora_weights_path) |
| else: |
| lora_state_dict = torch.load(lora_weights_path, map_location=device) |
| |
| |
| lora_config = LoraConfig.from_pretrained(lora_config_path) |
| |
| |
| from peft import get_peft_model, set_peft_model_state_dict |
| unet = get_peft_model(unet, lora_config) |
| set_peft_model_state_dict(unet, lora_state_dict) |
| |
| print("LoRA weights applied to UNet") |
| |
| unet.to(device, dtype) |
| return unet |
| |
| except Exception as e: |
| print(f"Error loading UNet with LoRA: {e}") |
| return None |
|
|
|
|
| def load_fused_unet( |
| fused_unet_path: str, |
| device: str = "cuda", |
| dtype: torch.dtype = torch.bfloat16 |
| ): |
| """ |
| Load UNet with fused LoRA weights |
| 加载融合了LoRA权重的UNet |
| |
| Args: |
| fused_unet_path: Path to fused UNet model directory |
| device: Device to load model on |
| dtype: Data type for model weights |
| |
| Returns: |
| UNet model with fused LoRA weights |
| """ |
| try: |
| from diffusers import UNet2DConditionModel |
| |
| unet = UNet2DConditionModel.from_pretrained( |
| fused_unet_path, |
| torch_dtype=dtype |
| ) |
| |
| unet.to(device, dtype) |
| print(f"Fused UNet loaded from {fused_unet_path}") |
| return unet |
| |
| except Exception as e: |
| print(f"Error loading fused UNet: {e}") |
| return None |
|
|
|
|
| def load_checkpoint(checkpoint_path: str, device: str = "cuda"): |
| """ |
| Load training checkpoint |
| 加载训练检查点 |
| |
| Args: |
| checkpoint_path: Path to checkpoint file |
| device: Device to load on |
| |
| Returns: |
| Dictionary containing checkpoint data |
| """ |
| try: |
| if checkpoint_path.endswith(".safetensors"): |
| return safetensors.torch.load_file(checkpoint_path, device=device) |
| else: |
| return torch.load(checkpoint_path, map_location=device) |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| return None |
|
|