import torch
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.offline_pipeline import ILQLRolloutStorage
[docs]@register_orchestrator
class OfflineOrchestrator(Orchestrator):
"""
Orchestrator that creates a static dataset for offline training
"""
def __init__(self, model, split_token=None):
self.model = model
self.split_token = split_token
[docs] def make_experience(self, samples, rewards):
"""
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the model
"""
if self.model.tokenizer:
input_ids = self.model.tokenize(samples)
else:
input_ids = samples
input_ids = list(map(torch.as_tensor, input_ids))
states_ixs, actions_ixs = [], []
dones = []
for s, s_tok in zip(samples, input_ids):
# split samples on (prompts, continuations) on a given substring `split_token`
if self.split_token:
prompt_str_len = s.index(self.split_token) + len(self.split_token)
prompt_tok_len = len(self.model.tokenizer(s[:prompt_str_len]).input_ids)
# else assume that the prompt is a bos token
else:
prompt_tok_len = 1
# indices of continuations, to mask prompts in loss computation
a_ixs = torch.arange(prompt_tok_len - 1, len(s_tok) - 1)
# same continuations but for value computation, with the premise to eventually support interleaved dialog
s_ixs = torch.arange(prompt_tok_len - 1, len(s_tok))
# mask continuation's ending
terminals = torch.ones_like(s_ixs)
terminals[-1] = 0
actions_ixs.append(a_ixs)
states_ixs.append(s_ixs)
dones.append(terminals)
if self.model.tokenizer:
prompt = self.model.tokenizer.decode(input_ids[0][: states_ixs[0][1]])
response = self.model.tokenizer.decode(input_ids[0][states_ixs[0][1] :])
print("[Sample example]")
print("Prompt: ", prompt)
print("Response: ", response)
print(f"[Mean reward] {torch.Tensor(rewards).mean():.2f}")
print(
f"[Mean sample length] {torch.mean(torch.Tensor(list(map(len, input_ids)))):.2f}"
)
returns = torch.as_tensor(rewards, dtype=torch.float)
returns = (returns - returns.mean()) / (returns.std() + 1e-30)
rewards = [torch.zeros(x.shape[0]) for x in actions_ixs]
for rs, G in zip(rewards, returns):
rs[-1] = G
attention_mask = [torch.ones(x.shape[0], dtype=int) for x in input_ids]
self.model.store = ILQLRolloutStorage(
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones
)