diffusion_model / diffusion_sampler.py
leowajda's picture
minor improvements
b7f0ee7
import tensorflow as tf
import math
from tensorflow import keras
from keras.models import load_model
def as_float32(t: tf.Tensor) -> tf.Tensor:
return tf.cast(t, dtype=tf.float32)
def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor:
def inner_function(coeff: tf.Tensor) -> tf.Tensor:
batch_dim = tf.shape(x)[0]
return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1])
return inner_function
class DiffusionSampler:
def __init__(
self,
model: keras.Model | str,
ema_model: keras.Model | str,
timesteps: int | None = 1_000,
beta_start: float | None = 1e-4,
beta_end: float | None = 0.02,
noise_scheduler: str = "linear",
ema: float = 0.999,
):
self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model
self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model, str) else ema_model
self.ema = ema
self.beta_start = beta_start
self.beta_end = beta_end
self.timesteps = timesteps
betas = self.noise_scheduler(noise_scheduler)
alphas = 1.0 - betas
alphas_cum_prod = tf.math.cumprod(alphas, axis=0)
alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0)
posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod)
self.betas = as_float32(betas)
self.posterior_variances = as_float32(posterior_variances)
self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev)
self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod)
self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev)
self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod))
self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev))
self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod))
self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod))
self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas))
def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor:
batch_dim = tf.shape(x_t)[0]
at_timestep = batch_reshape(t, x_t)
beta = at_timestep(self.betas)
rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
posterior_variance = at_timestep(self.posterior_variances)
mean = rev_sqrt_alpha * (
x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise
)
nonzero_mask = tf.reshape(
1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1]
)
random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise
def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor:
at_timestep = batch_reshape(t, x_t)
sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev)
rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev)
one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod)
one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev)
x0_t = (
(x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod
)
c1 = eta * tf.sqrt(
(one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * (
one_minus_alpha_cum_prod / alpha_cum_prod_prev)
)
x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1))
random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise
def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
pi, T = [tf.constant(num, dtype=tf.float64) for num in (math.pi, self.timesteps)]
alpha_bar = lambda i: tf.math.cos((i + 0.008) / 1.008 * pi / 2) ** 2
cosine_scheduler = lambda t: tf.minimum(1 - alpha_bar((t + 1) / T) / alpha_bar(t / T), max_beta)
if scheduler == "linear":
x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
return tf.cast(x, dtype=tf.float64)
elif scheduler == "cosine":
x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64))
return tf.cast(x, dtype=tf.float64)
def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
at_timestep = batch_reshape(t, x_start)
sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
@tf.function
def generate_images(
self,
num_images: int,
steps: int,
sample_strategy: str = "ddim",
step_strategy: str = "uniform",
ema: bool = True,
):
sampling_stategies = {
("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)),
("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64))
}
noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor
sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)
for t in tf.reverse(seq, axis=[0]):
tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
pred_noise = noise_predictor([samples, tt], training=False)
samples = sampler(pred_noise, samples, tt, )
return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)