Highlight noised tokens
Browse files
app.py
CHANGED
|
@@ -62,14 +62,13 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clust
|
|
| 62 |
num_to_noise = int(threshold * answer_len)
|
| 63 |
|
| 64 |
if num_to_noise == 0:
|
| 65 |
-
return noised
|
| 66 |
|
| 67 |
mixed_probs = token_probabilities.copy()
|
| 68 |
mixed_probs[eot_token_id] *= eot_weight
|
| 69 |
mixed_probs /= mixed_probs.sum()
|
| 70 |
|
| 71 |
-
|
| 72 |
-
num_clusters = max(1, int((1 - clustering) * num_to_noise)) # fewer clusters if more intensity
|
| 73 |
cluster_size = max(1, int(num_to_noise / num_clusters))
|
| 74 |
|
| 75 |
noised_indices = set()
|
|
@@ -79,15 +78,13 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clust
|
|
| 79 |
span_end = min(len(noised), span_start + cluster_size)
|
| 80 |
noised_indices.update(range(span_start, span_end))
|
| 81 |
|
| 82 |
-
# Trim in case we overshot due to overlapping clusters
|
| 83 |
noised_indices = sorted(list(noised_indices))[:num_to_noise]
|
| 84 |
|
| 85 |
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
|
| 86 |
for idx, val in zip(noised_indices, noise):
|
| 87 |
noised[idx] = val
|
| 88 |
|
| 89 |
-
return noised
|
| 90 |
-
|
| 91 |
|
| 92 |
|
| 93 |
# Add new noising function
|
|
@@ -165,7 +162,9 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
| 165 |
input_ids = input_ids[:256]
|
| 166 |
|
| 167 |
ori_input_tokens = input_ids
|
| 168 |
-
current_tokens = noisify_answer(
|
|
|
|
|
|
|
| 169 |
prev_decoded_tokens = []
|
| 170 |
last_tokens = []
|
| 171 |
|
|
@@ -178,14 +177,19 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
| 178 |
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
|
| 179 |
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
| 180 |
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
|
| 181 |
-
|
| 182 |
if filtered_prev_tokens:
|
| 183 |
highlighted = []
|
| 184 |
-
for
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
else:
|
| 188 |
-
highlighted.append(
|
| 189 |
else:
|
| 190 |
highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
|
| 191 |
|
|
@@ -203,7 +207,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
| 203 |
if use_confidence_noising:
|
| 204 |
current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
|
| 205 |
else:
|
| 206 |
-
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering)
|
| 207 |
|
| 208 |
time.sleep(0.01)
|
| 209 |
|
|
|
|
| 62 |
num_to_noise = int(threshold * answer_len)
|
| 63 |
|
| 64 |
if num_to_noise == 0:
|
| 65 |
+
return noised, []
|
| 66 |
|
| 67 |
mixed_probs = token_probabilities.copy()
|
| 68 |
mixed_probs[eot_token_id] *= eot_weight
|
| 69 |
mixed_probs /= mixed_probs.sum()
|
| 70 |
|
| 71 |
+
num_clusters = max(1, int((1 - clustering) * num_to_noise))
|
|
|
|
| 72 |
cluster_size = max(1, int(num_to_noise / num_clusters))
|
| 73 |
|
| 74 |
noised_indices = set()
|
|
|
|
| 78 |
span_end = min(len(noised), span_start + cluster_size)
|
| 79 |
noised_indices.update(range(span_start, span_end))
|
| 80 |
|
|
|
|
| 81 |
noised_indices = sorted(list(noised_indices))[:num_to_noise]
|
| 82 |
|
| 83 |
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
|
| 84 |
for idx, val in zip(noised_indices, noise):
|
| 85 |
noised[idx] = val
|
| 86 |
|
| 87 |
+
return noised, noised_indices
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
# Add new noising function
|
|
|
|
| 162 |
input_ids = input_ids[:256]
|
| 163 |
|
| 164 |
ori_input_tokens = input_ids
|
| 165 |
+
current_tokens, just_noised_indices = noisify_answer(
|
| 166 |
+
ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
|
| 167 |
+
)
|
| 168 |
prev_decoded_tokens = []
|
| 169 |
last_tokens = []
|
| 170 |
|
|
|
|
| 177 |
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
|
| 178 |
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
| 179 |
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
|
| 180 |
+
just_noised_indices = []
|
| 181 |
if filtered_prev_tokens:
|
| 182 |
highlighted = []
|
| 183 |
+
for i, tok in enumerate(decoded_tokens):
|
| 184 |
+
token_str = tokenizer.convert_tokens_to_string([tok])
|
| 185 |
+
|
| 186 |
+
abs_idx = answer_start + i
|
| 187 |
+
if abs_idx in just_noised_indices:
|
| 188 |
+
highlighted.append(f'<span style="color:red">{token_str}</span>')
|
| 189 |
+
elif prev_decoded_tokens and i < len(prev_decoded_tokens) and tok != prev_decoded_tokens[i]:
|
| 190 |
+
highlighted.append(f'<span style="color:green">{token_str}</span>')
|
| 191 |
else:
|
| 192 |
+
highlighted.append(token_str)
|
| 193 |
else:
|
| 194 |
highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
|
| 195 |
|
|
|
|
| 207 |
if use_confidence_noising:
|
| 208 |
current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
|
| 209 |
else:
|
| 210 |
+
current_tokens, just_noised_indices = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering)
|
| 211 |
|
| 212 |
time.sleep(0.01)
|
| 213 |
|