| from typing import AnyStr |
| import pathlib |
| from collections import OrderedDict |
| from packaging import version |
|
|
| import torch |
| from diffusers import StableDiffusionPipeline, SchedulerMixin, DiffusionPipeline |
| from diffusers.utils import is_torch_version, is_xformers_available |
|
|
| huggingface_model_dict = OrderedDict({ |
| "sd14": "/nfs/StableDiffusionModels/CompVis/stable-diffusion-v1-4", |
| "sd15": "/nfs/StableDiffusionModels/runwayml/stable-diffusion-v1-5", |
| "sd21b": "stabilityai/stable-diffusion-2-1-base", |
| "sd21": "stabilityai/stable-diffusion-2-1", |
| "sdxl": "stabilityai/stable-diffusion-xl-base-1.0", |
| }) |
|
|
| _model2resolution = { |
| "sd14": 512, |
| "sd15": 512, |
| "sd21b": 512, |
| "sd21": 768, |
| "sdxl": 1024, |
| } |
|
|
|
|
| def model2res(model_id: str): |
| return _model2resolution.get(model_id, 512) |
|
|
|
|
| def init_diffusion_pipeline(model_id: AnyStr, |
| custom_pipeline: StableDiffusionPipeline, |
| custom_scheduler: SchedulerMixin = None, |
| device: torch.device = "cuda", |
| torch_dtype: torch.dtype = torch.float32, |
| local_files_only: bool = True, |
| force_download: bool = False, |
| resume_download: bool = False, |
| ldm_speed_up: bool = False, |
| enable_xformers: bool = True, |
| gradient_checkpoint: bool = False, |
| lora_path: AnyStr = None, |
| unet_path: AnyStr = None) -> StableDiffusionPipeline: |
| """ |
| A tool for initial diffusers model. |
| |
| Args: |
| model_id (`str` or `os.PathLike`, *optional*): pretrained_model_name_or_path |
| custom_pipeline: any StableDiffusionPipeline pipeline |
| custom_scheduler: any scheduler |
| device: set device |
| local_files_only: prohibited download model |
| force_download: forced download model |
| resume_download: re-download model |
| ldm_speed_up: use the `torch.compile` api to speed up unet |
| enable_xformers: enable memory efficient attention from [xFormers] |
| gradient_checkpoint: activates gradient checkpointing for the current model |
| lora_path: load LoRA checkpoint |
| unet_path: load unet checkpoint |
| |
| Returns: |
| diffusers.StableDiffusionPipeline |
| """ |
|
|
| |
| model_id = huggingface_model_dict.get(model_id, model_id) |
|
|
| |
| if custom_scheduler is not None: |
| pipeline = custom_pipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| local_files_only=local_files_only, |
| force_download=force_download, |
| resume_download=resume_download, |
| scheduler=custom_scheduler.from_pretrained(model_id, |
| subfolder="scheduler", |
| local_files_only=local_files_only) |
| ).to(device) |
| else: |
| pipeline = custom_pipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch_dtype, |
| local_files_only=local_files_only, |
| force_download=force_download, |
| resume_download=resume_download, |
| ).to(device) |
|
|
| |
| if unet_path is not None and pathlib.Path(unet_path).exists(): |
| print(f"=> load u-net from {unet_path}") |
| pipeline.unet.from_pretrained(model_id, subfolder="unet") |
|
|
| |
| if lora_path is not None and pathlib.Path(lora_path).exists(): |
| pipeline.unet.load_attn_procs(lora_path) |
| print(f"=> load lora layers into U-Net from {lora_path} ...") |
|
|
| |
| if ldm_speed_up: |
| if is_torch_version(">=", "2.0.0"): |
| pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True) |
| print(f"=> enable torch.compile on U-Net") |
| else: |
| print(f"=> warning: calling torch.compile speed-up failed, since torch version <= 2.0.0") |
|
|
| |
| if enable_xformers: |
| if is_xformers_available(): |
| import xformers |
|
|
| xformers_version = version.parse(xformers.__version__) |
| if xformers_version == version.parse("0.0.16"): |
| print( |
| "xFormers 0.0.16 cannot be used for training in some GPUs. " |
| "If you observe problems during training, please update xFormers to at least 0.0.17. " |
| "See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." |
| ) |
| print(f"=> enable xformers") |
| pipeline.unet.enable_xformers_memory_efficient_attention() |
| else: |
| print(f"=> warning: calling xformers failed") |
|
|
| |
| if gradient_checkpoint: |
| if pipeline.unet.is_gradient_checkpointing: |
| print(f"=> enable gradient checkpointing") |
| pipeline.unet.enable_gradient_checkpointing() |
| else: |
| print("=> waring: gradient checkpointing is not activated for this model.") |
|
|
| print(f"Diffusion Model: {model_id}") |
| print(pipeline.scheduler) |
| return pipeline |
|
|