| """The App module is the main entry point for the application. |
| |
| Run `streamlit run app.py` to start the app. |
| """ |
|
|
| import pandas as pd |
| import streamlit as st |
| from streamlit_option_menu import option_menu |
|
|
| from src.load import load_context |
| from src.subpages import ( |
| DebugPage, |
| FindDuplicatesPage, |
| HomePage, |
| LossesPage, |
| LossySamplesPage, |
| MetricsPage, |
| MisclassifiedPage, |
| Page, |
| ProbingPage, |
| RandomSamplesPage, |
| RawDataPage, |
| ) |
| from src.subpages.attention import AttentionPage |
| from src.subpages.hidden_states import HiddenStatesPage |
| from src.subpages.inspect import InspectPage |
| from src.utils import classmap |
|
|
| sts = st.sidebar |
| st.set_page_config( |
| layout="wide", |
| page_title="Error Analysis", |
| page_icon="🏷️", |
| ) |
|
|
|
|
| def _show_menu(pages: list[Page]) -> int: |
| with st.sidebar: |
| page_names = [p.name for p in pages] |
| page_icons = [p.icon for p in pages] |
| selected_menu_item = st.session_state.active_page = option_menu( |
| menu_title="ExplaiNER", |
| options=page_names, |
| icons=page_icons, |
| menu_icon="layout-wtf", |
| default_index=0, |
| ) |
| return page_names.index(selected_menu_item) |
| assert False |
|
|
|
|
| def _initialize_session_state(pages: list[Page]): |
| if "active_page" not in st.session_state: |
| for page in pages: |
| st.session_state.update(**page._get_widget_defaults()) |
| st.session_state.update(st.session_state) |
|
|
|
|
| def _write_color_legend(context): |
| def style(x): |
| return [f"background-color: {rgb}; opacity: 1;" for rgb in colors] |
|
|
| labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels])) |
| colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels] |
|
|
| color_legend_df = pd.DataFrame( |
| [classmap[l] for l in labels], columns=["label"], index=labels |
| ).T |
| st.sidebar.write( |
| color_legend_df.T.style.apply(style, axis=0).set_properties( |
| **{"color": "white", "text-align": "center"} |
| ) |
| ) |
|
|
|
|
| def main(): |
| """The main entry point for the application.""" |
| pages: list[Page] = [ |
| HomePage(), |
| AttentionPage(), |
| HiddenStatesPage(), |
| ProbingPage(), |
| MetricsPage(), |
| LossySamplesPage(), |
| LossesPage(), |
| MisclassifiedPage(), |
| RandomSamplesPage(), |
| FindDuplicatesPage(), |
| InspectPage(), |
| RawDataPage(), |
| DebugPage(), |
| ] |
|
|
| _initialize_session_state(pages) |
|
|
| selected_page_idx = _show_menu(pages) |
| selected_page = pages[selected_page_idx] |
|
|
| if isinstance(selected_page, HomePage): |
| selected_page.render() |
| return |
|
|
| if "model_name" not in st.session_state: |
| |
| st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'") |
| return |
|
|
| context = load_context(**st.session_state) |
| _write_color_legend(context) |
| selected_page.render(context) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|