| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from argparse import Namespace |
| | from typing import Dict, Any |
| |
|
| | import torch |
| |
|
| | from .adaptor_generic import GenericAdaptor, AdaptorBase |
| |
|
| | dict_t = Dict[str, Any] |
| | state_t = Dict[str, torch.Tensor] |
| |
|
| |
|
| | class AdaptorRegistry: |
| | def __init__(self): |
| | self._registry = {} |
| |
|
| | def register_adaptor(self, name): |
| | def decorator(factory_function): |
| | if name in self._registry: |
| | raise ValueError(f"Model '{name}' already registered") |
| | self._registry[name] = factory_function |
| | return factory_function |
| | return decorator |
| |
|
| | def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: |
| | if name not in self._registry: |
| | return GenericAdaptor(main_config, adaptor_config, state) |
| | return self._registry[name](main_config, adaptor_config, state) |
| |
|
| | |
| | adaptor_registry = AdaptorRegistry() |
| |
|