| import copy |
| import os |
| import yaml |
| import json |
| import csv |
| from typing import Dict, Any, List, Tuple |
| from utils import torch_load, print_message |
| from embedder import get_embedding_filename |
| from base_models.get_base_models import get_tokenizer |
|
|
|
|
| if os.environ.get('WANDB_AVAILABLE') == 'true': |
| import wandb |
| else: |
| pass |
|
|
|
|
| class HyperoptModule: |
| def __init__( |
| self, |
| main_process, |
| model_name: str, |
| data_name: str, |
| dataset: Tuple, |
| emb_dict: Any, |
| sweep_config: Dict[str, Any], |
| results_list: List[Dict[str, Any]], |
| swept_param_keys: List[str] = None |
| ): |
| self.mp = main_process |
| self.model_name = model_name |
| self.data_name = data_name |
| self.dataset = dataset |
| self.emb_dict = emb_dict |
| self.sweep_config = sweep_config |
| self.results_list = results_list |
| self.swept_param_keys = swept_param_keys or [] |
| |
| self.base_probe_args = copy.deepcopy(self.mp.probe_args.__dict__) |
| self.base_trainer_args = copy.deepcopy(self.mp.trainer_args.__dict__) |
| |
| self.probe_keys = { |
| 'hidden_size','transformer_hidden_size','dropout','n_layers','pre_ln','classifier_size', |
| 'classifier_dropout','n_heads','rotary','use_bias','probe_pooling_types', |
| 'lora','lora_r','lora_alpha','lora_dropout','probe_type','tokenwise', 'pooling_types' |
| } |
| self.trainer_keys = { |
| 'lr','weight_decay','num_epochs','probe_batch_size', |
| 'base_batch_size','probe_grad_accum','base_grad_accum', |
| 'patience','seed' |
| } |
| self.embedding_keys = { |
| 'embedding_pooling_types' |
| } |
| self.int_keys = { |
| 'hidden_size', 'transformer_hidden_size', 'n_layers', 'classifier_size', 'n_heads', |
| 'lora_r', 'lora_alpha', 'num_epochs', 'probe_batch_size', |
| 'base_batch_size', 'probe_grad_accum', 'base_grad_accum', |
| 'patience', 'seed' |
| } |
|
|
| def apply_config(self, cfg: Dict[str, Any]): |
| self.mp.probe_args.__dict__.update(copy.deepcopy(self.base_probe_args)) |
| self.mp.trainer_args.__dict__.update(copy.deepcopy(self.base_trainer_args)) |
| |
| |
| for key in self.int_keys: |
| if key in cfg: |
| cfg[key] = int(cfg[key]) |
| |
| if 'hidden_size' in cfg: |
| val = cfg['hidden_size'] |
| |
| n_heads = max(1, val // 64) |
| cfg['n_heads'] = n_heads |
|
|
| if 'transformer_hidden_size' in cfg: |
| val = cfg['transformer_hidden_size'] |
| |
| n_heads = max(1, val // 64) |
| cfg['n_heads'] = n_heads |
| |
| if 'dropout' in cfg: |
| cfg['transformer_dropout'] = cfg['dropout'] |
|
|
| if 'probe_pooling_types' in cfg: |
| cfg['pooling_types'] = cfg['probe_pooling_types'] |
|
|
| for k, v in cfg.items(): |
| if k in self.probe_keys and hasattr(self.mp.probe_args, k): |
| setattr(self.mp.probe_args, k, v) |
| if k in self.trainer_keys and hasattr(self.mp.trainer_args, k): |
| setattr(self.mp.trainer_args, k, v) |
| |
| if k in self.embedding_keys: |
| if k == 'embedding_pooling_types': |
| if isinstance(v, str): |
| v = [v] |
| self.mp.embedding_args.pooling_types = v |
|
|
| def train_model(self, sweep_mode=True): |
| train_set, valid_set, test_set, _, _, ppi = self.dataset |
| |
| if self.mp.full_args.full_finetuning: |
| model, valid_metrics, test_metrics = self.mp._run_full_finetuning( |
| self.model_name, self.data_name, |
| train_set, valid_set, test_set, |
| ppi=ppi, sweep_mode=sweep_mode |
| ) |
| return model, valid_metrics, test_metrics |
|
|
| elif self.mp.full_args.hybrid_probe: |
| tokenizer = get_tokenizer(self.model_name) |
| model, valid_metrics, test_metrics = self.mp._run_hybrid_probe( |
| self.model_name, self.data_name, |
| train_set, valid_set, test_set, |
| tokenizer, |
| emb_dict=self.emb_dict, |
| ppi=ppi, |
| sweep_mode=sweep_mode |
| ) |
| return model, valid_metrics, test_metrics |
|
|
| else: |
| tokenizer = get_tokenizer(self.model_name) |
| probe, valid_metrics, test_metrics = self.mp._run_nn_probe( |
| self.model_name, self.data_name, |
| train_set, valid_set, test_set, |
| tokenizer, |
| emb_dict=self.emb_dict, |
| ppi=ppi, |
| sweep_mode=sweep_mode |
| ) |
| return probe, valid_metrics, test_metrics |
|
|
| def select_metric(self, valid_metrics: Dict[str, Any], test_metrics: Dict[str, Any], sweep_metric: str) -> float: |
| if valid_metrics and sweep_metric in valid_metrics: |
| return float(valid_metrics[sweep_metric]) |
| elif test_metrics and sweep_metric in test_metrics: |
| return float(test_metrics[sweep_metric]) |
| |
| |
| available_keys = [] |
| if valid_metrics: available_keys.extend(valid_metrics.keys()) |
| if test_metrics: available_keys.extend(test_metrics.keys()) |
| raise KeyError(f"Metric '{sweep_metric}' not found in validation or test metrics. Available metrics: {available_keys}") |
|
|
| def objective(self): |
| run = wandb.init( |
| project=self.mp.full_args.wandb_project, |
| entity=self.mp.full_args.wandb_entity, |
| config=self.sweep_config, |
| reinit=True, |
| tags=["sweep", f"model:{self.model_name}", f"data:{self.data_name}"], |
| ) |
| run.name = f"sweep-{self.model_name}_{self.data_name}-{run.id[:6]}" |
| |
| |
| full_config = dict(wandb.config) |
| self.apply_config(full_config) |
| |
| applied_config = {k: v for k, v in full_config.items() if k in self.swept_param_keys} |
| self.mp.trainer_args.make_plots = False |
| |
| |
| if 'embedding_pooling_types' in full_config and not self.mp.full_args.full_finetuning: |
| _, _, _, _, _, ppi = self.dataset |
| tokenizer = get_tokenizer(self.model_name) |
| test_seq = self.mp.all_seqs[0] |
| |
| if self.mp._sql: |
| filename = get_embedding_filename(self.model_name, self.mp._full, |
| self.mp.embedding_args.pooling_types, 'db') |
| save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename) |
| input_dim = self.mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| self.emb_dict = None |
| else: |
| filename = get_embedding_filename(self.model_name, self.mp._full, |
| self.mp.embedding_args.pooling_types, 'pth') |
| save_path = os.path.join(self.mp.embedding_args.embedding_save_dir, filename) |
| self.emb_dict = torch_load(save_path) |
| input_dim = self.mp.get_embedding_dim_pth(self.emb_dict, test_seq, tokenizer) |
| |
| self.mp.probe_args.input_size = input_dim * 2 if (ppi and not self.mp._full) else input_dim |
| |
| _, valid_metrics, test_metrics = self.train_model(sweep_mode=True) |
| |
| |
| label_type = self.mp.probe_args.task_type |
| metric_cls = getattr(self.mp.full_args, 'sweep_metric_cls', None) |
| metric_reg = getattr(self.mp.full_args, 'sweep_metric_reg', None) |
| dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg |
|
|
| all_metrics = {} |
| if isinstance(valid_metrics, dict): |
| for k, v in valid_metrics.items(): |
| all_metrics[f"{k}"] = v |
| if isinstance(test_metrics, dict): |
| for k, v in test_metrics.items(): |
| all_metrics[f"{k}"] = v |
| wandb.log(all_metrics) |
| |
| metric_value = self.select_metric(valid_metrics, test_metrics, dataset_metric) |
| |
| self.results_list.append({ |
| "wandb_run_id": run.id, |
| dataset_metric: metric_value, |
| "config": applied_config, |
| "valid_metrics": valid_metrics, |
| "test_metrics": test_metrics, |
| }) |
| |
| run.finish() |
| return float(metric_value) |
|
|
| @classmethod |
| def run_wandb_hyperopt(cls, mp): |
| mp.logger.info("Called method: run_wandb_hyperopt") |
|
|
| sweep_config = {} |
| sweep_config_path = mp.full_args.sweep_config_path |
| |
| if os.path.exists(sweep_config_path): |
| with open(sweep_config_path, 'r') as f: |
| sweep_config = yaml.safe_load(f) |
| else: |
| raise ValueError(f"Sweep config file not found: {sweep_config_path}") |
|
|
| params_to_hyperopt = sweep_config.get("parameters", {}) |
| |
| |
| probe_type = getattr(mp.probe_args, 'probe_type', 'linear') |
| use_lora = getattr(mp.probe_args, 'lora', False) |
| |
| |
| linear_probe_params = {'lr', 'weight_decay', 'hidden_size', 'n_layers', 'dropout', 'pre_ln', 'use_bias', 'probe_batch_size'} |
| transformer_probe_params = {'lr', 'weight_decay', 'transformer_hidden_size', 'n_layers', 'transformer_dropout', 'pre_ln', |
| 'classifier_dropout', 'classifier_size', 'use_bias', 'probe_pooling_types', 'embedding_pooling_types', 'probe_batch_size'} |
| lora_params = {'lora_r', 'lora_alpha', 'lora_dropout'} |
| |
| |
| if probe_type == 'linear': |
| relevant_params = linear_probe_params |
| elif probe_type == 'transformer': |
| relevant_params = transformer_probe_params |
| else: |
| |
| relevant_params = linear_probe_params | transformer_probe_params |
| |
| |
| if use_lora: |
| relevant_params = relevant_params | lora_params |
| |
| |
| filtered_params = {k: v for k, v in params_to_hyperopt.items() if k in relevant_params} |
| params_to_hyperopt = filtered_params |
| |
| |
| mp.logger.info(f"Probe type: {probe_type}, LoRA enabled: {use_lora}") |
| mp.logger.info(f"Sweeping over {len(params_to_hyperopt)} parameters: {list(params_to_hyperopt.keys())}") |
|
|
| method = mp.full_args.sweep_method |
| early_term = sweep_config.get("early_terminate", None) |
|
|
| total_combinations = len(mp.model_args.model_names) * len(mp.datasets) |
| mp.logger.info(f"Hyperopt over {total_combinations} model/dataset combinations") |
| for model_name in mp.model_args.model_names: |
| tokenizer = get_tokenizer(model_name) |
| test_seq = mp.all_seqs[0] |
|
|
| if "random" in model_name.lower() or "onehot" in model_name.lower(): |
| print_message(f"Skipping hyperparameter optimization for {model_name}.") |
|
|
| for data_name, dataset in mp.datasets.items(): |
| train_set, valid_set, test_set, num_labels, label_type, ppi = dataset |
| mp.probe_args.num_labels = num_labels |
| mp.probe_args.task_type = label_type |
| mp.trainer_args.task_type = label_type |
| mp.trainer_args.make_plots = True |
|
|
| emb_dict = None |
| if not mp.full_args.full_finetuning: |
| if mp._sql: |
| filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db') |
| save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| else: |
| filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth') |
| save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| emb_dict = torch_load(save_path) |
| input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) |
| mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim |
| if mp.full_args.full_finetuning: |
| _ = mp._run_full_finetuning(model_name, data_name, train_set, valid_set, test_set, ppi, sweep_mode=False) |
| elif mp.full_args.hybrid_probe: |
| _ = mp._run_hybrid_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False) |
| else: |
| _ = mp._run_nn_probe(model_name, data_name, train_set, valid_set, test_set, tokenizer, emb_dict=emb_dict, ppi=ppi, sweep_mode=False) |
| continue |
|
|
| for data_name, dataset in mp.datasets.items(): |
| mp.logger.info(f"Sweeping over {data_name} with {model_name}") |
| train_set, _, _, num_labels, label_type, ppi = dataset |
| mp.probe_args.num_labels = num_labels |
| mp.probe_args.task_type = label_type |
| mp.trainer_args.task_type = label_type |
|
|
| emb_dict = None |
| if not mp.full_args.full_finetuning: |
| if mp._sql: |
| filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'db') |
| save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| input_dim = mp.get_embedding_dim_sql(save_path, test_seq, tokenizer) |
| else: |
| filename = get_embedding_filename(model_name, mp._full, mp.embedding_args.pooling_types, 'pth') |
| save_path = os.path.join(mp.embedding_args.embedding_save_dir, filename) |
| emb_dict = torch_load(save_path) |
| input_dim = mp.get_embedding_dim_pth(emb_dict, test_seq, tokenizer) |
| mp.probe_args.input_size = input_dim * 2 if (ppi and not mp._full) else input_dim |
|
|
| |
| base_probe = copy.deepcopy(mp.probe_args.__dict__) |
| base_trainer = copy.deepcopy(mp.trainer_args.__dict__) |
|
|
| results_list = [] |
| |
| metric_cls = getattr(mp.full_args, 'sweep_metric_cls', None) |
| metric_reg = getattr(mp.full_args, 'sweep_metric_reg', None) |
| dataset_metric = metric_cls if label_type in ["singlelabel", "multilabel"] else metric_reg |
| |
| hyperopt_module = cls( |
| main_process=mp, |
| model_name=model_name, |
| data_name=data_name, |
| dataset=dataset, |
| emb_dict=emb_dict, |
| sweep_config=sweep_config, |
| results_list=results_list, |
| swept_param_keys=list(params_to_hyperopt.keys()) |
| ) |
|
|
| wb_sweep = { |
| "method": method, |
| "metric": {"name": dataset_metric, "goal": mp.full_args.sweep_goal}, |
| "early_terminate": early_term, |
| "parameters": params_to_hyperopt, |
| } |
| sweep_id = wandb.sweep(sweep=wb_sweep, project=mp.full_args.wandb_project, entity=mp.full_args.wandb_entity) |
| wandb.agent(sweep_id, function=hyperopt_module.objective, count=mp.full_args.sweep_count) |
|
|
| |
| reverse_flag = True if mp.full_args.sweep_goal == "maximize" else False |
| results_list.sort(key=lambda x: x[dataset_metric], reverse=reverse_flag) |
| sweep_log_path = os.path.join(mp.full_args.log_dir, f"{mp.random_id}_sweep_{data_name}_{model_name}.csv") |
| with open(sweep_log_path, 'w', newline='', encoding='utf-8') as f: |
| writer = csv.writer(f, delimiter=',') |
| |
| columns = ["rank","wandb_run_id",dataset_metric,"config","valid_metrics","test_metrics"] |
| writer.writerow(columns) |
| for idx, res in enumerate(results_list, start=1): |
| writer.writerow([ |
| idx, |
| res['wandb_run_id'], |
| res[dataset_metric], |
| json.dumps(res['config']), |
| json.dumps(res['valid_metrics']), |
| json.dumps(res['test_metrics']), |
| ]) |
|
|
| |
| best = results_list[0] if results_list else None |
| best_score = best[dataset_metric] |
| best_config = best['config'] |
| print_message(f"Best sweep result - {dataset_metric}: {best_score}") |
| print_message(f"Best hyperparameters: {json.dumps(best_config, indent=2)}") |
|
|
| |
| mp.probe_args.__dict__.update(copy.deepcopy(base_probe)) |
| mp.trainer_args.__dict__.update(copy.deepcopy(base_trainer)) |
| hyperopt_module.apply_config(best_config) |
| mp.trainer_args.make_plots = True |
| |
| final_config = { |
| **best_config, |
| 'probe_batch_size': mp.trainer_args.probe_batch_size, |
| 'seed': mp.trainer_args.seed, |
| 'patience': mp.trainer_args.patience, |
| 'num_epochs': mp.trainer_args.num_epochs, |
| } |
| print_message(f"Final training config: {json.dumps(final_config, indent=2)}") |
|
|
| |
| final_run = wandb.init( |
| project=mp.full_args.wandb_project, |
| entity=mp.full_args.wandb_entity, |
| config=final_config, |
| reinit=True, |
| tags=["final_model", f"model:{model_name}", f"data:{data_name}", f"best_sweep_score:{best_score}"], |
| name=f"final-{model_name}_{data_name}-best", |
| ) |
|
|
| |
| _, valid_metrics, test_metrics = hyperopt_module.train_model(sweep_mode=False) |
| |
| |
| all_final_metrics = {} |
| if isinstance(valid_metrics, dict): |
| for k, v in valid_metrics.items(): |
| all_final_metrics[f"final_{k}"] = v |
| if isinstance(test_metrics, dict): |
| for k, v in test_metrics.items(): |
| all_final_metrics[f"final_{k}"] = v |
| wandb.log(all_final_metrics) |
| |
| final_run.finish() |