Diff-Refine / eval_ae.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
# scripts/eval_ae_consistency.py
"""
z0 = encoder(x)
x^1 = decoder(z0)
z1 = encoder(x^1)
"""
import argparse
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder
from src.utils.data_utils import prepare_data
def pick_stop_id(tokenizer):
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def masked_mean(x, mask, eps=1e-6):
# x: [B,L] or [B,L,D] reduced already, mask: [B,L]
denom = mask.sum().clamp(min=eps)
return (x * mask).sum() / denom
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dataset", type=str, default="wiki")
ap.add_argument("--split", type=str, default="test")
ap.add_argument("--max_seq_len", type=int, default=128)
ap.add_argument("--batch_size", type=int, default=16)
ap.add_argument("--ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="path to ae.state_dict()")
ap.add_argument("--max_batches", type=int, default=0, help="0 means full eval")
ap.add_argument("--print_n", type=int, default=8)
args = ap.parse_args()
# configs
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code',
latent_dim=512,
max_seq_len=args.max_seq_len,
)
t_cfg = TrainConfig(batch_size=args.batch_size)
device = t_cfg.device
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True,trust_remote_code=False)
stop_id = pick_stop_id(tokenizer)
loader = prepare_data(args.dataset, tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split=args.split)
# test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
ae = ReshapedAutoencoder(m_cfg).to(device).float()
if args.ckpt:
sd = torch.load(args.ckpt, map_location="cpu")
ae.load_state_dict(sd, strict=True)
ae.eval()
total_ce = 0.0
total_acc = 0.0
total_tokens = 0.0
eos_found = 0
eos_pos_err = 0.0
eos_count = 0
total_cos = 0.0
total_l2 = 0.0
total_lat_tokens = 0.0
printed = 0
for bi, batch in enumerate(tqdm(loader, desc="Eval AE")):
if args.max_batches and bi >= args.max_batches:
break
ids = batch["tgt_ids"].to(device)
mask = batch["tgt_mask"].to(device)
# --- forward ---
z0 = ae.encode(ids, mask) # [B,L,D]
logits = ae.decode(z0, attention_mask=mask) # [B,L,V]
pred = logits.argmax(dim=-1) # [B,L]
# --- masked CE ---
labels = ids.masked_fill(mask == 0, -100)
ce = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, reduction="sum")
total_ce += ce.item()
# --- token acc (masked) ---
correct = ((pred == ids) & (mask.bool())).sum().item()
tok = mask.sum().item()
total_acc += correct
total_tokens += tok
# --- EOS stats ---
# true/pred EOS position (first occurrence)
B, L = ids.shape
for i in range(B):
# only search within valid tokens
valid_len = int(mask[i].sum().item())
true_seq = ids[i, :valid_len]
pred_seq = pred[i, :valid_len]
true_pos = (true_seq == stop_id).nonzero(as_tuple=True)[0]
pred_pos = (pred_seq == stop_id).nonzero(as_tuple=True)[0]
if pred_pos.numel() > 0:
eos_found += 1
if true_pos.numel() > 0:
eos_count += 1
tpos = int(true_pos[0].item())
ppos = int(pred_pos[0].item()) if pred_pos.numel() > 0 else valid_len - 1
eos_pos_err += abs(ppos - tpos)
# --- latent cycle: z0 -> token -> z1 ---
z1 = ae.encode(pred, mask)
cos = F.cosine_similarity(z0, z1, dim=-1) # [B,L]
l2 = (z0 - z1).pow(2).mean(dim=-1) # [B,L]
total_cos += (cos * mask).sum().item()
total_l2 += (l2 * mask).sum().item()
total_lat_tokens += mask.sum().item()
# --- print a few examples ---
if printed < args.print_n:
s = tokenizer.decode(ids[0], skip_special_tokens=True)
## 这里没有进行 pos 截断
# valid_len = int(mask[0].sum().item())
# pred_seq = pred[0, :valid_len]
# # 找 stop(eos/sep)
# end = _first_pos(pred_seq, stop_id, default=valid_len-1) + 1
# g = tokenizer.decode(pred_seq[:end], skip_special_tokens=True)
g = tokenizer.decode(pred[0], skip_special_tokens=True)
print("\n--- Example ---")
print("GT :", s)
print("REC:", g)
printed += 1
avg_ce = total_ce / max(total_tokens, 1.0)
avg_acc = total_acc / max(total_tokens, 1.0)
avg_cos = total_cos / max(total_lat_tokens, 1.0)
avg_l2 = total_l2 / max(total_lat_tokens, 1.0)
eos_found_rate = eos_found / max(total_tokens / args.max_seq_len, 1.0) # 近似 batch 数
eos_mae = eos_pos_err / max(eos_count, 1)
print("\n===== AE Metrics =====")
print(f"Masked CE per-token: {avg_ce:.4f}")
print(f"Token Acc (masked): {avg_acc:.4f}")
print(f"Latent cycle cosine(z0,z1): {avg_cos:.4f}")
print(f"Latent cycle l2(z0,z1): {avg_l2:.6f}")
print(f"EOS found rate (rough): {eos_found_rate:.4f}")
print(f"EOS position MAE (only where GT has EOS): {eos_mae:.2f}")
if __name__ == "__main__":
main()