Diff-Refine / train_flow.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.optim as optim
from transformers import AutoTokenizer
from tqdm import tqdm
import torch.nn.functional as F
import os
import argparse
import sacrebleu
from src.config import ModelConfig, TrainConfig
from src.models.autoencoder import ReshapedAutoencoder
from src.models.dit import PatchedFlowDiT
from src.trainer import Trainer
from src.utils.data_utils import prepare_data
# --- Helper Functions for Inference (复制过来以便独立运行) ---
def _pick_stop_id(tokenizer):
return tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
def _first_pos(x_1d, token_id, default):
idx = (x_1d == token_id).nonzero(as_tuple=True)[0]
return idx[0].item() if idx.numel() > 0 else default
def calculate_metrics(sources, predictions, references):
bleu = sacrebleu.corpus_bleu(predictions, [references])
try:
sari = sacrebleu.corpus_sari(sources, predictions, [references])
sari_score = sari.score
except Exception:
sari_score = 0.0
ratios = [len(p) / len(s) if len(s) > 0 else 0 for p, s in zip(predictions, sources)]
avg_ratio = sum(ratios) / len(ratios) if ratios else 0
return {"SARI": sari_score, "BLEU": bleu.score, "Compression Ratio": avg_ratio}
@torch.no_grad()
def inference_batch(ae, flow, loader, tokenizer, device, steps=10, save_path="results.txt", use_oneshot=True):
ae.eval()
flow.eval()
stop_id = _pick_stop_id(tokenizer)
pad_id = tokenizer.pad_token_id
print(f"\n>>> Running Inference on {len(loader.dataset)} examples...")
all_sources, all_targets, all_generated = [], [], []
scale = getattr(ae, "latent_scale", 10.0) # 兼容逻辑
with open(save_path, "w", encoding="utf-8") as f:
f.write("Source\tTarget\tGenerated\n")
for batch in tqdm(loader, desc="Inferencing"):
src_ids = batch['src_ids'].to(device)
src_mask = batch['src_mask'].to(device)
tgt_ids = batch['tgt_ids'].to(device)
B, L = src_ids.shape
# Encode
z_curr = ae.encode(src_ids, src_mask)
z_cond = z_curr.clone()
# Flow Sampling
if use_oneshot:
t0 = torch.zeros(B, device=device)
z_curr = flow(z_curr, t0, condition=z_cond).float()
else:
dt = 1.0 / steps
for i in range(steps):
t_val = i / steps
if t_val >= 0.999: break
t = torch.ones(B, device=device) * t_val
pred_z1 = flow(z_curr, t, condition=z_cond).float()
v = (pred_z1 - z_curr) / (1.0 - t_val + 1e-4)
z_curr = z_curr + v * dt
z_curr = pred_z1
# Decode (Pass 1: Detect Length)
full_mask = torch.ones(B, L, device=device)
logits1 = ae.decode(z_curr, attention_mask=full_mask)
ids1 = logits1.argmax(dim=-1)
stop_pos = []
for i in range(B):
pos = _first_pos(ids1[i], stop_id, default=L - 1)
stop_pos.append(pos)
# Decode (Pass 2: Clean Decode)
gen_mask = torch.zeros(B, L, device=device)
for i in range(B):
gen_mask[i, : stop_pos[i] + 1] = 1.0
logits2 = ae.decode(z_curr, attention_mask=gen_mask)
ids2 = logits2.argmax(dim=-1)
ids2 = ids2.masked_fill(gen_mask == 0, pad_id)
# Convert to Text
src_texts = tokenizer.batch_decode(src_ids, skip_special_tokens=True)
tgt_texts = tokenizer.batch_decode(tgt_ids, skip_special_tokens=True)
gen_texts = []
for i in range(B):
end = stop_pos[i] + 1
ids_cut = ids2[i, :end]
gen_texts.append(tokenizer.decode(ids_cut, skip_special_tokens=True))
for s, t, g in zip(src_texts, tgt_texts, gen_texts):
s_c = s.replace("\n", " ")
t_c = t.replace("\n", " ")
g_c = g.replace("\n", " ")
f.write(f"{s_c}\t{t_c}\t{g_c}\n")
all_sources.append(s_c)
all_targets.append(t_c)
all_generated.append(g_c)
return all_sources, all_targets, all_generated
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ae_ckpt", type=str, default="/mnt/hdfs/user/lixinyu.222/CodeFlow/residual_robust_checkpoints/ae_best.pt", help="Path to pre-trained AE checkpoint")
parser.add_argument("--save_dir", type=str, default="residual_robust_checkpoints", help="Directory to save flow checkpoints")
parser.add_argument("--use_oneshot", action="store_true", default=True, help="Use one-shot sampling for inference")
args = parser.parse_args()
os.makedirs(args.save_dir, exist_ok=True)
# --- Config ---
m_cfg = ModelConfig(
encoder_name='../jina-embeddings-v2-base-code',
latent_dim=512,
max_seq_len=128
)
t_cfg = TrainConfig(
batch_size=16,
num_epochs_flow=35, # 只关注 Flow 的 epoch
grad_accum_steps=4,
use_amp=False,
lr_flow=2e-4
)
# --- Tokenizer & Data ---
tokenizer = AutoTokenizer.from_pretrained(m_cfg.encoder_name,local_files_only=True, trust_remote_code=False)
train_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="train")
test_loader = prepare_data("wiki", tokenizer, m_cfg.max_seq_len, t_cfg.batch_size, split="test")
# --- Load AE (Pre-trained) ---
print(f"\n>>> Loading Pre-trained Autoencoder from {args.ae_ckpt} ...")
ae = ReshapedAutoencoder(m_cfg).to(t_cfg.device).float()
if not os.path.exists(args.ae_ckpt):
raise FileNotFoundError(f"AE checkpoint not found at {args.ae_ckpt}. Please run train_ae.py first.")
ae.load_state_dict(torch.load(args.ae_ckpt, map_location=t_cfg.device))
# 冻结 AE 的所有参数,Flow 训练时不更新 AE
ae.eval()
for param in ae.parameters():
param.requires_grad = False
print(">>> Autoencoder loaded and frozen.")
if ae.encoder.config.pad_token_id is None:
ae.encoder.config.pad_token_id = tokenizer.pad_token_id
# --- Initialize Flow ---
flow = PatchedFlowDiT(m_cfg).to(t_cfg.device).float()
# --- Trainer ---
trainer = Trainer(
ae=ae,
flow=flow,
cfg=t_cfg,
loader=train_loader,
pad_id=tokenizer.pad_token_id,
stop_id=_pick_stop_id(tokenizer)
)
# --- Optimizer ---
opt_flow = optim.AdamW(flow.parameters(), lr=t_cfg.lr_flow)
# --- Training Loop ---
best_flow_loss = float('inf')
print("\n>>> Start Training Flow DiT...")
for epoch in range(t_cfg.num_epochs_flow):
# 传入 opt_flow 训练 Flow
loss = trainer.train_flow(opt_flow)
print(f"Flow Epoch {epoch}: Loss {loss:.4f}")
# Save Best
if loss < best_flow_loss:
best_flow_loss = loss
save_path = os.path.join(args.save_dir, "flow_best.pt")
torch.save(flow.state_dict(), save_path)
# print(f" Saved Best Flow to {save_path}")
# Save Last
torch.save(flow.state_dict(), os.path.join(args.save_dir, "flow_last.pt"))
print(f"Flow Training Done. Best Loss: {best_flow_loss:.4f}")
# --- Inference / Evaluation ---
print("\n>>> Loading Best Flow Checkpoint for Evaluation...")
best_flow_path = os.path.join(args.save_dir, "flow_best.pt")
if os.path.exists(best_flow_path):
flow.load_state_dict(torch.load(best_flow_path, map_location=t_cfg.device))
else:
print("Warning: Best checkpoint not found, utilizing last epoch weights.")
print("\n--- Starting Inference ---")
sources, targets, gens = inference_batch(
ae, flow, test_loader, tokenizer, t_cfg.device,
steps=10,
save_path="wiki_results.tsv",
use_oneshot=args.use_oneshot
)
# Metrics
metrics = calculate_metrics(sources, gens, targets)
print("\n=== Metrics ===")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
print(f"\nResults saved to wiki_results.tsv")
if __name__ == "__main__":
main()