Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from datetime import datetime, timedelta | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Import các module của bạn | |
| from SRC.utils.preprocessing import Preproces | |
| from SRC.utils.CEEMDAN import CEEMDANWrapper | |
| from SRC.utils.data_preparation import DataPreparation | |
| from SRC.optimization.Optuna_opts import Optimizers | |
| from SRC.utils.plot_figure import plot_figure | |
| # from SRC.utils.insert_pd import insert_pd | |
| from SRC.config.constant import const # nếu bạn để trong SRC/config | |
| # ====================== GRADIO APP ====================== | |
| with gr.Blocks(title="🚀 CEEMDAN-Hybrid Stock Predictor") as demo: # ← xóa theme ở đây gr.Markdown("# CEEMDAN + Hybrid ML Stock Forecast\nHugging Face Space") | |
| with gr.Tab("1. Cấu hình & Tải dữ liệu"): | |
| with gr.Row(): | |
| ticker = gr.Textbox(value="AAPL", label="Ticker") | |
| start_date = gr.Textbox(value="2019-01-01", label="Ngày bắt đầu") | |
| end_date = gr.Textbox(value="2024-04-26", label="Ngày kết thúc") | |
| target_col = gr.Dropdown(["Close", "Adj Close"], value="Close", label="Cột dự đoán") | |
| window_size = gr.Slider(3, 20, value=5, step=1, label="Window size (lag)") | |
| split_ratio = gr.Slider(0.6, 0.95, value=0.80, step=0.05, label="Tỷ lệ train/test") | |
| ceemdan_trials = gr.Slider(10, 100, value=20, step=10, label="Số trials CEEMDAN (càng ít càng nhanh)") | |
| btn_load = gr.Button("📥 Tải & Chuẩn hóa dữ liệu", variant="primary") | |
| data_preview = gr.DataFrame(label="Dữ liệu giá") | |
| price_plot = gr.Plot(label="Biểu đồ giá lịch sử") | |
| with gr.Tab("2. Phân tích CEEMDAN"): | |
| btn_decompose = gr.Button("🔬 Chạy CEEMDAN Decomposition") | |
| ceemdan_overview = gr.DataFrame(label="Tổng quan các thành phần") | |
| ceemdan_plot = gr.Plot(label="Các IMF + Residue") | |
| with gr.Tab("3. Backtest (Đánh giá mô hình)"): | |
| models = gr.CheckboxGroup(["SVR", "DT", "RF", "KNN", "ANN"], value=["RF"], label="Chọn mô hình") | |
| n_trials = gr.Slider(5, 30, value=8, step=1, label="Số trial Optuna") | |
| btn_backtest = gr.Button("🚀 Chạy Backtest", variant="primary", size="large") | |
| metric_table = gr.DataFrame(label="Bảng so sánh Metric") | |
| tuning_table = gr.DataFrame(label="Best params & CV MSE") | |
| commentary = gr.Markdown() | |
| backtest_plot = gr.Plot() | |
| with gr.Tab("4. Dự đoán tương lai"): | |
| horizon = gr.Slider(1, 90, value=30, step=1, label="Số ngày dự đoán tương lai") | |
| models_future = gr.CheckboxGroup(["SVR", "DT", "RF", "KNN", "ANN"], value=["RF"], label="Mô hình dùng để dự báo") | |
| btn_future = gr.Button("🔮 Dự đoán tương lai", variant="primary", size="large") | |
| future_table = gr.DataFrame(label="Bảng dự đoán tương lai") | |
| future_plot = gr.Plot(label="Dự báo giá tương lai") | |
| # ====================== LOGIC ====================== | |
| config = {} | |
| df_global = None | |
| components_global = None | |
| component_names_global = None | |
| def load_data(ticker_val, start, end, target, trials_val): | |
| global df_global, config | |
| config = const().__dict__ # load default | |
| config["ticker"] = ticker_val | |
| config["start_date"] = start | |
| config["end_date"] = end | |
| config["target_col"] = target | |
| config["window_size"] = window_size.value | |
| config["split_ratio"] = split_ratio.value | |
| config["ceemdan_trials"] = int(trials_val) | |
| raw = Preproces.download_price_data(config) | |
| df_global = Preproces.preprocess_price_data(raw, target) | |
| # Plot giá | |
| fig, ax = plt.subplots(figsize=(12, 5)) | |
| ax.plot(df_global["Date"], df_global[target], label="Giá thực tế") | |
| ax.set_title(f"{ticker_val} - Giá lịch sử") | |
| ax.legend() | |
| return df_global.head(10), fig | |
| btn_load.click( | |
| load_data, | |
| inputs=[ticker, start_date, end_date, target_col, ceemdan_trials], | |
| outputs=[data_preview, price_plot] | |
| ) | |
| def run_ceemdan(): | |
| global components_global, component_names_global | |
| if df_global is None: | |
| return None, None # hoặc raise gr.Error("Vui lòng tải dữ liệu trước") | |
| signal = df_global[config["target_col"]].values | |
| components, names, comp_df, overview = CEEMDANWrapper.run_ceemdan( | |
| signal, config, df_global["Date"] | |
| ) | |
| components_global = components | |
| component_names_global = names | |
| # Tạo plot | |
| fig, axes = plt.subplots(len(names), 1, figsize=(12, 2.5 * len(names)), sharex=True) | |
| if len(names) == 1: | |
| axes = [axes] | |
| for i, (ax, name) in enumerate(zip(axes, names)): | |
| ax.plot(df_global["Date"], components[i], label=name) | |
| ax.set_title(name) | |
| ax.legend() | |
| plt.tight_layout() | |
| return overview, fig | |
| btn_decompose.click(run_ceemdan, outputs=[ceemdan_overview, ceemdan_plot]) # bạn có thể mở rộng plot | |
| def run_backtest(selected_models): | |
| if df_global is None or components_global is None: | |
| raise gr.Error("Vui lòng tải dữ liệu và chạy CEEMDAN trước!") | |
| # Tính test_target_indices CHÍNH XÁC theo window_size | |
| N = len(df_global) | |
| ws = config["window_size"] | |
| X_len = N - ws | |
| split_idx = int(X_len * config["split_ratio"]) | |
| test_start_idx = ws + split_idx # index gốc trong chuỗi | |
| test_target_indices = np.arange(test_start_idx, N) | |
| prepared_data = {} | |
| for i, name in enumerate(component_names_global): | |
| prepared_data[name] = DataPreparation.prepare_for_backtest( | |
| components_global[i], config["window_size"], config["split_ratio"] | |
| ) | |
| results = [] | |
| for model_key in selected_models: | |
| model_name = {"SVR": "SVR", "DT": "Decision Tree", "RF": "Random Forest", | |
| "KNN": "KNN", "ANN": "Neural Network"}[model_key] | |
| exp = Optimizers.run_model_experiment( | |
| model_key, | |
| model_name, | |
| prepared_data, | |
| df_global[config["target_col"]].values, | |
| df_global, | |
| test_target_indices, # ← THAY ĐỔI: dùng biến đã tính đúng | |
| config | |
| ) | |
| results.append(exp) | |
| # Plot cho model đầu tiên làm ví dụ | |
| if len(results) == 1: | |
| fig = plot_figure.plot_prediction_report(exp["prediction_df"], model_name) # sửa plot_figure để return fig nếu cần | |
| metric_df = pd.concat([r["metric_df"] for r in results], ignore_index=True) | |
| tuning_df = pd.concat([r["tuning_df"] for r in results], ignore_index=True) | |
| comment = f""" | |
| **Backtest hoàn tất với {len(selected_models)} mô hình** | |
| - Số trial Optuna: {config.get('n_trials', 8)} | |
| - CEEMDAN trials: {config.get('ceemdan_trials', 20)} | |
| - Model tốt nhất hiện tại: {results[0]['model_name']} | |
| """ | |
| return metric_df, tuning_df, comment, fig # bạn có thể trả fig nếu muốn | |
| btn_backtest.click(run_backtest, inputs=models, outputs=[metric_table, tuning_table, commentary, backtest_plot]) | |
| def run_future_forecast(selected_models, days): | |
| # Train trên toàn bộ dữ liệu | |
| prepared_data = {} | |
| for i, name in enumerate(component_names_global): | |
| prepared_data[name] = DataPreparation.prepare_for_forecast(components_global[i], config["window_size"]) | |
| component_results, _ = Optimizers.train_hybrid_model(selected_models[0], prepared_data, config) # chỉ lấy 1 model để demo, bạn có thể loop | |
| final_forecast, _ = Optimizers.forecast_hybrid_model(component_results, days, config) | |
| # Tạo bảng tương lai | |
| last_date = df_global["Date"].iloc[-1] | |
| future_dates = [last_date + timedelta(days=i+1) for i in range(days)] | |
| future_df = pd.DataFrame({ | |
| "Date": future_dates, | |
| "Predicted_Price": final_forecast.round(2) | |
| }) | |
| # Plot | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.plot(df_global["Date"], df_global[config["target_col"]], label="Lịch sử") | |
| ax.plot(future_dates, final_forecast, label="Dự báo tương lai", linestyle="--", color="red") | |
| ax.legend() | |
| ax.set_title(f"Dự báo {days} ngày tới - {config['ticker']}") | |
| return future_df, fig | |
| btn_future.click(run_future_forecast, [models_future, horizon], [future_table, future_plot]) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| theme=gr.themes.Soft(), # ← chuyển theme vào đây | |
| server_port=7860, | |
| share=False, # set True nếu muốn public link | |
| debug=False | |
| ) |