Source code for trlx.orchestrator

import sys
from abc import abstractmethod
from typing import Dict

from trlx.model import BaseRLModel
from trlx.pipeline import BasePipeline

# specifies a dictionary of architectures
_ORCH: Dict[str, any] = {}  # registry


def register_orchestrator(name):
    """Decorator used register a CARP architecture
    Args:
        name: Name of the architecture
    """

    def register_class(cls, name):
        _ORCH[name] = cls
        setattr(sys.modules[__name__], name, cls)
        return cls

    if isinstance(name, str):
        name = name.lower()
        return lambda c: register_class(c, name)

    cls = name
    name = cls.__name__
    register_class(cls, name.lower())

    return cls


[docs]@register_orchestrator class Orchestrator: def __init__(self, pipeline: BasePipeline, rl_model: BaseRLModel): self.pipeline = pipeline self.rl_model = rl_model
[docs] @abstractmethod def make_experience(self): """ Draw from pipeline, get action, generate reward Push to models RolloutStorage """ pass