classify-ecomm

a custom transformer built from scratch to classify e-commerce products into categories. no hugging face wrappers, just raw pytorch and math.

what is this?

this is a text classification model that takes a product name (e.g., "heavy-duty nylon gym bag") and predicts its category (e.g., "sports").

it uses a standard transformer architecture (self-attention, feed-forward networks, residual connections) trained on a synthetic dataset generated by llama 3. we use word-level tokenization because character-level models get confused by long adjectives.

requirements

you need python and pytorch. if you want to generate your own data, you need ollama.

pip install torch tqdm ollama

if you have an nvidia gpu, pytorch should pick it up automatically. if you are on mac, it tries to use mps. otherwise, it falls back to cpu, which will be slow.

quick start

1. get the data

you can generate the dataset yourself. it uses ollama (llama3.2:3b) to hallucinate 200k product names.

python generate_dataset_expanded.py

2. train the model

this trains the transformer. it saves the model, config, and vocabulary to classifier.pth.

python train.py

3. run inference

once you have the .pth file, you can test it.

interactive mode:

python cli.py

single prediction:

python cli.py "vintage wooden chess set"

4. evaluate

check how well the model actually works. evaluate.py tests on the training distribution. evaluate_holdout.py tests on completely unseen data structures to check for overfitting.

python evaluate.py
python evaluate_holdout.py

file structure

  • model.py: the actual neural network. multi-head attention, layer norms, the works.
  • train.py: the training loop. handles batching and backprop.
  • dataset.py: handles word tokenization and vocabulary building.
  • predict.py: inference logic. loads the checkpoint and reconstructs the model.
  • cli.py: simple interface to run predictions from your terminal.
  • generate_*.py: scripts to create the synthetic datasets using local llms.

license

licensed under the apache license, version 2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support