Data Classes
Data Elements contain the necessary information for each individual training sample.
PPO Data Classes
- class trlx.data.ppo_types.PPORLElement(query_tensor: torch.Tensor[torch.Tensor], response_tensor: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
- Parameters
query_tensor (torch.Tensor) – The query tensor i.e. the prompt tokens. Should be a long tensor.
response_tensor (torch.Tensor) – The response tensor i.e. the output tokens. Should be a long tensor.
logprobs (torch.Tensor) – The log probabilities over the response tokens generated by the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens.
values (torch.Tensor) – The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens.
rewards (torch.Tensor) – The rewards for each token outputted in response. Should be a float tensor of same size as tokens.
- class trlx.data.ppo_types.PPORLBatch(query_tensors: torch.Tensor[torch.Tensor], response_tensors: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
A batched version of the PPORLElement. See PPORLElement for more details on individual fields.
- Parameters
query_tensors (torch.Tensor) – A batch of query tensors. Should be a long tensor.
response_tensors (torch.Tensor) – A batch of response tensors. Should be a long tensor.
logprobs (torch.Tensor) – A batch of log probabilities from policy
values (torch.Tensor) – A batch of values from value network
rewards (torch.Tensor) – A batch of rewards
ILQL Data Classes
- class trlx.data.ilql_types.ILQLElement(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]
A single data item for ILQL training
- Parameters
input_ids (torch.Tensor) – Long tensor of input tokens.
attention_mask (torch.Tensor) – Attention mask for input tokens. Should be a long tensor.
rewards (torch.Tensor) – Rewards for each input token.
states_ixs (torch.Tensor) – Indices of states (user input or environment input for example) in the input_ids.
actions_ixs (torch.Tensor) – Indices of actions (model output) in the input_ids tensor.
dones (torch.Tensor) – Indicator of for the terminal state (end of episode) in the input_ids tensor.
- class trlx.models.modeling_ilql.CausalILQLOutput(logits: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, hidden_states: Optional[Tuple[torch.FloatTensor]] = None, value: Optional[torch.FloatTensor] = None, qs: Optional[Tuple[torch.FloatTensor]] = None, target_qs: Optional[Tuple[torch.FloatTensor]] = None)[source]
Output of the causal model with ILQL heads.
- Parameters
logits (torch.FloatTensor) – Logits of the causal model.
past_key_values (Tuple[Tuple[torch.FloatTensor]]) – Tuple of past key values of the causal model.
hidden_states (Tuple[torch.FloatTensor]) – Last hidden state of the causal model.
value (torch.FloatTensor) – Value function estimation for each token in the input sequence.
qs (Tuple[torch.FloatTensor]) – Q-function estimations for each token in the input sequence.
target_qs (Tuple[torch.FloatTensor]) – Q-function estimations from the target Q-head for each token in the input sequence.
- class trlx.data.ilql_types.ILQLSeq2SeqElement(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], decoder_input_ids: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]
A single data item for ILQL training
- Parameters
input_ids (torch.Tensor) – Long tensor of input tokens.
attention_mask (torch.Tensor) – Attention mask for input tokens. Should be a long tensor.
decoder_input_ids (torch.Tensor) – Long tensor of target input tokens.
rewards (torch.Tensor) – Rewards for each input token.
states_ixs (torch.Tensor) – Indices of states (user input or environment input for example) in the input_ids.
actions_ixs (torch.Tensor) – Indices of actions (model output) in the input_ids tensor.
dones (torch.Tensor) – Indicator of for the terminal state (end of episode) in the input_ids tensor.
- class trlx.models.modeling_ilql.Seq2SeqILQLOutput(logits: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, hidden_states: Optional[Tuple[torch.FloatTensor]] = None, value: Optional[torch.FloatTensor] = None, qs: Optional[Tuple[torch.FloatTensor]] = None, target_qs: Optional[Tuple[torch.FloatTensor]] = None, encoder_outputs: Optional[Tuple[Any]] = None)[source]
Output of the seq2seq model with ILQL heads.
- Parameters
logits (torch.FloatTensor) – Logits of the seq2seq model.
past_key_values (Tuple[Tuple[torch.FloatTensor]]) – Tuple of past key values of the seq2seq model.
hidden_states (Tuple[torch.FloatTensor]) – Last hidden state of the seq2seq model.
value (torch.FloatTensor) – Value function estimation for each token in the input sequence.
qs (Tuple[torch.FloatTensor]) – Q-function estimations for each token in the input sequence.
target_qs (Tuple[torch.FloatTensor]) – Q-function estimations from the target Q-head for each token in the input sequence.
encoder_outputs (Tuple[Any]) – Tuple of encoder outputs of the seq2seq model.
- class trlx.data.ilql_types.ILQLBatch(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]
Batched ILQL data elements
- Parameters
input_ids (torch.Tensor) – Batch of input tokens.
attention_mask (torch.Tensor) – Batch of attention masks.
rewards (torch.Tensor) – Batch of rewards for each token in each token batch.
states_ixs (torch.Tensor) – Batch of indices of states (user input or environment input for example) in the input_ids.
actions_ixs (torch.Tensor) – Batch of indices of actions (model output) in the input_ids tensor.
dones (torch.Tensor) – Batch of indicators of for the terminal state (end of episode) in the input_ids tensor.
- class trlx.data.ilql_types.ILQLSeq2SeqBatch(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], decoder_input_ids: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]
Batched ILQL data elements
- Parameters
input_ids (torch.Tensor) – Batch of input tokens.
attention_mask (torch.Tensor) – Batch of attention masks.
decoder_input_ids (torch.Tensor) – Batch of target input tokens.
rewards (torch.Tensor) – Batch of rewards for each token in each token batch.
states_ixs (torch.Tensor) – Batch of indices of states (user input or environment input for example) in the input_ids.
actions_ixs (torch.Tensor) – Batch of indices of actions (model output) in the input_ids tensor.
dones (torch.Tensor) – Batch of indicators of for the terminal state (end of episode) in the input_ids tensor.