Data Elements

All of the major Carper projects: trlX, CHEESE, and magiCARP use dataclasses corresponding to batches of data to communicate data between models and different components. trlX is no different, though it has many different dataclasses for different components like training or inference. Currently, we support PPO and ILQL, which each demand different kinds of data during training.

Basic Data Elements for Accelerate

class trlx.data.accelerate_base_datatypes.PromptElement(text: str, tokens: torch.Tensor[torch.Tensor])[source]

Dataclass for a single prompt, containing its string and tokenized form.

Parameters
  • text (str) – The prompt text.

  • tokens (torch.Tensor) – The prompt tokens. Should be a long tensor

class trlx.data.accelerate_base_datatypes.PromptBatch(text: Iterable[str], tokens: torch.Tensor[torch.Tensor])[source]

Batched PromptElement

Parameters
  • text (Iterable[str]) – An iterable of prompt texts.

  • tokens (torch.Tensor) – A long tensor batch of prompt tokens.

class trlx.data.accelerate_base_datatypes.AccelerateRLElement(output_tokens: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]

Dataclass for RL elements, containing output tokens and rewards for each token.

Parameters
  • tokens (torch.Tensor) – The output tokens. Should be a long tensor

  • rewards (torch.Tensor) – The rewards for each token. Should be a float tensor of same size as tokens.

class trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement(output_tokens: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]

Batched accelerate RL element

Parameters
  • tokens (torch.Tensor) – Batches of long tensors of output tokens.

  • rewards (torch.Tensor) – Batches of float tensors of rewards for each output token.

Data Elements for PPO

class trlx.data.ppo_types.PPORLElement(query_tensor: torch.Tensor[torch.Tensor], response_tensor: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]

RLElement for PPO

Parameters
  • query_tensor (torch.Tensor) – The query tensor i.e. the prompt tokens. Should be a long tensor.

  • response_tensor (torch.Tensor) – The response tensor i.e. the output tokens. Should be a long tensor.

  • logprobs (torch.Tensor) – The log probabilities over all tokens in the vocabulary for each token generated from the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens, with a dimension across the vocabulary.

  • values (torch.Tensor) – The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens.

  • rewards (torch.Tensor) – The rewards for each token outputted in response. Should be a float tensor of same size as tokens.

class trlx.data.ppo_types.PPORLBatch(query_tensors: torch.Tensor[torch.Tensor], response_tensors: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]

A batched version of the PPORLElement. See PPORLElement for more details on individual fields.

Parameters
  • query_tensors (torch.Tensor) – A batch of query tensors. Should be a long tensor.

  • response_tensors (torch.Tensor) – A batch of response tensors. Should be a long tensor.

  • logprobs (torch.Tensor) – A batch of log probabilities from policy

  • values (torch.Tensor) – A batch of values from value network

  • rewards (torch.Tensor) – A batch of rewards

Data Elements for ILQL

class trlx.data.ilql_types.ILQLElement(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]

Data element for ILQL

Parameters
  • input_ids (torch.Tensor) – Input tokens. Should be a long tensor.

  • attention_mask (torch.Tensor) – Attention mask. Should be a long tensor.

  • rewards (torch.Tensor) – Rewards for each token. Should be a float tensor of same size as tokens.

class trlx.data.ilql_types.ILQLBatch(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]

Batched ILQL data elements

Parameters
  • input_ids (torch.Tensor) – Batch of input tokens.

  • attention_mask (torch.Tensor) – Batch of attention masks.

  • rewards (torch.Tensor) – Batch of rewards for each token in each token batch.