| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import fire |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from transformers import AutoTokenizer |
| from utils import Seq2SeqDataset, pickle_save |
|
|
|
|
| def save_len_file( |
| tokenizer_name, data_dir, max_source_length=1024, max_target_length=1024, consider_target=False, **kwargs |
| ): |
| """Save max(src_len, tgt_len) for each example to allow dynamic batching.""" |
| tok = AutoTokenizer.from_pretrained(tokenizer_name) |
| train_ds = Seq2SeqDataset(tok, data_dir, max_source_length, max_target_length, type_path="train", **kwargs) |
| pad = tok.pad_token_id |
|
|
| def get_lens(ds): |
| dl = tqdm( |
| DataLoader(ds, batch_size=512, num_workers=8, shuffle=False, collate_fn=ds.collate_fn), |
| desc=str(ds.len_file), |
| ) |
| max_lens = [] |
| for batch in dl: |
| src_lens = batch["input_ids"].ne(pad).sum(1).tolist() |
| tgt_lens = batch["labels"].ne(pad).sum(1).tolist() |
| if consider_target: |
| for src, tgt in zip(src_lens, tgt_lens): |
| max_lens.append(max(src, tgt)) |
| else: |
| max_lens.extend(src_lens) |
| return max_lens |
|
|
| train_lens = get_lens(train_ds) |
| val_ds = Seq2SeqDataset(tok, data_dir, max_source_length, max_target_length, type_path="val", **kwargs) |
| val_lens = get_lens(val_ds) |
| pickle_save(train_lens, train_ds.len_file) |
| pickle_save(val_lens, val_ds.len_file) |
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(save_len_file) |
|
|