Pipelines
Pipelines are used for accumulation and convertion of the training data to appropriate format.
- class trlx.pipeline.BaseRolloutStore(capacity=- 1)[source]
- class trlx.pipeline.offline_pipeline.DialogMessage(is_output: bool, tokens: Tuple[int])[source]
Single message in a dialogue
- Parameters
is_output (bool) – Whether the message is a model output or a prompt
tokens (Tuple[int]) – Tokenized message
- class trlx.pipeline.offline_pipeline.DialogStore(dialogs: List[List[trlx.pipeline.offline_pipeline.DialogMessage]], tokenizer: transformers.tokenization_utils.PreTrainedTokenizer)[source]
- trlx.pipeline.offline_pipeline.tokenize_dialogue(dialogue: Union[str, Iterable[str]], tokenizer: Union[transformers.tokenization_utils.PreTrainedTokenizer, transformers.tokenization_utils_fast.PreTrainedTokenizerFast], max_length=2048) → List[trlx.pipeline.offline_pipeline.DialogMessage][source]
Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2…)
- class trlx.pipeline.ppo_pipeline.PPORolloutStorage(pad_token_id, padding_side)[source]
Rollout storage for training PPO
- create_loader(batch_size: int, shuffle: bool) → torch.utils.data.dataloader.DataLoader[source]
Create a dataloader for the rollout store
- Parameters
prep_fn (Callable) – Applied to RLElement after collation (typically tokenizer)
- push(exps: Iterable[trlx.data.ppo_types.PPORLElement])[source]
Push experiences to rollout storage
- class trlx.pipeline.offline_pipeline.PromptPipeline(prompts: Union[List[Dict[str, Any]], List[str]], max_prompt_length: int, tokenizer: transformers.tokenization_utils.PreTrainedTokenizer, add_special_tokens: bool = False)[source]
Dataloader which is used to supply prompts for either training or evaluation
- Args:
- prompts (List[str] or List[Dict[str, Any]]): list of raw text prompts or a dictionary with a required
key “prompt” and extra information, that would be passed along the generation for that prompt as a keyword argument to a reward function.
- max_prompt_length (int): max length of the prompt, if exceeded the prompt will be truncated according to
tokenizer’s truncation setting.
tokenizer (transformers.PreTrainedTokenizer): a tokenizer to tokenize prompts with. add_special_tokens (bool): whether to encode prompts with tokenizer’s special tokens (passed directly
into tokenizer.encode)
- class trlx.pipeline.offline_pipeline.ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones)[source]
Rollout storage for training ILQL