| import torch |
| import torch.optim as optim |
| from transformers import AutoTokenizer |
| from tqdm import tqdm |
| import torch.nn.functional as F |
| import os |
| import argparse |
| import sacrebleu |
|
|
| from src.config import ModelConfig, TrainConfig |
| from src.models.autoencoder import ReshapedAutoencoder |
| from src.models.dit import PatchedFlowDiT |
| from src.trainer import Trainer |
| from src.utils.data_utils import prepare_data |
|
|
| |
| def _pick_stop_id(tokenizer): |
| return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id |
|
|
| def _first_pos(x_1d, token_id, default): |
| idx = (x_1d == token_id).nonzero(as_tuple=True)[0] |
| return idx[0].item() if idx.numel() > 0 else default |
|
|
| def calculate_metrics(sources, predictions, references): |
| bleu = sacrebleu.corpus_bleu(predictions, [references]) |
| try: |
| sari = sacrebleu.corpus_sari(sources, predictions, [references]) |
| sari_score = sari.score |
| except Exception: |
| sari_score = 0.0 |
| |
| ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)] |
| avg_ratio = sum(ratios) / len(ratios) if ratios else 0 |
| |
| return {"SARI": sari_score, "BLEU": bleu.score, "Compression Ratio": avg_ratio} |
|
|
| @torch.no_grad() |
| def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt", use_oneshot=True): |
| ae.eval() |
| flow.eval() |
| stop_id = _pick_stop_id(tokenizer) |
| pad_id = tokenizer.pad_token_id |
| |
| print(f"\n>>> Running Inference on {len(loader.dataset)} examples...") |
| |
| all_sources, all_targets, all_generated = [], [], [] |
| scale = getattr(ae, "latent_scale", 10.0) |
| |
| with open(save_path, "w", encoding="utf-8") as f: |
| f.write("Source\tTarget\tGenerated\n") |
| |
| for batch in tqdm(loader, desc="Inferencing"): |
| src_ids = batch['src_ids'].to(device) |
| src_mask = batch['src_mask'].to(device) |
| tgt_ids = batch['tgt_ids'].to(device) |
| B, L = src_ids.shape |
|
|
| |
| z_curr = ae.encode(src_ids, src_mask) |
| z_cond = z_curr.clone() |
|
|
| |
| if use_oneshot: |
| t0 = torch.zeros(B, device=device) |
| z_curr = flow(z_curr, t0, condition=z_cond).float() |
| else: |
| dt = 1.0 / steps |
| for i in range(steps): |
| t_val = i / steps |
| if t_val >= 0.999: break |
| t = torch.ones(B, device=device) * t_val |
| pred_z1 = flow(z_curr, t, condition=z_cond).float() |
| v = (pred_z1 - z_curr) / (1.0 - t_val + 1e-4) |
| z_curr = z_curr + v * dt |
| z_curr = pred_z1 |
| |
| |
| full_mask = torch.ones(B, L, device=device) |
| logits1 = ae.decode(z_curr, attention_mask=full_mask) |
| ids1 = logits1.argmax(dim=-1) |
|
|
| stop_pos = [] |
| for i in range(B): |
| pos = _first_pos(ids1[i], stop_id, default=L - 1) |
| stop_pos.append(pos) |
| |
| |
| gen_mask = torch.zeros(B, L, device=device) |
| for i in range(B): |
| gen_mask[i, : stop_pos[i] + 1] = 1.0 |
| |
| logits2 = ae.decode(z_curr, attention_mask=gen_mask) |
| ids2 = logits2.argmax(dim=-1) |
| ids2 = ids2.masked_fill(gen_mask == 0, pad_id) |
|
|
| |
| src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True) |
| tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True) |
| |
| gen_texts = [] |
| for i in range(B): |
| end = stop_pos[i] + 1 |
| ids_cut = ids2[i, :end] |
| gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True)) |
| |
| for s, t, g in zip(src_texts, tgt_texts, gen_texts): |
| s_c = s.replace("\n", " ") |
| t_c = t.replace("\n", " ") |
| g_c = g.replace("\n", " ") |
| f.write(f"{s_c}\t{t_c}\t{g_c}\n") |
| all_sources.append(s_c) |
| all_targets.append(t_c) |
| all_generated.append(g_c) |
| |
| return all_sources, all_targets, all_generated |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ae_ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="Path to pre-trained AE checkpoint") |
| parser.add_argument("--save_dir", type=str, default="residual_robust_checkpoints", help="Directory to save flow checkpoints") |
| parser.add_argument("--use_oneshot", action="store_true", default=True, help="Use one-shot sampling for inference") |
| args = parser.parse_args() |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| |
| m_cfg = ModelConfig( |
| encoder_name='../jina-embeddings-v2-base-code', |
| latent_dim=512, |
| max_seq_len=128 |
| ) |
| |
| t_cfg = TrainConfig( |
| batch_size=16, |
| num_epochs_flow=35, |
| grad_accum_steps=4, |
| use_amp=False, |
| lr_flow=2e-4 |
| ) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True, trust_remote_code=False) |
| train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train") |
| test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test") |
|
|
| |
| print(f"\n>>> Loading Pre-trained Autoencoder from {args.ae_ckpt} ...") |
| ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float() |
| |
| if not os.path.exists(args.ae_ckpt): |
| raise FileNotFoundError(f"AE checkpoint not found at {args.ae_ckpt}. Please run train_ae.py first.") |
| |
| ae.load_state_dict(torch.load(args.ae_ckpt, map_location=t_cfg.device)) |
| |
| |
| ae.eval() |
| for param in ae.parameters(): |
| param.requires_grad = False |
| print(">>> Autoencoder loaded and frozen.") |
|
|
| if ae.encoder.config.pad_token_id is None: |
| ae.encoder.config.pad_token_id = tokenizer.pad_token_id |
|
|
| |
| flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float() |
| |
| |
| trainer = Trainer( |
| ae=ae, |
| flow=flow, |
| cfg=t_cfg, |
| loader=train_loader, |
| pad_id=tokenizer.pad_token_id, |
| stop_id=_pick_stop_id(tokenizer) |
| ) |
|
|
| |
| opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow) |
|
|
| |
| best_flow_loss = float('inf') |
| print("\n>>> Start Training Flow DiT...") |
| |
| for epoch in range(t_cfg.num_epochs_flow): |
| |
| loss = trainer.train_flow(opt_flow) |
| print(f"Flow Epoch {epoch}: Loss {loss:.4f}") |
|
|
| |
| if loss < best_flow_loss: |
| best_flow_loss = loss |
| save_path = os.path.join(args.save_dir, "flow_best.pt") |
| torch.save(flow.state_dict(), save_path) |
| |
| |
| |
| torch.save(flow.state_dict(), os.path.join(args.save_dir, "flow_last.pt")) |
| |
| print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}") |
| |
| |
| print("\n>>> Loading Best Flow Checkpoint for Evaluation...") |
| best_flow_path = os.path.join(args.save_dir, "flow_best.pt") |
| if os.path.exists(best_flow_path): |
| flow.load_state_dict(torch.load(best_flow_path, map_location=t_cfg.device)) |
| else: |
| print("Warning: Best checkpoint not found, utilizing last epoch weights.") |
|
|
| print("\n--- Starting Inference ---") |
| sources, targets, gens = inference_batch( |
| ae, flow, test_loader, tokenizer, t_cfg.device, |
| steps=10, |
| save_path="wiki_results.tsv", |
| use_oneshot=args.use_oneshot |
| ) |
| |
| |
| metrics = calculate_metrics(sources, gens, targets) |
| print("\n=== Metrics ===") |
| for k, v in metrics.items(): |
| print(f"{k}: {v:.4f}") |
| |
| print(f"\nResults saved to wiki_results.tsv") |
|
|
| if __name__ == "__main__": |
| main() |