Causal GPT-RL
GPT-style transformers (GPT-2, Llama) running as RL policies in continuous-control environments.
Both LLM generation and RL interaction are autoregressive:
token → next token (LLM generation)
(state, action) → (next state from env, next action) (RL rollout)
Causal GPT-RL policies act stably under their own rollouts — long-horizon control without the drift that has historically kept transformers from being usable as RL agents.
A single autoregressive model drives full-episode rollouts via KV cache — no critic, no auxiliary networks at inference.
This repository is the public inference runtime. It loads policy bundles, runs Gymnasium/MuJoCo rollouts, and provides small evaluation helpers.
- Code (GitHub): ccnets-team/causal-gpt-rl
- Run logs (W&B, public): wandb.ai/junhopark/Causal GPT-RL
- Hugging Face org: https://huggingface.co/ccnets
- Website: https://ccnets.org
- LinkedIn: https://www.linkedin.com/company/ccnets
Install
For Hub loading and MuJoCo environments:
pip install "causal-gpt-rl[hub,mujoco]"
For local development:
git clone https://github.com/ccnets-team/causal-gpt-rl.git
cd causal-gpt-rl
python -m pip install -e ".[hub,mujoco]"
For private bundles, authenticate first:
hf auth login
Quick Start
import gymnasium as gym
from causal_gpt_rl.inference import load_runner_from_hub, run_episodes
env = gym.make("Ant-v5")
runner = load_runner_from_hub(
repo_id="ccnets/causal-gpt-rl",
subfolder="ant-v5",
)
stats = run_episodes(env, runner, num_episodes=5, seed=0)
env.close()
print(stats["return_mean"], stats["return_std"])
Notebook version: examples/hub_quickstart.ipynb
Supported Environments
| Env | Bundle | Ctx | Return | Norm. | Medium Ref. |
|---|---|---|---|---|---|
Ant-v5 |
ant-v5 |
32 | 3339.51±1115.40 | 50.56±16.54 | 86.54 |
HalfCheetah-v5 |
halfcheetah-v5 |
32 | 5989.04±1902.22 | 37.86±11.53 | 74.83 |
Hopper-v5 |
hopper-v5 |
32 | 2836.28±987.67 | 73.40±25.72 | 72.91 |
Walker2d-v5 |
walker2d-v5 |
32 | 3883.30±684.09 | 56.69±9.99 | 83.26 |
Humanoid-v5 |
humanoid-v5 |
32 | 6089.64±2512.73 | 70.41±29.58 | 81.30 |
Training data is expert-free: bundles are trained using Minari simple and medium datasets only; expert trajectories are not used for training.
Return and Norm. are mean±std over 50 episodes with seeds 0..49. Ctx is context length. max_steps=1000, and KV cache max length is capped to Ctx.
Normalized scores use random=0 and expert=100:
100 * (return - random_ref) / (expert_ref - random_ref)
Medium reference scores are shown for context and are not the normalization baseline.
Evaluation runtime:
causal-gpt-rl 0.2.1
torch 2.12.0+cu132
gymnasium 1.2.2
mujoco 3.8.1
minari 0.5.3
Bundle Format
All public bundles include:
bundle/
model.safetensors
config.json
state_normalizer.safetensors
model.safetensors— model state dict for inference.config.json— model config, observation specs, action specs, context length, and optionalenv_id.state_normalizer.safetensors— state normalization statistics used by the policy.
Hugging Face Layout
Recommended layout:
ccnets/causal-gpt-rl/
ant-v5/
model.safetensors
config.json
state_normalizer.safetensors
README.md
For local bundles, use load_runner("path/to/bundle").
API
from causal_gpt_rl.inference import (
PolicyRunner, # step-wise rollout policy with KV cache
load_runner, # load runner from a local bundle directory
load_runner_from_hub, # load runner from a Hugging Face Hub repo
run_episodes, # evaluate over N episodes; returns stats dict
export_bundle, # write a bundle directory from a runner
convert_legacy_bundle_to_safetensors, # migrate legacy bundles to the safetensors format
)
Development Checks
python -m compileall -q causal_gpt_rl
python -m unittest discover -s tests
python -m build
python -m twine check dist/*
License
Released under PolyForm Noncommercial License 1.0.0. See LICENSE for details. For commercial licensing, contact the maintainers via ccnets.org.
- Downloads last month
- 277