Unconditional Image Generation
Diffusers
Safetensors
English
bitdance
imagenet
class-conditional
custom-pipeline
Instructions to use BiliSakura/BitDance-ImageNet-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/BitDance-ImageNet-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/BitDance-ImageNet-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| import torch | |
| from safetensors.torch import load_file as load_safetensors | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| # NOTE: Diffusers dynamic module loader only copies directly-referenced relative imports. | |
| # These guarded imports are intentionally never executed, but they force dependent files | |
| # (and their siblings) to be copied into the dynamic module cache. | |
| if False: # pragma: no cover | |
| from .model import BitDance_B as _BD_B_STD | |
| from .model import BitDance_H as _BD_H_STD | |
| from .model import BitDance_L as _BD_L_STD | |
| from .model_parallel import BitDance_B as _BD_B_PAR | |
| from .model_parallel import BitDance_H as _BD_H_PAR | |
| from .model_parallel import BitDance_L as _BD_L_PAR | |
| from .diff_head import DiffHead as _DiffHead | |
| from .diff_head_parallel import DiffHead as _DiffHeadParallel | |
| from .layers import TransformerBlock as _TB | |
| from .layers_parallel import TransformerBlock as _TBP | |
| from .qae import VQModel as _VQ | |
| from .gfq import GFQ as _GFQ | |
| from .sampling import euler_maruyama as _EM | |
| from .sampling_parallel import euler_maruyama as _EMP | |
| from .utils import patchify_raster as _PR | |
| class BitDanceImageNetTransformer(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| architecture: str, | |
| parallel_num: int, | |
| resolution: int, | |
| down_size: int, | |
| latent_dim: int, | |
| num_classes: int, | |
| runtime_impl: str, | |
| parallel_mode: str = "patch", | |
| time_schedule: str = "logit_normal", | |
| time_shift: float = 1.0, | |
| p_std: float = 1.0, | |
| p_mean: float = 0.0, | |
| ): | |
| super().__init__() | |
| kwargs = dict( | |
| resolution=resolution, | |
| down_size=down_size, | |
| patch_size=1, | |
| latent_dim=latent_dim, | |
| diff_batch_mul=4, | |
| cls_token_num=64, | |
| num_classes=num_classes, | |
| grad_checkpointing=False, | |
| trained_vae="", | |
| drop_rate=0.0, | |
| perturb_schedule="constant", | |
| perturb_rate=0.0, | |
| perturb_rate_max=0.3, | |
| time_schedule=time_schedule, | |
| time_shift=time_shift, | |
| P_std=p_std, | |
| P_mean=p_mean, | |
| ) | |
| if runtime_impl == "model_parallel.py" or parallel_num > 1: | |
| from .model_parallel import BitDance_B, BitDance_H, BitDance_L | |
| ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} | |
| kwargs.update(parallel_num=parallel_num, parallel_mode=parallel_mode) | |
| else: | |
| from .model import BitDance_B, BitDance_H, BitDance_L | |
| ctors = {"BitDance-B": BitDance_B, "BitDance-L": BitDance_L, "BitDance-H": BitDance_H} | |
| self.runtime_model = ctors[architecture](**kwargs) | |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs): | |
| del args, kwargs | |
| model_dir = Path(pretrained_model_name_or_path) | |
| config = json.loads((model_dir / "config.json").read_text(encoding="utf-8")) | |
| model = cls( | |
| architecture=config["architecture"], | |
| parallel_num=int(config["parallel_num"]), | |
| resolution=int(config["resolution"]), | |
| down_size=int(config["down_size"]), | |
| latent_dim=int(config["latent_dim"]), | |
| num_classes=int(config["num_classes"]), | |
| runtime_impl=config["runtime_impl"], | |
| parallel_mode=config.get("parallel_mode", "patch"), | |
| time_schedule=config.get("time_schedule", "logit_normal"), | |
| time_shift=float(config.get("time_shift", 1.0)), | |
| p_std=float(config.get("p_std", 1.0)), | |
| p_mean=float(config.get("p_mean", 0.0)), | |
| ) | |
| state = load_safetensors(model_dir / "diffusion_pytorch_model.safetensors") | |
| model.runtime_model.load_state_dict(state, strict=True) | |
| model.eval() | |
| return model | |
| def sample( | |
| self, | |
| class_ids: torch.Tensor, | |
| sample_steps: int = 100, | |
| cfg_scale: float = 4.6, | |
| chunk_size: int = 0, | |
| ) -> torch.Tensor: | |
| return self.runtime_model.sample( | |
| cond=class_ids, | |
| sample_steps=sample_steps, | |
| cfg_scale=cfg_scale, | |
| chunk_size=chunk_size, | |
| ) | |
| def forward(self, *args, **kwargs): | |
| return self.runtime_model(*args, **kwargs) | |