| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| 'use strict'; |
| const { seam } = require('./compose'); |
| const { wordsOnly } = require('./fragments'); |
|
|
| |
| |
| |
| function buildTransitions(store) { |
| const bySrc = new Map(); |
| store.fragments.forEach((f, i) => { |
| if (f.tier === 1 || f.posTag === 'clause' || f.isSpan) return; |
| if (!bySrc.has(f.src)) bySrc.set(f.src, []); |
| bySrc.get(f.src).push(i); |
| }); |
| const pred = [], succ = []; |
| for (const seq of bySrc.values()) for (let k = 0; k + 1 < seq.length; k++) { pred.push(seq[k]); succ.push(seq[k + 1]); } |
| return { pred, succ }; |
| } |
|
|
| |
| function mulberry32(a) { return function () { a |= 0; a = a + 0x6D2B79F5 | 0; let t = Math.imul(a ^ a >>> 15, 1 | a); t = t + Math.imul(t ^ t >>> 7, 61 | t) ^ t; return ((t ^ t >>> 14) >>> 0) / 4294967296; }; } |
|
|
| function vecOf(emb, i) { const d = emb.d, off = i * d, v = new Float32Array(d); for (let k = 0; k < d; k++) v[k] = emb.vectors[off + k]; return v; } |
| function cos(emb, i, q) { const d = emb.d, off = i * d; let s = 0; for (let k = 0; k < d; k++) s += emb.vectors[off + k] * q[k]; return s; } |
| function dot(a, b) { let s = 0; for (let k = 0; k < a.length; k++) s += a[k] * b[k]; return s; } |
|
|
| |
| |
| |
| function loadFlowMLP(filePath) { |
| try { const m = JSON.parse(fs.readFileSync(filePath, 'utf8')); return m; } catch (_) { return null; } |
| } |
| function predictNextMLP(mlp, emb, curIdx) { |
| const d = mlp.d, H = mlp.H, x = new Float32Array(d); |
| const off = curIdx * d; for (let k = 0; k < d; k++) x[k] = emb.vectors[off + k]; |
| const h = new Float32Array(H); |
| for (let j = 0; j < H; j++) { let s = mlp.b1[j]; const row = mlp.W1[j]; for (let k = 0; k < d; k++) s += row[k] * x[k]; h[j] = s > 0 ? s : 0; } |
| const out = new Float32Array(d); |
| |
| for (let i = 0; i < d; i++) { let s = mlp.b2[i]; const row = mlp.W2[i]; for (let j = 0; j < H; j++) s += row[j] * h[j]; out[i] = x[i] + s; } |
| let n = 0; for (let k = 0; k < d; k++) n += out[k] * out[k]; n = Math.sqrt(n) || 1; |
| for (let k = 0; k < d; k++) out[k] /= n; |
| return out; |
| } |
|
|
| |
| function wsim(emb, i, cur, w) { const d = emb.d, off = i * d; let s = 0; for (let k = 0; k < d; k++) s += w[k] * emb.vectors[off + k] * cur[k]; return s; } |
|
|
| |
| function project(P, vec, d, r) { const out = new Float32Array(r); for (let j = 0; j < r; j++) { const row = P[j]; let s = 0; for (let k = 0; k < d; k++) s += row[k] * vec[k]; out[j] = s; } return out; } |
| |
| function projKeys(emb, trans, attn) { |
| if (attn._keys) return attn._keys; |
| const d = emb.d, r = attn.r, n = trans.pred.length; |
| const keys = new Float32Array(n * r); |
| for (let t = 0; t < n; t++) { const off = trans.pred[t] * d; for (let j = 0; j < r; j++) { const row = attn.P[j]; let s = 0; for (let k = 0; k < d; k++) s += row[k] * emb.vectors[off + k]; keys[t * r + j] = s; } } |
| attn._keys = keys; return keys; |
| } |
|
|
| |
| |
| |
| |
| |
| function predictNext(emb, trans, curIdx, K, attn) { |
| const cur = vecOf(emb, curIdx); |
| const d = emb.d; |
| const scored = []; |
| if (attn && attn.P) { |
| const r = attn.r, keys = projKeys(emb, trans, attn), pq = project(attn.P, cur, d, r); |
| for (let t = 0; t < trans.pred.length; t++) { const off = t * r; let s = 0; for (let j = 0; j < r; j++) s += pq[j] * keys[off + j]; scored.push([t, s]); } |
| } else if (attn) for (let t = 0; t < trans.pred.length; t++) scored.push([t, wsim(emb, trans.pred[t], cur, attn.w)]); |
| else for (let t = 0; t < trans.pred.length; t++) scored.push([t, cos(emb, trans.pred[t], cur)]); |
| scored.sort((a, b) => b[1] - a[1]); |
| const top = scored.slice(0, K || 40); |
| const tau = attn ? attn.tau : 8; |
| const out = new Float32Array(d); |
| let wsum = 0; |
| for (const [t, s] of top) { const w = Math.exp((s - top[0][1]) * tau); wsum += w; const off = trans.succ[t] * d; for (let k = 0; k < d; k++) out[k] += w * emb.vectors[off + k]; } |
| if (wsum > 0) for (let k = 0; k < d; k++) out[k] /= wsum; |
| let n = 0; for (let k = 0; k < d; k++) n += out[k] * out[k]; n = Math.sqrt(n) || 1; |
| for (let k = 0; k < d; k++) out[k] /= n; |
| return out; |
| } |
|
|
| |
| |
| function composeFlow(store, vp, query, opts = {}) { |
| const { fragments, oracle } = store; |
| const emb = opts.emb; |
| const rel = opts.relevance || new Map(); |
| const trans = opts._trans || buildTransitions(store); |
| const target = opts.targetLength || 90; |
| |
| |
| |
| |
| const temp = opts.temp || 0; |
| const rng = mulberry32((opts.seed || 1) >>> 0); |
| const sampleTop = (cands) => { |
| |
| if (temp <= 0.001 || cands.length === 1) return cands[0][0]; |
| const N = Math.min(cands.length, 8); |
| const top = cands.slice(0, N); |
| const s0 = top[0][1]; |
| const ws = top.map(([, s]) => Math.exp((s - s0) / Math.max(0.05, temp))); |
| const sum = ws.reduce((a, b) => a + b, 0); |
| let r = rng() * sum; |
| for (let k = 0; k < N; k++) { r -= ws[k]; if (r <= 0) return top[k][0]; } |
| return top[N - 1][0]; |
| }; |
|
|
| |
| let anchor = -1, best = -Infinity; |
| for (let i = 0; i < fragments.length; i++) { |
| const f = fragments[i]; |
| if (f.tier === 1 || !f.sentenceInitial || f.posTag === 'clause' || f.isSpan) continue; |
| const r = (rel.get(i) || 0); |
| if (r > best) { best = r; anchor = i; } |
| } |
| if (anchor < 0) anchor = fragments.findIndex(f => f.sentenceInitial && f.tier !== 1); |
|
|
| const chain = [anchor]; |
| const used = new Set([anchor]); |
| let len = wordsOnly(fragments[anchor].text).length; |
|
|
| for (let step = 0; step < 12 && len < target * 1.25; step++) { |
| const tail = chain[chain.length - 1]; |
| const tailF = fragments[tail]; |
| const terminal = /[.!?…]["')\]]*$/.test(tailF.text.trim()); |
| if (len >= target * 0.7 && terminal) break; |
| const eNext = opts.mlp ? predictNextMLP(opts.mlp, emb, tail) : predictNext(emb, trans, tail, opts.K || 40, opts.attn); |
| |
| |
| const cands = []; |
| for (let i = 0; i < fragments.length; i++) { |
| if (used.has(i) || fragments[i].tier === 1 || fragments[i].isSpan) continue; |
| if (!seam(tailF, fragments[i], oracle)) continue; |
| const flowSim = cos(emb, i, eNext); |
| cands.push([i, flowSim * 0.7 + (rel.get(i) || 0) * 0.3]); |
| } |
| if (!cands.length) break; |
| cands.sort((a, b) => b[1] - a[1]); |
| const bestI = sampleTop(cands); |
| chain.push(bestI); used.add(bestI); len += wordsOnly(fragments[bestI].text).length; |
| } |
|
|
| const chainF = chain.map(i => fragments[i]); |
| let out = chainF[0].text; |
| for (let k = 1; k < chainF.length; k++) { const sm = seam(chainF[k - 1], chainF[k], oracle); out += (sm === 'sent' ? ' ' : ' ') + chainF[k].text; } |
| return { text: out, fragmentsUsed: chainF.map(f => f.text), words: wordsOnly(out).length, target, method: 'flow' }; |
| } |
|
|
| module.exports = { composeFlow, buildTransitions, loadFlowMLP, predictNextMLP, predictNext }; |
|
|