| | import math |
| | import re |
| | from collections import defaultdict |
| | from datetime import datetime |
| | from typing import List |
| |
|
| | import torch |
| | import torch.optim as optim |
| | from rex import accelerator |
| | from rex.data.data_manager import DataManager |
| | from rex.data.dataset import CachedDataset, StreamReadDataset |
| | from rex.tasks.simple_metric_task import SimpleMetricTask |
| | from rex.utils.batch import decompose_batch_into_instances |
| | from rex.utils.config import ConfigParser |
| | from rex.utils.dict import flatten_dict |
| | from rex.utils.io import load_jsonlines |
| | from rex.utils.registry import register |
| | from torch.utils.tensorboard import SummaryWriter |
| | from transformers.optimization import ( |
| | get_cosine_schedule_with_warmup, |
| | get_linear_schedule_with_warmup, |
| | ) |
| |
|
| | from .metric import MrcNERMetric, MrcSpanMetric, MultiPartSpanMetric |
| | from .model import ( |
| | MrcGlobalPointerModel, |
| | MrcPointerMatrixModel, |
| | SchemaGuidedInstructBertModel, |
| | ) |
| | from .transform import ( |
| | CachedLabelPointerTransform, |
| | CachedPointerMRCTransform, |
| | CachedPointerTaggingTransform, |
| | ) |
| |
|
| |
|
| | @register("task") |
| | class MrcTaggingTask(SimpleMetricTask): |
| | def __init__(self, config, **kwargs) -> None: |
| | super().__init__(config, **kwargs) |
| |
|
| | def after_initialization(self): |
| | now_string = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") |
| | self.tb_logger: SummaryWriter = SummaryWriter( |
| | log_dir=self.task_path / "tb_summary" / now_string, |
| | comment=self.config.comment, |
| | ) |
| |
|
| | def after_whole_train(self): |
| | self.tb_logger.close() |
| |
|
| | def get_grad_norm(self): |
| | |
| | |
| | |
| | |
| | total_norm = 0.0 |
| | for p in self.model.parameters(): |
| | if p.grad is not None: |
| | param_norm = p.grad.detach().data.norm(2) |
| | total_norm += param_norm.item() ** 2 |
| | total_norm = total_norm ** (1.0 / 2) |
| | return total_norm |
| |
|
| | def log_loss( |
| | self, idx: int, loss_item: float, step_or_epoch: str, dataset_name: str |
| | ): |
| | self.tb_logger.add_scalar( |
| | f"loss/{dataset_name}/{step_or_epoch}", loss_item, idx |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.tb_logger.add_scalar("lr", self.optimizer.param_groups[0]["lr"], idx) |
| | self.tb_logger.add_scalar("grad_norm_total", self.get_grad_norm(), idx) |
| |
|
| | def log_metrics( |
| | self, idx: int, metrics: dict, step_or_epoch: str, dataset_name: str |
| | ): |
| | metrics = flatten_dict(metrics) |
| | self.tb_logger.add_scalars(f"{dataset_name}/{step_or_epoch}", metrics, idx) |
| |
|
| | def init_transform(self): |
| | return CachedPointerTaggingTransform( |
| | self.config.max_seq_len, |
| | self.config.plm_dir, |
| | self.config.ent_type2query_filepath, |
| | mode=self.config.mode, |
| | negative_sample_prob=self.config.negative_sample_prob, |
| | ) |
| |
|
| | def init_data_manager(self): |
| | return DataManager( |
| | self.config.train_filepath, |
| | self.config.dev_filepath, |
| | self.config.test_filepath, |
| | CachedDataset, |
| | self.transform, |
| | load_jsonlines, |
| | self.config.train_batch_size, |
| | self.config.eval_batch_size, |
| | self.transform.collate_fn, |
| | use_stream_transform=False, |
| | debug_mode=self.config.debug_mode, |
| | dump_cache_dir=self.config.dump_cache_dir, |
| | regenerate_cache=self.config.regenerate_cache, |
| | ) |
| |
|
| | def init_model(self): |
| | |
| | m = MrcGlobalPointerModel( |
| | self.config.plm_dir, |
| | biaffine_size=self.config.biaffine_size, |
| | dropout=self.config.dropout, |
| | mode=self.config.mode, |
| | ) |
| | return m |
| |
|
| | def init_metric(self): |
| | return MrcNERMetric() |
| |
|
| | def init_optimizer(self): |
| | no_decay = r"(embedding|LayerNorm|\.bias$)" |
| | plm_lr = r"^plm\." |
| | non_trainable = r"^plm\.(emb|encoder\.layer\.[0-3])" |
| |
|
| | param_groups = [] |
| | for name, param in self.model.named_parameters(): |
| | lr = self.config.learning_rate |
| | weight_decay = self.config.weight_decay |
| | if re.search(non_trainable, name): |
| | param.requires_grad = False |
| | if not re.search(plm_lr, name): |
| | lr = self.config.other_learning_rate |
| | if re.search(no_decay, name): |
| | weight_decay = 0.0 |
| | param_groups.append( |
| | {"params": param, "lr": lr, "weight_decay": weight_decay} |
| | ) |
| | return optim.AdamW( |
| | param_groups, |
| | lr=self.config.learning_rate, |
| | betas=(0.9, 0.98), |
| | eps=1e-6, |
| | ) |
| |
|
| | def init_lr_scheduler(self): |
| | num_training_steps = int( |
| | len(self.data_manager.train_loader) |
| | * self.config.num_epochs |
| | * accelerator.num_processes |
| | ) |
| | num_warmup_steps = math.floor( |
| | num_training_steps * self.config.warmup_proportion |
| | ) |
| | |
| | return get_cosine_schedule_with_warmup( |
| | self.optimizer, |
| | num_warmup_steps=num_warmup_steps, |
| | num_training_steps=num_training_steps, |
| | ) |
| |
|
| | def predict_api(self, texts: List[str], **kwargs): |
| | raw_dataset = self.transform.predict_transform(texts) |
| | text_ids = sorted(list({ins["id"] for ins in raw_dataset})) |
| | loader = self.data_manager.prepare_loader(raw_dataset) |
| | |
| | loader = accelerator.prepare_data_loader(loader) |
| | id2ents = defaultdict(set) |
| | for batch in loader: |
| | batch_out = self.model(**batch, is_eval=True) |
| | for _id, _pred in zip(batch["id"], batch_out["pred"]): |
| | id2ents[_id].update(_pred) |
| | results = [id2ents[_id] for _id in text_ids] |
| |
|
| | return results |
| |
|
| |
|
| | @register("task") |
| | class MrcQaTask(MrcTaggingTask): |
| | def init_transform(self): |
| | return CachedPointerMRCTransform( |
| | self.config.max_seq_len, |
| | self.config.plm_dir, |
| | mode=self.config.mode, |
| | ) |
| |
|
| | def init_model(self): |
| | |
| | m = MrcGlobalPointerModel( |
| | self.config.plm_dir, |
| | biaffine_size=self.config.biaffine_size, |
| | dropout=self.config.dropout, |
| | mode=self.config.mode, |
| | ) |
| | return m |
| |
|
| | def init_metric(self): |
| | return MrcSpanMetric() |
| |
|
| | def predict_api(self, data: list[dict], **kwargs): |
| | """ |
| | Args: |
| | data: a list of dict with query, context, and background strings |
| | """ |
| | raw_dataset = self.transform.predict_transform(data) |
| | loader = self.data_manager.prepare_loader(raw_dataset) |
| | results = [] |
| | for batch in loader: |
| | batch_out = self.model(**batch, is_eval=True) |
| | batch["pred"] = batch_out["pred"] |
| | instances = decompose_batch_into_instances(batch) |
| | for ins in instances: |
| | preds = ins["pred"] |
| | ins_results = [] |
| | for index_list in preds: |
| | ins_result = [] |
| | for i in index_list: |
| | ins_result.append(ins["raw_tokens"][i]) |
| | ins_results.append(("".join(ins_result), tuple(index_list))) |
| | results.append(ins_results) |
| |
|
| | return results |
| |
|
| |
|
| | class StreamReadDatasetWithLen(StreamReadDataset): |
| | def __len__(self): |
| | return 631346 |
| |
|
| |
|
| | @register("task") |
| | class SchemaGuidedInstructBertTask(MrcTaggingTask): |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def init_transform(self): |
| | self.transform: CachedLabelPointerTransform |
| | return CachedLabelPointerTransform( |
| | self.config.max_seq_len, |
| | self.config.plm_dir, |
| | mode=self.config.mode, |
| | label_span=self.config.label_span, |
| | include_instructions=self.config.get("include_instructions", True), |
| | ) |
| |
|
| | def init_data_manager(self): |
| | if self.config.get("stream_mode", False): |
| | DatasetClass = StreamReadDatasetWithLen |
| | transform = self.transform.transform |
| | else: |
| | DatasetClass = CachedDataset |
| | transform = self.transform |
| | return DataManager( |
| | self.config.train_filepath, |
| | self.config.dev_filepath, |
| | self.config.test_filepath, |
| | DatasetClass, |
| | transform, |
| | load_jsonlines, |
| | self.config.train_batch_size, |
| | self.config.eval_batch_size, |
| | self.transform.collate_fn, |
| | use_stream_transform=self.config.get("stream_mode", False), |
| | debug_mode=self.config.debug_mode, |
| | dump_cache_dir=self.config.dump_cache_dir, |
| | regenerate_cache=self.config.regenerate_cache, |
| | ) |
| |
|
| | def init_model(self): |
| | self.model = SchemaGuidedInstructBertModel( |
| | self.config.plm_dir, |
| | vocab_size=len(self.transform.tokenizer), |
| | use_rope=self.config.use_rope, |
| | biaffine_size=self.config.biaffine_size, |
| | dropout=self.config.dropout, |
| | ) |
| |
|
| | if self.config.get("base_model_path"): |
| | self.load( |
| | self.config.base_model_path, |
| | load_config=False, |
| | load_model=True, |
| | load_optimizer=False, |
| | load_history=False, |
| | ) |
| | return self.model |
| |
|
| | def init_optimizer(self): |
| | no_decay = r"(embedding|LayerNorm|\.bias$)" |
| | plm_lr = r"^plm\." |
| | |
| | non_trainable = "no_non_trainable" |
| |
|
| | param_groups = [] |
| | for name, param in self.model.named_parameters(): |
| | lr = self.config.learning_rate |
| | weight_decay = self.config.weight_decay |
| | if re.search(non_trainable, name): |
| | param.requires_grad = False |
| | if not re.search(plm_lr, name): |
| | lr = self.config.other_learning_rate |
| | if re.search(no_decay, name): |
| | weight_decay = 0.0 |
| | param_groups.append( |
| | {"params": param, "lr": lr, "weight_decay": weight_decay} |
| | ) |
| | return optim.AdamW( |
| | param_groups, |
| | lr=self.config.learning_rate, |
| | betas=(0.9, 0.98), |
| | eps=1e-6, |
| | ) |
| |
|
| | def init_metric(self): |
| | return MultiPartSpanMetric() |
| |
|
| | def _convert_span_to_string(self, span, token_ids, tokenizer): |
| | string = "" |
| | if len(span) == 0 or len(span) > 2: |
| | pass |
| | elif len(span) == 1: |
| | string = tokenizer.decode(token_ids[span[0]]) |
| | elif len(span) == 2: |
| | string = tokenizer.decode(token_ids[span[0] : span[1] + 1]) |
| | return (string, self.reset_position(token_ids, span)) |
| |
|
| | def reset_position(self, input_ids: list[int], span: list[int]) -> list[int]: |
| | if isinstance(input_ids, torch.Tensor): |
| | input_ids = input_ids.cpu().tolist() |
| | if len(span) < 1: |
| | return span |
| |
|
| | tp_token_id, tl_token_id = self.transform.tokenizer.convert_tokens_to_ids( |
| | [self.transform.tp_token, self.transform.tl_token] |
| | ) |
| | offset = 0 |
| | if tp_token_id in input_ids: |
| | offset = input_ids.index(tp_token_id) + 1 |
| | elif tl_token_id in input_ids: |
| | offset = input_ids.index(tl_token_id) + 1 |
| | return [i - offset for i in span] |
| |
|
| | def predict_api(self, data: list[dict], **kwargs): |
| | """ |
| | Args: |
| | data: a list of dict in UDI: |
| | { |
| | "id": str, |
| | "instruction": str, |
| | "schema": { |
| | "ent": list, |
| | "rel": list, |
| | "event": dict, |
| | "cls": list, |
| | "discontinuous_ent": list, |
| | "hyper_rel": dict |
| | }, |
| | "text": str, |
| | "bg": str, |
| | "ans": {}, # empty dict |
| | } |
| | """ |
| | raw_dataset = [self.transform.transform(d) for d in data] |
| | loader = self.data_manager.prepare_loader(raw_dataset) |
| | results = [] |
| | for batch in loader: |
| | batch_out = self.model(**batch, is_eval=True) |
| | batch["pred"] = batch_out["pred"] |
| | instances = decompose_batch_into_instances(batch) |
| | for ins in instances: |
| | pred_clses = [] |
| | pred_ents = [] |
| | pred_rels = [] |
| | pred_trigger_to_event = defaultdict( |
| | lambda: {"event_type": "", "arguments": []} |
| | ) |
| | pred_events = [] |
| | pred_spans = [] |
| | pred_discon_ents = [] |
| | pred_hyper_rels = [] |
| | raw_schema = ins["raw"]["schema"] |
| | for multi_part_span in ins["pred"]: |
| | span = tuple(multi_part_span) |
| | span_to_label = ins["span_to_label"] |
| | if span[0] in span_to_label: |
| | label = span_to_label[span[0]] |
| | if label["task"] == "cls" and len(span) == 1: |
| | pred_clses.append(label["string"]) |
| | elif label["task"] == "ent" and len(span) == 2: |
| | string = self._convert_span_to_string( |
| | span[1], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | pred_ents.append((label["string"], string)) |
| | elif label["task"] == "rel" and len(span) == 3: |
| | head = self._convert_span_to_string( |
| | span[1], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | tail = self._convert_span_to_string( |
| | span[2], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | pred_rels.append((label["string"], head, tail)) |
| | elif label["task"] == "event": |
| | if label["type"] == "lm" and len(span) == 2: |
| | pred_trigger_to_event[span[1]]["event_type"] = label["string"] |
| | elif label["type"] == "lr" and len(span) == 3: |
| | arg = self._convert_span_to_string( |
| | span[2], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | pred_trigger_to_event[span[1]]["arguments"].append( |
| | {"argument": arg, "role": label["string"]} |
| | ) |
| | elif label["task"] == "discontinuous_ent" and len(span) > 1: |
| | parts = [ |
| | self._convert_span_to_string( |
| | part, ins["input_ids"], self.transform.tokenizer |
| | ) |
| | for part in span[1:] |
| | ] |
| | string = " ".join([part[0] for part in parts]) |
| | position = [] |
| | for part in parts: |
| | position.append(part[1]) |
| | pred_discon_ents.append( |
| | (label["string"], string, self.reset_position(position)) |
| | ) |
| | elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
| | q_label = span_to_label[span[3]] |
| | span_1 = self._convert_span_to_string( |
| | span[1], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | span_2 = self._convert_span_to_string( |
| | span[2], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | span_4 = self._convert_span_to_string( |
| | span[4], ins["input_ids"], self.transform.tokenizer |
| | ) |
| | pred_hyper_rels.append((label["string"], span_1, span_2, q_label["string"], span_4)) |
| | else: |
| | |
| | pred_token_ids = [] |
| | for part in span: |
| | _pred_token_ids = [ins["input_ids"][i] for i in part] |
| | pred_token_ids.extend(_pred_token_ids) |
| | span_string = self.transform.tokenizer.decode(pred_token_ids) |
| | pred_spans.append( |
| | ( |
| | span_string, |
| | tuple( |
| | [ |
| | tuple( |
| | self.reset_position( |
| | ins["input_ids"].cpu().tolist(), part |
| | ) |
| | ) |
| | for part in span |
| | ] |
| | ), |
| | ) |
| | ) |
| | for trigger, item in pred_trigger_to_event.items(): |
| | trigger = self._convert_span_to_string( |
| | trigger, ins["input_ids"], self.transform.tokenizer |
| | ) |
| | if item["event_type"] not in raw_schema["event"]: |
| | continue |
| | legal_roles = raw_schema["event"][item["event_type"]] |
| | pred_events.append( |
| | { |
| | "trigger": trigger, |
| | "event_type": item["event_type"], |
| | "arguments": [ |
| | arg |
| | for arg in filter( |
| | lambda arg: arg["role"] in legal_roles, |
| | item["arguments"], |
| | ) |
| | ], |
| | } |
| | ) |
| | results.append( |
| | { |
| | "id": ins["raw"]["id"], |
| | "results": { |
| | "cls": pred_clses, |
| | "ent": pred_ents, |
| | "rel": pred_rels, |
| | "event": pred_events, |
| | "span": pred_spans, |
| | "discon_ent": pred_discon_ents, |
| | "hyper_rel": pred_hyper_rels, |
| | }, |
| | } |
| | ) |
| |
|
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | pass |
| | |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|