VLAlert / training /Policy /train_policy_binary.sh
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
2.04 kB
#!/usr/bin/env bash
# Binary Ablation: 2-class Policy (SILENT=0 / ALERT=1)
#
# 目的:论文 Ablation Study
# 证明 3-class 的 OBSERVE 类比 2-class 更好
# 如果移除 OBSERVE 后指标下降 → 说明 OBSERVE 有价值
#
# 实现方式:
# 将所有 OBSERVE 标签重新映射为 ALERT (更保守) 或 SILENT (更激进)
# 然后用 2-class head 训练
# 这通过 --n_actions 2 和 --merge_observe {alert|silent} 参数控制
#
# 用法:
# bash training/Policy/train_policy_binary.sh
set -euo pipefail
ROOT=PROJECT_ROOT
SFT_CHECKPOINT="$ROOT/checkpoints/SFT/sft_v2/best"
LABEL_DIR="$ROOT/data/policy_labels"
CACHE_DIR="$ROOT/data/belief_cache"
OUTPUT_DIR="$ROOT/checkpoints/Policy"
cd "$ROOT"
# ── Ablation A: merge OBSERVE→ALERT (2-class, conservative) ─────────────────
echo "=== Binary Ablation A: OBSERVE→ALERT (2-class conservative) ==="
python -m training.Policy.warm_start_trainer \
--sft_checkpoint "$SFT_CHECKPOINT" \
--label_dir "$LABEL_DIR" \
--belief_cache_dir "$CACHE_DIR" \
--output_dir "$OUTPUT_DIR" \
--experiment_name "policy_binary_obs2alert" \
--num_epochs 15 \
--batch_size 256 \
--learning_rate 2e-4 \
--lr_min 1e-6 \
--val_every_n_steps 200 \
--focal_gamma 2.0 \
--focal_alpha 0.3 0.0 0.7 \
--belief_noise_std 0.01 \
--label_smoothing 0.05 \
--early_stop_patience 7 \
--score_weights 0.6 0.0 0.4 \
--merge_observe alert \
--use_balanced_sampler \
--use_wandb
echo ""
echo "✅ Binary ablation (obs→alert) complete."