infon-coref-pointer / js /src /model.ts
cp500's picture
Upload js/src/model.ts with huggingface_hub
6fb4bfe verified
Raw
History Blame Contribute Delete
13.6 kB
/**
* InfonCorefModel β€” the main public class.
*
* Wires together: tokenizer β†’ backbone+BIO ONNX β†’ BIO decode β†’
* mention scorer ONNX β†’ cluster grouping. Everything else in this
* package is a helper for one of those stages.
*/
import { decodeBio } from './bio.js';
import { fetchHubFile, fetchHubJson, hubUrl } from './hub.js';
import { createSession, makeTensor, type OrtSession, type OrtTensor } from './ort.js';
import { buildPairs, groupClusters, pickAntecedents } from './pairs.js';
import { loadTokenizer, type Tokenizer } from './tokenizer.js';
import type {
CorefResult,
LoadedModel,
Mention,
ModelOptions,
Token,
} from './types.js';
/** Files that must exist in an HF model repo for this client to load. */
interface RepoMeta {
hidden_size: number;
n_bio_classes: number;
max_length: number;
/** ``"<s>"`` for XLM-R, ``"[CLS]"`` for BERT. */
cls_token: string;
/** Filenames keyed by precision. */
files: {
backbone_fp32?: string;
backbone_fp16?: string;
scorer_fp32?: string;
scorer_fp16?: string;
tokenizer: string;
};
}
/**
* Loads + runs the multilingual coreference pointer model.
*
* @example Browser
* ```ts
* import { InfonCorefModel } from '@cp500/infon-coref';
*
* const model = await InfonCorefModel.fromHub('cp500/infon-coref-pointer', {
* precision: 'fp16',
* device: 'auto', // tries WebGPU, falls back to WASM
* });
*
* const result = await model.resolve(
* "Toyota announced a partnership with Panasonic. " +
* "The Japanese automaker said the deal is worth $250M."
* );
*
* for (const cluster of result.clusters) {
* const mentions = cluster.map(i => result.mentions[i].text);
* console.log(mentions.join(' = '));
* // Toyota = The Japanese automaker
* }
* ```
*
* @example Node
* ```ts
* import { InfonCorefModel } from '@cp500/infon-coref';
*
* const model = await InfonCorefModel.fromLocal('./models/coref-pointer/');
* const result = await model.resolve(text);
* ```
*/
export class InfonCorefModel {
/** Built via {@link fromHub} / {@link fromLocal}; do not construct
* directly. */
private constructor(private readonly loaded: LoadedModel) {}
/** Load the model artefacts from a Hugging Face repo.
*
* Caches downloads in the browser Cache API (named
* ``infon-coref-v1``); subsequent loads in the same origin are
* instant. */
static async fromHub(
repo: string,
opts: ModelOptions & { revision?: string } = {},
): Promise<InfonCorefModel> {
const meta = await fetchHubJson<RepoMeta>({
repo,
path: 'meta.json',
revision: opts.revision,
});
const precision = opts.precision ?? 'fp16';
const backboneFile =
precision === 'fp16'
? meta.files.backbone_fp16 ?? meta.files.backbone_fp32
: meta.files.backbone_fp32 ?? meta.files.backbone_fp16;
const scorerFile =
precision === 'fp16'
? meta.files.scorer_fp16 ?? meta.files.scorer_fp32
: meta.files.scorer_fp32 ?? meta.files.scorer_fp16;
if (!backboneFile || !scorerFile) {
throw new Error(
`repo ${repo} is missing backbone/scorer ONNX files; check meta.json`,
);
}
// The browser will follow .onnx β†’ .onnx.data automatically when
// ORT is given a URL β€” but ORT-web's URL loader doesn't always
// negotiate external data correctly cross-origin. Fetching to
// ArrayBuffer once and passing buffers in is more reliable.
const [backboneBuf, scorerBuf, tokBuf] = await Promise.all([
fetchHubFile({ repo, path: backboneFile, revision: opts.revision }),
fetchHubFile({ repo, path: scorerFile, revision: opts.revision }),
fetchHubFile({
repo,
path: meta.files.tokenizer,
revision: opts.revision,
}),
]);
if (opts.debug) {
const mb = (b: ArrayBuffer) => (b.byteLength / 1e6).toFixed(1);
console.debug(
`[infon-coref] loaded backbone ${mb(backboneBuf)} MB, ` +
`scorer ${mb(scorerBuf)} MB, tokenizer ${mb(tokBuf)} MB`,
);
}
return InfonCorefModel.#fromBuffers(
backboneBuf,
scorerBuf,
tokBuf,
meta,
opts,
`hub:${repo}`,
);
}
/** Load the model from a local directory.
*
* Browser: ``baseUrl`` is a URL prefix (``/models/coref/``). Node:
* a filesystem path (``./models/coref/``). The directory must
* contain ``meta.json``, the ONNX files referenced therein, and
* ``tokenizer.json``. */
static async fromLocal(
baseUrl: string,
opts: ModelOptions = {},
): Promise<InfonCorefModel> {
const isPath =
typeof window === 'undefined' && !baseUrl.startsWith('http');
const join = (name: string) =>
isPath
? `${baseUrl.replace(/\/$/, '')}/${name}`
: new URL(name, baseUrl.endsWith('/') ? baseUrl : baseUrl + '/').href;
const meta = await loadJson<RepoMeta>(join('meta.json'), isPath);
const precision = opts.precision ?? 'fp16';
const backboneFile =
precision === 'fp16'
? meta.files.backbone_fp16 ?? meta.files.backbone_fp32
: meta.files.backbone_fp32 ?? meta.files.backbone_fp16;
const scorerFile =
precision === 'fp16'
? meta.files.scorer_fp16 ?? meta.files.scorer_fp32
: meta.files.scorer_fp32 ?? meta.files.scorer_fp16;
if (!backboneFile || !scorerFile) {
throw new Error(
`local model missing backbone/scorer ONNX files in meta.json`,
);
}
const [backboneBuf, scorerBuf, tokBuf] = await Promise.all([
loadBytes(join(backboneFile), isPath),
loadBytes(join(scorerFile), isPath),
loadBytes(join(meta.files.tokenizer), isPath),
]);
return InfonCorefModel.#fromBuffers(
backboneBuf,
scorerBuf,
tokBuf,
meta,
opts,
`local:${baseUrl}`,
);
}
static async #fromBuffers(
backboneBuf: ArrayBuffer,
scorerBuf: ArrayBuffer,
tokBuf: ArrayBuffer,
meta: RepoMeta,
opts: ModelOptions,
sourceTag: string,
): Promise<InfonCorefModel> {
const device = opts.device ?? 'auto';
const [backbone, scorer, tokenizer] = await Promise.all([
createSession(new Uint8Array(backboneBuf), device),
createSession(new Uint8Array(scorerBuf), device),
loadTokenizer(new Uint8Array(tokBuf)),
]);
return new InfonCorefModel({
backbone,
scorer,
tokenizer,
meta: {
hiddenSize: meta.hidden_size,
nBioClasses: meta.n_bio_classes,
maxLength: opts.maxLength ?? meta.max_length,
precision: opts.precision ?? 'fp16',
device:
device === 'auto'
? typeof window !== 'undefined'
? 'auto-browser'
: 'auto-node'
: device,
},
});
}
/** Run end-to-end coreference resolution on a single document. */
async resolve(text: string, opts: ModelOptions = {}): Promise<CorefResult> {
const debug = opts.debug ?? false;
const t0 = nowMs();
// 1. Tokenize.
const enc = this.loaded.tokenizer.tokenize(text, {
maxLength: opts.maxLength ?? this.loaded.meta.maxLength,
});
const T = enc.inputIds.length;
const t1 = nowMs();
// 2. Backbone forward β†’ (last_hidden_state, bio_logits).
const idsT = await makeTensor('int64', enc.inputIds, [1, T]);
const maskT = await makeTensor('int64', enc.attentionMask, [1, T]);
const bbOut = await this.loaded.backbone.run({
input_ids: idsT,
attention_mask: maskT,
});
const hiddenTensor = bbOut.last_hidden_state ?? bbOut.hidden_states;
const bioTensor = bbOut.bio_logits;
if (!hiddenTensor || !bioTensor) {
throw new Error(
`backbone outputs missing; got: [${Object.keys(bbOut).join(', ')}]`,
);
}
const t2 = nowMs();
// 3. Decode BIO into wordpiece spans.
const bioLogits = floatArray(bioTensor);
const spans = decodeBio(
bioLogits,
enc.attentionMask,
opts.bioThreshold,
);
const t3 = nowMs();
if (spans.length === 0) {
// No mentions detected β€” short-circuit so we don't run the
// scorer with empty inputs (some ORT EPs choke on M=0).
return this.#emptyResult(text, enc.tokens, [t0, t1, t2, t3]);
}
// 4. Build pair tensors + run scorer.
const M = spans.length;
const starts = BigInt64Array.from(spans.map(([s]) => BigInt(s)));
const ends = BigInt64Array.from(spans.map(([, e]) => BigInt(e)));
const [pairI, pairJ] = buildPairs(M);
// Hidden is (1, T, H); the scorer wants (T, H).
const hiddenFlat = floatArray(hiddenTensor);
const H = this.loaded.meta.hiddenSize;
const hiddenT = await makeTensor('float32', hiddenFlat, [T, H]);
const startsT = await makeTensor('int64', starts, [M]);
const endsT = await makeTensor('int64', ends, [M]);
const piT = await makeTensor('int64', pairI, [pairI.length]);
const pjT = await makeTensor('int64', pairJ, [pairJ.length]);
const scOut = await this.loaded.scorer.run({
hidden: hiddenT,
span_starts: startsT,
span_ends: endsT,
pair_i: piT,
pair_j: pjT,
});
const scoresTensor = scOut.pair_scores;
if (!scoresTensor) {
throw new Error(
`scorer output missing 'pair_scores'; got [${Object.keys(scOut).join(', ')}]`,
);
}
const scores = floatArray(scoresTensor);
const t4 = nowMs();
// 5. Per-mention argmax + cluster grouping.
const decisions = pickAntecedents(M, pairI, pairJ, scores);
const grouping = groupClusters(decisions);
// 6. Project wordpiece spans to char offsets via the tokenizer's
// offset map.
const mentions: Mention[] = spans.map(([wstart, wend], i) => {
const sTok = enc.tokens[wstart];
const eTok = enc.tokens[wend];
const charStart = sTok?.start ?? 0;
const charEnd = eTok?.end ?? charStart;
return {
start: wstart,
end: wend,
charStart,
charEnd,
text: text.slice(charStart, charEnd),
cluster: grouping.cluster[i],
antecedent: decisions[i].antecedent - 1, // 0-based; -1 = DUMMY
};
});
if (debug) {
console.debug('[infon-coref] timings (ms)', {
tokenize: t1 - t0,
backbone: t2 - t1,
bioDecode: t3 - t2,
scorer: t4 - t3,
total: t4 - t0,
});
}
return {
text,
tokens: enc.tokens,
mentions,
clusters: grouping.clusters,
timing: {
tokenize: t1 - t0,
backbone: t2 - t1,
bioDecode: t3 - t2,
scorer: t4 - t3,
total: t4 - t0,
},
};
}
/** Architecture metadata loaded from the repo's ``meta.json``. */
get meta() {
return this.loaded.meta;
}
#emptyResult(
text: string,
tokens: Token[],
[t0, t1, t2, t3]: number[],
): CorefResult {
return {
text,
tokens,
mentions: [],
clusters: [],
timing: {
tokenize: t1 - t0,
backbone: t2 - t1,
bioDecode: t3 - t2,
scorer: 0,
total: t3 - t0,
},
};
}
}
// ── Internal helpers ────────────────────────────────────────────────
function nowMs(): number {
return typeof performance !== 'undefined'
? performance.now()
: Date.now();
}
function floatArray(t: OrtTensor): Float32Array {
// FP32 tensors are Float32Array; FP16 ORT tensors are Uint16Array
// bit-packed half-floats. We only run the scorer in FP32 currently
// (it's tiny and FP16 buys nothing), but the backbone may return
// FP16 hidden states β€” promote to Float32 so the scorer feed is
// shape-correct. For FP16 β†’ FP32 we use a quick bit-twiddle (no
// dependency on a runtime fp16 polyfill).
if (t.data instanceof Float32Array) return t.data;
if (t.data instanceof Uint16Array) {
const out = new Float32Array(t.data.length);
for (let i = 0; i < t.data.length; i++) {
out[i] = halfToFloat(t.data[i]);
}
return out;
}
throw new Error(
`expected Float32Array or Uint16Array tensor, got ${(t.data as { constructor: { name: string } }).constructor.name}`,
);
}
/** IEEE 754 half β†’ single. Fast enough for the per-token volume we
* see (a few thousand floats per doc); for very long inputs prefer
* an FP16 backbone variant that ORT itself converts at the boundary. */
function halfToFloat(h: number): number {
const sign = (h & 0x8000) >> 15;
const exp = (h & 0x7c00) >> 10;
const frac = h & 0x03ff;
if (exp === 0) {
return (sign ? -1 : 1) * Math.pow(2, -14) * (frac / 1024);
} else if (exp === 0x1f) {
return frac ? NaN : (sign ? -1 : 1) * Infinity;
}
return (sign ? -1 : 1) * Math.pow(2, exp - 15) * (1 + frac / 1024);
}
async function loadJson<T>(path: string, isFsPath: boolean): Promise<T> {
if (isFsPath) {
const fs = await import('node:fs/promises');
return JSON.parse(await fs.readFile(path, 'utf-8')) as T;
}
const r = await fetch(path);
if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`);
return (await r.json()) as T;
}
async function loadBytes(
path: string,
isFsPath: boolean,
): Promise<ArrayBuffer> {
if (isFsPath) {
const fs = await import('node:fs/promises');
const buf = await fs.readFile(path);
return buf.buffer.slice(
buf.byteOffset,
buf.byteOffset + buf.byteLength,
);
}
const r = await fetch(path);
if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`);
return await r.arrayBuffer();
}