/** * InfonCorefModel — the main public class. * * Wires together: tokenizer → backbone+BIO ONNX → BIO decode → * mention scorer ONNX → cluster grouping. Everything else in this * package is a helper for one of those stages. */ import { decodeBio } from './bio.js'; import { fetchHubFile, fetchHubJson, hubUrl } from './hub.js'; import { createSession, makeTensor, type OrtSession, type OrtTensor } from './ort.js'; import { buildPairs, groupClusters, pickAntecedents } from './pairs.js'; import { loadTokenizer, type Tokenizer } from './tokenizer.js'; import type { CorefResult, LoadedModel, Mention, ModelOptions, Token, } from './types.js'; /** Files that must exist in an HF model repo for this client to load. */ interface RepoMeta { hidden_size: number; n_bio_classes: number; max_length: number; /** ``""`` for XLM-R, ``"[CLS]"`` for BERT. */ cls_token: string; /** Filenames keyed by precision. */ files: { backbone_fp32?: string; backbone_fp16?: string; scorer_fp32?: string; scorer_fp16?: string; tokenizer: string; }; } /** * Loads + runs the multilingual coreference pointer model. * * @example Browser * ```ts * import { InfonCorefModel } from '@cp500/infon-coref'; * * const model = await InfonCorefModel.fromHub('cp500/infon-coref-pointer', { * precision: 'fp16', * device: 'auto', // tries WebGPU, falls back to WASM * }); * * const result = await model.resolve( * "Toyota announced a partnership with Panasonic. " + * "The Japanese automaker said the deal is worth $250M." * ); * * for (const cluster of result.clusters) { * const mentions = cluster.map(i => result.mentions[i].text); * console.log(mentions.join(' = ')); * // Toyota = The Japanese automaker * } * ``` * * @example Node * ```ts * import { InfonCorefModel } from '@cp500/infon-coref'; * * const model = await InfonCorefModel.fromLocal('./models/coref-pointer/'); * const result = await model.resolve(text); * ``` */ export class InfonCorefModel { /** Built via {@link fromHub} / {@link fromLocal}; do not construct * directly. */ private constructor(private readonly loaded: LoadedModel) {} /** Load the model artefacts from a Hugging Face repo. * * Caches downloads in the browser Cache API (named * ``infon-coref-v1``); subsequent loads in the same origin are * instant. */ static async fromHub( repo: string, opts: ModelOptions & { revision?: string } = {}, ): Promise { const meta = await fetchHubJson({ repo, path: 'meta.json', revision: opts.revision, }); const precision = opts.precision ?? 'fp16'; const backboneFile = precision === 'fp16' ? meta.files.backbone_fp16 ?? meta.files.backbone_fp32 : meta.files.backbone_fp32 ?? meta.files.backbone_fp16; const scorerFile = precision === 'fp16' ? meta.files.scorer_fp16 ?? meta.files.scorer_fp32 : meta.files.scorer_fp32 ?? meta.files.scorer_fp16; if (!backboneFile || !scorerFile) { throw new Error( `repo ${repo} is missing backbone/scorer ONNX files; check meta.json`, ); } // The browser will follow .onnx → .onnx.data automatically when // ORT is given a URL — but ORT-web's URL loader doesn't always // negotiate external data correctly cross-origin. Fetching to // ArrayBuffer once and passing buffers in is more reliable. const [backboneBuf, scorerBuf, tokBuf] = await Promise.all([ fetchHubFile({ repo, path: backboneFile, revision: opts.revision }), fetchHubFile({ repo, path: scorerFile, revision: opts.revision }), fetchHubFile({ repo, path: meta.files.tokenizer, revision: opts.revision, }), ]); if (opts.debug) { const mb = (b: ArrayBuffer) => (b.byteLength / 1e6).toFixed(1); console.debug( `[infon-coref] loaded backbone ${mb(backboneBuf)} MB, ` + `scorer ${mb(scorerBuf)} MB, tokenizer ${mb(tokBuf)} MB`, ); } return InfonCorefModel.#fromBuffers( backboneBuf, scorerBuf, tokBuf, meta, opts, `hub:${repo}`, ); } /** Load the model from a local directory. * * Browser: ``baseUrl`` is a URL prefix (``/models/coref/``). Node: * a filesystem path (``./models/coref/``). The directory must * contain ``meta.json``, the ONNX files referenced therein, and * ``tokenizer.json``. */ static async fromLocal( baseUrl: string, opts: ModelOptions = {}, ): Promise { const isPath = typeof window === 'undefined' && !baseUrl.startsWith('http'); const join = (name: string) => isPath ? `${baseUrl.replace(/\/$/, '')}/${name}` : new URL(name, baseUrl.endsWith('/') ? baseUrl : baseUrl + '/').href; const meta = await loadJson(join('meta.json'), isPath); const precision = opts.precision ?? 'fp16'; const backboneFile = precision === 'fp16' ? meta.files.backbone_fp16 ?? meta.files.backbone_fp32 : meta.files.backbone_fp32 ?? meta.files.backbone_fp16; const scorerFile = precision === 'fp16' ? meta.files.scorer_fp16 ?? meta.files.scorer_fp32 : meta.files.scorer_fp32 ?? meta.files.scorer_fp16; if (!backboneFile || !scorerFile) { throw new Error( `local model missing backbone/scorer ONNX files in meta.json`, ); } const [backboneBuf, scorerBuf, tokBuf] = await Promise.all([ loadBytes(join(backboneFile), isPath), loadBytes(join(scorerFile), isPath), loadBytes(join(meta.files.tokenizer), isPath), ]); return InfonCorefModel.#fromBuffers( backboneBuf, scorerBuf, tokBuf, meta, opts, `local:${baseUrl}`, ); } static async #fromBuffers( backboneBuf: ArrayBuffer, scorerBuf: ArrayBuffer, tokBuf: ArrayBuffer, meta: RepoMeta, opts: ModelOptions, sourceTag: string, ): Promise { const device = opts.device ?? 'auto'; const [backbone, scorer, tokenizer] = await Promise.all([ createSession(new Uint8Array(backboneBuf), device), createSession(new Uint8Array(scorerBuf), device), loadTokenizer(new Uint8Array(tokBuf)), ]); return new InfonCorefModel({ backbone, scorer, tokenizer, meta: { hiddenSize: meta.hidden_size, nBioClasses: meta.n_bio_classes, maxLength: opts.maxLength ?? meta.max_length, precision: opts.precision ?? 'fp16', device: device === 'auto' ? typeof window !== 'undefined' ? 'auto-browser' : 'auto-node' : device, }, }); } /** Run end-to-end coreference resolution on a single document. */ async resolve(text: string, opts: ModelOptions = {}): Promise { const debug = opts.debug ?? false; const t0 = nowMs(); // 1. Tokenize. const enc = this.loaded.tokenizer.tokenize(text, { maxLength: opts.maxLength ?? this.loaded.meta.maxLength, }); const T = enc.inputIds.length; const t1 = nowMs(); // 2. Backbone forward → (last_hidden_state, bio_logits). const idsT = await makeTensor('int64', enc.inputIds, [1, T]); const maskT = await makeTensor('int64', enc.attentionMask, [1, T]); const bbOut = await this.loaded.backbone.run({ input_ids: idsT, attention_mask: maskT, }); const hiddenTensor = bbOut.last_hidden_state ?? bbOut.hidden_states; const bioTensor = bbOut.bio_logits; if (!hiddenTensor || !bioTensor) { throw new Error( `backbone outputs missing; got: [${Object.keys(bbOut).join(', ')}]`, ); } const t2 = nowMs(); // 3. Decode BIO into wordpiece spans. const bioLogits = floatArray(bioTensor); const spans = decodeBio( bioLogits, enc.attentionMask, opts.bioThreshold, ); const t3 = nowMs(); if (spans.length === 0) { // No mentions detected — short-circuit so we don't run the // scorer with empty inputs (some ORT EPs choke on M=0). return this.#emptyResult(text, enc.tokens, [t0, t1, t2, t3]); } // 4. Build pair tensors + run scorer. const M = spans.length; const starts = BigInt64Array.from(spans.map(([s]) => BigInt(s))); const ends = BigInt64Array.from(spans.map(([, e]) => BigInt(e))); const [pairI, pairJ] = buildPairs(M); // Hidden is (1, T, H); the scorer wants (T, H). const hiddenFlat = floatArray(hiddenTensor); const H = this.loaded.meta.hiddenSize; const hiddenT = await makeTensor('float32', hiddenFlat, [T, H]); const startsT = await makeTensor('int64', starts, [M]); const endsT = await makeTensor('int64', ends, [M]); const piT = await makeTensor('int64', pairI, [pairI.length]); const pjT = await makeTensor('int64', pairJ, [pairJ.length]); const scOut = await this.loaded.scorer.run({ hidden: hiddenT, span_starts: startsT, span_ends: endsT, pair_i: piT, pair_j: pjT, }); const scoresTensor = scOut.pair_scores; if (!scoresTensor) { throw new Error( `scorer output missing 'pair_scores'; got [${Object.keys(scOut).join(', ')}]`, ); } const scores = floatArray(scoresTensor); const t4 = nowMs(); // 5. Per-mention argmax + cluster grouping. const decisions = pickAntecedents(M, pairI, pairJ, scores); const grouping = groupClusters(decisions); // 6. Project wordpiece spans to char offsets via the tokenizer's // offset map. const mentions: Mention[] = spans.map(([wstart, wend], i) => { const sTok = enc.tokens[wstart]; const eTok = enc.tokens[wend]; const charStart = sTok?.start ?? 0; const charEnd = eTok?.end ?? charStart; return { start: wstart, end: wend, charStart, charEnd, text: text.slice(charStart, charEnd), cluster: grouping.cluster[i], antecedent: decisions[i].antecedent - 1, // 0-based; -1 = DUMMY }; }); if (debug) { console.debug('[infon-coref] timings (ms)', { tokenize: t1 - t0, backbone: t2 - t1, bioDecode: t3 - t2, scorer: t4 - t3, total: t4 - t0, }); } return { text, tokens: enc.tokens, mentions, clusters: grouping.clusters, timing: { tokenize: t1 - t0, backbone: t2 - t1, bioDecode: t3 - t2, scorer: t4 - t3, total: t4 - t0, }, }; } /** Architecture metadata loaded from the repo's ``meta.json``. */ get meta() { return this.loaded.meta; } #emptyResult( text: string, tokens: Token[], [t0, t1, t2, t3]: number[], ): CorefResult { return { text, tokens, mentions: [], clusters: [], timing: { tokenize: t1 - t0, backbone: t2 - t1, bioDecode: t3 - t2, scorer: 0, total: t3 - t0, }, }; } } // ── Internal helpers ──────────────────────────────────────────────── function nowMs(): number { return typeof performance !== 'undefined' ? performance.now() : Date.now(); } function floatArray(t: OrtTensor): Float32Array { // FP32 tensors are Float32Array; FP16 ORT tensors are Uint16Array // bit-packed half-floats. We only run the scorer in FP32 currently // (it's tiny and FP16 buys nothing), but the backbone may return // FP16 hidden states — promote to Float32 so the scorer feed is // shape-correct. For FP16 → FP32 we use a quick bit-twiddle (no // dependency on a runtime fp16 polyfill). if (t.data instanceof Float32Array) return t.data; if (t.data instanceof Uint16Array) { const out = new Float32Array(t.data.length); for (let i = 0; i < t.data.length; i++) { out[i] = halfToFloat(t.data[i]); } return out; } throw new Error( `expected Float32Array or Uint16Array tensor, got ${(t.data as { constructor: { name: string } }).constructor.name}`, ); } /** IEEE 754 half → single. Fast enough for the per-token volume we * see (a few thousand floats per doc); for very long inputs prefer * an FP16 backbone variant that ORT itself converts at the boundary. */ function halfToFloat(h: number): number { const sign = (h & 0x8000) >> 15; const exp = (h & 0x7c00) >> 10; const frac = h & 0x03ff; if (exp === 0) { return (sign ? -1 : 1) * Math.pow(2, -14) * (frac / 1024); } else if (exp === 0x1f) { return frac ? NaN : (sign ? -1 : 1) * Infinity; } return (sign ? -1 : 1) * Math.pow(2, exp - 15) * (1 + frac / 1024); } async function loadJson(path: string, isFsPath: boolean): Promise { if (isFsPath) { const fs = await import('node:fs/promises'); return JSON.parse(await fs.readFile(path, 'utf-8')) as T; } const r = await fetch(path); if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`); return (await r.json()) as T; } async function loadBytes( path: string, isFsPath: boolean, ): Promise { if (isFsPath) { const fs = await import('node:fs/promises'); const buf = await fs.readFile(path); return buf.buffer.slice( buf.byteOffset, buf.byteOffset + buf.byteLength, ); } const r = await fetch(path); if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`); return await r.arrayBuffer(); }