Source code for trlx.pipeline

import random
import sys
from abc import abstractmethod, abstractstaticmethod
from typing import Any, Callable, Dict, Iterable

from datasets import load_from_disk
from torch.utils.data import DataLoader, Dataset

from trlx.data import GeneralElement, RLElement

# specifies a dictionary of architectures
_DATAPIPELINE: Dict[str, any] = {}  # registry


def register_datapipeline(name):
    """Decorator used register a CARP architecture
    Args:
        name: Name of the architecture
    """

    def register_class(cls, name):
        _DATAPIPELINE[name] = cls
        setattr(sys.modules[__name__], name, cls)
        return cls

    if isinstance(name, str):
        name = name.lower()
        return lambda c: register_class(c, name)

    cls = name
    name = cls.__name__
    register_class(cls, name.lower())

    return cls


[docs]@register_datapipeline class BasePipeline(Dataset): def __init__(self, path: str = "dataset"): super().__init__() @abstractmethod def __getitem__(self, index: int) -> GeneralElement: pass @abstractmethod def __len__(self) -> int: pass
[docs] @abstractmethod def create_loader( self, batch_size: int, shuffle: bool, prep_fn: Callable = None, num_workers: int = 0, ) -> DataLoader: """ Create a dataloader for the pipeline :param prep_fn: Typically a tokenizer. Applied to GeneralElement after collation. """ pass
[docs]class BaseRolloutStore(Dataset): def __init__(self, capacity=-1): self.history: Iterable[Any] = None self.capacity = capacity
[docs] @abstractmethod def push(self, exps: Iterable[Any]): """ Push experiences to rollout storage """ pass
def __getitem__(self, index: int) -> RLElement: return self.history[index] def __len__(self) -> int: return len(self.history)
[docs] @abstractmethod def create_loader( self, batch_size: int, shuffle: bool, prep_fn: Callable = None, num_workers: int = 0, ) -> DataLoader: """ Create a dataloader for the rollout store :param prep_fn: Applied to RLElement after collation (typically tokenizer) :type prep_fn: Callable """ pass