| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| set -euo pipefail |
|
|
| ROOT=PROJECT_ROOT |
| SFT_CHECKPOINT="$ROOT/checkpoints/SFT/sft_v2/best" |
| LABEL_DIR="$ROOT/data/policy_labels" |
| CACHE_DIR="$ROOT/data/belief_cache" |
| OUTPUT_DIR="$ROOT/checkpoints/Policy" |
|
|
| cd "$ROOT" |
|
|
| |
| echo "=== Binary Ablation A: OBSERVE→ALERT (2-class conservative) ===" |
| python -m training.Policy.warm_start_trainer \ |
| --sft_checkpoint "$SFT_CHECKPOINT" \ |
| --label_dir "$LABEL_DIR" \ |
| --belief_cache_dir "$CACHE_DIR" \ |
| --output_dir "$OUTPUT_DIR" \ |
| --experiment_name "policy_binary_obs2alert" \ |
| --num_epochs 15 \ |
| --batch_size 256 \ |
| --learning_rate 2e-4 \ |
| --lr_min 1e-6 \ |
| --val_every_n_steps 200 \ |
| --focal_gamma 2.0 \ |
| --focal_alpha 0.3 0.0 0.7 \ |
| --belief_noise_std 0.01 \ |
| --label_smoothing 0.05 \ |
| --early_stop_patience 7 \ |
| --score_weights 0.6 0.0 0.4 \ |
| --merge_observe alert \ |
| --use_balanced_sampler \ |
| --use_wandb |
|
|
| echo "" |
| echo "✅ Binary ablation (obs→alert) complete." |
|
|