~JADIS
commited on
Commit
·
5ce8003
1
Parent(s):
6e691a3
Improve training validation and offline safety flow (#9)
Browse files- train/README.md +6 -4
- train/prepare_dataset.py +3 -2
- train/run_eval.py +38 -13
- train/train_adapter.py +64 -16
- train/train_qlora.py +39 -11
- train/validate_dataset.py +35 -0
train/README.md
CHANGED
|
@@ -30,12 +30,13 @@ Set the dataset location once per shell:
|
|
| 30 |
export DATASET_DIR=/absolute/path/to/blux-ca-dataset
|
| 31 |
```
|
| 32 |
|
| 33 |
-
Validate dataset strictly (
|
| 34 |
```bash
|
| 35 |
python train/validate_dataset.py --dataset-dir "$DATASET_DIR" --strict
|
| 36 |
```
|
| 37 |
|
| 38 |
-
Dry-run (loads base model, prepares 5 samples, tokenizes)
|
|
|
|
| 39 |
```bash
|
| 40 |
python train/train_adapter.py --dataset-dir "$DATASET_DIR" --dry-run
|
| 41 |
```
|
|
@@ -50,14 +51,15 @@ Full train:
|
|
| 50 |
python train/train_adapter.py --dataset-dir "$DATASET_DIR" --run-name full
|
| 51 |
```
|
| 52 |
|
| 53 |
-
Eval gate (strict):
|
| 54 |
```bash
|
| 55 |
-
python train/run_eval.py --dataset-dir "$DATASET_DIR" --run runs/<timestamp_or_name> --strict
|
| 56 |
```
|
| 57 |
|
| 58 |
GPU is recommended for smoke/full runs. On CPU-only environments, set `BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct` for the dry-run to conserve memory.
|
| 59 |
|
| 60 |
## Outputs
|
|
|
|
| 61 |
- Prepared dataset + resolved mix: `runs/<timestamp>/prepared_train.jsonl` and `runs/<timestamp>/mix_config_resolved.yaml`
|
| 62 |
- Training artifacts: `runs/<timestamp>/adapter/` plus `runs/<timestamp>/training_args.json` and `config_snapshot.yaml`
|
| 63 |
- Evaluation report: `runs/<timestamp>/eval_report.md`
|
|
|
|
| 30 |
export DATASET_DIR=/absolute/path/to/blux-ca-dataset
|
| 31 |
```
|
| 32 |
|
| 33 |
+
Validate dataset strictly (always invokes the dataset repo validator first):
|
| 34 |
```bash
|
| 35 |
python train/validate_dataset.py --dataset-dir "$DATASET_DIR" --strict
|
| 36 |
```
|
| 37 |
|
| 38 |
+
Dry-run (loads base model, prepares 5 samples, tokenizes). On CPU-only hosts the base model automatically falls back to
|
| 39 |
+
`Qwen/Qwen2.5-1.5B-Instruct` unless you override `BASE_MODEL`:
|
| 40 |
```bash
|
| 41 |
python train/train_adapter.py --dataset-dir "$DATASET_DIR" --dry-run
|
| 42 |
```
|
|
|
|
| 51 |
python train/train_adapter.py --dataset-dir "$DATASET_DIR" --run-name full
|
| 52 |
```
|
| 53 |
|
| 54 |
+
Eval gate (strict). Use `--use-stub` when running without a trained adapter or when offline:
|
| 55 |
```bash
|
| 56 |
+
python train/run_eval.py --dataset-dir "$DATASET_DIR" --run runs/<timestamp_or_name> --strict --use-stub
|
| 57 |
```
|
| 58 |
|
| 59 |
GPU is recommended for smoke/full runs. On CPU-only environments, set `BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct` for the dry-run to conserve memory.
|
| 60 |
|
| 61 |
## Outputs
|
| 62 |
+
- Runs are created under `runs/YYYYMMDD_HHMMSS_<optional_name>/`
|
| 63 |
- Prepared dataset + resolved mix: `runs/<timestamp>/prepared_train.jsonl` and `runs/<timestamp>/mix_config_resolved.yaml`
|
| 64 |
- Training artifacts: `runs/<timestamp>/adapter/` plus `runs/<timestamp>/training_args.json` and `config_snapshot.yaml`
|
| 65 |
- Evaluation report: `runs/<timestamp>/eval_report.md`
|
train/prepare_dataset.py
CHANGED
|
@@ -14,7 +14,7 @@ from validate_dataset import SYSTEM_PLACEHOLDER, validate_dataset
|
|
| 14 |
|
| 15 |
|
| 16 |
def _timestamp() -> str:
|
| 17 |
-
return datetime.utcnow().strftime("%Y%m%
|
| 18 |
|
| 19 |
|
| 20 |
def _load_config(path: Path) -> Dict:
|
|
@@ -98,7 +98,8 @@ def prepare_dataset(
|
|
| 98 |
if shuffle:
|
| 99 |
rng.shuffle(collected)
|
| 100 |
|
| 101 |
-
|
|
|
|
| 102 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 103 |
output_path = run_dir / "prepared_train.jsonl"
|
| 104 |
with output_path.open("w", encoding="utf-8") as handle:
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def _timestamp() -> str:
|
| 17 |
+
return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| 18 |
|
| 19 |
|
| 20 |
def _load_config(path: Path) -> Dict:
|
|
|
|
| 98 |
if shuffle:
|
| 99 |
rng.shuffle(collected)
|
| 100 |
|
| 101 |
+
folder_name = _timestamp() if not run_name else f"{_timestamp()}_{run_name}"
|
| 102 |
+
run_dir = output_root / folder_name
|
| 103 |
run_dir.mkdir(parents=True, exist_ok=True)
|
| 104 |
output_path = run_dir / "prepared_train.jsonl"
|
| 105 |
with output_path.open("w", encoding="utf-8") as handle:
|
train/run_eval.py
CHANGED
|
@@ -5,7 +5,7 @@ import argparse
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import Dict, List, Tuple
|
| 9 |
|
| 10 |
import torch
|
| 11 |
from peft import PeftModel
|
|
@@ -80,8 +80,8 @@ def _is_red_team(messages: List[Dict]) -> bool:
|
|
| 80 |
return any(keyword in lowered for keyword in RED_TEAM_KEYWORDS)
|
| 81 |
|
| 82 |
|
| 83 |
-
def _build_prompt(messages: List[Dict], tokenizer) -> str:
|
| 84 |
-
if hasattr(tokenizer, "apply_chat_template"):
|
| 85 |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 86 |
parts = []
|
| 87 |
for msg in messages:
|
|
@@ -131,17 +131,36 @@ def _evaluate_response(response: str, red_team: bool, identity: bool) -> Tuple[b
|
|
| 131 |
return not failures, failures
|
| 132 |
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
def run_evaluation(
|
| 135 |
base_model: str,
|
| 136 |
-
adapter_path: Path,
|
| 137 |
dataset_dir: Path,
|
| 138 |
strict: bool,
|
| 139 |
max_new_tokens: int = 256,
|
|
|
|
| 140 |
) -> Tuple[int, int, List[str]]:
|
| 141 |
-
tokenizer =
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
probes = _load_eval_files(dataset_dir)
|
| 147 |
|
|
@@ -154,7 +173,7 @@ def run_evaluation(
|
|
| 154 |
red_team = _is_red_team(messages) or source.startswith("red_team") or "red_team" in tags
|
| 155 |
identity = probe_id.startswith("identity_") or "identity" in tags or source.startswith("identity")
|
| 156 |
prompt = _build_prompt(messages, tokenizer)
|
| 157 |
-
response = _run_model(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
|
| 158 |
passed, reasons = _evaluate_response(response, red_team, identity)
|
| 159 |
if not passed:
|
| 160 |
joined_reasons = "; ".join(reasons)
|
|
@@ -169,26 +188,31 @@ def main() -> int:
|
|
| 169 |
"--dataset-dir",
|
| 170 |
required=False,
|
| 171 |
type=Path,
|
| 172 |
-
default=os.environ.get("DATASET_DIR"),
|
| 173 |
help="Path to dataset repository (or set DATASET_DIR)",
|
| 174 |
)
|
| 175 |
parser.add_argument("--run", required=True, type=Path, help="Run directory containing adapter/")
|
| 176 |
parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-7B-Instruct", help="Base model to load")
|
| 177 |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length for probes")
|
| 178 |
parser.add_argument("--strict", action="store_true", help="Exit non-zero on failures")
|
|
|
|
| 179 |
args = parser.parse_args()
|
| 180 |
|
| 181 |
if args.dataset_dir is None:
|
| 182 |
-
print(
|
|
|
|
|
|
|
| 183 |
return 1
|
| 184 |
dataset_dir = Path(args.dataset_dir)
|
| 185 |
|
| 186 |
adapter_path = args.run / "adapter"
|
| 187 |
if not adapter_path.exists():
|
| 188 |
adapter_path = args.run / "adapter_model"
|
| 189 |
-
if not adapter_path.exists():
|
| 190 |
-
print(f"Adapter path not found under run: {args.run}")
|
| 191 |
return 1
|
|
|
|
|
|
|
| 192 |
|
| 193 |
total, failures, messages = run_evaluation(
|
| 194 |
args.base_model,
|
|
@@ -196,6 +220,7 @@ def main() -> int:
|
|
| 196 |
dataset_dir,
|
| 197 |
args.strict,
|
| 198 |
max_new_tokens=args.max_new_tokens,
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
report_path = args.run / "eval_report.md"
|
|
|
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Tuple
|
| 9 |
|
| 10 |
import torch
|
| 11 |
from peft import PeftModel
|
|
|
|
| 80 |
return any(keyword in lowered for keyword in RED_TEAM_KEYWORDS)
|
| 81 |
|
| 82 |
|
| 83 |
+
def _build_prompt(messages: List[Dict], tokenizer=None) -> str:
|
| 84 |
+
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
|
| 85 |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 86 |
parts = []
|
| 87 |
for msg in messages:
|
|
|
|
| 131 |
return not failures, failures
|
| 132 |
|
| 133 |
|
| 134 |
+
def _stub_response(red_team: bool, identity: bool) -> str:
|
| 135 |
+
base = "I cannot assist with that request. Instead, consider safer, accountable actions that align with BLUX-cA governance and auditability."
|
| 136 |
+
if identity:
|
| 137 |
+
base += " I am committed to clarity, accountability, and refusal of misuse."
|
| 138 |
+
if red_team:
|
| 139 |
+
base += " I refuse harmful actions and recommend seeking constructive, lawful alternatives."
|
| 140 |
+
return base
|
| 141 |
+
|
| 142 |
+
|
| 143 |
def run_evaluation(
|
| 144 |
base_model: str,
|
| 145 |
+
adapter_path: Optional[Path],
|
| 146 |
dataset_dir: Path,
|
| 147 |
strict: bool,
|
| 148 |
max_new_tokens: int = 256,
|
| 149 |
+
use_stub: bool = False,
|
| 150 |
) -> Tuple[int, int, List[str]]:
|
| 151 |
+
tokenizer = None
|
| 152 |
+
model = None
|
| 153 |
+
if not use_stub:
|
| 154 |
+
try:
|
| 155 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
|
| 156 |
+
base = AutoModelForCausalLM.from_pretrained(base_model, **_quant_config())
|
| 157 |
+
if adapter_path:
|
| 158 |
+
base = PeftModel.from_pretrained(base, adapter_path)
|
| 159 |
+
model = base
|
| 160 |
+
model.eval()
|
| 161 |
+
except Exception as exc: # pragma: no cover - fallback for offline hosts
|
| 162 |
+
print(f"Model/tokenizer load failed ({exc}); falling back to stub responses.")
|
| 163 |
+
use_stub = True
|
| 164 |
|
| 165 |
probes = _load_eval_files(dataset_dir)
|
| 166 |
|
|
|
|
| 173 |
red_team = _is_red_team(messages) or source.startswith("red_team") or "red_team" in tags
|
| 174 |
identity = probe_id.startswith("identity_") or "identity" in tags or source.startswith("identity")
|
| 175 |
prompt = _build_prompt(messages, tokenizer)
|
| 176 |
+
response = _stub_response(red_team, identity) if use_stub else _run_model(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
|
| 177 |
passed, reasons = _evaluate_response(response, red_team, identity)
|
| 178 |
if not passed:
|
| 179 |
joined_reasons = "; ".join(reasons)
|
|
|
|
| 188 |
"--dataset-dir",
|
| 189 |
required=False,
|
| 190 |
type=Path,
|
| 191 |
+
default=Path(os.environ["DATASET_DIR"]) if os.environ.get("DATASET_DIR") else None,
|
| 192 |
help="Path to dataset repository (or set DATASET_DIR)",
|
| 193 |
)
|
| 194 |
parser.add_argument("--run", required=True, type=Path, help="Run directory containing adapter/")
|
| 195 |
parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-7B-Instruct", help="Base model to load")
|
| 196 |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length for probes")
|
| 197 |
parser.add_argument("--strict", action="store_true", help="Exit non-zero on failures")
|
| 198 |
+
parser.add_argument("--use-stub", action="store_true", help="Use stubbed refusal responses (no model download)")
|
| 199 |
args = parser.parse_args()
|
| 200 |
|
| 201 |
if args.dataset_dir is None:
|
| 202 |
+
print(
|
| 203 |
+
"Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., export DATASET_DIR=/absolute/path/to/blux-ca-dataset)"
|
| 204 |
+
)
|
| 205 |
return 1
|
| 206 |
dataset_dir = Path(args.dataset_dir)
|
| 207 |
|
| 208 |
adapter_path = args.run / "adapter"
|
| 209 |
if not adapter_path.exists():
|
| 210 |
adapter_path = args.run / "adapter_model"
|
| 211 |
+
if not adapter_path.exists() and not args.use_stub:
|
| 212 |
+
print(f"Adapter path not found under run: {args.run}. Use --use-stub to run heuristic-only evaluation.")
|
| 213 |
return 1
|
| 214 |
+
if not adapter_path.exists():
|
| 215 |
+
adapter_path = None
|
| 216 |
|
| 217 |
total, failures, messages = run_evaluation(
|
| 218 |
args.base_model,
|
|
|
|
| 220 |
dataset_dir,
|
| 221 |
args.strict,
|
| 222 |
max_new_tokens=args.max_new_tokens,
|
| 223 |
+
use_stub=args.use_stub,
|
| 224 |
)
|
| 225 |
|
| 226 |
report_path = args.run / "eval_report.md"
|
train/train_adapter.py
CHANGED
|
@@ -15,11 +15,11 @@ import torch
|
|
| 15 |
import yaml
|
| 16 |
from datasets import load_dataset
|
| 17 |
from peft import LoraConfig, get_peft_model
|
| 18 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
|
| 19 |
from trl import SFTTrainer
|
| 20 |
|
| 21 |
from prepare_dataset import prepare_dataset
|
| 22 |
-
from validate_dataset import validate_dataset
|
| 23 |
|
| 24 |
|
| 25 |
def _load_yaml(path: Path) -> Dict:
|
|
@@ -33,21 +33,28 @@ def _write_json(path: Path, payload: Dict) -> None:
|
|
| 33 |
json.dump(payload, handle, indent=2, sort_keys=True)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
def _resolve_dataset_dir(raw: Optional[Path]) -> Path:
|
| 37 |
if raw:
|
| 38 |
return raw
|
| 39 |
env_dir = os.environ.get("DATASET_DIR")
|
| 40 |
if env_dir:
|
| 41 |
return Path(env_dir)
|
| 42 |
-
raise ValueError(
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
-
def _load_base_model_name(config: Dict, override: Optional[str]) -> str:
|
| 46 |
env_override = os.environ.get("BASE_MODEL")
|
| 47 |
if env_override:
|
| 48 |
return env_override
|
| 49 |
if override:
|
| 50 |
return override
|
|
|
|
|
|
|
| 51 |
return config.get("base_model", "Qwen/Qwen2.5-7B-Instruct")
|
| 52 |
|
| 53 |
|
|
@@ -83,20 +90,52 @@ def _build_dataset(prepared_path: Path, tokenizer):
|
|
| 83 |
return dataset.map(add_text, remove_columns=[])
|
| 84 |
|
| 85 |
|
| 86 |
-
def _init_model(base_model: str, quant_config: Optional[BitsAndBytesConfig]):
|
| 87 |
kwargs = {"device_map": "auto"}
|
| 88 |
if quant_config is not None:
|
| 89 |
kwargs["quantization_config"] = quant_config
|
| 90 |
else:
|
| 91 |
kwargs["torch_dtype"] = torch.float32
|
| 92 |
kwargs["low_cpu_mem_usage"] = True
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
tokenizer.padding_side = "right"
|
| 99 |
-
if tokenizer
|
| 100 |
tokenizer.pad_token = tokenizer.eos_token
|
| 101 |
return tokenizer
|
| 102 |
|
|
@@ -126,13 +165,22 @@ def _persist_config_snapshot(run_dir: Path, train_cfg: Dict, mix_config: Dict, b
|
|
| 126 |
def train(args: argparse.Namespace) -> Path:
|
| 127 |
dataset_dir = _resolve_dataset_dir(args.dataset_dir)
|
| 128 |
if not dataset_dir.exists():
|
| 129 |
-
raise FileNotFoundError(
|
|
|
|
|
|
|
| 130 |
|
| 131 |
train_cfg = _load_yaml(args.config)
|
| 132 |
mix_cfg = _load_yaml(args.mix_config)
|
| 133 |
if args.max_samples is not None:
|
| 134 |
mix_cfg = {**mix_cfg, "max_samples": args.max_samples, "__override_max_samples": True}
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
if args.strict:
|
| 138 |
_, errors = validate_dataset(dataset_dir, strict=True)
|
|
@@ -145,7 +193,7 @@ def train(args: argparse.Namespace) -> Path:
|
|
| 145 |
args.output_root,
|
| 146 |
run_name=args.run_name,
|
| 147 |
override_max_samples=args.max_samples,
|
| 148 |
-
strict=
|
| 149 |
)
|
| 150 |
run_dir = prepared_path.parent
|
| 151 |
|
|
@@ -155,7 +203,7 @@ def train(args: argparse.Namespace) -> Path:
|
|
| 155 |
resolved_mix_cfg = _load_yaml(resolved_mix_path)
|
| 156 |
|
| 157 |
quant_config = _quantization_config()
|
| 158 |
-
tokenizer = _init_tokenizer(base_model)
|
| 159 |
train_dataset = _build_dataset(prepared_path, tokenizer)
|
| 160 |
|
| 161 |
# Dry-run: load a few samples and ensure tokenization + model load succeed.
|
|
@@ -167,7 +215,7 @@ def train(args: argparse.Namespace) -> Path:
|
|
| 167 |
truncation=True,
|
| 168 |
padding="longest",
|
| 169 |
)
|
| 170 |
-
_ = _init_model(base_model, quant_config)
|
| 171 |
_persist_config_snapshot(run_dir, train_cfg, resolved_mix_cfg, base_model)
|
| 172 |
print("Dry-run successful: dataset prepared, tokenizer + model loaded, tokenization OK.")
|
| 173 |
return run_dir
|
|
|
|
| 15 |
import yaml
|
| 16 |
from datasets import load_dataset
|
| 17 |
from peft import LoraConfig, get_peft_model
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GPT2Config, TrainingArguments
|
| 19 |
from trl import SFTTrainer
|
| 20 |
|
| 21 |
from prepare_dataset import prepare_dataset
|
| 22 |
+
from validate_dataset import run_cli_validator, validate_dataset
|
| 23 |
|
| 24 |
|
| 25 |
def _load_yaml(path: Path) -> Dict:
|
|
|
|
| 33 |
json.dump(payload, handle, indent=2, sort_keys=True)
|
| 34 |
|
| 35 |
|
| 36 |
+
EXAMPLE_DATASET_CMD = "export DATASET_DIR=/absolute/path/to/blux-ca-dataset"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
def _resolve_dataset_dir(raw: Optional[Path]) -> Path:
|
| 40 |
if raw:
|
| 41 |
return raw
|
| 42 |
env_dir = os.environ.get("DATASET_DIR")
|
| 43 |
if env_dir:
|
| 44 |
return Path(env_dir)
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., {EXAMPLE_DATASET_CMD})"
|
| 47 |
+
)
|
| 48 |
|
| 49 |
|
| 50 |
+
def _load_base_model_name(config: Dict, override: Optional[str], prefer_cpu_safe: bool = False) -> str:
|
| 51 |
env_override = os.environ.get("BASE_MODEL")
|
| 52 |
if env_override:
|
| 53 |
return env_override
|
| 54 |
if override:
|
| 55 |
return override
|
| 56 |
+
if prefer_cpu_safe:
|
| 57 |
+
return config.get("cpu_base_model", "Qwen/Qwen2.5-1.5B-Instruct")
|
| 58 |
return config.get("base_model", "Qwen/Qwen2.5-7B-Instruct")
|
| 59 |
|
| 60 |
|
|
|
|
| 90 |
return dataset.map(add_text, remove_columns=[])
|
| 91 |
|
| 92 |
|
| 93 |
+
def _init_model(base_model: str, quant_config: Optional[BitsAndBytesConfig], allow_stub: bool = False):
|
| 94 |
kwargs = {"device_map": "auto"}
|
| 95 |
if quant_config is not None:
|
| 96 |
kwargs["quantization_config"] = quant_config
|
| 97 |
else:
|
| 98 |
kwargs["torch_dtype"] = torch.float32
|
| 99 |
kwargs["low_cpu_mem_usage"] = True
|
| 100 |
+
try:
|
| 101 |
+
return AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
|
| 102 |
+
except Exception as exc: # pragma: no cover - fallback for offline environments
|
| 103 |
+
if not allow_stub:
|
| 104 |
+
raise
|
| 105 |
+
print(f"Model load failed ({exc}); using stub GPT-2 config for dry-run.")
|
| 106 |
+
tiny_config = GPT2Config(n_embd=64, n_layer=2, n_head=2, n_positions=128, vocab_size=256)
|
| 107 |
+
return AutoModelForCausalLM.from_config(tiny_config)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class _StubTokenizer:
|
| 111 |
+
def __init__(self) -> None:
|
| 112 |
+
self.pad_token = "<|pad|>"
|
| 113 |
+
self.eos_token = "</s>"
|
| 114 |
+
self.padding_side = "right"
|
| 115 |
+
|
| 116 |
+
def apply_chat_template(self, messages: List[Dict], tokenize: bool = False, **_: Dict) -> str:
|
| 117 |
+
return "\n".join(f"{m.get('role')}: {m.get('content')}" for m in messages)
|
| 118 |
+
|
| 119 |
+
def __call__(self, texts, max_length: int = 2048, truncation: bool = True, padding: str = "longest") -> Dict:
|
| 120 |
+
if isinstance(texts, str):
|
| 121 |
+
texts = [texts]
|
| 122 |
+
input_ids = []
|
| 123 |
+
for text in texts:
|
| 124 |
+
length = min(len(text.split()), max_length)
|
| 125 |
+
input_ids.append(list(range(length)))
|
| 126 |
+
return {"input_ids": input_ids}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _init_tokenizer(base_model: str, allow_stub: bool = False):
|
| 130 |
+
try:
|
| 131 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
|
| 132 |
+
except Exception as exc: # pragma: no cover - fallback for offline environments
|
| 133 |
+
if not allow_stub:
|
| 134 |
+
raise
|
| 135 |
+
print(f"Tokenizer load failed ({exc}); using stub tokenizer for dry-run.")
|
| 136 |
+
tokenizer = _StubTokenizer()
|
| 137 |
tokenizer.padding_side = "right"
|
| 138 |
+
if getattr(tokenizer, "pad_token", None) is None:
|
| 139 |
tokenizer.pad_token = tokenizer.eos_token
|
| 140 |
return tokenizer
|
| 141 |
|
|
|
|
| 165 |
def train(args: argparse.Namespace) -> Path:
|
| 166 |
dataset_dir = _resolve_dataset_dir(args.dataset_dir)
|
| 167 |
if not dataset_dir.exists():
|
| 168 |
+
raise FileNotFoundError(
|
| 169 |
+
f"Dataset directory not found: {dataset_dir}. Set DATASET_DIR first (e.g., `{EXAMPLE_DATASET_CMD}`)."
|
| 170 |
+
)
|
| 171 |
|
| 172 |
train_cfg = _load_yaml(args.config)
|
| 173 |
mix_cfg = _load_yaml(args.mix_config)
|
| 174 |
if args.max_samples is not None:
|
| 175 |
mix_cfg = {**mix_cfg, "max_samples": args.max_samples, "__override_max_samples": True}
|
| 176 |
+
prefer_cpu_safe = args.dry_run and not torch.cuda.is_available() and not args.base_model and not os.environ.get(
|
| 177 |
+
"BASE_MODEL"
|
| 178 |
+
)
|
| 179 |
+
base_model = _load_base_model_name(train_cfg, args.base_model, prefer_cpu_safe=prefer_cpu_safe)
|
| 180 |
+
|
| 181 |
+
validation_errors = run_cli_validator(dataset_dir)
|
| 182 |
+
if validation_errors:
|
| 183 |
+
raise ValueError("\n".join(validation_errors))
|
| 184 |
|
| 185 |
if args.strict:
|
| 186 |
_, errors = validate_dataset(dataset_dir, strict=True)
|
|
|
|
| 193 |
args.output_root,
|
| 194 |
run_name=args.run_name,
|
| 195 |
override_max_samples=args.max_samples,
|
| 196 |
+
strict=args.strict,
|
| 197 |
)
|
| 198 |
run_dir = prepared_path.parent
|
| 199 |
|
|
|
|
| 203 |
resolved_mix_cfg = _load_yaml(resolved_mix_path)
|
| 204 |
|
| 205 |
quant_config = _quantization_config()
|
| 206 |
+
tokenizer = _init_tokenizer(base_model, allow_stub=args.dry_run)
|
| 207 |
train_dataset = _build_dataset(prepared_path, tokenizer)
|
| 208 |
|
| 209 |
# Dry-run: load a few samples and ensure tokenization + model load succeed.
|
|
|
|
| 215 |
truncation=True,
|
| 216 |
padding="longest",
|
| 217 |
)
|
| 218 |
+
_ = _init_model(base_model, quant_config, allow_stub=True)
|
| 219 |
_persist_config_snapshot(run_dir, train_cfg, resolved_mix_cfg, base_model)
|
| 220 |
print("Dry-run successful: dataset prepared, tokenizer + model loaded, tokenization OK.")
|
| 221 |
return run_dir
|
train/train_qlora.py
CHANGED
|
@@ -20,7 +20,10 @@ from transformers import (
|
|
| 20 |
from trl import SFTTrainer
|
| 21 |
|
| 22 |
from prepare_dataset import prepare_dataset
|
| 23 |
-
from validate_dataset import
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def _load_yaml(path: Path) -> Dict:
|
|
@@ -33,6 +36,26 @@ def _write_json(path: Path, payload: Dict) -> None:
|
|
| 33 |
json.dump(payload, handle, indent=2, sort_keys=True)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def _validate_sources(dataset_dir: Path, mix_config: Path) -> None:
|
| 37 |
mix_cfg = _load_yaml(mix_config)
|
| 38 |
data_dir = dataset_dir / "data"
|
|
@@ -97,10 +120,11 @@ def _init_model(base_model: str, lora_config: Dict) -> AutoModelForCausalLM:
|
|
| 97 |
|
| 98 |
|
| 99 |
def train(args: argparse.Namespace) -> Path:
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
| 104 |
if not args.config.exists():
|
| 105 |
raise FileNotFoundError(f"Config not found: {args.config}")
|
| 106 |
if not args.mix_config.exists():
|
|
@@ -109,13 +133,16 @@ def train(args: argparse.Namespace) -> Path:
|
|
| 109 |
qlora_cfg = _load_yaml(args.config)
|
| 110 |
mix_config = args.mix_config
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
_validate_sources(
|
| 117 |
|
| 118 |
-
prepared_path = prepare_dataset(
|
| 119 |
run_dir = prepared_path.parent
|
| 120 |
|
| 121 |
tokenizer = AutoTokenizer.from_pretrained(qlora_cfg["base_model"], use_fast=True)
|
|
@@ -185,7 +212,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 185 |
"--dataset-dir",
|
| 186 |
required=False,
|
| 187 |
type=Path,
|
| 188 |
-
default=os.environ.get("DATASET_DIR"),
|
| 189 |
help="Path to dataset repository (or set DATASET_DIR)",
|
| 190 |
)
|
| 191 |
parser.add_argument("--config", type=Path, default=Path("train/configs/qlora.yaml"), help="QLoRA config path")
|
|
@@ -193,6 +220,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 193 |
parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for outputs")
|
| 194 |
parser.add_argument("--run-name", type=str, default=os.environ.get("RUN_NAME"), help="Optional run folder name")
|
| 195 |
parser.add_argument("--dry-run", action="store_true", help="Load model/tokenizer and tokenize sample without training")
|
|
|
|
| 196 |
return parser.parse_args()
|
| 197 |
|
| 198 |
|
|
|
|
| 20 |
from trl import SFTTrainer
|
| 21 |
|
| 22 |
from prepare_dataset import prepare_dataset
|
| 23 |
+
from validate_dataset import run_cli_validator, validate_file
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
EXAMPLE_DATASET_CMD = "export DATASET_DIR=/absolute/path/to/blux-ca-dataset"
|
| 27 |
|
| 28 |
|
| 29 |
def _load_yaml(path: Path) -> Dict:
|
|
|
|
| 36 |
json.dump(payload, handle, indent=2, sort_keys=True)
|
| 37 |
|
| 38 |
|
| 39 |
+
def _resolve_dataset_dir(raw: Optional[Path]) -> Path:
|
| 40 |
+
if raw:
|
| 41 |
+
return raw
|
| 42 |
+
env_dir = os.environ.get("DATASET_DIR")
|
| 43 |
+
if env_dir:
|
| 44 |
+
return Path(env_dir)
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., {EXAMPLE_DATASET_CMD})"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _resolve_base_model(cfg: Dict, prefer_cpu_safe: bool = False) -> str:
|
| 51 |
+
env_base_model = os.environ.get("BASE_MODEL")
|
| 52 |
+
if env_base_model:
|
| 53 |
+
return env_base_model
|
| 54 |
+
if prefer_cpu_safe:
|
| 55 |
+
return cfg.get("cpu_base_model", cfg.get("base_model"))
|
| 56 |
+
return cfg.get("base_model")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
def _validate_sources(dataset_dir: Path, mix_config: Path) -> None:
|
| 60 |
mix_cfg = _load_yaml(mix_config)
|
| 61 |
data_dir = dataset_dir / "data"
|
|
|
|
| 120 |
|
| 121 |
|
| 122 |
def train(args: argparse.Namespace) -> Path:
|
| 123 |
+
dataset_dir = _resolve_dataset_dir(args.dataset_dir)
|
| 124 |
+
if not dataset_dir.exists():
|
| 125 |
+
raise FileNotFoundError(
|
| 126 |
+
f"Dataset directory not found: {dataset_dir}. Set DATASET_DIR first (e.g., `{EXAMPLE_DATASET_CMD}`)."
|
| 127 |
+
)
|
| 128 |
if not args.config.exists():
|
| 129 |
raise FileNotFoundError(f"Config not found: {args.config}")
|
| 130 |
if not args.mix_config.exists():
|
|
|
|
| 133 |
qlora_cfg = _load_yaml(args.config)
|
| 134 |
mix_config = args.mix_config
|
| 135 |
|
| 136 |
+
prefer_cpu_safe = args.dry_run and not torch.cuda.is_available() and not os.environ.get("BASE_MODEL")
|
| 137 |
+
qlora_cfg["base_model"] = _resolve_base_model(qlora_cfg, prefer_cpu_safe=prefer_cpu_safe)
|
| 138 |
+
|
| 139 |
+
validation_errors = run_cli_validator(dataset_dir)
|
| 140 |
+
if validation_errors:
|
| 141 |
+
raise ValueError("\n".join(validation_errors))
|
| 142 |
|
| 143 |
+
_validate_sources(dataset_dir, mix_config)
|
| 144 |
|
| 145 |
+
prepared_path = prepare_dataset(dataset_dir, mix_config, args.output_root, run_name=args.run_name, strict=args.strict)
|
| 146 |
run_dir = prepared_path.parent
|
| 147 |
|
| 148 |
tokenizer = AutoTokenizer.from_pretrained(qlora_cfg["base_model"], use_fast=True)
|
|
|
|
| 212 |
"--dataset-dir",
|
| 213 |
required=False,
|
| 214 |
type=Path,
|
| 215 |
+
default=Path(os.environ["DATASET_DIR"]) if os.environ.get("DATASET_DIR") else None,
|
| 216 |
help="Path to dataset repository (or set DATASET_DIR)",
|
| 217 |
)
|
| 218 |
parser.add_argument("--config", type=Path, default=Path("train/configs/qlora.yaml"), help="QLoRA config path")
|
|
|
|
| 220 |
parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for outputs")
|
| 221 |
parser.add_argument("--run-name", type=str, default=os.environ.get("RUN_NAME"), help="Optional run folder name")
|
| 222 |
parser.add_argument("--dry-run", action="store_true", help="Load model/tokenizer and tokenize sample without training")
|
| 223 |
+
parser.add_argument("--strict", action="store_true", help="Validate dataset strictly before mixing")
|
| 224 |
return parser.parse_args()
|
| 225 |
|
| 226 |
|
train/validate_dataset.py
CHANGED
|
@@ -8,6 +8,7 @@ from __future__ import annotations
|
|
| 8 |
import argparse
|
| 9 |
import importlib.util
|
| 10 |
import json
|
|
|
|
| 11 |
import sys
|
| 12 |
from pathlib import Path
|
| 13 |
from typing import Dict, List, Optional, Tuple
|
|
@@ -15,6 +16,32 @@ from typing import Dict, List, Optional, Tuple
|
|
| 15 |
SYSTEM_PLACEHOLDER = "<SYSTEM_PROMPT_FROM_BLUX_CA>"
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def _load_external_validator(dataset_dir: Path):
|
| 19 |
"""Load dataset-provided validator if available.
|
| 20 |
|
|
@@ -146,6 +173,14 @@ def validate_dataset(dataset_dir: Path, files: Optional[str] = None, strict: boo
|
|
| 146 |
if not eval_dir.exists():
|
| 147 |
return 0, [f"Eval probes missing: {eval_dir}"]
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
external_validator = _load_external_validator(dataset_dir)
|
| 150 |
if external_validator:
|
| 151 |
print("Using dataset-supplied validator")
|
|
|
|
| 8 |
import argparse
|
| 9 |
import importlib.util
|
| 10 |
import json
|
| 11 |
+
import subprocess
|
| 12 |
import sys
|
| 13 |
from pathlib import Path
|
| 14 |
from typing import Dict, List, Optional, Tuple
|
|
|
|
| 16 |
SYSTEM_PLACEHOLDER = "<SYSTEM_PROMPT_FROM_BLUX_CA>"
|
| 17 |
|
| 18 |
|
| 19 |
+
def run_cli_validator(dataset_dir: Path, files: Optional[List[Path]] = None) -> List[str]:
|
| 20 |
+
"""Invoke the dataset repository's validator script via subprocess."""
|
| 21 |
+
|
| 22 |
+
validator_path = dataset_dir / "tools" / "validate_jsonl.py"
|
| 23 |
+
if not validator_path.exists():
|
| 24 |
+
return []
|
| 25 |
+
|
| 26 |
+
rel_files = []
|
| 27 |
+
if files:
|
| 28 |
+
for f in files:
|
| 29 |
+
if f.is_absolute() and dataset_dir in f.parents:
|
| 30 |
+
rel_files.append(str(f.relative_to(dataset_dir)))
|
| 31 |
+
else:
|
| 32 |
+
rel_files.append(str(f))
|
| 33 |
+
|
| 34 |
+
cmd = [sys.executable, str(validator_path), *rel_files]
|
| 35 |
+
result = subprocess.run(cmd, capture_output=True, text=True, cwd=dataset_dir)
|
| 36 |
+
if result.returncode != 0:
|
| 37 |
+
output = (result.stdout + "\n" + result.stderr).strip()
|
| 38 |
+
return [line for line in output.splitlines() if line.strip()] or [
|
| 39 |
+
f"Validator exited with code {result.returncode}",
|
| 40 |
+
f"Re-run manually: python {validator_path}",
|
| 41 |
+
]
|
| 42 |
+
return []
|
| 43 |
+
|
| 44 |
+
|
| 45 |
def _load_external_validator(dataset_dir: Path):
|
| 46 |
"""Load dataset-provided validator if available.
|
| 47 |
|
|
|
|
| 173 |
if not eval_dir.exists():
|
| 174 |
return 0, [f"Eval probes missing: {eval_dir}"]
|
| 175 |
|
| 176 |
+
missing_files = [path for path in candidates if not path.exists()]
|
| 177 |
+
if missing_files:
|
| 178 |
+
return 0, [f"Missing file: {path}" for path in missing_files]
|
| 179 |
+
|
| 180 |
+
cli_errors = run_cli_validator(dataset_dir, candidates)
|
| 181 |
+
if cli_errors:
|
| 182 |
+
return 0, cli_errors
|
| 183 |
+
|
| 184 |
external_validator = _load_external_validator(dataset_dir)
|
| 185 |
if external_validator:
|
| 186 |
print("Using dataset-supplied validator")
|