Unconditional Image Generation
Diffusers
Safetensors
English
bitdance
imagenet
class-conditional
custom-pipeline
Instructions to use BiliSakura/BitDance-ImageNet-diffusers with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use BiliSakura/BitDance-ImageNet-diffusers with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("BiliSakura/BitDance-ImageNet-diffusers", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| def time_shift_sana(t: torch.Tensor, flow_shift: float = 1., sigma: float = 1.): | |
| return (1 / flow_shift) / ( (1 / flow_shift) + (1 / t - 1) ** sigma) | |
| def get_score_from_velocity(velocity, x, t): | |
| alpha_t, d_alpha_t = t, 1 | |
| sigma_t, d_sigma_t = 1 - t, -1 | |
| mean = x | |
| reverse_alpha_ratio = alpha_t / d_alpha_t | |
| var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t | |
| score = (reverse_alpha_ratio * velocity - mean) / var | |
| return score | |
| def get_velocity_from_cfg(velocity, cfg, cfg_mult): | |
| if cfg_mult == 2: | |
| cond_v, uncond_v = torch.chunk(velocity, 2, dim=0) | |
| velocity = uncond_v + cfg * (cond_v - uncond_v) | |
| return velocity | |
| def euler_step(x, v, dt: float, cfg: float, cfg_mult: int): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| v = v.to(torch.float32) | |
| v = get_velocity_from_cfg(v, cfg, cfg_mult) | |
| x = x + v * dt | |
| return x | |
| def euler_maruyama_step(x, v, t, dt: float, cfg: float, cfg_mult: int): | |
| with torch.amp.autocast("cuda", enabled=False): | |
| v = v.to(torch.float32) | |
| v = get_velocity_from_cfg(v, cfg, cfg_mult) | |
| score = get_score_from_velocity(v, x, t) | |
| drift = v + (1 - t) * score | |
| noise_scale = (2.0 * (1.0 - t) * dt) ** 0.5 | |
| x = x + drift * dt + noise_scale * torch.randn_like(x) | |
| return x | |
| def euler_maruyama( | |
| input_dim, | |
| forward_fn, | |
| c: torch.Tensor, | |
| cfg: float = 1.0, | |
| num_sampling_steps: int = 20, | |
| last_step_size: float = 0.05, | |
| time_shift: float = 1., | |
| ): | |
| cfg_mult = 1 | |
| if cfg > 1.0: | |
| cfg_mult += 1 | |
| x_shape = list(c.shape) | |
| x_shape[0] = x_shape[0] // cfg_mult | |
| x_shape[-1] = input_dim | |
| x = torch.randn(x_shape, device=c.device) | |
| # an = (1.0 - last_step_size) / num_sampling_steps | |
| t_all = torch.linspace(0, 1-last_step_size, num_sampling_steps+1, device=c.device, dtype=torch.float32) | |
| t_all = time_shift_sana(t_all, time_shift) | |
| dt = t_all[1:] - t_all[:-1] | |
| t = torch.tensor( | |
| 0.0, device=c.device, dtype=torch.float32 | |
| ) # use tensor to avoid compile warning | |
| t_batch = torch.zeros(c.shape[0], device=c.device) | |
| for i in range(num_sampling_steps): | |
| t_batch[:] = t | |
| combined = torch.cat([x] * cfg_mult, dim=0) | |
| output = forward_fn( | |
| combined, | |
| t_batch, | |
| c, | |
| ) | |
| v = (output - combined) / (1 - t_batch.view(-1, 1, 1)).clamp_min(0.05) | |
| x = euler_maruyama_step(x, v, t, dt[i], cfg, cfg_mult) | |
| t += dt[i] | |
| combined = torch.cat([x] * cfg_mult, dim=0) | |
| t_batch[:] = 1 - last_step_size | |
| output = forward_fn( | |
| combined, | |
| t_batch, | |
| c, | |
| ) | |
| v = (output - combined) / (1 - t_batch.view(-1, 1, 1)).clamp_min(0.05) | |
| x = euler_step(x, v, last_step_size, cfg, cfg_mult) | |
| return torch.cat([x] * cfg_mult, dim=0) | |
| def euler( | |
| input_dim, | |
| forward_fn, | |
| c, | |
| cfg: float = 1.0, | |
| num_sampling_steps: int = 50, | |
| ): | |
| cfg_mult = 1 | |
| if cfg > 1.0: | |
| cfg_mult = 2 | |
| x_shape = list(c.shape) | |
| x_shape[0] = x_shape[0] // cfg_mult | |
| x_shape[-1] = input_dim | |
| x = torch.randn(x_shape, device=c.device) | |
| dt = 1.0 / num_sampling_steps | |
| t = 0 | |
| t_batch = torch.zeros(c.shape[0], device=c.device) | |
| for _ in range(num_sampling_steps): | |
| t_batch[:] = t | |
| combined = torch.cat([x] * cfg_mult, dim=0) | |
| v = forward_fn(combined, t_batch, c) | |
| x = euler_step(x, v, dt, cfg, cfg_mult) | |
| t += dt | |
| return torch.cat([x] * cfg_mult, dim=0) | |