Spaces:
Paused
Paused
File size: 6,562 Bytes
7578496 b7f0ee7 7578496 e7c9be9 7578496 e671de5 d342810 e671de5 7578496 b7f0ee7 7578496 ea54e12 7578496 ea54e12 7578496 ea54e12 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import tensorflow as tf
import math
from tensorflow import keras
from keras.models import load_model
def as_float32(t: tf.Tensor) -> tf.Tensor:
return tf.cast(t, dtype=tf.float32)
def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor:
def inner_function(coeff: tf.Tensor) -> tf.Tensor:
batch_dim = tf.shape(x)[0]
return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1])
return inner_function
class DiffusionSampler:
def __init__(
self,
model: keras.Model | str,
ema_model: keras.Model | str,
timesteps: int | None = 1_000,
beta_start: float | None = 1e-4,
beta_end: float | None = 0.02,
noise_scheduler: str = "linear",
ema: float = 0.999,
):
self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model
self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model, str) else ema_model
self.ema = ema
self.beta_start = beta_start
self.beta_end = beta_end
self.timesteps = timesteps
betas = self.noise_scheduler(noise_scheduler)
alphas = 1.0 - betas
alphas_cum_prod = tf.math.cumprod(alphas, axis=0)
alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0)
posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod)
self.betas = as_float32(betas)
self.posterior_variances = as_float32(posterior_variances)
self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev)
self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod)
self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev)
self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod))
self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev))
self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod))
self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod))
self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas))
def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor:
batch_dim = tf.shape(x_t)[0]
at_timestep = batch_reshape(t, x_t)
beta = at_timestep(self.betas)
rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
posterior_variance = at_timestep(self.posterior_variances)
mean = rev_sqrt_alpha * (
x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise
)
nonzero_mask = tf.reshape(
1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1]
)
random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise
def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor:
at_timestep = batch_reshape(t, x_t)
sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev)
rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev)
one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod)
one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev)
x0_t = (
(x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod
)
c1 = eta * tf.sqrt(
(one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * (
one_minus_alpha_cum_prod / alpha_cum_prod_prev)
)
x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1))
random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise
def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
pi, T = [tf.constant(num, dtype=tf.float64) for num in (math.pi, self.timesteps)]
alpha_bar = lambda i: tf.math.cos((i + 0.008) / 1.008 * pi / 2) ** 2
cosine_scheduler = lambda t: tf.minimum(1 - alpha_bar((t + 1) / T) / alpha_bar(t / T), max_beta)
if scheduler == "linear":
x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
return tf.cast(x, dtype=tf.float64)
elif scheduler == "cosine":
x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64))
return tf.cast(x, dtype=tf.float64)
def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
at_timestep = batch_reshape(t, x_start)
sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
@tf.function
def generate_images(
self,
num_images: int,
steps: int,
sample_strategy: str = "ddim",
step_strategy: str = "uniform",
ema: bool = True,
):
sampling_stategies = {
("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)),
("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64))
}
noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor
sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)
for t in tf.reverse(seq, axis=[0]):
tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
pred_noise = noise_predictor([samples, tt], training=False)
samples = sampler(pred_noise, samples, tt, )
return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)
|