Spaces:
Sleeping
Sleeping
File size: 35,521 Bytes
d4d1ca8 cc2e1db d4d1ca8 74baf5f d4d1ca8 cc2e1db d4d1ca8 cc2e1db 74baf5f cc2e1db 74baf5f d4d1ca8 cc2e1db d4d1ca8 74baf5f d4d1ca8 74baf5f d4d1ca8 cc2e1db d4d1ca8 cc2e1db d4d1ca8 cc2e1db d4d1ca8 41a1090 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 |
import gradio as gr
import pandas as pd
from pathlib import Path
from scripts.recommendation import summarize_events
from scripts.data_cleansing import cleanse_data
from dotenv import load_dotenv
import os
import numpy as np
import joblib
ROOT = Path(__file__).resolve().parents[0]
load_dotenv(ROOT / '.env')
def preview_csv(file_obj):
try:
df = pd.read_csv(file_obj.name, dtype=str)
return df.head(10).to_html(index=False)
except Exception as e:
return f"Error reading file: {e}"
def parse_row_selection(df, rows_text: str):
if not rows_text:
return df
idx = []
for token in rows_text.split(','):
token = token.strip()
if token.isdigit():
idx.append(int(token))
return df.iloc[idx]
with gr.Blocks() as demo:
gr.Markdown("# OMS Analyze — Prototype")
gr.Markdown("> Created by PEACE, Powered by AI, Version 0.0.1")
with gr.Tabs():
# Upload & Preview tab
with gr.TabItem('Upload & Preview'):
gr.Markdown("**Usecase Scenario — Upload & Preview**: อัปโหลดไฟล์ CSV เพื่อตรวจสอบข้อมูลต้นฉบับ ทำความสะอาดข้อมูล (ลบข้อมูลซ้ำ, จัดการค่าที่หายไป) เปรียบเทียบตัวอย่างก่อน/หลัง และดาวน์โหลดไฟล์ที่ทำความสะอาดแล้ว")
csv_up = gr.File(label='Upload CSV (data.csv)')
with gr.Row():
remove_dup = gr.Checkbox(label='Remove Duplicates', value=False)
missing_handling = gr.Radio(choices=['drop','impute_mean','impute_median','impute_mode'], value='drop', label='Missing Values Handling')
apply_clean = gr.Button('Apply Cleansing')
with gr.Tabs():
with gr.TabItem('Original Data'):
original_preview = gr.Dataframe(label='Original Data Preview')
with gr.TabItem('Cleansed Data'):
cleansed_preview = gr.Dataframe(label='Cleansed Data Preview')
download_cleansed = gr.File(label='Download Cleansed CSV')
clean_status = gr.Textbox(label='Cleansing Status', interactive=False)
def initial_preview(file):
if file is None:
return pd.DataFrame(), pd.DataFrame(), "Upload a file"
df = pd.read_csv(file.name, dtype=str)
return df.head(100), pd.DataFrame(), "File uploaded, apply cleansing if needed"
def apply_cleansing(file, remove_duplicates, missing_strategy):
if file is None:
return pd.DataFrame(), "No file", None
try:
df = pd.read_csv(file.name, dtype=str)
df_clean, orig_shape, clean_shape = cleanse_data(df, remove_duplicates, missing_strategy)
status = f"Original: {orig_shape[0]} rows, {orig_shape[1]} cols → Cleaned: {clean_shape[0]} rows, {clean_shape[1]} cols"
# Save cleansed data for download
out_file = ROOT / 'outputs' / 'cleansed_data.csv'
out_file.parent.mkdir(exist_ok=True)
df_clean.to_csv(out_file, index=False, encoding='utf-8-sig')
return df_clean.head(100), status, str(out_file)
except Exception as e:
return pd.DataFrame(), f"Error: {e}", None
csv_up.change(fn=initial_preview, inputs=csv_up, outputs=[original_preview, cleansed_preview, clean_status])
apply_clean.click(fn=apply_cleansing, inputs=[csv_up, remove_dup, missing_handling], outputs=[cleansed_preview, clean_status, download_cleansed])
# Summary tab
with gr.TabItem('Summary'):
gr.Markdown("**Usecase Scenario — Summary**: สร้างสรุปภาพรวมของชุดข้อมูลทั้งหมด รวมสถิติพื้นฐาน และคำนวณดัชนีความน่าเชื่อถือ (เช่น SAIFI, SAIDI, CAIDI) พร้อมตัวเลือกใช้ Generative AI ในการขยายความ")
csv_in_sum = gr.File(label='Upload CSV for Overall Summary')
with gr.Row():
use_hf_sum = gr.Checkbox(label='Use Generative AI for Summary', value=False)
total_customers = gr.Number(label='Total Customers (for reliability calculation)', value=500000, precision=0)
run_sum = gr.Button('Generate Overall Summary')
with gr.Row():
model_selector_sum = gr.Dropdown(
choices=[
'meta-llama/Llama-3.1-8B-Instruct:novita',
'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita',
'Qwen/Qwen3-VL-235B-A22B-Instruct:novita',
'deepseek-ai/DeepSeek-R1:novita',
'moonshotai/Kimi-K2-Instruct-0905:novita'
],
value='meta-llama/Llama-3.1-8B-Instruct:novita',
label='GenAI Model',
interactive=True,
visible=False
)
with gr.Tabs():
with gr.TabItem('AI Summary'):
ai_summary_out = gr.Textbox(label='AI Generated Summary', lines=10)
with gr.TabItem('Basic Statistics'):
basic_stats_out = gr.JSON(label='Basic Statistics')
with gr.TabItem('Reliability Indices'):
reliability_out = gr.Dataframe(label='Reliability Metrics')
sum_status = gr.Textbox(label='Summary Status', interactive=False)
def run_overall_summary(file, use_hf_flag, total_cust, model):
if file is None:
return {}, {}, pd.DataFrame(), 'No file provided'
try:
from scripts.summary import summarize_overall
df = pd.read_csv(file.name, dtype=str)
result = summarize_overall(df, use_hf=use_hf_flag, model=model, total_customers=total_cust)
# Prepare outputs
ai_summary = result.get('ai_summary', 'ไม่สามารถสร้างสรุปด้วย AI ได้')
basic_stats = {
'total_events': result.get('total_events'),
'date_range': result.get('date_range'),
'event_types': result.get('event_types'),
'total_affected_customers': result.get('total_affected_customers')
}
# Reliability metrics as DataFrame
reliability_df = result.get('reliability_df', pd.DataFrame())
status = f"Summary generated for {len(df)} events. AI used: {use_hf_flag}"
return ai_summary, basic_stats, reliability_df, status
except Exception as e:
return f"Error: {str(e)}", {}, pd.DataFrame(), f'Summary failed: {e}'
def update_model_visibility_sum(use_hf_flag):
return gr.update(visible=use_hf_flag, interactive=use_hf_flag)
use_hf_sum.change(fn=update_model_visibility_sum, inputs=use_hf_sum, outputs=model_selector_sum)
run_sum.click(fn=run_overall_summary, inputs=[csv_in_sum, use_hf_sum, total_customers, model_selector_sum], outputs=[ai_summary_out, basic_stats_out, reliability_out, sum_status])
# Recommendation tab
with gr.TabItem('Recommendation'):
gr.Markdown("**Usecase Scenario — Recommendation**: สร้างสรุปเหตุการณ์ (เช่น สรุปเหตุการณ์ไฟฟ้าขัอข้องหรือบำรุงรักษา) สำหรับแถวที่เลือก ปรับระดับรายละเอียด และเลือกใช้ Generative AI เพื่อเพิ่มความชัดเจน พร้อมดาวน์โหลดไฟล์สรุป")
csv_in = gr.File(label='Upload CSV (data.csv)')
with gr.Row():
rows = gr.Textbox(label='Rows (comma-separated indexes) or empty = all', placeholder='e.g. 0,1,2')
use_hf = gr.Checkbox(label='Use Generative AI', value=False)
verbosity = gr.Radio(choices=['analyze','recommend'], value='analyze', label='Summary Type', interactive=True)
run_btn = gr.Button('Generate Summaries', interactive=True)
with gr.Row():
model_selector = gr.Dropdown(
choices=[
'meta-llama/Llama-3.1-8B-Instruct:novita',
'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita',
'Qwen/Qwen3-VL-235B-A22B-Instruct:novita',
'deepseek-ai/DeepSeek-R1:novita',
'moonshotai/Kimi-K2-Instruct-0905:novita'
],
value='meta-llama/Llama-3.1-8B-Instruct:novita',
label='GenAI Model',
interactive=True,
visible=False
)
out = gr.Dataframe(headers=['EventNumber','OutageDateTime','Summary'])
status = gr.Textbox(label='Status', interactive=False)
download = gr.File(label='Download summaries')
def run_summarize(file, rows_text, use_hf_flag, verbosity_level, model):
print(f"Debug: file={file}, rows_text={rows_text}, use_hf_flag={use_hf_flag}, verbosity_level={verbosity_level}, model={model}")
if file is None:
return pd.DataFrame([], columns=['EventNumber','OutageDateTime','Summary']), 'No file provided', None
df = pd.read_csv(file.name, dtype=str)
df_sel = parse_row_selection(df, rows_text)
res = summarize_events(df_sel, use_hf=use_hf_flag, verbosity=verbosity_level, model=model)
out_df = pd.DataFrame(res)
out_file = ROOT / 'outputs' / 'summaries_from_ui.csv'
out_file.parent.mkdir(exist_ok=True)
out_df.to_csv(out_file, index=False, encoding='utf-8-sig')
status_text = f"Summaries generated: {len(out_df)} rows. HF used: {use_hf_flag}"
return out_df, status_text, str(out_file)
def update_model_visibility(use_hf_flag):
return gr.update(visible=use_hf_flag, interactive=use_hf_flag)
use_hf.change(fn=update_model_visibility, inputs=use_hf, outputs=model_selector)
run_btn.click(fn=run_summarize, inputs=[csv_in, rows, use_hf, verbosity, model_selector], outputs=[out, status, download])
# Anomaly Detection tab
with gr.TabItem('Anomaly Detection'):
gr.Markdown("**Usecase Scenario — Anomaly Detection**: ตรวจจับเหตุการณ์ที่มีพฤติกรรมผิดปกติในชุดข้อมูล (เช่น เหตุการณ์ที่มีค่าสูง/ต่ำผิดปกติ) โดยใช้หลาย algorithm ปรับระดับ contamination และส่งออกผลลัพธ์พร้อมธงความผิดปกติ")
csv_in_anom = gr.File(label='Upload CSV for Anomaly')
with gr.Row():
alg = gr.Radio(choices=['iso+lof','iso','lof','autoencoder'], value='iso+lof', label='Algorithm')
contamination = gr.Slider(minimum=0.01, maximum=0.2, value=0.05, step=0.01, label='Contamination')
run_anom = gr.Button('Run Anomaly Detection')
anom_out = gr.Dataframe()
anom_status = gr.Textbox(label='Anomaly Status', interactive=False)
anom_download = gr.File(label='Download anomalies CSV')
def run_anomaly_ui(file, algorithm, contamination):
if file is None:
return pd.DataFrame(), 'No file provided', None
from scripts.anomaly import detect_anomalies
df = pd.read_csv(file.name, dtype=str)
res = detect_anomalies(df, contamination=contamination, algorithm=algorithm)
# Reorder columns to put ensemble_flag and final_flag at the end
cols = [c for c in res.columns if c not in ['ensemble_flag', 'final_flag']] + ['ensemble_flag', 'final_flag']
res = res[cols]
out_file = ROOT / 'outputs' / 'anomalies_from_ui.csv'
out_file.parent.mkdir(exist_ok=True)
res.to_csv(out_file, index=False, encoding='utf-8-sig')
status = f"Anomaly detection done. Rows: {len(res)}. Flags: {res['final_flag'].sum()}"
return res, status, str(out_file)
run_anom.click(fn=run_anomaly_ui, inputs=[csv_in_anom, alg, contamination], outputs=[anom_out, anom_status, anom_download])
# Classification tab
with gr.TabItem('Classification'):
gr.Markdown("**Usecase Scenario — Classification**: ฝึกและทดสอบโมเดลเพื่อจำแนกสาเหตุของเหตุการณ์ กำหนดคอลัมน์เป้าหมาย ปรับ hyperparameters, เปิดใช้งาน weak-labeling และดาวน์โหลดโมเดล/ผลการทำนาย")
csv_in_cls = gr.File(label='Upload CSV for Classification')
with gr.Row():
label_col = gr.Dropdown(choices=['CauseType','SubCauseType'], value='CauseType', label='Target Column')
do_weak = gr.Checkbox(label='Run weak-labeling using HF (requires HF_TOKEN)', value=False)
model_type = gr.Radio(choices=['rf','gb','mlp'], value='rf', label='Model Type')
run_cls = gr.Button('Train Classifier')
def update_hyperparams_visibility(model_choice):
rf_visible = model_choice == 'rf'
gb_visible = model_choice == 'gb'
mlp_visible = model_choice == 'mlp'
return [
gr.update(visible=rf_visible),
gr.update(visible=rf_visible),
gr.update(visible=rf_visible),
gr.update(visible=rf_visible),
gr.update(visible=gb_visible),
gr.update(visible=gb_visible),
gr.update(visible=gb_visible),
gr.update(visible=mlp_visible),
gr.update(visible=mlp_visible),
gr.update(visible=mlp_visible),
]
with gr.Accordion("Hyperparameters (Advanced)", open=False):
gr.Markdown("Adjust hyperparameters for the selected model. Defaults are set for good performance.")
rf_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="RF: n_estimators", visible=True)
rf_max_depth = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="RF: max_depth", visible=True)
rf_min_samples_split = gr.Slider(minimum=2, maximum=10, value=2, step=1, label="RF: min_samples_split", visible=True)
rf_min_samples_leaf = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="RF: min_samples_leaf", visible=True)
gb_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="GB: n_estimators", visible=False)
gb_max_depth = gr.Slider(minimum=3, maximum=20, value=3, step=1, label="GB: max_depth", visible=False)
gb_learning_rate = gr.Slider(minimum=0.01, maximum=0.3, value=0.1, step=0.01, label="GB: learning_rate", visible=False)
mlp_hidden_layer_sizes = gr.Textbox(value="(100,)", label="MLP: hidden_layer_sizes (tuple)", visible=False)
mlp_alpha = gr.Slider(minimum=0.0001, maximum=0.01, value=0.0001, step=0.0001, label="MLP: alpha", visible=False)
mlp_max_iter = gr.Slider(minimum=100, maximum=4000, value=500, step=50, label="MLP: max_iter", visible=False)
model_type.change(fn=update_hyperparams_visibility, inputs=model_type, outputs=[rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter])
cls_out = gr.Textbox(label='Classification Report')
model_path_state = gr.State()
cls_download_model = gr.File(label='Download saved model')
cls_download_preds = gr.File(label='Download predictions CSV')
# Test section
gr.Markdown("---")
gr.Markdown("**ทดสอบโมเดล**: อัปโหลดไฟล์ CSV ใหม่เพื่อทดสอบโมเดลที่ฝึกแล้ว")
test_csv = gr.File(label='Upload CSV for Testing')
run_test = gr.Button('Test Model')
test_out = gr.Dataframe(label='Test Predictions')
test_status = gr.Textbox(label='Test Status', interactive=False)
test_download = gr.File(label='Download Test Predictions')
def run_classify_ui(file, label_col_choice, use_weak, model_choice, rf_n_est, rf_max_d, rf_min_ss, rf_min_sl, gb_n_est, gb_max_d, gb_lr, mlp_hls, mlp_a, mlp_mi):
if file is None:
return 'No file provided', None, None, None
from scripts.classify import train_classifier
df = pd.read_csv(file.name, dtype=str)
try:
hyperparams = {}
if model_choice == 'rf':
hyperparams = {'n_estimators': int(rf_n_est), 'max_depth': int(rf_max_d), 'min_samples_split': int(rf_min_ss), 'min_samples_leaf': int(rf_min_sl)}
elif model_choice == 'gb':
hyperparams = {'n_estimators': int(gb_n_est), 'max_depth': int(gb_max_d), 'learning_rate': gb_lr}
elif model_choice == 'mlp':
import ast
hyperparams = {'hidden_layer_sizes': ast.literal_eval(mlp_hls), 'alpha': mlp_a, 'max_iter': int(mlp_mi)}
res = train_classifier(df, label_col=label_col_choice, model_type=model_choice, hyperparams=hyperparams)
report = res.get('report','')
model_file = res.get('model_file')
preds_file = res.get('predictions_file')
# ensure returned file paths are strings for Gradio
return report, model_file, preds_file, model_file
except Exception as e:
return f'Training failed: {e}', None, None, None
def run_test_ui(test_file, model_path):
if test_file is None:
return pd.DataFrame(), 'No test file provided', None
if model_path is None:
return pd.DataFrame(), 'No trained model available. Please train a model first.', None
try:
from scripts.classify import parse_and_features
# Load model
model_data = joblib.load(model_path)
pipeline = model_data['pipeline']
le = model_data['label_encoder']
# Load and preprocess test data
df_test = pd.read_csv(test_file.name, dtype=str)
df_test = parse_and_features(df_test)
# Define features (same as training)
feature_cols = ['duration_min','Load(MW)_num','Capacity(kVA)_num','AffectedCustomer_num','hour','weekday','device_freq','OpDeviceType','Owner','Weather','EventType']
X_test = df_test[feature_cols]
# Predict
y_pred_encoded = pipeline.predict(X_test)
y_pred = le.inverse_transform(y_pred_encoded)
# Create output df
pred_df = df_test.copy()
pred_df['Predicted_CauseType'] = y_pred
# Save predictions
out_file = ROOT / 'outputs' / 'test_predictions.csv'
out_file.parent.mkdir(exist_ok=True)
pred_df.to_csv(out_file, index=False, encoding='utf-8-sig')
status = f"Test completed. Predictions for {len(pred_df)} rows."
return pred_df.head(100), status, str(out_file)
except Exception as e:
return pd.DataFrame(), f'Test failed: {e}', None
run_cls.click(fn=run_classify_ui, inputs=[csv_in_cls, label_col, do_weak, model_type, rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter], outputs=[cls_out, cls_download_model, cls_download_preds, model_path_state])
run_test.click(fn=run_test_ui, inputs=[test_csv, model_path_state], outputs=[test_out, test_status, test_download])
# Label Suggestion tab
with gr.TabItem('Label Suggestion'):
gr.Markdown("**Usecase Scenario — Label Suggestion**: ให้คำแนะนำป้ายกำกับสาเหตุที่เป็นไปได้สำหรับเหตุการณ์ที่ไม่มีฉลาก โดยเทียบความคล้ายกับตัวอย่างที่มีฉลาก ปรับจำนวนคำแนะนำสูงสุด และส่งออกเป็นไฟล์ CSV")
csv_in_ls = gr.File(label='Upload CSV (defaults to data/data_3.csv)')
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label='Top K suggestions')
run_ls = gr.Button('Run Label Suggestion')
ls_out = gr.Dataframe()
ls_status = gr.Textbox(label='Label Suggestion Status', interactive=False)
ls_download = gr.File(label='Download label suggestions')
def run_label_suggestion(file, top_k_suggest):
# delegate to scripts.label_suggestion
from scripts.label_suggestion import suggest_labels_to_file
if file is None:
default = ROOT / 'data' / 'data_3.csv'
if not default.exists():
return pd.DataFrame(), 'No file provided and default data/data_3.csv not found', None
df = pd.read_csv(default, dtype=str)
else:
df = pd.read_csv(file.name, dtype=str)
out_file = ROOT / 'outputs' / 'label_suggestions.csv'
out_df = suggest_labels_to_file(df, out_path=str(out_file), top_k=int(top_k_suggest))
status = f"Label suggestion done. Unknown rows processed: {len(out_df)}. Output: {out_file}"
return out_df, status, str(out_file) if len(out_df)>0 else None
run_ls.click(fn=run_label_suggestion, inputs=[csv_in_ls, top_k], outputs=[ls_out, ls_status, ls_download])
# Forecasting tab
with gr.TabItem('Forecasting'):
gr.Markdown("**Usecase Scenario — Forecasting**: พยากรณ์จำนวนเหตุการณ์หรือเวลาหยุดทำงานในอนาคตโดยเลือกโมเดล (Prophet, LSTM, Bi-LSTM, GRU, Naive) ปรับพารามิเตอร์ และส่งออกผลการพยากรณ์")
gr.Markdown("*Multivariate forecasting (ใช้หลายฟีเจอร์) รองรับเฉพาะโมเดล LSTM, Bi-LSTM, GRU เท่านั้น*")
csv_in_fc = gr.File(label='Upload CSV for Forecasting')
with gr.Row():
metric_fc = gr.Radio(choices=['count','downtime_minutes'], value='count', label='Metric to Forecast')
model_type_fc = gr.Radio(choices=['prophet','lstm','bilstm','gru','naive'], value='lstm', label='Forecasting Model', elem_id='forecast_model_radio')
periods_fc = gr.Slider(minimum=1, maximum=30, value=7, step=1, label='Forecast Periods (days)')
multivariate_fc = gr.Checkbox(value=False, label='Use Multivariate (Multiple Features)', interactive=False)
run_fc = gr.Button('Run Forecasting')
# Add state to track current model
current_model_state = gr.State(value='lstm')
def update_multivariate_visibility(model_choice):
# Multivariate is only supported for LSTM, Bi-LSTM, GRU
supported_models = ['lstm', 'bilstm', 'gru']
is_supported = model_choice in supported_models
return gr.update(interactive=is_supported, value=False)
def update_model_state(model_choice):
return model_choice
# Hyperparameter controls for forecasting
with gr.Accordion("Hyperparameters (Advanced)", open=False):
gr.Markdown("Adjust hyperparameters for the selected forecasting model. Defaults are set for good performance.")
# Prophet hyperparameters
prophet_changepoint_prior = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Prophet: changepoint_prior_scale", visible=False)
prophet_seasonality_prior = gr.Slider(minimum=0.01, maximum=10.0, value=10.0, step=0.1, label="Prophet: seasonality_prior_scale", visible=False)
prophet_seasonality_mode = gr.Radio(choices=['additive', 'multiplicative'], value='additive', label="Prophet: seasonality_mode", visible=False)
# Deep learning hyperparameters (LSTM, Bi-LSTM, GRU)
dl_seq_length = gr.Slider(minimum=3, maximum=30, value=7, step=1, label="DL: sequence_length (lag/input length)", visible=True)
dl_epochs = gr.Slider(minimum=10, maximum=200, value=100, step=10, label="DL: epochs", visible=True)
dl_batch_size = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="DL: batch_size", visible=True)
dl_learning_rate = gr.Slider(minimum=0.0001, maximum=0.01, value=0.001, step=0.0001, label="DL: learning_rate", visible=True)
dl_units = gr.Slider(minimum=32, maximum=256, value=100, step=16, label="DL: units (neurons)", visible=True)
dl_dropout = gr.Slider(minimum=0.0, maximum=0.5, value=0.2, step=0.05, label="DL: dropout_rate", visible=True)
# Naive has no hyperparameters
def update_forecast_hyperparams_visibility(model_choice):
prophet_visible = model_choice == 'prophet'
dl_visible = model_choice in ['lstm', 'bilstm', 'gru']
return [
gr.update(visible=prophet_visible), # prophet_changepoint_prior
gr.update(visible=prophet_visible), # prophet_seasonality_prior
gr.update(visible=prophet_visible), # prophet_seasonality_mode
gr.update(visible=dl_visible), # dl_seq_length
gr.update(visible=dl_visible), # dl_epochs
gr.update(visible=dl_visible), # dl_batch_size
gr.update(visible=dl_visible), # dl_learning_rate
gr.update(visible=dl_visible), # dl_units
gr.update(visible=dl_visible), # dl_dropout
]
with gr.Tabs():
with gr.TabItem('Historical Data'):
hist_out = gr.Dataframe(label='Historical Time Series Data')
with gr.TabItem('Forecast Results'):
fcst_out = gr.Dataframe(label='Forecast Results')
with gr.TabItem('Time Series Plot'):
plot_out = gr.Plot(label='Historical + Forecast Plot')
fc_status = gr.Textbox(label='Forecast Status', interactive=False)
fc_download = gr.File(label='Download forecast CSV')
def run_forecast_ui(file, metric, model_type, periods, multivariate, current_model, prophet_cp, prophet_sp, prophet_sm, dl_sl, dl_e, dl_bs, dl_lr, dl_u, dl_d):
# Use current_model if available, otherwise use model_type
actual_model = current_model if current_model else model_type
if file is None:
return pd.DataFrame(), pd.DataFrame(), None, 'No file provided', None
try:
from scripts.forecast import run_forecast
import matplotlib.pyplot as plt
df = pd.read_csv(file.name, dtype=str)
# Build hyperparams dict based on model type
hyperparams = {}
if actual_model == 'prophet':
hyperparams = {
'changepoint_prior_scale': prophet_cp,
'seasonality_prior_scale': prophet_sp,
'seasonality_mode': prophet_sm
}
elif actual_model in ['lstm', 'bilstm', 'gru']:
hyperparams = {
'seq_length': int(dl_sl),
'epochs': int(dl_e),
'batch_size': int(dl_bs),
'learning_rate': dl_lr,
'units': int(dl_u),
'dropout_rate': dl_d
}
ts, fcst = run_forecast(df, metric=metric, periods=periods, model_type=actual_model, multivariate=multivariate, hyperparams=hyperparams)
# Create time series plot
fig, ax = plt.subplots(figsize=(14, 7))
# Plot historical data
if len(ts) > 0 and 'y' in ts.columns:
ax.plot(ts['ds'], ts['y'], 'b-', label='Historical Data', linewidth=2, marker='o', markersize=4)
# Plot forecast data
if len(fcst) > 0 and 'yhat' in fcst.columns:
ax.plot(fcst['ds'], fcst['yhat'], 'r--', label='Forecast', linewidth=3, marker='s', markersize=5)
if 'yhat_lower' in fcst.columns and 'yhat_upper' in fcst.columns:
ax.fill_between(fcst['ds'], fcst['yhat_lower'], fcst['yhat_upper'],
color='red', alpha=0.3, label='Confidence Interval')
# Add vertical line to separate historical from forecast
if len(ts) > 0 and len(fcst) > 0:
last_hist_date = ts['ds'].max()
ax.axvline(x=last_hist_date, color='gray', linestyle='--', alpha=0.7, label='Forecast Start')
ax.set_title(f'Time Series Forecast: {model_type.upper()} ({metric.replace("_", " ").title()})',
fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('Date', fontsize=14)
ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=14)
ax.legend(loc='upper left', fontsize=12)
ax.grid(True, alpha=0.3)
# Format x-axis dates
import matplotlib.dates as mdates
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(ts) // 10)))
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
# Save forecast results
mode = 'multivariate' if multivariate else 'univariate'
if multivariate and model_type not in ['lstm', 'bilstm', 'gru']:
mode += ' (fallback: model does not support multivariate)'
out_file = ROOT / 'outputs' / f'forecast_{metric}_{model_type}_{mode.replace(" ", "_")}.csv'
out_file.parent.mkdir(exist_ok=True)
fcst.to_csv(out_file, index=False)
status = f"Forecasting completed using {model_type.upper()} ({mode}). Historical data: {len(ts)} days, Forecast: {len(fcst)} days."
if multivariate and model_type not in ['lstm', 'bilstm', 'gru']:
status += " Note: Model does not support multivariate - used univariate instead."
return ts, fcst, fig, status, str(out_file)
except Exception as e:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(14, 7))
ax.text(0.5, 0.5, f'Forecasting Error:\n{str(e)}',
transform=ax.transAxes, ha='center', va='center',
fontsize=14, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
ax.set_title('Time Series Forecast - Error Occurred', fontsize=16, fontweight='bold')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.axis('off')
return pd.DataFrame(), pd.DataFrame(), fig, f'Forecasting failed: {e}', None
model_type_fc.change(fn=update_multivariate_visibility, inputs=[model_type_fc], outputs=[multivariate_fc])
model_type_fc.change(fn=update_model_state, inputs=[model_type_fc], outputs=[current_model_state])
model_type_fc.change(fn=update_forecast_hyperparams_visibility, inputs=[model_type_fc], outputs=[prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout])
run_fc.click(fn=run_forecast_ui, inputs=[csv_in_fc, metric_fc, model_type_fc, periods_fc, multivariate_fc, current_model_state, prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout], outputs=[hist_out, fcst_out, plot_out, fc_status, fc_download])
if __name__ == '__main__':
demo.launch(server_name="0.0.0.0", server_port=7860)
|