| --- |
| license: mit |
| tags: |
| - reinforcement-learning |
| - offline-rl |
| - decision-transformer |
| - unity-ml-agents |
| - robotics |
| - sim-to-real |
| datasets: |
| - DecisionTransformer-Unity-Sim/DTTrajectoryData.zip |
| --- |
| |
| # Decision Transformer for Dynamic 3D Environments via Strategic Data Curation |
|
|
| This repository contains the official implementation and pre-trained models for the paper "[Data-Centric Offline Reinforcement Learning: Strategic Data Curation via Unity ML-Agents and Decision Transformer]" (Submitted to Scientific Reports). |
|
|
| We present a data-centric approach to Offline Reinforcement Learning (Offline RL) using **Unity ML-Agents** and **Decision Transformer (DT)**. Our research demonstrates that **strategic data curation**—specifically, fine-tuning on a small subset of high-quality "virtual expert" trajectories—is more critical for performance optimization than mere data volume. |
|
|
| ## 🚀 Key Features |
| * **Sim-to-Data-to-Model:** A complete pipeline generating synthetic data via Unity ML-Agents to train Transformer-based control agents. |
| * **Strategic Curation:** Demonstrates that fine-tuning with only **5-10%** of high-quality data (Top-tier trajectories) significantly outperforms training on massive mixed-quality datasets. |
| * **Robust Generalization:** The model maintains **96-100%** success rates even in zero-shot environments with increased complexity (e.g., 20 simultaneous targets). |
|
|
| ## 📊 Model Zoo |
|
|
| | Model Name | Pre-training Data | Fine-tuning Data | Description | |
| | :--- | :--- | :--- | :--- | |
| | **DT_S_100** | 100% Mixed Data | None | Baseline model trained on the full dataset without curation. | |
| | **DT_C_5** | None | Top 5% Expert Data | Model trained *only* on a small, high-quality subset. | |
| | **DT_C_10** | None | Top 10% Expert Data | Model trained *only* on a larger high-quality subset. | |
| | **DT_SC_5** | 100% Mixed Data | Top 5% Expert Data | Pre-trained on mixed data, fine-tuned on top 5% curated data. | |
| | **DT_SC_10** | 100% Mixed Data | Top 10% Expert Data | **(Best)** Pre-trained on mixed data, fine-tuned on top 10% curated data. Achieves 4x stability. | |
|
|
| ## 🏗️ Methodology |
| 1. **Data Generation:** We utilized **Unity ML-Agents** to train a PPO (Proximal Policy Optimization) agent as a "Virtual Expert." |
| 2. **Data Collection:** Collected step-wise interaction data (State, Action, Reward, RTG) from the PPO agent in a 3D projectile interception task. Supported by scripts in `UnityScript/`. |
| 3. **Offline Training:** Trained a **Decision Transformer** (Chen et al., 2021) to predict the next optimal action based on the history of states and target returns. Implemented in `model_dt.py`. |
|
|
| ## 📈 Performance |
| * **Control Stability:** Improved by **3.5x** in the `DT_SC` model compared to the baseline. |
| * **Firing Stability:** Improved by over **4x**. |
| * **Success Rate:** Maintained PPO-level performance (~98%) while strictly operating in an offline manner. |
| * **Metrics Visualization:** Use `chart_visualize.py` to reproduce performance plots (Win Rate, Avg Steps, Smoothness). |
|
|
| ## 💻 Usage |
|
|
| The following example demonstrates how to load a pre-trained model and run inference: |
|
|
| ```python |
| import torch |
| import numpy as np |
| from model_dt import DecisionTransformer |
| |
| # Configuration (must match training config) |
| OBS_DIM = 9 |
| ACT_DIM = 3 |
| HIDDEN_SIZE = 256 |
| MAX_LEN = 1024 # Sequence length |
| |
| # 1. Load the pre-trained model |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = DecisionTransformer( |
| obs_dim=OBS_DIM, |
| act_dim=ACT_DIM, |
| hidden=HIDDEN_SIZE, |
| max_len=MAX_LEN |
| ) |
| |
| # Load weights (example: DT_SC_5.pth) |
| model_path = "DT_SC_5.pth" |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.to(device) |
| model.eval() |
| |
| print(f"Loaded model from {model_path}") |
| |
| # 2. Inference Loop (Pseudo-code example) |
| # Note: Requires a running environment 'env' |
| def get_action(model, states, actions, rewards, target_return, timesteps): |
| # Pad all inputs to context length (MAX_LEN) if necessary |
| # ... (Padding logic here) ... |
| |
| with torch.no_grad(): |
| # Predict action |
| state_preds = model( |
| states.unsqueeze(0), |
| actions.unsqueeze(0), |
| rewards.unsqueeze(0), |
| timesteps.unsqueeze(0) |
| ) |
| action_pred = state_preds[0, -1] # Take the last action prediction |
| return action_pred |
| |
| # Example usage within an episode |
| # state = env.reset() |
| # target_return = torch.tensor([1.0], device=device) # Normalized expert return |
| # for t in range(max_steps): |
| # action = get_action(model, state_history, action_history, reward_history, target_return, t) |
| # next_state, reward, done, _ = env.step(action) |
| # ... |
| ``` |
|
|
| ## 📁 File Structure |
| - `model_dt.py`: Decision Transformer model definition. |
| - `train_sequential.py`: Main training script. |
| - `dataset_dt.py`: Dataset loader for trajectory data. |
| - `chart_visualize.py`: Visualization tool for benchmark metrics. |