Causal GPT-RL

GPT-style transformers (GPT-2, Llama) running as RL policies in continuous-control environments.

Both LLM generation and RL interaction are autoregressive:

token           → next token                           (LLM generation)
(state, action) → (next state from env, next action)   (RL rollout)

Causal GPT-RL policies act stably under their own rollouts — long-horizon control without the drift that has historically kept transformers from being usable as RL agents.

A single autoregressive model drives full-episode rollouts via KV cache — no critic, no auxiliary networks at inference.

This repository is the public inference runtime. It loads policy bundles, runs Gymnasium/MuJoCo rollouts, and provides small evaluation helpers.

Install

For Hub loading and MuJoCo environments:

pip install "causal-gpt-rl[hub,mujoco]"

For local development:

git clone https://github.com/ccnets-team/causal-gpt-rl.git
cd causal-gpt-rl
python -m pip install -e ".[hub,mujoco]"

For private bundles, authenticate first:

hf auth login

Quick Start

import gymnasium as gym

from causal_gpt_rl.inference import load_runner_from_hub, run_episodes

env = gym.make("Ant-v5")
runner = load_runner_from_hub(
    repo_id="ccnets/causal-gpt-rl",
    subfolder="ant-v5",
)

stats = run_episodes(env, runner, num_episodes=5, seed=0)
env.close()
print(stats["return_mean"], stats["return_std"])

Notebook version: examples/hub_quickstart.ipynb

Supported Environments

Env Bundle Ctx Return Norm. Medium Ref.
Ant-v5 ant-v5 32 3339.51±1115.40 50.56±16.54 86.54
HalfCheetah-v5 halfcheetah-v5 32 5989.04±1902.22 37.86±11.53 74.83
Hopper-v5 hopper-v5 32 2836.28±987.67 73.40±25.72 72.91
Walker2d-v5 walker2d-v5 32 3883.30±684.09 56.69±9.99 83.26
Humanoid-v5 humanoid-v5 32 6089.64±2512.73 70.41±29.58 81.30

Training data is expert-free: bundles are trained using Minari simple and medium datasets only; expert trajectories are not used for training.

Return and Norm. are mean±std over 50 episodes with seeds 0..49. Ctx is context length. max_steps=1000, and KV cache max length is capped to Ctx.

Normalized scores use random=0 and expert=100:

100 * (return - random_ref) / (expert_ref - random_ref)

Medium reference scores are shown for context and are not the normalization baseline.

Evaluation runtime:

causal-gpt-rl 0.2.1
torch 2.12.0+cu132
gymnasium 1.2.2
mujoco 3.8.1
minari 0.5.3

Bundle Format

All public bundles include:

bundle/
  model.safetensors
  config.json
  state_normalizer.safetensors
  • model.safetensors — model state dict for inference.
  • config.json — model config, observation specs, action specs, context length, and optional env_id.
  • state_normalizer.safetensors — state normalization statistics used by the policy.

Hugging Face Layout

Recommended layout:

ccnets/causal-gpt-rl/
  ant-v5/
    model.safetensors
    config.json
    state_normalizer.safetensors
    README.md

For local bundles, use load_runner("path/to/bundle").

API

from causal_gpt_rl.inference import (
    PolicyRunner,                          # step-wise rollout policy with KV cache
    load_runner,                           # load runner from a local bundle directory
    load_runner_from_hub,                  # load runner from a Hugging Face Hub repo
    run_episodes,                          # evaluate over N episodes; returns stats dict
    export_bundle,                         # write a bundle directory from a runner
    convert_legacy_bundle_to_safetensors,  # migrate legacy bundles to the safetensors format
)

Development Checks

python -m compileall -q causal_gpt_rl
python -m unittest discover -s tests
python -m build
python -m twine check dist/*

License

Released under PolyForm Noncommercial License 1.0.0. See LICENSE for details. For commercial licensing, contact the maintainers via ccnets.org.

Downloads last month
277
Video Preview
loading