import tensorflow as tf import math from tensorflow import keras from keras.models import load_model def as_float32(t: tf.Tensor) -> tf.Tensor: return tf.cast(t, dtype=tf.float32) def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor: def inner_function(coeff: tf.Tensor) -> tf.Tensor: batch_dim = tf.shape(x)[0] return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1]) return inner_function class DiffusionSampler: def __init__( self, model: keras.Model | str, ema_model: keras.Model | str, timesteps: int | None = 1_000, beta_start: float | None = 1e-4, beta_end: float | None = 0.02, noise_scheduler: str = "linear", ema: float = 0.999, ): self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model, str) else ema_model self.ema = ema self.beta_start = beta_start self.beta_end = beta_end self.timesteps = timesteps betas = self.noise_scheduler(noise_scheduler) alphas = 1.0 - betas alphas_cum_prod = tf.math.cumprod(alphas, axis=0) alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0) posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod) self.betas = as_float32(betas) self.posterior_variances = as_float32(posterior_variances) self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev) self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod) self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev) self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod)) self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev)) self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod)) self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod)) self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas)) def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor: batch_dim = tf.shape(x_t)[0] at_timestep = batch_reshape(t, x_t) beta = at_timestep(self.betas) rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas) sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod) posterior_variance = at_timestep(self.posterior_variances) mean = rev_sqrt_alpha * ( x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise ) nonzero_mask = tf.reshape( 1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1] ) random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype) return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor: at_timestep = batch_reshape(t, x_t) sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev) rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod) sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod) alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev) one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod) one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev) x0_t = ( (x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod ) c1 = eta * tf.sqrt( (one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * ( one_minus_alpha_cum_prod / alpha_cum_prod_prev) ) x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1)) random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype) return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor: pi, T = [tf.constant(num, dtype=tf.float64) for num in (math.pi, self.timesteps)] alpha_bar = lambda i: tf.math.cos((i + 0.008) / 1.008 * pi / 2) ** 2 cosine_scheduler = lambda t: tf.minimum(1 - alpha_bar((t + 1) / T) / alpha_bar(t / T), max_beta) if scheduler == "linear": x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps) return tf.cast(x, dtype=tf.float64) elif scheduler == "cosine": x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64)) return tf.cast(x, dtype=tf.float64) def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor: at_timestep = batch_reshape(t, x_start) sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod) sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod) return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise @tf.function def generate_images( self, num_images: int, steps: int, sample_strategy: str = "ddim", step_strategy: str = "uniform", ema: bool = True, ): sampling_stategies = { ("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)), ("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)), ("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)), ("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64)) } noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor sampler, seq = sampling_stategies[(sample_strategy, step_strategy)] samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32) for t in tf.reverse(seq, axis=[0]): tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64) pred_noise = noise_predictor([samples, tt], training=False) samples = sampler(pred_noise, samples, tt, ) return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)