import pandas as pd from datasets import load_dataset from pathlib import Path import numpy as np import faiss import bm25s from src.fireworks.inference import create_client from src.config import EMBEDDING_MODEL _FILE_PATH = Path(__file__).parents[2] def load_amazon_raw_product_data() -> pd.DataFrame: ds = load_dataset("ckandemir/amazon-products") df = ds["train"].to_pandas() return df def load_clean_amazon_product_data() -> pd.DataFrame: return pd.read_parquet(_FILE_PATH / "data" / "amazon_products.parquet") def prepare_amazon_product_data(df: pd.DataFrame) -> pd.DataFrame: """ Data preparation for Amazon products. Args: df: DataFrame with 'Product Name', 'Category', 'Description' columns Returns: DataFrame """ # Full text is combination of Category + Description df.loc[:, "FullText"] = ( df["Product Name"] + " | " + df["Category"] + " | " + df["Description"] ) df.loc[:, "FullText"] = df.FullText.str.lower().str.strip().str.replace("\n", " ") df[["MainCategory", "SecondaryCategory", "TertiaryCategory"]] = df[ "Category" ].str.split(r" \| ", n=2, expand=True, regex=True) df = df.dropna(subset=["MainCategory", "SecondaryCategory"]) # Drop dupes df = df.drop_duplicates(subset=["FullText"]) # Downsample where MainCategory == Toys and Games to 650 since in raw data its over 70% of data df_non_toys = df[df["MainCategory"] != "Toys & Games"] df_toys = df[df["MainCategory"] == "Toys & Games"] df_toys = df_toys.sample(n=650, random_state=42) df = pd.concat([df_non_toys, df_toys]) # Filter to only top 5 MainCategories df = df[df["MainCategory"].isin(df["MainCategory"].value_counts().index[:5])] print( f"Prepared dataset with {len(df)} products with \n Count of MainCategories: {df['MainCategory'].value_counts()}" ) return df.loc[ :, [ "Product Name", "Description", "MainCategory", "SecondaryCategory", "TertiaryCategory", "FullText", ], ] def save_as_parquet(df: pd.DataFrame): """ Save DataFrame to parquet file. """ df.to_parquet(_FILE_PATH / "data" / "amazon_products.parquet", index=False) print(f"Saved to {_FILE_PATH / 'data' / 'amazon_products.parquet'}") def create_faiss_index(df: pd.DataFrame, batch_size: int = 100): """ Create FAISS index from product data using Fireworks AI embeddings. Args: df: DataFrame with 'FullText' column to embed batch_size: Number of texts to embed in each API call Returns: Tuple of (faiss_index, embeddings_array) """ client = create_client() print(f"Generating embeddings for {len(df)} products...") all_embeddings = [] texts = df["FullText"].tolist() for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] print( f"Processing batch {i // batch_size + 1}/{(len(texts) + batch_size - 1) // batch_size}" ) response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch) batch_embeddings = [item.embedding for item in response.data] all_embeddings.extend(batch_embeddings) embeddings_array = np.array(all_embeddings, dtype=np.float32) dimension = embeddings_array.shape[1] index = faiss.IndexFlatL2( dimension ) # L2 distance for cosine similarity after normalization # Normalize embeddings for cosine similarity faiss.normalize_L2(embeddings_array) index.add(embeddings_array) print(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}") faiss.write_index(index, str(_FILE_PATH / "data" / "faiss_index.bin")) np.save(_FILE_PATH / "data" / "embeddings.npy", embeddings_array) print(f"Saved FAISS index to {_FILE_PATH / 'data' / 'faiss_index.bin'}") print(f"Saved embeddings to {_FILE_PATH / 'data' / 'embeddings.npy'}") return index, embeddings_array def load_faiss_index(): """ Load pre-computed FAISS index and embeddings from disk. Returns: Tuple of (faiss_index, embeddings_array) """ index = faiss.read_index(str(_FILE_PATH / "data" / "faiss_index.bin")) embeddings = np.load(_FILE_PATH / "data" / "embeddings.npy") print(f"Loaded FAISS index with {index.ntotal} vectors") return index, embeddings def create_bm25_index(df: pd.DataFrame): """ Create BM25 index from product data for lexical search. Args: df: DataFrame with 'FullText' column to index Returns: BM25 index object """ print(f"Creating BM25 index for {len(df)} products...") corpus = df["FullText"].tolist() corpus_tokens = bm25s.tokenize(corpus, stopwords="en") retriever = bm25s.BM25() retriever.index(corpus_tokens) retriever.save(_FILE_PATH / "data" / "bm25_index") print(f"Saved BM25 index to {_FILE_PATH / 'data' / 'bm25_index'}") return retriever def load_bm25_index(): """ Load pre-computed BM25 index from disk. Returns: BM25 index object """ retriever = bm25s.BM25.load(_FILE_PATH / "data" / "bm25_index", load_corpus=False) print("Loaded BM25 index") return retriever if __name__ == "__main__": _df = load_amazon_raw_product_data() _df = prepare_amazon_product_data(_df) save_as_parquet(_df) create_bm25_index(_df) create_faiss_index(_df)