#!/usr/bin/env bash # Binary Ablation: 2-class Policy (SILENT=0 / ALERT=1) # # 目的:论文 Ablation Study # 证明 3-class 的 OBSERVE 类比 2-class 更好 # 如果移除 OBSERVE 后指标下降 → 说明 OBSERVE 有价值 # # 实现方式: # 将所有 OBSERVE 标签重新映射为 ALERT (更保守) 或 SILENT (更激进) # 然后用 2-class head 训练 # 这通过 --n_actions 2 和 --merge_observe {alert|silent} 参数控制 # # 用法: # bash training/Policy/train_policy_binary.sh set -euo pipefail ROOT=PROJECT_ROOT SFT_CHECKPOINT="$ROOT/checkpoints/SFT/sft_v2/best" LABEL_DIR="$ROOT/data/policy_labels" CACHE_DIR="$ROOT/data/belief_cache" OUTPUT_DIR="$ROOT/checkpoints/Policy" cd "$ROOT" # ── Ablation A: merge OBSERVE→ALERT (2-class, conservative) ───────────────── echo "=== Binary Ablation A: OBSERVE→ALERT (2-class conservative) ===" python -m training.Policy.warm_start_trainer \ --sft_checkpoint "$SFT_CHECKPOINT" \ --label_dir "$LABEL_DIR" \ --belief_cache_dir "$CACHE_DIR" \ --output_dir "$OUTPUT_DIR" \ --experiment_name "policy_binary_obs2alert" \ --num_epochs 15 \ --batch_size 256 \ --learning_rate 2e-4 \ --lr_min 1e-6 \ --val_every_n_steps 200 \ --focal_gamma 2.0 \ --focal_alpha 0.3 0.0 0.7 \ --belief_noise_std 0.01 \ --label_smoothing 0.05 \ --early_stop_patience 7 \ --score_weights 0.6 0.0 0.4 \ --merge_observe alert \ --use_balanced_sampler \ --use_wandb echo "" echo "✅ Binary ablation (obs→alert) complete."