Unconditional Image Generation
Diffusers
Safetensors
English
bitdance
imagenet
class-conditional
custom-pipeline
Instructions to use BiliSakura/BitDance-ImageNet-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/BitDance-ImageNet-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/BitDance-ImageNet-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| from .diff_head import DiffHead | |
| from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d | |
| from .qae import VQModel | |
| def get_model_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model", type=str, choices=list(BitDance_models.keys()), default="BitDance-L" | |
| ) | |
| parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) | |
| parser.add_argument("--down-size", type=int, default=16, choices=[16]) | |
| parser.add_argument("--patch-size", type=int, default=1, choices=[1, 2, 4]) | |
| parser.add_argument("--num-classes", type=int, default=1000) | |
| parser.add_argument("--cls-token-num", type=int, default=64) | |
| parser.add_argument("--latent-dim", type=int, default=16) | |
| parser.add_argument("--diff-batch-mul", type=int, default=4) | |
| parser.add_argument("--grad-checkpointing", action="store_true") | |
| parser.add_argument("--trained-vae", type=str, default="") | |
| parser.add_argument("--drop-rate", type=float, default=0.0) | |
| parser.add_argument("--perturb-schedule", type=str, default="constant") | |
| parser.add_argument("--perturb-rate", type=float, default=0.0) | |
| parser.add_argument("--perturb-rate-max", type=float, default=0.3) | |
| parser.add_argument("--time-schedule", type=str, default='logit_normal') | |
| parser.add_argument("--time-shift", type=float, default=1.) | |
| parser.add_argument("--P-std", type=float, default=1.) | |
| parser.add_argument("--P-mean", type=float, default=0.) | |
| return parser | |
| def create_model(args, device): | |
| model = BitDance_models[args.model]( | |
| resolution=args.image_size, | |
| down_size=args.down_size, | |
| patch_size=args.patch_size, | |
| latent_dim=args.latent_dim, | |
| diff_batch_mul=args.diff_batch_mul, | |
| cls_token_num=args.cls_token_num, | |
| num_classes=args.num_classes, | |
| grad_checkpointing=args.grad_checkpointing, | |
| trained_vae=args.trained_vae, | |
| drop_rate=args.drop_rate, | |
| perturb_schedule=args.perturb_schedule, | |
| perturb_rate=args.perturb_rate, | |
| perturb_rate_max=args.perturb_rate_max, | |
| time_schedule=args.time_schedule, | |
| time_shift=args.time_shift, | |
| P_std=args.P_std, | |
| P_mean=args.P_mean, | |
| ).to(device, memory_format=torch.channels_last) | |
| return model | |
| class MLPConnector(nn.Module): | |
| def __init__(self, in_dim, dim, dropout_p=0.0): | |
| super().__init__() | |
| hidden_dim = int(dim * 1.5) | |
| self.w1 = nn.Linear(in_dim, hidden_dim * 2, bias=True) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=True) | |
| self.ffn_dropout = nn.Dropout(dropout_p) | |
| def forward(self, x): | |
| h1, h2 = self.w1(x).chunk(2, dim=-1) | |
| return self.ffn_dropout(self.w2(F.silu(h1) * h2)) | |
| def flip_tensor_elements_uniform_prob(tensor: torch.Tensor, p_max: float) -> torch.Tensor: | |
| if not 0.0 <= p_max <= 1.0: | |
| raise ValueError(f"p_max must be in [0.0, 1.0] range, but got: {p_max}") | |
| r1 = torch.rand_like(tensor) | |
| r2 = torch.rand_like(tensor) | |
| flip_mask = r1 < p_max * r2 | |
| multiplier = torch.where(flip_mask, -1.0, 1.0) | |
| multiplier = multiplier.to(tensor.dtype) | |
| flipped_tensor = tensor * multiplier | |
| return flipped_tensor | |
| class BitDance(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| n_layer, | |
| n_head, | |
| diff_layers, | |
| diff_dim, | |
| diff_adanln_layers, | |
| latent_dim, | |
| down_size, | |
| patch_size, | |
| resolution, | |
| diff_batch_mul, | |
| grad_checkpointing=False, | |
| cls_token_num=16, | |
| num_classes: int = 1000, | |
| class_dropout_prob: float = 0.1, | |
| trained_vae: str = "", | |
| drop_rate: float = 0.0, | |
| perturb_schedule: str = "constant", | |
| perturb_rate: float = 0.0, | |
| perturb_rate_max: float = 0.3, | |
| time_schedule: str = 'logit_normal', | |
| time_shift: float = 1., | |
| P_std: float = 1., | |
| P_mean: float = 0., | |
| ): | |
| super().__init__() | |
| self.n_layer = n_layer | |
| self.resolution = resolution | |
| self.down_size = down_size | |
| self.patch_size = patch_size | |
| self.num_classes = num_classes | |
| self.cls_token_num = cls_token_num | |
| self.class_dropout_prob = class_dropout_prob | |
| self.latent_dim = latent_dim | |
| self.trained_vae = trained_vae | |
| self.perturb_schedule = perturb_schedule | |
| self.perturb_rate = perturb_rate | |
| self.perturb_rate_max = perturb_rate_max | |
| # define the vae and mar model | |
| ddconfig = { | |
| "double_z": False, | |
| "z_channels": latent_dim, | |
| "in_channels": 3, | |
| "out_ch": 3, | |
| "ch": 256, | |
| "ch_mult": [1,1,2,2,4], | |
| "num_res_blocks": 4 | |
| } | |
| num_codebooks = 4 | |
| # print(f"loading vae unexpected_keys: {unexpected_keys}") | |
| self.vae = VQModel(ddconfig, num_codebooks) | |
| self.grad_checkpointing = grad_checkpointing | |
| self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num) | |
| self.proj_in = MLPConnector(latent_dim * self.patch_size * self.patch_size, dim, drop_rate) | |
| self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) | |
| self.h, self.w = resolution // (down_size * patch_size), resolution // (down_size * patch_size) | |
| self.total_tokens = self.h * self.w + self.cls_token_num | |
| self.layers = torch.nn.ModuleList() | |
| for layer_id in range(n_layer): | |
| self.layers.append( | |
| TransformerBlock( | |
| dim, | |
| n_head, | |
| resid_dropout_p=drop_rate, | |
| causal=True, | |
| ) | |
| ) | |
| self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) | |
| self.pos_for_diff = nn.Embedding(self.h * self.w, dim) | |
| self.head = DiffHead( | |
| ch_target=latent_dim * self.patch_size * self.patch_size, | |
| ch_cond=dim, | |
| ch_latent=diff_dim, | |
| depth_latent=diff_layers, | |
| depth_adanln=diff_adanln_layers, | |
| grad_checkpointing=grad_checkpointing, | |
| time_shift=time_shift, | |
| time_schedule=time_schedule, | |
| P_std=P_std, | |
| P_mean=P_mean, | |
| ) | |
| self.diff_batch_mul = diff_batch_mul | |
| patch_2d_pos = get_2d_pos(resolution, int(down_size * patch_size)) | |
| self.register_buffer( | |
| "freqs_cis", | |
| precompute_freqs_cis_2d( | |
| patch_2d_pos, | |
| dim // n_head, | |
| 10000, | |
| cls_token_num=self.cls_token_num, | |
| )[:-1], | |
| persistent=False, | |
| ) | |
| self.freeze_vae() | |
| self.initialize_weights() | |
| def load_vae_weight(self): | |
| state = torch.load( | |
| self.trained_vae, | |
| map_location="cpu", | |
| ) | |
| missing_keys, unexpected_keys = self.vae.load_state_dict(state["state_dict"], strict=False) | |
| print(f"loading vae, missing_keys: {missing_keys}") | |
| del state | |
| def non_decay_keys(self): | |
| return ["proj_in", "cls_embedding"] | |
| def freeze_module(self, module: nn.Module): | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| def freeze_vae(self): | |
| self.freeze_module(self.vae) | |
| self.vae.eval() | |
| def initialize_weights(self): | |
| # Initialize nn.Linear and nn.Embedding | |
| self.apply(self.__init_weights) | |
| self.head.initialize_weights() | |
| # self.vae.initialize_weights() | |
| def __init_weights(self, module): | |
| std = 0.02 | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| def drop_label(self, class_id): | |
| if self.class_dropout_prob > 0.0 and self.training: | |
| is_drop = ( | |
| torch.rand(class_id.shape, device=class_id.device) | |
| < self.class_dropout_prob | |
| ) | |
| class_id = torch.where(is_drop, self.num_classes, class_id) | |
| return class_id | |
| def patchify(self, x): | |
| bsz, c, h, w = x.shape | |
| p = self.patch_size | |
| h_, w_ = h // p, w // p | |
| x = x.reshape(bsz, c, h_, p, w_, p) | |
| x = torch.einsum('nchpwq->nhwcpq', x) | |
| x = x.reshape(bsz, h_ * w_, c * p ** 2) | |
| return x # [n, l, d] | |
| def unpatchify(self, x): | |
| bsz = x.shape[0] | |
| p = self.patch_size | |
| c = self.latent_dim | |
| h_, w_ = self.h, self.w | |
| x = x.reshape(bsz, h_, w_, c, p, p) | |
| x = torch.einsum('nhwcpq->nchpwq', x) | |
| x = x.reshape(bsz, c, h_ * p, w_ * p) | |
| return x # [n, c, h, w] | |
| def forward( | |
| self, | |
| images, | |
| class_id, | |
| cached=False | |
| ): | |
| if cached: | |
| vae_latent = images | |
| else: | |
| vae_latent, _, _, _ = self.vae.encode(images) # b c h w | |
| vae_latent = self.patchify(vae_latent) | |
| x = vae_latent.clone().detach() | |
| if self.training: | |
| if self.perturb_schedule =="constant": | |
| x = flip_tensor_elements_uniform_prob(x, self.perturb_rate) | |
| else: | |
| raise NotImplementedError(f"unknown perturb_schedule {self.perturb_schedule}") | |
| x = self.proj_in(x[:, :-1, :]) | |
| class_id = self.drop_label(class_id) | |
| bsz = x.shape[0] | |
| c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1) | |
| x = torch.cat([c, x], dim=1) | |
| x = self.emb_norm(x) | |
| if self.grad_checkpointing and self.training: | |
| for layer in self.layers: | |
| block = partial(layer.forward, freqs_cis=self.freqs_cis) | |
| x = checkpoint(block, x, use_reentrant=False) | |
| else: | |
| for layer in self.layers: | |
| x = layer(x, self.freqs_cis) | |
| x = x[:, -self.h * self.w :, :] | |
| x = self.norm(x) | |
| x = x + self.pos_for_diff.weight | |
| target = vae_latent.clone().detach() | |
| x = x.view(-1, x.shape[-1]) | |
| target = target.view(-1, target.shape[-1]) | |
| x = x.repeat(self.diff_batch_mul, 1) | |
| target = target.repeat(self.diff_batch_mul, 1) | |
| loss = self.head(target, x) | |
| return loss | |
| def enable_kv_cache(self, bsz): | |
| for layer in self.layers: | |
| layer.attention.enable_kv_cache(bsz, self.total_tokens) | |
| def forward_model(self, x, start_pos, end_pos): | |
| x = self.emb_norm(x) | |
| for layer in self.layers: | |
| x = layer.forward_onestep( | |
| x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos | |
| ) | |
| x = self.norm(x) | |
| return x | |
| def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"): | |
| x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :] | |
| x = x.view(-1, x.shape[-1]) | |
| seq_len = self.h * self.w | |
| if cfg_scale > 1.0: | |
| if cfg_schedule == "constant": | |
| cfg_iter = cfg_scale | |
| elif cfg_schedule == "linear": | |
| start = 1.0 | |
| cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len | |
| else: | |
| raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}") | |
| else: | |
| cfg_iter = 1.0 | |
| pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter) | |
| pred = pred.view(-1, 1, pred.shape[-1]) | |
| # Important: LFQ here, sign the prediction | |
| pred = torch.sign(pred) | |
| return pred | |
| def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear", chunk_size=0): | |
| self.eval() | |
| if cfg_scale > 1.0: | |
| cond_null = torch.ones_like(cond) * self.num_classes | |
| cond_combined = torch.cat([cond, cond_null]) | |
| else: | |
| cond_combined = cond | |
| bsz = cond_combined.shape[0] | |
| act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz | |
| self.enable_kv_cache(bsz) | |
| c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1) | |
| last_pred = None | |
| all_preds = [] | |
| for i in range(self.h * self.w): | |
| if i == 0: | |
| x = self.forward_model(c, 0, self.cls_token_num) | |
| else: | |
| x = self.proj_in(last_pred) | |
| x = self.forward_model( | |
| x, i + self.cls_token_num - 1, i + self.cls_token_num | |
| ) | |
| last_pred = self.head_sample( | |
| x[:, -1:, :], | |
| i, | |
| sample_steps, | |
| cfg_scale, | |
| cfg_schedule, | |
| ) | |
| all_preds.append(last_pred) | |
| x = torch.cat(all_preds, dim=-2)[:act_bsz] | |
| if x.dim() == 3: #b n c -> b c h w | |
| x = self.unpatchify(x) | |
| if chunk_size > 0: | |
| recon = self.decode_in_chunks(x, chunk_size) | |
| else: | |
| recon = self.vae.decode(x) | |
| return recon | |
| def decode_in_chunks(self, latent_tensor, chunk_size=64): | |
| total_bsz = latent_tensor.shape[0] | |
| recon_chunks_on_cpu = [] | |
| with torch.no_grad(): | |
| for i in range(0, total_bsz, chunk_size): | |
| end_idx = min(i + chunk_size, total_bsz) | |
| latent_chunk = latent_tensor[i:end_idx] | |
| recon_chunk = self.vae.decode(latent_chunk) | |
| recon_chunks_on_cpu.append(recon_chunk.cpu()) | |
| return torch.cat(recon_chunks_on_cpu, dim=0) | |
| def get_fsdp_wrap_module_list(self): | |
| return list(self.layers) | |
| def BitDance_H(**kwargs): | |
| return BitDance( | |
| n_layer=40, | |
| n_head=20, | |
| dim=1280, | |
| diff_layers=12, | |
| diff_dim=1280, | |
| diff_adanln_layers=3, | |
| **kwargs, | |
| ) | |
| def BitDance_L(**kwargs): | |
| return BitDance( | |
| n_layer=32, | |
| n_head=16, | |
| dim=1024, | |
| diff_layers=8, | |
| diff_dim=1024, | |
| diff_adanln_layers=2, | |
| **kwargs, | |
| ) | |
| def BitDance_B(**kwargs): | |
| return BitDance( | |
| n_layer=24, | |
| n_head=12, | |
| dim=768, | |
| diff_layers=6, | |
| diff_dim=768, | |
| diff_adanln_layers=2, | |
| **kwargs, | |
| ) | |
| BitDance_models = { | |
| "BitDance-B": BitDance_B, | |
| "BitDance-L": BitDance_L, | |
| "BitDance-H": BitDance_H, | |
| } | |