| import os |
| from typing import Tuple |
| from functools import reduce |
|
|
| from argparse import Namespace |
| from omegaconf import DictConfig, OmegaConf |
|
|
|
|
| |
| |
| |
|
|
| def register_resolver(): |
| OmegaConf.register_new_resolver( |
| "add", lambda *numbers: sum(numbers) |
| ) |
| OmegaConf.register_new_resolver( |
| "multiply", lambda *numbers: reduce(lambda x, y: x * y, numbers) |
| ) |
| OmegaConf.register_new_resolver( |
| "sub", lambda n1, n2: n1 - n2 |
| ) |
|
|
|
|
| def _merge_args_and_config( |
| cmd_args: Namespace, |
| yaml_config: DictConfig, |
| read_only: bool = False |
| ) -> Tuple[DictConfig, DictConfig, DictConfig]: |
| |
| cmd_args_dict = vars(cmd_args) |
| cmd_args_list = [] |
| for k, v in cmd_args_dict.items(): |
| cmd_args_list.append(f"{k}={v}") |
| cmd_args_conf = OmegaConf.from_cli(cmd_args_list) |
|
|
| |
| |
| args_ = OmegaConf.merge(yaml_config, cmd_args_conf) |
|
|
| if read_only: |
| OmegaConf.set_readonly(args_, True) |
|
|
| return args_, cmd_args_conf, yaml_config |
|
|
|
|
| def merge_configs(args, method_cfg_path): |
| """merge command line args (argparse) and config file (OmegaConf)""" |
| yaml_config_path = os.path.join("./", "config", method_cfg_path) |
| try: |
| yaml_config = OmegaConf.load(yaml_config_path) |
| except FileNotFoundError as e: |
| print(f"error: {e}") |
| print(f"input file path: `{method_cfg_path}`") |
| print(f"config path: `{yaml_config_path}` not found.") |
| raise FileNotFoundError(e) |
| return _merge_args_and_config(args, yaml_config, read_only=False) |
|
|
|
|
| def update_configs(source_args, update_nodes, strict=True, remove_update_nodes=True): |
| """update config file (OmegaConf) with dotlist""" |
| if update_nodes is None: |
| return source_args |
|
|
| update_args_list = str(update_nodes).split() |
| if len(update_args_list) < 1: |
| return source_args |
|
|
| |
| for item in update_args_list: |
| item_key_ = str(item).split('=')[0] |
| |
|
|
| if strict: |
| |
| |
|
|
| |
| assert not OmegaConf.is_missing(source_args, item_key_), f"the value of {item_key_} is missing." |
|
|
| |
| if OmegaConf.select(source_args, item_key_) is None: |
| source_args.item_key_ = item_key_ |
|
|
| |
| update_nodes = OmegaConf.from_dotlist(update_args_list) |
| merged_args = OmegaConf.merge(source_args, update_nodes) |
|
|
| |
| if remove_update_nodes: |
| OmegaConf.update(merged_args, 'update', '') |
| return merged_args |
|
|
|
|
| def update_if_exist(source_args, update_nodes): |
| """update config file (OmegaConf) with dotlist""" |
| if update_nodes is None: |
| return source_args |
|
|
| upd_args_list = str(update_nodes).split() |
| if len(upd_args_list) < 1: |
| return source_args |
|
|
| update_args_list = [] |
| for item in upd_args_list: |
| item_key_ = str(item).split('=')[0] |
|
|
| |
| |
| |
|
|
| update_args_list.append(item) |
|
|
| |
| if len(update_args_list) < 1: |
| merged_args = source_args |
| else: |
| update_nodes = OmegaConf.from_dotlist(update_args_list) |
| merged_args = OmegaConf.merge(source_args, update_nodes) |
|
|
| return merged_args |
|
|
|
|
| def merge_and_update_config(args): |
| register_resolver() |
|
|
| |
| |
| if args.config is not None and str(args.config).endswith('.yaml'): |
| merged_args, cmd_args, yaml_config = merge_configs(args, args.config) |
| else: |
| merged_args, cmd_args, yaml_config = args, args, None |
|
|
| |
| update_nodes = args.update |
| final_args = update_configs(merged_args, update_nodes) |
|
|
| |
| yaml_config_update = update_if_exist(yaml_config, update_nodes) |
| cmd_args_update = update_if_exist(cmd_args, update_nodes) |
| cmd_args_update.update = "" |
|
|
| final_args.yaml_config = yaml_config_update |
| final_args.cmd_args = cmd_args_update |
|
|
| |
| if final_args.seed < 0: |
| import random |
| final_args.seed = random.randint(0, 65535) |
|
|
| return final_args |
|
|