RNN Next-Token Classifier

Binary classifier that predicts whether a candidate token is the true next token in a sequence, using frozen Pythia-160M hidden states as features.

Architecture

  • Feature extractor: EleutherAI/pythia-160m (frozen)
  • Classifier: 2-layer GRU โ†’ linear head
  • Input: concatenation of mean-pooled context embedding and candidate token embedding
  • Output: probability that the candidate is the true next token

Training

  • Dataset: jordiclive/wikipedia-summary-dataset (1000 samples)
  • Evaluation: 100 held-out samples (indices 1100โ€“1199)

Usage

import torch, json
from rnn_classifier import RNNClassifier   # copy class definition

with open("config.json") as f:
    cfg = json.load(f)

model = RNNClassifier(cfg["input_dim"], cfg["hidden_dim"], cfg["num_layers"])
model.load_state_dict(torch.load("rnn_classifier.pt", map_location="cpu"))
model.eval()
Downloads last month
47
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support