File size: 35,521 Bytes
d4d1ca8
 
 
cc2e1db
d4d1ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74baf5f
d4d1ca8
 
 
cc2e1db
d4d1ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2e1db
 
 
 
 
 
 
 
 
 
 
 
 
 
74baf5f
 
cc2e1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74baf5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4d1ca8
cc2e1db
d4d1ca8
 
74baf5f
d4d1ca8
 
 
 
 
 
 
 
 
 
 
 
74baf5f
 
 
d4d1ca8
 
 
 
 
 
 
 
 
 
cc2e1db
d4d1ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2e1db
d4d1ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2e1db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4d1ca8
41a1090
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
import gradio as gr
import pandas as pd
from pathlib import Path
from scripts.recommendation import summarize_events
from scripts.data_cleansing import cleanse_data
from dotenv import load_dotenv
import os
import numpy as np
import joblib

ROOT = Path(__file__).resolve().parents[0]
load_dotenv(ROOT / '.env')


def preview_csv(file_obj):
    try:
        df = pd.read_csv(file_obj.name, dtype=str)
        return df.head(10).to_html(index=False)
    except Exception as e:
        return f"Error reading file: {e}"


def parse_row_selection(df, rows_text: str):
    if not rows_text:
        return df
    idx = []
    for token in rows_text.split(','):
        token = token.strip()
        if token.isdigit():
            idx.append(int(token))
    return df.iloc[idx]


with gr.Blocks() as demo:
    gr.Markdown("# OMS Analyze — Prototype")
    gr.Markdown("> Created by PEACE, Powered by AI, Version 0.0.1")
    with gr.Tabs():
        # Upload & Preview tab
        with gr.TabItem('Upload & Preview'):
            gr.Markdown("**Usecase Scenario — Upload & Preview**: อัปโหลดไฟล์ CSV เพื่อตรวจสอบข้อมูลต้นฉบับ ทำความสะอาดข้อมูล (ลบข้อมูลซ้ำ, จัดการค่าที่หายไป) เปรียบเทียบตัวอย่างก่อน/หลัง และดาวน์โหลดไฟล์ที่ทำความสะอาดแล้ว")
            csv_up = gr.File(label='Upload CSV (data.csv)')
            with gr.Row():
                remove_dup = gr.Checkbox(label='Remove Duplicates', value=False)
                missing_handling = gr.Radio(choices=['drop','impute_mean','impute_median','impute_mode'], value='drop', label='Missing Values Handling')
                apply_clean = gr.Button('Apply Cleansing')
            with gr.Tabs():
                with gr.TabItem('Original Data'):
                    original_preview = gr.Dataframe(label='Original Data Preview')
                with gr.TabItem('Cleansed Data'):
                    cleansed_preview = gr.Dataframe(label='Cleansed Data Preview')
                    download_cleansed = gr.File(label='Download Cleansed CSV')
            clean_status = gr.Textbox(label='Cleansing Status', interactive=False)
            
            def initial_preview(file):
                if file is None:
                    return pd.DataFrame(), pd.DataFrame(), "Upload a file"
                df = pd.read_csv(file.name, dtype=str)
                return df.head(100), pd.DataFrame(), "File uploaded, apply cleansing if needed"
            
            def apply_cleansing(file, remove_duplicates, missing_strategy):
                if file is None:
                    return pd.DataFrame(), "No file", None
                try:
                    df = pd.read_csv(file.name, dtype=str)
                    df_clean, orig_shape, clean_shape = cleanse_data(df, remove_duplicates, missing_strategy)
                    status = f"Original: {orig_shape[0]} rows, {orig_shape[1]} cols → Cleaned: {clean_shape[0]} rows, {clean_shape[1]} cols"
                    # Save cleansed data for download
                    out_file = ROOT / 'outputs' / 'cleansed_data.csv'
                    out_file.parent.mkdir(exist_ok=True)
                    df_clean.to_csv(out_file, index=False, encoding='utf-8-sig')
                    return df_clean.head(100), status, str(out_file)
                except Exception as e:
                    return pd.DataFrame(), f"Error: {e}", None
            
            csv_up.change(fn=initial_preview, inputs=csv_up, outputs=[original_preview, cleansed_preview, clean_status])
            apply_clean.click(fn=apply_cleansing, inputs=[csv_up, remove_dup, missing_handling], outputs=[cleansed_preview, clean_status, download_cleansed])

        # Summary tab
        with gr.TabItem('Summary'):
            gr.Markdown("**Usecase Scenario — Summary**: สร้างสรุปภาพรวมของชุดข้อมูลทั้งหมด รวมสถิติพื้นฐาน และคำนวณดัชนีความน่าเชื่อถือ (เช่น SAIFI, SAIDI, CAIDI) พร้อมตัวเลือกใช้ Generative AI ในการขยายความ")
            csv_in_sum = gr.File(label='Upload CSV for Overall Summary')
            with gr.Row():
                use_hf_sum = gr.Checkbox(label='Use Generative AI for Summary', value=False)
                total_customers = gr.Number(label='Total Customers (for reliability calculation)', value=500000, precision=0)
                run_sum = gr.Button('Generate Overall Summary')
            with gr.Row():
                model_selector_sum = gr.Dropdown(
                    choices=[
                        'meta-llama/Llama-3.1-8B-Instruct:novita',
                        'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita',
                        'Qwen/Qwen3-VL-235B-A22B-Instruct:novita',
                        'deepseek-ai/DeepSeek-R1:novita',
                        'moonshotai/Kimi-K2-Instruct-0905:novita'
                    ],
                    value='meta-llama/Llama-3.1-8B-Instruct:novita',
                    label='GenAI Model',
                    interactive=True,
                    visible=False
                )
            
            with gr.Tabs():
                with gr.TabItem('AI Summary'):
                    ai_summary_out = gr.Textbox(label='AI Generated Summary', lines=10)
                with gr.TabItem('Basic Statistics'):
                    basic_stats_out = gr.JSON(label='Basic Statistics')
                with gr.TabItem('Reliability Indices'):
                    reliability_out = gr.Dataframe(label='Reliability Metrics')
            
            sum_status = gr.Textbox(label='Summary Status', interactive=False)

            def run_overall_summary(file, use_hf_flag, total_cust, model):
                if file is None:
                    return {}, {}, pd.DataFrame(), 'No file provided'
                try:
                    from scripts.summary import summarize_overall
                    df = pd.read_csv(file.name, dtype=str)
                    
                    result = summarize_overall(df, use_hf=use_hf_flag, model=model, total_customers=total_cust)
                    
                    # Prepare outputs
                    ai_summary = result.get('ai_summary', 'ไม่สามารถสร้างสรุปด้วย AI ได้')
                    basic_stats = {
                        'total_events': result.get('total_events'),
                        'date_range': result.get('date_range'),
                        'event_types': result.get('event_types'),
                        'total_affected_customers': result.get('total_affected_customers')
                    }
                    
                    # Reliability metrics as DataFrame
                    reliability_df = result.get('reliability_df', pd.DataFrame())
                    
                    status = f"Summary generated for {len(df)} events. AI used: {use_hf_flag}"
                    return ai_summary, basic_stats, reliability_df, status
                    
                except Exception as e:
                    return f"Error: {str(e)}", {}, pd.DataFrame(), f'Summary failed: {e}'

            def update_model_visibility_sum(use_hf_flag):
                return gr.update(visible=use_hf_flag, interactive=use_hf_flag)
            
            use_hf_sum.change(fn=update_model_visibility_sum, inputs=use_hf_sum, outputs=model_selector_sum)
            
            run_sum.click(fn=run_overall_summary, inputs=[csv_in_sum, use_hf_sum, total_customers, model_selector_sum], outputs=[ai_summary_out, basic_stats_out, reliability_out, sum_status])

        # Recommendation tab
        with gr.TabItem('Recommendation'):
            gr.Markdown("**Usecase Scenario — Recommendation**: สร้างสรุปเหตุการณ์ (เช่น สรุปเหตุการณ์ไฟฟ้าขัอข้องหรือบำรุงรักษา) สำหรับแถวที่เลือก ปรับระดับรายละเอียด และเลือกใช้ Generative AI เพื่อเพิ่มความชัดเจน พร้อมดาวน์โหลดไฟล์สรุป")
            csv_in = gr.File(label='Upload CSV (data.csv)')
            with gr.Row():
                rows = gr.Textbox(label='Rows (comma-separated indexes) or empty = all', placeholder='e.g. 0,1,2')
                use_hf = gr.Checkbox(label='Use Generative AI', value=False)
                verbosity = gr.Radio(choices=['analyze','recommend'], value='analyze', label='Summary Type', interactive=True)
                run_btn = gr.Button('Generate Summaries', interactive=True)
            with gr.Row():
                model_selector = gr.Dropdown(
                    choices=[
                        'meta-llama/Llama-3.1-8B-Instruct:novita',
                        'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita',
                        'Qwen/Qwen3-VL-235B-A22B-Instruct:novita',
                        'deepseek-ai/DeepSeek-R1:novita',
                        'moonshotai/Kimi-K2-Instruct-0905:novita'
                    ],
                    value='meta-llama/Llama-3.1-8B-Instruct:novita',
                    label='GenAI Model',
                    interactive=True,
                    visible=False
                )
            out = gr.Dataframe(headers=['EventNumber','OutageDateTime','Summary'])
            status = gr.Textbox(label='Status', interactive=False)
            download = gr.File(label='Download summaries')

            def run_summarize(file, rows_text, use_hf_flag, verbosity_level, model):
                print(f"Debug: file={file}, rows_text={rows_text}, use_hf_flag={use_hf_flag}, verbosity_level={verbosity_level}, model={model}")
                if file is None:
                    return pd.DataFrame([], columns=['EventNumber','OutageDateTime','Summary']), 'No file provided', None
                df = pd.read_csv(file.name, dtype=str)
                df_sel = parse_row_selection(df, rows_text)
                res = summarize_events(df_sel, use_hf=use_hf_flag, verbosity=verbosity_level, model=model)
                out_df = pd.DataFrame(res)
                out_file = ROOT / 'outputs' / 'summaries_from_ui.csv'
                out_file.parent.mkdir(exist_ok=True)
                out_df.to_csv(out_file, index=False, encoding='utf-8-sig')
                status_text = f"Summaries generated: {len(out_df)} rows. HF used: {use_hf_flag}"
                return out_df, status_text, str(out_file)

            def update_model_visibility(use_hf_flag):
                return gr.update(visible=use_hf_flag, interactive=use_hf_flag)
            
            use_hf.change(fn=update_model_visibility, inputs=use_hf, outputs=model_selector)
            
            run_btn.click(fn=run_summarize, inputs=[csv_in, rows, use_hf, verbosity, model_selector], outputs=[out, status, download])

        # Anomaly Detection tab
        with gr.TabItem('Anomaly Detection'):
            gr.Markdown("**Usecase Scenario — Anomaly Detection**: ตรวจจับเหตุการณ์ที่มีพฤติกรรมผิดปกติในชุดข้อมูล (เช่น เหตุการณ์ที่มีค่าสูง/ต่ำผิดปกติ) โดยใช้หลาย algorithm ปรับระดับ contamination และส่งออกผลลัพธ์พร้อมธงความผิดปกติ")
            csv_in_anom = gr.File(label='Upload CSV for Anomaly')
            with gr.Row():
                alg = gr.Radio(choices=['iso+lof','iso','lof','autoencoder'], value='iso+lof', label='Algorithm')
                contamination = gr.Slider(minimum=0.01, maximum=0.2, value=0.05, step=0.01, label='Contamination')
                run_anom = gr.Button('Run Anomaly Detection')
            anom_out = gr.Dataframe()
            anom_status = gr.Textbox(label='Anomaly Status', interactive=False)
            anom_download = gr.File(label='Download anomalies CSV')

            def run_anomaly_ui(file, algorithm, contamination):
                if file is None:
                    return pd.DataFrame(), 'No file provided', None
                from scripts.anomaly import detect_anomalies
                df = pd.read_csv(file.name, dtype=str)
                res = detect_anomalies(df, contamination=contamination, algorithm=algorithm)
                # Reorder columns to put ensemble_flag and final_flag at the end
                cols = [c for c in res.columns if c not in ['ensemble_flag', 'final_flag']] + ['ensemble_flag', 'final_flag']
                res = res[cols]
                out_file = ROOT / 'outputs' / 'anomalies_from_ui.csv'
                out_file.parent.mkdir(exist_ok=True)
                res.to_csv(out_file, index=False, encoding='utf-8-sig')
                status = f"Anomaly detection done. Rows: {len(res)}. Flags: {res['final_flag'].sum()}"
                return res, status, str(out_file)

            run_anom.click(fn=run_anomaly_ui, inputs=[csv_in_anom, alg, contamination], outputs=[anom_out, anom_status, anom_download])

        # Classification tab
        with gr.TabItem('Classification'):
            gr.Markdown("**Usecase Scenario — Classification**: ฝึกและทดสอบโมเดลเพื่อจำแนกสาเหตุของเหตุการณ์ กำหนดคอลัมน์เป้าหมาย ปรับ hyperparameters, เปิดใช้งาน weak-labeling และดาวน์โหลดโมเดล/ผลการทำนาย")
            csv_in_cls = gr.File(label='Upload CSV for Classification')
            with gr.Row():
                label_col = gr.Dropdown(choices=['CauseType','SubCauseType'], value='CauseType', label='Target Column')
                do_weak = gr.Checkbox(label='Run weak-labeling using HF (requires HF_TOKEN)', value=False)
                model_type = gr.Radio(choices=['rf','gb','mlp'], value='rf', label='Model Type')
                run_cls = gr.Button('Train Classifier')
            def update_hyperparams_visibility(model_choice):
                rf_visible = model_choice == 'rf'
                gb_visible = model_choice == 'gb'
                mlp_visible = model_choice == 'mlp'
                return [
                    gr.update(visible=rf_visible),
                    gr.update(visible=rf_visible),
                    gr.update(visible=rf_visible),
                    gr.update(visible=rf_visible),
                    gr.update(visible=gb_visible),
                    gr.update(visible=gb_visible),
                    gr.update(visible=gb_visible),
                    gr.update(visible=mlp_visible),
                    gr.update(visible=mlp_visible),
                    gr.update(visible=mlp_visible),
                ]

            with gr.Accordion("Hyperparameters (Advanced)", open=False):
                gr.Markdown("Adjust hyperparameters for the selected model. Defaults are set for good performance.")
                rf_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="RF: n_estimators", visible=True)
                rf_max_depth = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="RF: max_depth", visible=True)
                rf_min_samples_split = gr.Slider(minimum=2, maximum=10, value=2, step=1, label="RF: min_samples_split", visible=True)
                rf_min_samples_leaf = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="RF: min_samples_leaf", visible=True)
                gb_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="GB: n_estimators", visible=False)
                gb_max_depth = gr.Slider(minimum=3, maximum=20, value=3, step=1, label="GB: max_depth", visible=False)
                gb_learning_rate = gr.Slider(minimum=0.01, maximum=0.3, value=0.1, step=0.01, label="GB: learning_rate", visible=False)
                mlp_hidden_layer_sizes = gr.Textbox(value="(100,)", label="MLP: hidden_layer_sizes (tuple)", visible=False)
                mlp_alpha = gr.Slider(minimum=0.0001, maximum=0.01, value=0.0001, step=0.0001, label="MLP: alpha", visible=False)
                mlp_max_iter = gr.Slider(minimum=100, maximum=4000, value=500, step=50, label="MLP: max_iter", visible=False)

            model_type.change(fn=update_hyperparams_visibility, inputs=model_type, outputs=[rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter])

            cls_out = gr.Textbox(label='Classification Report')
            model_path_state = gr.State()
            cls_download_model = gr.File(label='Download saved model')
            cls_download_preds = gr.File(label='Download predictions CSV')
            
            # Test section
            gr.Markdown("---")
            gr.Markdown("**ทดสอบโมเดล**: อัปโหลดไฟล์ CSV ใหม่เพื่อทดสอบโมเดลที่ฝึกแล้ว")
            test_csv = gr.File(label='Upload CSV for Testing')
            run_test = gr.Button('Test Model')
            test_out = gr.Dataframe(label='Test Predictions')
            test_status = gr.Textbox(label='Test Status', interactive=False)
            test_download = gr.File(label='Download Test Predictions')

            def run_classify_ui(file, label_col_choice, use_weak, model_choice, rf_n_est, rf_max_d, rf_min_ss, rf_min_sl, gb_n_est, gb_max_d, gb_lr, mlp_hls, mlp_a, mlp_mi):
                if file is None:
                    return 'No file provided', None, None, None
                from scripts.classify import train_classifier
                df = pd.read_csv(file.name, dtype=str)
                try:
                    hyperparams = {}
                    if model_choice == 'rf':
                        hyperparams = {'n_estimators': int(rf_n_est), 'max_depth': int(rf_max_d), 'min_samples_split': int(rf_min_ss), 'min_samples_leaf': int(rf_min_sl)}
                    elif model_choice == 'gb':
                        hyperparams = {'n_estimators': int(gb_n_est), 'max_depth': int(gb_max_d), 'learning_rate': gb_lr}
                    elif model_choice == 'mlp':
                        import ast
                        hyperparams = {'hidden_layer_sizes': ast.literal_eval(mlp_hls), 'alpha': mlp_a, 'max_iter': int(mlp_mi)}
                    res = train_classifier(df, label_col=label_col_choice, model_type=model_choice, hyperparams=hyperparams)
                    report = res.get('report','')
                    model_file = res.get('model_file')
                    preds_file = res.get('predictions_file')
                    # ensure returned file paths are strings for Gradio
                    return report, model_file, preds_file, model_file
                except Exception as e:
                    return f'Training failed: {e}', None, None, None
            
            def run_test_ui(test_file, model_path):
                if test_file is None:
                    return pd.DataFrame(), 'No test file provided', None
                if model_path is None:
                    return pd.DataFrame(), 'No trained model available. Please train a model first.', None
                try:
                    from scripts.classify import parse_and_features
                    # Load model
                    model_data = joblib.load(model_path)
                    pipeline = model_data['pipeline']
                    le = model_data['label_encoder']
                    
                    # Load and preprocess test data
                    df_test = pd.read_csv(test_file.name, dtype=str)
                    df_test = parse_and_features(df_test)
                    
                    # Define features (same as training)
                    feature_cols = ['duration_min','Load(MW)_num','Capacity(kVA)_num','AffectedCustomer_num','hour','weekday','device_freq','OpDeviceType','Owner','Weather','EventType']
                    X_test = df_test[feature_cols]
                    
                    # Predict
                    y_pred_encoded = pipeline.predict(X_test)
                    y_pred = le.inverse_transform(y_pred_encoded)
                    
                    # Create output df
                    pred_df = df_test.copy()
                    pred_df['Predicted_CauseType'] = y_pred
                    
                    # Save predictions
                    out_file = ROOT / 'outputs' / 'test_predictions.csv'
                    out_file.parent.mkdir(exist_ok=True)
                    pred_df.to_csv(out_file, index=False, encoding='utf-8-sig')
                    
                    status = f"Test completed. Predictions for {len(pred_df)} rows."
                    return pred_df.head(100), status, str(out_file)
                except Exception as e:
                    return pd.DataFrame(), f'Test failed: {e}', None

            run_cls.click(fn=run_classify_ui, inputs=[csv_in_cls, label_col, do_weak, model_type, rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter], outputs=[cls_out, cls_download_model, cls_download_preds, model_path_state])
            run_test.click(fn=run_test_ui, inputs=[test_csv, model_path_state], outputs=[test_out, test_status, test_download])

        # Label Suggestion tab
        with gr.TabItem('Label Suggestion'):
            gr.Markdown("**Usecase Scenario — Label Suggestion**: ให้คำแนะนำป้ายกำกับสาเหตุที่เป็นไปได้สำหรับเหตุการณ์ที่ไม่มีฉลาก โดยเทียบความคล้ายกับตัวอย่างที่มีฉลาก ปรับจำนวนคำแนะนำสูงสุด และส่งออกเป็นไฟล์ CSV")
            csv_in_ls = gr.File(label='Upload CSV (defaults to data/data_3.csv)')
            with gr.Row():
                top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label='Top K suggestions')
                run_ls = gr.Button('Run Label Suggestion')
            ls_out = gr.Dataframe()
            ls_status = gr.Textbox(label='Label Suggestion Status', interactive=False)
            ls_download = gr.File(label='Download label suggestions')

            def run_label_suggestion(file, top_k_suggest):
                # delegate to scripts.label_suggestion
                from scripts.label_suggestion import suggest_labels_to_file
                if file is None:
                    default = ROOT / 'data' / 'data_3.csv'
                    if not default.exists():
                        return pd.DataFrame(), 'No file provided and default data/data_3.csv not found', None
                    df = pd.read_csv(default, dtype=str)
                else:
                    df = pd.read_csv(file.name, dtype=str)

                out_file = ROOT / 'outputs' / 'label_suggestions.csv'
                out_df = suggest_labels_to_file(df, out_path=str(out_file), top_k=int(top_k_suggest))
                status = f"Label suggestion done. Unknown rows processed: {len(out_df)}. Output: {out_file}"
                return out_df, status, str(out_file) if len(out_df)>0 else None

            run_ls.click(fn=run_label_suggestion, inputs=[csv_in_ls, top_k], outputs=[ls_out, ls_status, ls_download])

        # Forecasting tab
        with gr.TabItem('Forecasting'):
            gr.Markdown("**Usecase Scenario — Forecasting**: พยากรณ์จำนวนเหตุการณ์หรือเวลาหยุดทำงานในอนาคตโดยเลือกโมเดล (Prophet, LSTM, Bi-LSTM, GRU, Naive) ปรับพารามิเตอร์ และส่งออกผลการพยากรณ์")
            gr.Markdown("*Multivariate forecasting (ใช้หลายฟีเจอร์) รองรับเฉพาะโมเดล LSTM, Bi-LSTM, GRU เท่านั้น*")
            csv_in_fc = gr.File(label='Upload CSV for Forecasting')
            with gr.Row():
                metric_fc = gr.Radio(choices=['count','downtime_minutes'], value='count', label='Metric to Forecast')
                model_type_fc = gr.Radio(choices=['prophet','lstm','bilstm','gru','naive'], value='lstm', label='Forecasting Model', elem_id='forecast_model_radio')
                periods_fc = gr.Slider(minimum=1, maximum=30, value=7, step=1, label='Forecast Periods (days)')
                multivariate_fc = gr.Checkbox(value=False, label='Use Multivariate (Multiple Features)', interactive=False)
                run_fc = gr.Button('Run Forecasting')
            
            # Add state to track current model
            current_model_state = gr.State(value='lstm')
            
            def update_multivariate_visibility(model_choice):
                # Multivariate is only supported for LSTM, Bi-LSTM, GRU
                supported_models = ['lstm', 'bilstm', 'gru']
                is_supported = model_choice in supported_models
                return gr.update(interactive=is_supported, value=False)
            
            def update_model_state(model_choice):
                return model_choice
            
            # Hyperparameter controls for forecasting
            with gr.Accordion("Hyperparameters (Advanced)", open=False):
                gr.Markdown("Adjust hyperparameters for the selected forecasting model. Defaults are set for good performance.")
                
                # Prophet hyperparameters
                prophet_changepoint_prior = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Prophet: changepoint_prior_scale", visible=False)
                prophet_seasonality_prior = gr.Slider(minimum=0.01, maximum=10.0, value=10.0, step=0.1, label="Prophet: seasonality_prior_scale", visible=False)
                prophet_seasonality_mode = gr.Radio(choices=['additive', 'multiplicative'], value='additive', label="Prophet: seasonality_mode", visible=False)
                
                # Deep learning hyperparameters (LSTM, Bi-LSTM, GRU)
                dl_seq_length = gr.Slider(minimum=3, maximum=30, value=7, step=1, label="DL: sequence_length (lag/input length)", visible=True)
                dl_epochs = gr.Slider(minimum=10, maximum=200, value=100, step=10, label="DL: epochs", visible=True)
                dl_batch_size = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="DL: batch_size", visible=True)
                dl_learning_rate = gr.Slider(minimum=0.0001, maximum=0.01, value=0.001, step=0.0001, label="DL: learning_rate", visible=True)
                dl_units = gr.Slider(minimum=32, maximum=256, value=100, step=16, label="DL: units (neurons)", visible=True)
                dl_dropout = gr.Slider(minimum=0.0, maximum=0.5, value=0.2, step=0.05, label="DL: dropout_rate", visible=True)
                
                # Naive has no hyperparameters
            
            def update_forecast_hyperparams_visibility(model_choice):
                prophet_visible = model_choice == 'prophet'
                dl_visible = model_choice in ['lstm', 'bilstm', 'gru']
                return [
                    gr.update(visible=prophet_visible),  # prophet_changepoint_prior
                    gr.update(visible=prophet_visible),  # prophet_seasonality_prior
                    gr.update(visible=prophet_visible),  # prophet_seasonality_mode
                    gr.update(visible=dl_visible),       # dl_seq_length
                    gr.update(visible=dl_visible),       # dl_epochs
                    gr.update(visible=dl_visible),       # dl_batch_size
                    gr.update(visible=dl_visible),       # dl_learning_rate
                    gr.update(visible=dl_visible),       # dl_units
                    gr.update(visible=dl_visible),       # dl_dropout
                ]
            
            with gr.Tabs():
                with gr.TabItem('Historical Data'):
                    hist_out = gr.Dataframe(label='Historical Time Series Data')
                with gr.TabItem('Forecast Results'):
                    fcst_out = gr.Dataframe(label='Forecast Results')
                with gr.TabItem('Time Series Plot'):
                    plot_out = gr.Plot(label='Historical + Forecast Plot')
            fc_status = gr.Textbox(label='Forecast Status', interactive=False)
            fc_download = gr.File(label='Download forecast CSV')

            def run_forecast_ui(file, metric, model_type, periods, multivariate, current_model, prophet_cp, prophet_sp, prophet_sm, dl_sl, dl_e, dl_bs, dl_lr, dl_u, dl_d):
                # Use current_model if available, otherwise use model_type
                actual_model = current_model if current_model else model_type
                if file is None:
                    return pd.DataFrame(), pd.DataFrame(), None, 'No file provided', None
                try:
                    from scripts.forecast import run_forecast
                    import matplotlib.pyplot as plt
                    df = pd.read_csv(file.name, dtype=str)
                    
                    # Build hyperparams dict based on model type
                    hyperparams = {}
                    if actual_model == 'prophet':
                        hyperparams = {
                            'changepoint_prior_scale': prophet_cp,
                            'seasonality_prior_scale': prophet_sp,
                            'seasonality_mode': prophet_sm
                        }
                    elif actual_model in ['lstm', 'bilstm', 'gru']:
                        hyperparams = {
                            'seq_length': int(dl_sl),
                            'epochs': int(dl_e),
                            'batch_size': int(dl_bs),
                            'learning_rate': dl_lr,
                            'units': int(dl_u),
                            'dropout_rate': dl_d
                        }
                    
                    ts, fcst = run_forecast(df, metric=metric, periods=periods, model_type=actual_model, multivariate=multivariate, hyperparams=hyperparams)

                    # Create time series plot
                    fig, ax = plt.subplots(figsize=(14, 7))
                    
                    # Plot historical data
                    if len(ts) > 0 and 'y' in ts.columns:
                        ax.plot(ts['ds'], ts['y'], 'b-', label='Historical Data', linewidth=2, marker='o', markersize=4)
                    
                    # Plot forecast data
                    if len(fcst) > 0 and 'yhat' in fcst.columns:
                        ax.plot(fcst['ds'], fcst['yhat'], 'r--', label='Forecast', linewidth=3, marker='s', markersize=5)
                        if 'yhat_lower' in fcst.columns and 'yhat_upper' in fcst.columns:
                            ax.fill_between(fcst['ds'], fcst['yhat_lower'], fcst['yhat_upper'], 
                                          color='red', alpha=0.3, label='Confidence Interval')
                    
                    # Add vertical line to separate historical from forecast
                    if len(ts) > 0 and len(fcst) > 0:
                        last_hist_date = ts['ds'].max()
                        ax.axvline(x=last_hist_date, color='gray', linestyle='--', alpha=0.7, label='Forecast Start')
                    
                    ax.set_title(f'Time Series Forecast: {model_type.upper()} ({metric.replace("_", " ").title()})', 
                               fontsize=16, fontweight='bold', pad=20)
                    ax.set_xlabel('Date', fontsize=14)
                    ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=14)
                    ax.legend(loc='upper left', fontsize=12)
                    ax.grid(True, alpha=0.3)
                    
                    # Format x-axis dates
                    import matplotlib.dates as mdates
                    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
                    ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(ts) // 10)))
                    plt.xticks(rotation=45, ha='right')
                    
                    plt.tight_layout()

                    # Save forecast results
                    mode = 'multivariate' if multivariate else 'univariate'
                    if multivariate and model_type not in ['lstm', 'bilstm', 'gru']:
                        mode += ' (fallback: model does not support multivariate)'
                    out_file = ROOT / 'outputs' / f'forecast_{metric}_{model_type}_{mode.replace(" ", "_")}.csv'
                    out_file.parent.mkdir(exist_ok=True)
                    fcst.to_csv(out_file, index=False)

                    status = f"Forecasting completed using {model_type.upper()} ({mode}). Historical data: {len(ts)} days, Forecast: {len(fcst)} days."
                    if multivariate and model_type not in ['lstm', 'bilstm', 'gru']:
                        status += " Note: Model does not support multivariate - used univariate instead."
                    return ts, fcst, fig, status, str(out_file)
                except Exception as e:
                    import matplotlib.pyplot as plt
                    fig, ax = plt.subplots(figsize=(14, 7))
                    ax.text(0.5, 0.5, f'Forecasting Error:\n{str(e)}', 
                           transform=ax.transAxes, ha='center', va='center', 
                           fontsize=14, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
                    ax.set_title('Time Series Forecast - Error Occurred', fontsize=16, fontweight='bold')
                    ax.set_xlim(0, 1)
                    ax.set_ylim(0, 1)
                    plt.axis('off')
                    return pd.DataFrame(), pd.DataFrame(), fig, f'Forecasting failed: {e}', None

            model_type_fc.change(fn=update_multivariate_visibility, inputs=[model_type_fc], outputs=[multivariate_fc])
            model_type_fc.change(fn=update_model_state, inputs=[model_type_fc], outputs=[current_model_state])
            model_type_fc.change(fn=update_forecast_hyperparams_visibility, inputs=[model_type_fc], outputs=[prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout])

            run_fc.click(fn=run_forecast_ui, inputs=[csv_in_fc, metric_fc, model_type_fc, periods_fc, multivariate_fc, current_model_state, prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout], outputs=[hist_out, fcst_out, plot_out, fc_status, fc_download])

if __name__ == '__main__':
    demo.launch(server_name="0.0.0.0", server_port=7860)