tabulm / code /eval_baselines.py
rakshi719's picture
Add TabuLM training and evaluation code
f32c034 verified
Raw
History Blame Contribute Delete
17.4 kB
#!/usr/bin/env python3
"""Baseline evaluation on TabQA-kin using mBERT, XLM-R, and KinyaBERT.
Each baseline uses the same cell-selection approach as TabuLM:
- Serialize table as flat text, track (row, col) token positions
- Pool hidden states per cell β†’ linear cell scorer
- Same 80/20 train/dev split (seed=42), 20 epochs, lr=2e-5
- Gold cell found by text-matching (same as finetune_tabqa.py)
"""
import argparse
import json
import os
import random
import sys
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from transformers import AutoTokenizer, AutoModel
DATA_DIR = '/shared/scratch/0/tmp/v_ireddi_rakshitha_results/tabulm/data'
CODE_DIR = '/shared/scratch/0/tmp/v_ireddi_rakshitha_results/tabulm/code'
TABQA_FILE = os.path.join(DATA_DIR, 'tabqa_kin.json')
CSV_DIR = os.path.join(DATA_DIR, 'tables')
sys.path.insert(0, CODE_DIR)
from tabular_serializer import serialize_csv, TableCell
# ── Model registry ─────────────────────────────────────────────────────────────
MODELS = {
'mbert': 'bert-base-multilingual-cased',
'xlmr': 'xlm-roberta-base',
'kinyabert': 'jean-paul/KinyaBERT-large',
}
# ── Gold cell lookup (same as finetune_tabqa.py) ───────────────────────────────
def find_gold_cell(cells: List[TableCell], answer_text: str,
question_text: str = '') -> Optional[Tuple[int, int]]:
answer_norm = answer_text.strip().lower()
matches = [(c.row_id, c.col_id) for c in cells
if c.row_id > 1 and c.col_id > 0 and c.content.strip() == answer_text.strip()]
if not matches:
matches = [(c.row_id, c.col_id) for c in cells
if c.row_id > 1 and c.col_id > 0 and c.content.strip().lower() == answer_norm]
if not matches:
return None
if len(matches) == 1:
return matches[0]
if question_text:
q_words = set(question_text.lower().split())
row_labels = {c.row_id: c.content.strip().lower()
for c in cells if c.col_id == 1 and c.row_id > 1}
col_headers = {c.col_id: c.content.strip().lower()
for c in cells if c.row_id == 1 and c.col_id > 0}
best, best_score = None, (-1, -1)
for (row_id, col_id) in matches:
r_sc = len(q_words & set(row_labels.get(row_id, '').split()))
c_sc = len(q_words & set(col_headers.get(col_id, '').split()))
if (r_sc, c_sc) > best_score:
best_score, best = (r_sc, c_sc), (row_id, col_id)
if best_score[0] > 0 or best_score[1] > 0:
return best
return min(matches, key=lambda rc: (rc[0], rc[1]))
# ── Table serialization for HF tokenizers ─────────────────────────────────────
def serialize_table_for_bert(cells: List[TableCell], question: str,
tokenizer, max_seq_len: int = 512):
"""
Linearize table + question for a BERT-style model.
Format: [CLS] h1 | h2 | ... [SEP] v1 | v2 | ... [SEP] ... question [SEP]
Tracks which token positions belong to each (row_id, col_id) data cell.
Returns (input_ids, attn_mask, ordered_cells, cell_to_positions) or None.
"""
rows: Dict[int, Dict[int, str]] = {}
for c in cells:
if c.row_id > 0 and c.col_id > 0:
rows.setdefault(c.row_id, {})[c.col_id] = c.content.strip()
if not rows:
return None
sorted_rows = sorted(rows.keys())
col_ids_sorted = sorted(rows[sorted_rows[0]].keys())
tokens: List[int] = [tokenizer.cls_token_id]
cell_to_positions: Dict[Tuple[int, int], List[int]] = {}
pipe_id = tokenizer.convert_tokens_to_ids('|')
if pipe_id == tokenizer.unk_token_id:
pipe_id = None # skip pipe separator if not in vocab
for row_id in sorted_rows:
for i, col_id in enumerate(col_ids_sorted):
content = rows[row_id].get(col_id, '')
sub = tokenizer.encode(content, add_special_tokens=False)
if sub:
start = len(tokens)
tokens.extend(sub)
if row_id > 1: # data rows only (header row_id=1 not scored)
cell_to_positions.setdefault((row_id, col_id), []).extend(
range(start, len(tokens))
)
if pipe_id is not None and i < len(col_ids_sorted) - 1:
tokens.append(pipe_id)
tokens.append(tokenizer.sep_token_id)
# Question
q_ids = tokenizer.encode(question, add_special_tokens=False)
tokens.extend(q_ids)
tokens.append(tokenizer.sep_token_id)
# Truncate
if len(tokens) > max_seq_len:
tokens = tokens[:max_seq_len - 1] + [tokenizer.sep_token_id]
cell_to_positions = {
k: [p for p in v if p < max_seq_len]
for k, v in cell_to_positions.items()
}
cell_to_positions = {k: v for k, v in cell_to_positions.items() if v}
input_ids = torch.tensor(tokens, dtype=torch.long)
attn_mask = torch.ones(len(tokens), dtype=torch.long)
ordered_cells = sorted(cell_to_positions.keys(),
key=lambda rc: min(cell_to_positions[rc]))
return input_ids, attn_mask, ordered_cells, cell_to_positions
# ── Cell-scoring model ─────────────────────────────────────────────────────────
class BaselineCellScorer(nn.Module):
def __init__(self, encoder, d_model: int):
super().__init__()
self.encoder = encoder
self.cell_head = nn.Linear(d_model, 1)
nn.init.normal_(self.cell_head.weight, std=0.02)
nn.init.zeros_(self.cell_head.bias)
def forward(self, input_ids, attention_mask, ordered_cells,
cell_to_positions, device):
out = self.encoder(
input_ids=input_ids.unsqueeze(0).to(device),
attention_mask=attention_mask.unsqueeze(0).to(device),
)
hidden = out.last_hidden_state[0] # (S, d)
S = hidden.size(0)
cell_embeds, valid_cells = [], []
for rc in ordered_cells:
positions = [p for p in cell_to_positions[rc] if p < S]
if not positions:
continue
h = hidden[positions].mean(0)
cell_embeds.append(h)
valid_cells.append(rc)
if not cell_embeds:
return None, []
scores = self.cell_head(torch.stack(cell_embeds)).squeeze(-1)
return scores, valid_cells
# ── Lookup prediction helper ───────────────────────────────────────────────────
def _predict_lookup(scores, valid_cells, cells, question_text):
q_words = set(question_text.lower().split())
row_labels = {c.row_id: c.content.strip().lower()
for c in cells if c.col_id == 1 and c.row_id > 1}
col_headers = {c.col_id: c.content.strip().lower()
for c in cells if c.row_id == 1 and c.col_id > 0}
if not q_words:
return valid_cells[scores.argmax().item()]
row_score = {r: len(q_words & set(lbl.split())) for r, lbl in row_labels.items()}
col_score = {c: len(q_words & set(hdr.split())) for c, hdr in col_headers.items()}
best_row = max(row_score, key=row_score.get) if row_score else None
best_col = max(col_score, key=col_score.get) if col_score else None
has_row = best_row is not None and row_score[best_row] > 0
has_col = best_col is not None and col_score[best_col] > 0
if not has_row and not has_col:
return valid_cells[scores.argmax().item()]
if has_row:
row_indices = [i for i, rc in enumerate(valid_cells) if rc[0] == best_row]
else:
row_indices = list(range(len(valid_cells)))
if not row_indices:
return valid_cells[scores.argmax().item()]
if has_col:
both = [i for i in row_indices if valid_cells[i][1] == best_col]
if both:
return valid_cells[max(both, key=lambda i: scores[i].item())]
return valid_cells[max(row_indices, key=lambda i: scores[i].item())]
def _predict_comparison(scores, valid_cells, cells, question_text):
"""For comparison questions: restrict to col-1 (entity name) cells of the
top-2 question-relevant rows. The answer is the winning entity name."""
q_words = set(question_text.lower().split())
row_labels = {c.row_id: c.content.strip().lower()
for c in cells if c.col_id == 1 and c.row_id > 1}
if not q_words or not row_labels:
return valid_cells[scores.argmax().item()]
row_score = {r: len(q_words & set(lbl.split())) for r, lbl in row_labels.items()}
top2 = sorted([r for r, s in row_score.items() if s > 0],
key=lambda r: row_score[r], reverse=True)[:2]
if not top2:
return valid_cells[scores.argmax().item()]
cand = [i for i, (r, c) in enumerate(valid_cells) if r in top2 and c == 1]
if not cand:
cand = [i for i, (r, c) in enumerate(valid_cells) if r in top2]
if not cand:
return valid_cells[scores.argmax().item()]
return valid_cells[max(cand, key=lambda i: scores[i].item())]
# ── Evaluation ─────────────────────────────────────────────────────────────────
def evaluate(model, tokenizer, items, csv_dir, device, split='dev'):
model.eval()
correct, total, skipped = 0, 0, 0
by_type: Dict[str, List[int]] = {}
with torch.no_grad():
for item in items:
csv_path = os.path.join(csv_dir, item['table_file'])
if not os.path.exists(csv_path):
skipped += 1; continue
cells = serialize_csv(csv_path)
if not cells:
skipped += 1; continue
question = item['question']
atype = item.get('answer_type', '?')
gold_rc = find_gold_cell(cells, item['answer'], question_text=question)
if gold_rc is None:
skipped += 1; continue
enc = serialize_table_for_bert(cells, question, tokenizer)
if enc is None:
skipped += 1; continue
input_ids, attn_mask, ordered_cells, cell_to_positions = enc
if gold_rc not in cell_to_positions:
skipped += 1; continue
scores, valid_cells = model(input_ids, attn_mask, ordered_cells,
cell_to_positions, device)
if scores is None or gold_rc not in valid_cells:
skipped += 1; continue
if atype == 'lookup':
pred_rc = _predict_lookup(scores, valid_cells, cells, question)
elif atype == 'comparison':
pred_rc = _predict_comparison(scores, valid_cells, cells, question)
else:
pred_rc = valid_cells[scores.argmax().item()]
hit = int(pred_rc == gold_rc)
correct += hit; total += 1
by_type.setdefault(atype, []).append(hit)
em = correct / total if total > 0 else 0.0
ts = datetime.now().strftime('%H:%M:%S')
print(f' [{ts}] {split} EM={em:.4f} ({correct}/{total}, {skipped} skipped)')
for t, hits in sorted(by_type.items()):
print(f' {t}: {sum(hits)}/{len(hits)} = {sum(hits)/len(hits):.3f}')
return em, {t: round(sum(h)/len(h), 4) for t, h in by_type.items()}
# ── Training ───────────────────────────────────────────────────────────────────
def train_baseline(model_name, hf_name, train_items, dev_items, csv_dir,
device, num_epochs=20, lr=2e-5):
print(f'\n[BL] ===== {model_name.upper()} ({hf_name}) =====')
tokenizer = AutoTokenizer.from_pretrained(hf_name)
encoder = AutoModel.from_pretrained(hf_name)
d_model = encoder.config.hidden_size
model = BaselineCellScorer(encoder, d_model).to(device)
print(f'[BL] d_model={d_model} '
f'params={sum(p.numel() for p in model.parameters()):,}')
print('[BL] Zero-shot baseline:')
best_em, best_by_type = evaluate(model, tokenizer, dev_items,
csv_dir, device, 'zero-shot')
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
history = []
for epoch in range(1, num_epochs + 1):
model.train()
random.shuffle(train_items)
total_loss, n_ok, n_skip = 0.0, 0, 0
for item in train_items:
csv_path = os.path.join(csv_dir, item['table_file'])
if not os.path.exists(csv_path):
n_skip += 1; continue
cells = serialize_csv(csv_path)
if not cells:
n_skip += 1; continue
gold_rc = find_gold_cell(cells, item['answer'])
if gold_rc is None:
n_skip += 1; continue
enc = serialize_table_for_bert(cells, item['question'], tokenizer)
if enc is None:
n_skip += 1; continue
input_ids, attn_mask, ordered_cells, cell_to_positions = enc
if gold_rc not in cell_to_positions:
n_skip += 1; continue
scores, valid_cells = model(input_ids, attn_mask, ordered_cells,
cell_to_positions, device)
if scores is None or gold_rc not in valid_cells:
n_skip += 1; continue
gold_idx = torch.tensor([valid_cells.index(gold_rc)],
dtype=torch.long, device=device)
loss = criterion(scores.unsqueeze(0), gold_idx)
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item(); n_ok += 1
ts = datetime.now().strftime('%H:%M:%S')
avg = total_loss / max(n_ok, 1)
print(f'[{ts}] Epoch {epoch}/{num_epochs} loss={avg:.4f} '
f'trained={n_ok} skipped={n_skip}')
em, by_type = evaluate(model, tokenizer, dev_items, csv_dir, device)
history.append({'epoch': epoch, 'loss': avg, 'dev_em': em})
if em > best_em:
best_em, best_by_type = em, by_type
sp = os.path.join(DATA_DIR, f'baseline_{model_name}_best.pt')
torch.save({'model_state_dict': model.state_dict(), 'em': em}, sp)
print(f' ** New best EM={best_em:.4f}')
print(f'\n[BL] {model_name.upper()} done. Best EM={best_em:.4f}')
return best_em, best_by_type, history
# ── Main ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--models', nargs='+', default=['mbert'],
choices=list(MODELS.keys()))
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--lr', type=float, default=2e-5)
args = parser.parse_args()
random.seed(42); np.random.seed(42); torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'[BL] Device: {device}')
with open(TABQA_FILE) as f:
all_items = json.load(f)
random.shuffle(all_items)
split_n = int(0.8 * len(all_items))
train_items = all_items[:split_n]
dev_items = all_items[split_n:]
print(f'[BL] {len(all_items)} items Train={len(train_items)} Dev={len(dev_items)}')
all_results = {}
for model_name in args.models:
hf_name = MODELS[model_name]
best_em, best_by_type, history = train_baseline(
model_name, hf_name, train_items, dev_items, CSV_DIR, device,
num_epochs=args.epochs, lr=args.lr,
)
all_results[model_name] = {
'hf_name': hf_name, 'best_em': best_em,
'by_type': best_by_type, 'history': history,
}
out = os.path.join(DATA_DIR, 'baseline_results.json')
with open(out, 'w') as f:
json.dump(all_results, f, indent=2)
print(f'\n[BL] Results saved to {out}')
print('\n' + '='*58)
print(f'{"Model":<18} {"EM":>6} {"Agg":>5} {"Cmp":>5} {"Lkp":>5}')
print('-'*58)
for name, r in all_results.items():
bt = r['by_type']
print(f'{name:<18} {r["best_em"]:>6.3f} '
f'{bt.get("aggregation",0):>5.3f} '
f'{bt.get("comparison",0):>5.3f} '
f'{bt.get("lookup",0):>5.3f}')
print(f'{"TabuLM (ours)":<18} {"0.560":>6} '
f'{"0.792":>5} {"0.417":>5} {"0.286":>5}')
if __name__ == '__main__':
main()