FinanceAuger / main.py
therickglenn's picture
Create main.py
70fc7f7 verified
raw
history blame
19.5 kB
#!/usr/bin/env python3
"""
Integrated Market Prediction Tool
This script does two things:
1. Downloads historical market data from yfinance based on YAML configuration files.
2. Provides an interactive dashboard (via Streamlit) to view the downloaded data.
Usage:
streamlit run main.py
"""
import os
import yaml
import yfinance as yf
import pandas as pd
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from ta.volatility import BollingerBands
from ta.trend import SMAIndicator, EMAIndicator, MACD
from ta.momentum import RSIIndicator
from ta.volume import VolumeWeightedAveragePrice
from models.prediction_models import MarketPredictor
# -----------------------------
# Configuration Loader Function
# -----------------------------
def load_yaml(filepath):
"""
Loads a YAML file and returns the parsed configuration.
"""
try:
with open(filepath, 'r') as file:
config = yaml.safe_load(file)
return config
except Exception as e:
st.error(f"Error loading YAML file {filepath}: {e}")
raise
# -----------------------------
# Data Download Function
# -----------------------------
def download_data(tickers, period, interval, output_dir):
"""
Download historical data for each ticker using yfinance and save as CSV.
Parameters:
tickers (list): List of ticker symbols.
period (str): Data period (e.g., '500d' for 500 days).
interval (str): Data interval (e.g., '1d' for daily data).
output_dir (str): Directory where CSV files will be saved.
"""
downloaded = []
os.makedirs(output_dir, exist_ok=True)
for ticker in tickers:
st.info(f"Downloading data for {ticker}...")
try:
stock = yf.Ticker(ticker)
data = stock.history(period=period, interval=interval)
if data.empty:
st.warning(f"No data returned for {ticker}. Skipping.")
continue
output_file = os.path.join(output_dir, f"{ticker}.csv")
data.to_csv(output_file)
downloaded.append(ticker)
st.success(f"Saved data for {ticker} to {output_file}")
except Exception as e:
st.error(f"Error downloading data for {ticker}: {e}")
return downloaded
# -----------------------------
# Load Ticker Data Function
# -----------------------------
@st.cache_data(show_spinner=False)
def load_ticker_data(data_dir, ticker):
"""Load CSV data for a given ticker."""
file_path = os.path.join(data_dir, f"{ticker}.csv")
if os.path.exists(file_path):
try:
data = pd.read_csv(file_path, parse_dates=True, index_col='Date')
return data
except Exception as e:
st.error(f"Error loading {ticker} data: {e}")
else:
st.warning(f"Data file for {ticker} not found in {data_dir}.")
return None
# -----------------------------
# Main Dashboard Function
# -----------------------------
def main():
st.set_page_config(layout="wide")
st.title("Market Data Simulation and Prediction Dashboard")
# Define configuration file paths and load configurations
market_config_path = os.path.join(".", "market_config.yml")
project_config_path = os.path.join(".", "project_config.yml")
try:
market_config = load_yaml(market_config_path)
project_config = load_yaml(project_config_path)
if market_config is None or project_config is None:
st.error("Configuration files are invalid. Please check market_config.yml and project_config.yml")
return
except Exception as e:
st.error("Could not load configuration files.")
return
# Setup output directory
output_dir = project_config.get('project', {}).get('directory', 'data')
os.makedirs(output_dir, exist_ok=True)
# Sidebar controls
with st.sidebar:
st.header("Controls")
# Get available groups from config and allow multiple selection
available_groups = list(market_config.get('groups', {}).keys())
selected_groups = st.multiselect("Select Groups", available_groups,
default=['crypto', 'commodities'] if 'crypto' in available_groups else [available_groups[0]])
# Initialize session state for selected tickers if not exists
if 'selected_tickers' not in st.session_state:
st.session_state.selected_tickers = set()
# Initialize selected_tickers as empty list
selected_tickers = []
# Create tabs for each selected group
if selected_groups:
tabs = st.tabs(selected_groups)
for group, tab in zip(selected_groups, tabs):
with tab:
group_tickers = market_config['groups'][group]
# Get currently selected tickers for this group
group_selections = st.multiselect(
f"Select {group} tickers",
sorted(group_tickers),
default=list(st.session_state.selected_tickers.intersection(group_tickers))
)
# Update session state with selections from this group
st.session_state.selected_tickers.difference_update(group_tickers)
st.session_state.selected_tickers.update(group_selections)
selected_tickers = list(st.session_state.selected_tickers)
# Show currently selected tickers
if selected_tickers:
st.write("Currently selected:")
st.write(", ".join(sorted(selected_tickers)))
# Main content area
if selected_tickers:
# Prepare download parameters
days = market_config.get('market', {}).get('data_history', 500)
period_str = f"{days}d"
interval = market_config.get('market', {}).get('data_fractal', "1d")
# Check which tickers need downloading
files = set(f.split('.')[0] for f in os.listdir(output_dir) if f.endswith('.csv'))
tickers_to_download = [t for t in selected_tickers if t not in files]
# Download any missing data
if tickers_to_download:
download_status = st.empty()
progress_bar = st.progress(0)
for idx, ticker in enumerate(tickers_to_download, 1):
download_status.text(f"Downloading {ticker}...")
try:
download_data([ticker], period_str, interval, output_dir)
progress_bar.progress(idx / len(tickers_to_download))
except Exception as e:
st.warning(f"Failed to download {ticker}: {str(e)}")
download_status.empty()
progress_bar.empty()
# Create the combined plot
combined_data = None
for ticker in selected_tickers:
data = load_ticker_data(output_dir, ticker)
if data is not None:
if 'Date' in data.columns:
data['Date'] = pd.to_datetime(data['Date'], utc=True)
data.set_index('Date', inplace=True)
else:
data.index = pd.to_datetime(data.index, utc=True)
if 'Close' in data.columns:
data = data.rename(columns={'Close': ticker})
col_data = data[[ticker]]
if combined_data is None:
combined_data = col_data
else:
combined_data = combined_data.join(col_data, how='outer')
if combined_data is not None:
# Add tabs for different analysis types
analysis_tab, prediction_tab = st.tabs(["Technical Analysis", "Predictions & Risk"])
with analysis_tab:
st.write("Technical Analysis Tools:")
col1, col2, col3, col4 = st.columns(4)
with col1:
show_bb = st.checkbox('Bollinger Bands')
show_rsi = st.checkbox('RSI')
with col2:
show_sma = st.checkbox('SMA')
sma_period = st.number_input('SMA Period', min_value=1, value=50, max_value=200) if show_sma else 50
with col3:
show_ema = st.checkbox('EMA')
ema_period = st.number_input('EMA Period', min_value=1, value=20, max_value=200) if show_ema else 20
with col4:
show_vwap = st.checkbox('VWAP')
show_macd = st.checkbox('MACD')
# Create plot
fig = go.Figure()
# Add price lines for each ticker
for ticker in combined_data.columns:
fig.add_trace(go.Scatter(
x=combined_data.index,
y=combined_data[ticker],
name=ticker,
mode='lines'
))
if show_bb:
bb = BollingerBands(close=combined_data[ticker], window=20, window_dev=2)
fig.add_trace(go.Scatter(x=combined_data.index, y=bb.bollinger_hband(), name=f'{ticker} BB Upper',
line=dict(dash='dash'), opacity=0.7))
fig.add_trace(go.Scatter(x=combined_data.index, y=bb.bollinger_lband(), name=f'{ticker} BB Lower',
line=dict(dash='dash'), opacity=0.7))
fig.add_trace(go.Scatter(x=combined_data.index, y=bb.bollinger_mavg(), name=f'{ticker} BB MA',
line=dict(dash='dash'), opacity=0.7))
if show_sma:
sma = SMAIndicator(close=combined_data[ticker], window=sma_period)
fig.add_trace(go.Scatter(x=combined_data.index, y=sma.sma_indicator(),
name=f'{ticker} SMA{sma_period}', line=dict(dash='dot')))
if show_ema:
ema = EMAIndicator(close=combined_data[ticker], window=ema_period)
fig.add_trace(go.Scatter(x=combined_data.index, y=ema.ema_indicator(),
name=f'{ticker} EMA{ema_period}', line=dict(dash='dot')))
if show_vwap:
# Note: VWAP typically needs high, low, close, and volume data
data = load_ticker_data(output_dir, ticker)
if 'Volume' in data.columns:
vwap = VolumeWeightedAveragePrice(
high=data['High'],
low=data['Low'],
close=data['Close'],
volume=data['Volume'],
window=14
)
fig.add_trace(go.Scatter(x=combined_data.index, y=vwap.volume_weighted_average_price(),
name=f'{ticker} VWAP', line=dict(dash='dashdot')))
if show_macd:
macd = MACD(close=combined_data[ticker])
fig.add_trace(go.Scatter(x=combined_data.index, y=macd.macd(),
name=f'{ticker} MACD', line=dict(dash='dot')))
fig.add_trace(go.Scatter(x=combined_data.index, y=macd.macd_signal(),
name=f'{ticker} Signal', line=dict(dash='dot')))
# Update layout
fig.update_layout(
height=800,
showlegend=True,
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
margin=dict(l=40, r=40, t=40, b=40),
hovermode='x unified'
)
st.plotly_chart(fig, use_container_width=True)
with prediction_tab:
st.write("Prediction & Risk Analysis Tools:")
pred_col1, pred_col2 = st.columns(2)
with pred_col1:
prediction_days = st.slider("Prediction Days", 5, 60, 30)
show_rf = st.checkbox("Random Forest Prediction")
show_exp = st.checkbox("Exponential Smoothing")
show_monte_carlo = st.checkbox("Monte Carlo Simulation")
with pred_col2:
show_var = st.checkbox("Value at Risk")
show_patterns = st.checkbox("Pattern Detection")
show_breakouts = st.checkbox("Breakout Prediction")
if any([show_rf, show_exp, show_monte_carlo, show_var, show_patterns, show_breakouts]):
predictor = MarketPredictor()
# Create prediction figure
pred_fig = go.Figure()
for ticker in combined_data.columns:
# Add original data
pred_fig.add_trace(go.Scatter(
x=combined_data.index,
y=combined_data[ticker],
name=f"{ticker} (Actual)",
mode='lines'
))
if show_rf:
predictor.train_rf(combined_data[ticker])
rf_pred = predictor.predict_rf(combined_data[ticker], days_ahead=prediction_days)
future_dates = pd.date_range(
start=combined_data.index[-1],
periods=prediction_days+1
)[1:]
pred_fig.add_trace(go.Scatter(
x=future_dates,
y=rf_pred.flatten(),
name=f"{ticker} (Random Forest)",
line=dict(dash='dash')
))
if show_exp:
predictor.train_exp(combined_data[ticker])
exp_pred = predictor.predict_exp(days_ahead=prediction_days)
pred_fig.add_trace(go.Scatter(
x=pd.date_range(start=combined_data.index[-1], periods=prediction_days+1)[1:],
y=exp_pred,
name=f"{ticker} (Exp Smoothing)",
line=dict(dash='dot')
))
if show_monte_carlo:
mc_sims = predictor.monte_carlo_simulation(
combined_data[ticker],
days_ahead=prediction_days
)
future_dates = pd.date_range(
start=combined_data.index[-1],
periods=prediction_days+1
)[1:]
# Plot confidence intervals
upper = mc_sims.quantile(0.95, axis=1)
lower = mc_sims.quantile(0.05, axis=1)
pred_fig.add_trace(go.Scatter(
x=future_dates,
y=upper,
fill=None,
mode='lines',
line_color='rgba(0,100,80,0.2)',
name=f"{ticker} MC (95% Upper)"
))
pred_fig.add_trace(go.Scatter(
x=future_dates,
y=lower,
fill='tonexty',
mode='lines',
line_color='rgba(0,100,80,0.2)',
name=f"{ticker} MC (95% Lower)"
))
if show_var:
var = predictor.calculate_var(combined_data[ticker])
st.write(f"Value at Risk (95%) for {ticker}: {var:.2%}")
if show_patterns or show_breakouts:
# Load OHLC data for pattern detection
ticker_data = load_ticker_data(output_dir, ticker)
if show_patterns:
patterns = predictor.detect_patterns(ticker_data)
st.write(f"\nDetected patterns for {ticker}:")
for pattern, signals in patterns.items():
if signals.any():
st.write(f"- {pattern}: {signals.sum()} occurrences")
if show_breakouts:
breakouts = predictor.predict_breakouts(ticker_data)
high_prob = breakouts[breakouts['breakout_probability'] > 0.7]
if not high_prob.empty:
st.write(f"\nPotential breakout points for {ticker}:")
st.write(high_prob)
# Update prediction figure layout
pred_fig.update_layout(
height=600,
title="Price Predictions",
showlegend=True,
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
margin=dict(l=40, r=40, t=40, b=40),
hovermode='x unified'
)
st.plotly_chart(pred_fig, use_container_width=True)
# Optional: Display the raw data in an expander
with st.expander("View Raw Data"):
st.dataframe(combined_data)
else:
st.markdown("### Market Analysis Dashboard πŸ“Š")
st.markdown("""
Compare stocks, crypto, commodities and more with advanced technical indicators.
#### Getting Started πŸš€
1. **Select Asset Groups** πŸ‘ˆ
- Choose from Stocks, Crypto, Commodities, ETFs, and Bonds
- Mix multiple groups to compare different asset classes
2. **Pick Your Tickers** πŸ“ˆ
- Select assets from each group's tab
- Compare any combination (e.g., Bitcoin vs Gold vs S&P 500)
- Data downloads automatically
3. **Add Indicators** βš™οΈ
- Bollinger Bands for volatility analysis
- SMA/EMA with custom periods
- MACD for trend detection
- RSI for momentum
- VWAP for price action
#### Chart Features πŸ”
- Zoom: Mouse wheel or toolbar
- Pan: Click and drag
- Data: Toggle assets in legend
- Details: Hover for exact values
""")
# -----------------------------
# Entry Point
# -----------------------------
if __name__ == "__main__":
main()