RNN Next-Token Classifier
Binary classifier that predicts whether a candidate token is the true next token in a sequence, using frozen Pythia-160M hidden states as features.
Architecture
- Feature extractor: EleutherAI/pythia-160m (frozen)
- Classifier: 2-layer GRU โ linear head
- Input: concatenation of mean-pooled context embedding and candidate token embedding
- Output: probability that the candidate is the true next token
Training
- Dataset:
jordiclive/wikipedia-summary-dataset(1000 samples) - Evaluation: 100 held-out samples (indices 1100โ1199)
Usage
import torch, json
from rnn_classifier import RNNClassifier # copy class definition
with open("config.json") as f:
cfg = json.load(f)
model = RNNClassifier(cfg["input_dim"], cfg["hidden_dim"], cfg["num_layers"])
model.load_state_dict(torch.load("rnn_classifier.pt", map_location="cpu"))
model.eval()
- Downloads last month
- 47