File size: 6,562 Bytes
7578496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7f0ee7
7578496
 
 
 
 
 
 
 
 
 
 
e7c9be9
7578496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e671de5
d342810
e671de5
 
7578496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7f0ee7
7578496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea54e12
7578496
ea54e12
7578496
 
ea54e12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import tensorflow as tf
import math
from tensorflow import keras
from keras.models import load_model


def as_float32(t: tf.Tensor) -> tf.Tensor:
    return tf.cast(t, dtype=tf.float32)


def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor:
    def inner_function(coeff: tf.Tensor) -> tf.Tensor:
        batch_dim = tf.shape(x)[0]
        return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1])

    return inner_function


class DiffusionSampler:
    def __init__(
        self,
        model: keras.Model | str,
        ema_model: keras.Model | str,
        timesteps: int | None = 1_000,
        beta_start: float | None = 1e-4,
        beta_end: float | None = 0.02,
        noise_scheduler: str = "linear",
        ema: float = 0.999,
    ):
        self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model
        self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model, str) else ema_model
        self.ema = ema
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.timesteps = timesteps

        betas = self.noise_scheduler(noise_scheduler)
        alphas = 1.0 - betas
        alphas_cum_prod = tf.math.cumprod(alphas, axis=0)
        alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0)
        posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod)

        self.betas = as_float32(betas)
        self.posterior_variances = as_float32(posterior_variances)
        self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev)
        self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod)
        self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev)

        self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod))
        self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev))
        self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod))

        self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod))
        self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas))

    def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor:
        batch_dim = tf.shape(x_t)[0]
        at_timestep = batch_reshape(t, x_t)

        beta = at_timestep(self.betas)
        rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas)
        sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
        posterior_variance = at_timestep(self.posterior_variances)

        mean = rev_sqrt_alpha * (
            x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise
        )

        nonzero_mask = tf.reshape(
            1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1]
        )

        random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
        return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise

    def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor:
        at_timestep = batch_reshape(t, x_t)

        sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev)
        rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod)
        sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
        alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev)
        one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod)
        one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev)

        x0_t = (
            (x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod
        )
        c1 = eta * tf.sqrt(
            (one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * (
                    one_minus_alpha_cum_prod / alpha_cum_prod_prev)
        )

        x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1))
        random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
        return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise

    def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
        pi, T = [tf.constant(num, dtype=tf.float64) for num in (math.pi, self.timesteps)]

        alpha_bar = lambda i: tf.math.cos((i + 0.008) / 1.008 * pi / 2) ** 2
        cosine_scheduler = lambda t: tf.minimum(1 - alpha_bar((t + 1) / T) / alpha_bar(t / T), max_beta)

        if scheduler == "linear":
            x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
            return tf.cast(x, dtype=tf.float64)

        elif scheduler == "cosine":
            x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64))
            return tf.cast(x, dtype=tf.float64)

    def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
        at_timestep = batch_reshape(t, x_start)

        sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod)
        sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
        return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise

    @tf.function
    def generate_images(
        self,
        num_images: int,
        steps: int,
        sample_strategy: str = "ddim",
        step_strategy: str = "uniform",
        ema: bool = True,
    ):
        sampling_stategies = {
            ("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
            ("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
            ("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)),
            ("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64))
        }

        noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor
        sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
        samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)

        for t in tf.reverse(seq, axis=[0]):
            tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
            pred_noise = noise_predictor([samples, tt], training=False)
            samples = sampler(pred_noise, samples, tt, )

        return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)