File size: 2,044 Bytes
1e05592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env bash
# Binary Ablation: 2-class Policy (SILENT=0 / ALERT=1)
#
# 目的:论文 Ablation Study
#   证明 3-class 的 OBSERVE 类比 2-class 更好
#   如果移除 OBSERVE 后指标下降 → 说明 OBSERVE 有价值
#
# 实现方式:
#   将所有 OBSERVE 标签重新映射为 ALERT (更保守) 或 SILENT (更激进)
#   然后用 2-class head 训练
#   这通过 --n_actions 2 和 --merge_observe {alert|silent} 参数控制
#
# 用法:
#   bash training/Policy/train_policy_binary.sh
set -euo pipefail

ROOT=PROJECT_ROOT
SFT_CHECKPOINT="$ROOT/checkpoints/SFT/sft_v2/best"
LABEL_DIR="$ROOT/data/policy_labels"
CACHE_DIR="$ROOT/data/belief_cache"
OUTPUT_DIR="$ROOT/checkpoints/Policy"

cd "$ROOT"

# ── Ablation A: merge OBSERVE→ALERT (2-class, conservative) ─────────────────
echo "=== Binary Ablation A: OBSERVE→ALERT (2-class conservative) ==="
python -m training.Policy.warm_start_trainer \
    --sft_checkpoint        "$SFT_CHECKPOINT"        \
    --label_dir             "$LABEL_DIR"              \
    --belief_cache_dir      "$CACHE_DIR"              \
    --output_dir            "$OUTPUT_DIR"             \
    --experiment_name       "policy_binary_obs2alert" \
    --num_epochs            15                        \
    --batch_size            256                       \
    --learning_rate         2e-4                      \
    --lr_min                1e-6                      \
    --val_every_n_steps     200                       \
    --focal_gamma           2.0                       \
    --focal_alpha           0.3 0.0 0.7               \
    --belief_noise_std      0.01                      \
    --label_smoothing       0.05                      \
    --early_stop_patience   7                         \
    --score_weights         0.6 0.0 0.4               \
    --merge_observe         alert                     \
    --use_balanced_sampler                            \
    --use_wandb

echo ""
echo "✅ Binary ablation (obs→alert) complete."