matanzig's picture
Update app.py
49e6cce verified
import gradio as gr
import pandas as pd
import joblib
from huggingface_hub import hf_hub_download
from datetime import datetime
# --- 1. Download and Load Models ---
print("Downloading models from Hugging Face Hub...")
REPO_ID = "matanzig/flight-price-prediction"
reg_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_rf_model.pkl"))
cls_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_classifier_rf.pkl"))
COLUMNS = [
'startingAirport', 'destinationAirport', 'isBasicEconomy', 'isRefundable',
'isNonStop', 'seatsRemaining', 'totalTravelDistance', 'month',
'days_until_flight', 'travelDuration_mins', 'primary_airline',
'primary_cabin', 'departure_hour', 'day_of_week_Monday',
'day_of_week_Saturday', 'day_of_week_Sunday', 'day_of_week_Thursday',
'day_of_week_Tuesday', 'day_of_week_Wednesday', 'is_weekend',
'days_until_flight_squared', 'cluster_group_1', 'cluster_group_2',
'cluster_group_3', 'cluster_group_1', 'cluster_group_2', 'cluster_group_3'
]
AIRLINE_MAPPING = {
"Alaska Airlines": 0, "American Airlines": 1, "Boutique Air": 2, "Cape Air": 3,
"Contour Airlines": 4, "Delta": 5, "Frontier Airlines": 6, "Hawaiian Airlines": 7,
"JetBlue Airways": 8, "Key Lime Air": 9, "Southern Airways Express": 10,
"Spirit Airlines": 11, "Sun Country Airlines": 12, "United": 13
}
# --- 2. Prediction Engine ---
def predict_flight_price(flight_date, distance, duration, days_until, seats, airline_name, seen_price, is_nonstop, is_basic_economy):
try:
dt = pd.to_datetime(flight_date, format="%Y-%m-%d")
month_val = dt.month
day_name = dt.day_name()
is_weekend = 1 if day_name in ['Saturday', 'Sunday'] else 0
except ValueError:
raise gr.Error("❌ Invalid Date Format! Please use exactly YYYY-MM-DD (e.g., 2026-07-15).")
scaled_month = (month_val - 7.0) / 2.0
scaled_days = (days_until - 30) / 15.0
scaled_days_sq = scaled_days ** 2
airline_id = AIRLINE_MAPPING.get(airline_name, 5)
row_data = [
4, 5,
int(is_basic_economy), 0, int(is_nonstop), int(seats), float(distance),
float(scaled_month), float(scaled_days), float(duration), int(airline_id), 1, -0.32,
day_name == 'Monday', day_name == 'Saturday', day_name == 'Sunday',
day_name == 'Thursday', day_name == 'Tuesday', day_name == 'Wednesday',
is_weekend, float(scaled_days_sq),
False, False, False, False, False, False
]
input_df = pd.DataFrame([row_data], columns=COLUMNS)
reg_prediction = reg_model.predict(input_df)[0]
cls_prediction = cls_model.predict(input_df)[0]
tier_mapping = {0: "Budget Expected 🟒", 1: "Standard Expected 🟑", 2: "Premium Expected πŸ”΄"}
expected_tier = tier_mapping.get(cls_prediction, "Unknown")
if seen_price > 0:
diff = seen_price - reg_prediction
if diff < -25:
deal_analysis = f"πŸ”₯ Amazing Deal! This ticket is ${abs(diff):.0f} CHEAPER than the AI average. Book it now!"
elif diff > 25:
deal_analysis = f"⚠️ Overpriced! This ticket is ${diff:.0f} MORE EXPENSIVE than it should be. Wait or find another flight."
else:
deal_analysis = f"βš–οΈ Fair Market Price. The price you found perfectly matches our algorithm's baseline."
else:
deal_analysis = "Enter a 'Seen Price' above to get an instant deal analysis."
return f"${reg_prediction:.2f}", expected_tier, deal_analysis
# --- 3. Gradio Interface (UI) ---
DESCRIPTION = """
**Welcome to the US Flight Price Predictor AI!**
Enter your flight details below to get an AI-powered baseline fare and deal analysis.
⚠️ **Model Limitations:** This algorithm was trained exclusively on Expedia data spanning from **April to October 2022**. It is highly optimized for summer and early fall travel dynamics, but does *not* account for major winter holiday price surges (e.g., Thanksgiving, Christmas) or extreme macroeconomic inflation events beyond that window.
"""
interface = gr.Interface(
fn=predict_flight_price,
inputs=[
gr.Textbox(value="2026-07-15", label="Flight Date (YYYY-MM-DD)", placeholder="e.g., 2026-08-01"),
gr.Slider(minimum=100, maximum=3000, step=50, value=1300, label="Travel Distance (Miles)"),
gr.Slider(minimum=60, maximum=800, step=10, value=400, label="Travel Duration (Minutes)"),
gr.Slider(minimum=0, maximum=90, step=1, value=45, label="How many days in advance are you booking?"),
gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Seats Remaining"),
gr.Dropdown(choices=list(AIRLINE_MAPPING.keys()), value="Delta", label="Airline"),
gr.Number(value=0, label="Price you found online ($) - Optional"),
gr.Checkbox(label="Is Non-Stop (Direct Flight)?"),
gr.Checkbox(label="Is Basic Economy?")
],
outputs=[
gr.Textbox(label="πŸ€– AI Predicted Fair Price"),
gr.Textbox(label="πŸ“Š Expected Market Tier (Classifier)"),
gr.Textbox(label="πŸ’‘ Personal Deal Analysis")
],
title="✈️ Smart Flight Price Predictor",
description=DESCRIPTION,
theme="default"
)
if __name__ == "__main__":
interface.launch()