File size: 1,387 Bytes
fd0b01f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# merge lora into base, export fp32 onnx, quantize to int8.
# writes runs/<run-id>/merged/, onnx/model.onnx, onnx/int8/model.onnx.

import argparse
from pathlib import Path

from cleanup.config import load_train_config
from cleanup.export.merge import merge_adapter
from cleanup.export.quantize import quantize_int8
from cleanup.export.to_onnx import export_onnx


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="configs/train.yaml")
    parser.add_argument("--runs-dir", default="runs")
    parser.add_argument("--run-id", required=True)
    parser.add_argument("--skip-int8", action="store_true")
    parser.add_argument("--skip-onnx", action="store_true", help="merge only, no onnx export")
    args = parser.parse_args()

    cfg = load_train_config(args.config)
    run_dir = Path(args.runs_dir) / args.run_id
    adapter_dir = run_dir / "model"
    merged_dir = run_dir / "merged"
    onnx_dir = run_dir / "onnx"
    int8_dir = onnx_dir / "int8"

    merge_adapter(cfg, adapter_dir, merged_dir)
    if args.skip_onnx:
        print("[export] skipping onnx per --skip-onnx")
        return
    fp32_onnx = export_onnx(merged_dir, onnx_dir)
    if not args.skip_int8:
        quantize_int8(fp32_onnx, int8_dir)
    print(f"next: make benchmark RUN_ID={args.run_id}    (LOCAL cpu only)")


if __name__ == "__main__":
    main()