AGILLM4-diffusionblocks / dblocks_train.py
Scott/Codex
Tune DBlock backward-math speed line
9c90574
Raw
History Blame Contribute Delete
16.9 kB
"""DiffusionBlocks training mode folded into AGILLM-4 (gated by --dblock).
Block-wise EDM denoising on the real Encoder blocks, supervising AR + SAT(fixed+var)
+ NAT each step on ONE block, with grad-checkpointed layers and fused vocab-streaming
CE. Reuses the live data stream / optimizer / checkpointing of nB300_agillm4.
Lazy-imports nB300 inside functions to avoid a circular import.
"""
import math
import random
import time
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as _ck
from fused_ce import fused_ce
SD = 0.5
def _profile_active(state, args):
limit = int(getattr(args, "profile_steps", 0) or 0)
return limit > 0 and int(state.get("profile_n", 0)) < limit
def _profile_add(state, name, seconds):
if seconds is None:
return
prof = state.setdefault("profile_times", defaultdict(float))
prof[name] += float(seconds)
def _profile_tic(enabled):
if not enabled:
return None
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter()
def _profile_toc(state, name, start):
if start is None:
return
if torch.cuda.is_available():
torch.cuda.synchronize()
_profile_add(state, name, time.perf_counter() - start)
def _profile_step_done(state, args):
limit = int(getattr(args, "profile_steps", 0) or 0)
if limit <= 0:
return
n_prev = int(state.get("profile_n", 0))
if n_prev >= limit:
return
state["profile_n"] = n_prev + 1
n = int(state["profile_n"])
log_every = max(1, int(getattr(args, "profile_log_every", 25) or 25))
if n % log_every != 0 and n != limit:
return
times = state.get("profile_times", {})
keys = [
"data_stream", "tensor", "setup",
"ar_forward", "ar_ce", "ar_backward",
"sat_forward", "sat_ce", "sat_backward",
"nat_forward", "nat_ce", "nat_backward",
"opt_step", "step_total",
]
parts = []
for key in keys:
val = float(times.get(key, 0.0)) * 1000.0 / max(1, n)
if val > 0.01:
parts.append(f"{key}={val:.2f}ms")
print(f"[profile] n={n}/{limit} avg " + " ".join(parts), flush=True)
def _cdf(x):
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
def _ppf(p):
return float(torch.erfinv(torch.tensor(2 * p - 1.0)) * math.sqrt(2))
def _block_sigmas(B, smin=0.002, smax=80.0, pm=-1.2, ps=1.2):
a, b = _cdf((math.log(smin) - pm) / ps), _cdf((math.log(smax) - pm) / ps)
return [float(np.exp(pm + ps * _ppf(a + (b - a) * (i / B)))) for i in range(B + 1)]
def _edm_pre(s):
s = s[:, None, None]
return SD**2 / (s**2 + SD**2), s * SD / (s**2 + SD**2) ** 0.5, 1 / (s**2 + SD**2) ** 0.5
def _edm_w(s, wmax=5.0):
return float(((s**2 + SD**2) / (s * SD) ** 2).clamp(max=wmax).mean())
def _dblock_init(core, args):
B = int(getattr(args, "dblock_blocks", 4))
L = len(core.blocks)
sp = max(1, L // B)
asg = [list(range(i * sp, (i + 1) * sp)) for i in range(B)]
asg[-1] = list(range((B - 1) * sp, L))
bsig = _block_sigmas(B)
schedule = getattr(args, "dblock_schedule", "loss_balanced")
print(f"[dblock] DiffusionBlocks mode: {L} layers -> {B} blocks {asg}")
print(f"[dblock] schedule={schedule} sigma boundaries: {[round(x, 3) for x in bsig]}")
return {
"B": B,
"assign": asg,
"bsig": bsig,
"step": 0,
"counts": [0 for _ in range(B)],
"loss_ema": [None for _ in range(B)],
}
def _choose_block(state, args):
B = state["B"]
schedule = str(getattr(args, "dblock_schedule", "loss_balanced") or "loss_balanced").lower()
step = int(state.get("step", 0))
counts = state.setdefault("counts", [0 for _ in range(B)])
emas = state.setdefault("loss_ema", [None for _ in range(B)])
if schedule == "random":
return random.randrange(B)
if schedule == "roundrobin":
return step % B
explore = float(getattr(args, "dblock_explore", 0.05))
warmup = int(getattr(args, "dblock_warmup_steps", max(8, B * 2)))
if step < warmup or any(c == 0 for c in counts):
return min(range(B), key=lambda i: (counts[i], i))
if explore > 0.0 and random.random() < explore:
return min(range(B), key=lambda i: (counts[i], i))
return max(range(B), key=lambda i: (-1.0 if emas[i] is None else emas[i], -counts[i]))
def _sample_sigma(ids, lo, hi, args, state):
cur_step = int(state.get("step", 0))
curriculum = int(getattr(args, "dblock_sigma_curriculum_steps", 0))
if curriculum > 0:
frac = min(1.0, max(0.05, (cur_step + 1) / float(curriculum)))
hi = lo * ((hi / max(lo, 1e-8)) ** frac)
sig_np = np.exp(
np.random.uniform(
math.log(max(lo, 1e-4)),
math.log(max(hi, lo + 1e-4)),
ids.size(0),
).astype("float32")
)
return torch.from_numpy(sig_np).to(ids.device)
def _maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=None):
log_every = int(getattr(args, "dblock_log_every", 50))
step = int(state.get("step", 0))
if log_every <= 0 or step % log_every != 0:
return
counts = ",".join(str(x) for x in state.get("counts", []))
emas = ",".join("nan" if x is None else f"{x:.2f}" for x in state.get("loss_ema", []))
mem = ""
if peak_alloc is not None:
mem = f" peak_alloc={peak_alloc:.2f}GB peak_reserved={peak_reserved:.2f}GB"
print(
f"[dblock] step={step} block={bi} obj={objective or 'mixed'} layers={layers} "
f"loss={total_val:.3f} ar={ar_val:.3f} sat={sat_val:.3f} nat={nat_val:.3f} "
f"counts=[{counts}] ema=[{emas}]{mem}",
flush=True,
)
def _update_stats(state, bi, loss_value):
B = state["B"]
counts = state.setdefault("counts", [0 for _ in range(B)])
emas = state.setdefault("loss_ema", [None for _ in range(B)])
counts[bi] += 1
prev = emas[bi]
beta = 0.96
emas[bi] = float(loss_value) if prev is None else beta * float(prev) + (1.0 - beta) * float(loss_value)
state["step"] = int(state.get("step", 0)) + 1
def _activation_offload_enabled(args):
return bool(getattr(args, "dblock_activation_offload", False)) and torch.cuda.is_available()
def _activation_offload_hooks(args):
min_bytes = int(float(getattr(args, "dblock_activation_offload_min_mb", 1.0) or 1.0) * 1024 * 1024)
def pack(t):
if not torch.is_tensor(t) or not t.is_cuda or not t.is_floating_point() or t.numel() * t.element_size() < min_bytes:
return t
return ("cpu_offload", t.device, t.detach().to("cpu", non_blocking=True))
def unpack(x):
if isinstance(x, tuple) and len(x) == 3 and x[0] == "cpu_offload":
_, dev, cpu_t = x
return cpu_t.to(dev, non_blocking=True)
return x
return torch.autograd.graph.saved_tensors_hooks(pack, unpack)
def _run_block(block, x, mask, use_checkpoint, args=None):
if use_checkpoint:
return _ck.checkpoint(lambda y, block=block: block(y, mask), x, use_reentrant=False)
if args is not None and _activation_offload_enabled(args):
with _activation_offload_hooks(args):
return block(x, mask)
return block(x, mask)
def _dblock_checkpoint_this_layer(args, base_enabled, layer_pos, layer_count=None):
if not base_enabled:
return False
pos = int(layer_pos)
count = int(layer_count or 0)
skip_tail = max(0, int(getattr(args, "dblock_checkpoint_skip_tail", 0) or 0))
if skip_tail > 0 and count > 0 and pos >= max(0, count - skip_tail):
return False
stride = int(getattr(args, "dblock_checkpoint_stride", 1) or 1)
if stride <= 0:
return False
if stride == 1:
return True
return (pos % stride) == 0
def _sample_token_loss_inputs(hidden, targets, max_tokens):
max_tokens = int(max_tokens or 0)
if max_tokens <= 0:
return hidden.contiguous(), targets.contiguous(), int(targets.numel()), int(targets.numel())
flat_targets = targets.reshape(-1)
total = int(flat_targets.numel())
if total <= max_tokens:
return hidden.contiguous(), targets.contiguous(), total, total
# With-replacement sampling avoids building a full randperm each step; the sampled
# mean remains an unbiased estimator of the dense token CE mean.
idx = torch.randint(total, (max_tokens,), device=targets.device)
flat_hidden = hidden.reshape(total, hidden.size(-1))
return flat_hidden.index_select(0, idx).contiguous(), flat_targets.index_select(0, idx).contiguous(), int(max_tokens), total
def _choose_objectives(state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic):
mode = str(getattr(args, "dblock_objective_mode", "periodic") or "periodic").lower()
if mode != "stochastic":
return ar_weight > 0.0, sat_weight > 0.0 and do_sat_periodic, nat_weight > 0.0 and do_nat_periodic, "periodic"
choices = []
probs = []
if ar_weight > 0.0:
choices.append("ar")
probs.append(max(0.0, float(getattr(args, "dblock_ar_prob", 0.80))))
if sat_weight > 0.0 and not getattr(args, "ar_only", False):
choices.append("sat")
probs.append(max(0.0, float(getattr(args, "dblock_sat_prob", 0.10))))
if nat_weight > 0.0 and not getattr(args, "ar_only", False):
choices.append("nat")
probs.append(max(0.0, float(getattr(args, "dblock_nat_prob", 0.10))))
if not choices:
return False, False, False, "none"
total = sum(probs)
if total <= 0.0:
probs = [1.0 / len(choices) for _ in choices]
else:
probs = [p / total for p in probs]
picked = random.choices(choices, weights=probs, k=1)[0]
return picked == "ar", picked == "sat", picked == "nat", picked
def _dblock_step(core, ar_h, sat_h, nat_h, opt, scaler, args, ids, state):
import nB300_agillm4 as M
prof = _profile_active(state, args)
_step_t = _profile_tic(prof)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
_setup_t = _profile_tic(prof)
B = state["B"]
asg = state["assign"]
bs = state["bsig"]
T = ids.size(1)
use_layer_checkpoint = bool(getattr(args, "grad_checkpoint", False))
bi = _choose_block(state, args)
lo, hi = sorted([bs[bi], bs[bi + 1]])
layers = asg[bi]
sig = _sample_sigma(ids, lo, hi, args, state)
cs, co, ci = _edm_pre(sig)
w = _edm_w(sig, float(getattr(args, "dblock_edm_wmax", 5.0)))
SATB = M.SAT_BLOCK
ar_weight = float(getattr(args, "dblock_ar_weight", 1.0))
sat_weight = float(getattr(args, "dblock_sat_weight", 1.0))
nat_weight = float(getattr(args, "dblock_nat_weight", 1.0)) * float(getattr(args, "nat_loss_weight", 1.0))
do_sat_periodic = (not getattr(args, "ar_only", False)) and (
int(getattr(args, "sat_every", 1)) <= 1 or ((int(state.get("step", 0)) + 1) % int(getattr(args, "sat_every", 1)) == 0)
)
do_nat_periodic = (
nat_h is not None
and (not getattr(args, "ar_only", False))
and int(getattr(args, "nat_every", 1)) > 0
and (
int(getattr(args, "nat_every", 1)) <= 1
or ((int(state.get("step", 0)) + 1) % int(getattr(args, "nat_every", 1)) == 0)
)
)
run_ar, run_sat, run_nat, objective = _choose_objectives(
state, args, ar_weight, sat_weight, nat_weight, do_sat_periodic, do_nat_periodic
)
_profile_toc(state, "setup", _setup_t)
ar_val = 0.0
sat_val = 0.0
nat_val = 0.0
if run_ar:
causal = M.causal_mask(T, structured=M.use_structured_masks(args))
_t = _profile_tic(prof)
with M.amp(args.amp):
emb = core.emb(ids)
zt = emb + sig[:, None, None] * torch.randn_like(emb)
h = ci * zt
for lpos, li in enumerate(layers):
h = _run_block(core.blocks[li], h, causal, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args)
Dn = core.ln(cs * zt + co * h)
_profile_toc(state, "ar_forward", _t)
_t = _profile_tic(prof)
ar_hidden, ar_targets, ar_used, ar_total = _sample_token_loss_inputs(
Dn[:, :-1], ids[:, 1:], int(getattr(args, "dblock_ar_loss_tokens", 0))
)
ar = ar_weight * w * fused_ce(ar_hidden, ar_h.proj.weight, ar_targets)
ar_val = float(ar.detach())
_profile_toc(state, "ar_ce", _t)
_t = _profile_tic(prof)
scaler.scale(ar).backward()
_profile_toc(state, "ar_backward", _t)
del causal, emb, zt, h, Dn, ar_hidden, ar_targets, ar, ar_used, ar_total
if run_sat:
smask = M.sat_mask(T, structured=M.use_structured_masks(args))
_t = _profile_tic(prof)
with M.amp(args.amp):
emb2 = core.emb(ids)
zt2 = emb2 + sig[:, None, None] * torch.randn_like(emb2)
h2 = ci * zt2
for lpos, li in enumerate(layers):
h2 = _run_block(core.blocks[li], h2, smask, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args)
Ds = core.ln(cs * zt2 + co * h2)
last = Ds[:, -SATB:]
_profile_toc(state, "sat_forward", _t)
_t = _profile_tic(prof)
sat_hidden, sat_targets, sat_used, sat_total = _sample_token_loss_inputs(
last, ids[:, 1 : SATB + 1], int(getattr(args, "dblock_sat_loss_tokens", 0))
)
with M.amp(args.amp):
satf = fused_ce(sat_hidden, sat_h.proj.weight, sat_targets)
satv = (
M.EMIT_LAMBDA
* F.cross_entropy(
sat_h.gate(Ds[:, 0].float()),
torch.ones(ids.size(0), dtype=torch.long, device=ids.device),
)
if sat_h.gate is not None
else 0.0
)
sat = sat_weight * w * (satf + satv)
_profile_toc(state, "sat_ce", _t)
sat_val = float(sat.detach())
_t = _profile_tic(prof)
scaler.scale(sat).backward()
_profile_toc(state, "sat_backward", _t)
del smask, emb2, zt2, h2, Ds, last, sat_hidden, sat_targets, satf, satv, sat
if run_nat:
ratio = min(max(float(getattr(args, "nat_mask_ratio", 0.5)), 0.05), 0.95)
nat_ids = M._nat_ids_for_training(ids, int(getattr(args, "nat_max_tokens", 0)))
_t = _profile_tic(prof)
with M.amp(args.amp):
nat_in = nat_ids.clone()
m = torch.rand(nat_ids.shape, device=nat_ids.device) < ratio
if not bool(m.any()):
m[..., -1] = True
nat_in[m] = M.BLANK
hn = core.emb(nat_in)
for lpos, li in enumerate(layers):
hn = _run_block(core.blocks[li], hn, None, _dblock_checkpoint_this_layer(args, use_layer_checkpoint, lpos, len(layers)), args)
Dnat = core.ln(hn)
_profile_toc(state, "nat_forward", _t)
_t = _profile_tic(prof)
nat_hidden = Dnat[m]
nat_targets = nat_ids[m]
nat_hidden, nat_targets, nat_used, nat_total = _sample_token_loss_inputs(
nat_hidden.unsqueeze(0), nat_targets.unsqueeze(0), int(getattr(args, "dblock_nat_loss_tokens", 0))
)
nat = nat_weight * fused_ce(nat_hidden, nat_h.proj.weight, nat_targets)
nat_val = float(nat.detach())
_profile_toc(state, "nat_ce", _t)
_t = _profile_tic(prof)
scaler.scale(nat).backward()
_profile_toc(state, "nat_backward", _t)
del nat_ids, nat_in, m, hn, Dnat, nat_hidden, nat_targets, nat, nat_used, nat_total
total_val = ar_val + sat_val + nat_val
if not math.isfinite(total_val):
opt.zero_grad(set_to_none=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"[dblock] non-finite loss {total_val}; skipped optimizer step", flush=True)
_profile_toc(state, "step_total", _step_t)
_profile_step_done(state, args)
_update_stats(state, bi, total_val)
return total_val
_t = _profile_tic(prof)
scaler.unscale_(opt)
nn.utils.clip_grad_norm_([p for g in opt.param_groups for p in g["params"]], 1.0)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
_profile_toc(state, "opt_step", _t)
peak_alloc = None
peak_reserved = None
if torch.cuda.is_available():
peak_alloc = torch.cuda.max_memory_allocated() / (1024**3)
peak_reserved = torch.cuda.max_memory_reserved() / (1024**3)
_profile_toc(state, "step_total", _step_t)
_profile_step_done(state, args)
_update_stats(state, bi, total_val)
_maybe_log(state, args, bi, layers, ar_val, sat_val, nat_val, total_val, peak_alloc, peak_reserved, objective=objective)
return total_val