import os from datasets import load_dataset from transformers import AutoTokenizer import torch from torch.utils.data import DataLoader def custom_collate(batch): return { 'src_ids': torch.stack([torch.tensor(x['src_ids']) for x in batch]), 'src_mask': torch.stack([torch.tensor(x['src_mask']) for x in batch]), 'tgt_ids': torch.stack([torch.tensor(x['tgt_ids']) for x in batch]), 'tgt_mask': torch.stack([torch.tensor(x['tgt_mask']) for x in batch]), # 保留测试用例用于验证 (仅 Eval 时有效) 'test_code': [x.get('test_code', "") for x in batch], 'entry_point': [x.get('entry_point', "") for x in batch] } def prepare_data(task_name, tokenizer, max_len, batch_size, split="train"): """ 支持 split 参数,方便划分训练集和测试集 """ print(f"Loading {task_name} ({split})...") if task_name == "codexglue": # 训练集:Microsoft CodeXGLUE (Python Refinement) # 包含 GitHub Bug -> Fix dataset = load_dataset("./code_x_glue_cc_code_refinement_full", "medium", split=split) # 40k if split == "train": dataset = dataset.select(range(40000)) # Case A: 标准修复数据 (有 source 和 target) if 'source' in cols and 'target' in cols: print(">> Detected standard refinement pairs.") def preprocess_standard(ex): src = tokenizer(ex['source'], max_length=max_len, padding="max_length", truncation=True) tgt = tokenizer(ex['target'], max_length=max_len, padding="max_length", truncation=True) return { 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'] } preprocess_fn = preprocess_standard # Case B: 只有代码 (有 code),需要人工注入 Bug elif 'code' in cols: print(">> Detected raw code. Not to inject synthetic bugs...") else: raise ValueError(f"Dataset columns {cols} not recognized. Need 'source'/'target' or 'code'.") def preprocess(ex): buggy = ex['source'] fixed = ex['target'] src = tokenizer(buggy, max_length=max_len, padding="max_length", truncation=True) tgt = tokenizer(fixed, max_length=max_len, padding="max_length", truncation=True) return { 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'] } # 移除原始列 cols = dataset.column_names elif task_name == "humanevalpack": # 验证集:HumanEvalPack (Fix Task) # 包含 Buggy Code 和 对应的 Unit Tests dataset = load_dataset("./bigcode_humanevalpack_full", "python", split="test") # 只有 test 集 # 筛选出 FIX 任务 dataset = dataset.filter(lambda x: x['task_id'].startswith("Python/FIX")) def preprocess(ex): # prompt 是前面的描述,buggy_solution 是有 bug 的代码 # 为了简化,我们把 prompt + buggy_solution 作为输入 full_buggy = ex['prompt'] + "\n" + ex['buggy_solution'] full_fixed = ex['prompt'] + "\n" + ex['canonical_solution'] src = tokenizer(full_buggy, max_length=max_len, padding="max_length", truncation=True) tgt = tokenizer(full_fixed, max_length=max_len, padding="max_length", truncation=True) return { 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'], 'test_code': ex['test'], # 核心:保留测试代码 'entry_point': ex['entry_point'] # 核心:保留入口函数名 } # 保留所有列用于 debug,dataset.map 会自动处理返回的 dict cols = [] # 不自动删除列,我们需要 test 列在 collate 中处理 # --- 1. Load Dataset --- elif task_name == "wiki": # 尝试本地加载,失败则下载 try: dataset = load_dataset("./wikilarge-dataset") except: print("Local load failed, downloading from Hub...") dataset = load_dataset("wikilarge") # 手动划分: train用前10000条, test用后1000条 (做demo够了,全量太慢) if split == "train": dataset = dataset['train'].select(range(20000)) else: # 假设总共有 ~290k,我们取后面一点做测试 dataset = dataset['train'].select(range(20000, 25000)) # 自动探测列名 cols = dataset.column_names print(f"Wiki Dataset Columns: {cols}") # 映射列名到 src/tgt if 'src' in cols and 'dst' in cols: src_key, tgt_key = 'src', 'dst' elif 'Normal' in cols and 'Simple' in cols: src_key, tgt_key = 'Normal', 'Simple' else: raise ValueError(f"Unknown column format for WikiLarge: {cols}") def preprocess(ex): # Source (Complex) -> Target (Simple) src = tokenizer(ex[src_key], max_length=max_len, padding="max_length", truncation=True) tgt = tokenizer(ex[tgt_key], max_length=max_len, padding="max_length", truncation=True) return { 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'] } elif task_name == "mbpp": dataset = load_dataset("mbpp", split="train[:500]") print(f"MBPP Dataset Columns: {dataset.column_names}") # MBPP 自重建任务: src=code, tgt=code def preprocess(ex): enc = tokenizer(ex['code'], max_length=max_len, padding="max_length", truncation=True) return { 'src_ids': enc['input_ids'], 'src_mask': enc['attention_mask'], 'tgt_ids': enc['input_ids'], 'tgt_mask': enc['attention_mask'] } else: raise ValueError(f"Unknown task: {task_name}") # --- 2. Map & Batch --- print(f"Preprocessing {task_name} data...") # 使用 remove_columns=dataset.column_names 确保删除所有原始列 print(f"Preprocessing {len(dataset)} examples...") dataset = dataset.map( preprocess, batched=True, remove_columns=dataset.column_names, num_proc=4 ) # Test 集不 shuffle,方便对齐 return DataLoader(dataset, batch_size=batch_size, shuffle=(split=="train"), collate_fn=custom_collate)