| |
| """Baseline evaluation on TabQA-kin using mBERT, XLM-R, and KinyaBERT. |
| |
| Each baseline uses the same cell-selection approach as TabuLM: |
| - Serialize table as flat text, track (row, col) token positions |
| - Pool hidden states per cell β linear cell scorer |
| - Same 80/20 train/dev split (seed=42), 20 epochs, lr=2e-5 |
| - Gold cell found by text-matching (same as finetune_tabqa.py) |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| from datetime import datetime |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch.nn.utils import clip_grad_norm_ |
| from transformers import AutoTokenizer, AutoModel |
|
|
| DATA_DIR = '/shared/scratch/0/tmp/v_ireddi_rakshitha_results/tabulm/data' |
| CODE_DIR = '/shared/scratch/0/tmp/v_ireddi_rakshitha_results/tabulm/code' |
| TABQA_FILE = os.path.join(DATA_DIR, 'tabqa_kin.json') |
| CSV_DIR = os.path.join(DATA_DIR, 'tables') |
|
|
| sys.path.insert(0, CODE_DIR) |
| from tabular_serializer import serialize_csv, TableCell |
|
|
| |
| MODELS = { |
| 'mbert': 'bert-base-multilingual-cased', |
| 'xlmr': 'xlm-roberta-base', |
| 'kinyabert': 'jean-paul/KinyaBERT-large', |
| } |
|
|
|
|
| |
|
|
| def find_gold_cell(cells: List[TableCell], answer_text: str, |
| question_text: str = '') -> Optional[Tuple[int, int]]: |
| answer_norm = answer_text.strip().lower() |
| matches = [(c.row_id, c.col_id) for c in cells |
| if c.row_id > 1 and c.col_id > 0 and c.content.strip() == answer_text.strip()] |
| if not matches: |
| matches = [(c.row_id, c.col_id) for c in cells |
| if c.row_id > 1 and c.col_id > 0 and c.content.strip().lower() == answer_norm] |
| if not matches: |
| return None |
| if len(matches) == 1: |
| return matches[0] |
| if question_text: |
| q_words = set(question_text.lower().split()) |
| row_labels = {c.row_id: c.content.strip().lower() |
| for c in cells if c.col_id == 1 and c.row_id > 1} |
| col_headers = {c.col_id: c.content.strip().lower() |
| for c in cells if c.row_id == 1 and c.col_id > 0} |
| best, best_score = None, (-1, -1) |
| for (row_id, col_id) in matches: |
| r_sc = len(q_words & set(row_labels.get(row_id, '').split())) |
| c_sc = len(q_words & set(col_headers.get(col_id, '').split())) |
| if (r_sc, c_sc) > best_score: |
| best_score, best = (r_sc, c_sc), (row_id, col_id) |
| if best_score[0] > 0 or best_score[1] > 0: |
| return best |
| return min(matches, key=lambda rc: (rc[0], rc[1])) |
|
|
|
|
| |
|
|
| def serialize_table_for_bert(cells: List[TableCell], question: str, |
| tokenizer, max_seq_len: int = 512): |
| """ |
| Linearize table + question for a BERT-style model. |
| Format: [CLS] h1 | h2 | ... [SEP] v1 | v2 | ... [SEP] ... question [SEP] |
| Tracks which token positions belong to each (row_id, col_id) data cell. |
| Returns (input_ids, attn_mask, ordered_cells, cell_to_positions) or None. |
| """ |
| rows: Dict[int, Dict[int, str]] = {} |
| for c in cells: |
| if c.row_id > 0 and c.col_id > 0: |
| rows.setdefault(c.row_id, {})[c.col_id] = c.content.strip() |
|
|
| if not rows: |
| return None |
|
|
| sorted_rows = sorted(rows.keys()) |
| col_ids_sorted = sorted(rows[sorted_rows[0]].keys()) |
|
|
| tokens: List[int] = [tokenizer.cls_token_id] |
| cell_to_positions: Dict[Tuple[int, int], List[int]] = {} |
|
|
| pipe_id = tokenizer.convert_tokens_to_ids('|') |
| if pipe_id == tokenizer.unk_token_id: |
| pipe_id = None |
|
|
| for row_id in sorted_rows: |
| for i, col_id in enumerate(col_ids_sorted): |
| content = rows[row_id].get(col_id, '') |
| sub = tokenizer.encode(content, add_special_tokens=False) |
| if sub: |
| start = len(tokens) |
| tokens.extend(sub) |
| if row_id > 1: |
| cell_to_positions.setdefault((row_id, col_id), []).extend( |
| range(start, len(tokens)) |
| ) |
| if pipe_id is not None and i < len(col_ids_sorted) - 1: |
| tokens.append(pipe_id) |
| tokens.append(tokenizer.sep_token_id) |
|
|
| |
| q_ids = tokenizer.encode(question, add_special_tokens=False) |
| tokens.extend(q_ids) |
| tokens.append(tokenizer.sep_token_id) |
|
|
| |
| if len(tokens) > max_seq_len: |
| tokens = tokens[:max_seq_len - 1] + [tokenizer.sep_token_id] |
| cell_to_positions = { |
| k: [p for p in v if p < max_seq_len] |
| for k, v in cell_to_positions.items() |
| } |
| cell_to_positions = {k: v for k, v in cell_to_positions.items() if v} |
|
|
| input_ids = torch.tensor(tokens, dtype=torch.long) |
| attn_mask = torch.ones(len(tokens), dtype=torch.long) |
| ordered_cells = sorted(cell_to_positions.keys(), |
| key=lambda rc: min(cell_to_positions[rc])) |
|
|
| return input_ids, attn_mask, ordered_cells, cell_to_positions |
|
|
|
|
| |
|
|
| class BaselineCellScorer(nn.Module): |
| def __init__(self, encoder, d_model: int): |
| super().__init__() |
| self.encoder = encoder |
| self.cell_head = nn.Linear(d_model, 1) |
| nn.init.normal_(self.cell_head.weight, std=0.02) |
| nn.init.zeros_(self.cell_head.bias) |
|
|
| def forward(self, input_ids, attention_mask, ordered_cells, |
| cell_to_positions, device): |
| out = self.encoder( |
| input_ids=input_ids.unsqueeze(0).to(device), |
| attention_mask=attention_mask.unsqueeze(0).to(device), |
| ) |
| hidden = out.last_hidden_state[0] |
| S = hidden.size(0) |
|
|
| cell_embeds, valid_cells = [], [] |
| for rc in ordered_cells: |
| positions = [p for p in cell_to_positions[rc] if p < S] |
| if not positions: |
| continue |
| h = hidden[positions].mean(0) |
| cell_embeds.append(h) |
| valid_cells.append(rc) |
|
|
| if not cell_embeds: |
| return None, [] |
|
|
| scores = self.cell_head(torch.stack(cell_embeds)).squeeze(-1) |
| return scores, valid_cells |
|
|
|
|
| |
|
|
| def _predict_lookup(scores, valid_cells, cells, question_text): |
| q_words = set(question_text.lower().split()) |
| row_labels = {c.row_id: c.content.strip().lower() |
| for c in cells if c.col_id == 1 and c.row_id > 1} |
| col_headers = {c.col_id: c.content.strip().lower() |
| for c in cells if c.row_id == 1 and c.col_id > 0} |
|
|
| if not q_words: |
| return valid_cells[scores.argmax().item()] |
|
|
| row_score = {r: len(q_words & set(lbl.split())) for r, lbl in row_labels.items()} |
| col_score = {c: len(q_words & set(hdr.split())) for c, hdr in col_headers.items()} |
|
|
| best_row = max(row_score, key=row_score.get) if row_score else None |
| best_col = max(col_score, key=col_score.get) if col_score else None |
| has_row = best_row is not None and row_score[best_row] > 0 |
| has_col = best_col is not None and col_score[best_col] > 0 |
|
|
| if not has_row and not has_col: |
| return valid_cells[scores.argmax().item()] |
|
|
| if has_row: |
| row_indices = [i for i, rc in enumerate(valid_cells) if rc[0] == best_row] |
| else: |
| row_indices = list(range(len(valid_cells))) |
|
|
| if not row_indices: |
| return valid_cells[scores.argmax().item()] |
|
|
| if has_col: |
| both = [i for i in row_indices if valid_cells[i][1] == best_col] |
| if both: |
| return valid_cells[max(both, key=lambda i: scores[i].item())] |
|
|
| return valid_cells[max(row_indices, key=lambda i: scores[i].item())] |
|
|
|
|
| def _predict_comparison(scores, valid_cells, cells, question_text): |
| """For comparison questions: restrict to col-1 (entity name) cells of the |
| top-2 question-relevant rows. The answer is the winning entity name.""" |
| q_words = set(question_text.lower().split()) |
| row_labels = {c.row_id: c.content.strip().lower() |
| for c in cells if c.col_id == 1 and c.row_id > 1} |
|
|
| if not q_words or not row_labels: |
| return valid_cells[scores.argmax().item()] |
|
|
| row_score = {r: len(q_words & set(lbl.split())) for r, lbl in row_labels.items()} |
| top2 = sorted([r for r, s in row_score.items() if s > 0], |
| key=lambda r: row_score[r], reverse=True)[:2] |
|
|
| if not top2: |
| return valid_cells[scores.argmax().item()] |
|
|
| cand = [i for i, (r, c) in enumerate(valid_cells) if r in top2 and c == 1] |
| if not cand: |
| cand = [i for i, (r, c) in enumerate(valid_cells) if r in top2] |
| if not cand: |
| return valid_cells[scores.argmax().item()] |
|
|
| return valid_cells[max(cand, key=lambda i: scores[i].item())] |
|
|
|
|
| |
|
|
| def evaluate(model, tokenizer, items, csv_dir, device, split='dev'): |
| model.eval() |
| correct, total, skipped = 0, 0, 0 |
| by_type: Dict[str, List[int]] = {} |
|
|
| with torch.no_grad(): |
| for item in items: |
| csv_path = os.path.join(csv_dir, item['table_file']) |
| if not os.path.exists(csv_path): |
| skipped += 1; continue |
| cells = serialize_csv(csv_path) |
| if not cells: |
| skipped += 1; continue |
|
|
| question = item['question'] |
| atype = item.get('answer_type', '?') |
|
|
| gold_rc = find_gold_cell(cells, item['answer'], question_text=question) |
| if gold_rc is None: |
| skipped += 1; continue |
|
|
| enc = serialize_table_for_bert(cells, question, tokenizer) |
| if enc is None: |
| skipped += 1; continue |
| input_ids, attn_mask, ordered_cells, cell_to_positions = enc |
|
|
| if gold_rc not in cell_to_positions: |
| skipped += 1; continue |
|
|
| scores, valid_cells = model(input_ids, attn_mask, ordered_cells, |
| cell_to_positions, device) |
| if scores is None or gold_rc not in valid_cells: |
| skipped += 1; continue |
|
|
| if atype == 'lookup': |
| pred_rc = _predict_lookup(scores, valid_cells, cells, question) |
| elif atype == 'comparison': |
| pred_rc = _predict_comparison(scores, valid_cells, cells, question) |
| else: |
| pred_rc = valid_cells[scores.argmax().item()] |
| hit = int(pred_rc == gold_rc) |
| correct += hit; total += 1 |
| by_type.setdefault(atype, []).append(hit) |
|
|
| em = correct / total if total > 0 else 0.0 |
| ts = datetime.now().strftime('%H:%M:%S') |
| print(f' [{ts}] {split} EM={em:.4f} ({correct}/{total}, {skipped} skipped)') |
| for t, hits in sorted(by_type.items()): |
| print(f' {t}: {sum(hits)}/{len(hits)} = {sum(hits)/len(hits):.3f}') |
| return em, {t: round(sum(h)/len(h), 4) for t, h in by_type.items()} |
|
|
|
|
| |
|
|
| def train_baseline(model_name, hf_name, train_items, dev_items, csv_dir, |
| device, num_epochs=20, lr=2e-5): |
| print(f'\n[BL] ===== {model_name.upper()} ({hf_name}) =====') |
| tokenizer = AutoTokenizer.from_pretrained(hf_name) |
| encoder = AutoModel.from_pretrained(hf_name) |
| d_model = encoder.config.hidden_size |
| model = BaselineCellScorer(encoder, d_model).to(device) |
|
|
| print(f'[BL] d_model={d_model} ' |
| f'params={sum(p.numel() for p in model.parameters()):,}') |
|
|
| print('[BL] Zero-shot baseline:') |
| best_em, best_by_type = evaluate(model, tokenizer, dev_items, |
| csv_dir, device, 'zero-shot') |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
| criterion = nn.CrossEntropyLoss() |
| history = [] |
|
|
| for epoch in range(1, num_epochs + 1): |
| model.train() |
| random.shuffle(train_items) |
| total_loss, n_ok, n_skip = 0.0, 0, 0 |
|
|
| for item in train_items: |
| csv_path = os.path.join(csv_dir, item['table_file']) |
| if not os.path.exists(csv_path): |
| n_skip += 1; continue |
| cells = serialize_csv(csv_path) |
| if not cells: |
| n_skip += 1; continue |
|
|
| gold_rc = find_gold_cell(cells, item['answer']) |
| if gold_rc is None: |
| n_skip += 1; continue |
|
|
| enc = serialize_table_for_bert(cells, item['question'], tokenizer) |
| if enc is None: |
| n_skip += 1; continue |
| input_ids, attn_mask, ordered_cells, cell_to_positions = enc |
|
|
| if gold_rc not in cell_to_positions: |
| n_skip += 1; continue |
|
|
| scores, valid_cells = model(input_ids, attn_mask, ordered_cells, |
| cell_to_positions, device) |
| if scores is None or gold_rc not in valid_cells: |
| n_skip += 1; continue |
|
|
| gold_idx = torch.tensor([valid_cells.index(gold_rc)], |
| dtype=torch.long, device=device) |
| loss = criterion(scores.unsqueeze(0), gold_idx) |
| optimizer.zero_grad() |
| loss.backward() |
| clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| total_loss += loss.item(); n_ok += 1 |
|
|
| ts = datetime.now().strftime('%H:%M:%S') |
| avg = total_loss / max(n_ok, 1) |
| print(f'[{ts}] Epoch {epoch}/{num_epochs} loss={avg:.4f} ' |
| f'trained={n_ok} skipped={n_skip}') |
|
|
| em, by_type = evaluate(model, tokenizer, dev_items, csv_dir, device) |
| history.append({'epoch': epoch, 'loss': avg, 'dev_em': em}) |
|
|
| if em > best_em: |
| best_em, best_by_type = em, by_type |
| sp = os.path.join(DATA_DIR, f'baseline_{model_name}_best.pt') |
| torch.save({'model_state_dict': model.state_dict(), 'em': em}, sp) |
| print(f' ** New best EM={best_em:.4f}') |
|
|
| print(f'\n[BL] {model_name.upper()} done. Best EM={best_em:.4f}') |
| return best_em, best_by_type, history |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--models', nargs='+', default=['mbert'], |
| choices=list(MODELS.keys())) |
| parser.add_argument('--epochs', type=int, default=20) |
| parser.add_argument('--lr', type=float, default=2e-5) |
| args = parser.parse_args() |
|
|
| random.seed(42); np.random.seed(42); torch.manual_seed(42) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f'[BL] Device: {device}') |
|
|
| with open(TABQA_FILE) as f: |
| all_items = json.load(f) |
| random.shuffle(all_items) |
| split_n = int(0.8 * len(all_items)) |
| train_items = all_items[:split_n] |
| dev_items = all_items[split_n:] |
| print(f'[BL] {len(all_items)} items Train={len(train_items)} Dev={len(dev_items)}') |
|
|
| all_results = {} |
| for model_name in args.models: |
| hf_name = MODELS[model_name] |
| best_em, best_by_type, history = train_baseline( |
| model_name, hf_name, train_items, dev_items, CSV_DIR, device, |
| num_epochs=args.epochs, lr=args.lr, |
| ) |
| all_results[model_name] = { |
| 'hf_name': hf_name, 'best_em': best_em, |
| 'by_type': best_by_type, 'history': history, |
| } |
|
|
| out = os.path.join(DATA_DIR, 'baseline_results.json') |
| with open(out, 'w') as f: |
| json.dump(all_results, f, indent=2) |
| print(f'\n[BL] Results saved to {out}') |
|
|
| print('\n' + '='*58) |
| print(f'{"Model":<18} {"EM":>6} {"Agg":>5} {"Cmp":>5} {"Lkp":>5}') |
| print('-'*58) |
| for name, r in all_results.items(): |
| bt = r['by_type'] |
| print(f'{name:<18} {r["best_em"]:>6.3f} ' |
| f'{bt.get("aggregation",0):>5.3f} ' |
| f'{bt.get("comparison",0):>5.3f} ' |
| f'{bt.get("lookup",0):>5.3f}') |
| print(f'{"TabuLM (ours)":<18} {"0.560":>6} ' |
| f'{"0.792":>5} {"0.417":>5} {"0.286":>5}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|