| |
|
| | """ |
| | Addressed State Attention (ASA) - Analysis Harness |
| | |
| | Research implementation with mechanistic intervention capabilities. |
| | For efficient training without interventions, use asm_training.py instead. |
| | |
| | Features: |
| | - Slot-mask causal interventions (slot_mask, slot_mask_where, slot_mask_scope) |
| | - Refinement decomposition (orthogonal/parallel gating) |
| | - Per-head geometry logging |
| | - Configurable information storage (info_level, info_cfg) |
| | |
| | Checkpoint Compatibility: |
| | All parameter/buffer names match asm_training.py for weight sharing. |
| | Do NOT rename: slot_keys, Wk_write, Wv_write, Wq_read, out_proj, |
| | _alibi_slopes, _alibi_strength_param, _content_read_gamma_raw, |
| | slot_in/slot_q/slot_k/slot_v/slot_out, _slotspace_gate_raw, |
| | rope/rope_slotspace buffers. |
| | |
| | Repository: https://github.com/DigitalDaimyo/AddressedStateAttention |
| | Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/tree/main/paper_drafts |
| | """ |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional, Dict, Tuple, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | __all__ = [ |
| | 'AddressedStateAttention', |
| | 'ASMBlock', |
| | 'ASMLanguageModel', |
| | 'ASMTrainConfig', |
| | 'build_model_from_cfg', |
| | ] |
| |
|
| |
|
| | |
| |
|
| | def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| | x1 = x[..., ::2] |
| | x2 = x[..., 1::2] |
| | return torch.stack((-x2, x1), dim=-1).flatten(-2) |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__(self, dim: int, base: float = 10000.0): |
| | super().__init__() |
| | assert dim % 2 == 0, "RoPE requires even dim" |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| | self._cos_cached = None |
| | self._sin_cached = None |
| | self._t_cached = None |
| | self._device_cached = None |
| |
|
| | def get_cos_sin(self, T: int, device, dtype): |
| | if ( |
| | self._t_cached == T |
| | and self._cos_cached is not None |
| | and self._device_cached == device |
| | ): |
| | return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) |
| | t = torch.arange(T, device=device, dtype=self.inv_freq.dtype) |
| | freqs = torch.einsum("t,f->tf", t, self.inv_freq) |
| | emb = torch.cat([freqs, freqs], dim=-1) |
| | cos = emb.cos()[None, None, :, :] |
| | sin = emb.sin()[None, None, :, :] |
| | self._t_cached = T |
| | self._device_cached = device |
| | self._cos_cached = cos |
| | self._sin_cached = sin |
| | return cos.to(dtype=dtype), sin.to(dtype=dtype) |
| |
|
| |
|
| | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| | return (x * cos) + (_rotate_half(x) * sin) |
| |
|
| |
|
| | def alibi_slopes(num_heads: int, device=None, dtype=torch.float32) -> torch.Tensor: |
| | def get_slopes(n): |
| | def power_of_2_slopes(n): |
| | start = 2.0 ** (-(2.0 ** -(math.log2(n) - 3))) |
| | ratio = start |
| | return [start * (ratio ** i) for i in range(n)] |
| | if math.log2(n).is_integer(): |
| | return power_of_2_slopes(n) |
| | closest = 2 ** math.floor(math.log2(n)) |
| | return power_of_2_slopes(closest) + get_slopes(2 * closest)[0::2][: n - closest] |
| | return torch.tensor(get_slopes(num_heads), device=device, dtype=dtype) |
| |
|
| |
|
| | def _inv_softplus(y: torch.Tensor) -> torch.Tensor: |
| | return torch.log(torch.expm1(y)) |
| |
|
| |
|
| | def phi(x: torch.Tensor) -> torch.Tensor: |
| | """Performer-style feature map (elu + 1).""" |
| | return F.elu(x) + 1.0 |
| |
|
| |
|
| | |
| |
|
| | class AddressedStateAttention(nn.Module): |
| | """ |
| | Addressed State Attention (ASA) — unified research harness. |
| | |
| | Core mechanism |
| | -------------- |
| | * Prefix-softmax WRITE into K learned slots (streaming, O(T)) |
| | * READ routing from tokens → slots (softmax / top-k / external) |
| | * Content-conditioned READ term (gamma-weighted) |
| | * RoPE on write keys (geometry) |
| | * ALiBi bias on write logits (prefix-friendly) |
| | |
| | Slot-space refinement |
| | --------------------- |
| | * Causal linear attention in a low-dim slot-address coordinate space |
| | * Produces per-token signed weights over slots |
| | * Decoded through the same streaming slot-state basis |
| | * Gated by learnable ``slotspace_gate`` (softplus) |
| | |
| | Causal intervention (slot mask) |
| | ------------------------------- |
| | * ``slot_mask`` [K] float/bool, 1=keep 0=mask |
| | * ``slot_mask_where`` "read" | "content_read_only" | "slotspace_only" |
| | * ``slot_mask_scope`` "all" | "last_pos_only" |
| | |
| | Refine-delta intervention (instance attrs, NO-OP by default) |
| | ---------------------------------------------------------------- |
| | * ``_intv_mode`` "off" | "delta_par" | "delta_orth" | "orth_gate" | … |
| | * Decomposes refine delta into parallel / orthogonal vs base output |
| | * See User Guide for configuration details. |
| | |
| | Refine-geometry logging (NO output change) |
| | ------------------------------------------------ |
| | * ``_log_refine_geom = True`` enables per-head geometry vectors in info dict. |
| | |
| | Info storage |
| | ------------ |
| | * ``info_level`` "basic" | "logits" | "full" |
| | * ``info_cfg`` dict controlling which tensors to store, downsampling, CPU offload. |
| | """ |
| |
|
| | |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int = 8, |
| | num_slots: int = 8, |
| | dropout: float = 0.1, |
| | |
| | read_temperature: float = 1.0, |
| | write_temperature: float = 1.0, |
| | state_fp32: bool = True, |
| | slot_dropout: float = 0.0, |
| | normalize_k: bool = False, |
| | |
| | use_rope_keys: bool = True, |
| | rope_base: float = 10000.0, |
| | |
| | use_alibi_write: bool = True, |
| | alibi_strength_init: float = 0.1, |
| | learn_alibi_strength: bool = True, |
| | min_strength: float = 0.0, |
| | |
| | use_content_read: bool = True, |
| | content_read_init: float = -4.0, |
| | content_read_max_gamma: float = 3.0, |
| | |
| | use_slotspace_refine: bool = True, |
| | slotspace_dim: int = 32, |
| | slotspace_gate_init: float = -4.0, |
| | slotspace_dropout: float = 0.05, |
| | slotspace_signed_weights: bool = True, |
| | |
| | use_rope_slotspace: bool = True, |
| | rope_base_slotspace: float = 100000.0, |
| | |
| | write_chunk_size: int = 128, |
| | slotspace_chunk_size: int = 128, |
| | ): |
| | super().__init__() |
| | assert embed_dim % num_heads == 0 |
| | self.embed_dim = embed_dim |
| | self.num_heads = num_heads |
| | self.num_slots = num_slots |
| | self.head_dim = embed_dim // num_heads |
| |
|
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | self.read_temperature = float(read_temperature) |
| | self.write_temperature = float(write_temperature) |
| | self.state_fp32 = bool(state_fp32) |
| | self.slot_dropout = float(slot_dropout) |
| | self.normalize_k = bool(normalize_k) |
| | self.routing_override = None |
| |
|
| | self.use_rope_keys = bool(use_rope_keys) |
| | self.use_alibi_write = bool(use_alibi_write) |
| | self.learn_alibi_strength = bool(learn_alibi_strength) |
| | self.min_strength = float(min_strength) |
| |
|
| | self.use_content_read = bool(use_content_read) |
| | self.content_read_max_gamma = float(content_read_max_gamma) |
| |
|
| | self.use_slotspace_refine = bool(use_slotspace_refine) |
| | self.slotspace_dim = int(slotspace_dim) |
| | self.slotspace_dropout = nn.Dropout(float(slotspace_dropout)) |
| | self.slotspace_signed_weights = bool(slotspace_signed_weights) |
| |
|
| | self.write_chunk_size = int(write_chunk_size) |
| | self.slotspace_chunk_size = int(slotspace_chunk_size) |
| |
|
| | |
| | self.slot_keys = nn.Parameter( |
| | torch.randn(num_heads, num_slots, self.head_dim) / math.sqrt(self.head_dim) |
| | ) |
| |
|
| | |
| | self.Wk_write = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.Wv_write = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.Wq_read = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| |
|
| | |
| | self.rope = RotaryEmbedding(self.head_dim, base=rope_base) if self.use_rope_keys else None |
| |
|
| | |
| | if self.use_alibi_write: |
| | self.register_buffer("_alibi_slopes", alibi_slopes(num_heads), persistent=False) |
| | else: |
| | self.register_buffer("_alibi_slopes", torch.zeros(num_heads), persistent=False) |
| |
|
| | if self.use_alibi_write and self.learn_alibi_strength: |
| | init = torch.tensor(float(alibi_strength_init) - self.min_strength).clamp_min(1e-8) |
| | self._alibi_strength_param = nn.Parameter(_inv_softplus(init)) |
| | else: |
| | self._alibi_strength_param = None |
| | self.alibi_strength = float(alibi_strength_init) |
| |
|
| | |
| | if self.use_content_read: |
| | self._content_read_gamma_raw = nn.Parameter(torch.tensor(float(content_read_init))) |
| | else: |
| | self._content_read_gamma_raw = None |
| |
|
| | |
| | self.use_rope_slotspace = bool(use_rope_slotspace) and bool(self.use_slotspace_refine) |
| | if self.use_slotspace_refine: |
| | self.slot_in = nn.Linear(num_slots, self.slotspace_dim, bias=False) |
| | self.slot_q = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False) |
| | self.slot_k = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False) |
| | self.slot_v = nn.Linear(self.slotspace_dim, self.slotspace_dim, bias=False) |
| | self.slot_out = nn.Linear(self.slotspace_dim, num_slots, bias=False) |
| | self._slotspace_gate_raw = nn.Parameter(torch.tensor(float(slotspace_gate_init))) |
| | if self.use_rope_slotspace: |
| | assert (self.slotspace_dim % 2) == 0, "use_rope_slotspace requires even slotspace_dim" |
| | self.rope_slotspace = RotaryEmbedding(self.slotspace_dim, base=float(rope_base_slotspace)) |
| | else: |
| | self.rope_slotspace = None |
| | else: |
| | self.slot_in = None |
| | self.slot_q = self.slot_k = self.slot_v = None |
| | self.slot_out = None |
| | self._slotspace_gate_raw = None |
| | self.rope_slotspace = None |
| |
|
| | |
| | self._intv_mode: str = "off" |
| | self._intv_beta: float = 1.0 |
| | self._intv_score_kind: str = "orth_frac" |
| | self._intv_tau_kind: str = "pctl" |
| | self._intv_tau: float = 0.15 |
| | self._intv_tau_pctl: float = 75.0 |
| | self._intv_mask_mode: str = "soft" |
| | self._intv_soft_temp: float = 0.05 |
| | self._intv_par_beta: float = 1.0 |
| | self._intv_head_mask: Optional[torch.Tensor] = None |
| | self._intv_score_clip_pctl: float = 99.0 |
| |
|
| | |
| | self._log_refine_geom: bool = False |
| |
|
| | |
| |
|
| | def _alibi_strength(self, dtype, device) -> torch.Tensor: |
| | if not (self.use_alibi_write and self.learn_alibi_strength): |
| | return torch.tensor(self.alibi_strength, dtype=dtype, device=device) |
| | return (F.softplus(self._alibi_strength_param) + self.min_strength).to(dtype=dtype, device=device) |
| |
|
| | def _content_read_gamma(self, dtype, device) -> torch.Tensor: |
| | if not self.use_content_read: |
| | return torch.tensor(0.0, dtype=dtype, device=device) |
| | g = F.softplus(self._content_read_gamma_raw) |
| | if self.content_read_max_gamma is not None and self.content_read_max_gamma > 0: |
| | g = g.clamp(max=self.content_read_max_gamma) |
| | return g.to(dtype=dtype, device=device) |
| |
|
| | def _slotspace_gate(self, dtype, device) -> torch.Tensor: |
| | if not self.use_slotspace_refine: |
| | return torch.tensor(0.0, dtype=dtype, device=device) |
| | return F.softplus(self._slotspace_gate_raw).to(dtype=dtype, device=device) |
| |
|
| | |
| |
|
| | @staticmethod |
| | def _safe_exp_sub_max(s: torch.Tensor, m: torch.Tensor) -> torch.Tensor: |
| | diff = s - m |
| | diff = diff.masked_fill(~torch.isfinite(m), float("-inf")) |
| | return torch.exp(diff) |
| |
|
| | |
| |
|
| | def _resolve_slot_mask( |
| | self, |
| | slot_mask: Optional[torch.Tensor], |
| | *, |
| | B: int, H: int, L: int, K: int, |
| | device, dtype, scope: str, |
| | ) -> Optional[torch.Tensor]: |
| | """Expand [K] mask → [B,H,L,K]. Falls back to self.slot_mask attr.""" |
| | if slot_mask is None: |
| | slot_mask = getattr(self, "slot_mask", None) |
| | if slot_mask is None: |
| | return None |
| | sm = slot_mask.to(device=device, dtype=dtype) |
| | if sm.ndim != 1 or sm.numel() != K: |
| | raise ValueError(f"slot_mask must be shape [K]={K}, got {tuple(sm.shape)}") |
| | sm = sm.view(1, 1, 1, K) |
| | if scope == "all": |
| | return sm.expand(B, H, L, K) |
| | if scope == "last_pos_only": |
| | out = torch.ones((B, H, L, K), device=device, dtype=dtype) |
| | out[:, :, -1:, :] = sm.expand(B, H, 1, K) |
| | return out |
| | raise ValueError(f"Unknown slot_mask_scope={scope!r}") |
| |
|
| | @staticmethod |
| | def _apply_hard_mask_and_renorm(w: torch.Tensor, keep: torch.Tensor) -> torch.Tensor: |
| | w = w * keep.to(w.dtype) |
| | return w / w.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
| |
|
| | |
| |
|
| | @staticmethod |
| | def default_info_cfg() -> Dict: |
| | """Return default info_cfg dict. Copy and modify before passing to forward().""" |
| | return dict( |
| | store_read_weights=True, |
| | store_read_logits=True, |
| | store_write_logits=True, |
| | store_slot_state_norm=True, |
| | store_out1=False, |
| | store_delta=False, |
| | store_slot_w=False, |
| | detach_to_cpu=False, |
| | time_stride=1, |
| | batch_stride=1, |
| | ) |
| |
|
| | @staticmethod |
| | def _store_tensor( |
| | t: Optional[torch.Tensor], *, cfg: Dict, kind: str, |
| | ) -> Optional[torch.Tensor]: |
| | """Downsample + detach (+ optional CPU offload).""" |
| | if t is None: |
| | return None |
| | bstride = int(cfg.get("batch_stride", 1)) |
| | tstride = int(cfg.get("time_stride", 1)) |
| | to_cpu = bool(cfg.get("detach_to_cpu", False)) |
| | x = t |
| | if x.dim() >= 1 and bstride > 1: |
| | x = x[::bstride] |
| | if x.dim() == 4 and tstride > 1: |
| | if kind == "bhtk": |
| | x = x[:, :, ::tstride, :] |
| | elif kind == "bhkt": |
| | x = x[:, :, :, ::tstride] |
| | x = x.detach() |
| | if to_cpu: |
| | x = x.to("cpu", non_blocking=True) |
| | return x |
| |
|
| | |
| |
|
| | def _compute_read_weights( |
| | self, |
| | *, |
| | read_logits: torch.Tensor, |
| | read_logits_key: torch.Tensor, |
| | read_logits_content: Optional[torch.Tensor], |
| | routing_mode: str, |
| | routing_topk: int, |
| | read_weights_override: Optional[torch.Tensor], |
| | routing_noise: Optional[str], |
| | routing_noise_scale: float, |
| | rtemp: float, |
| | sm: Optional[torch.Tensor], |
| | slot_mask_where: str, |
| | B: int, H: int, L: int, K: int, |
| | T_total: int, |
| | t0: int, t1: int, |
| | q_read_c: torch.Tensor, |
| | slot_keys: torch.Tensor, |
| | slot_state_t: torch.Tensor, |
| | valid: Optional[torch.Tensor], |
| | state_dtype, |
| | ) -> torch.Tensor: |
| | """Compute read weights for one write-chunk. Handles noise, overrides, masks.""" |
| | |
| | if routing_noise is not None: |
| | if routing_noise == "gumbel": |
| | u = torch.rand_like(read_logits) |
| | g = -torch.log(-torch.log(u.clamp_min(1e-8)).clamp_min(1e-8)) |
| | read_logits = read_logits + routing_noise_scale * g |
| | elif routing_noise == "gaussian": |
| | read_logits = read_logits + routing_noise_scale * torch.randn_like(read_logits) |
| | else: |
| | raise ValueError(f"Unknown routing_noise={routing_noise}") |
| |
|
| | |
| | if self.routing_override is not None: |
| | if callable(self.routing_override): |
| | ctx = dict( |
| | t0=t0, t1=t1, B=B, H=H, T=T_total, K=K, d=self.head_dim, |
| | rtemp=rtemp, state_dtype=state_dtype, |
| | q_read_c=q_read_c, slot_keys=slot_keys, |
| | slot_state_t=slot_state_t, valid=valid, |
| | ) |
| | read_w = self.routing_override( |
| | t0, t1, read_logits, read_logits_key, read_logits_content, ctx, |
| | ) |
| | else: |
| | read_w = self.routing_override[:, :, t0:t1, :].to(read_logits.dtype) |
| | read_w = torch.nan_to_num(read_w, nan=0.0, posinf=0.0, neginf=0.0) |
| | read_w = read_w.clamp_min(0.0) |
| | read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
| |
|
| | else: |
| | if routing_mode == "softmax": |
| | read_w = torch.softmax(read_logits / rtemp, dim=-1) |
| | elif routing_mode == "top1": |
| | top = read_logits.argmax(dim=-1) |
| | read_w = F.one_hot(top, num_classes=K).to(read_logits.dtype) |
| | elif routing_mode == "topk": |
| | kk = max(1, min(K, int(routing_topk))) |
| | vals, idx = torch.topk(read_logits, k=kk, dim=-1) |
| | masked = torch.full_like(read_logits, float("-inf")) |
| | masked.scatter_(-1, idx, vals) |
| | read_w = torch.softmax(masked / rtemp, dim=-1) |
| | elif routing_mode == "external": |
| | if read_weights_override is None: |
| | raise ValueError("routing_mode='external' requires read_weights_override") |
| | if read_weights_override.shape[-2] == T_total: |
| | read_w = read_weights_override[:, :, t0:t1, :] |
| | else: |
| | read_w = read_weights_override |
| | read_w = read_w / read_w.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
| | else: |
| | raise ValueError(f"Unknown routing_mode={routing_mode}") |
| |
|
| | |
| | if slot_mask_where == "read" and sm is not None: |
| | read_w = self._apply_hard_mask_and_renorm(read_w, (sm > 0.0)) |
| |
|
| | return read_w |
| |
|
| | |
| |
|
| | def _apply_refine_intervention( |
| | self, |
| | out1: torch.Tensor, |
| | delta: torch.Tensor, |
| | slot_w: Optional[torch.Tensor], |
| | ): |
| | """Decompose refine delta into par/orth vs base output, optionally gate.""" |
| | eps = 1e-8 |
| | B, H, L, d = out1.shape |
| |
|
| | |
| | hm = getattr(self, "_intv_head_mask", None) |
| | if hm is not None: |
| | hm = hm.to(device=out1.device).view(1, H, 1, 1).to(dtype=out1.dtype) |
| |
|
| | out1_norm2 = (out1 * out1).sum(dim=-1, keepdim=True).clamp_min(eps) |
| | alpha = (delta * out1).sum(dim=-1, keepdim=True) / out1_norm2 |
| | delta_par = alpha * out1 |
| | delta_orth = delta - delta_par |
| |
|
| | logs = None |
| |
|
| | |
| | if getattr(self, "_log_refine_geom", False): |
| | out1n = out1.norm(dim=-1).clamp_min(eps) |
| | dn = delta.norm(dim=-1).clamp_min(eps) |
| | dparn = delta_par.norm(dim=-1) |
| | dorthn = delta_orth.norm(dim=-1) |
| | a = alpha.squeeze(-1) |
| | logs = dict( |
| | geom_alpha_mean=a.mean(dim=(0, 2)), |
| | geom_alpha_abs=a.abs().mean(dim=(0, 2)), |
| | geom_sign_pos=(a > 0).float().mean(dim=(0, 2)), |
| | geom_orth_frac=(dorthn / dn).mean(dim=(0, 2)), |
| | geom_d_ratio=(dn / out1n).mean(dim=(0, 2)), |
| | geom_dpar_ratio=(dparn / dn).mean(dim=(0, 2)), |
| | ) |
| |
|
| | mode = getattr(self, "_intv_mode", "off") |
| | if mode is None or mode == "off": |
| | return delta, logs |
| |
|
| | |
| | if mode == "delta_par": |
| | delta_mod = delta_par |
| | logs = logs or {} |
| | logs["alpha"] = alpha.squeeze(-1) |
| |
|
| | elif mode == "delta_orth": |
| | delta_mod = delta_orth |
| | logs = logs or {} |
| | logs["alpha"] = alpha.squeeze(-1) |
| |
|
| | elif mode == "delta_par_plus_orth": |
| | delta_mod = delta_par + delta_orth |
| | logs = logs or {} |
| | logs["alpha"] = alpha.squeeze(-1) |
| |
|
| | elif mode == "orth_gate": |
| | beta = float(getattr(self, "_intv_beta", 1.0)) |
| | sk = getattr(self, "_intv_score_kind", "orth_frac") |
| | out1n = out1.norm(dim=-1).clamp_min(eps) |
| | dorthn = delta_orth.norm(dim=-1) |
| | dn = delta.norm(dim=-1).clamp_min(eps) |
| |
|
| | if sk == "orth_ratio": |
| | score = dorthn / out1n |
| | elif sk == "orth_frac": |
| | score = dorthn / dn |
| | elif sk == "alpha_abs": |
| | score = alpha.abs().squeeze(-1) |
| | elif sk == "slot_peaked": |
| | if slot_w is None: |
| | raise ValueError("score_kind='slot_peaked' requires slot_w") |
| | p = torch.softmax(slot_w.float(), dim=-1).clamp_min(1e-8) |
| | Hrw = -(p * p.log()).sum(dim=-1) |
| | K = p.shape[-1] |
| | score = (1.0 - Hrw / max(1e-8, math.log(K))).to(dtype=out1.dtype) |
| | else: |
| | raise ValueError(f"Unknown _intv_score_kind={sk}") |
| |
|
| | |
| | clip_p = getattr(self, "_intv_score_clip_pctl", None) |
| | if clip_p is not None: |
| | clip_p = float(clip_p) |
| | if 0.0 < clip_p < 100.0: |
| | smax = torch.quantile(score.detach().flatten(), clip_p / 100.0).to(score.dtype) |
| | score = torch.clamp(score, max=smax) |
| |
|
| | |
| | tk = getattr(self, "_intv_tau_kind", "pctl") |
| | if tk == "abs": |
| | tau = torch.tensor(float(getattr(self, "_intv_tau", 0.15)), |
| | device=score.device, dtype=score.dtype) |
| | elif tk == "pctl": |
| | tau = torch.quantile( |
| | score.detach().flatten(), |
| | float(getattr(self, "_intv_tau_pctl", 75.0)) / 100.0, |
| | ).to(score.dtype) |
| | else: |
| | raise ValueError(f"Unknown _intv_tau_kind={tk}") |
| |
|
| | |
| | mm = getattr(self, "_intv_mask_mode", "soft") |
| | if mm == "hard": |
| | mask = (score > tau).to(out1.dtype) |
| | elif mm == "soft": |
| | temp = max(1e-6, float(getattr(self, "_intv_soft_temp", 0.05))) |
| | mask = torch.sigmoid((score - tau) / temp).to(out1.dtype) |
| | else: |
| | raise ValueError(f"Unknown _intv_mask_mode={mm}") |
| |
|
| | par_beta = float(getattr(self, "_intv_par_beta", 1.0)) |
| | delta_mod = par_beta * delta_par + beta * mask.unsqueeze(-1) * delta_orth |
| |
|
| | logs = logs or {} |
| | logs.update(dict( |
| | score=score, tau=tau, mask=mask, |
| | alpha=alpha.squeeze(-1), |
| | out1_norm=out1n, |
| | dpar_norm=delta_par.norm(dim=-1), |
| | dorth_norm=dorthn, |
| | )) |
| | else: |
| | raise ValueError(f"Unknown _intv_mode={mode}") |
| |
|
| | |
| | if hm is not None: |
| | delta_mod = hm * delta_mod + (1.0 - hm) * delta |
| | logs = logs or {} |
| | logs["head_mask"] = hm.squeeze(0).squeeze(-1).squeeze(-1).detach() |
| |
|
| | return delta_mod, logs |
| |
|
| | |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | return_info: bool = False, |
| | |
| | |
| | routing_mode: str = "softmax", |
| | routing_topk: int = 2, |
| | read_weights_override: Optional[torch.Tensor] = None, |
| | routing_noise: Optional[str] = None, |
| | routing_noise_scale: float = 1.0, |
| | |
| | |
| | slot_mask: Optional[torch.Tensor] = None, |
| | slot_mask_where: str = "read", |
| | slot_mask_scope: str = "all", |
| | |
| | |
| | info_level: str = "full", |
| | info_cfg: Optional[Dict] = None, |
| | ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: |
| | """ |
| | Parameters |
| | ---------- |
| | x : [B, T, C] |
| | attention_mask : [B, T] optional padding mask (1=valid, 0=pad) |
| | return_info : if True, return diagnostics dict as second element |
| | routing_mode : "softmax" | "top1" | "topk" | "external" |
| | routing_topk : k for topk mode |
| | read_weights_override : [B,H,T,K] or [B,H,L,K] for external routing |
| | routing_noise : None | "gumbel" | "gaussian" |
| | routing_noise_scale : scale for routing noise |
| | slot_mask : [K] where 1=keep, 0=mask |
| | slot_mask_where : "read" | "content_read_only" | "slotspace_only" |
| | slot_mask_scope : "all" | "last_pos_only" |
| | info_level : "basic" | "logits" | "full" |
| | info_cfg : dict (see default_info_cfg()) |
| | |
| | Returns |
| | ------- |
| | (output, info) where info is None if return_info=False. |
| | """ |
| |
|
| | B, T, C = x.shape |
| | H, K, d = self.num_heads, self.num_slots, self.head_dim |
| |
|
| | |
| | if info_cfg is None: |
| | info_cfg = self.default_info_cfg() |
| | store_read_weights = bool(info_cfg.get("store_read_weights", True)) |
| | store_read_logits = bool(info_cfg.get("store_read_logits", True)) and info_level in ("logits", "full") |
| | store_write_logits = bool(info_cfg.get("store_write_logits", True)) and info_level == "full" |
| | store_slot_norm = bool(info_cfg.get("store_slot_state_norm", True)) and info_level == "full" |
| | store_out1 = bool(info_cfg.get("store_out1", False)) and return_info |
| | store_delta = bool(info_cfg.get("store_delta", False)) and return_info |
| | store_slot_w = bool(info_cfg.get("store_slot_w", False)) and return_info |
| |
|
| | |
| | k_write = self.Wk_write(x).view(B, T, H, d).transpose(1, 2) |
| | v_write = self.Wv_write(x).view(B, T, H, d).transpose(1, 2) |
| | q_read = self.Wq_read(x).view(B, T, H, d).transpose(1, 2) |
| |
|
| | if self.normalize_k: |
| | k_write = F.normalize(k_write, dim=-1, eps=1e-8) |
| |
|
| | if self.use_rope_keys: |
| | cos, sin = self.rope.get_cos_sin(T, device=x.device, dtype=k_write.dtype) |
| | k_write = apply_rope(k_write, cos, sin) |
| |
|
| | |
| | slot_keys = self.slot_keys |
| | if self.training and self.slot_dropout > 0.0: |
| | drop = (torch.rand((H, K), device=x.device) < self.slot_dropout) |
| | slot_keys = slot_keys * (~drop).to(slot_keys.dtype).unsqueeze(-1) |
| |
|
| | |
| | write_logits_raw = torch.einsum("hkd,bhtd->bhkt", slot_keys, k_write) / math.sqrt(d) |
| | state_dtype = torch.float32 if (self.state_fp32 and x.dtype != torch.float32) else x.dtype |
| | write_logits = write_logits_raw.to(state_dtype) / max(1e-6, self.write_temperature) |
| |
|
| | |
| | alibi_bias_applied = None |
| | if self.use_alibi_write: |
| | strength = self._alibi_strength(dtype=state_dtype, device=x.device) |
| | slopes = self._alibi_slopes.to(device=x.device, dtype=state_dtype) * strength |
| | pos_i = torch.arange(T, device=x.device, dtype=state_dtype) |
| | alibi_bias = slopes.view(1, H, 1, 1) * pos_i.view(1, 1, 1, T) |
| | write_logits = write_logits + alibi_bias |
| | alibi_bias_applied = alibi_bias |
| |
|
| | |
| | if attention_mask is not None: |
| | valid = attention_mask.to(dtype=torch.bool) |
| | write_logits = write_logits.masked_fill(~valid.view(B, 1, 1, T), float("-inf")) |
| | else: |
| | valid = None |
| |
|
| | |
| | |
| | |
| | content_read_gamma = self._content_read_gamma(dtype=q_read.dtype, device=x.device) |
| | rtemp = max(1e-6, self.read_temperature) |
| |
|
| | out_h = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) |
| |
|
| | out1_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_out1 else None |
| | delta_full = torch.empty((B, H, T, d), device=x.device, dtype=state_dtype) if store_delta else None |
| | slot_w_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if store_slot_w else None |
| |
|
| | need_rw = bool(self.use_slotspace_refine) or (return_info and store_read_weights) |
| | read_weights = torch.empty((B, H, T, K), device=x.device, dtype=q_read.dtype) if need_rw else None |
| |
|
| | slot_state_norm_t = ( |
| | torch.empty((B, H, T, K), device=x.device, dtype=torch.float32) |
| | if (return_info and store_slot_norm) else None |
| | ) |
| |
|
| | if return_info and store_read_logits: |
| | read_logits_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) |
| | read_logits_key_full = torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) |
| | read_logits_content_full = ( |
| | torch.empty((B, H, T, K), device=x.device, dtype=state_dtype) if self.use_content_read else None |
| | ) |
| | else: |
| | read_logits_full = read_logits_key_full = read_logits_content_full = None |
| |
|
| | |
| | denom_state = torch.zeros((B, H, K), device=x.device, dtype=state_dtype) |
| | numer_state = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype) |
| | m_state = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype) |
| |
|
| | WRITE_CHUNK = self.write_chunk_size |
| |
|
| | for t0 in range(0, T, WRITE_CHUNK): |
| | t1 = min(T, t0 + WRITE_CHUNK) |
| | L = t1 - t0 |
| |
|
| | wlog_c = write_logits[:, :, :, t0:t1] |
| | m_c, _ = torch.cummax(wlog_c, dim=-1) |
| | m_new = torch.maximum(m_state.unsqueeze(-1), m_c) |
| |
|
| | scale = torch.exp(m_state.unsqueeze(-1) - m_new) |
| | denom_c = denom_state.unsqueeze(-1) * scale |
| | numer_c = numer_state.unsqueeze(-2) * scale.unsqueeze(-1) |
| |
|
| | w_new = self._safe_exp_sub_max(wlog_c, m_new) |
| | denom_c = denom_c + torch.cumsum(w_new, dim=-1) |
| |
|
| | v_c = v_write[:, :, t0:t1, :].to(state_dtype) |
| | add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) |
| | numer_c = numer_c + add |
| |
|
| | slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1) |
| | slot_state_t = slot_state_c.permute(0, 1, 3, 2, 4).contiguous() |
| |
|
| | |
| | q_read_c = q_read[:, :, t0:t1, :] |
| | read_logits_key = torch.einsum("bhld,hkd->bhlk", q_read_c, slot_keys) / math.sqrt(d) |
| |
|
| | read_logits_content = None |
| | if self.use_content_read: |
| | read_logits_content = torch.einsum( |
| | "bhld,bhlkd->bhlk", q_read_c, slot_state_t.to(q_read_c.dtype), |
| | ) / math.sqrt(d) |
| |
|
| | |
| | sm = self._resolve_slot_mask( |
| | slot_mask, B=B, H=H, L=L, K=K, |
| | device=x.device, dtype=read_logits_key.dtype, scope=slot_mask_scope, |
| | ) |
| |
|
| | |
| | if slot_mask_where == "read": |
| | if sm is not None: |
| | read_logits_key = read_logits_key.masked_fill(sm <= 0.0, float("-inf")) |
| | if self.use_content_read and read_logits_content is not None: |
| | read_logits_content = read_logits_content.masked_fill(sm <= 0.0, float("-inf")) |
| | elif slot_mask_where == "content_read_only": |
| | if sm is not None and self.use_content_read and read_logits_content is not None: |
| | read_logits_content = read_logits_content.masked_fill(sm <= 0.0, 0.0) |
| | elif slot_mask_where == "slotspace_only": |
| | pass |
| | else: |
| | raise ValueError(f"Unknown slot_mask_where={slot_mask_where!r}") |
| |
|
| | |
| | rl = read_logits_key |
| | if self.use_content_read and read_logits_content is not None: |
| | rl = rl + content_read_gamma.to(rl.dtype) * read_logits_content |
| |
|
| | if return_info and store_read_logits: |
| | read_logits_full[:, :, t0:t1, :] = rl.to(state_dtype) |
| | read_logits_key_full[:, :, t0:t1, :] = read_logits_key.to(state_dtype) |
| | if self.use_content_read and read_logits_content_full is not None: |
| | read_logits_content_full[:, :, t0:t1, :] = read_logits_content.to(state_dtype) |
| |
|
| | |
| | read_w_c = self._compute_read_weights( |
| | read_logits=rl, read_logits_key=read_logits_key, |
| | read_logits_content=read_logits_content, |
| | routing_mode=routing_mode, routing_topk=routing_topk, |
| | read_weights_override=read_weights_override, |
| | routing_noise=routing_noise, routing_noise_scale=routing_noise_scale, |
| | rtemp=rtemp, sm=sm, slot_mask_where=slot_mask_where, |
| | B=B, H=H, L=L, K=K, T_total=T, t0=t0, t1=t1, |
| | q_read_c=q_read_c, slot_keys=slot_keys, |
| | slot_state_t=slot_state_t, valid=valid, |
| | state_dtype=state_dtype, |
| | ) |
| |
|
| | if read_weights is not None: |
| | read_weights[:, :, t0:t1, :] = read_w_c |
| |
|
| | |
| | out_h[:, :, t0:t1, :] = torch.einsum( |
| | "bhlk,bhlkd->bhld", read_w_c.to(state_dtype), slot_state_t.to(state_dtype), |
| | ) |
| |
|
| | if out1_full is not None: |
| | out1_full[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] |
| |
|
| | if slot_state_norm_t is not None: |
| | slot_state_norm_t[:, :, t0:t1, :] = slot_state_t.to(torch.float32).norm(dim=-1) |
| |
|
| | m_state = m_new[:, :, :, -1] |
| | denom_state = denom_c[:, :, :, -1] |
| | numer_state = numer_c[:, :, :, -1, :] |
| |
|
| | |
| | |
| | |
| | slotspace_delta_norm_mean = None |
| | intv_logs_acc: Optional[Dict] = None |
| | intv_logs_count = 0 |
| |
|
| | if self.use_slotspace_refine: |
| | slotspace_dtype = state_dtype |
| | M = self.slotspace_dim |
| | assert read_weights is not None |
| |
|
| | u = self.slot_in(read_weights.to(slotspace_dtype)) |
| | q_s = self.slot_q(u) |
| | k_s = self.slot_k(u) |
| | v_s = self.slot_v(u) |
| |
|
| | if self.use_rope_slotspace: |
| | cos_s, sin_s = self.rope_slotspace.get_cos_sin(T, device=x.device, dtype=q_s.dtype) |
| | q_s = apply_rope(q_s, cos_s, sin_s) |
| | k_s = apply_rope(k_s, cos_s, sin_s) |
| |
|
| | qf = phi(q_s) |
| | kf = phi(k_s) |
| |
|
| | if valid is not None: |
| | vmask = valid.view(B, 1, T, 1).to(slotspace_dtype) |
| | qf = qf * vmask |
| | kf = kf * vmask |
| | v_s = v_s * vmask |
| |
|
| | u2 = torch.empty((B, H, T, M), device=x.device, dtype=slotspace_dtype) |
| | S_state = torch.zeros((B, H, M, M), device=x.device, dtype=slotspace_dtype) |
| | Z_state = torch.zeros((B, H, M), device=x.device, dtype=slotspace_dtype) |
| |
|
| | SS_CHUNK = self.slotspace_chunk_size |
| | for t0 in range(0, T, SS_CHUNK): |
| | t1 = min(T, t0 + SS_CHUNK) |
| | qf_c = qf[:, :, t0:t1, :] |
| | kf_c = kf[:, :, t0:t1, :] |
| | v_c = v_s[:, :, t0:t1, :] |
| |
|
| | kv = torch.einsum("bhlm,bhln->bhlmn", kf_c, v_c) |
| | S_c = torch.cumsum(kv, dim=2) + S_state.unsqueeze(2) |
| | Z_c = (torch.cumsum(kf_c, dim=2) + Z_state.unsqueeze(2)).clamp_min(1e-8) |
| |
|
| | num = torch.einsum("bhlm,bhlmn->bhln", qf_c, S_c) |
| | den = torch.einsum("bhlm,bhlm->bhl", qf_c, Z_c).unsqueeze(-1).clamp_min(1e-8) |
| | u2[:, :, t0:t1, :] = num / den |
| |
|
| | S_state = S_c[:, :, -1, :, :] |
| | Z_state = Z_c[:, :, -1, :] |
| |
|
| | u2 = self.slotspace_dropout(u2) |
| | slot_w = self.slot_out(u2) |
| |
|
| | if slot_w_full is not None: |
| | slot_w_full[:] = slot_w.to(state_dtype) |
| |
|
| | if self.slotspace_signed_weights: |
| | slot_w_eff = torch.tanh(slot_w) |
| | else: |
| | slot_w_eff = torch.softmax(slot_w, dim=-1) |
| |
|
| | |
| | if slot_mask_where == "slotspace_only": |
| | sm_full = self._resolve_slot_mask( |
| | slot_mask, B=B, H=H, L=T, K=K, |
| | device=x.device, dtype=slot_w_eff.dtype, scope=slot_mask_scope, |
| | ) |
| | if sm_full is not None: |
| | slot_w_eff = slot_w_eff * (sm_full > 0.0).to(slot_w_eff.dtype) |
| | if not self.slotspace_signed_weights: |
| | slot_w_eff = slot_w_eff / slot_w_eff.sum(dim=-1, keepdim=True).clamp_min(1e-8) |
| |
|
| | gate = self._slotspace_gate(dtype=state_dtype, device=x.device).to(state_dtype) |
| |
|
| | |
| | denom_state2 = torch.zeros((B, H, K), device=x.device, dtype=state_dtype) |
| | numer_state2 = torch.zeros((B, H, K, d), device=x.device, dtype=state_dtype) |
| | m_state2 = torch.full((B, H, K), float("-inf"), device=x.device, dtype=state_dtype) |
| |
|
| | delta_norm_sum = torch.zeros((), device=x.device, dtype=torch.float32) |
| | delta_norm_count = 0 |
| |
|
| | for t0 in range(0, T, WRITE_CHUNK): |
| | t1 = min(T, t0 + WRITE_CHUNK) |
| | Lc = t1 - t0 |
| |
|
| | wlog_c = write_logits[:, :, :, t0:t1] |
| | m_c, _ = torch.cummax(wlog_c, dim=-1) |
| | m_new = torch.maximum(m_state2.unsqueeze(-1), m_c) |
| |
|
| | scale = torch.exp(m_state2.unsqueeze(-1) - m_new) |
| | denom_c = denom_state2.unsqueeze(-1) * scale |
| | numer_c = numer_state2.unsqueeze(-2) * scale.unsqueeze(-1) |
| |
|
| | w_new = self._safe_exp_sub_max(wlog_c, m_new) |
| | denom_c = denom_c + torch.cumsum(w_new, dim=-1) |
| |
|
| | v_c = v_write[:, :, t0:t1, :].to(state_dtype) |
| | add = torch.cumsum(w_new.unsqueeze(-1) * v_c.unsqueeze(2), dim=-2) |
| | numer_c = numer_c + add |
| |
|
| | slot_state_c = numer_c / denom_c.clamp_min(1e-8).unsqueeze(-1) |
| | slot_state_t2 = slot_state_c.permute(0, 1, 3, 2, 4).contiguous() |
| |
|
| | slot_w_c = slot_w_eff[:, :, t0:t1, :].to(state_dtype) |
| | delta_c = torch.einsum("bhlk,bhlkd->bhld", slot_w_c, slot_state_t2.to(state_dtype)) |
| |
|
| | delta = gate * delta_c |
| |
|
| | if delta_full is not None: |
| | delta_full[:, :, t0:t1, :] = delta |
| |
|
| | |
| | slot_w_for_score = slot_w[:, :, t0:t1, :] if store_slot_w else None |
| | delta_mod, logs = self._apply_refine_intervention( |
| | out1=out_h[:, :, t0:t1, :], delta=delta, slot_w=slot_w_for_score, |
| | ) |
| |
|
| | out_h[:, :, t0:t1, :] = out_h[:, :, t0:t1, :] + delta_mod |
| |
|
| | |
| | if logs is not None and return_info: |
| | if intv_logs_acc is None: |
| | intv_logs_acc = {} |
| | for klog, v in logs.items(): |
| | if torch.is_tensor(v): |
| | vv = v.detach().to(torch.float32) |
| | intv_logs_acc[klog] = vv if vv.ndim == 1 else vv.mean() |
| | intv_logs_count = 1 |
| | else: |
| | for klog, v in logs.items(): |
| | if torch.is_tensor(v) and klog in intv_logs_acc: |
| | vv = v.detach().to(torch.float32) |
| | intv_logs_acc[klog] = intv_logs_acc[klog] + (vv if vv.ndim == 1 else vv.mean()) |
| | intv_logs_count += 1 |
| |
|
| | delta_norm_sum = delta_norm_sum + delta.detach().to(torch.float32).norm(dim=-1).sum() |
| | delta_norm_count += B * H * Lc |
| |
|
| | m_state2 = m_new[:, :, :, -1] |
| | denom_state2 = denom_c[:, :, :, -1] |
| | numer_state2 = numer_c[:, :, :, -1, :] |
| |
|
| | slotspace_delta_norm_mean = (delta_norm_sum / max(1, delta_norm_count)).detach().cpu() |
| |
|
| | |
| | |
| | |
| | out = out_h.transpose(1, 2).contiguous().view(B, T, C) |
| | out = self.out_proj(out) |
| | out = self.dropout(out) |
| |
|
| | |
| | info = None |
| | if return_info: |
| | info = { |
| | "content_read_gamma": content_read_gamma.detach().to(torch.float32).cpu(), |
| | "routing_mode": routing_mode, |
| | "slot_mask_where": slot_mask_where, |
| | "slot_mask_scope": slot_mask_scope, |
| | "intv_mode": getattr(self, "_intv_mode", "off"), |
| | } |
| |
|
| | if alibi_bias_applied is not None and info_level == "full": |
| | info["alibi_bias_applied"] = self._store_tensor(alibi_bias_applied.to(torch.float32), cfg=info_cfg, kind="other") |
| |
|
| | if self.use_alibi_write and self.learn_alibi_strength: |
| | info["alibi_strength"] = self._alibi_strength(dtype=torch.float32, device=x.device).detach().cpu() |
| |
|
| | if self.use_slotspace_refine: |
| | info["slotspace_gate"] = self._slotspace_gate(dtype=torch.float32, device=x.device).detach().cpu() |
| | info["use_rope_slotspace"] = torch.tensor(bool(self.use_rope_slotspace)) |
| | if slotspace_delta_norm_mean is not None: |
| | info["slotspace_delta_norm"] = slotspace_delta_norm_mean |
| |
|
| | |
| | if store_read_weights and read_weights is not None: |
| | info["read_weights"] = self._store_tensor(read_weights, cfg=info_cfg, kind="bhtk") |
| | else: |
| | info["read_weights"] = None |
| |
|
| | |
| | if store_slot_norm and slot_state_norm_t is not None: |
| | s = slot_state_norm_t.permute(0, 1, 3, 2).contiguous() |
| | info["slot_state_norm"] = self._store_tensor(s, cfg=info_cfg, kind="bhkt") |
| | else: |
| | info["slot_state_norm"] = None |
| |
|
| | |
| | if store_read_logits and read_logits_full is not None: |
| | info["read_logits"] = self._store_tensor(read_logits_full.to(torch.float32), cfg=info_cfg, kind="bhtk") |
| | info["read_logits_key"] = self._store_tensor(read_logits_key_full.to(torch.float32), cfg=info_cfg, kind="bhtk") |
| | info["read_logits_content"] = ( |
| | self._store_tensor(read_logits_content_full.to(torch.float32), cfg=info_cfg, kind="bhtk") |
| | if read_logits_content_full is not None else None |
| | ) |
| | else: |
| | info["read_logits"] = info["read_logits_key"] = info["read_logits_content"] = None |
| |
|
| | |
| | if store_write_logits and info_level == "full": |
| | info["write_logits_raw"] = self._store_tensor(write_logits_raw, cfg=info_cfg, kind="bhkt") |
| | info["write_logits"] = self._store_tensor(write_logits.to(torch.float32), cfg=info_cfg, kind="bhkt") |
| | else: |
| | info["write_logits_raw"] = info["write_logits"] = None |
| |
|
| | |
| | info["out1"] = self._store_tensor(out1_full.to(torch.float32), cfg=info_cfg, kind="other") if out1_full is not None else None |
| | info["delta"] = self._store_tensor(delta_full.to(torch.float32), cfg=info_cfg, kind="other") if delta_full is not None else None |
| | info["slot_w"] = self._store_tensor(slot_w_full.to(torch.float32), cfg=info_cfg, kind="bhtk") if slot_w_full is not None else None |
| |
|
| | |
| | if intv_logs_acc is not None and intv_logs_count > 0: |
| | for klog, v in intv_logs_acc.items(): |
| | info[klog] = (v / float(intv_logs_count)).detach().cpu() |
| |
|
| | |
| | for alias_from, alias_to in [ |
| | ("score", "intv_score_mean"), ("mask", "intv_mask_mean"), |
| | ("tau", "intv_tau"), ("alpha", "intv_alpha_mean"), |
| | ("out1_norm", "intv_out1_norm_mean"), |
| | ("dpar_norm", "intv_dpar_norm_mean"), |
| | ("dorth_norm", "intv_dorth_norm_mean"), |
| | ]: |
| | if alias_from in intv_logs_acc: |
| | val = info.get(alias_from) |
| | if torch.is_tensor(val) and val.ndim != 1: |
| | info[alias_to] = val |
| |
|
| | return out, info |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | @dataclass |
| | class ASMTrainConfig: |
| | |
| | dataset_name: str = "wikitext" |
| | dataset_config: str = "wikitext-103-raw-v1" |
| | tokenizer_name: str = "gpt2" |
| |
|
| | max_seq_len: int = 256 |
| | stride_frac_val: float = 0.50 |
| | seed: int = 1337 |
| |
|
| | micro_batch_size: int = 2 |
| | grad_accum_steps: int = 8 |
| | train_samples_target: int = 100_000_000 |
| | val_samples_target: int = 25_000 |
| |
|
| | |
| | batch_size: int = 64 |
| | learning_rate: float = 3e-4 |
| | weight_decay: float = 0.01 |
| | betas: Tuple[float, float] = (0.9, 0.95) |
| | grad_clip: float = 1.0 |
| | warmup_steps: int = 1_000 |
| | total_steps: int = 75_000 |
| | eval_interval: int = 1_000 |
| | log_interval: int = 100 |
| |
|
| | |
| | vocab_size: int = 50257 |
| | embed_dim: int = 384 |
| | num_layers: int = 23 |
| | num_heads: int = 8 |
| | num_slots: int = 32 |
| | mlp_ratio: float = 4.0 |
| | dropout: float = 0.1 |
| | tie_weights: bool = True |
| |
|
| | |
| | read_temperature: float = 1.0 |
| | write_temperature: float = 1.0 |
| | slot_dropout: float = 0.05 |
| | state_fp32: bool = True |
| | normalize_k: bool = False |
| |
|
| | |
| | use_abs_pos: bool = False |
| | use_rope_keys: bool = True |
| | rope_base: float = 10000.0 |
| | use_alibi_write: bool = True |
| | alibi_strength_init: float = 0.1 |
| | learn_alibi_strength: bool = True |
| | min_strength: float = 0.0 |
| |
|
| | |
| | use_content_read: bool = True |
| | content_read_init: float = -4.0 |
| | content_read_max_gamma: float = 3.0 |
| |
|
| | |
| | use_slotspace_refine: bool = True |
| | slotspace_dim: int = 64 |
| | slotspace_gate_init: float = -4.0 |
| | slotspace_dropout: float = 0.05 |
| | slotspace_signed_weights: bool = True |
| |
|
| | |
| | use_rope_slotspace: bool = True |
| | rope_base_slotspace: float = 100000.0 |
| |
|
| | |
| | write_chunk_size: int = 128 |
| | slotspace_chunk_size: int = 128 |
| | enable_compiled: bool = False |
| |
|
| | |
| | eval_max_batches: int = 150 |
| | analytics_last_k: int = 32 |
| |
|
| | |
| | output_dir: str = "./drive/MyDrive/asm_outputs" |
| | tag: str = "asm_wikitext" |
| | cache_dir: str = "./drive/MyDrive/asm_caches" |
| | val_windows_cache: str = "./drive/MyDrive/asm_val_cache_windows_1024.pkl" |
| |
|
| |
|
| | |
| | |
| | |
| | class ASMBlock(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | num_slots: int, |
| | mlp_ratio: float = 4.0, |
| | dropout: float = 0.1, |
| | |
| | read_temperature: float = 1.0, |
| | write_temperature: float = 1.0, |
| | state_fp32: bool = True, |
| | slot_dropout: float = 0.0, |
| | normalize_k: bool = False, |
| | |
| | use_rope_keys: bool = True, |
| | rope_base: float = 10000.0, |
| | use_alibi_write: bool = True, |
| | |
| | alibi_strength_init: float = 0.1, |
| | learn_alibi_strength: bool = True, |
| | min_strength: float = 0.0, |
| | |
| | use_content_read: bool = True, |
| | content_read_init: float = -4.0, |
| | content_read_max_gamma: float = 3.0, |
| | |
| | use_slotspace_refine: bool = True, |
| | slotspace_dim: int = 32, |
| | slotspace_gate_init: float = -10.0, |
| | slotspace_dropout: float = 0.0, |
| | slotspace_signed_weights: bool = True, |
| | |
| | use_rope_slotspace: bool = True, |
| | rope_base_slotspace: float = 100000.0, |
| | |
| | write_chunk_size: int = 128, |
| | slotspace_chunk_size: int = 128, |
| | ): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(embed_dim) |
| |
|
| | self.asa = AddressedStateAttention( |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | num_slots=num_slots, |
| | dropout=dropout, |
| | read_temperature=read_temperature, |
| | write_temperature=write_temperature, |
| | state_fp32=state_fp32, |
| | slot_dropout=slot_dropout, |
| | normalize_k=normalize_k, |
| | use_rope_keys=use_rope_keys, |
| | rope_base=rope_base, |
| | use_alibi_write=use_alibi_write, |
| | alibi_strength_init=alibi_strength_init, |
| | learn_alibi_strength=learn_alibi_strength, |
| | min_strength=min_strength, |
| | use_content_read=use_content_read, |
| | content_read_init=content_read_init, |
| | content_read_max_gamma=content_read_max_gamma, |
| | use_slotspace_refine=use_slotspace_refine, |
| | slotspace_dim=slotspace_dim, |
| | slotspace_gate_init=slotspace_gate_init, |
| | slotspace_dropout=slotspace_dropout, |
| | slotspace_signed_weights=slotspace_signed_weights, |
| | use_rope_slotspace=use_rope_slotspace, |
| | rope_base_slotspace=rope_base_slotspace, |
| | write_chunk_size=write_chunk_size, |
| | slotspace_chunk_size=slotspace_chunk_size, |
| | ) |
| |
|
| | self.norm2 = nn.LayerNorm(embed_dim) |
| | hidden = int(embed_dim * mlp_ratio) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(embed_dim, hidden, bias=False), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(hidden, embed_dim, bias=False), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | return_info: bool = False, |
| | |
| | routing_mode: str = "softmax", |
| | routing_topk: int = 2, |
| | read_weights_override: Optional[torch.Tensor] = None, |
| | routing_noise: Optional[str] = None, |
| | routing_noise_scale: float = 1.0, |
| | |
| | slot_mask: Optional[torch.Tensor] = None, |
| | slot_mask_where: str = "read", |
| | slot_mask_scope: str = "all", |
| | |
| | info_level: str = "full", |
| | info_cfg: Optional[Dict] = None, |
| | ): |
| | a, info = self.asa( |
| | self.norm1(x), |
| | attention_mask=attention_mask, |
| | return_info=return_info, |
| | routing_mode=routing_mode, |
| | routing_topk=routing_topk, |
| | read_weights_override=read_weights_override, |
| | routing_noise=routing_noise, |
| | routing_noise_scale=routing_noise_scale, |
| | slot_mask=slot_mask, |
| | slot_mask_where=slot_mask_where, |
| | slot_mask_scope=slot_mask_scope, |
| | info_level=info_level, |
| | info_cfg=info_cfg, |
| | ) |
| | x = x + a |
| | x = x + self.mlp(self.norm2(x)) |
| | return x, info |
| |
|
| |
|
| | |
| | |
| | |
| | class ASMLanguageModel(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size: int, |
| | embed_dim: int = 384, |
| | num_layers: int = 6, |
| | num_heads: int = 8, |
| | num_slots: int = 8, |
| | max_seq_len: int = 1024, |
| | mlp_ratio: float = 4.0, |
| | dropout: float = 0.1, |
| | |
| | read_temperature: float = 1.0, |
| | write_temperature: float = 1.0, |
| | state_fp32: bool = True, |
| | slot_dropout: float = 0.05, |
| | normalize_k: bool = False, |
| | tie_weights: bool = True, |
| | |
| | use_abs_pos: bool = False, |
| | |
| | use_rope_keys: bool = True, |
| | rope_base: float = 10000.0, |
| | use_alibi_write: bool = True, |
| | |
| | alibi_strength_init: float = 0.1, |
| | learn_alibi_strength: bool = True, |
| | min_strength: float = 0.0, |
| | |
| | use_content_read: bool = True, |
| | content_read_init: float = -4.0, |
| | content_read_max_gamma: float = 3.0, |
| | |
| | use_slotspace_refine: bool = True, |
| | slotspace_dim: int = 32, |
| | slotspace_gate_init: float = -10.0, |
| | slotspace_dropout: float = 0.0, |
| | slotspace_signed_weights: bool = True, |
| | |
| | use_rope_slotspace: bool = True, |
| | rope_base_slotspace: float = 100000.0, |
| | |
| | write_chunk_size: int = 128, |
| | slotspace_chunk_size: int = 128, |
| | ): |
| | super().__init__() |
| | self.vocab_size = vocab_size |
| | self.embed_dim = embed_dim |
| | self.max_seq_len = max_seq_len |
| | self.use_abs_pos = bool(use_abs_pos) |
| |
|
| | self.tok = nn.Embedding(vocab_size, embed_dim) |
| | self.pos = nn.Embedding(max_seq_len, embed_dim) if self.use_abs_pos else None |
| | self.drop = nn.Dropout(dropout) |
| |
|
| | self.blocks = nn.ModuleList([ |
| | ASMBlock( |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | num_slots=num_slots, |
| | mlp_ratio=mlp_ratio, |
| | dropout=dropout, |
| | read_temperature=read_temperature, |
| | write_temperature=write_temperature, |
| | state_fp32=state_fp32, |
| | slot_dropout=slot_dropout, |
| | normalize_k=normalize_k, |
| | use_rope_keys=use_rope_keys, |
| | rope_base=rope_base, |
| | use_alibi_write=use_alibi_write, |
| | alibi_strength_init=alibi_strength_init, |
| | learn_alibi_strength=learn_alibi_strength, |
| | min_strength=min_strength, |
| | use_content_read=use_content_read, |
| | content_read_init=content_read_init, |
| | content_read_max_gamma=content_read_max_gamma, |
| | use_slotspace_refine=use_slotspace_refine, |
| | slotspace_dim=slotspace_dim, |
| | slotspace_gate_init=slotspace_gate_init, |
| | slotspace_dropout=slotspace_dropout, |
| | slotspace_signed_weights=slotspace_signed_weights, |
| | use_rope_slotspace=use_rope_slotspace, |
| | rope_base_slotspace=rope_base_slotspace, |
| | write_chunk_size=write_chunk_size, |
| | slotspace_chunk_size=slotspace_chunk_size, |
| | ) |
| | for _ in range(num_layers) |
| | ]) |
| |
|
| | self.norm = nn.LayerNorm(embed_dim) |
| | self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False) |
| | if tie_weights: |
| | self.lm_head.weight = self.tok.weight |
| |
|
| | self.apply(self._init) |
| |
|
| | def _init(self, m): |
| | if isinstance(m, nn.Linear): |
| | nn.init.normal_(m.weight, std=0.02) |
| | elif isinstance(m, nn.Embedding): |
| | nn.init.normal_(m.weight, std=0.02) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.ones_(m.weight) |
| | nn.init.zeros_(m.bias) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | return_info: bool = False, |
| | |
| | routing_mode: str = "softmax", |
| | routing_topk: int = 2, |
| | read_weights_override: Optional[torch.Tensor] = None, |
| | routing_noise: Optional[str] = None, |
| | routing_noise_scale: float = 1.0, |
| | |
| | slot_mask: Optional[torch.Tensor] = None, |
| | slot_mask_where: str = "read", |
| | slot_mask_scope: str = "all", |
| | |
| | info_level: str = "full", |
| | info_cfg: Optional[Dict] = None, |
| | ): |
| | B, T = input_ids.shape |
| |
|
| | x = self.tok(input_ids) |
| | if self.use_abs_pos: |
| | pos = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1) |
| | x = x + self.pos(pos) |
| | x = self.drop(x) |
| |
|
| | infos: List[Optional[Dict[str, torch.Tensor]]] = [] |
| | for blk in self.blocks: |
| | x, info = blk( |
| | x, |
| | attention_mask=attention_mask, |
| | return_info=return_info, |
| | routing_mode=routing_mode, |
| | routing_topk=routing_topk, |
| | read_weights_override=read_weights_override, |
| | routing_noise=routing_noise, |
| | routing_noise_scale=routing_noise_scale, |
| | slot_mask=slot_mask, |
| | slot_mask_where=slot_mask_where, |
| | slot_mask_scope=slot_mask_scope, |
| | info_level=info_level, |
| | info_cfg=info_cfg, |
| | ) |
| | if return_info: |
| | infos.append(info) |
| |
|
| | x = self.norm(x) |
| | logits = self.lm_head(x) |
| | return (logits, infos) if return_info else logits |
| |
|
| |
|
| | |
| | |
| | |
| | def build_model_from_cfg(cfg: ASMTrainConfig) -> ASMLanguageModel: |
| | return ASMLanguageModel( |
| | vocab_size=cfg.vocab_size, |
| | embed_dim=cfg.embed_dim, |
| | num_layers=cfg.num_layers, |
| | num_heads=cfg.num_heads, |
| | num_slots=cfg.num_slots, |
| | max_seq_len=cfg.max_seq_len, |
| | mlp_ratio=cfg.mlp_ratio, |
| | dropout=cfg.dropout, |
| | read_temperature=cfg.read_temperature, |
| | write_temperature=cfg.write_temperature, |
| | state_fp32=cfg.state_fp32, |
| | slot_dropout=cfg.slot_dropout, |
| | normalize_k=cfg.normalize_k, |
| | tie_weights=cfg.tie_weights, |
| | use_abs_pos=cfg.use_abs_pos, |
| | use_rope_keys=cfg.use_rope_keys, |
| | rope_base=cfg.rope_base, |
| | use_alibi_write=cfg.use_alibi_write, |
| | alibi_strength_init=cfg.alibi_strength_init, |
| | learn_alibi_strength=cfg.learn_alibi_strength, |
| | min_strength=cfg.min_strength, |
| | use_content_read=cfg.use_content_read, |
| | content_read_init=cfg.content_read_init, |
| | content_read_max_gamma=cfg.content_read_max_gamma, |
| | use_slotspace_refine=cfg.use_slotspace_refine, |
| | slotspace_dim=cfg.slotspace_dim, |
| | slotspace_gate_init=cfg.slotspace_gate_init, |
| | slotspace_dropout=cfg.slotspace_dropout, |
| | slotspace_signed_weights=cfg.slotspace_signed_weights, |
| | use_rope_slotspace=cfg.use_rope_slotspace, |
| | rope_base_slotspace=cfg.rope_base_slotspace, |
| | write_chunk_size=cfg.write_chunk_size, |
| | slotspace_chunk_size=cfg.slotspace_chunk_size, |
| | ) |
| |
|