devjas1 commited on
Commit
21823a6
Β·
1 Parent(s): 95c5f93

FEAT(dashboard): Implement interactive batch analysis dashboard page

Browse files

Introduces a new dedicated dashboard for in-depth batch result analysis,
significantly enhancing the application's analytical capabilities.

- Creates `pages/2_πŸ“Š_Dashboard.py` to house the new feature.
- Implements the `BatchAnalysis` class in `modules/analyzer.py`,
encapsulating all dashboard logic, including KPI metrics, diagnostic
visualizations (Confusion Matrix, Confidence Plot), and an interactive
data grid.
- This establishes the foundational entity for all future advanced
analysis features.

Files changed (3) hide show
  1. app.py +0 -41
  2. modules/analyzer.py +309 -0
  3. pages/2_πŸ“Š_Dashboard.py +36 -0
app.py DELETED
@@ -1,41 +0,0 @@
1
- """Streamlit main entrance; modularized for clarity"""
2
-
3
- import streamlit as st
4
-
5
- from modules.callbacks import init_session_state
6
-
7
- from modules.ui_components import (
8
- render_sidebar,
9
- render_results_column,
10
- render_input_column,
11
- load_css,
12
- )
13
-
14
-
15
- # --- Page Setup (Called only ONCE) ---
16
- st.set_page_config(
17
- page_title="ML Polymer Classification",
18
- page_icon="πŸ”¬",
19
- layout="wide",
20
- initial_sidebar_state="expanded",
21
- menu_items={"Get help": "https://github.com/KLab-AI3/ml-polymer-recycling"},
22
- )
23
-
24
-
25
- def main():
26
- """Modularized main content to other scripts to clean the main app"""
27
- load_css("static/style.css")
28
- init_session_state()
29
-
30
- # Render UI components
31
- render_sidebar()
32
-
33
- col1, col2 = st.columns([1, 1.35], gap="small")
34
- with col1:
35
- render_input_column()
36
- with col2:
37
- render_results_column()
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/analyzer.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # In modules/analyzer.py
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
+ import seaborn as sns
7
+ from sklearn.metrics import confusion_matrix
8
+ import matplotlib.pyplot as plt
9
+ from datetime import datetime
10
+
11
+ from config import LABEL_MAP # Assuming LABEL_MAP is correctly defined in config.py
12
+
13
+
14
+ class BatchAnalysis:
15
+ def __init__(self, df: pd.DataFrame):
16
+ """Initializes the analysis object with the results DataFrame."""
17
+ self.df = df
18
+ if self.df.empty:
19
+ return
20
+
21
+ self.total_files = len(self.df)
22
+ self.has_ground_truth = (
23
+ "Ground Truth" in self.df.columns
24
+ and not self.df["Ground Truth"].isnull().all()
25
+ )
26
+ self._prepare_data()
27
+ self.kpis = self._calculate_kpis()
28
+
29
+ def _prepare_data(self):
30
+ """Ensures data types are correct for analysis."""
31
+ self.df["Confidence"] = pd.to_numeric(self.df["Confidence"], errors="coerce")
32
+ if self.has_ground_truth:
33
+ self.df["Ground Truth"] = pd.to_numeric(
34
+ self.df["Ground Truth"], errors="coerce"
35
+ )
36
+
37
+ def _calculate_kpis(self) -> dict:
38
+ """A private method to compute all the key performance indicators."""
39
+ stable_count = self.df[
40
+ self.df["Predicted Class"] == "Stable (Unweathered)"
41
+ ].shape[0]
42
+ accuracy = "N/A"
43
+ if self.has_ground_truth:
44
+ valid_gt = self.df.dropna(subset=["Ground Truth", "Prediction"])
45
+ accuracy = (valid_gt["Prediction"] == valid_gt["Ground Truth"]).mean()
46
+
47
+ return {
48
+ "Total Files": self.total_files,
49
+ "Avg. Confidence": self.df["Confidence"].mean(),
50
+ "Stable/Weathered": f"{stable_count}/{self.total_files - stable_count}",
51
+ "Accuracy": accuracy,
52
+ }
53
+
54
+ def render_kpis(self):
55
+ """Renders the top-level KPI metrics."""
56
+ kpi_cols = st.columns(4)
57
+ kpi_cols[0].metric("Total Files", f"{self.kpis['Total Files']}")
58
+ kpi_cols[1].metric("Avg. Confidence", f"{self.kpis['Avg. Confidence']:.3f}")
59
+ kpi_cols[2].metric("Stable/Weathered", self.kpis["Stable/Weathered"])
60
+ kpi_cols[3].metric(
61
+ "Accuracy",
62
+ (
63
+ f"{self.kpis['Accuracy']:.3f}"
64
+ if isinstance(self.kpis["Accuracy"], float)
65
+ else "N/A"
66
+ ),
67
+ )
68
+
69
+ def render_visual_diagnostics(self):
70
+ """
71
+ Renders the main diagnostic plots with improved aesthetics and layout.
72
+ """
73
+ st.markdown("##### Visual Analysis")
74
+ if not self.has_ground_truth:
75
+ st.info(
76
+ "Visual analysis requires Ground Truth data, which is not available for this batch."
77
+ )
78
+ return
79
+
80
+ valid_gt_df = self.df.dropna(subset=["Ground Truth"])
81
+
82
+ viz_cols = st.columns(2)
83
+
84
+ # --- Chart 1: Confusion Matrix (Aesthetically Improved) ---
85
+ with viz_cols[0]:
86
+ st.markdown("**Confusion Matrix**")
87
+ cm = confusion_matrix(
88
+ valid_gt_df["Ground Truth"],
89
+ valid_gt_df["Prediction"],
90
+ labels=list(LABEL_MAP.keys()),
91
+ )
92
+
93
+ # Use Matplotlib's constrained_layout for better sizing
94
+ fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
95
+
96
+ sns.heatmap(
97
+ cm,
98
+ annot=True,
99
+ fmt="g",
100
+ ax=ax,
101
+ cmap="Blues",
102
+ xticklabels=list(LABEL_MAP.values()),
103
+ yticklabels=list(LABEL_MAP.values()),
104
+ )
105
+
106
+ # Improve label readability and appearance
107
+ ax.set_ylabel("Actual Class", fontsize=12)
108
+ ax.set_xlabel("Predicted Class", fontsize=12)
109
+ ax.set_xticklabels(
110
+ ax.get_xticklabels(), rotation=45, ha="right"
111
+ ) # Rotate labels to prevent overlap
112
+ ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
113
+
114
+ # Use `use_container_width=True` to let Streamlit manage the plot's width
115
+ st.pyplot(fig, use_container_width=True)
116
+
117
+ # --- Chart 2: Confidence vs. Correctness Box Plot (Aesthetically Improved) ---
118
+ with viz_cols[1]:
119
+ st.markdown("**Confidence Analysis**")
120
+ valid_gt_df["Result"] = np.where(
121
+ valid_gt_df["Prediction"] == valid_gt_df["Ground Truth"],
122
+ "Correct",
123
+ "Incorrect",
124
+ )
125
+
126
+ fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True)
127
+
128
+ sns.boxplot(
129
+ x="Result",
130
+ y="Confidence",
131
+ data=valid_gt_df,
132
+ ax=ax,
133
+ palette={"Correct": "#64C764", "Incorrect": "#E57373"},
134
+ ) # Use softer colors
135
+
136
+ ax.set_ylabel("Model Confidence", fontsize=12)
137
+ ax.set_xlabel("Prediction Result", fontsize=12)
138
+
139
+ st.pyplot(fig, use_container_width=True)
140
+
141
+ # ... (The interactive button grid for the confusion matrix remains the same) ...
142
+ st.markdown("Click on a cell below to filter the results grid:")
143
+ cm_labels = list(LABEL_MAP.values())
144
+ for i, actual_label in enumerate(cm_labels):
145
+ cols = st.columns(len(cm_labels))
146
+ for j, predicted_label in enumerate(cm_labels):
147
+ cell_value = cm[i, j]
148
+ cols[j].button(
149
+ f"Actual: {actual_label}\nPred: {predicted_label} ({cell_value})",
150
+ key=f"cm_cell_{i}_{j}",
151
+ on_click=self._set_cm_filter,
152
+ args=(i, j, actual_label, predicted_label),
153
+ use_container_width=True,
154
+ )
155
+
156
+ def _set_cm_filter(
157
+ self,
158
+ actual_idx: int,
159
+ predicted_idx: int,
160
+ actual_label: str,
161
+ predicted_label: str,
162
+ ):
163
+ """Callback to set the confusion matrix filter in session state."""
164
+ st.session_state["cm_actual_filter"] = actual_idx
165
+ st.session_state["cm_predicted_filter"] = predicted_idx
166
+ st.session_state["cm_filter_label"] = (
167
+ f"Actual: {actual_label}, Predicted: {predicted_label}"
168
+ )
169
+ st.session_state["cm_filter_active"] = True
170
+ # Streamlit will rerun automatically
171
+
172
+ def _clear_cm_filter(self):
173
+ """Callback to clear the confusion matrix filter from session state."""
174
+ if "cm_actual_filter" in st.session_state:
175
+ del st.session_state["cm_actual_filter"]
176
+ if "cm_predicted_filter" in st.session_state:
177
+ del st.session_state["cm_predicted_filter"]
178
+ if "cm_filter_label" in st.session_state:
179
+ del st.session_state["cm_filter_label"]
180
+ if "cm_filter_active" in st.session_state:
181
+ del st.session_state["cm_filter_active"]
182
+
183
+ def render_interactive_grid(self):
184
+ """
185
+ Renders the filterable, detailed data grid with robust handling for
186
+ row selection to prevent KeyError.
187
+ """
188
+ st.markdown("##### Detailed Results Explorer")
189
+
190
+ # Start with a full copy of the dataframe to apply filters to
191
+ filtered_df = self.df.copy()
192
+
193
+ # --- Filter Section ---
194
+ st.markdown("**Filters**")
195
+ filter_cols = st.columns([2, 2, 3]) # Allocate more space for the slider
196
+
197
+ # Filter 1: By Predicted Class
198
+ selected_classes = filter_cols[0].multiselect(
199
+ "Filter by Prediction:",
200
+ options=self.df["Predicted Class"].unique(),
201
+ default=self.df["Predicted Class"].unique(),
202
+ )
203
+ filtered_df = filtered_df[filtered_df["Predicted Class"].isin(selected_classes)]
204
+
205
+ # Filter 2: By Ground Truth Correctness (if available)
206
+ if self.has_ground_truth:
207
+ filtered_df["Correct"] = (
208
+ filtered_df["Prediction"] == filtered_df["Ground Truth"]
209
+ )
210
+ correctness_options = ["βœ… Correct", "❌ Incorrect"]
211
+
212
+ # Create a temporary column for display in multiselect
213
+ filtered_df["Result_Display"] = np.where(
214
+ filtered_df["Correct"], "βœ… Correct", "❌ Incorrect"
215
+ )
216
+
217
+ selected_correctness = filter_cols[1].multiselect(
218
+ "Filter by Result:",
219
+ options=correctness_options,
220
+ default=correctness_options,
221
+ )
222
+ # Filter based on the boolean 'Correct' column
223
+ filter_correctness_bools = [
224
+ True if c == "βœ… Correct" else False for c in selected_correctness
225
+ ]
226
+ filtered_df = filtered_df[
227
+ filtered_df["Correct"].isin(filter_correctness_bools)
228
+ ]
229
+
230
+ # --- NEW: Filter 3: By Confidence Range ---
231
+ min_conf, max_conf = filter_cols[2].slider(
232
+ "Filter by Confidence Range:",
233
+ min_value=0.0,
234
+ max_value=1.0,
235
+ value=(0.0, 1.0), # Default to the full range
236
+ step=0.01,
237
+ )
238
+ filtered_df = filtered_df[
239
+ (filtered_df["Confidence"] >= min_conf)
240
+ & (filtered_df["Confidence"] <= max_conf)
241
+ ]
242
+ # --- END NEW FILTER ---
243
+
244
+ # Apply Confusion Matrix Drill-Down Filter (if active)
245
+ if st.session_state.get("cm_filter_active", False):
246
+ actual_idx = st.session_state["cm_actual_filter"]
247
+ predicted_idx = st.session_state["cm_predicted_filter"]
248
+ filter_label = st.session_state["cm_filter_label"]
249
+
250
+ st.info(f"Filtering results for: **{filter_label}**")
251
+ filtered_df = filtered_df[
252
+ (filtered_df["Ground Truth"] == actual_idx)
253
+ & (filtered_df["Prediction"] == predicted_idx)
254
+ ]
255
+
256
+ # --- Display the Filtered Data Table ---
257
+ if filtered_df.empty:
258
+ st.warning("No files match the current filter criteria.")
259
+ st.session_state.selected_spectrum_file = None
260
+ else:
261
+ display_df = filtered_df.drop(
262
+ columns=["Correct", "Result_Display"], errors="ignore"
263
+ )
264
+
265
+ st.dataframe(
266
+ display_df,
267
+ use_container_width=True,
268
+ hide_index=True,
269
+ on_select="rerun",
270
+ selection_mode="single-row",
271
+ key="results_grid_selection",
272
+ )
273
+
274
+ # --- ROBUST SELECTION HANDLING (THE FIX) ---
275
+ selection_state = st.session_state.get("results_grid_selection")
276
+
277
+ # Check if selection_state is a dictionary AND if it contains the 'rows' key
278
+ if (
279
+ isinstance(selection_state, dict)
280
+ and "rows" in selection_state
281
+ and selection_state["rows"]
282
+ ):
283
+ selected_index = selection_state["rows"][0]
284
+
285
+ if selected_index < len(filtered_df):
286
+ st.session_state.selected_spectrum_file = filtered_df.iloc[
287
+ selected_index
288
+ ]["Filename"]
289
+ else:
290
+ # This can happen if the table is re-filtered and the old index is now out of bounds
291
+ st.session_state.selected_spectrum_file = None
292
+ else:
293
+ # If the selection is empty or in an unexpected format, clear the selection
294
+ st.session_state.selected_spectrum_file = None
295
+ # --- END ROBUST HANDLING ---
296
+
297
+ def render(self):
298
+ """The main public method to render the entire dashboard."""
299
+ if self.df.empty:
300
+ st.info(
301
+ "The results table is empty. Please run an analysis on the 'Upload and Run' page."
302
+ )
303
+ return
304
+
305
+ self.render_kpis()
306
+ st.divider()
307
+ self.render_visual_diagnostics()
308
+ st.divider()
309
+ self.render_interactive_grid()
pages/2_πŸ“Š_Dashboard.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # In pages/2_πŸ“Š_Dashboard.py
2
+
3
+ import streamlit as st
4
+ from utils.results_manager import ResultsManager
5
+ from modules.analyzer import BatchAnalysis # Adjusted import path
6
+
7
+ st.set_page_config(page_title="Analysis Dashboard", layout="wide")
8
+
9
+ # --- INITIALIZE SESSION STATE FOR THIS PAGE ---
10
+ if "cm_filter_active" not in st.session_state:
11
+ st.session_state["cm_filter_active"] = False
12
+ if "selected_spectrum_file" not in st.session_state:
13
+ st.session_state["selected_spectrum_file"] = (
14
+ None # Stores the filename of the clicked row
15
+ )
16
+ # --- END INITIALIZATION ---
17
+
18
+ st.title("πŸ“Š Interactive Analysis Dashboard")
19
+ st.markdown(
20
+ "Dive deeper into your batch results. Use the charts below to analyze model performance."
21
+ )
22
+ st.divider()
23
+
24
+ # --- Initialize session state for CM filter ---
25
+ if "cm_filter_active" not in st.session_state:
26
+ st.session_state["cm_filter_active"] = False
27
+
28
+
29
+ # Get the results from the session state
30
+ results_df = ResultsManager.get_results_dataframe()
31
+
32
+ # Create an instance of our analyzer with the results
33
+ analyzer = BatchAnalysis(results_df)
34
+
35
+ # Render the entire dashboard with one line!
36
+ analyzer.render()