| | from pathlib import Path |
| |
|
| | import pandas as pd |
| | from rex.utils.initialization import set_seed_and_log_path |
| | from rex.utils.io import load_json |
| | from rich.console import Console |
| | from rich.table import Table |
| |
|
| | from src.task import SchemaGuidedInstructBertTask |
| |
|
| | set_seed_and_log_path(log_path="tmp_eval.log") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | task_dir = "mirror_outputs/Mirror_Pretrain_AllExcluded_2" |
| | |
| | task: SchemaGuidedInstructBertTask = SchemaGuidedInstructBertTask.from_taskdir( |
| | task_dir, |
| | load_best_model=True, |
| | initialize=False, |
| | dump_configfile=False, |
| | update_config={ |
| | "regenerate_cache": True, |
| | "eval_on_data": ["dev"], |
| | "select_best_on_data": "dev", |
| | "select_best_by_key": "metric", |
| | "best_metric_field": "general_spans.micro.f1", |
| | "eval_batch_size": 32, |
| | }, |
| | ) |
| | table = Table(title=task_dir) |
| |
|
| | data_pairs = [ |
| | |
| |
|
| | |
| | |
| | |
| | ["ent_conll03_test", "resources/Mirror/uie/ent/conll03/test.jsonl"], |
| | |
| | ["rel_conll04_test", "resources/Mirror/uie/rel/conll04/test.jsonl"], |
| | |
| | |
| | ["event_ace05_test", "resources/Mirror/uie/event/ace05-evt/test.jsonl"], |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ] |
| |
|
| | eval_res = {"task": [], "dataset": [], "metric_val": []} |
| | table.add_column("Task", justify="left", style="cyan") |
| | table.add_column("Dataset", justify="left", style="magenta") |
| | table.add_column("Metric (%)", justify="right", style="green") |
| | for dname, fpath in data_pairs: |
| | dname = dname.lower() |
| | task.data_manager.update_datapath(dname, fpath) |
| | _, res = task.eval(dname, verbose=True, dump=True, dump_middle=True) |
| | |
| | if dname.startswith("ent_"): |
| | eval_res["task"].append("ent") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["ent"]["micro"]["f1"]) |
| | elif dname.startswith("rel_"): |
| | eval_res["task"].append("rel") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"]) |
| | elif dname.startswith("event_"): |
| | eval_res["task"].append("event") |
| | eval_res["dataset"].append(dname + "_tgg") |
| | eval_res["metric_val"].append(res["event"]["trigger_cls"]["f1"]) |
| | eval_res["task"].append("event") |
| | eval_res["dataset"].append(dname + "_arg") |
| | eval_res["metric_val"].append(res["event"]["arg_cls"]["f1"]) |
| | elif dname.startswith("absa_"): |
| | eval_res["task"].append("absa") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["rel"]["rel"]["micro"]["f1"]) |
| | elif dname.startswith("cls_"): |
| | eval_res["task"].append("cls") |
| | eval_res["dataset"].append(dname) |
| | if "_glue_" in dname: |
| | if "_cola" in dname: |
| | eval_res["metric_val"].append(res["cls"]["mcc"]) |
| | else: |
| | eval_res["metric_val"].append(res["cls"]["acc"]) |
| | else: |
| | eval_res["metric_val"].append(res["cls"]["mf1"]["micro"]["f1"]) |
| | elif dname.startswith("span"): |
| | eval_res["task"].append("span_em") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["span"]["em"]) |
| | eval_res["task"].append("span_f1") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["span"]["f1"]["f1"]) |
| | elif dname.startswith("discontinuous_ent"): |
| | eval_res["task"].append("discontinuous_ent") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["discontinuous_ent"]["micro"]["f1"]) |
| | elif dname.startswith("hyper_rel"): |
| | eval_res["task"].append("hyper_rel") |
| | eval_res["dataset"].append(dname) |
| | eval_res["metric_val"].append(res["hyper_rel"]["micro"]["f1"]) |
| | else: |
| | raise ValueError |
| |
|
| | for i in range(len(eval_res["task"])): |
| | table.add_row( |
| | eval_res["task"][i], |
| | eval_res["dataset"][i], |
| | f"{100*eval_res['metric_val'][i]:.3f}", |
| | ) |
| |
|
| | console = Console() |
| | console.print(table) |
| |
|
| | df = pd.DataFrame(eval_res) |
| | df.to_excel(task.measures_path.joinpath("data_eval_res.xlsx")) |
| |
|