WrinkleBrane / Wrinkle /09_standalone_model /benchmark_optimizers.py
WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
#!/usr/bin/env python3
"""Benchmark: AdamW vs Adafactor vs Muon on WrinkleBrane.
All runs use FP32 baseline config (same as the 10K step run).
500 steps each, same data, same seed for model init.
Optimizers:
1. AdamW β€” current baseline (lr=3e-4, betas=(0.9, 0.95), wd=0.01)
2. Adafactor β€” memory-efficient adaptive, no external scheduler (PyTorch 2.8)
3. Muon β€” Momentum + Newton-Schulz orthogonalization, lr=0.05, clip=2.0
DESIGNED FOR JUPYTERLAB TERMINAL β€” output is tee'd to both stdout and a
timestamped log file under logs/ (or --log_dir) so browser crashes don't
lose results. Run as:
cd /data/WrinkleBrane-Research/09_standalone_model
OMP_NUM_THREADS=8 MKL_NUM_THREADS=8 PYTHONPATH=src \\
nohup python3 -u benchmark_optimizers.py &
tail -f logs/benchmark_<timestamp>.log
"""
from __future__ import annotations
import copy
import gc
import math
import os
import subprocess
import sys
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argparse
import torch
from torch import nn, Tensor
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src"))
from wrinklebrane.standalone_model import WrinkleBraneConfig, WrinkleBraneModel
from wrinklebrane.data import load_train_val, VOCAB_SIZE
# ============================================================================
# Tee: write to stdout AND a log file simultaneously
# ============================================================================
class Tee:
"""Duplicate stdout writes to a log file.
Replaces sys.stdout so all print() calls go to both the terminal
and a persistent log file β€” survives browser/VSCode crashes.
"""
def __init__(self, log_path: str):
self.terminal = sys.__stdout__
os.makedirs(os.path.dirname(log_path), exist_ok=True)
self.log = open(log_path, "w", buffering=1) # line-buffered
print(f" Logging to: {log_path}", file=self.terminal)
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.log.close()
sys.stdout = self.terminal
# ============================================================================
# Configuration
# ============================================================================
BENCHMARK_STEPS = 500
BATCH_SIZE = 16
SEQ_LEN = 128
WARMUP = 50
LOG_EVERY = 50
EVAL_EVERY = 100
SEED = 42
def make_config() -> WrinkleBraneConfig:
"""Same optimal config as the 10K training run."""
return WrinkleBraneConfig(
vocab_size=VOCAB_SIZE,
d_model=128,
n_layers=6,
n_heads=4,
L=16,
K=SEQ_LEN,
code_init="hadamard",
learnable_codes=True,
temperature=0.5,
ffn_expansion=4,
use_gated_ffn=True,
max_seq_len=SEQ_LEN,
dropout=0.1,
ortho_lambda=0.01,
persistence_lambda=0.99,
weight_tying=True,
)
# ============================================================================
# Muon Optimizer β€” Momentum with Newton-Schulz Orthogonalization
# ============================================================================
class Muon(torch.optim.Optimizer):
"""Muon: Momentum + Orthogonalized Updates via Newton-Schulz.
For matrix-shaped parameters (2D), Muon orthogonalizes the momentum
buffer using Newton-Schulz iterations before applying the update.
This naturally preserves orthogonal structure in weight matrices.
For non-matrix parameters (1D scalars, biases, embeddings), falls
back to standard momentum SGD.
Particularly suited for WrinkleBrane because:
- Codebook parameters benefit from orthogonality preservation
- Projection matrices (W_v, W_q, W_o) stay well-conditioned
- The Newton-Schulz step acts like a natural preconditioner
Reference: Keller Jordan et al., "Muon: An optimizer for hidden layers"
Parameters
----------
params : iterable
Model parameters.
lr : float
Learning rate (default: 0.02, typically higher than Adam).
momentum : float
Momentum coefficient (default: 0.95).
ns_steps : int
Number of Newton-Schulz orthogonalization iterations (default: 5).
weight_decay : float
Decoupled weight decay (default: 0.0).
"""
def __init__(
self,
params,
lr: float = 0.02,
momentum: float = 0.95,
ns_steps: int = 5,
weight_decay: float = 0.0,
):
defaults = dict(
lr=lr,
momentum=momentum,
ns_steps=ns_steps,
weight_decay=weight_decay,
)
super().__init__(params, defaults)
@staticmethod
def newton_schulz_orthogonalize(M: Tensor, steps: int = 5) -> Tensor:
"""Orthogonalize matrix M using Newton-Schulz iteration.
Computes the polar factor of M: the nearest orthogonal matrix.
Uses the iteration: X_{k+1} = 1.5 * X_k - 0.5 * X_k @ X_k^T @ X_k
For efficiency, operates on the smaller dimension.
Parameters
----------
M : Tensor [m, n]
steps : int
Number of NS iterations (5 is usually sufficient).
Returns
-------
Tensor [m, n]
Orthogonalized matrix.
"""
m, n = M.shape
transpose = False
# Work with the smaller dimension for efficiency
if m < n:
M = M.T
transpose = True
# Normalize to unit spectral norm (approximate)
X = M / (M.norm() + 1e-7)
# Newton-Schulz iterations: X = 1.5*X - 0.5*X@X^T@X
for _ in range(steps):
A = X @ X.T
X = 1.5 * X - 0.5 * A @ X
if transpose:
X = X.T
return X
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
momentum = group["momentum"]
ns_steps = group["ns_steps"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
# Decoupled weight decay
if wd > 0:
p.mul_(1 - lr * wd)
# Get or create momentum buffer
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(p)
buf = state["momentum_buffer"]
buf.mul_(momentum).add_(grad)
# For 2D parameters: apply Newton-Schulz orthogonalization
if p.dim() == 2 and min(p.shape) > 1:
update = self.newton_schulz_orthogonalize(buf, steps=ns_steps)
# Scale update to match the gradient magnitude
update = update * (buf.norm() / (update.norm() + 1e-7))
else:
# For 1D/scalar params: standard momentum
update = buf
p.add_(update, alpha=-lr)
return loss
# ============================================================================
# Training infrastructure
# ============================================================================
@dataclass
class BenchmarkResult:
name: str
steps: int
total_time: float
losses: List[float]
eval_losses: List[Tuple[int, float]]
tok_per_sec: List[float]
grad_norms: List[float]
@property
def avg_tok_per_sec(self) -> float:
return sum(self.tok_per_sec) / len(self.tok_per_sec) if self.tok_per_sec else 0
@property
def final_loss(self) -> float:
return self.losses[-1] if self.losses else float("nan")
@property
def final_eval_loss(self) -> float:
return self.eval_losses[-1][1] if self.eval_losses else float("nan")
def evaluate(model, val_corpus, seq_len, batch_size=16, n_batches=10, device=None):
"""Evaluate on validation data."""
if device is None:
device = next(model.parameters()).device
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for _ in range(n_batches):
inp, tgt = val_corpus.get_batch(batch_size, seq_len)
inp, tgt = inp.to(device), tgt.to(device)
logits = model(inp)
B, T, V = logits.shape
loss = nn.functional.cross_entropy(
logits.reshape(B * T, V), tgt.reshape(B * T),
reduction="sum",
)
total_loss += loss.item()
total_tokens += B * T
return total_loss / total_tokens
def run_benchmark(
name: str,
model: WrinkleBraneModel,
config: WrinkleBraneConfig,
optimizer: torch.optim.Optimizer,
scheduler,
train_corpus,
val_corpus,
steps: int = BENCHMARK_STEPS,
grad_clip: float = 1.0,
device: torch.device = None,
) -> BenchmarkResult:
"""Run a training benchmark for the given number of steps."""
if device is None:
device = next(model.parameters()).device
print(f"\n{'='*70}")
print(f" BENCHMARK: {name}")
print(f" {steps} steps, batch_size={BATCH_SIZE}, seq_len={SEQ_LEN}")
print(f"{'='*70}")
param_count = sum(p.numel() for p in model.parameters())
print(f" Parameters: {param_count:,}")
# Initial eval
init_eval = evaluate(model, val_corpus, SEQ_LEN)
print(f" Initial eval loss: {init_eval:.4f} (PPL {math.exp(min(init_eval, 20)):.2f})")
# Tracking
losses = []
eval_losses = [(0, init_eval)]
tok_per_sec_list = []
grad_norms_list = []
running_loss = 0.0
running_tokens = 0
interval_start = time.time()
total_start = time.time()
print(f"\n Training started at {time.strftime('%H:%M:%S')}")
print(f" {'─'*64}")
for step in range(1, steps + 1):
model.train()
optimizer.zero_grad()
inp, tgt = train_corpus.get_batch(BATCH_SIZE, SEQ_LEN)
inp, tgt = inp.to(device), tgt.to(device)
# Forward
logits = model(inp)
B, T, V = logits.shape
task_loss = nn.functional.cross_entropy(
logits.reshape(B * T, V), tgt.reshape(B * T),
)
# Ortho regularization
ortho = config.ortho_lambda * model.ortho_loss()
total_loss = task_loss + ortho
# Backward
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
if scheduler is not None:
scheduler.step()
running_loss += task_loss.item() * B * T
running_tokens += B * T
grad_norms_list.append(grad_norm.item())
# Log
if step % LOG_EVERY == 0:
now = time.time()
elapsed = now - interval_start
avg_loss = running_loss / running_tokens
tps = running_tokens / elapsed
losses.append(avg_loss)
tok_per_sec_list.append(tps)
# Get current LR
if scheduler is not None:
lr = scheduler.get_last_lr()[0]
else:
lr = optimizer.param_groups[0]["lr"]
print(f" step {step:4d}/{steps} | "
f"loss={avg_loss:.4f} ppl={math.exp(min(avg_loss, 20)):7.2f} | "
f"lr={lr:.2e} gnorm={grad_norm:.2f} | "
f"{tps:6.0f} tok/s",
flush=True)
running_loss = 0.0
running_tokens = 0
interval_start = time.time()
# Eval
if step % EVAL_EVERY == 0:
val_loss = evaluate(model, val_corpus, SEQ_LEN)
eval_losses.append((step, val_loss))
print(f" >>> EVAL step {step}: loss={val_loss:.4f}, "
f"ppl={math.exp(min(val_loss, 20)):.2f}", flush=True)
total_time = time.time() - total_start
# Final eval
final_eval = evaluate(model, val_corpus, SEQ_LEN)
eval_losses.append((steps, final_eval))
print(f"\n {'─'*64}")
print(f" DONE: {name}")
print(f" Total time: {total_time:.1f}s ({total_time/60:.1f} min)")
print(f" Final train loss: {losses[-1]:.4f}" if losses else " No losses recorded")
print(f" Final eval loss: {final_eval:.4f} (PPL {math.exp(min(final_eval, 20)):.2f})")
print(f" Avg throughput: {sum(tok_per_sec_list)/len(tok_per_sec_list):.0f} tok/s" if tok_per_sec_list else "")
return BenchmarkResult(
name=name,
steps=steps,
total_time=total_time,
losses=losses,
eval_losses=eval_losses,
tok_per_sec=tok_per_sec_list,
grad_norms=grad_norms_list,
)
def make_cosine_scheduler(optimizer, warmup, total_steps):
"""Cosine LR schedule with linear warmup."""
def lr_schedule(step):
if step < warmup:
return step / warmup
progress = (step - warmup) / max(total_steps - warmup, 1)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
# ============================================================================
# Main benchmark
# ============================================================================
def main():
parser = argparse.ArgumentParser(description="WrinkleBrane optimizer benchmark")
parser.add_argument(
"--log_dir",
default=os.path.join(os.path.dirname(__file__), "logs"),
help="Directory for log files (default: ./logs/)",
)
parser.add_argument(
"--steps", type=int, default=BENCHMARK_STEPS,
help=f"Training steps per optimizer (default: {BENCHMARK_STEPS})",
)
parser.add_argument(
"--device", type=str, default=None,
help="Device: 'cuda', 'cuda:0', 'cpu'. Auto-detects if not set.",
)
args = parser.parse_args()
# Redirect stdout to Tee (stdout + log file) β€” survives browser crashes
timestamp = time.strftime("%Y%m%d_%H%M%S")
log_path = os.path.join(args.log_dir, f"benchmark_optimizers_{timestamp}.log")
tee = Tee(log_path)
sys.stdout = tee
# Device setup
if args.device:
device = torch.device(args.device)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("=" * 70)
print(" WrinkleBrane Optimizer Benchmark")
print(" AdamW vs Adafactor vs Muon β€” 500 steps each (FP32) [v2: fixed configs]")
print("=" * 70)
print()
# System info
print(" Hardware:")
cpu_model = subprocess.check_output(
"grep 'model name' /proc/cpuinfo | head -1 | cut -d: -f2",
shell=True, text=True,
).strip()
print(f" CPU: {cpu_model}")
print(f" Device: {device}")
if device.type == "cuda":
print(f" GPU: {torch.cuda.get_device_name(device)}")
print(f" VRAM: {torch.cuda.get_device_properties(device).total_memory / 1e9:.1f} GB")
# Set thread count properly (CPU only; on GPU, threads matter less)
n_threads = int(subprocess.check_output("nproc", text=True).strip())
torch.set_num_threads(n_threads)
os.environ["OMP_NUM_THREADS"] = str(n_threads)
os.environ["MKL_NUM_THREADS"] = str(n_threads)
print(f" Torch threads: {n_threads}")
print(f" PyTorch: {torch.__version__}")
print()
# Load data once
print(" Loading data...")
train_corpus, val_corpus = load_train_val(
"/data/WrinkleBrane-Research/raw"
)
print()
results = []
# ────────────────────────────────────────────────────────────────────
# Benchmark 1: AdamW (baseline)
# ────────────────────────────────────────────────────────────────────
print(" β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(" β”‚ 1/3: AdamW (baseline) β”‚")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
torch.manual_seed(SEED)
config = make_config()
model = WrinkleBraneModel(config).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
weight_decay=0.01,
betas=(0.9, 0.95),
)
scheduler = make_cosine_scheduler(optimizer, WARMUP, BENCHMARK_STEPS)
result = run_benchmark(
name="AdamW (lr=3e-4, Ξ²=(0.9,0.95), wd=0.01)",
model=model,
config=config,
optimizer=optimizer,
scheduler=scheduler,
train_corpus=train_corpus,
val_corpus=val_corpus,
device=device,
)
results.append(result)
del model, optimizer, scheduler
gc.collect()
# ────────────────────────────────────────────────────────────────────
# Benchmark 2: Adafactor
# ────────────────────────────────────────────────────────────────────
print("\n β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(" β”‚ 2/3: Adafactor β”‚")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
torch.manual_seed(SEED)
config = make_config()
model = WrinkleBraneModel(config).to(device)
# Adafactor config (v2 β€” fixed):
# - lr=1e-3 constant, NO external scheduler
# - Adafactor manages its own adaptive second moment via beta2_decay
# - Layering a cosine schedule on top (v1) was double-scheduling it,
# fighting the internal adaptive rate and preventing convergence
# - eps=(None, 1e-3): no row-factor (None), column epsilon=1e-3
# - weight_decay=0.01 (match AdamW)
optimizer = torch.optim.Adafactor(
model.parameters(),
lr=1e-3,
beta2_decay=-0.8,
eps=(None, 1e-3),
weight_decay=0.01,
)
scheduler = None # Adafactor manages its own schedule internally
result = run_benchmark(
name="Adafactor (lr=1e-3, no sched, Ξ²2d=-0.8, wd=0.01)",
model=model,
config=config,
optimizer=optimizer,
scheduler=scheduler,
train_corpus=train_corpus,
val_corpus=val_corpus,
device=device,
)
results.append(result)
del model, optimizer, scheduler
gc.collect()
# ────────────────────────────────────────────────────────────────────
# Benchmark 3: Muon
# ────────────────────────────────────────────────────────────────────
print("\n β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”")
print(" β”‚ 3/3: Muon β”‚")
print(" β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜")
torch.manual_seed(SEED)
config = make_config()
model = WrinkleBraneModel(config).to(device)
# Muon config (v2 β€” fixed):
# - lr=0.05 (v1's 0.02 plateaued hard after step 150; orthogonalized
# updates are naturally bounded so Muon can push much harder)
# - momentum=0.95 (standard heavy momentum)
# - ns_steps=5 (Newton-Schulz iterations, standard)
# - weight_decay=0.01 (decoupled, match others)
# - grad_clip=2.0 (v1's avg grad norm was 0.474 β€” well under 1.0,
# so the clip was never helping; raise ceiling to let Muon breathe)
optimizer = Muon(
model.parameters(),
lr=0.05,
momentum=0.95,
ns_steps=5,
weight_decay=0.01,
)
scheduler = make_cosine_scheduler(optimizer, WARMUP, BENCHMARK_STEPS)
result = run_benchmark(
name="Muon (lr=0.05, mom=0.95, ns=5, wd=0.01, clip=2.0)",
model=model,
config=config,
optimizer=optimizer,
scheduler=scheduler,
train_corpus=train_corpus,
val_corpus=val_corpus,
grad_clip=2.0,
device=device,
)
results.append(result)
del model, optimizer, scheduler
gc.collect()
# ────────────────────────────────────────────────────────────────────
# Results comparison
# ────────────────────────────────────────────────────────────────────
print("\n\n")
print("=" * 70)
print(" OPTIMIZER BENCHMARK RESULTS")
print("=" * 70)
print()
# Summary table
header = f" {'Metric':<28}"
sep = f" {'─'*28}"
for r in results:
short_name = r.name.split("(")[0].strip()
header += f" β”‚ {short_name:<26}"
sep += f" β”‚{'─'*26}"
print(header)
print(sep)
# Total time
row = f" {'Total time':<28}"
for r in results:
row += f" β”‚ {r.total_time:>7.1f}s ({r.total_time/60:>4.1f}m) "
print(row)
# Time per step
row = f" {'Time per step':<28}"
for r in results:
row += f" β”‚ {r.total_time/r.steps*1000:>7.1f} ms/step "
print(row)
# Throughput
row = f" {'Avg throughput':<28}"
for r in results:
row += f" β”‚ {r.avg_tok_per_sec:>7.0f} tok/s "
print(row)
# Final train loss
row = f" {'Final train loss':<28}"
for r in results:
row += f" β”‚ {r.final_loss:>7.4f} "
print(row)
# Final eval loss
row = f" {'Final eval loss':<28}"
for r in results:
row += f" β”‚ {r.final_eval_loss:>7.4f} "
print(row)
# Final eval PPL
row = f" {'Final eval PPL':<28}"
for r in results:
ppl = math.exp(min(r.final_eval_loss, 20))
row += f" β”‚ {ppl:>7.2f} "
print(row)
# Avg gradient norm
row = f" {'Avg gradient norm':<28}"
for r in results:
avg_gn = sum(r.grad_norms) / len(r.grad_norms) if r.grad_norms else 0
row += f" β”‚ {avg_gn:>7.3f} "
print(row)
# Peak gradient norm
row = f" {'Peak gradient norm':<28}"
for r in results:
peak_gn = max(r.grad_norms) if r.grad_norms else 0
row += f" β”‚ {peak_gn:>7.3f} "
print(row)
# Best eval loss (with step)
row = f" {'Best eval loss (step)':<28}"
for r in results:
best_step, best_loss = min(r.eval_losses, key=lambda x: x[1])
row += f" β”‚ {best_loss:.4f} @ {best_step:<4} "
print(row)
# Relative comparisons
print(f"\n{sep}")
base_eval = results[0].final_eval_loss
base_time = results[0].total_time
row = f" {'Eval loss vs AdamW':<28}"
for r in results:
delta = r.final_eval_loss - base_eval
pct = (delta / base_eval) * 100
sign = "+" if delta >= 0 else ""
row += f" β”‚ {sign}{delta:>6.4f} ({sign}{pct:.1f}%) "
print(row)
row = f" {'Time vs AdamW':<28}"
for r in results:
ratio = r.total_time / base_time
row += f" β”‚ {ratio:>6.2f}x "
print(row)
# ────────────────────────────────────────────────────────────────────
# Loss curves
# ────────────────────────────────────────────────────────────────────
print(f"\n\n EVAL LOSS CURVES")
header = f" {'Step':<8}"
sep2 = f" {'─'*8}"
for r in results:
short_name = r.name.split("(")[0].strip()
header += f" β”‚ {short_name:<26}"
sep2 += f" β”‚{'─'*26}"
print(header)
print(sep2)
all_steps = sorted(set(s for r in results for s, _ in r.eval_losses))
for step in all_steps:
row = f" {step:<8}"
for r in results:
val = next((l for s, l in r.eval_losses if s == step), None)
if val is not None:
ppl = math.exp(min(val, 20))
row += f" β”‚ {val:>6.4f} (PPL {ppl:>6.2f}) "
else:
row += f" β”‚ {'β€”':>24} "
print(row)
# ────────────────────────────────────────────────────────────────────
# Training loss curves (per LOG_EVERY)
# ────────────────────────────────────────────────────────────────────
print(f"\n\n TRAINING LOSS CURVES")
header = f" {'Step':<8}"
for r in results:
short_name = r.name.split("(")[0].strip()
header += f" β”‚ {short_name:<26}"
print(header)
print(sep2)
max_entries = max(len(r.losses) for r in results)
for i in range(max_entries):
step = (i + 1) * LOG_EVERY
row = f" {step:<8}"
for r in results:
if i < len(r.losses):
ppl = math.exp(min(r.losses[i], 20))
row += f" β”‚ {r.losses[i]:>6.4f} (PPL {ppl:>6.2f}) "
else:
row += f" β”‚ {'β€”':>24} "
print(row)
# Winner announcement
print(f"\n\n{'='*70}")
best_result = min(results, key=lambda r: r.final_eval_loss)
best_ppl = math.exp(min(best_result.final_eval_loss, 20))
print(f" WINNER: {best_result.name}")
print(f" Final eval loss: {best_result.final_eval_loss:.4f} (PPL {best_ppl:.2f})")
improvement = (base_eval - best_result.final_eval_loss) / base_eval * 100
if best_result.name != results[0].name:
print(f" Improvement over AdamW: {improvement:.2f}%")
print(f"{'='*70}")
print(f"\n Log saved to: {log_path}", flush=True)
tee.close()
if __name__ == "__main__":
main()