| import os |
| from datasets import load_dataset |
| from transformers import AutoTokenizer |
| import torch |
| from torch.utils.data import DataLoader |
|
|
| def custom_collate(batch): |
| return { |
| 'src_ids': torch.stack([torch.tensor(x['src_ids']) for x in batch]), |
| 'src_mask': torch.stack([torch.tensor(x['src_mask']) for x in batch]), |
| 'tgt_ids': torch.stack([torch.tensor(x['tgt_ids']) for x in batch]), |
| 'tgt_mask': torch.stack([torch.tensor(x['tgt_mask']) for x in batch]), |
| |
| 'test_code': [x.get('test_code', "") for x in batch], |
| 'entry_point': [x.get('entry_point', "") for x in batch] |
| } |
|
|
| def prepare_data(task_name, tokenizer, max_len, batch_size, split="train"): |
| """ |
| 支持 split 参数,方便划分训练集和测试集 |
| """ |
| print(f"Loading {task_name} ({split})...") |
| |
| if task_name == "codexglue": |
| |
| |
| dataset = load_dataset("./code_x_glue_cc_code_refinement_full", "medium", split=split) |
|
|
| |
| if split == "train": dataset = dataset.select(range(40000)) |
|
|
| |
| if 'source' in cols and 'target' in cols: |
| print(">> Detected standard refinement pairs.") |
| def preprocess_standard(ex): |
| src = tokenizer(ex['source'], max_length=max_len, padding="max_length", truncation=True) |
| tgt = tokenizer(ex['target'], max_length=max_len, padding="max_length", truncation=True) |
| return { |
| 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], |
| 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'] |
| } |
| preprocess_fn = preprocess_standard |
|
|
| |
| elif 'code' in cols: |
| print(">> Detected raw code. Not to inject synthetic bugs...") |
| |
| else: |
| raise ValueError(f"Dataset columns {cols} not recognized. Need 'source'/'target' or 'code'.") |
| |
| def preprocess(ex): |
| buggy = ex['source'] |
| fixed = ex['target'] |
| |
| src = tokenizer(buggy, max_length=max_len, padding="max_length", truncation=True) |
| tgt = tokenizer(fixed, max_length=max_len, padding="max_length", truncation=True) |
| |
| return { |
| 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], |
| 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'] |
| } |
| |
| |
| cols = dataset.column_names |
| |
| elif task_name == "humanevalpack": |
| |
| |
| dataset = load_dataset("./bigcode_humanevalpack_full", "python", split="test") |
| |
| |
| dataset = dataset.filter(lambda x: x['task_id'].startswith("Python/FIX")) |
| |
| def preprocess(ex): |
| |
| |
| full_buggy = ex['prompt'] + "\n" + ex['buggy_solution'] |
| full_fixed = ex['prompt'] + "\n" + ex['canonical_solution'] |
| |
| src = tokenizer(full_buggy, max_length=max_len, padding="max_length", truncation=True) |
| tgt = tokenizer(full_fixed, max_length=max_len, padding="max_length", truncation=True) |
| |
| return { |
| 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], |
| 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'], |
| 'test_code': ex['test'], |
| 'entry_point': ex['entry_point'] |
| } |
| |
| |
| cols = [] |
| |
| |
| elif task_name == "wiki": |
| |
| try: |
| dataset = load_dataset("./wikilarge-dataset") |
| except: |
| print("Local load failed, downloading from Hub...") |
| dataset = load_dataset("wikilarge") |
| |
| |
| if split == "train": |
| dataset = dataset['train'].select(range(20000)) |
| else: |
| |
| dataset = dataset['train'].select(range(20000, 25000)) |
|
|
| |
| cols = dataset.column_names |
| print(f"Wiki Dataset Columns: {cols}") |
| |
| |
| if 'src' in cols and 'dst' in cols: |
| src_key, tgt_key = 'src', 'dst' |
| elif 'Normal' in cols and 'Simple' in cols: |
| src_key, tgt_key = 'Normal', 'Simple' |
| else: |
| raise ValueError(f"Unknown column format for WikiLarge: {cols}") |
|
|
| def preprocess(ex): |
| |
| src = tokenizer(ex[src_key], max_length=max_len, padding="max_length", truncation=True) |
| tgt = tokenizer(ex[tgt_key], max_length=max_len, padding="max_length", truncation=True) |
| return { |
| 'src_ids': src['input_ids'], 'src_mask': src['attention_mask'], |
| 'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'] |
| } |
| |
| elif task_name == "mbpp": |
| dataset = load_dataset("mbpp", split="train[:500]") |
| print(f"MBPP Dataset Columns: {dataset.column_names}") |
| |
| |
| def preprocess(ex): |
| enc = tokenizer(ex['code'], max_length=max_len, padding="max_length", truncation=True) |
| return { |
| 'src_ids': enc['input_ids'], 'src_mask': enc['attention_mask'], |
| 'tgt_ids': enc['input_ids'], 'tgt_mask': enc['attention_mask'] |
| } |
| |
| else: |
| raise ValueError(f"Unknown task: {task_name}") |
|
|
| |
| print(f"Preprocessing {task_name} data...") |
| |
| print(f"Preprocessing {len(dataset)} examples...") |
| dataset = dataset.map( |
| preprocess, |
| batched=True, |
| remove_columns=dataset.column_names, |
| num_proc=4 |
| ) |
| |
| |
| return DataLoader(dataset, batch_size=batch_size, shuffle=(split=="train"), collate_fn=custom_collate) |