Spaces:
Runtime error
Runtime error
| #%% | |
| import argparse | |
| import time | |
| from tqdm import tqdm | |
| import pandas as pd | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import os | |
| import json | |
| import torch | |
| from dotenv import load_dotenv | |
| #%% | |
| load_dotenv() | |
| from nltk.tokenize import sent_tokenize | |
| wd = os.path.dirname(os.path.realpath(__file__)) | |
| class BackTranslatorAugmenter: | |
| """ | |
| A class that performs BackTranslation in order to do data augmentation. | |
| For best results we recommend using bottleneck languages (`out_lang`) | |
| such as russian (ru) and | |
| spanish (es). | |
| Example | |
| ------- | |
| .. code-block:: python | |
| data_augmenter = BackTranslatorAugmenter(out_lang="es") | |
| text = "I want to augment this sentence" | |
| print(text) | |
| data_augmenter.back_translate(text, verbose=True) | |
| :param in_lang: the text input language, defaults to "en" | |
| :type in_lang: str, optional | |
| :param out_lang: the language to translate with, defaults to "ru" | |
| :type out_lang: str, optional | |
| """ | |
| def __init__(self, in_lang="en", out_lang="ru") -> None: | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| else: | |
| self.device = "cpu" | |
| self.in_tokenizer = AutoTokenizer.from_pretrained( | |
| f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}", | |
| cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
| ) | |
| self.in_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| f"Helsinki-NLP/opus-mt-{in_lang}-{out_lang}", | |
| cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
| ).to(self.device) | |
| self.out_tokenizer = AutoTokenizer.from_pretrained( | |
| f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}", | |
| cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
| ) | |
| self.out_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| f"Helsinki-NLP/opus-mt-{out_lang}-{in_lang}", | |
| cache_dir=os.getenv("TRANSFORMERS_CACHE"), | |
| ).to(self.device) | |
| def back_translate(self, text, verbose=False): | |
| if verbose: | |
| tic = time.time() | |
| encoded_text = self.in_tokenizer( | |
| text, return_tensors="pt", padding=True, truncation=True, return_overflowing_tokens=True | |
| ).to(self.device) | |
| if encoded_text['num_truncated_tokens'][0] > 0: | |
| print('Text is too long ') | |
| return self.back_translate_long(text,verbose=verbose) | |
| in_generated_ids = self.in_model.generate(inputs=encoded_text['input_ids'], | |
| attention_mask=encoded_text["attention_mask"]) | |
| in_preds = [ | |
| self.in_tokenizer.decode( | |
| gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| for gen_id in in_generated_ids | |
| ] | |
| if verbose: | |
| print("in_pred : ", in_preds) | |
| encoded_text = self.out_tokenizer( | |
| in_preds, return_tensors="pt", padding=True, truncation=True,return_overflowing_tokens=True | |
| ).to(self.device) | |
| out_generated_ids = self.out_model.generate(inputs=encoded_text['input_ids'], | |
| attention_mask=encoded_text["attention_mask"]) | |
| out_preds = [ | |
| self.out_tokenizer.decode( | |
| gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| for gen_id in out_generated_ids | |
| ] | |
| if verbose: | |
| tac = time.time() | |
| print("out_pred : ", out_preds) | |
| print("Elapsed time : ", tac - tic) | |
| return out_preds | |
| def back_translate_long(self, text, verbose=False): | |
| sentences = sent_tokenize(text) | |
| return [" ".join(self.back_translate(sentences, verbose=verbose))] | |
| def do_backtranslation(**args): | |
| df = pd.read_csv(args["input_data_path"])[:1] | |
| data_augmenter = BackTranslatorAugmenter( | |
| in_lang=args["in_lang"], out_lang=args["out_lang"] | |
| ) | |
| dict_res = {col_name: [] for _, col_name in args["col_map"].items()} | |
| for i in tqdm(range(0, len(df), args["batch_size"])): | |
| for old_col, new_col in args["col_map"].items(): | |
| dict_res[new_col] += data_augmenter.back_translate( | |
| list(df[old_col].iloc[i : i + args["batch_size"]]) | |
| ) | |
| augmented_df = pd.DataFrame(dict_res) | |
| os.makedirs(os.path.dirname(args["output_data_path"]), exist_ok=True) | |
| augmented_df.to_csv(args["output_data_path"]) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Back Translate a dataset for better training" | |
| ) | |
| parser.add_argument( | |
| "-in_lang", | |
| type=str, | |
| default="en", | |
| help="""the text input language, defaults to "en", | |
| one can choose between {'es','ru','en','fr','de','pt','zh'} | |
| but please have a look at https://huggingface.co/Helsinki-NLP to make sure the language | |
| pair you ask for is available""", | |
| ) | |
| parser.add_argument( | |
| "-out_lang", | |
| type=str, | |
| default="ru", | |
| help="The bottleneck language if you want to resume training one can" | |
| "choose between {'es','ru','en','fr','de','pt','zh'} but please have a " | |
| "look at https://huggingface.co/Helsinki-NLP to make sure the language" | |
| "pair you ask for is available", | |
| ) | |
| parser.add_argument( | |
| "-input_data_path", | |
| type=str, | |
| default=os.path.join(wd, "dataset", "train_neurips_dataset.csv"), | |
| help="dataset location, please note it should be a CSV file with two" | |
| 'columns : "text" and "summary"', | |
| ) | |
| parser.add_argument( | |
| "-output_data_path", | |
| type=str, | |
| default=os.path.join( | |
| wd, "dataset", "augmented_datas", "augmented_dataset_output.csv" | |
| ), | |
| help="augmented dataset output location", | |
| ) | |
| parser.add_argument( | |
| "-columns_mapping", | |
| "--col_map", | |
| type=json.loads, | |
| default={"abstract": "text", "tldr": "summary"}, | |
| help="columns names to apply data augmentation on " | |
| "you have to give a key/value pair dict such that " | |
| "{'input_column_name1':'output_column_name1'} by default " | |
| " it is set as {'abstract': 'text', 'tldr':'summary'}, " | |
| "if you don't want to change the column names," | |
| " please provide a dict such that keys=values ", | |
| ) | |
| parser.add_argument("-batch_size", type=int, default=25, help="batch_size") | |
| args = parser.parse_args() | |
| do_backtranslation(**vars(args)) | |