Improve confidence guided noising and show number of tokens generated
Browse files
app.py
CHANGED
|
@@ -110,35 +110,48 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, clustering=0.5, noise
|
|
| 110 |
|
| 111 |
|
| 112 |
# Add new noising function
|
| 113 |
-
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start
|
| 114 |
noised = input_ids.copy()
|
| 115 |
answer_len = len(input_ids) - answer_start
|
| 116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
| 117 |
-
|
| 118 |
if num_to_noise == 0:
|
| 119 |
return noised
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
# Avoid zero-probability weights for selection
|
| 125 |
-
# If noise clipping == 1, all tokens have equal chance to be noised.
|
| 126 |
-
# If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
|
| 127 |
-
raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)
|
| 128 |
|
| 129 |
-
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
size=
|
| 137 |
replace=False,
|
| 138 |
-
p=
|
| 139 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
for idx in
|
| 142 |
noised[idx] = mask_token_id
|
| 143 |
|
| 144 |
return noised
|
|
@@ -256,11 +269,19 @@ def diffusion_chat(question, max_it, pause_length, sharpness, clustering, noise_
|
|
| 256 |
time.sleep(pause_length)
|
| 257 |
|
| 258 |
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
print(final_output)
|
| 263 |
-
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
|
|
|
| 264 |
|
| 265 |
# --- Gradio Interface ---
|
| 266 |
print("Loading model...")
|
|
@@ -271,11 +292,11 @@ demo = gr.Interface(
|
|
| 271 |
fn=diffusion_chat,
|
| 272 |
inputs=[
|
| 273 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
| 274 |
-
gr.Slider(1, 512, value=
|
| 275 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
|
| 276 |
-
gr.Slider(1.0, 20.0, value=
|
| 277 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
|
| 278 |
-
gr.Slider(0.0, 1.0, value=0.
|
| 279 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
| 280 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
|
| 281 |
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
# Add new noising function
|
| 113 |
+
def confidence_guided_noising(input_ids, answer_start, confidences, noise_clipping, threshold=1.0, noise_start=1.0):
|
| 114 |
noised = input_ids.copy()
|
| 115 |
answer_len = len(input_ids) - answer_start
|
| 116 |
num_to_noise = int(threshold * answer_len * noise_start)
|
|
|
|
| 117 |
if num_to_noise == 0:
|
| 118 |
return noised
|
| 119 |
|
| 120 |
+
all_indices = np.arange(answer_start, len(input_ids))
|
| 121 |
+
eos_indices = [i for i in all_indices if input_ids[i] == eos_token_id]
|
| 122 |
+
non_eos_indices = [i for i in all_indices if input_ids[i] != eos_token_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
num_non_eos_to_noise = int(num_to_noise * (len(non_eos_indices) / (len(non_eos_indices) + len(eos_indices) + 1e-5)))
|
| 125 |
+
num_eos_to_noise = num_to_noise - num_non_eos_to_noise
|
| 126 |
|
| 127 |
+
# === Non-EOS sampling ===
|
| 128 |
+
raw_weights_non_eos = 1.0 - np.array([confidences[i - answer_start] for i in non_eos_indices])
|
| 129 |
+
raw_weights_non_eos = np.clip(raw_weights_non_eos, a_min=noise_clipping, a_max=None)
|
| 130 |
+
weights_non_eos = raw_weights_non_eos / raw_weights_non_eos.sum() if raw_weights_non_eos.sum() > 0 else None
|
| 131 |
|
| 132 |
+
chosen_non_eos = rng.choice(
|
| 133 |
+
non_eos_indices,
|
| 134 |
+
size=min(num_non_eos_to_noise, len(non_eos_indices)),
|
| 135 |
replace=False,
|
| 136 |
+
p=weights_non_eos
|
| 137 |
+
) if weights_non_eos is not None else []
|
| 138 |
+
|
| 139 |
+
# === EOS sampling ===
|
| 140 |
+
if eos_indices:
|
| 141 |
+
raw_weights_eos = 1.0 - np.array([confidences[i - answer_start] for i in eos_indices])
|
| 142 |
+
raw_weights_eos = np.clip(raw_weights_eos, a_min=noise_clipping, a_max=None)
|
| 143 |
+
weights_eos = raw_weights_eos / raw_weights_eos.sum() if raw_weights_eos.sum() > 0 else None
|
| 144 |
+
|
| 145 |
+
chosen_eos = rng.choice(
|
| 146 |
+
eos_indices,
|
| 147 |
+
size=min(num_eos_to_noise, len(eos_indices)),
|
| 148 |
+
replace=False,
|
| 149 |
+
p=weights_eos
|
| 150 |
+
) if weights_eos is not None else []
|
| 151 |
+
else:
|
| 152 |
+
chosen_eos = []
|
| 153 |
|
| 154 |
+
for idx in list(chosen_non_eos) + list(chosen_eos):
|
| 155 |
noised[idx] = mask_token_id
|
| 156 |
|
| 157 |
return noised
|
|
|
|
| 269 |
time.sleep(pause_length)
|
| 270 |
|
| 271 |
|
| 272 |
+
answer_ids = current_tokens[answer_start:]
|
| 273 |
+
try:
|
| 274 |
+
eos_index = answer_ids.index(eos_token_id)
|
| 275 |
+
final_ids = answer_ids[:eos_index]
|
| 276 |
+
except ValueError:
|
| 277 |
+
final_ids = answer_ids
|
| 278 |
+
|
| 279 |
+
num_tokens = len(final_ids)
|
| 280 |
+
final_output = tokenizer.decode(final_ids, skip_special_tokens=True)
|
| 281 |
+
|
| 282 |
print(final_output)
|
| 283 |
+
yield f"<b>Final Output ({num_tokens} tokens after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')
|
| 284 |
+
|
| 285 |
|
| 286 |
# --- Gradio Interface ---
|
| 287 |
print("Loading model...")
|
|
|
|
| 292 |
fn=diffusion_chat,
|
| 293 |
inputs=[
|
| 294 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of Amsterdam?"),
|
| 295 |
+
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
|
| 296 |
gr.Slider(0.01, 5, value=0.01, step=0.01, label="↑ = longer pause (for visualization)"),
|
| 297 |
+
gr.Slider(1.0, 20.0, value=1.0, step=0.5, label="↓ = more noising (sharpness)"),
|
| 298 |
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="↑ = more clustered noising (fewer, larger edits)"),
|
| 299 |
+
gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="↑ = more noise (noise start)"),
|
| 300 |
gr.Checkbox(value=False, label="Use confidence-guided noising"),
|
| 301 |
gr.Slider(0.01, 1.0, value=0.01, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
|
| 302 |
|