| | |
| | |
| | |
| | """ |
| | Run script for fine-tuning OlmoE with adapters on specific text domains. |
| | Handles argument parsing and configuration. |
| | """ |
| |
|
| | import argparse |
| | import os |
| | import sys |
| | from dataclasses import dataclass, field |
| | from typing import Optional |
| |
|
| | from transformers import ( |
| | HfArgumentParser, |
| | TrainingArguments, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class ScriptArguments: |
| | """ |
| | Arguments for the run script that aren't covered by TrainingArguments. |
| | """ |
| | model_path: str = field( |
| | default="allenai/OLMo-7B-Instruct", |
| | metadata={"help": "Path to the model to fine-tune"} |
| | ) |
| | output_dir: str = field( |
| | default="./output_olmoe_adapter", |
| | metadata={"help": "Directory to save the model and logs"} |
| | ) |
| | adapter_size: int = field( |
| | default=64, |
| | metadata={"help": "Size of the adapter layers"} |
| | ) |
| | dataset_name: str = field( |
| | default="mlfoundations/dclm-baseline-1.0", |
| | metadata={"help": "Name of the dataset to use"} |
| | ) |
| | max_steps: int = field( |
| | default=10000, |
| | metadata={"help": "Maximum number of training steps"} |
| | ) |
| | learning_rate: float = field( |
| | default=5e-5, |
| | metadata={"help": "Learning rate for fine-tuning"} |
| | ) |
| | per_device_batch_size: int = field( |
| | default=8, |
| | metadata={"help": "Batch size per device"} |
| | ) |
| | gradient_accumulation_steps: int = field( |
| | default=1, |
| | metadata={"help": "Number of steps to accumulate gradients"} |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | def main(): |
| | |
| | parser = HfArgumentParser(ScriptArguments) |
| | args = parser.parse_args_into_dataclasses()[0] |
| | |
| | |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | |
| | |
| | cmd = [ |
| | "python", |
| | "train_olmoe_adapter.py", |
| | |
| | |
| | f"--model_name_or_path={args.model_path}", |
| | f"--adapter_size={args.adapter_size}", |
| | "--freeze_base_model=True", |
| | f"--checkpoint_dir={args.output_dir}", |
| | |
| | |
| | f"--dataset_name={args.dataset_name}", |
| | "--streaming=True", |
| | "--streaming_buffer_size=8192", |
| | "--max_seq_length=1024", |
| | |
| | |
| | f"--output_dir={args.output_dir}", |
| | f"--per_device_train_batch_size={args.per_device_batch_size}", |
| | f"--gradient_accumulation_steps={args.gradient_accumulation_steps}", |
| | f"--learning_rate={args.learning_rate}", |
| | f"--max_steps={args.max_steps}", |
| | "--warmup_steps=500", |
| | "--logging_steps=10", |
| | "--save_steps=1000", |
| | "--save_total_limit=2", |
| | "--dataloader_num_workers=4", |
| | "--seed=42", |
| | ] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | cmd_str = " ".join(cmd) |
| | print(f"Running command: {cmd_str}") |
| | |
| | |
| | os.environ["PYTHONPATH"] = os.getcwd() |
| | ret = os.system(cmd_str) |
| | |
| | if ret != 0: |
| | print(f"Training failed with exit code {ret}") |
| | sys.exit(ret) |
| | |
| | print("Training completed successfully!") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |