File size: 5,746 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # scripts/eval_ae_consistency.py
"""
z0 = encoder(x)
x^1 = decoder(z0)
z1 = encoder(x^1)
"""
import argparse
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder
from src.utils.data_utils import prepare_data
def pick_stop_id(tokenizer):
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def masked_mean(x, mask, eps=1e-6):
# x: [B,L] or [B,L,D] reduced already, mask: [B,L]
denom = mask.sum().clamp(min=eps)
return (x * mask).sum() / denom
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", type=str, default="wiki")
ap.add_argument("--split", type=str, default="test")
ap.add_argument("--max_seq_len", type=int, default=128)
ap.add_argument("--batch_size", type=int, default=16)
ap.add_argument("--ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="path to ae.state_dict()")
ap.add_argument("--max_batches", type=int, default=0, help="0 means full eval")
ap.add_argument("--print_n", type=int, default=8)
args = ap.parse_args()
# configs
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code',
latent_dim=512,
max_seq_len=args.max_seq_len,
)
t_cfg = TrainConfig(batch_size=args.batch_size)
device = t_cfg.device
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True,trust_remote_code=False)
stop_id = pick_stop_id(tokenizer)
loader = prepare_data(args.dataset, tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split=args.split)
# test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
ae = ReshapedAutoencoder(m_cfg).to(device).float()
if args.ckpt:
sd = torch.load(args.ckpt, map_location="cpu")
ae.load_state_dict(sd, strict=True)
ae.eval()
total_ce = 0.0
total_acc = 0.0
total_tokens = 0.0
eos_found = 0
eos_pos_err = 0.0
eos_count = 0
total_cos = 0.0
total_l2 = 0.0
total_lat_tokens = 0.0
printed = 0
for bi, batch in enumerate(tqdm(loader, desc="Eval AE")):
if args.max_batches and bi >= args.max_batches:
break
ids = batch["tgt_ids"].to(device)
mask = batch["tgt_mask"].to(device)
# --- forward ---
z0 = ae.encode(ids, mask) # [B,L,D]
logits = ae.decode(z0, attention_mask=mask) # [B,L,V]
pred = logits.argmax(dim=-1) # [B,L]
# --- masked CE ---
labels = ids.masked_fill(mask == 0, -100)
ce = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, reduction="sum")
total_ce += ce.item()
# --- token acc (masked) ---
correct = ((pred == ids) & (mask.bool())).sum().item()
tok = mask.sum().item()
total_acc += correct
total_tokens += tok
# --- EOS stats ---
# true/pred EOS position (first occurrence)
B, L = ids.shape
for i in range(B):
# only search within valid tokens
valid_len = int(mask[i].sum().item())
true_seq = ids[i, :valid_len]
pred_seq = pred[i, :valid_len]
true_pos = (true_seq == stop_id).nonzero(as_tuple=True)[0]
pred_pos = (pred_seq == stop_id).nonzero(as_tuple=True)[0]
if pred_pos.numel() > 0:
eos_found += 1
if true_pos.numel() > 0:
eos_count += 1
tpos = int(true_pos[0].item())
ppos = int(pred_pos[0].item()) if pred_pos.numel() > 0 else valid_len - 1
eos_pos_err += abs(ppos - tpos)
# --- latent cycle: z0 -> token -> z1 ---
z1 = ae.encode(pred, mask)
cos = F.cosine_similarity(z0, z1, dim=-1) # [B,L]
l2 = (z0 - z1).pow(2).mean(dim=-1) # [B,L]
total_cos += (cos * mask).sum().item()
total_l2 += (l2 * mask).sum().item()
total_lat_tokens += mask.sum().item()
# --- print a few examples ---
if printed < args.print_n:
s = tokenizer.decode(ids[0], skip_special_tokens=True)
## 这里没有进行 pos 截断
# valid_len = int(mask[0].sum().item())
# pred_seq = pred[0, :valid_len]
# # 找 stop(eos/sep)
# end = _first_pos(pred_seq, stop_id, default=valid_len-1) + 1
# g = tokenizer.decode(pred_seq[:end], skip_special_tokens=True)
g = tokenizer.decode(pred[0], skip_special_tokens=True)
print("\n--- Example ---")
print("GT :", s)
print("REC:", g)
printed += 1
avg_ce = total_ce / max(total_tokens, 1.0)
avg_acc = total_acc / max(total_tokens, 1.0)
avg_cos = total_cos / max(total_lat_tokens, 1.0)
avg_l2 = total_l2 / max(total_lat_tokens, 1.0)
eos_found_rate = eos_found / max(total_tokens / args.max_seq_len, 1.0) # 近似 batch 数
eos_mae = eos_pos_err / max(eos_count, 1)
print("\n===== AE Metrics =====")
print(f"Masked CE per-token: {avg_ce:.4f}")
print(f"Token Acc (masked): {avg_acc:.4f}")
print(f"Latent cycle cosine(z0,z1): {avg_cos:.4f}")
print(f"Latent cycle l2(z0,z1): {avg_l2:.6f}")
print(f"EOS found rate (rough): {eos_found_rate:.4f}")
print(f"EOS position MAE (only where GT has EOS): {eos_mae:.2f}")
if __name__ == "__main__":
main()
|