Source code for trlx.data.method_configs

import sys
from dataclasses import dataclass
from typing import Any, Dict

# specifies a dictionary of method configs
_METHODS: Dict[str, Any] = {}  # registry


def register_method(name):
    """Decorator used register a method config
    Args:
        name: Name of the method
    """

    def register_class(cls, name):
        _METHODS[name] = cls
        setattr(sys.modules[__name__], name, cls)
        return cls

    if isinstance(name, str):
        name = name.lower()
        return lambda c: register_class(c, name)

    cls = name
    name = cls.__name__
    register_class(cls, name.lower())

    return cls


[docs]@dataclass @register_method class MethodConfig: """ Config for a certain RL method. :param name: Name of the method :type name: str """ name: str @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config)
def get_method(name: str) -> MethodConfig: """ Return constructor for specified method config """ name = name.lower() if name in _METHODS: return _METHODS[name] else: raise Exception("Error: Trying to access a method that has not been registered")