leowajda commited on
Commit
ea54e12
·
1 Parent(s): 1fcedcb

compile sampling step with autograph

Browse files
Files changed (2) hide show
  1. app.py +7 -4
  2. diffusion_sampler.py +4 -7
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import tensorflow as tf
 
3
  from huggingface_hub import from_pretrained_keras
4
  from diffusion_sampler import DiffusionSampler
5
 
@@ -71,7 +72,7 @@ step_button = gr.Slider(
71
  )
72
 
73
  gallery = gr.Gallery(
74
- columns=5,
75
  label="""
76
  Generated Flowers
77
  """
@@ -101,14 +102,16 @@ def call_model(
101
  progress=gr.Progress(track_tqdm=True),
102
  ):
103
  diffusion_model = linear_diffusion_model if model_to_call.lower() == "linear" else cosine_diffusion_model
104
- return diffusion_model.generate_images(
105
- num_images=int(num_images),
106
- steps=int(steps),
107
  sample_strategy=sample_strategy.lower(),
108
  step_strategy=step_strategy.lower(),
109
  ema=ema,
110
  )
111
 
 
 
112
 
113
  demo = gr.Interface(
114
  fn=call_model,
 
1
  import gradio as gr
2
  import tensorflow as tf
3
+ import numpy as np
4
  from huggingface_hub import from_pretrained_keras
5
  from diffusion_sampler import DiffusionSampler
6
 
 
72
  )
73
 
74
  gallery = gr.Gallery(
75
+ columns=4,
76
  label="""
77
  Generated Flowers
78
  """
 
102
  progress=gr.Progress(track_tqdm=True),
103
  ):
104
  diffusion_model = linear_diffusion_model if model_to_call.lower() == "linear" else cosine_diffusion_model
105
+ images = diffusion_model.generate_images(
106
+ num_images=num_images,
107
+ steps=steps,
108
  sample_strategy=sample_strategy.lower(),
109
  step_strategy=step_strategy.lower(),
110
  ema=ema,
111
  )
112
 
113
+ return images.numpy().astype(np.uint8)
114
+
115
 
116
  demo = gr.Interface(
117
  fn=call_model,
diffusion_sampler.py CHANGED
@@ -120,6 +120,7 @@ class DiffusionSampler(keras.Model):
120
  sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
121
  return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
122
 
 
123
  def generate_images(
124
  self,
125
  num_images: int,
@@ -139,13 +140,9 @@ class DiffusionSampler(keras.Model):
139
  sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
140
  samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)
141
 
142
- for t in tqdm.tqdm(tf.reverse(seq, axis=[0])):
143
  tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
144
- pred_noise = noise_predictor.predict([samples, tt], verbose=0, batch_size=num_images)
145
  samples = sampler(pred_noise, samples, tt, )
146
 
147
- return (
148
- tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)
149
- .numpy()
150
- .astype(np.uint8)
151
- )
 
120
  sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
121
  return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
122
 
123
+ @tf.function()
124
  def generate_images(
125
  self,
126
  num_images: int,
 
140
  sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
141
  samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)
142
 
143
+ for t in tf.reverse(seq, axis=[0]):
144
  tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
145
+ pred_noise = noise_predictor([samples, tt], training=False)
146
  samples = sampler(pred_noise, samples, tt, )
147
 
148
+ return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)