Commit
·
4a4e1c1
1
Parent(s):
86e510b
feat: Add update colours buttom
Browse files
app.py
CHANGED
|
@@ -333,6 +333,11 @@ def main() -> None:
|
|
| 333 |
interactive=True,
|
| 334 |
scale=1,
|
| 335 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
with gr.Row():
|
| 337 |
plot = gr.Plot(
|
| 338 |
value=produce_radial_plot(
|
|
@@ -442,6 +447,21 @@ def main() -> None:
|
|
| 442 |
],
|
| 443 |
outputs=plot,
|
| 444 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
demo.launch()
|
| 447 |
|
|
@@ -782,12 +802,15 @@ def fetch_results() -> dict[Language, pd.DataFrame]:
|
|
| 782 |
return results_dfs
|
| 783 |
|
| 784 |
|
| 785 |
-
def update_colour_mapping(model_ids: list[str]) -> None:
|
| 786 |
"""Get a mapping from model ids to RGB triplets.
|
| 787 |
|
| 788 |
Args:
|
| 789 |
model_ids:
|
| 790 |
-
The model ids to update the colour
|
|
|
|
|
|
|
|
|
|
| 791 |
"""
|
| 792 |
if not model_ids:
|
| 793 |
return
|
|
@@ -796,6 +819,9 @@ def update_colour_mapping(model_ids: list[str]) -> None:
|
|
| 796 |
global seed
|
| 797 |
seed += 1
|
| 798 |
|
|
|
|
|
|
|
|
|
|
| 799 |
for i in it.count():
|
| 800 |
min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
|
| 801 |
retries_left = 10 * len(model_ids)
|
|
|
|
| 333 |
interactive=True,
|
| 334 |
scale=1,
|
| 335 |
)
|
| 336 |
+
update_colours_button = gr.Button(
|
| 337 |
+
value="Update colours",
|
| 338 |
+
interactive=True,
|
| 339 |
+
scale=1,
|
| 340 |
+
)
|
| 341 |
with gr.Row():
|
| 342 |
plot = gr.Plot(
|
| 343 |
value=produce_radial_plot(
|
|
|
|
| 447 |
],
|
| 448 |
outputs=plot,
|
| 449 |
)
|
| 450 |
+
update_colours_button.click(
|
| 451 |
+
fn=partial(update_colour_mapping, from_scratch=True),
|
| 452 |
+
inputs=model_ids_dropdown
|
| 453 |
+
).then(
|
| 454 |
+
fn=partial(produce_radial_plot, results_dfs=results_dfs),
|
| 455 |
+
inputs=[
|
| 456 |
+
model_ids_dropdown,
|
| 457 |
+
language_names_dropdown,
|
| 458 |
+
use_rank_score_checkbox,
|
| 459 |
+
show_scale_checkbox,
|
| 460 |
+
plot_width_slider,
|
| 461 |
+
plot_height_slider,
|
| 462 |
+
],
|
| 463 |
+
outputs=plot,
|
| 464 |
+
)
|
| 465 |
|
| 466 |
demo.launch()
|
| 467 |
|
|
|
|
| 802 |
return results_dfs
|
| 803 |
|
| 804 |
|
| 805 |
+
def update_colour_mapping(model_ids: list[str], from_scratch: bool = False) -> None:
|
| 806 |
"""Get a mapping from model ids to RGB triplets.
|
| 807 |
|
| 808 |
Args:
|
| 809 |
model_ids:
|
| 810 |
+
The model ids to update the colour.
|
| 811 |
+
from_scratch:
|
| 812 |
+
Whether to reset the existing colour mapping and build a new one from
|
| 813 |
+
scratch. Defaults to False.
|
| 814 |
"""
|
| 815 |
if not model_ids:
|
| 816 |
return
|
|
|
|
| 819 |
global seed
|
| 820 |
seed += 1
|
| 821 |
|
| 822 |
+
if from_scratch:
|
| 823 |
+
colour_mapping = dict()
|
| 824 |
+
|
| 825 |
for i in it.count():
|
| 826 |
min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
|
| 827 |
retries_left = 10 * len(model_ids)
|