Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import joblib | |
| from huggingface_hub import hf_hub_download | |
| from datetime import datetime | |
| # --- 1. Download and Load Models --- | |
| print("Downloading models from Hugging Face Hub...") | |
| REPO_ID = "matanzig/flight-price-prediction" | |
| reg_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_rf_model.pkl")) | |
| cls_model = joblib.load(hf_hub_download(repo_id=REPO_ID, filename="flight_price_classifier_rf.pkl")) | |
| COLUMNS = [ | |
| 'startingAirport', 'destinationAirport', 'isBasicEconomy', 'isRefundable', | |
| 'isNonStop', 'seatsRemaining', 'totalTravelDistance', 'month', | |
| 'days_until_flight', 'travelDuration_mins', 'primary_airline', | |
| 'primary_cabin', 'departure_hour', 'day_of_week_Monday', | |
| 'day_of_week_Saturday', 'day_of_week_Sunday', 'day_of_week_Thursday', | |
| 'day_of_week_Tuesday', 'day_of_week_Wednesday', 'is_weekend', | |
| 'days_until_flight_squared', 'cluster_group_1', 'cluster_group_2', | |
| 'cluster_group_3', 'cluster_group_1', 'cluster_group_2', 'cluster_group_3' | |
| ] | |
| AIRLINE_MAPPING = { | |
| "Alaska Airlines": 0, "American Airlines": 1, "Boutique Air": 2, "Cape Air": 3, | |
| "Contour Airlines": 4, "Delta": 5, "Frontier Airlines": 6, "Hawaiian Airlines": 7, | |
| "JetBlue Airways": 8, "Key Lime Air": 9, "Southern Airways Express": 10, | |
| "Spirit Airlines": 11, "Sun Country Airlines": 12, "United": 13 | |
| } | |
| # --- 2. Prediction Engine --- | |
| def predict_flight_price(flight_date, distance, duration, days_until, seats, airline_name, seen_price, is_nonstop, is_basic_economy): | |
| try: | |
| dt = pd.to_datetime(flight_date, format="%Y-%m-%d") | |
| month_val = dt.month | |
| day_name = dt.day_name() | |
| is_weekend = 1 if day_name in ['Saturday', 'Sunday'] else 0 | |
| except ValueError: | |
| raise gr.Error("β Invalid Date Format! Please use exactly YYYY-MM-DD (e.g., 2026-07-15).") | |
| scaled_month = (month_val - 7.0) / 2.0 | |
| scaled_days = (days_until - 30) / 15.0 | |
| scaled_days_sq = scaled_days ** 2 | |
| airline_id = AIRLINE_MAPPING.get(airline_name, 5) | |
| row_data = [ | |
| 4, 5, | |
| int(is_basic_economy), 0, int(is_nonstop), int(seats), float(distance), | |
| float(scaled_month), float(scaled_days), float(duration), int(airline_id), 1, -0.32, | |
| day_name == 'Monday', day_name == 'Saturday', day_name == 'Sunday', | |
| day_name == 'Thursday', day_name == 'Tuesday', day_name == 'Wednesday', | |
| is_weekend, float(scaled_days_sq), | |
| False, False, False, False, False, False | |
| ] | |
| input_df = pd.DataFrame([row_data], columns=COLUMNS) | |
| reg_prediction = reg_model.predict(input_df)[0] | |
| cls_prediction = cls_model.predict(input_df)[0] | |
| tier_mapping = {0: "Budget Expected π’", 1: "Standard Expected π‘", 2: "Premium Expected π΄"} | |
| expected_tier = tier_mapping.get(cls_prediction, "Unknown") | |
| if seen_price > 0: | |
| diff = seen_price - reg_prediction | |
| if diff < -25: | |
| deal_analysis = f"π₯ Amazing Deal! This ticket is ${abs(diff):.0f} CHEAPER than the AI average. Book it now!" | |
| elif diff > 25: | |
| deal_analysis = f"β οΈ Overpriced! This ticket is ${diff:.0f} MORE EXPENSIVE than it should be. Wait or find another flight." | |
| else: | |
| deal_analysis = f"βοΈ Fair Market Price. The price you found perfectly matches our algorithm's baseline." | |
| else: | |
| deal_analysis = "Enter a 'Seen Price' above to get an instant deal analysis." | |
| return f"${reg_prediction:.2f}", expected_tier, deal_analysis | |
| # --- 3. Gradio Interface (UI) --- | |
| DESCRIPTION = """ | |
| **Welcome to the US Flight Price Predictor AI!** | |
| Enter your flight details below to get an AI-powered baseline fare and deal analysis. | |
| β οΈ **Model Limitations:** This algorithm was trained exclusively on Expedia data spanning from **April to October 2022**. It is highly optimized for summer and early fall travel dynamics, but does *not* account for major winter holiday price surges (e.g., Thanksgiving, Christmas) or extreme macroeconomic inflation events beyond that window. | |
| """ | |
| interface = gr.Interface( | |
| fn=predict_flight_price, | |
| inputs=[ | |
| gr.Textbox(value="2026-07-15", label="Flight Date (YYYY-MM-DD)", placeholder="e.g., 2026-08-01"), | |
| gr.Slider(minimum=100, maximum=3000, step=50, value=1300, label="Travel Distance (Miles)"), | |
| gr.Slider(minimum=60, maximum=800, step=10, value=400, label="Travel Duration (Minutes)"), | |
| gr.Slider(minimum=0, maximum=90, step=1, value=45, label="How many days in advance are you booking?"), | |
| gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Seats Remaining"), | |
| gr.Dropdown(choices=list(AIRLINE_MAPPING.keys()), value="Delta", label="Airline"), | |
| gr.Number(value=0, label="Price you found online ($) - Optional"), | |
| gr.Checkbox(label="Is Non-Stop (Direct Flight)?"), | |
| gr.Checkbox(label="Is Basic Economy?") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="π€ AI Predicted Fair Price"), | |
| gr.Textbox(label="π Expected Market Tier (Classifier)"), | |
| gr.Textbox(label="π‘ Personal Deal Analysis") | |
| ], | |
| title="βοΈ Smart Flight Price Predictor", | |
| description=DESCRIPTION, | |
| theme="default" | |
| ) | |
| if __name__ == "__main__": | |
| interface.launch() |