Pipelines

Pipelines are used for accumulation and convertion of the training data to appropriate format.

class trlx.pipeline.BasePipeline(path: str = 'dataset')[source]
abstract create_loader(batch_size: int, shuffle: bool, prep_fn: Optional[Callable] = None, num_workers: int = 0)torch.utils.data.dataloader.DataLoader[source]

Create a dataloader for the pipeline

Parameters

prep_fn – Typically a tokenizer. Applied to GeneralElement after collation.

class trlx.pipeline.BaseRolloutStore(capacity=- 1)[source]
abstract create_loader(batch_size: int, shuffle: bool, prep_fn: Optional[Callable] = None, num_workers: int = 0)torch.utils.data.dataloader.DataLoader[source]

Create a dataloader for the rollout store

Parameters

prep_fn (Callable) – Applied to RLElement after collation (typically tokenizer)

abstract push(exps: Iterable[Any])[source]

Push experiences to rollout storage

class trlx.pipeline.offline_pipeline.DialogMessage(is_output: bool, tokens: Tuple[int])[source]

Single message in a dialogue

Parameters
  • is_output (bool) – Whether the message is a model output or a prompt

  • tokens (Tuple[int]) – Tokenized message

class trlx.pipeline.offline_pipeline.DialogStore(dialogs: List[List[trlx.pipeline.offline_pipeline.DialogMessage]], tokenizer: transformers.tokenization_utils.PreTrainedTokenizer)[source]
create_loader(batch_size: int, shuffle=False)torch.utils.data.dataloader.DataLoader[source]

Create a dataloader for the rollout store

Parameters

prep_fn (Callable) – Applied to RLElement after collation (typically tokenizer)

trlx.pipeline.offline_pipeline.tokenize_dialogue(dialogue: Union[str, Iterable[str]], tokenizer: Union[transformers.tokenization_utils.PreTrainedTokenizer, transformers.tokenization_utils_fast.PreTrainedTokenizerFast], max_length=2048)List[trlx.pipeline.offline_pipeline.DialogMessage][source]

Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2…)

class trlx.pipeline.ppo_pipeline.PPORolloutStorage(pad_token_id, padding_side)[source]

Rollout storage for training PPO

create_loader(batch_size: int, shuffle: bool)torch.utils.data.dataloader.DataLoader[source]

Create a dataloader for the rollout store

Parameters

prep_fn (Callable) – Applied to RLElement after collation (typically tokenizer)

push(exps: Iterable[trlx.data.ppo_types.PPORLElement])[source]

Push experiences to rollout storage

class trlx.pipeline.offline_pipeline.PromptPipeline(prompts: Union[List[Dict[str, Any]], List[str]], max_prompt_length: int, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, add_special_tokens: bool = False)[source]

Dataloader which is used to supply prompts for either training or evaluation

Args:
prompts (List[str] or List[Dict[str, Any]]): list of raw text prompts or a dictionary with a required

key “prompt” and extra information, that would be passed along the generation for that prompt as a keyword argument to a reward function.

max_prompt_length (int): max length of the prompt, if exceeded the prompt will be truncated according to

tokenizer’s truncation setting.

tokenizer (transformers.PreTrainedTokenizer): a tokenizer to tokenize prompts with. add_special_tokens (bool): whether to encode prompts with tokenizer’s special tokens (passed directly

into tokenizer.encode)

create_loader(batch_size: int, shuffle=False, sampler=None, drop_last=False)torch.utils.data.dataloader.DataLoader[source]

Create a dataloader for the pipeline

Parameters

prep_fn – Typically a tokenizer. Applied to GeneralElement after collation.

class trlx.pipeline.offline_pipeline.ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones)[source]

Rollout storage for training ILQL

create_loader(batch_size: int)[source]

Create a dataloader for the rollout store

Parameters

prep_fn (Callable) – Applied to RLElement after collation (typically tokenizer)

class trlx.pipeline.offline_pipeline.ILQLSeq2SeqRolloutStorage(input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones)[source]

Rollout storage for training ILQL with Seq2Seq models

create_loader(batch_size: int)[source]

Create a dataloader for the rollout store

Parameters

prep_fn (Callable) – Applied to RLElement after collation (typically tokenizer)