| """ |
| Data Loading Utilities for QwenIllustrious |
| 数据加载工具 - 处理训练数据的加载和预处理,支持预计算嵌入加速 |
| """ |
|
|
| import os |
| import json |
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
| import torchvision.transforms as transforms |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
| import pickle |
| from tqdm import tqdm |
|
|
|
|
| class QwenIllustriousDataset(Dataset): |
| """ |
| Dataset for QwenIllustrious training |
| 支持以下功能: |
| - 从 metadata.json 文件加载图像和标注 |
| - 图像预处理和增强 |
| - Qwen 文本编码缓存 |
| - VAE 潜在空间编码缓存 |
| - 训练时的预计算加速 |
| """ |
| |
| def __init__( |
| self, |
| dataset_path: str, |
| qwen_text_encoder=None, |
| vae=None, |
| image_size: int = 1024, |
| cache_dir: Optional[str] = None, |
| precompute_embeddings: bool = False |
| ): |
| self.dataset_path = Path(dataset_path) |
| self.qwen_text_encoder = qwen_text_encoder |
| self.vae = vae |
| self.image_size = image_size |
| self.cache_dir = Path(cache_dir) if cache_dir else None |
| self.precompute_embeddings = precompute_embeddings |
| |
| |
| self.image_transforms = transforms.Compose([ |
| transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]) |
| ]) |
| |
| |
| self.metadata = self._load_metadata() |
| |
| |
| if self.cache_dir: |
| self.cache_dir.mkdir(exist_ok=True) |
| self.text_cache_dir = self.cache_dir / "text_embeddings" |
| self.vae_cache_dir = self.cache_dir / "vae_latents" |
| self.text_cache_dir.mkdir(exist_ok=True) |
| self.vae_cache_dir.mkdir(exist_ok=True) |
| |
| |
| self.precomputed_data = {} |
| |
| def _load_metadata(self) -> List[Dict]: |
| """Load all metadata files""" |
| metadata_dir = self.dataset_path / "metadata" |
| if not metadata_dir.exists(): |
| raise ValueError(f"Metadata directory not found: {metadata_dir}") |
| |
| metadata_files = list(metadata_dir.glob("*.json")) |
| |
| metadata_list = [] |
| print(f"Loading metadata from {len(metadata_files)} files...") |
| |
| for file_path in tqdm(metadata_files, desc="Loading metadata"): |
| try: |
| with open(file_path, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| data['metadata_file'] = str(file_path) |
| data['image_file'] = str(self.dataset_path / f"{data['filename_hash']}.png") |
| metadata_list.append(data) |
| except Exception as e: |
| print(f"Error loading {file_path}: {e}") |
| continue |
| |
| print(f"Successfully loaded {len(metadata_list)} metadata files") |
| return metadata_list |
| |
| def _get_text_cache_path(self, filename_hash: str) -> Path: |
| """Get path for cached text embeddings""" |
| return self.text_cache_dir / f"{filename_hash}_text.pt" |
| |
| def _get_vae_cache_path(self, filename_hash: str) -> Path: |
| """Get path for cached VAE latents""" |
| return self.vae_cache_dir / f"{filename_hash}_vae.pt" |
| |
| def _compute_text_embeddings(self, prompt: str, device='cpu') -> Dict[str, torch.Tensor]: |
| """Compute text embeddings using Qwen text encoder""" |
| if not self.qwen_text_encoder: |
| |
| return { |
| 'text_embeddings': torch.zeros(1, 2048), |
| 'pooled_embeddings': torch.zeros(1, 1280) |
| } |
| |
| with torch.no_grad(): |
| |
| original_device = next(self.qwen_text_encoder.parameters()).device |
| self.qwen_text_encoder.to(device) |
| |
| embeddings = self.qwen_text_encoder.encode_prompts([prompt]) |
| |
| |
| self.qwen_text_encoder.to(original_device) |
| |
| return { |
| 'text_embeddings': embeddings[0].cpu(), |
| 'pooled_embeddings': embeddings[1].cpu() if len(embeddings) > 1 else embeddings[0].cpu() |
| } |
| |
| def _compute_vae_latents(self, image: torch.Tensor, device='cpu') -> torch.Tensor: |
| """Compute VAE latents for image""" |
| if not self.vae: |
| |
| return torch.zeros(1, 4, self.image_size // 8, self.image_size // 8) |
| |
| with torch.no_grad(): |
| |
| original_device = next(self.vae.parameters()).device |
| self.vae.to(device) |
| |
| |
| if image.dim() == 3: |
| image = image.unsqueeze(0) |
| |
| image = image.to(device).to(self.vae.dtype) |
| latents = self.vae.encode(image).latent_dist.sample() |
| latents = latents * self.vae.config.scaling_factor |
| |
| |
| self.vae.to(original_device) |
| |
| return latents.cpu() |
| |
| def _load_or_compute_text_embeddings(self, prompt: str, filename_hash: str, device='cpu') -> Dict[str, torch.Tensor]: |
| """Load cached text embeddings or compute new ones""" |
| if self.cache_dir: |
| cache_path = self._get_text_cache_path(filename_hash) |
| |
| |
| if cache_path.exists(): |
| try: |
| return torch.load(cache_path, map_location='cpu') |
| except Exception as e: |
| print(f"Error loading cached text embeddings {cache_path}: {e}") |
| |
| |
| embeddings = self._compute_text_embeddings(prompt, device) |
| |
| |
| if self.cache_dir: |
| try: |
| torch.save(embeddings, cache_path) |
| except Exception as e: |
| print(f"Error saving text embeddings cache {cache_path}: {e}") |
| |
| return embeddings |
| |
| def _load_or_compute_vae_latents(self, image_path: str, filename_hash: str, device='cpu') -> torch.Tensor: |
| """Load cached VAE latents or compute new ones""" |
| if self.cache_dir: |
| cache_path = self._get_vae_cache_path(filename_hash) |
| |
| |
| if cache_path.exists(): |
| try: |
| return torch.load(cache_path, map_location='cpu') |
| except Exception as e: |
| print(f"Error loading cached VAE latents {cache_path}: {e}") |
| |
| |
| try: |
| image = Image.open(image_path).convert('RGB') |
| image = self.image_transforms(image) |
| except Exception as e: |
| print(f"Error loading image {image_path}: {e}") |
| image = torch.zeros(3, self.image_size, self.image_size) |
| |
| |
| latents = self._compute_vae_latents(image, device) |
| |
| |
| if self.cache_dir: |
| try: |
| torch.save(latents, cache_path) |
| except Exception as e: |
| print(f"Error saving VAE latents cache {cache_path}: {e}") |
| |
| return latents |
| |
| def precompute_all(self, device='cuda'): |
| """Precompute all embeddings and latents for faster training""" |
| print("Precomputing all embeddings and latents...") |
| |
| for idx in tqdm(range(len(self.metadata)), desc="Precomputing"): |
| metadata = self.metadata[idx] |
| filename_hash = metadata['filename_hash'] |
| |
| |
| prompt = metadata.get('natural_caption_data', {}).get('natural_caption', '') |
| if not prompt: |
| prompt = metadata.get('original_prompt_data', {}).get('positive_prompt', '') |
| |
| |
| text_embeddings = self._load_or_compute_text_embeddings(prompt, filename_hash, device) |
| |
| |
| vae_latents = self._load_or_compute_vae_latents(metadata['image_file'], filename_hash, device) |
| |
| |
| self.precomputed_data[filename_hash] = { |
| 'text_embeddings': text_embeddings['text_embeddings'].squeeze(0), |
| 'pooled_embeddings': text_embeddings['pooled_embeddings'].squeeze(0), |
| 'latents': vae_latents.squeeze(0), |
| 'prompt': prompt |
| } |
| |
| print(f"Precomputation completed for {len(self.precomputed_data)} items") |
| |
| def __len__(self): |
| return len(self.metadata) |
| |
| def __getitem__(self, idx) -> Dict[str, torch.Tensor]: |
| metadata = self.metadata[idx] |
| filename_hash = metadata['filename_hash'] |
| |
| if self.precompute_embeddings and filename_hash in self.precomputed_data: |
| |
| data = self.precomputed_data[filename_hash] |
| return { |
| 'text_embeddings': data['text_embeddings'], |
| 'pooled_embeddings': data['pooled_embeddings'], |
| 'latents': data['latents'], |
| 'prompts': data['prompt'], |
| 'filename_hash': filename_hash, |
| 'metadata': metadata |
| } |
| else: |
| |
| |
| |
| image_path = metadata['image_file'] |
| try: |
| image = Image.open(image_path).convert('RGB') |
| image = self.image_transforms(image) |
| except Exception as e: |
| print(f"Error loading image {image_path}: {e}") |
| image = torch.zeros(3, self.image_size, self.image_size) |
| |
| |
| prompt = metadata.get('natural_caption_data', {}).get('natural_caption', '') |
| if not prompt: |
| prompt = metadata.get('original_prompt_data', {}).get('positive_prompt', '') |
| |
| |
| text_embeddings = self._load_or_compute_text_embeddings(prompt, filename_hash) |
| |
| return { |
| 'images': image, |
| 'prompts': prompt, |
| 'text_embeddings': text_embeddings['text_embeddings'].squeeze(0), |
| 'pooled_embeddings': text_embeddings['pooled_embeddings'].squeeze(0), |
| 'filename_hash': filename_hash, |
| 'metadata': metadata |
| } |
|
|
|
|
| def collate_fn(examples: List[Dict]) -> Dict[str, torch.Tensor]: |
| """Custom collate function for DataLoader""" |
| batch = {} |
| |
| |
| if 'latents' in examples[0]: |
| |
| batch['latents'] = torch.stack([example['latents'] for example in examples]) |
| batch['text_embeddings'] = torch.stack([example['text_embeddings'] for example in examples]) |
| batch['pooled_embeddings'] = torch.stack([example['pooled_embeddings'] for example in examples]) |
| else: |
| |
| batch['images'] = torch.stack([example['images'] for example in examples]) |
| batch['text_embeddings'] = torch.stack([example['text_embeddings'] for example in examples]) |
| batch['pooled_embeddings'] = torch.stack([example['pooled_embeddings'] for example in examples]) |
| |
| |
| batch['prompts'] = [example['prompts'] for example in examples] |
| batch['filename_hash'] = [example['filename_hash'] for example in examples] |
| batch['metadata'] = [example['metadata'] for example in examples] |
| |
| return batch |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from PIL import Image |
| import json |
| import os |
| from typing import List, Dict, Any, Optional, Tuple, Union |
| import torchvision.transforms as transforms |
| import random |
|
|
|
|
| class ImageCaptionDataset(Dataset): |
| """ |
| Dataset for image-caption pairs |
| 图像-标题对数据集 |
| """ |
| |
| def __init__( |
| self, |
| data_root: str, |
| annotations_file: str, |
| image_size: int = 1024, |
| center_crop: bool = True, |
| random_flip: bool = True, |
| caption_column: str = "caption", |
| image_column: str = "image", |
| max_caption_length: int = 512 |
| ): |
| self.data_root = data_root |
| self.image_size = image_size |
| self.caption_column = caption_column |
| self.image_column = image_column |
| self.max_caption_length = max_caption_length |
| |
| |
| self.annotations = self._load_annotations(annotations_file) |
| |
| |
| self.image_transforms = self._setup_transforms(image_size, center_crop, random_flip) |
| |
| print(f"📚 数据集加载完成: {len(self.annotations)} 个样本") |
| |
| def _load_annotations(self, annotations_file: str) -> List[Dict]: |
| """Load annotations from file""" |
| if annotations_file.endswith('.json'): |
| with open(annotations_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| elif annotations_file.endswith('.jsonl'): |
| data = [] |
| with open(annotations_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| if line.strip(): |
| data.append(json.loads(line)) |
| else: |
| raise ValueError(f"Unsupported annotation file format: {annotations_file}") |
| |
| |
| valid_data = [] |
| for item in data: |
| if self.caption_column in item and self.image_column in item: |
| if isinstance(item[self.caption_column], str) and item[self.caption_column].strip(): |
| valid_data.append(item) |
| |
| print(f"📋 有效样本数: {len(valid_data)} / {len(data)}") |
| return valid_data |
| |
| def _setup_transforms(self, size: int, center_crop: bool, random_flip: bool): |
| """Setup image preprocessing transforms""" |
| transform_list = [] |
| |
| |
| if center_crop: |
| transform_list.extend([ |
| transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), |
| transforms.CenterCrop(size) |
| ]) |
| else: |
| transform_list.append( |
| transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR) |
| ) |
| |
| |
| if random_flip: |
| transform_list.append(transforms.RandomHorizontalFlip(p=0.5)) |
| |
| |
| transform_list.extend([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]) |
| ]) |
| |
| return transforms.Compose(transform_list) |
| |
| def __len__(self): |
| return len(self.annotations) |
| |
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| """Get a single sample""" |
| annotation = self.annotations[idx] |
| |
| |
| image_path = os.path.join(self.data_root, annotation[self.image_column]) |
| try: |
| image = Image.open(image_path) |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| except Exception as e: |
| print(f"⚠️ 加载图像失败 {image_path}: {e}") |
| |
| image = Image.new('RGB', (self.image_size, self.image_size), (0, 0, 0)) |
| |
| |
| image = self.image_transforms(image) |
| |
| |
| caption = annotation[self.caption_column] |
| if len(caption) > self.max_caption_length: |
| caption = caption[:self.max_caption_length] |
| |
| return { |
| "images": image, |
| "captions": caption, |
| "image_paths": image_path |
| } |
|
|
|
|
| class MultiAspectDataset(Dataset): |
| """ |
| Dataset that supports multiple aspect ratios |
| 支持多种长宽比的数据集 |
| """ |
| |
| def __init__( |
| self, |
| data_root: str, |
| annotations_file: str, |
| base_size: int = 1024, |
| aspect_ratios: List[Tuple[int, int]] = None, |
| bucket_tolerance: float = 0.1, |
| caption_column: str = "caption", |
| image_column: str = "image", |
| max_caption_length: int = 512 |
| ): |
| self.data_root = data_root |
| self.base_size = base_size |
| self.caption_column = caption_column |
| self.image_column = image_column |
| self.max_caption_length = max_caption_length |
| |
| |
| if aspect_ratios is None: |
| aspect_ratios = [ |
| (1024, 1024), |
| (1152, 896), |
| (896, 1152), |
| (1216, 832), |
| (832, 1216), |
| (1344, 768), |
| (768, 1344), |
| (1536, 640), |
| (640, 1536), |
| ] |
| |
| self.aspect_ratios = aspect_ratios |
| self.bucket_tolerance = bucket_tolerance |
| |
| |
| self.annotations = self._load_and_bucket_annotations(annotations_file) |
| |
| print(f"📚 多长宽比数据集加载完成: {len(self.annotations)} 个样本") |
| self._print_bucket_stats() |
| |
| def _load_and_bucket_annotations(self, annotations_file: str) -> List[Dict]: |
| """Load annotations and assign to aspect ratio buckets""" |
| |
| if annotations_file.endswith('.json'): |
| with open(annotations_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| elif annotations_file.endswith('.jsonl'): |
| data = [] |
| with open(annotations_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| if line.strip(): |
| data.append(json.loads(line)) |
| |
| bucketed_data = [] |
| |
| for item in data: |
| if self.caption_column not in item or self.image_column not in item: |
| continue |
| |
| caption = item[self.caption_column] |
| if not isinstance(caption, str) or not caption.strip(): |
| continue |
| |
| |
| image_path = os.path.join(self.data_root, item[self.image_column]) |
| try: |
| with Image.open(image_path) as img: |
| width, height = img.size |
| aspect_ratio = width / height |
| |
| |
| best_bucket = self._find_best_bucket(aspect_ratio) |
| |
| item_copy = item.copy() |
| item_copy['bucket_width'] = best_bucket[0] |
| item_copy['bucket_height'] = best_bucket[1] |
| item_copy['original_width'] = width |
| item_copy['original_height'] = height |
| |
| bucketed_data.append(item_copy) |
| |
| except Exception as e: |
| print(f"⚠️ 无法获取图像尺寸 {image_path}: {e}") |
| |
| item_copy = item.copy() |
| item_copy['bucket_width'] = 1024 |
| item_copy['bucket_height'] = 1024 |
| item_copy['original_width'] = 1024 |
| item_copy['original_height'] = 1024 |
| bucketed_data.append(item_copy) |
| |
| return bucketed_data |
| |
| def _find_best_bucket(self, aspect_ratio: float) -> Tuple[int, int]: |
| """Find the best matching aspect ratio bucket""" |
| best_bucket = self.aspect_ratios[0] |
| best_diff = float('inf') |
| |
| for bucket_w, bucket_h in self.aspect_ratios: |
| bucket_ratio = bucket_w / bucket_h |
| diff = abs(aspect_ratio - bucket_ratio) |
| |
| if diff < best_diff: |
| best_diff = diff |
| best_bucket = (bucket_w, bucket_h) |
| |
| return best_bucket |
| |
| def _print_bucket_stats(self): |
| """Print statistics about bucket distribution""" |
| bucket_counts = {} |
| for item in self.annotations: |
| bucket = (item['bucket_width'], item['bucket_height']) |
| bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1 |
| |
| print("📊 长宽比分布:") |
| for bucket, count in sorted(bucket_counts.items()): |
| ratio = bucket[0] / bucket[1] |
| print(f" {bucket[0]}×{bucket[1]} (比例 {ratio:.2f}): {count} 个样本") |
| |
| def _get_transforms(self, target_width: int, target_height: int): |
| """Get transforms for specific target size""" |
| return transforms.Compose([ |
| transforms.Resize((target_height, target_width), interpolation=transforms.InterpolationMode.BILINEAR), |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]) |
| ]) |
| |
| def __len__(self): |
| return len(self.annotations) |
| |
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| """Get a single sample""" |
| annotation = self.annotations[idx] |
| |
| |
| target_width = annotation['bucket_width'] |
| target_height = annotation['bucket_height'] |
| |
| |
| image_path = os.path.join(self.data_root, annotation[self.image_column]) |
| try: |
| image = Image.open(image_path) |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| except Exception as e: |
| print(f"⚠️ 加载图像失败 {image_path}: {e}") |
| image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) |
| |
| |
| transforms_fn = self._get_transforms(target_width, target_height) |
| image = transforms_fn(image) |
| |
| |
| caption = annotation[self.caption_column] |
| if len(caption) > self.max_caption_length: |
| caption = caption[:self.max_caption_length] |
| |
| return { |
| "images": image, |
| "captions": caption, |
| "image_paths": image_path, |
| "width": target_width, |
| "height": target_height |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def create_dataloader( |
| dataset: Dataset, |
| batch_size: int = 4, |
| shuffle: bool = True, |
| num_workers: int = 4, |
| pin_memory: bool = True, |
| drop_last: bool = True |
| ) -> DataLoader: |
| """ |
| Create dataloader with appropriate settings |
| 创建具有适当设置的数据加载器 |
| """ |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| drop_last=drop_last, |
| collate_fn=collate_fn |
| ) |
|
|