tabulm / code /eval_compare.py
rakshi719's picture
Add TabuLM training and evaluation code
f32c034 verified
Raw
History Blame Contribute Delete
16.4 kB
#!/usr/bin/env python3
"""Paired per-item evaluation: TabuLM vs KinyaBERT on the same dev items.
Outputs:
data/compare_results.json β€” per-item predictions for both models
Prints bootstrap 95% CI for both models
Prints comparison examples (TabuLM correct, KinyaBERT wrong) for error analysis
"""
import argparse
import json
import os
import random
import sys
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
DATA_DIR = '/shared/scratch/0/tmp/v_ireddi_rakshitha_results/tabulm/data'
CODE_DIR = '/shared/scratch/0/tmp/v_ireddi_rakshitha_results/tabulm/code'
TABQA_FILE = os.path.join(DATA_DIR, 'tabqa_kin.json')
CSV_DIR = os.path.join(DATA_DIR, 'tables')
TABULM_CKPT = os.path.join(DATA_DIR, 'finetune_tabqa_v3_best.pt')
KINYABERT_CKPT = os.path.join(DATA_DIR, 'baseline_kinyabert_best.pt')
sys.path.insert(0, CODE_DIR)
import youtokentome as yttm
from morpho_data_loaders import KBVocab
from tabular_serializer import serialize_csv, table_cells_to_text, TableCell
from morpho_stub import parse_text_stub
from tabulm_model import TabuLM, tabulm_base
# ── Shared gold-cell lookup ────────────────────────────────────────────────────
def find_gold_cell(cells: List[TableCell], answer_text: str,
question_text: str = '') -> Optional[Tuple[int, int]]:
answer_norm = answer_text.strip().lower()
matches = [(c.row_id, c.col_id) for c in cells
if c.row_id > 1 and c.col_id > 0
and c.content.strip() == answer_text.strip()]
if not matches:
matches = [(c.row_id, c.col_id) for c in cells
if c.row_id > 1 and c.col_id > 0
and c.content.strip().lower() == answer_norm]
if not matches:
return None
if len(matches) == 1:
return matches[0]
if question_text:
q_words = set(question_text.lower().split())
row_labels = {c.row_id: c.content.strip().lower()
for c in cells if c.col_id == 1 and c.row_id > 1}
col_headers = {c.col_id: c.content.strip().lower()
for c in cells if c.row_id == 1 and c.col_id > 0}
best, best_score = None, (-1, -1)
for (row_id, col_id) in matches:
r_sc = len(q_words & set(row_labels.get(row_id, '').split()))
c_sc = len(q_words & set(col_headers.get(col_id, '').split()))
if (r_sc, c_sc) > best_score:
best_score, best = (r_sc, c_sc), (row_id, col_id)
if best_score[0] > 0 or best_score[1] > 0:
return best
return min(matches, key=lambda rc: (rc[0], rc[1]))
def _predict_lookup(scores, valid_cells, cells, question_text):
q_words = set(question_text.lower().split())
row_labels = {c.row_id: c.content.strip().lower()
for c in cells if c.col_id == 1 and c.row_id > 1}
col_headers = {c.col_id: c.content.strip().lower()
for c in cells if c.row_id == 1 and c.col_id > 0}
if not q_words:
return valid_cells[scores.argmax().item()]
row_score = {r: len(q_words & set(lbl.split())) for r, lbl in row_labels.items()}
col_score = {c: len(q_words & set(hdr.split())) for c, hdr in col_headers.items()}
best_row = max(row_score, key=row_score.get) if row_score else None
best_col = max(col_score, key=col_score.get) if col_score else None
has_row = best_row is not None and row_score[best_row] > 0
has_col = best_col is not None and col_score[best_col] > 0
if not has_row and not has_col:
return valid_cells[scores.argmax().item()]
row_indices = [i for i, rc in enumerate(valid_cells) if rc[0] == best_row] if has_row \
else list(range(len(valid_cells)))
if not row_indices:
return valid_cells[scores.argmax().item()]
if has_col:
both = [i for i in row_indices if valid_cells[i][1] == best_col]
if both:
return valid_cells[max(both, key=lambda i: scores[i].item())]
return valid_cells[max(row_indices, key=lambda i: scores[i].item())]
# ── TabuLM inference ───────────────────────────────────────────────────────────
def load_tabulm(device):
from finetune_tabqa import (detect_arch_from_state, encode_item, TabQAModel)
bpe = yttm.BPE(model=os.path.join(DATA_DIR, 'BPE-30k.mdl'))
kv = KBVocab()
kv.load_state_dict(torch.load(
os.path.join(DATA_DIR, 'kb_vocab_state_dict_2021-02-07.pt'),
map_location='cpu'))
saved = torch.load(TABULM_CKPT, map_location='cpu')
state = saved['model_state_dict']
pretrain_state = {k: v for k, v in state.items()
if not k.startswith('cell_head')}
if all(k.startswith('module.') for k in pretrain_state):
pretrain_state = {k[7:]: v for k, v in pretrain_state.items()}
args = detect_arch_from_state(pretrain_state)
tabulm_obj = tabulm_base(kv, None, None, device, args, saved_model_file=None)
model = TabQAModel(tabulm_obj).to(device)
model.load_state_dict(saved['model_state_dict'], strict=True)
model.eval()
return model, args, kv, bpe, encode_item
def tabulm_predict(model, args, kv, bpe, encode_item, cells, question, atype, device):
from finetune_tabqa import encode_item as _enc
enc = _enc(cells, question, kv, bpe)
if enc is None:
return None
(pos_tags, stems, affixes, tokens_lengths,
row_ids, col_ids, cell_types, ordered_cells, cell_to_positions) = enc
with torch.no_grad():
scores, valid_cells = model(
args, pos_tags, stems, affixes,
tokens_lengths, row_ids, col_ids, cell_types,
ordered_cells, cell_to_positions, device)
if scores is None:
return None
if atype == 'lookup':
return _predict_lookup(scores, valid_cells, cells, question)
return valid_cells[scores.argmax().item()]
# ── KinyaBERT inference ────────────────────────────────────────────────────────
class KBCellScorer(nn.Module):
def __init__(self, encoder, d_model):
super().__init__()
self.encoder = encoder
self.cell_head = nn.Linear(d_model, 1)
def forward(self, input_ids, attention_mask, ordered_cells,
cell_to_positions, device):
out = self.encoder(
input_ids=input_ids.unsqueeze(0).to(device),
attention_mask=attention_mask.unsqueeze(0).to(device))
hidden = out.last_hidden_state[0]
S = hidden.size(0)
cell_embeds, valid_cells = [], []
for rc in ordered_cells:
positions = [p for p in cell_to_positions[rc] if p < S]
if not positions:
continue
cell_embeds.append(hidden[positions].mean(0))
valid_cells.append(rc)
if not cell_embeds:
return None, []
scores = self.cell_head(torch.stack(cell_embeds)).squeeze(-1)
return scores, valid_cells
def load_kinyabert(device):
from eval_baselines import serialize_table_for_bert
hf_name = 'jean-paul/KinyaBERT-large'
tokenizer = AutoTokenizer.from_pretrained(hf_name)
encoder = AutoModel.from_pretrained(hf_name)
d_model = encoder.config.hidden_size
model = KBCellScorer(encoder, d_model).to(device)
saved = torch.load(KINYABERT_CKPT, map_location='cpu')
model.load_state_dict(saved['model_state_dict'], strict=False)
model.eval()
return model, tokenizer, serialize_table_for_bert
def kb_predict(model, tokenizer, serialize_fn, cells, question, atype, device):
enc = serialize_fn(cells, question, tokenizer)
if enc is None:
return None
input_ids, attn_mask, ordered_cells, cell_to_positions = enc
with torch.no_grad():
scores, valid_cells = model(
input_ids, attn_mask, ordered_cells, cell_to_positions, device)
if scores is None:
return None
if atype == 'lookup':
return _predict_lookup(scores, valid_cells, cells, question)
return valid_cells[scores.argmax().item()]
# ── Bootstrap CI ──────────────────────────────────────────────────────────────
def bootstrap_ci(hits: List[int], n_boot: int = 10000, alpha: float = 0.05):
hits_arr = np.array(hits, dtype=float)
n = len(hits_arr)
boot_means = [np.mean(np.random.choice(hits_arr, size=n, replace=True))
for _ in range(n_boot)]
boot_means = np.array(boot_means)
lo = np.percentile(boot_means, 100 * alpha / 2)
hi = np.percentile(boot_means, 100 * (1 - alpha / 2))
return float(np.mean(hits_arr)), lo, hi
def paired_bootstrap_pvalue(hits_a: List[int], hits_b: List[int],
n_boot: int = 10000) -> float:
"""One-sided test: P(A > B) under H0 that A and B have same mean."""
a = np.array(hits_a, dtype=float)
b = np.array(hits_b, dtype=float)
n = len(a)
assert len(b) == n
obs_diff = np.mean(a) - np.mean(b)
diffs = []
for _ in range(n_boot):
idx = np.random.choice(n, size=n, replace=True)
diffs.append(np.mean(a[idx]) - np.mean(b[idx]))
diffs = np.array(diffs)
# two-sided p-value
p = float(np.mean(np.abs(diffs) >= np.abs(obs_diff)))
return p
# ── Main ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--n-boot', type=int, default=10000)
parser.add_argument('--out', default=os.path.join(DATA_DIR, 'compare_results.json'))
args_cli = parser.parse_args()
random.seed(42); np.random.seed(42); torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'[CMP] Device: {device}')
print('[CMP] Loading TabuLM...')
tabulm_model, tabulm_args, kv, bpe, enc_fn = load_tabulm(device)
print('[CMP] Loading KinyaBERT...')
kb_model, kb_tok, kb_ser = load_kinyabert(device)
with open(TABQA_FILE) as f:
all_items = json.load(f)
random.shuffle(all_items)
dev_items = all_items[int(0.8 * len(all_items)):]
print(f'[CMP] {len(dev_items)} dev items')
records = []
for item in dev_items:
csv_path = os.path.join(CSV_DIR, item['table_file'])
atype = item.get('answer_type', '?')
if atype == 'count' or not os.path.exists(csv_path):
continue
cells = serialize_csv(csv_path)
if not cells:
continue
gold_rc = find_gold_cell(cells, item['answer'], item['question'])
if gold_rc is None:
continue
# TabuLM prediction
t_pred = tabulm_predict(tabulm_model, tabulm_args, kv, bpe, enc_fn,
cells, item['question'], atype, device)
t_hit = int(t_pred == gold_rc) if t_pred is not None else None
# KinyaBERT prediction
k_pred = kb_predict(kb_model, kb_tok, kb_ser,
cells, item['question'], atype, device)
k_hit = int(k_pred == gold_rc) if k_pred is not None else None
# only count items both models can evaluate
if t_hit is None or k_hit is None:
continue
# get cell text for readable output
cell_map = {(c.row_id, c.col_id): c.content.strip() for c in cells}
records.append({
'question': item['question'],
'table': item['table_file'],
'answer_type': atype,
'gold': item['answer'],
'gold_rc': list(gold_rc),
'tabulm_rc': list(t_pred),
'tabulm_pred': cell_map.get(t_pred, '?'),
'tabulm_hit': t_hit,
'kb_rc': list(k_pred),
'kb_pred': cell_map.get(k_pred, '?'),
'kb_hit': k_hit,
})
print(f'\n[CMP] Paired items: {len(records)}')
# ── Per-type breakdown ──────────────────────────────────────────────────
by_type: Dict[str, Dict] = {}
for r in records:
t = r['answer_type']
by_type.setdefault(t, {'tabulm': [], 'kb': []})
by_type[t]['tabulm'].append(r['tabulm_hit'])
by_type[t]['kb'].append(r['kb_hit'])
print(f'\n{"Type":<14} {"N":>4} {"TabuLM":>8} {"KinyaBERT":>10}')
print('-' * 44)
all_t, all_k = [], []
for t, d in sorted(by_type.items()):
n = len(d['tabulm'])
t_em = sum(d['tabulm']) / n
k_em = sum(d['kb']) / n
all_t.extend(d['tabulm']); all_k.extend(d['kb'])
print(f'{t:<14} {n:>4} {t_em:>8.3f} {k_em:>10.3f}')
n_all = len(all_t)
print(f'{"Overall":<14} {n_all:>4} {sum(all_t)/n_all:>8.3f} {sum(all_k)/n_all:>10.3f}')
# ── Bootstrap CI ───────────────────────────────────────────────────────
print(f'\n[CMP] Bootstrap CI (n_boot={args_cli.n_boot}):')
t_em, t_lo, t_hi = bootstrap_ci(all_t, args_cli.n_boot)
k_em, k_lo, k_hi = bootstrap_ci(all_k, args_cli.n_boot)
p_val = paired_bootstrap_pvalue(all_t, all_k, args_cli.n_boot)
print(f' TabuLM: {t_em:.3f} 95% CI [{t_lo:.3f}, {t_hi:.3f}]')
print(f' KinyaBERT: {k_em:.3f} 95% CI [{k_lo:.3f}, {k_hi:.3f}]')
print(f' Two-sided paired bootstrap p-value: {p_val:.4f}')
# ── Error analysis: TabuLM correct, KinyaBERT wrong ────────────────────
wins = [r for r in records if r['tabulm_hit'] == 1 and r['kb_hit'] == 0]
losses = [r for r in records if r['tabulm_hit'] == 0 and r['kb_hit'] == 1]
print(f'\n[CMP] TabuLM wins (T=1, KB=0): {len(wins)}')
print(f'[CMP] TabuLM losses (T=0, KB=1): {len(losses)}')
print('\n=== TabuLM WINS (comparison type preferred) ===')
cmp_wins = [r for r in wins if r['answer_type'] == 'comparison']
for r in cmp_wins[:5]:
print(f'\n Table : {r["table"]}')
print(f' Q : {r["question"]}')
print(f' Type : {r["answer_type"]}')
print(f' Gold : {r["gold"]} (row={r["gold_rc"][0]}, col={r["gold_rc"][1]})')
print(f' TabuLM: {r["tabulm_pred"]} βœ“')
print(f' KinyaB: {r["kb_pred"]} βœ—')
print('\n=== TabuLM LOSSES ===')
for r in losses[:3]:
print(f'\n Table : {r["table"]}')
print(f' Q : {r["question"]}')
print(f' Type : {r["answer_type"]}')
print(f' Gold : {r["gold"]}')
print(f' TabuLM: {r["tabulm_pred"]} βœ—')
print(f' KinyaB: {r["kb_pred"]} βœ“')
# ── Save ───────────────────────────────────────────────────────────────
out_data = {
'n_paired': n_all,
'tabulm_em': round(t_em, 4),
'kb_em': round(k_em, 4),
'tabulm_ci': [round(t_lo, 4), round(t_hi, 4)],
'kb_ci': [round(k_lo, 4), round(k_hi, 4)],
'p_value': round(p_val, 4),
'by_type': {t: {
'n': len(d['tabulm']),
'tabulm_em': round(sum(d['tabulm'])/len(d['tabulm']), 4),
'kb_em': round(sum(d['kb'])/len(d['kb']), 4),
} for t, d in by_type.items()},
'tabulm_wins': wins,
'tabulm_losses': losses,
'records': records,
}
with open(args_cli.out, 'w') as f:
json.dump(out_data, f, indent=2, ensure_ascii=False)
print(f'\n[CMP] Saved to {args_cli.out}')
if __name__ == '__main__':
main()