| |
| """ |
| z0 = encoder(x) |
| x^1 = decoder(z0) |
| z1 = encoder(x^1) |
| """ |
|
|
| import argparse |
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| from src.config import ModelConfig, TrainConfig |
| from src.models.autoencoder import ReshapedAutoencoder |
| from src.utils.data_utils import prepare_data |
|
|
| def pick_stop_id(tokenizer): |
| return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id |
|
|
| def masked_mean(x, mask, eps=1e-6): |
| |
| denom = mask.sum().clamp(min=eps) |
| return (x * mask).sum() / denom |
|
|
| @torch.no_grad() |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--dataset", type=str, default="wiki") |
| ap.add_argument("--split", type=str, default="test") |
| ap.add_argument("--max_seq_len", type=int, default=128) |
| ap.add_argument("--batch_size", type=int, default=16) |
| ap.add_argument("--ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="path to ae.state_dict()") |
| ap.add_argument("--max_batches", type=int, default=0, help="0 means full eval") |
| ap.add_argument("--print_n", type=int, default=8) |
| args = ap.parse_args() |
|
|
| |
| m_cfg = ModelConfig( |
| encoder_name='../jina-embeddings-v2-base-code', |
| latent_dim=512, |
| max_seq_len=args.max_seq_len, |
| ) |
| t_cfg = TrainConfig(batch_size=args.batch_size) |
|
|
| device = t_cfg.device |
| tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True,trust_remote_code=False) |
| stop_id = pick_stop_id(tokenizer) |
|
|
| loader = prepare_data(args.dataset, tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split=args.split) |
| |
|
|
| ae = ReshapedAutoencoder(m_cfg).to(device).float() |
| if args.ckpt: |
| sd = torch.load(args.ckpt, map_location="cpu") |
| ae.load_state_dict(sd, strict=True) |
| ae.eval() |
|
|
| total_ce = 0.0 |
| total_acc = 0.0 |
| total_tokens = 0.0 |
|
|
| eos_found = 0 |
| eos_pos_err = 0.0 |
| eos_count = 0 |
|
|
| total_cos = 0.0 |
| total_l2 = 0.0 |
| total_lat_tokens = 0.0 |
|
|
| printed = 0 |
|
|
| for bi, batch in enumerate(tqdm(loader, desc="Eval AE")): |
| if args.max_batches and bi >= args.max_batches: |
| break |
|
|
| ids = batch["tgt_ids"].to(device) |
| mask = batch["tgt_mask"].to(device) |
|
|
| |
| z0 = ae.encode(ids, mask) |
| logits = ae.decode(z0, attention_mask=mask) |
| pred = logits.argmax(dim=-1) |
|
|
| |
| labels = ids.masked_fill(mask == 0, -100) |
| ce = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, reduction="sum") |
| total_ce += ce.item() |
|
|
| |
| correct = ((pred == ids) & (mask.bool())).sum().item() |
| tok = mask.sum().item() |
| total_acc += correct |
| total_tokens += tok |
|
|
| |
| |
| B, L = ids.shape |
| for i in range(B): |
| |
| valid_len = int(mask[i].sum().item()) |
| true_seq = ids[i, :valid_len] |
| pred_seq = pred[i, :valid_len] |
|
|
| true_pos = (true_seq == stop_id).nonzero(as_tuple=True)[0] |
| pred_pos = (pred_seq == stop_id).nonzero(as_tuple=True)[0] |
|
|
| if pred_pos.numel() > 0: |
| eos_found += 1 |
| if true_pos.numel() > 0: |
| eos_count += 1 |
| tpos = int(true_pos[0].item()) |
| ppos = int(pred_pos[0].item()) if pred_pos.numel() > 0 else valid_len - 1 |
| eos_pos_err += abs(ppos - tpos) |
|
|
| |
| z1 = ae.encode(pred, mask) |
| cos = F.cosine_similarity(z0, z1, dim=-1) |
| l2 = (z0 - z1).pow(2).mean(dim=-1) |
| total_cos += (cos * mask).sum().item() |
| total_l2 += (l2 * mask).sum().item() |
| total_lat_tokens += mask.sum().item() |
|
|
| |
| if printed < args.print_n: |
| s = tokenizer.decode(ids[0], skip_special_tokens=True) |
| |
| |
| |
| |
| |
| |
| g = tokenizer.decode(pred[0], skip_special_tokens=True) |
| print("\n--- Example ---") |
| print("GT :", s) |
| print("REC:", g) |
| printed += 1 |
|
|
| avg_ce = total_ce / max(total_tokens, 1.0) |
| avg_acc = total_acc / max(total_tokens, 1.0) |
| avg_cos = total_cos / max(total_lat_tokens, 1.0) |
| avg_l2 = total_l2 / max(total_lat_tokens, 1.0) |
|
|
| eos_found_rate = eos_found / max(total_tokens / args.max_seq_len, 1.0) |
| eos_mae = eos_pos_err / max(eos_count, 1) |
|
|
| print("\n===== AE Metrics =====") |
| print(f"Masked CE per-token: {avg_ce:.4f}") |
| print(f"Token Acc (masked): {avg_acc:.4f}") |
| print(f"Latent cycle cosine(z0,z1): {avg_cos:.4f}") |
| print(f"Latent cycle l2(z0,z1): {avg_l2:.6f}") |
| print(f"EOS found rate (rough): {eos_found_rate:.4f}") |
| print(f"EOS position MAE (only where GT has EOS): {eos_mae:.2f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|