from dataclasses import dataclass
from torchtyping import TensorType
[docs]@dataclass
class PPORLElement:
"""
:param query_tensor: The query tensor i.e. the prompt tokens.
Should be a long tensor.
:type query_tensor: torch.Tensor
:param response_tensor: The response tensor i.e. the output tokens.
Should be a long tensor.
:type response_tensor: torch.Tensor
:param logprobs: The log probabilities over the response tokens generated
by the policy network (i.e. the autoregressive model).
Should be a float tensor of same size as tokens.
:type logprobs: torch.Tensor
:param values: The values for each token generated from the value network or value head.
Should be a float tensor of same size as tokens.
:type values: torch.Tensor
:param rewards: The rewards for each token outputted in response.
Should be a float tensor of same size as tokens.
:type rewards: torch.Tensor
"""
query_tensor: TensorType["query_size"]
response_tensor: TensorType["response_size"]
logprobs: TensorType["response_size"]
values: TensorType["response_size"]
rewards: TensorType["response_size"]
[docs]@dataclass
class PPORLBatch:
"""
A batched version of the PPORLElement. See PPORLElement for more details on individual fields.
:param query_tensors: A batch of query tensors. Should be a long tensor.
:type query_tensors: torch.Tensor
:param response_tensors: A batch of response tensors. Should be a long tensor.
:type response_tensors: torch.Tensor
:param logprobs: A batch of log probabilities from policy
:type logprobs: torch.Tensor
:param values: A batch of values from value network
:type values: torch.Tensor
:param rewards: A batch of rewards
:type rewards: torch.Tensor
"""
query_tensors: TensorType["batch_size", "query_size"]
response_tensors: TensorType["batch_size", "response_size"]
logprobs: TensorType["batch_size", "response_size"]
values: TensorType["batch_size", "response_size"]
rewards: TensorType["batch_size", "response_size"]