import random
import sys
from abc import abstractmethod, abstractstaticmethod
from dataclasses import is_dataclass
from typing import Any, Callable, Dict, Iterable
from torch.utils.data import DataLoader, Dataset
from transformers.tokenization_utils_base import BatchEncoding
from trlx.data import GeneralElement, RLElement
from trlx.utils import logging
# specifies a dictionary of architectures
_DATAPIPELINE: Dict[str, any] = {} # registry
logger = logging.get_logger(__name__)
def register_datapipeline(name):
"""Decorator used register a CARP architecture
Args:
name: Name of the architecture
"""
def register_class(cls, name):
_DATAPIPELINE[name] = cls
setattr(sys.modules[__name__], name, cls)
return cls
if isinstance(name, str):
name = name.lower()
return lambda c: register_class(c, name)
cls = name
name = cls.__name__
register_class(cls, name.lower())
return cls
[docs]@register_datapipeline
class BasePipeline(Dataset):
def __init__(self, path: str = "dataset"):
super().__init__()
@abstractmethod
def __getitem__(self, index: int) -> GeneralElement:
pass
@abstractmethod
def __len__(self) -> int:
pass
[docs] @abstractmethod
def create_loader(
self,
batch_size: int,
shuffle: bool,
prep_fn: Callable = None,
num_workers: int = 0,
) -> DataLoader:
"""
Create a dataloader for the pipeline
:param prep_fn: Typically a tokenizer. Applied to GeneralElement after collation.
"""
pass
[docs]class BaseRolloutStore(Dataset):
def __init__(self, capacity=-1):
self.history: Iterable[Any] = None
self.capacity = capacity
[docs] @abstractmethod
def push(self, exps: Iterable[Any]):
"""
Push experiences to rollout storage
"""
pass
def __getitem__(self, index: int) -> RLElement:
return self.history[index]
def __len__(self) -> int:
return len(self.history)
[docs] @abstractmethod
def create_loader(
self,
batch_size: int,
shuffle: bool,
prep_fn: Callable = None,
num_workers: int = 0,
) -> DataLoader:
"""
Create a dataloader for the rollout store
:param prep_fn: Applied to RLElement after collation (typically tokenizer)
:type prep_fn: Callable
"""
pass
class MiniBatchIterator:
"""
A custom iterator for generating mini-batches from a PyTorch DataLoader.
"""
def __init__(self, data_loader, mb_size, num_mb):
"""
Initializes the MiniBatchIterator.
Args:
data_loader (torch.utils.data.DataLoader): The DataLoader to generate mini-batches from.
mb_size (int): The size of each mini-batch.
num_mb (int): The number of mini-batches to generate for each iteration.
"""
self.data_loader = data_loader
self.data_loader_iter = iter(data_loader)
self.mb_size = mb_size
self.num_mb = num_mb
def __iter__(self):
return self
def __next__(self): # noqa: C901
batch = next(self.data_loader_iter)
if batch is None:
logger.warning(
"WARNING: Not enough samples to saturate the minibatch size. Increase the number "
"of prompts or samples or decrease the minibatch size."
)
raise StopIteration
minibatches = []
for mbi in range(self.num_mb):
sliced_data = {}
batch_dict = batch
if is_dataclass(batch):
batch_dict = batch.__dict__
for key, value in batch_dict.items():
start_idx = mbi * self.mb_size
end_idx = (mbi + 1) * self.mb_size
sliced_data[key] = value[start_idx:end_idx]
if self.num_mb > 1 and len(sliced_data[key]) == 0:
logger.warning(
"WARNING: MiniBatchIterator generated a minibatch with 0 elements. "
"This may be due to the wrong mb_size and/or num_mb or the last batch"
"in the dataset being smaller."
)
sliced_data.pop(key)
break
elif self.num_mb > 1 and len(sliced_data[key]) < self.mb_size:
logger.warning(
"WARNING: MiniBatchIterator generated a minibatch with fewer elements than mb_size. "
"This may be due to the wrong mb_size and/or num_mb or the last batch in the dataset "
"being smaller."
)
if not sliced_data:
break
if isinstance(batch, BatchEncoding):
minibatch = BatchEncoding(sliced_data)
elif is_dataclass(batch):
minibatch = batch.__class__(**sliced_data)
# else:
# minibatch = sliced_data
minibatches.append(minibatch)
if not minibatches:
raise StopIteration
return minibatches