blux-ca / train /prepare_dataset.py
~JADIS
Improve training validation and offline safety flow (#9)
5ce8003
"""Prepare a weighted, shuffled training set for BLUX-cA QLoRA."""
from __future__ import annotations
import argparse
import json
import random
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import yaml
from validate_dataset import SYSTEM_PLACEHOLDER, validate_dataset
def _timestamp() -> str:
return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
def _load_config(path: Path) -> Dict:
with path.open("r", encoding="utf-8") as handle:
return yaml.safe_load(handle)
def _load_jsonl(path: Path) -> List[Dict]:
records: List[Dict] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
if line.strip():
records.append(json.loads(line))
return records
def _sample_records(records: List[Dict], target: int, rng: random.Random) -> List[Dict]:
if not records:
return []
if target <= len(records):
return rng.sample(records, target)
return [rng.choice(records) for _ in range(target)]
def prepare_dataset(
dataset_dir: Path,
mix_config: Path,
output_root: Path,
run_name: Optional[str] = None,
override_max_samples: Optional[int] = None,
strict: bool = False,
) -> Path:
"""Create a weighted, deterministic training mix.
If ``override_max_samples`` is provided it takes precedence over the YAML
``max_samples`` value. When ``strict`` is True, data files are validated
before mixing.
"""
if strict:
_, errors = validate_dataset(dataset_dir, strict=True)
if errors:
joined = "\n".join(errors)
raise ValueError(f"Dataset validation failed before mixing:\n{joined}")
config = _load_config(mix_config)
sources = config.get("sources", [])
shuffle = bool(config.get("shuffle", True))
max_samples = override_max_samples if override_max_samples is not None else config.get("max_samples")
seed = int(config.get("seed", 42))
rng = random.Random(seed)
total_weight = sum(float(src.get("weight", 1.0)) for src in sources)
if total_weight <= 0:
raise ValueError("Total weight must be positive")
collected: List[Dict] = []
for src in sources:
raw_path = Path(src["file"])
if raw_path.is_absolute():
file_path = raw_path
elif raw_path.parts and raw_path.parts[0] == "data":
file_path = dataset_dir / raw_path
else:
file_path = dataset_dir / "data" / raw_path
if not file_path.exists():
raise FileNotFoundError(f"Missing dataset file: {file_path}")
weight = float(src.get("weight", 1.0))
records = _load_jsonl(file_path)
if max_samples is None:
target = max(len(records), 0)
else:
target = max(1, round((weight / total_weight) * max_samples))
sampled = _sample_records(records, target, rng)
collected.extend(sampled)
if not collected:
raise ValueError("No samples collected from provided sources")
if shuffle:
rng.shuffle(collected)
folder_name = _timestamp() if not run_name else f"{_timestamp()}_{run_name}"
run_dir = output_root / folder_name
run_dir.mkdir(parents=True, exist_ok=True)
output_path = run_dir / "prepared_train.jsonl"
with output_path.open("w", encoding="utf-8") as handle:
for record in collected:
if "messages" in record:
system_msgs = [m for m in record["messages"] if m.get("role") == "system"]
if system_msgs:
system_msgs[0]["content"] = SYSTEM_PLACEHOLDER
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
resolved_mix_path = run_dir / "mix_config_resolved.yaml"
with resolved_mix_path.open("w", encoding="utf-8") as handle:
yaml.safe_dump(
{
**config,
"max_samples": max_samples,
"seed": seed,
"dataset_dir": str(dataset_dir),
},
handle,
sort_keys=False,
)
return output_path
def main() -> int:
parser = argparse.ArgumentParser(description="Prepare weighted training data for QLoRA")
parser.add_argument("--dataset-dir", required=True, type=Path, help="Path to dataset repository")
parser.add_argument("--mix-config", type=Path, default=Path("train/configs/dataset_mix.yaml"), help="Mixing config YAML")
parser.add_argument("--output-root", type=Path, default=Path("runs"), help="Root directory for run outputs")
parser.add_argument("--run-name", type=str, default=None, help="Optional run folder name (otherwise timestamp)")
parser.add_argument("--max-samples", type=int, default=None, help="Override max_samples in config for quick smoke runs")
parser.add_argument("--strict", action="store_true", help="Validate input files strictly before mixing")
args = parser.parse_args()
if not args.dataset_dir.exists():
print(f"Dataset directory not found: {args.dataset_dir}")
return 1
output_path = prepare_dataset(
args.dataset_dir,
args.mix_config,
args.output_root,
run_name=args.run_name,
override_max_samples=args.max_samples,
strict=args.strict,
)
print(f"Prepared dataset written to {output_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())