| import math |
| import contextlib |
| import logging |
| from typing import Dict, List, Tuple, Optional |
|
|
| import torch |
| import torch.distributed as dist |
| import sys |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as checkpoint |
|
|
| from .torch_utils import cpu_autocast |
|
|
| from .optimization import configure_optimizer |
| from .compression import decompress_bits |
| from .parity import enforce_parity |
|
|
| _mask_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} |
| _attention_cache: Dict[str, torch.Tensor] = {} |
| _MAX_CACHE_SIZE = 50 |
|
|
|
|
| def clear_cache(): |
| """Clear memory caches to prevent OOM in long sequences.""" |
| global _mask_cache, _attention_cache |
| _mask_cache.clear() |
| _attention_cache.clear() |
|
|
|
|
| def get_tri_mask(seq_len: int, device: torch.device) -> torch.Tensor: |
| """Return or create a cached upper-triangular mask with memory management.""" |
| key = (seq_len, device) |
| |
| |
| if len(_mask_cache) > _MAX_CACHE_SIZE: |
| clear_cache() |
| |
| if key not in _mask_cache: |
| _mask_cache[key] = torch.triu( |
| torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), 1 |
| ) |
| return _mask_cache[key] |
|
|
| try: |
| if torch.__version__ and tuple(map(int, torch.__version__.split(".")[:2])) >= (2, 0) and sys.version_info < (3, 11): |
| compile_fn = torch.compile |
| else: |
| raise RuntimeError |
| except Exception: |
|
|
| def compile_fn(fn=None, **kwargs): |
| if fn is None: |
| return lambda f: f |
| return fn |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """Sinusoidal positional encoding.""" |
|
|
| def __init__(self, d_model: int, max_len: int = 1024) -> None: |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
| inv = torch.exp( |
| torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) |
| ) |
| pe[:, 0::2] = torch.sin(pos * inv) |
| pe[:, 1::2] = torch.cos(pos * inv) |
| self.register_buffer("pe", pe.unsqueeze(1)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Add positional encoding to input tensor.""" |
| return x + self.pe[: x.size(0)] |
|
|
|
|
| class LoggingTransformerEncoderLayer(nn.Module): |
| """Transformer encoder layer that exposes attention weights. |
| |
| It optionally performs chunked attention with a fixed window size. |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| dim_feedforward: int = 512, |
| dropout: float = 0.1, |
| chunk_size: Optional[int] = None, |
| overlap: int = 0, |
| full_attn_logging: Optional[bool] = None, |
| ) -> None: |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
| self.chunk_size = chunk_size |
| self.overlap = overlap |
| if full_attn_logging is None: |
| full_attn_logging = False if chunk_size is not None else True |
| self.full_attn_logging = full_attn_logging |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
| self.activation = F.relu |
|
|
| def _chunked_attn( |
| self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Perform memory-efficient chunked self attention with overlap.""" |
| T, B, D = src.shape |
| |
| |
| if T <= 128 or self.chunk_size is None or self.chunk_size >= T: |
| return self._full_attn(src, attn_mask) |
| |
| src_b = src.transpose(0, 1) |
| C = self.chunk_size |
| O = self.overlap |
| n_chunks = (T + C - 1) // C |
| pad_len = n_chunks * C - T |
| |
| |
| outputs = [] |
| weights_list = [] |
| |
| |
| with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()): |
| for chunk_idx in range(n_chunks): |
| start_idx = chunk_idx * C |
| end_idx = min(start_idx + C + 2 * O, T + O) |
| |
| |
| chunk_start = max(0, start_idx - O) |
| chunk_end = min(T, end_idx) |
| chunk = src_b[:, chunk_start:chunk_end] |
| |
| |
| if chunk.size(1) < C + 2 * O: |
| pad_size = C + 2 * O - chunk.size(1) |
| chunk = F.pad(chunk, (0, 0, 0, pad_size)) |
| |
| chunk_len = chunk.size(1) |
| mask = get_tri_mask(chunk_len, src.device) if attn_mask is not None else None |
| |
| |
| out, weights = self.self_attn( |
| chunk, chunk, chunk, |
| attn_mask=mask, |
| need_weights=self.full_attn_logging, |
| average_attn_weights=False, |
| ) |
| |
| |
| core_start = O if chunk_idx > 0 else 0 |
| core_end = core_start + min(C, T - start_idx) |
| outputs.append(out[:, core_start:core_end]) |
| |
| if self.full_attn_logging and weights is not None: |
| weights_list.append(weights[:, :, core_start:core_end]) |
| |
| |
| del out, weights, chunk |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| |
| |
| seq = torch.cat(outputs, dim=1) |
| |
| |
| if self.full_attn_logging and weights_list: |
| |
| if T > 1024: |
| attn_out = torch.empty(0, device=src.device) |
| else: |
| attn_out = torch.cat(weights_list, dim=2) |
| else: |
| attn_out = torch.empty(0, device=src.device) |
| |
| return seq.transpose(0, 1), attn_out |
| |
| def _full_attn(self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Standard full attention for smaller sequences.""" |
| qkv = src.transpose(0, 1) |
| attn_output, attn_weights = self.self_attn( |
| qkv, qkv, qkv, |
| attn_mask=attn_mask, |
| need_weights=True, |
| average_attn_weights=False, |
| ) |
| return attn_output.transpose(0, 1), attn_weights |
|
|
| def forward( |
| self, src: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Return output and attention map.""" |
| if self.chunk_size is not None: |
| attn_output, attn_weights = self._chunked_attn(src, attn_mask) |
| else: |
| qkv = src.transpose(0, 1) |
| attn_output, attn_weights = self.self_attn( |
| qkv, |
| qkv, |
| qkv, |
| attn_mask=attn_mask, |
| need_weights=True, |
| average_attn_weights=False, |
| ) |
| attn_output = attn_output.transpose(0, 1) |
| src = src + self.dropout1(attn_output) |
| src = self.norm1(src) |
| out = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| src = src + self.dropout2(out) |
| src = self.norm2(src) |
| return src, attn_weights.detach() |
|
|
|
|
| class ReversibleLoggingTransformerEncoderLayer(nn.Module): |
| """Reversible transformer encoder layer with checkpointing.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| dim_feedforward: int = 512, |
| dropout: float = 0.1, |
| chunk_size: Optional[int] = None, |
| overlap: int = 0, |
| full_attn_logging: Optional[bool] = None, |
| ) -> None: |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
| self.chunk_size = chunk_size |
| self.overlap = overlap |
| if full_attn_logging is None: |
| full_attn_logging = False if chunk_size is not None else True |
| self.full_attn_logging = full_attn_logging |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
| self.activation = F.relu |
|
|
| @compile_fn |
| def _sa_block( |
| self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.chunk_size is not None: |
| T, B, D = x.shape |
| x_b = x.transpose(0, 1) |
| C = self.chunk_size or T |
| O = self.overlap |
| n_chunks = (T + C - 1) // C |
| pad_len = n_chunks * C - T |
| src_pad = F.pad(x_b, (0, 0, O, pad_len + O)) |
| chunk_len = C + 2 * O |
| chunks = src_pad.unfold(1, chunk_len, C) |
| mask = get_tri_mask(chunk_len, x.device) if attn_mask is not None else None |
| out, weights = self.self_attn( |
| chunks.reshape(B * n_chunks, chunk_len, D), |
| chunks.reshape(B * n_chunks, chunk_len, D), |
| chunks.reshape(B * n_chunks, chunk_len, D), |
| attn_mask=mask, |
| need_weights=True, |
| average_attn_weights=False, |
| ) |
| out = out.view(B, n_chunks, chunk_len, D)[:, :, O : O + C] |
| weights = weights.view(B, n_chunks, self.self_attn.num_heads, chunk_len, chunk_len)[ |
| :, :, :, O : O + C |
| ] |
| seq = out.reshape(B, n_chunks * C, D)[:, :T] |
| if self.full_attn_logging and C < T: |
| full_attn = torch.zeros( |
| B, self.self_attn.num_heads, n_chunks * C, n_chunks * C, device=x.device |
| ) |
| for idx in range(n_chunks): |
| s = idx * C |
| start = max(s - O, 0) |
| end = min(s + C, n_chunks * C) |
| src_start = O - (s - start) |
| src_end = src_start + (end - start) |
| full_attn[:, :, s : s + C, start:end] = weights[ |
| :, idx, :, src_start:src_end |
| ] |
| full_attn = full_attn[:, :, :T, :T] |
| weights = full_attn.detach() |
| else: |
| weights = torch.empty(0, device=x.device) |
| attn_out = seq.transpose(0, 1) |
| else: |
| qkv = x.transpose(0, 1) |
| attn_out, weights = self.self_attn( |
| qkv, |
| qkv, |
| qkv, |
| attn_mask=attn_mask, |
| need_weights=True, |
| average_attn_weights=False, |
| ) |
| attn_out = attn_out.transpose(0, 1) |
| x = self.norm1(x + self.dropout1(attn_out)) |
| return x, weights.detach() |
|
|
| @compile_fn |
| def _ff_block(self, x: torch.Tensor) -> torch.Tensor: |
| out = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
| x = self.norm2(x + self.dropout2(out)) |
| return x |
|
|
| def forward( |
| self, |
| x1: torch.Tensor, |
| x2: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| y1, weights = self._sa_block(x2, attn_mask) |
| y1 = x1 + y1 |
| y2 = x2 + self._ff_block(y1) |
| return y1, y2, weights |
|
|
|
|
| class BitTransformerLM(nn.Module): |
| """Transformer language model that operates on raw bits (0/1) with telemetry.""" |
|
|
| def __init__( |
| self, |
| d_model: int = 128, |
| nhead: int = 8, |
| num_layers: int = 4, |
| dim_feedforward: int = 512, |
| max_seq_len: int = 1024, |
| lambda_K: float = 1.0, |
| lambda_C: float = 1.0, |
| lambda_S: float = 1.0, |
| reversible: bool = False, |
| use_checkpoint: bool = True, |
| use_autocast: bool = False, |
| use_act: bool = False, |
| act_threshold: float = 0.9, |
| chunk_size: Optional[int] = None, |
| overlap: int = 0, |
| full_attn_logging: Optional[bool] = None, |
| ) -> None: |
| """Create a BitTransformer language model. |
| |
| Args: |
| full_attn_logging: When ``False`` and ``chunk_size`` is |
| smaller than the sequence length, the model skips |
| reconstructing the full ``TรT`` attention matrices for |
| telemetry to reduce memory use. |
| """ |
| super().__init__() |
| self.d_model = d_model |
| self.num_layers = num_layers |
| self.lambda_K = lambda_K |
| self.lambda_C = lambda_C |
| self.lambda_S = lambda_S |
| self.reversible = reversible |
| self.use_checkpoint = use_checkpoint |
| self.use_autocast = use_autocast |
| self.use_act = use_act |
| self.act_threshold = act_threshold |
| self.chunk_size = chunk_size |
| self.overlap = overlap |
| if full_attn_logging is None: |
| full_attn_logging = False if chunk_size is not None else True |
| self.full_attn_logging = full_attn_logging |
|
|
| |
| self.embedding = nn.Embedding(2, d_model) |
| self.pos_enc = PositionalEncoding(d_model, max_len=max_seq_len) |
|
|
| layer_cls = ( |
| ReversibleLoggingTransformerEncoderLayer |
| if reversible |
| else LoggingTransformerEncoderLayer |
| ) |
| self.layers = nn.ModuleList( |
| [ |
| layer_cls( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| chunk_size=chunk_size, |
| overlap=overlap, |
| full_attn_logging=full_attn_logging, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| if self.use_act: |
| self.halt_projs = nn.ModuleList( |
| [nn.Linear(d_model, 1) for _ in range(num_layers)] |
| ) |
|
|
| self.out_head = nn.Linear(d_model, 2) |
|
|
| def expand_positional_encoding(self, new_len: int) -> None: |
| """Expand positional encoding to at least ``new_len``.""" |
| cur_len = self.pos_enc.pe.size(0) |
| if new_len <= cur_len: |
| return |
| device = self.pos_enc.pe.device |
| d_model = self.d_model |
| pe = torch.zeros(new_len, d_model, device=device) |
| pe[:cur_len] = self.pos_enc.pe.squeeze(1) |
| pos = torch.arange(cur_len, new_len, dtype=torch.float32, device=device).unsqueeze(1) |
| inv = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) |
| pe[cur_len:, 0::2] = torch.sin(pos * inv) |
| pe[cur_len:, 1::2] = torch.cos(pos * inv) |
| self.pos_enc.pe = pe.unsqueeze(1) |
|
|
| def set_lambdas(self, lambda_K: float, lambda_C: float, lambda_S: float) -> None: |
| """Update weighting coefficients for telemetry metrics.""" |
| self.lambda_K = lambda_K |
| self.lambda_C = lambda_C |
| self.lambda_S = lambda_S |
|
|
| def _maybe_decompress(self, codes: torch.Tensor) -> torch.Tensor: |
| """Return raw bit sequences, decompressing if input appears run-length encoded.""" |
| if codes.dim() <= 1: |
| return codes |
| needs_decompress = codes.max().item() > 1 |
| if not needs_decompress and codes.size(1) % 2 == 0: |
| vals = codes[:, 0::2] |
| if torch.all(vals[:, 1:] != vals[:, :-1]): |
| needs_decompress = True |
| if not needs_decompress: |
| return codes |
| seqs = [decompress_bits(row.to(torch.uint8)) for row in codes] |
| max_len = max(seq.numel() for seq in seqs) |
| padded = [F.pad(seq, (0, max_len - seq.numel())) for seq in seqs] |
| return torch.stack(padded) |
|
|
| def negentropy_kpi(self, codes: torch.Tensor) -> torch.Tensor: |
| """Approximate negentropy of bit sequences. |
| |
| Returns a value in ``[0, 1]`` where ``1`` denotes a perfectly ordered |
| sequence (all zeros or ones) and ``0`` reflects maximal entropy. |
| """ |
| codes = self._maybe_decompress(codes) |
| p = codes.float().mean(dim=1) |
| entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) |
| max_e = math.log(2.0) |
| return 1 - entropy / max_e |
|
|
| def lz_complexity(self, codes: torch.Tensor) -> torch.Tensor: |
| """Differentiable proxy for LempelโZiv complexity. |
| |
| Values near ``0`` indicate highly compressible sequences while values |
| approaching ``1`` correspond to rapid bit alternation. |
| """ |
| codes = self._maybe_decompress(codes) |
| diffs = torch.abs(codes[:, 1:] - codes[:, :-1]) |
| return diffs.float().mean(dim=1) |
|
|
| def negentropy_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: |
| """Negentropy computed from model logits. |
| |
| Parameters |
| ---------- |
| logits: ``torch.Tensor`` |
| Logit tensor of shape ``(B, T, 2)``. |
| detach: bool, default ``True`` |
| When ``True`` the computation is detached from the autograd graph. |
| """ |
| assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
| prob = logits.softmax(-1) |
| if detach: |
| prob = prob.detach() |
| p = prob[..., 1].mean(dim=1) |
| entropy = -(p * torch.log(p + 1e-9) + (1 - p) * torch.log(1 - p + 1e-9)) |
| max_e = math.log(2.0) |
| return 1 - entropy / max_e |
|
|
| def lz_complexity_logits(self, logits: torch.Tensor, detach: bool = True) -> torch.Tensor: |
| """LZ complexity proxy computed from logits. |
| |
| Parameters |
| ---------- |
| logits: ``torch.Tensor`` |
| Logit tensor of shape ``(B, T, 2)``. |
| detach: bool, default ``True`` |
| When ``True`` the computation is detached from the autograd graph. |
| """ |
| assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
| prob = logits.softmax(-1) |
| if detach: |
| prob = prob.detach() |
| prob1 = prob[..., 1] |
| diffs = torch.abs(prob1[:, 1:] - prob1[:, :-1]) |
| return diffs.mean(dim=1) |
|
|
| def symbiosis_kl_logits( |
| self, logits: torch.Tensor, ref_prob: float = 0.5, detach: bool = True |
| ) -> torch.Tensor: |
| """Symbiosis score from KL divergence to a reference distribution. |
| |
| Returns a value in ``[0, 1]`` with ``1`` meaning perfect agreement with |
| the reference distribution and ``0`` indicating maximal divergence. |
| """ |
| assert logits.dim() == 3 and logits.size(-1) == 2, "logits must be [B,T,2]" |
| probs = logits.softmax(-1) |
| if detach: |
| probs = probs.detach() |
| ref = torch.tensor([1 - ref_prob, ref_prob], device=logits.device) |
| kl = (probs * (probs.clamp_min(1e-9).log() - ref.log())).sum(-1).mean(dim=1) |
| max_kl = math.log(2.0) |
| return 1 - kl / max_kl |
|
|
| def _act_step( |
| self, |
| hidden: torch.Tensor, |
| idx: int, |
| halt_prob: torch.Tensor, |
| act_state: torch.Tensor, |
| halt_history: List[torch.Tensor], |
| ) -> Tuple[torch.Tensor, torch.Tensor, bool]: |
| """Apply one step of ACT halting logic.""" |
| p = torch.sigmoid(self.halt_projs[idx](hidden)) |
| delta = (1 - halt_prob) * p |
| halt_prob = halt_prob + delta |
| act_state = act_state + hidden * delta |
| halt_history.append(halt_prob.detach()) |
| min_prob = halt_prob.detach().min() |
| if dist.is_initialized(): |
| dist.all_reduce(min_prob, op=dist.ReduceOp.MIN) |
| return halt_prob, act_state, min_prob.item() >= self.act_threshold |
|
|
| def forward( |
| self, bit_seq: torch.Tensor, causal: bool = True |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """Forward pass returning logits and telemetry from the same graph. |
| |
| By default the model uses causal masking and (optional) chunked |
| attention. When ``causal`` is ``False`` the model operates in |
| "Diffusion LM" mode. In this mode chunked attention is temporarily |
| disabled so that every token can attend to the full sequence |
| bidirectionally. The original chunking configuration is restored after |
| the forward pass. |
| """ |
|
|
| |
| orig_chunks = None |
| orig_model_chunk = None |
| if not causal and self.chunk_size is not None: |
| orig_model_chunk = self.chunk_size |
| orig_chunks = [layer.chunk_size for layer in self.layers] |
| self.chunk_size = None |
| for layer in self.layers: |
| layer.chunk_size = None |
|
|
| try: |
| ctx = cpu_autocast() if self.use_autocast else contextlib.nullcontext() |
| with ctx: |
| x = self.embedding(bit_seq).transpose(0, 1) * math.sqrt(self.d_model) |
| x = self.pos_enc(x) |
|
|
| attn_mask = get_tri_mask(x.size(0), x.device) if causal else None |
|
|
| activations: List[torch.Tensor] = [] |
| attn_maps: List[torch.Tensor] = [] |
| halt_history: List[torch.Tensor] = [] |
| if self.use_act: |
| halt_prob = torch.zeros(x.size(0), x.size(1), 1, device=x.device) |
| act_state = torch.zeros_like(x) |
| if self.reversible: |
| x1, x2 = x, x |
| for idx, layer in enumerate(self.layers): |
| if self.use_checkpoint: |
| x1, x2, attn = checkpoint.checkpoint( |
| layer, x1, x2, attn_mask |
| ) |
| else: |
| x1, x2, attn = layer(x1, x2, attn_mask) |
| combined = (x1 + x2) / 2 |
| activations.append(combined) |
| if attn.numel() > 0: |
| attn_maps.append(attn) |
| if self.use_act: |
| halt_prob, act_state, should_break = self._act_step( |
| combined, idx, halt_prob, act_state, halt_history |
| ) |
| if should_break: |
| break |
| x = (x1 + x2) / 2 |
| else: |
| for idx, layer in enumerate(self.layers): |
| if self.use_checkpoint: |
| x, attn = checkpoint.checkpoint(layer, x, attn_mask) |
| else: |
| x, attn = layer(x, attn_mask) |
| activations.append(x) |
| if attn.numel() > 0: |
| attn_maps.append(attn) |
| if self.use_act: |
| halt_prob, act_state, should_break = self._act_step( |
| x, idx, halt_prob, act_state, halt_history |
| ) |
| if should_break: |
| break |
| if self.use_act: |
| act_state = act_state + x * (1 - halt_prob) |
| x = act_state |
| logits = self.out_head(x) |
|
|
| |
| entropies = [] |
| for act in activations: |
| prob = act.softmax(-1) |
| ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1).mean() |
| entropies.append(ent) |
|
|
| attn_entropies = [] |
| for attn in attn_maps: |
| prob = attn |
| ent = -(prob * prob.clamp_min(1e-9).log()).sum(-1) |
| ent = ent.mean(1) |
| attn_entropies.append(ent) |
| if attn_entropies: |
| attn_entropy_map = torch.stack(attn_entropies).mean(0) |
| else: |
| attn_entropy_map = torch.zeros( |
| bit_seq.size(0), bit_seq.size(1), device=bit_seq.device |
| ) |
| max_ent = math.log(attn_entropy_map.size(-1)) |
| attn_entropy_map = attn_entropy_map / max_ent |
| attn_entropy = attn_entropy_map.mean(1) |
|
|
| logits_bt = logits.transpose(0, 1) |
| negentropy_in = self.negentropy_kpi(bit_seq) |
| lz_in = self.lz_complexity(bit_seq.float()) |
| negentropy_logits_b = self.negentropy_logits(logits_bt, detach=False) |
| lz_logits_b = self.lz_complexity_logits(logits_bt, detach=False) |
| kl_div_b = self.symbiosis_kl_logits(logits_bt, detach=False) |
|
|
| raw_sym = ( |
| (self.lambda_K * negentropy_logits_b + self.lambda_C * lz_logits_b) / 2 |
| + negentropy_logits_b * lz_logits_b |
| - self.lambda_S * kl_div_b |
| - 0.1 * attn_entropy |
| ) |
| weight_norm = torch.stack([p.norm() for p in self.parameters()]).mean().detach() |
| raw_sym = raw_sym - 0.01 * weight_norm |
| sym_score = torch.sigmoid(raw_sym) |
|
|
| B, T = bit_seq.shape |
| assert logits_bt.shape[:2] == (B, T) |
| assert attn_entropy_map.shape == (B, T) |
|
|
| telemetry = { |
| "activations": activations, |
| "attention_maps": attn_maps, |
| "attention_entropy": attn_entropy_map, |
| "entropy": entropies, |
| "attention_entropy_mean": attn_entropy, |
| "negentropy_input": negentropy_in.detach(), |
| "lz_complexity_input": lz_in.detach(), |
| "negentropy_logits": negentropy_logits_b.detach(), |
| "lz_complexity_logits": lz_logits_b.detach(), |
| "symbiosis_kl": kl_div_b.detach(), |
| "symbiosis_score": sym_score.detach(), |
| } |
| if self.use_act: |
| telemetry["halt_probs"] = halt_history |
|
|
| return logits_bt, telemetry |
| finally: |
| if orig_chunks is not None: |
| self.chunk_size = orig_model_chunk |
| for layer, chunk in zip(self.layers, orig_chunks): |
| layer.chunk_size = chunk |
|
|
| def forward_compressed( |
| self, compressed_bits, causal: bool = True |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """Decompress bit sequences then run the normal forward pass.""" |
| if isinstance(compressed_bits, torch.Tensor) and compressed_bits.dim() == 1: |
| sequences = [decompress_bits(compressed_bits).to(torch.long)] |
| else: |
| sequences = [decompress_bits(c).to(torch.long) for c in compressed_bits] |
| lengths = [seq.numel() for seq in sequences] |
| if len(set(lengths)) != 1: |
| raise ValueError("Sequences decompress to different lengths") |
| bits = torch.stack(sequences) |
| return self.forward(bits, causal=causal) |
|
|
| def _current_params(self) -> Dict: |
| """Return a dictionary with the current model hyperparameters.""" |
| return { |
| "d_model": self.d_model, |
| "nhead": self.layers[0].self_attn.num_heads, |
| "num_layers": self.num_layers, |
| "dim_feedforward": self.layers[0].linear1.out_features, |
| "max_seq_len": self.pos_enc.pe.size(0), |
| "lambda_K": self.lambda_K, |
| "lambda_C": self.lambda_C, |
| "lambda_S": self.lambda_S, |
| "reversible": self.reversible, |
| "use_checkpoint": self.use_checkpoint, |
| "use_autocast": self.use_autocast, |
| "use_act": self.use_act, |
| "act_threshold": self.act_threshold, |
| "chunk_size": self.chunk_size, |
| "overlap": self.overlap, |
| } |
|
|
| def double_width(self) -> "BitTransformerLM": |
| """Return a copy of the model with doubled hidden size.""" |
| from .scale import expand_model |
|
|
| params = self._current_params() |
| params["d_model"] *= 2 |
| params["dim_feedforward"] *= 2 |
| return expand_model(self, params) |
|
|
| def double_layers(self) -> "BitTransformerLM": |
| """Return a copy of the model with twice as many layers.""" |
| from .scale import expand_model |
|
|
| params = self._current_params() |
| params["num_layers"] *= 2 |
| return expand_model(self, params) |
|
|
| def double_length(self) -> "BitTransformerLM": |
| """Return a copy of the model with doubled maximum sequence length.""" |
| from .scale import expand_model |
|
|
| params = self._current_params() |
| params["max_seq_len"] *= 2 |
| params["chunk_size"] = params["max_seq_len"] |
| return expand_model(self, params) |
|
|
| def train_full_sequence( |
| self, |
| bits: torch.Tensor, |
| *, |
| ctx_bits: int = 4096, |
| detach_every_n: int = 1_048_576, |
| ) -> float: |
| """Train on a long bit tensor using sliding windows. |
| |
| Parameters |
| ---------- |
| bits: ``torch.Tensor`` |
| 1D tensor containing the full bit sequence. |
| ctx_bits: int |
| Size of the training context window. |
| detach_every_n: int |
| Interval in bits for optimizer updates and graph detachment. |
| Returns |
| ------- |
| float |
| Mean loss over all windows. |
| """ |
| self.train() |
| optimizer, scheduler = configure_optimizer( |
| self, lr=1e-3, total_steps=max(1, bits.numel() // ctx_bits) |
| ) |
| accum = 0 |
| total_loss = 0.0 |
| count = 0 |
| for start in range(0, bits.numel() - ctx_bits - 1, ctx_bits): |
| segment = bits[start : start + ctx_bits + 1].unsqueeze(0) |
| logits, _ = self(segment) |
| pred = logits[:, :-1, :].reshape(-1, 2) |
| target = segment[:, 1:].reshape(-1) |
| loss = F.cross_entropy(pred, target) |
| loss.backward() |
| accum += ctx_bits |
| total_loss += loss.item() |
| count += 1 |
| if accum >= detach_every_n: |
| torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| accum = 0 |
| if accum > 0: |
| torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| return total_loss / max(1, count) |
|
|
|
|
| def infer_long_sequence( |
| model: BitTransformerLM, |
| bits: torch.Tensor, |
| *, |
| ctx_bits: int = 4096, |
| overlap: int = 256, |
| ) -> Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]: |
| """Infer a long bit sequence using sliding windows with overlap.""" |
| model.eval() |
| device = next(model.parameters()).device |
| bits = bits.to(device) |
| step = ctx_bits - overlap |
| outputs: List[torch.Tensor] = [] |
| logs: List[Dict[str, torch.Tensor]] = [] |
| for start in range(0, bits.numel(), step): |
| window = bits[start : start + ctx_bits].unsqueeze(0) |
| logits, tele = model(window, causal=True) |
| pred = logits.argmax(-1).squeeze(0) |
| outputs.append(pred) |
| logs.append(tele) |
| out = torch.cat(outputs)[: bits.numel()] |
| return out, logs |
|
|
|
|
| def diffusion_inference( |
| model: BitTransformerLM, |
| *, |
| length: int, |
| steps: int = 8, |
| batch_size: int = 1, |
| init_bits: Optional[torch.Tensor] = None, |
| schedule: str = "linear", |
| ) -> torch.Tensor: |
| """Generate bit sequences using iterative denoising diffusion. |
| |
| Parameters |
| ---------- |
| model: ``BitTransformerLM`` |
| The model used for denoising. It is run in non-causal mode with |
| chunked attention disabled, enabling full-context bidirectional |
| attention. |
| length: int |
| Length of the bit sequences to generate. |
| steps: int, default ``8`` |
| Number of denoising iterations. More steps generally yield sharper |
| samples at the cost of compute. |
| batch_size: int, default ``1`` |
| Number of sequences to generate in parallel. |
| init_bits: ``torch.Tensor`` | ``None`` |
| Optional initial noisy bits of shape ``(batch_size, length)``. When |
| ``None`` random noise is used. |
| schedule: str, default ``"linear"`` |
| Noise schedule for the denoising mask probability. Options are |
| ``"linear"``, ``"cosine"``, and ``"exp"``. |
| |
| Returns |
| ------- |
| ``torch.Tensor`` |
| A tensor of shape ``(batch_size, length)`` containing generated bits. |
| """ |
|
|
| model.eval() |
| device = next(model.parameters()).device |
| if init_bits is None: |
| bits = torch.randint(0, 2, (batch_size, length), device=device) |
| else: |
| bits = init_bits.to(device) |
| if bits.shape != (batch_size, length): |
| raise ValueError("init_bits must have shape (batch_size, length)") |
|
|
| for step in range(steps): |
| logits, _ = model(bits, causal=False) |
| prob = logits.softmax(-1)[..., 1] |
| t = (step + 1) / steps |
| if schedule == "linear": |
| mask_prob = 1.0 - t |
| elif schedule == "cosine": |
| mask_prob = math.cos(math.pi * t / 2) |
| elif schedule == "exp": |
| mask_prob = math.exp(-5 * t) |
| else: |
| raise ValueError(f"unknown schedule: {schedule}") |
| mask = (torch.rand_like(bits.float()) < mask_prob).long() |
| sampled = torch.bernoulli(prob).long() |
| bits = torch.where(mask.bool(), sampled, bits) |
| if bits.shape[-1] % 9 == 0: |
| bits, corrections = enforce_parity(bits) |
| if corrections: |
| logging.info("Parity corrections applied: %d", corrections) |
| try: |
| from .safety import hil_safe_inference |
|
|
| hil_safe_inference(model, bits, causal=False, strict=False) |
| except RuntimeError as exc: |
| logging.warning("Safety gate warning: %s", exc) |
| return bits |
|
|
|
|
| def example_usage() -> float: |
| """Run the example from the README and return the loss.""" |
| B, L = 4, 16 |
| model = BitTransformerLM( |
| d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=L |
| ) |
| bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
| logits, _ = model(bits) |
| pred = logits[:, :-1, :].reshape(-1, 2) |
| target = bits[:, 1:].reshape(-1) |
| loss = F.cross_entropy(pred, target) |
| return loss.item() |
|
|
|
|
| def example_training_step() -> Tuple[float, Dict[str, torch.Tensor]]: |
| """Demonstrate a training step where metrics do not affect gradients.""" |
| B, L = 4, 16 |
| model = BitTransformerLM( |
| d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L |
| ) |
| optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=1) |
|
|
| bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
| logits, telemetry = model(bits) |
|
|
| pred = logits[:, :-1, :].reshape(-1, 2) |
| target = bits[:, 1:].reshape(-1) |
| loss = F.cross_entropy(pred, target) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| return loss.item(), telemetry |
|
|
|
|
| if __name__ == "__main__": |
| loss, telemetry = example_training_step() |
| print("Composite loss:", loss) |
| print("Telemetry keys:", list(telemetry.keys())) |
|
|