| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Create a registry.""" |
| |
|
| | from typing import Optional, List |
| | import json |
| | from json import JSONDecodeError |
| |
|
| |
|
| | class Registry: |
| | """Create the registry that will map name to object. This facilitates the users to create |
| | custom registry. |
| | |
| | Parameters |
| | ---------- |
| | name |
| | The name of the registry |
| | |
| | Examples |
| | -------- |
| | |
| | >>> from prediff.utils.registry import Registry |
| | >>> # Create a registry |
| | >>> MODEL_REGISTRY = Registry('MODEL') |
| | >>> |
| | >>> # To register a class/function with decorator |
| | >>> @MODEL_REGISTRY.register() |
| | ... class MyModel: |
| | ... pass |
| | >>> @MODEL_REGISTRY.register() |
| | ... def my_model(): |
| | ... return |
| | >>> |
| | >>> # To register a class object with decorator and provide nickname: |
| | >>> @MODEL_REGISTRY.register('test_class') |
| | ... class MyModelWithNickName: |
| | ... pass |
| | >>> @MODEL_REGISTRY.register('test_function') |
| | ... def my_model_with_nick_name(): |
| | ... return |
| | >>> |
| | >>> # To register a class/function object by function call |
| | ... class MyModel2: |
| | ... pass |
| | >>> MODEL_REGISTRY.register(MyModel2) |
| | >>> # To register with a given name |
| | >>> MODEL_REGISTRY.register('my_model2', MyModel2) |
| | >>> # To list all the registered objects: |
| | >>> MODEL_REGISTRY.list_keys() |
| | |
| | ['MyModel', 'my_model', 'test_class', 'test_function', 'MyModel2', 'my_model2'] |
| | |
| | >>> # To get the registered object/class |
| | >>> MODEL_REGISTRY.get('test_class') |
| | |
| | __main__.MyModelWithNickName |
| | |
| | """ |
| |
|
| | def __init__(self, name: str) -> None: |
| | self._name: str = name |
| | self._obj_map: dict[str, object] = dict() |
| |
|
| | def _do_register(self, name: str, obj: object) -> None: |
| | assert ( |
| | name not in self._obj_map |
| | ), "An object named '{}' was already registered in '{}' registry!".format( |
| | name, self._name |
| | ) |
| | self._obj_map[name] = obj |
| |
|
| | def register(self, *args): |
| | """ |
| | Register the given object under either the nickname or `obj.__name__`. It can be used as |
| | either a decorator or not. See docstring of this class for usage. |
| | """ |
| | if len(args) == 2: |
| | |
| | nickname, obj = args |
| | self._do_register(nickname, obj) |
| | elif len(args) == 1: |
| | if isinstance(args[0], str): |
| | |
| | nickname = args[0] |
| | def deco(func_or_class: object) -> object: |
| | self._do_register(nickname, func_or_class) |
| | return func_or_class |
| | return deco |
| | else: |
| | |
| | self._do_register(args[0].__name__, args[0]) |
| | elif len(args) == 0: |
| | |
| | def deco(func_or_class: object) -> object: |
| | self._do_register(func_or_class.__name__, func_or_class) |
| | return func_or_class |
| | return deco |
| | else: |
| | raise ValueError('Do not support the usage!') |
| |
|
| | def get(self, name: str) -> object: |
| | ret = self._obj_map.get(name) |
| | if ret is None: |
| | raise KeyError( |
| | "No object named '{}' found in '{}' registry!".format( |
| | name, self._name |
| | ) |
| | ) |
| | return ret |
| |
|
| | def list_keys(self) -> List: |
| | return list(self._obj_map.keys()) |
| |
|
| | def __repr__(self) -> str: |
| | s = '{name}(keys={keys})'.format(name=self._name, |
| | keys=self.list_keys()) |
| | return s |
| |
|
| | def create(self, name: str, *args, **kwargs) -> object: |
| | """Create the class object with the given args and kwargs |
| | |
| | Parameters |
| | ---------- |
| | name |
| | The name in the registry |
| | args |
| | kwargs |
| | |
| | Returns |
| | ------- |
| | ret |
| | The created object |
| | """ |
| | obj = self.get(name) |
| | try: |
| | return obj(*args, **kwargs) |
| | except Exception as exp: |
| | print('Cannot create name="{}" --> {} with the provided arguments!\n' |
| | ' args={},\n' |
| | ' kwargs={},\n' |
| | .format(name, obj, args, kwargs)) |
| | raise exp |
| |
|
| | def create_with_json(self, name: str, json_str: str): |
| | """ |
| | |
| | Parameters |
| | ---------- |
| | name |
| | json_str |
| | |
| | Returns |
| | ------- |
| | |
| | """ |
| | try: |
| | args = json.loads(json_str) |
| | except JSONDecodeError: |
| | raise ValueError('Unable to decode the json string: json_str="{}"' |
| | .format(json_str)) |
| | if isinstance(args, (list, tuple)): |
| | return self.create(name, *args) |
| | elif isinstance(args, dict): |
| | return self.create(name, **args) |
| | else: |
| | raise NotImplementedError('The format of json string is not supported! We only support ' |
| | 'list/dict. json_str="{}".' |
| | .format(json_str)) |
| |
|