Spaces:
Paused
Paused
| import tensorflow as tf | |
| import math | |
| from tensorflow import keras | |
| from keras.models import load_model | |
| def as_float32(t: tf.Tensor) -> tf.Tensor: | |
| return tf.cast(t, dtype=tf.float32) | |
| def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor: | |
| def inner_function(coeff: tf.Tensor) -> tf.Tensor: | |
| batch_dim = tf.shape(x)[0] | |
| return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1]) | |
| return inner_function | |
| class DiffusionSampler: | |
| def __init__( | |
| self, | |
| model: keras.Model | str, | |
| ema_model: keras.Model | str, | |
| timesteps: int | None = 1_000, | |
| beta_start: float | None = 1e-4, | |
| beta_end: float | None = 0.02, | |
| noise_scheduler: str = "linear", | |
| ema: float = 0.999, | |
| ): | |
| self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model | |
| self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model, str) else ema_model | |
| self.ema = ema | |
| self.beta_start = beta_start | |
| self.beta_end = beta_end | |
| self.timesteps = timesteps | |
| betas = self.noise_scheduler(noise_scheduler) | |
| alphas = 1.0 - betas | |
| alphas_cum_prod = tf.math.cumprod(alphas, axis=0) | |
| alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0) | |
| posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod) | |
| self.betas = as_float32(betas) | |
| self.posterior_variances = as_float32(posterior_variances) | |
| self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev) | |
| self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod) | |
| self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev) | |
| self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod)) | |
| self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev)) | |
| self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod)) | |
| self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod)) | |
| self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas)) | |
| def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor: | |
| batch_dim = tf.shape(x_t)[0] | |
| at_timestep = batch_reshape(t, x_t) | |
| beta = at_timestep(self.betas) | |
| rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas) | |
| sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod) | |
| posterior_variance = at_timestep(self.posterior_variances) | |
| mean = rev_sqrt_alpha * ( | |
| x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise | |
| ) | |
| nonzero_mask = tf.reshape( | |
| 1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1] | |
| ) | |
| random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype) | |
| return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise | |
| def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor: | |
| at_timestep = batch_reshape(t, x_t) | |
| sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev) | |
| rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod) | |
| sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod) | |
| alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev) | |
| one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod) | |
| one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev) | |
| x0_t = ( | |
| (x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod | |
| ) | |
| c1 = eta * tf.sqrt( | |
| (one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * ( | |
| one_minus_alpha_cum_prod / alpha_cum_prod_prev) | |
| ) | |
| x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1)) | |
| random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype) | |
| return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise | |
| def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor: | |
| pi, T = [tf.constant(num, dtype=tf.float64) for num in (math.pi, self.timesteps)] | |
| alpha_bar = lambda i: tf.math.cos((i + 0.008) / 1.008 * pi / 2) ** 2 | |
| cosine_scheduler = lambda t: tf.minimum(1 - alpha_bar((t + 1) / T) / alpha_bar(t / T), max_beta) | |
| if scheduler == "linear": | |
| x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps) | |
| return tf.cast(x, dtype=tf.float64) | |
| elif scheduler == "cosine": | |
| x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64)) | |
| return tf.cast(x, dtype=tf.float64) | |
| def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor: | |
| at_timestep = batch_reshape(t, x_start) | |
| sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod) | |
| sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod) | |
| return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise | |
| def generate_images( | |
| self, | |
| num_images: int, | |
| steps: int, | |
| sample_strategy: str = "ddim", | |
| step_strategy: str = "uniform", | |
| ema: bool = True, | |
| ): | |
| sampling_stategies = { | |
| ("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)), | |
| ("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)), | |
| ("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)), | |
| ("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64)) | |
| } | |
| noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor | |
| sampler, seq = sampling_stategies[(sample_strategy, step_strategy)] | |
| samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32) | |
| for t in tf.reverse(seq, axis=[0]): | |
| tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64) | |
| pred_noise = noise_predictor([samples, tt], training=False) | |
| samples = sampler(pred_noise, samples, tt, ) | |
| return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0) | |