Data Classes

Data Elements contain the necessary information for each individual training sample.

PPO Data Classes

class trlx.data.ppo_types.PPORLElement(query_tensor: torch.Tensor[torch.Tensor], response_tensor: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]
Parameters
  • query_tensor (torch.Tensor) – The query tensor i.e. the prompt tokens. Should be a long tensor.

  • response_tensor (torch.Tensor) – The response tensor i.e. the output tokens. Should be a long tensor.

  • logprobs (torch.Tensor) – The log probabilities over the response tokens generated by the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens.

  • values (torch.Tensor) – The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens.

  • rewards (torch.Tensor) – The rewards for each token outputted in response. Should be a float tensor of same size as tokens.

class trlx.data.ppo_types.PPORLBatch(query_tensors: torch.Tensor[torch.Tensor], response_tensors: torch.Tensor[torch.Tensor], logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor])[source]

A batched version of the PPORLElement. See PPORLElement for more details on individual fields.

Parameters
  • query_tensors (torch.Tensor) – A batch of query tensors. Should be a long tensor.

  • response_tensors (torch.Tensor) – A batch of response tensors. Should be a long tensor.

  • logprobs (torch.Tensor) – A batch of log probabilities from policy

  • values (torch.Tensor) – A batch of values from value network

  • rewards (torch.Tensor) – A batch of rewards

ILQL Data Classes

class trlx.data.ilql_types.ILQLElement(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]

A single data item for ILQL training

Parameters
  • input_ids (torch.Tensor) – Long tensor of input tokens.

  • attention_mask (torch.Tensor) – Attention mask for input tokens. Should be a long tensor.

  • rewards (torch.Tensor) – Rewards for each input token.

  • states_ixs (torch.Tensor) – Indices of states (user input or environment input for example) in the input_ids.

  • actions_ixs (torch.Tensor) – Indices of actions (model output) in the input_ids tensor.

  • dones (torch.Tensor) – Indicator of for the terminal state (end of episode) in the input_ids tensor.

class trlx.models.modeling_ilql.CausalILQLOutput(logits: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, hidden_states: Optional[Tuple[torch.FloatTensor]] = None, value: Optional[torch.FloatTensor] = None, qs: Optional[Tuple[torch.FloatTensor]] = None, target_qs: Optional[Tuple[torch.FloatTensor]] = None)[source]

Output of the causal model with ILQL heads.

Parameters
  • logits (torch.FloatTensor) – Logits of the causal model.

  • past_key_values (Tuple[Tuple[torch.FloatTensor]]) – Tuple of past key values of the causal model.

  • hidden_states (Tuple[torch.FloatTensor]) – Last hidden state of the causal model.

  • value (torch.FloatTensor) – Value function estimation for each token in the input sequence.

  • qs (Tuple[torch.FloatTensor]) – Q-function estimations for each token in the input sequence.

  • target_qs (Tuple[torch.FloatTensor]) – Q-function estimations from the target Q-head for each token in the input sequence.

class trlx.data.ilql_types.ILQLSeq2SeqElement(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], decoder_input_ids: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]

A single data item for ILQL training

Parameters
  • input_ids (torch.Tensor) – Long tensor of input tokens.

  • attention_mask (torch.Tensor) – Attention mask for input tokens. Should be a long tensor.

  • decoder_input_ids (torch.Tensor) – Long tensor of target input tokens.

  • rewards (torch.Tensor) – Rewards for each input token.

  • states_ixs (torch.Tensor) – Indices of states (user input or environment input for example) in the input_ids.

  • actions_ixs (torch.Tensor) – Indices of actions (model output) in the input_ids tensor.

  • dones (torch.Tensor) – Indicator of for the terminal state (end of episode) in the input_ids tensor.

class trlx.models.modeling_ilql.Seq2SeqILQLOutput(logits: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, hidden_states: Optional[Tuple[torch.FloatTensor]] = None, value: Optional[torch.FloatTensor] = None, qs: Optional[Tuple[torch.FloatTensor]] = None, target_qs: Optional[Tuple[torch.FloatTensor]] = None, encoder_outputs: Optional[Tuple[Any]] = None)[source]

Output of the seq2seq model with ILQL heads.

Parameters
  • logits (torch.FloatTensor) – Logits of the seq2seq model.

  • past_key_values (Tuple[Tuple[torch.FloatTensor]]) – Tuple of past key values of the seq2seq model.

  • hidden_states (Tuple[torch.FloatTensor]) – Last hidden state of the seq2seq model.

  • value (torch.FloatTensor) – Value function estimation for each token in the input sequence.

  • qs (Tuple[torch.FloatTensor]) – Q-function estimations for each token in the input sequence.

  • target_qs (Tuple[torch.FloatTensor]) – Q-function estimations from the target Q-head for each token in the input sequence.

  • encoder_outputs (Tuple[Any]) – Tuple of encoder outputs of the seq2seq model.

class trlx.data.ilql_types.ILQLBatch(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]

Batched ILQL data elements

Parameters
  • input_ids (torch.Tensor) – Batch of input tokens.

  • attention_mask (torch.Tensor) – Batch of attention masks.

  • rewards (torch.Tensor) – Batch of rewards for each token in each token batch.

  • states_ixs (torch.Tensor) – Batch of indices of states (user input or environment input for example) in the input_ids.

  • actions_ixs (torch.Tensor) – Batch of indices of actions (model output) in the input_ids tensor.

  • dones (torch.Tensor) – Batch of indicators of for the terminal state (end of episode) in the input_ids tensor.

class trlx.data.ilql_types.ILQLSeq2SeqBatch(input_ids: torch.Tensor[torch.Tensor], attention_mask: torch.Tensor[torch.Tensor], decoder_input_ids: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], states_ixs: torch.Tensor[torch.Tensor], actions_ixs: torch.Tensor[torch.Tensor], dones: torch.Tensor[torch.Tensor])[source]

Batched ILQL data elements

Parameters
  • input_ids (torch.Tensor) – Batch of input tokens.

  • attention_mask (torch.Tensor) – Batch of attention masks.

  • decoder_input_ids (torch.Tensor) – Batch of target input tokens.

  • rewards (torch.Tensor) – Batch of rewards for each token in each token batch.

  • states_ixs (torch.Tensor) – Batch of indices of states (user input or environment input for example) in the input_ids.

  • actions_ixs (torch.Tensor) – Batch of indices of actions (model output) in the input_ids tensor.

  • dones (torch.Tensor) – Batch of indicators of for the terminal state (end of episode) in the input_ids tensor.