| import torch
|
| import torch.nn.functional as F
|
|
|
| from birwkv7 import BiRWKV7Layer
|
|
|
|
|
| def wkv7_forward_scan(r, w, k, v, a, sab_scale, init_state=None):
|
| B, T, H, D = r.shape
|
| r, w, k, v, a = [x.float() for x in (r, w, k, v, a)]
|
| k = k * (D ** -0.5)
|
| decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w))
|
| a = torch.sigmoid(a)
|
| sab_s = float(sab_scale)
|
| state = init_state.float().clone() if init_state is not None else \
|
| torch.zeros(B, H, D, D, device=r.device, dtype=torch.float32)
|
| outputs = []
|
| for t in range(T):
|
| kt, vt, rt, at, dt = k[:, t], v[:, t], r[:, t], a[:, t], decay[:, t]
|
| sa = torch.einsum('bhij,bhj->bhi', state, -kt)
|
| sab = torch.einsum('bhi,bhj->bhij', sa, kt * at)
|
| state = state * dt.unsqueeze(-2) + sab_s * sab + \
|
| torch.einsum('bhi,bhj->bhij', vt, kt)
|
| state = state.clamp(-10.0, 10.0)
|
| outputs.append(torch.einsum('bhij,bhj->bhi', state, rt))
|
| return torch.stack(outputs, dim=1), state.detach()
|
|
|
|
|
| class SpanEncoder:
|
|
|
| def __init__(self, model, tokenizer, device, chunk_size=512):
|
| self.model = model
|
| self.tokenizer = tokenizer
|
| self.device = device
|
| self.chunk_size = chunk_size
|
|
|
| self.birwkv_layers = []
|
| self.birwkv_ids = {}
|
| for m in model.modules():
|
| if isinstance(m, BiRWKV7Layer):
|
| self.birwkv_ids[id(m)] = len(self.birwkv_layers)
|
| self.birwkv_layers.append(m)
|
|
|
| self._originals = {}
|
| self._hooked = False
|
| self._active_states = [None] * len(self.birwkv_layers)
|
| self.span_data = {}
|
|
|
| def _hook(self):
|
| if self._hooked:
|
| return
|
| for layer in self.birwkv_layers:
|
| self._originals[id(layer)] = layer.forward
|
| layer.forward = self._make_fwd(layer)
|
| self._hooked = True
|
|
|
| def _unhook(self):
|
| if not self._hooked:
|
| return
|
| for layer in self.birwkv_layers:
|
| layer.forward = self._originals[id(layer)]
|
| self._originals.clear()
|
| self._hooked = False
|
|
|
| def _make_fwd(self, layer):
|
| enc = self
|
| idx = self.birwkv_ids[id(layer)]
|
|
|
| def fwd(x, attention_mask=None, **kwargs):
|
| B, T, C_ = x.shape
|
| H, D = layer.num_heads, layer.head_size
|
| prev = enc._active_states[idx]
|
| if prev is not None:
|
| x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1)
|
| else:
|
| x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
|
|
|
| def mix(mu):
|
| return x + (x_prev - x) * torch.sigmoid(mu)
|
|
|
| r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D)
|
| w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D)
|
| k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D)
|
| v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D)
|
| a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D)
|
| g = torch.sigmoid(layer.W_g(mix(layer.mu_g)))
|
| sab_scale = torch.sigmoid(layer.sab_gate)
|
| init_st = prev['wkv_state'] if prev else None
|
|
|
| try:
|
| from birwkv7_triton import wkv7_scan_triton
|
| r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float()
|
| a_f = torch.sigmoid(a.float())
|
| decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float()))
|
| out_fwd, wkv_state = wkv7_scan_triton(
|
| r_f, decay, k_f, v_f, a_f, sab_scale,
|
| return_state=True, init_state=init_st)
|
| out_bwd = wkv7_scan_triton(
|
| r_f.flip(1), decay.flip(1), k_f.flip(1),
|
| v_f.flip(1), a_f.flip(1), sab_scale,
|
| return_state=False).flip(1)
|
| except (ImportError, Exception):
|
| out_fwd, wkv_state = wkv7_forward_scan(
|
| r, w, k, v, a, sab_scale, init_st)
|
| out_bwd = wkv7_forward_scan(
|
| r.flip(1), w.flip(1), k.flip(1),
|
| v.flip(1), a.flip(1), sab_scale, None)[0].flip(1)
|
| enc._active_states[idx] = {
|
| 'wkv_state': wkv_state,
|
| 'last_x': x[:, -1:].detach().clone(),
|
| }
|
| out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_)
|
| out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2)
|
| out = layer.W_o(out * g)
|
| return out, None
|
| return fwd
|
|
|
| @torch.no_grad()
|
| def _forward_encode_raw(self, text, init_states=None, max_length=8192):
|
| self._hook()
|
| if init_states is not None:
|
| self._active_states = [
|
| {k: v.clone() for k, v in s.items()} if s else None
|
| for s in init_states
|
| ]
|
| else:
|
| self._active_states = [None] * len(self.birwkv_layers)
|
|
|
| enc = self.tokenizer(text, return_tensors='pt', truncation=True,
|
| max_length=max_length)
|
| ids = enc['input_ids'].to(self.device)
|
| mask = enc['attention_mask'].to(self.device)
|
|
|
| h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| content = h[0, 1:-1, :].cpu()
|
| n_content = content.shape[0]
|
|
|
| final_states = [
|
| {k: v.clone() for k, v in s.items()} if s else None
|
| for s in self._active_states
|
| ]
|
| self._unhook()
|
| return content, n_content, final_states
|
|
|
| def _chunk_hidden(self, content, return_residual=False):
|
| T = content.shape[0]
|
| chunks = []
|
| last_end = 0
|
| for start in range(0, T, self.chunk_size):
|
| end = min(start + self.chunk_size, T)
|
| if end - start < 32:
|
| break
|
| emb = F.normalize(content[start:end].mean(0, keepdim=True),
|
| p=2, dim=-1)
|
| chunks.append(emb)
|
| last_end = end
|
| if not chunks and T > 0:
|
| chunks.append(F.normalize(content.mean(0, keepdim=True),
|
| p=2, dim=-1))
|
| last_end = T
|
| if return_residual:
|
| residual = content[last_end:] if last_end < T else None
|
| return chunks, residual
|
| return chunks
|
|
|
| @torch.no_grad()
|
| def encode_query(self, query):
|
| assert not self._hooked
|
| enc = self.tokenizer(query, return_tensors='pt', truncation=True,
|
| max_length=512)
|
| ids = enc['input_ids'].to(self.device)
|
| mask = enc['attention_mask'].to(self.device)
|
| h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
|
| m = mask.unsqueeze(-1).float()
|
| emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9)
|
| return F.normalize(emb, p=2, dim=-1).cpu()
|
|
|
| def encode_span(self, text, key):
|
| content, n_tok, states = self._forward_encode_raw(text)
|
| chunks, residual = self._chunk_hidden(content, return_residual=True)
|
| self.span_data[key] = {
|
| 'layer_states': states,
|
| 'chunk_embs': chunks,
|
| 'n_tokens': n_tok,
|
| 'residual_hidden': residual,
|
| }
|
| return n_tok
|
|
|
| def extend_right(self, piece_text, old_key, new_key):
|
| old = self.span_data.pop(old_key)
|
| content, n_new, states = self._forward_encode_raw(
|
| piece_text, init_states=old['layer_states'])
|
| if old.get('residual_hidden') is not None:
|
| content = torch.cat([old['residual_hidden'], content], dim=0)
|
| new_chunks, residual = self._chunk_hidden(
|
| content, return_residual=True)
|
| self.span_data[new_key] = {
|
| 'layer_states': states,
|
| 'chunk_embs': old['chunk_embs'] + new_chunks,
|
| 'n_tokens': old['n_tokens'] + n_new,
|
| 'residual_hidden': residual,
|
| }
|
| return n_new
|
|
|