Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Integrated Market Prediction Tool | |
| This script does two things: | |
| 1. Downloads historical market data from yfinance based on YAML configuration files. | |
| 2. Provides an interactive dashboard (via Streamlit) to view the downloaded data. | |
| Usage: | |
| streamlit run main.py | |
| """ | |
| import os | |
| import yaml | |
| import yfinance as yf | |
| import pandas as pd | |
| import streamlit as st | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from ta.volatility import BollingerBands | |
| from ta.trend import SMAIndicator, EMAIndicator, MACD | |
| from ta.momentum import RSIIndicator | |
| from ta.volume import VolumeWeightedAveragePrice | |
| from models.prediction_models import MarketPredictor | |
| # ----------------------------- | |
| # Configuration Loader Function | |
| # ----------------------------- | |
| def load_yaml(filepath): | |
| """ | |
| Loads a YAML file and returns the parsed configuration. | |
| """ | |
| try: | |
| with open(filepath, 'r') as file: | |
| config = yaml.safe_load(file) | |
| return config | |
| except Exception as e: | |
| st.error(f"Error loading YAML file {filepath}: {e}") | |
| raise | |
| # ----------------------------- | |
| # Data Download Function | |
| # ----------------------------- | |
| def download_data(tickers, period, interval, output_dir): | |
| """ | |
| Download historical data for each ticker using yfinance and save as CSV. | |
| Parameters: | |
| tickers (list): List of ticker symbols. | |
| period (str): Data period (e.g., '500d' for 500 days). | |
| interval (str): Data interval (e.g., '1d' for daily data). | |
| output_dir (str): Directory where CSV files will be saved. | |
| """ | |
| downloaded = [] | |
| os.makedirs(output_dir, exist_ok=True) | |
| for ticker in tickers: | |
| st.info(f"Downloading data for {ticker}...") | |
| try: | |
| stock = yf.Ticker(ticker) | |
| data = stock.history(period=period, interval=interval) | |
| if data.empty: | |
| st.warning(f"No data returned for {ticker}. Skipping.") | |
| continue | |
| output_file = os.path.join(output_dir, f"{ticker}.csv") | |
| data.to_csv(output_file) | |
| downloaded.append(ticker) | |
| st.success(f"Saved data for {ticker} to {output_file}") | |
| except Exception as e: | |
| st.error(f"Error downloading data for {ticker}: {e}") | |
| return downloaded | |
| # ----------------------------- | |
| # Load Ticker Data Function | |
| # ----------------------------- | |
| def load_ticker_data(data_dir, ticker): | |
| """Load CSV data for a given ticker.""" | |
| file_path = os.path.join(data_dir, f"{ticker}.csv") | |
| if os.path.exists(file_path): | |
| try: | |
| data = pd.read_csv(file_path, parse_dates=True, index_col='Date') | |
| return data | |
| except Exception as e: | |
| st.error(f"Error loading {ticker} data: {e}") | |
| else: | |
| st.warning(f"Data file for {ticker} not found in {data_dir}.") | |
| return None | |
| # ----------------------------- | |
| # Main Dashboard Function | |
| # ----------------------------- | |
| def main(): | |
| st.set_page_config(layout="wide") | |
| st.title("Market Data Simulation and Prediction Dashboard") | |
| # Define configuration file paths and load configurations | |
| market_config_path = os.path.join(".", "market_config.yml") | |
| project_config_path = os.path.join(".", "project_config.yml") | |
| try: | |
| market_config = load_yaml(market_config_path) | |
| project_config = load_yaml(project_config_path) | |
| if market_config is None or project_config is None: | |
| st.error("Configuration files are invalid. Please check market_config.yml and project_config.yml") | |
| return | |
| except Exception as e: | |
| st.error("Could not load configuration files.") | |
| return | |
| # Setup output directory | |
| output_dir = project_config.get('project', {}).get('directory', 'data') | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Sidebar controls | |
| with st.sidebar: | |
| st.header("Controls") | |
| # Get available groups from config and allow multiple selection | |
| available_groups = list(market_config.get('groups', {}).keys()) | |
| selected_groups = st.multiselect("Select Groups", available_groups, | |
| default=['crypto', 'commodities'] if 'crypto' in available_groups else [available_groups[0]]) | |
| # Initialize session state for selected tickers if not exists | |
| if 'selected_tickers' not in st.session_state: | |
| st.session_state.selected_tickers = set() | |
| # Initialize selected_tickers as empty list | |
| selected_tickers = [] | |
| # Create tabs for each selected group | |
| if selected_groups: | |
| tabs = st.tabs(selected_groups) | |
| for group, tab in zip(selected_groups, tabs): | |
| with tab: | |
| group_tickers = market_config['groups'][group] | |
| # Get currently selected tickers for this group | |
| group_selections = st.multiselect( | |
| f"Select {group} tickers", | |
| sorted(group_tickers), | |
| default=list(st.session_state.selected_tickers.intersection(group_tickers)) | |
| ) | |
| # Update session state with selections from this group | |
| st.session_state.selected_tickers.difference_update(group_tickers) | |
| st.session_state.selected_tickers.update(group_selections) | |
| selected_tickers = list(st.session_state.selected_tickers) | |
| # Show currently selected tickers | |
| if selected_tickers: | |
| st.write("Currently selected:") | |
| st.write(", ".join(sorted(selected_tickers))) | |
| # Main content area | |
| if selected_tickers: | |
| # Prepare download parameters | |
| days = market_config.get('market', {}).get('data_history', 500) | |
| period_str = f"{days}d" | |
| interval = market_config.get('market', {}).get('data_fractal', "1d") | |
| # Check which tickers need downloading | |
| files = set(f.split('.')[0] for f in os.listdir(output_dir) if f.endswith('.csv')) | |
| tickers_to_download = [t for t in selected_tickers if t not in files] | |
| # Download any missing data | |
| if tickers_to_download: | |
| download_status = st.empty() | |
| progress_bar = st.progress(0) | |
| for idx, ticker in enumerate(tickers_to_download, 1): | |
| download_status.text(f"Downloading {ticker}...") | |
| try: | |
| download_data([ticker], period_str, interval, output_dir) | |
| progress_bar.progress(idx / len(tickers_to_download)) | |
| except Exception as e: | |
| st.warning(f"Failed to download {ticker}: {str(e)}") | |
| download_status.empty() | |
| progress_bar.empty() | |
| # Create the combined plot | |
| combined_data = None | |
| for ticker in selected_tickers: | |
| data = load_ticker_data(output_dir, ticker) | |
| if data is not None: | |
| if 'Date' in data.columns: | |
| data['Date'] = pd.to_datetime(data['Date'], utc=True) | |
| data.set_index('Date', inplace=True) | |
| else: | |
| data.index = pd.to_datetime(data.index, utc=True) | |
| if 'Close' in data.columns: | |
| data = data.rename(columns={'Close': ticker}) | |
| col_data = data[[ticker]] | |
| if combined_data is None: | |
| combined_data = col_data | |
| else: | |
| combined_data = combined_data.join(col_data, how='outer') | |
| if combined_data is not None: | |
| # Add tabs for different analysis types | |
| analysis_tab, prediction_tab = st.tabs(["Technical Analysis", "Predictions & Risk"]) | |
| with analysis_tab: | |
| st.write("Technical Analysis Tools:") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| show_bb = st.checkbox('Bollinger Bands') | |
| show_rsi = st.checkbox('RSI') | |
| with col2: | |
| show_sma = st.checkbox('SMA') | |
| sma_period = st.number_input('SMA Period', min_value=1, value=50, max_value=200) if show_sma else 50 | |
| with col3: | |
| show_ema = st.checkbox('EMA') | |
| ema_period = st.number_input('EMA Period', min_value=1, value=20, max_value=200) if show_ema else 20 | |
| with col4: | |
| show_vwap = st.checkbox('VWAP') | |
| show_macd = st.checkbox('MACD') | |
| # Create plot | |
| fig = go.Figure() | |
| # Add price lines for each ticker | |
| for ticker in combined_data.columns: | |
| fig.add_trace(go.Scatter( | |
| x=combined_data.index, | |
| y=combined_data[ticker], | |
| name=ticker, | |
| mode='lines' | |
| )) | |
| if show_bb: | |
| bb = BollingerBands(close=combined_data[ticker], window=20, window_dev=2) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=bb.bollinger_hband(), name=f'{ticker} BB Upper', | |
| line=dict(dash='dash'), opacity=0.7)) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=bb.bollinger_lband(), name=f'{ticker} BB Lower', | |
| line=dict(dash='dash'), opacity=0.7)) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=bb.bollinger_mavg(), name=f'{ticker} BB MA', | |
| line=dict(dash='dash'), opacity=0.7)) | |
| if show_sma: | |
| sma = SMAIndicator(close=combined_data[ticker], window=sma_period) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=sma.sma_indicator(), | |
| name=f'{ticker} SMA{sma_period}', line=dict(dash='dot'))) | |
| if show_ema: | |
| ema = EMAIndicator(close=combined_data[ticker], window=ema_period) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=ema.ema_indicator(), | |
| name=f'{ticker} EMA{ema_period}', line=dict(dash='dot'))) | |
| if show_vwap: | |
| # Note: VWAP typically needs high, low, close, and volume data | |
| data = load_ticker_data(output_dir, ticker) | |
| if 'Volume' in data.columns: | |
| vwap = VolumeWeightedAveragePrice( | |
| high=data['High'], | |
| low=data['Low'], | |
| close=data['Close'], | |
| volume=data['Volume'], | |
| window=14 | |
| ) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=vwap.volume_weighted_average_price(), | |
| name=f'{ticker} VWAP', line=dict(dash='dashdot'))) | |
| if show_macd: | |
| macd = MACD(close=combined_data[ticker]) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=macd.macd(), | |
| name=f'{ticker} MACD', line=dict(dash='dot'))) | |
| fig.add_trace(go.Scatter(x=combined_data.index, y=macd.macd_signal(), | |
| name=f'{ticker} Signal', line=dict(dash='dot'))) | |
| # Update layout | |
| fig.update_layout( | |
| height=800, | |
| showlegend=True, | |
| legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
| margin=dict(l=40, r=40, t=40, b=40), | |
| hovermode='x unified' | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| with prediction_tab: | |
| st.write("Prediction & Risk Analysis Tools:") | |
| pred_col1, pred_col2 = st.columns(2) | |
| with pred_col1: | |
| prediction_days = st.slider("Prediction Days", 5, 60, 30) | |
| show_rf = st.checkbox("Random Forest Prediction") | |
| show_exp = st.checkbox("Exponential Smoothing") | |
| show_monte_carlo = st.checkbox("Monte Carlo Simulation") | |
| with pred_col2: | |
| show_var = st.checkbox("Value at Risk") | |
| show_patterns = st.checkbox("Pattern Detection") | |
| show_breakouts = st.checkbox("Breakout Prediction") | |
| if any([show_rf, show_exp, show_monte_carlo, show_var, show_patterns, show_breakouts]): | |
| predictor = MarketPredictor() | |
| # Create prediction figure | |
| pred_fig = go.Figure() | |
| for ticker in combined_data.columns: | |
| # Add original data | |
| pred_fig.add_trace(go.Scatter( | |
| x=combined_data.index, | |
| y=combined_data[ticker], | |
| name=f"{ticker} (Actual)", | |
| mode='lines' | |
| )) | |
| if show_rf: | |
| predictor.train_rf(combined_data[ticker]) | |
| rf_pred = predictor.predict_rf(combined_data[ticker], days_ahead=prediction_days) | |
| future_dates = pd.date_range( | |
| start=combined_data.index[-1], | |
| periods=prediction_days+1 | |
| )[1:] | |
| pred_fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=rf_pred.flatten(), | |
| name=f"{ticker} (Random Forest)", | |
| line=dict(dash='dash') | |
| )) | |
| if show_exp: | |
| predictor.train_exp(combined_data[ticker]) | |
| exp_pred = predictor.predict_exp(days_ahead=prediction_days) | |
| pred_fig.add_trace(go.Scatter( | |
| x=pd.date_range(start=combined_data.index[-1], periods=prediction_days+1)[1:], | |
| y=exp_pred, | |
| name=f"{ticker} (Exp Smoothing)", | |
| line=dict(dash='dot') | |
| )) | |
| if show_monte_carlo: | |
| mc_sims = predictor.monte_carlo_simulation( | |
| combined_data[ticker], | |
| days_ahead=prediction_days | |
| ) | |
| future_dates = pd.date_range( | |
| start=combined_data.index[-1], | |
| periods=prediction_days+1 | |
| )[1:] | |
| # Plot confidence intervals | |
| upper = mc_sims.quantile(0.95, axis=1) | |
| lower = mc_sims.quantile(0.05, axis=1) | |
| pred_fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=upper, | |
| fill=None, | |
| mode='lines', | |
| line_color='rgba(0,100,80,0.2)', | |
| name=f"{ticker} MC (95% Upper)" | |
| )) | |
| pred_fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=lower, | |
| fill='tonexty', | |
| mode='lines', | |
| line_color='rgba(0,100,80,0.2)', | |
| name=f"{ticker} MC (95% Lower)" | |
| )) | |
| if show_var: | |
| var = predictor.calculate_var(combined_data[ticker]) | |
| st.write(f"Value at Risk (95%) for {ticker}: {var:.2%}") | |
| if show_patterns or show_breakouts: | |
| # Load OHLC data for pattern detection | |
| ticker_data = load_ticker_data(output_dir, ticker) | |
| if show_patterns: | |
| patterns = predictor.detect_patterns(ticker_data) | |
| st.write(f"\nDetected patterns for {ticker}:") | |
| for pattern, signals in patterns.items(): | |
| if signals.any(): | |
| st.write(f"- {pattern}: {signals.sum()} occurrences") | |
| if show_breakouts: | |
| breakouts = predictor.predict_breakouts(ticker_data) | |
| high_prob = breakouts[breakouts['breakout_probability'] > 0.7] | |
| if not high_prob.empty: | |
| st.write(f"\nPotential breakout points for {ticker}:") | |
| st.write(high_prob) | |
| # Update prediction figure layout | |
| pred_fig.update_layout( | |
| height=600, | |
| title="Price Predictions", | |
| showlegend=True, | |
| legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
| margin=dict(l=40, r=40, t=40, b=40), | |
| hovermode='x unified' | |
| ) | |
| st.plotly_chart(pred_fig, use_container_width=True) | |
| # Optional: Display the raw data in an expander | |
| with st.expander("View Raw Data"): | |
| st.dataframe(combined_data) | |
| else: | |
| st.markdown("### Market Analysis Dashboard π") | |
| st.markdown(""" | |
| Compare stocks, crypto, commodities and more with advanced technical indicators. | |
| #### Getting Started π | |
| 1. **Select Asset Groups** π | |
| - Choose from Stocks, Crypto, Commodities, ETFs, and Bonds | |
| - Mix multiple groups to compare different asset classes | |
| 2. **Pick Your Tickers** π | |
| - Select assets from each group's tab | |
| - Compare any combination (e.g., Bitcoin vs Gold vs S&P 500) | |
| - Data downloads automatically | |
| 3. **Add Indicators** βοΈ | |
| - Bollinger Bands for volatility analysis | |
| - SMA/EMA with custom periods | |
| - MACD for trend detection | |
| - RSI for momentum | |
| - VWAP for price action | |
| #### Chart Features π | |
| - Zoom: Mouse wheel or toolbar | |
| - Pan: Click and drag | |
| - Data: Toggle assets in legend | |
| - Details: Hover for exact values | |
| """) | |
| # ----------------------------- | |
| # Entry Point | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| main() | |