| import os |
| import re |
| import torch |
| import torch.distributed as dist |
| from pathlib import Path |
| from diffusers import FluxPipeline |
| from diffusers import FluxTransformer2DModel |
| from torch.utils.data import Dataset, DistributedSampler |
|
|
| class PromptDataset(Dataset): |
| def __init__(self, file_path): |
| with open(file_path, 'r') as f: |
| self.prompts = [line.strip() for line in f if line.strip()] |
| |
| def __len__(self): |
| return len(self.prompts) |
|
|
| def __getitem__(self, idx): |
| return self.prompts[idx] |
|
|
| def sanitize_filename(text, max_length=200): |
| sanitized = re.sub(r'[\\/:*?"<>|]', '_', text) |
| return sanitized[:max_length].rstrip() or "untitled" |
|
|
| def distributed_setup(): |
| rank = int(os.environ['RANK']) |
| local_rank = int(os.environ['LOCAL_RANK']) |
| world_size = int(os.environ['WORLD_SIZE']) |
| |
| dist.init_process_group(backend="nccl") |
| torch.cuda.set_device(local_rank) |
| return rank, local_rank, world_size |
|
|
| def main(): |
| rank, local_rank, world_size = distributed_setup() |
| |
| model_path = "CKPT_PATH" |
| flux_path = "./ckpt/flux" |
|
|
| transformer = FluxTransformer2DModel.from_pretrained(model_path, use_safetensors=True, torch_dtype=torch.float16).to("cuda") |
| pipe = FluxPipeline.from_pretrained(flux_path, transformer=None, torch_dtype=torch.float16).to("cuda") |
| pipe.transformer = transformer |
|
|
| dataset = PromptDataset("scripts/evaluation/prompt_test.txt") |
| sampler = DistributedSampler( |
| dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=False |
| ) |
|
|
| output_dir = Path("IMAGE_SAVE_FOLDER") |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| for idx in sampler: |
| prompt = dataset[idx] |
| try: |
| generator = torch.Generator(device=f"cuda:{local_rank}") |
| generator.manual_seed(42 + idx + rank*1000) |
| |
| image = pipe( |
| prompt, |
| guidance_scale=3.5, |
| height=1024, |
| width=1024, |
| num_inference_steps=50, |
| max_sequence_length=512, |
| generator=generator, |
| ).images[0] |
|
|
| filename = sanitize_filename(prompt) |
| save_path = output_dir / f"{filename}.png" |
| image.save(save_path) |
|
|
| print(f"[Rank {rank}] Generated: {save_path.name}") |
|
|
| except Exception as e: |
| print(f"[Rank {rank}] Error processing '{prompt[:20]}...': {str(e)}") |
|
|
| dist.destroy_process_group() |
|
|
| if __name__ == "__main__": |
| main() |