classify-ecomm
a custom transformer built from scratch to classify e-commerce products into categories. no hugging face wrappers, just raw pytorch and math.
what is this?
this is a text classification model that takes a product name (e.g., "heavy-duty nylon gym bag") and predicts its category (e.g., "sports").
it uses a standard transformer architecture (self-attention, feed-forward networks, residual connections) trained on a synthetic dataset generated by llama 3. we use word-level tokenization because character-level models get confused by long adjectives.
requirements
you need python and pytorch. if you want to generate your own data, you need ollama.
pip install torch tqdm ollama
if you have an nvidia gpu, pytorch should pick it up automatically. if you are on mac, it tries to use mps. otherwise, it falls back to cpu, which will be slow.
quick start
1. get the data
you can generate the dataset yourself. it uses ollama (llama3.2:3b) to hallucinate 200k product names.
python generate_dataset_expanded.py
2. train the model
this trains the transformer. it saves the model, config, and vocabulary to classifier.pth.
python train.py
3. run inference
once you have the .pth file, you can test it.
interactive mode:
python cli.py
single prediction:
python cli.py "vintage wooden chess set"
4. evaluate
check how well the model actually works. evaluate.py tests on the training distribution. evaluate_holdout.py tests on completely unseen data structures to check for overfitting.
python evaluate.py
python evaluate_holdout.py
file structure
model.py: the actual neural network. multi-head attention, layer norms, the works.train.py: the training loop. handles batching and backprop.dataset.py: handles word tokenization and vocabulary building.predict.py: inference logic. loads the checkpoint and reconstructs the model.cli.py: simple interface to run predictions from your terminal.generate_*.py: scripts to create the synthetic datasets using local llms.
license
licensed under the apache license, version 2.0.