Source code for trlx.pipeline.offline_pipeline

from typing import Iterable

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

from trlx.data.ilql_types import ILQLBatch, ILQLElement
from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline


[docs]@register_datapipeline class PromptPipeline(BasePipeline): """ Tokenizes texts, and then pads them into batches """ def __init__(self, prompts, tokenizer=None): super().__init__() self.tokenizer = tokenizer self.prompts = list(map(tokenizer if tokenizer else (lambda x: x), prompts)) def __getitem__(self, ix: int): return self.prompts[ix] def __len__(self) -> int: return len(self.prompts)
[docs] def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: collate_fn = ( DataCollatorWithPadding(self.tokenizer) if self.tokenizer else torch.vstack ) return DataLoader( self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle )
[docs]class ILQLRolloutStorage(BaseRolloutStore): """ Rollout storage for training ILQL """ def __init__( self, input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones ): super().__init__() self.input_ids = input_ids self.attention_mask = attention_mask self.rewards = rewards self.states_ixs = states_ixs self.actions_ixs = actions_ixs self.dones = dones def __getitem__(self, ix: int) -> ILQLElement: return ILQLElement( self.input_ids[ix], self.attention_mask[ix], self.rewards[ix], self.states_ixs[ix], self.actions_ixs[ix], self.dones[ix], ) def __len__(self) -> int: return len(self.input_ids)
[docs] def create_loader(self, batch_size: int): def collate_fn(elems: Iterable[ILQLElement]): return ILQLBatch( pad_sequence( [x.input_ids for x in elems], batch_first=True, padding_value=0 ), pad_sequence( [x.attention_mask for x in elems], batch_first=True, padding_value=0 ), pad_sequence( [x.rewards for x in elems], batch_first=True, padding_value=0.0 ), pad_sequence( [x.states_ixs for x in elems], batch_first=True, padding_value=0 ), pad_sequence( [x.actions_ixs for x in elems], batch_first=True, padding_value=0 ), pad_sequence( [x.dones for x in elems], batch_first=True, padding_value=0 ), ) return DataLoader( self, batch_size=batch_size, shuffle=True, collate_fn=collate_fn )