Data Elements
All of the major Carper projects: trlX, CHEESE, and magiCARP use dataclasses corresponding to batches of data to communicate data between models and different components. trlX is no different, though it has many different dataclasses for different components like training or inference. Currently, we support PPO and ILQL, which each demand different kinds of data during training.
Basic Data Elements for Accelerate
- class trlx.data.accelerate_base_datatypes.PromptElement(text: str, tokens: torch.Tensor[torch.Tensor])[source]
Dataclass for a single prompt, containing its string and tokenized form.
- Parameters
text (str) – The prompt text.
tokens (torch.Tensor) – The prompt tokens. Should be a long tensor
- class trlx.data.accelerate_base_datatypes.PromptBatch(text: Iterable[str], tokens: torch.Tensor[torch.Tensor])[source]
Batched PromptElement
- Parameters
text (Iterable[str]) – An iterable of prompt texts.
tokens (torch.Tensor) – A long tensor batch of prompt tokens.
- class trlx.data.accelerate_base_datatypes.AccelerateRLElement(output_tokens: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
Dataclass for RL elements, containing output tokens and rewards for each token.
- Parameters
tokens (torch.Tensor) – The output tokens. Should be a long tensor
rewards (torch.Tensor) – The rewards for each token. Should be a float tensor of same size as tokens.
- class trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement(output_tokens: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
Batched accelerate RL element
- Parameters
tokens (torch.Tensor) – Batches of long tensors of output tokens.
rewards (torch.Tensor) – Batches of float tensors of rewards for each output token.
Data Elements for PPO
- class trlx.data.ppo_types.PPORLElement(query_tensor: torch.Tensor[torch.Tensor], response_tensor: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
RLElement for PPO
- Parameters
query_tensor (torch.Tensor) – The query tensor i.e. the prompt tokens. Should be a long tensor.
response_tensor (torch.Tensor) – The response tensor i.e. the output tokens. Should be a long tensor.
logprobs (torch.Tensor) – The log probabilities over all tokens in the vocabulary for each token generated from the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens, with a dimension across the vocabulary.
values (torch.Tensor) – The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens.
rewards (torch.Tensor) – The rewards for each token outputted in response. Should be a float tensor of same size as tokens.
- class trlx.data.ppo_types.PPORLBatch(query_tensors: torch.Tensor[torch.Tensor], response_tensors: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
A batched version of the PPORLElement. See PPORLElement for more details on individual fields.
- Parameters
query_tensors (torch.Tensor) – A batch of query tensors. Should be a long tensor.
response_tensors (torch.Tensor) – A batch of response tensors. Should be a long tensor.
logprobs (torch.Tensor) – A batch of log probabilities from policy
values (torch.Tensor) – A batch of values from value network
rewards (torch.Tensor) – A batch of rewards
Data Elements for ILQL
- class trlx.data.ilql_types.ILQLElement(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]
Data element for ILQL
- Parameters
input_ids (torch.Tensor) – Input tokens. Should be a long tensor.
attention_mask (torch.Tensor) – Attention mask. Should be a long tensor.
rewards (torch.Tensor) – Rewards for each token. Should be a float tensor of same size as tokens.
- class trlx.data.ilql_types.ILQLBatch(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]
Batched ILQL data elements
- Parameters
input_ids (torch.Tensor) – Batch of input tokens.
attention_mask (torch.Tensor) – Batch of attention masks.
rewards (torch.Tensor) – Batch of rewards for each token in each token batch.