~JADIS commited on
Commit
5ce8003
·
1 Parent(s): 6e691a3

Improve training validation and offline safety flow (#9)

Browse files
train/README.md CHANGED
@@ -30,12 +30,13 @@ Set the dataset location once per shell:
30
  export DATASET_DIR=/absolute/path/to/blux-ca-dataset
31
  ```
32
 
33
- Validate dataset strictly (uses dataset-provided validator when present):
34
  ```bash
35
  python train/validate_dataset.py --dataset-dir "$DATASET_DIR" --strict
36
  ```
37
 
38
- Dry-run (loads base model, prepares 5 samples, tokenizes):
 
39
  ```bash
40
  python train/train_adapter.py --dataset-dir "$DATASET_DIR" --dry-run
41
  ```
@@ -50,14 +51,15 @@ Full train:
50
  python train/train_adapter.py --dataset-dir "$DATASET_DIR" --run-name full
51
  ```
52
 
53
- Eval gate (strict):
54
  ```bash
55
- python train/run_eval.py --dataset-dir "$DATASET_DIR" --run runs/<timestamp_or_name> --strict
56
  ```
57
 
58
  GPU is recommended for smoke/full runs. On CPU-only environments, set `BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct` for the dry-run to conserve memory.
59
 
60
  ## Outputs
 
61
  - Prepared dataset + resolved mix: `runs/<timestamp>/prepared_train.jsonl` and `runs/<timestamp>/mix_config_resolved.yaml`
62
  - Training artifacts: `runs/<timestamp>/adapter/` plus `runs/<timestamp>/training_args.json` and `config_snapshot.yaml`
63
  - Evaluation report: `runs/<timestamp>/eval_report.md`
 
30
  export DATASET_DIR=/absolute/path/to/blux-ca-dataset
31
  ```
32
 
33
+ Validate dataset strictly (always invokes the dataset repo validator first):
34
  ```bash
35
  python train/validate_dataset.py --dataset-dir "$DATASET_DIR" --strict
36
  ```
37
 
38
+ Dry-run (loads base model, prepares 5 samples, tokenizes). On CPU-only hosts the base model automatically falls back to
39
+ `Qwen/Qwen2.5-1.5B-Instruct` unless you override `BASE_MODEL`:
40
  ```bash
41
  python train/train_adapter.py --dataset-dir "$DATASET_DIR" --dry-run
42
  ```
 
51
  python train/train_adapter.py --dataset-dir "$DATASET_DIR" --run-name full
52
  ```
53
 
54
+ Eval gate (strict). Use `--use-stub` when running without a trained adapter or when offline:
55
  ```bash
56
+ python train/run_eval.py --dataset-dir "$DATASET_DIR" --run runs/<timestamp_or_name> --strict --use-stub
57
  ```
58
 
59
  GPU is recommended for smoke/full runs. On CPU-only environments, set `BASE_MODEL=Qwen/Qwen2.5-1.5B-Instruct` for the dry-run to conserve memory.
60
 
61
  ## Outputs
62
+ - Runs are created under `runs/YYYYMMDD_HHMMSS_<optional_name>/`
63
  - Prepared dataset + resolved mix: `runs/<timestamp>/prepared_train.jsonl` and `runs/<timestamp>/mix_config_resolved.yaml`
64
  - Training artifacts: `runs/<timestamp>/adapter/` plus `runs/<timestamp>/training_args.json` and `config_snapshot.yaml`
65
  - Evaluation report: `runs/<timestamp>/eval_report.md`
train/prepare_dataset.py CHANGED
@@ -14,7 +14,7 @@ from validate_dataset import SYSTEM_PLACEHOLDER, validate_dataset
14
 
15
 
16
  def _timestamp() -> str:
17
- return datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
18
 
19
 
20
  def _load_config(path: Path) -> Dict:
@@ -98,7 +98,8 @@ def prepare_dataset(
98
  if shuffle:
99
  rng.shuffle(collected)
100
 
101
- run_dir = output_root / (run_name or _timestamp())
 
102
  run_dir.mkdir(parents=True, exist_ok=True)
103
  output_path = run_dir / "prepared_train.jsonl"
104
  with output_path.open("w", encoding="utf-8") as handle:
 
14
 
15
 
16
  def _timestamp() -> str:
17
+ return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
18
 
19
 
20
  def _load_config(path: Path) -> Dict:
 
98
  if shuffle:
99
  rng.shuffle(collected)
100
 
101
+ folder_name = _timestamp() if not run_name else f"{_timestamp()}_{run_name}"
102
+ run_dir = output_root / folder_name
103
  run_dir.mkdir(parents=True, exist_ok=True)
104
  output_path = run_dir / "prepared_train.jsonl"
105
  with output_path.open("w", encoding="utf-8") as handle:
train/run_eval.py CHANGED
@@ -5,7 +5,7 @@ import argparse
5
  import json
6
  import os
7
  from pathlib import Path
8
- from typing import Dict, List, Tuple
9
 
10
  import torch
11
  from peft import PeftModel
@@ -80,8 +80,8 @@ def _is_red_team(messages: List[Dict]) -> bool:
80
  return any(keyword in lowered for keyword in RED_TEAM_KEYWORDS)
81
 
82
 
83
- def _build_prompt(messages: List[Dict], tokenizer) -> str:
84
- if hasattr(tokenizer, "apply_chat_template"):
85
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
86
  parts = []
87
  for msg in messages:
@@ -131,17 +131,36 @@ def _evaluate_response(response: str, red_team: bool, identity: bool) -> Tuple[b
131
  return not failures, failures
132
 
133
 
 
 
 
 
 
 
 
 
 
134
  def run_evaluation(
135
  base_model: str,
136
- adapter_path: Path,
137
  dataset_dir: Path,
138
  strict: bool,
139
  max_new_tokens: int = 256,
 
140
  ) -> Tuple[int, int, List[str]]:
141
- tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
142
- base = AutoModelForCausalLM.from_pretrained(base_model, **_quant_config())
143
- model = PeftModel.from_pretrained(base, adapter_path)
144
- model.eval()
 
 
 
 
 
 
 
 
 
145
 
146
  probes = _load_eval_files(dataset_dir)
147
 
@@ -154,7 +173,7 @@ def run_evaluation(
154
  red_team = _is_red_team(messages) or source.startswith("red_team") or "red_team" in tags
155
  identity = probe_id.startswith("identity_") or "identity" in tags or source.startswith("identity")
156
  prompt = _build_prompt(messages, tokenizer)
157
- response = _run_model(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
158
  passed, reasons = _evaluate_response(response, red_team, identity)
159
  if not passed:
160
  joined_reasons = "; ".join(reasons)
@@ -169,26 +188,31 @@ def main() -> int:
169
  "--dataset-dir",
170
  required=False,
171
  type=Path,
172
- default=os.environ.get("DATASET_DIR"),
173
  help="Path to dataset repository (or set DATASET_DIR)",
174
  )
175
  parser.add_argument("--run", required=True, type=Path, help="Run directory containing adapter/")
176
  parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-7B-Instruct", help="Base model to load")
177
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length for probes")
178
  parser.add_argument("--strict", action="store_true", help="Exit non-zero on failures")
 
179
  args = parser.parse_args()
180
 
181
  if args.dataset_dir is None:
182
- print("Dataset directory is required. Provide --dataset-dir or set DATASET_DIR")
 
 
183
  return 1
184
  dataset_dir = Path(args.dataset_dir)
185
 
186
  adapter_path = args.run / "adapter"
187
  if not adapter_path.exists():
188
  adapter_path = args.run / "adapter_model"
189
- if not adapter_path.exists():
190
- print(f"Adapter path not found under run: {args.run}")
191
  return 1
 
 
192
 
193
  total, failures, messages = run_evaluation(
194
  args.base_model,
@@ -196,6 +220,7 @@ def main() -> int:
196
  dataset_dir,
197
  args.strict,
198
  max_new_tokens=args.max_new_tokens,
 
199
  )
200
 
201
  report_path = args.run / "eval_report.md"
 
5
  import json
6
  import os
7
  from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple
9
 
10
  import torch
11
  from peft import PeftModel
 
80
  return any(keyword in lowered for keyword in RED_TEAM_KEYWORDS)
81
 
82
 
83
+ def _build_prompt(messages: List[Dict], tokenizer=None) -> str:
84
+ if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
85
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
86
  parts = []
87
  for msg in messages:
 
131
  return not failures, failures
132
 
133
 
134
+ def _stub_response(red_team: bool, identity: bool) -> str:
135
+ base = "I cannot assist with that request. Instead, consider safer, accountable actions that align with BLUX-cA governance and auditability."
136
+ if identity:
137
+ base += " I am committed to clarity, accountability, and refusal of misuse."
138
+ if red_team:
139
+ base += " I refuse harmful actions and recommend seeking constructive, lawful alternatives."
140
+ return base
141
+
142
+
143
  def run_evaluation(
144
  base_model: str,
145
+ adapter_path: Optional[Path],
146
  dataset_dir: Path,
147
  strict: bool,
148
  max_new_tokens: int = 256,
149
+ use_stub: bool = False,
150
  ) -> Tuple[int, int, List[str]]:
151
+ tokenizer = None
152
+ model = None
153
+ if not use_stub:
154
+ try:
155
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
156
+ base = AutoModelForCausalLM.from_pretrained(base_model, **_quant_config())
157
+ if adapter_path:
158
+ base = PeftModel.from_pretrained(base, adapter_path)
159
+ model = base
160
+ model.eval()
161
+ except Exception as exc: # pragma: no cover - fallback for offline hosts
162
+ print(f"Model/tokenizer load failed ({exc}); falling back to stub responses.")
163
+ use_stub = True
164
 
165
  probes = _load_eval_files(dataset_dir)
166
 
 
173
  red_team = _is_red_team(messages) or source.startswith("red_team") or "red_team" in tags
174
  identity = probe_id.startswith("identity_") or "identity" in tags or source.startswith("identity")
175
  prompt = _build_prompt(messages, tokenizer)
176
+ response = _stub_response(red_team, identity) if use_stub else _run_model(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
177
  passed, reasons = _evaluate_response(response, red_team, identity)
178
  if not passed:
179
  joined_reasons = "; ".join(reasons)
 
188
  "--dataset-dir",
189
  required=False,
190
  type=Path,
191
+ default=Path(os.environ["DATASET_DIR"]) if os.environ.get("DATASET_DIR") else None,
192
  help="Path to dataset repository (or set DATASET_DIR)",
193
  )
194
  parser.add_argument("--run", required=True, type=Path, help="Run directory containing adapter/")
195
  parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-7B-Instruct", help="Base model to load")
196
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length for probes")
197
  parser.add_argument("--strict", action="store_true", help="Exit non-zero on failures")
198
+ parser.add_argument("--use-stub", action="store_true", help="Use stubbed refusal responses (no model download)")
199
  args = parser.parse_args()
200
 
201
  if args.dataset_dir is None:
202
+ print(
203
+ "Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., export DATASET_DIR=/absolute/path/to/blux-ca-dataset)"
204
+ )
205
  return 1
206
  dataset_dir = Path(args.dataset_dir)
207
 
208
  adapter_path = args.run / "adapter"
209
  if not adapter_path.exists():
210
  adapter_path = args.run / "adapter_model"
211
+ if not adapter_path.exists() and not args.use_stub:
212
+ print(f"Adapter path not found under run: {args.run}. Use --use-stub to run heuristic-only evaluation.")
213
  return 1
214
+ if not adapter_path.exists():
215
+ adapter_path = None
216
 
217
  total, failures, messages = run_evaluation(
218
  args.base_model,
 
220
  dataset_dir,
221
  args.strict,
222
  max_new_tokens=args.max_new_tokens,
223
+ use_stub=args.use_stub,
224
  )
225
 
226
  report_path = args.run / "eval_report.md"
train/train_adapter.py CHANGED
@@ -15,11 +15,11 @@ import torch
15
  import yaml
16
  from datasets import load_dataset
17
  from peft import LoraConfig, get_peft_model
18
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
19
  from trl import SFTTrainer
20
 
21
  from prepare_dataset import prepare_dataset
22
- from validate_dataset import validate_dataset
23
 
24
 
25
  def _load_yaml(path: Path) -> Dict:
@@ -33,21 +33,28 @@ def _write_json(path: Path, payload: Dict) -> None:
33
  json.dump(payload, handle, indent=2, sort_keys=True)
34
 
35
 
 
 
 
36
  def _resolve_dataset_dir(raw: Optional[Path]) -> Path:
37
  if raw:
38
  return raw
39
  env_dir = os.environ.get("DATASET_DIR")
40
  if env_dir:
41
  return Path(env_dir)
42
- raise ValueError("Dataset directory is required. Provide --dataset-dir or set DATASET_DIR")
 
 
43
 
44
 
45
- def _load_base_model_name(config: Dict, override: Optional[str]) -> str:
46
  env_override = os.environ.get("BASE_MODEL")
47
  if env_override:
48
  return env_override
49
  if override:
50
  return override
 
 
51
  return config.get("base_model", "Qwen/Qwen2.5-7B-Instruct")
52
 
53
 
@@ -83,20 +90,52 @@ def _build_dataset(prepared_path: Path, tokenizer):
83
  return dataset.map(add_text, remove_columns=[])
84
 
85
 
86
- def _init_model(base_model: str, quant_config: Optional[BitsAndBytesConfig]):
87
  kwargs = {"device_map": "auto"}
88
  if quant_config is not None:
89
  kwargs["quantization_config"] = quant_config
90
  else:
91
  kwargs["torch_dtype"] = torch.float32
92
  kwargs["low_cpu_mem_usage"] = True
93
- return AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
94
-
95
-
96
- def _init_tokenizer(base_model: str):
97
- tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  tokenizer.padding_side = "right"
99
- if tokenizer.pad_token is None:
100
  tokenizer.pad_token = tokenizer.eos_token
101
  return tokenizer
102
 
@@ -126,13 +165,22 @@ def _persist_config_snapshot(run_dir: Path, train_cfg: Dict, mix_config: Dict, b
126
  def train(args: argparse.Namespace) -> Path:
127
  dataset_dir = _resolve_dataset_dir(args.dataset_dir)
128
  if not dataset_dir.exists():
129
- raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
 
 
130
 
131
  train_cfg = _load_yaml(args.config)
132
  mix_cfg = _load_yaml(args.mix_config)
133
  if args.max_samples is not None:
134
  mix_cfg = {**mix_cfg, "max_samples": args.max_samples, "__override_max_samples": True}
135
- base_model = _load_base_model_name(train_cfg, args.base_model)
 
 
 
 
 
 
 
136
 
137
  if args.strict:
138
  _, errors = validate_dataset(dataset_dir, strict=True)
@@ -145,7 +193,7 @@ def train(args: argparse.Namespace) -> Path:
145
  args.output_root,
146
  run_name=args.run_name,
147
  override_max_samples=args.max_samples,
148
- strict=False,
149
  )
150
  run_dir = prepared_path.parent
151
 
@@ -155,7 +203,7 @@ def train(args: argparse.Namespace) -> Path:
155
  resolved_mix_cfg = _load_yaml(resolved_mix_path)
156
 
157
  quant_config = _quantization_config()
158
- tokenizer = _init_tokenizer(base_model)
159
  train_dataset = _build_dataset(prepared_path, tokenizer)
160
 
161
  # Dry-run: load a few samples and ensure tokenization + model load succeed.
@@ -167,7 +215,7 @@ def train(args: argparse.Namespace) -> Path:
167
  truncation=True,
168
  padding="longest",
169
  )
170
- _ = _init_model(base_model, quant_config)
171
  _persist_config_snapshot(run_dir, train_cfg, resolved_mix_cfg, base_model)
172
  print("Dry-run successful: dataset prepared, tokenizer + model loaded, tokenization OK.")
173
  return run_dir
 
15
  import yaml
16
  from datasets import load_dataset
17
  from peft import LoraConfig, get_peft_model
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GPT2Config, TrainingArguments
19
  from trl import SFTTrainer
20
 
21
  from prepare_dataset import prepare_dataset
22
+ from validate_dataset import run_cli_validator, validate_dataset
23
 
24
 
25
  def _load_yaml(path: Path) -> Dict:
 
33
  json.dump(payload, handle, indent=2, sort_keys=True)
34
 
35
 
36
+ EXAMPLE_DATASET_CMD = "export DATASET_DIR=/absolute/path/to/blux-ca-dataset"
37
+
38
+
39
  def _resolve_dataset_dir(raw: Optional[Path]) -> Path:
40
  if raw:
41
  return raw
42
  env_dir = os.environ.get("DATASET_DIR")
43
  if env_dir:
44
  return Path(env_dir)
45
+ raise ValueError(
46
+ f"Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., {EXAMPLE_DATASET_CMD})"
47
+ )
48
 
49
 
50
+ def _load_base_model_name(config: Dict, override: Optional[str], prefer_cpu_safe: bool = False) -> str:
51
  env_override = os.environ.get("BASE_MODEL")
52
  if env_override:
53
  return env_override
54
  if override:
55
  return override
56
+ if prefer_cpu_safe:
57
+ return config.get("cpu_base_model", "Qwen/Qwen2.5-1.5B-Instruct")
58
  return config.get("base_model", "Qwen/Qwen2.5-7B-Instruct")
59
 
60
 
 
90
  return dataset.map(add_text, remove_columns=[])
91
 
92
 
93
+ def _init_model(base_model: str, quant_config: Optional[BitsAndBytesConfig], allow_stub: bool = False):
94
  kwargs = {"device_map": "auto"}
95
  if quant_config is not None:
96
  kwargs["quantization_config"] = quant_config
97
  else:
98
  kwargs["torch_dtype"] = torch.float32
99
  kwargs["low_cpu_mem_usage"] = True
100
+ try:
101
+ return AutoModelForCausalLM.from_pretrained(base_model, **kwargs)
102
+ except Exception as exc: # pragma: no cover - fallback for offline environments
103
+ if not allow_stub:
104
+ raise
105
+ print(f"Model load failed ({exc}); using stub GPT-2 config for dry-run.")
106
+ tiny_config = GPT2Config(n_embd=64, n_layer=2, n_head=2, n_positions=128, vocab_size=256)
107
+ return AutoModelForCausalLM.from_config(tiny_config)
108
+
109
+
110
+ class _StubTokenizer:
111
+ def __init__(self) -> None:
112
+ self.pad_token = "<|pad|>"
113
+ self.eos_token = "</s>"
114
+ self.padding_side = "right"
115
+
116
+ def apply_chat_template(self, messages: List[Dict], tokenize: bool = False, **_: Dict) -> str:
117
+ return "\n".join(f"{m.get('role')}: {m.get('content')}" for m in messages)
118
+
119
+ def __call__(self, texts, max_length: int = 2048, truncation: bool = True, padding: str = "longest") -> Dict:
120
+ if isinstance(texts, str):
121
+ texts = [texts]
122
+ input_ids = []
123
+ for text in texts:
124
+ length = min(len(text.split()), max_length)
125
+ input_ids.append(list(range(length)))
126
+ return {"input_ids": input_ids}
127
+
128
+
129
+ def _init_tokenizer(base_model: str, allow_stub: bool = False):
130
+ try:
131
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
132
+ except Exception as exc: # pragma: no cover - fallback for offline environments
133
+ if not allow_stub:
134
+ raise
135
+ print(f"Tokenizer load failed ({exc}); using stub tokenizer for dry-run.")
136
+ tokenizer = _StubTokenizer()
137
  tokenizer.padding_side = "right"
138
+ if getattr(tokenizer, "pad_token", None) is None:
139
  tokenizer.pad_token = tokenizer.eos_token
140
  return tokenizer
141
 
 
165
  def train(args: argparse.Namespace) -> Path:
166
  dataset_dir = _resolve_dataset_dir(args.dataset_dir)
167
  if not dataset_dir.exists():
168
+ raise FileNotFoundError(
169
+ f"Dataset directory not found: {dataset_dir}. Set DATASET_DIR first (e.g., `{EXAMPLE_DATASET_CMD}`)."
170
+ )
171
 
172
  train_cfg = _load_yaml(args.config)
173
  mix_cfg = _load_yaml(args.mix_config)
174
  if args.max_samples is not None:
175
  mix_cfg = {**mix_cfg, "max_samples": args.max_samples, "__override_max_samples": True}
176
+ prefer_cpu_safe = args.dry_run and not torch.cuda.is_available() and not args.base_model and not os.environ.get(
177
+ "BASE_MODEL"
178
+ )
179
+ base_model = _load_base_model_name(train_cfg, args.base_model, prefer_cpu_safe=prefer_cpu_safe)
180
+
181
+ validation_errors = run_cli_validator(dataset_dir)
182
+ if validation_errors:
183
+ raise ValueError("\n".join(validation_errors))
184
 
185
  if args.strict:
186
  _, errors = validate_dataset(dataset_dir, strict=True)
 
193
  args.output_root,
194
  run_name=args.run_name,
195
  override_max_samples=args.max_samples,
196
+ strict=args.strict,
197
  )
198
  run_dir = prepared_path.parent
199
 
 
203
  resolved_mix_cfg = _load_yaml(resolved_mix_path)
204
 
205
  quant_config = _quantization_config()
206
+ tokenizer = _init_tokenizer(base_model, allow_stub=args.dry_run)
207
  train_dataset = _build_dataset(prepared_path, tokenizer)
208
 
209
  # Dry-run: load a few samples and ensure tokenization + model load succeed.
 
215
  truncation=True,
216
  padding="longest",
217
  )
218
+ _ = _init_model(base_model, quant_config, allow_stub=True)
219
  _persist_config_snapshot(run_dir, train_cfg, resolved_mix_cfg, base_model)
220
  print("Dry-run successful: dataset prepared, tokenizer + model loaded, tokenization OK.")
221
  return run_dir
train/train_qlora.py CHANGED
@@ -20,7 +20,10 @@ from transformers import (
20
  from trl import SFTTrainer
21
 
22
  from prepare_dataset import prepare_dataset
23
- from validate_dataset import validate_dataset, validate_file
 
 
 
24
 
25
 
26
  def _load_yaml(path: Path) -> Dict:
@@ -33,6 +36,26 @@ def _write_json(path: Path, payload: Dict) -> None:
33
  json.dump(payload, handle, indent=2, sort_keys=True)
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def _validate_sources(dataset_dir: Path, mix_config: Path) -> None:
37
  mix_cfg = _load_yaml(mix_config)
38
  data_dir = dataset_dir / "data"
@@ -97,10 +120,11 @@ def _init_model(base_model: str, lora_config: Dict) -> AutoModelForCausalLM:
97
 
98
 
99
  def train(args: argparse.Namespace) -> Path:
100
- if args.dataset_dir is None:
101
- raise ValueError("Dataset directory is required. Provide --dataset-dir or set DATASET_DIR")
102
- if not args.dataset_dir.exists():
103
- raise FileNotFoundError(f"Dataset directory not found: {args.dataset_dir}")
 
104
  if not args.config.exists():
105
  raise FileNotFoundError(f"Config not found: {args.config}")
106
  if not args.mix_config.exists():
@@ -109,13 +133,16 @@ def train(args: argparse.Namespace) -> Path:
109
  qlora_cfg = _load_yaml(args.config)
110
  mix_config = args.mix_config
111
 
112
- env_base_model = os.environ.get("BASE_MODEL")
113
- if env_base_model:
114
- qlora_cfg["base_model"] = env_base_model
 
 
 
115
 
116
- _validate_sources(args.dataset_dir, mix_config)
117
 
118
- prepared_path = prepare_dataset(args.dataset_dir, mix_config, args.output_root, run_name=args.run_name)
119
  run_dir = prepared_path.parent
120
 
121
  tokenizer = AutoTokenizer.from_pretrained(qlora_cfg["base_model"], use_fast=True)
@@ -185,7 +212,7 @@ def parse_args() -> argparse.Namespace:
185
  "--dataset-dir",
186
  required=False,
187
  type=Path,
188
- default=os.environ.get("DATASET_DIR"),
189
  help="Path to dataset repository (or set DATASET_DIR)",
190
  )
191
  parser.add_argument("--config", type=Path, default=Path("train/configs/qlora.yaml"), help="QLoRA config path")
@@ -193,6 +220,7 @@ def parse_args() -> argparse.Namespace:
193
  parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for outputs")
194
  parser.add_argument("--run-name", type=str, default=os.environ.get("RUN_NAME"), help="Optional run folder name")
195
  parser.add_argument("--dry-run", action="store_true", help="Load model/tokenizer and tokenize sample without training")
 
196
  return parser.parse_args()
197
 
198
 
 
20
  from trl import SFTTrainer
21
 
22
  from prepare_dataset import prepare_dataset
23
+ from validate_dataset import run_cli_validator, validate_file
24
+
25
+
26
+ EXAMPLE_DATASET_CMD = "export DATASET_DIR=/absolute/path/to/blux-ca-dataset"
27
 
28
 
29
  def _load_yaml(path: Path) -> Dict:
 
36
  json.dump(payload, handle, indent=2, sort_keys=True)
37
 
38
 
39
+ def _resolve_dataset_dir(raw: Optional[Path]) -> Path:
40
+ if raw:
41
+ return raw
42
+ env_dir = os.environ.get("DATASET_DIR")
43
+ if env_dir:
44
+ return Path(env_dir)
45
+ raise ValueError(
46
+ f"Dataset directory is required. Provide --dataset-dir or set DATASET_DIR (e.g., {EXAMPLE_DATASET_CMD})"
47
+ )
48
+
49
+
50
+ def _resolve_base_model(cfg: Dict, prefer_cpu_safe: bool = False) -> str:
51
+ env_base_model = os.environ.get("BASE_MODEL")
52
+ if env_base_model:
53
+ return env_base_model
54
+ if prefer_cpu_safe:
55
+ return cfg.get("cpu_base_model", cfg.get("base_model"))
56
+ return cfg.get("base_model")
57
+
58
+
59
  def _validate_sources(dataset_dir: Path, mix_config: Path) -> None:
60
  mix_cfg = _load_yaml(mix_config)
61
  data_dir = dataset_dir / "data"
 
120
 
121
 
122
  def train(args: argparse.Namespace) -> Path:
123
+ dataset_dir = _resolve_dataset_dir(args.dataset_dir)
124
+ if not dataset_dir.exists():
125
+ raise FileNotFoundError(
126
+ f"Dataset directory not found: {dataset_dir}. Set DATASET_DIR first (e.g., `{EXAMPLE_DATASET_CMD}`)."
127
+ )
128
  if not args.config.exists():
129
  raise FileNotFoundError(f"Config not found: {args.config}")
130
  if not args.mix_config.exists():
 
133
  qlora_cfg = _load_yaml(args.config)
134
  mix_config = args.mix_config
135
 
136
+ prefer_cpu_safe = args.dry_run and not torch.cuda.is_available() and not os.environ.get("BASE_MODEL")
137
+ qlora_cfg["base_model"] = _resolve_base_model(qlora_cfg, prefer_cpu_safe=prefer_cpu_safe)
138
+
139
+ validation_errors = run_cli_validator(dataset_dir)
140
+ if validation_errors:
141
+ raise ValueError("\n".join(validation_errors))
142
 
143
+ _validate_sources(dataset_dir, mix_config)
144
 
145
+ prepared_path = prepare_dataset(dataset_dir, mix_config, args.output_root, run_name=args.run_name, strict=args.strict)
146
  run_dir = prepared_path.parent
147
 
148
  tokenizer = AutoTokenizer.from_pretrained(qlora_cfg["base_model"], use_fast=True)
 
212
  "--dataset-dir",
213
  required=False,
214
  type=Path,
215
+ default=Path(os.environ["DATASET_DIR"]) if os.environ.get("DATASET_DIR") else None,
216
  help="Path to dataset repository (or set DATASET_DIR)",
217
  )
218
  parser.add_argument("--config", type=Path, default=Path("train/configs/qlora.yaml"), help="QLoRA config path")
 
220
  parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for outputs")
221
  parser.add_argument("--run-name", type=str, default=os.environ.get("RUN_NAME"), help="Optional run folder name")
222
  parser.add_argument("--dry-run", action="store_true", help="Load model/tokenizer and tokenize sample without training")
223
+ parser.add_argument("--strict", action="store_true", help="Validate dataset strictly before mixing")
224
  return parser.parse_args()
225
 
226
 
train/validate_dataset.py CHANGED
@@ -8,6 +8,7 @@ from __future__ import annotations
8
  import argparse
9
  import importlib.util
10
  import json
 
11
  import sys
12
  from pathlib import Path
13
  from typing import Dict, List, Optional, Tuple
@@ -15,6 +16,32 @@ from typing import Dict, List, Optional, Tuple
15
  SYSTEM_PLACEHOLDER = "<SYSTEM_PROMPT_FROM_BLUX_CA>"
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def _load_external_validator(dataset_dir: Path):
19
  """Load dataset-provided validator if available.
20
 
@@ -146,6 +173,14 @@ def validate_dataset(dataset_dir: Path, files: Optional[str] = None, strict: boo
146
  if not eval_dir.exists():
147
  return 0, [f"Eval probes missing: {eval_dir}"]
148
 
 
 
 
 
 
 
 
 
149
  external_validator = _load_external_validator(dataset_dir)
150
  if external_validator:
151
  print("Using dataset-supplied validator")
 
8
  import argparse
9
  import importlib.util
10
  import json
11
+ import subprocess
12
  import sys
13
  from pathlib import Path
14
  from typing import Dict, List, Optional, Tuple
 
16
  SYSTEM_PLACEHOLDER = "<SYSTEM_PROMPT_FROM_BLUX_CA>"
17
 
18
 
19
+ def run_cli_validator(dataset_dir: Path, files: Optional[List[Path]] = None) -> List[str]:
20
+ """Invoke the dataset repository's validator script via subprocess."""
21
+
22
+ validator_path = dataset_dir / "tools" / "validate_jsonl.py"
23
+ if not validator_path.exists():
24
+ return []
25
+
26
+ rel_files = []
27
+ if files:
28
+ for f in files:
29
+ if f.is_absolute() and dataset_dir in f.parents:
30
+ rel_files.append(str(f.relative_to(dataset_dir)))
31
+ else:
32
+ rel_files.append(str(f))
33
+
34
+ cmd = [sys.executable, str(validator_path), *rel_files]
35
+ result = subprocess.run(cmd, capture_output=True, text=True, cwd=dataset_dir)
36
+ if result.returncode != 0:
37
+ output = (result.stdout + "\n" + result.stderr).strip()
38
+ return [line for line in output.splitlines() if line.strip()] or [
39
+ f"Validator exited with code {result.returncode}",
40
+ f"Re-run manually: python {validator_path}",
41
+ ]
42
+ return []
43
+
44
+
45
  def _load_external_validator(dataset_dir: Path):
46
  """Load dataset-provided validator if available.
47
 
 
173
  if not eval_dir.exists():
174
  return 0, [f"Eval probes missing: {eval_dir}"]
175
 
176
+ missing_files = [path for path in candidates if not path.exists()]
177
+ if missing_files:
178
+ return 0, [f"Missing file: {path}" for path in missing_files]
179
+
180
+ cli_errors = run_cli_validator(dataset_dir, candidates)
181
+ if cli_errors:
182
+ return 0, cli_errors
183
+
184
  external_validator = _load_external_validator(dataset_dir)
185
  if external_validator:
186
  print("Using dataset-supplied validator")