Diff-Refine / src /utils /data_utils.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import os
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader
def custom_collate(batch):
return {
'src_ids': torch.stack([torch.tensor(x['src_ids']) for x in batch]),
'src_mask': torch.stack([torch.tensor(x['src_mask']) for x in batch]),
'tgt_ids': torch.stack([torch.tensor(x['tgt_ids']) for x in batch]),
'tgt_mask': torch.stack([torch.tensor(x['tgt_mask']) for x in batch]),
# 保留测试用例用于验证 (仅 Eval 时有效)
'test_code': [x.get('test_code', "") for x in batch],
'entry_point': [x.get('entry_point', "") for x in batch]
}
def prepare_data(task_name, tokenizer, max_len, batch_size, split="train"):
"""
支持 split 参数,方便划分训练集和测试集
"""
print(f"Loading {task_name} ({split})...")
if task_name == "codexglue":
# 训练集:Microsoft CodeXGLUE (Python Refinement)
# 包含 GitHub Bug -> Fix
dataset = load_dataset("./code_x_glue_cc_code_refinement_full", "medium", split=split)
# 40k
if split == "train": dataset = dataset.select(range(40000))
# Case A: 标准修复数据 (有 source 和 target)
if 'source' in cols and 'target' in cols:
print(">> Detected standard refinement pairs.")
def preprocess_standard(ex):
src = tokenizer(ex['source'], max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(ex['target'], max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
}
preprocess_fn = preprocess_standard
# Case B: 只有代码 (有 code),需要人工注入 Bug
elif 'code' in cols:
print(">> Detected raw code. Not to inject synthetic bugs...")
else:
raise ValueError(f"Dataset columns {cols} not recognized. Need 'source'/'target' or 'code'.")
def preprocess(ex):
buggy = ex['source']
fixed = ex['target']
src = tokenizer(buggy, max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(fixed, max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
}
# 移除原始列
cols = dataset.column_names
elif task_name == "humanevalpack":
# 验证集:HumanEvalPack (Fix Task)
# 包含 Buggy Code 和 对应的 Unit Tests
dataset = load_dataset("./bigcode_humanevalpack_full", "python", split="test") # 只有 test 集
# 筛选出 FIX 任务
dataset = dataset.filter(lambda x: x['task_id'].startswith("Python/FIX"))
def preprocess(ex):
# prompt 是前面的描述,buggy_solution 是有 bug 的代码
# 为了简化,我们把 prompt + buggy_solution 作为输入
full_buggy = ex['prompt'] + "\n" + ex['buggy_solution']
full_fixed = ex['prompt'] + "\n" + ex['canonical_solution']
src = tokenizer(full_buggy, max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(full_fixed, max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'],
'test_code': ex['test'], # 核心:保留测试代码
'entry_point': ex['entry_point'] # 核心:保留入口函数名
}
# 保留所有列用于 debug,dataset.map 会自动处理返回的 dict
cols = [] # 不自动删除列,我们需要 test 列在 collate 中处理
# --- 1. Load Dataset ---
elif task_name == "wiki":
# 尝试本地加载,失败则下载
try:
dataset = load_dataset("./wikilarge-dataset")
except:
print("Local load failed, downloading from Hub...")
dataset = load_dataset("wikilarge")
# 手动划分: train用前10000条, test用后1000条 (做demo够了,全量太慢)
if split == "train":
dataset = dataset['train'].select(range(20000))
else:
# 假设总共有 ~290k,我们取后面一点做测试
dataset = dataset['train'].select(range(20000, 25000))
# 自动探测列名
cols = dataset.column_names
print(f"Wiki Dataset Columns: {cols}")
# 映射列名到 src/tgt
if 'src' in cols and 'dst' in cols:
src_key, tgt_key = 'src', 'dst'
elif 'Normal' in cols and 'Simple' in cols:
src_key, tgt_key = 'Normal', 'Simple'
else:
raise ValueError(f"Unknown column format for WikiLarge: {cols}")
def preprocess(ex):
# Source (Complex) -> Target (Simple)
src = tokenizer(ex[src_key], max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(ex[tgt_key], max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
}
elif task_name == "mbpp":
dataset = load_dataset("mbpp", split="train[:500]")
print(f"MBPP Dataset Columns: {dataset.column_names}")
# MBPP 自重建任务: src=code, tgt=code
def preprocess(ex):
enc = tokenizer(ex['code'], max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': enc['input_ids'], 'src_mask': enc['attention_mask'],
'tgt_ids': enc['input_ids'], 'tgt_mask': enc['attention_mask']
}
else:
raise ValueError(f"Unknown task: {task_name}")
# --- 2. Map & Batch ---
print(f"Preprocessing {task_name} data...")
# 使用 remove_columns=dataset.column_names 确保删除所有原始列
print(f"Preprocessing {len(dataset)} examples...")
dataset = dataset.map(
preprocess,
batched=True,
remove_columns=dataset.column_names,
num_proc=4
)
# Test 集不 shuffle,方便对齐
return DataLoader(dataset, batch_size=batch_size, shuffle=(split=="train"), collate_fn=custom_collate)