Source code for trlx.pipeline.offline_pipeline

from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Tuple, Union

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import (
    DataCollatorWithPadding,
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)

from trlx.data.ilql_types import (
    ILQLBatch,
    ILQLElement,
    ILQLSeq2SeqBatch,
    ILQLSeq2SeqElement,
)
from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline


[docs]@dataclass class DialogMessage: """ Single message in a dialogue :param is_output: Whether the message is a model output or a prompt :type is_output: bool :param tokens: Tokenized message :type tokens: Tuple[int] """ is_output: bool tokens: Tuple[int]
[docs]def tokenize_dialogue( # noqa: C901 dialogue: Union[str, Iterable[str]], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048 ) -> List[DialogMessage]: """ Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...) """ if isinstance(dialogue, str): bos_token = tokenizer.bos_token or tokenizer.eos_token dialogue = [bos_token, dialogue] elif isinstance(dialogue, Iterable): if len(dialogue) % 2 != 0: raise ValueError("Dialogue must have an even number of phrases, alternating prompt and output") dialogue = list(dialogue) if not dialogue[-1].endswith(tokenizer.eos_token): dialogue[-1] = dialogue[-1] + tokenizer.eos_token tokenized = [ DialogMessage(is_output=i % 2 == 1, tokens=tuple(tokenizer(dialogue[i], add_special_tokens=False).input_ids)) for i in range(len(dialogue)) ] # flip to truncate from the left if tokenizer.truncation_side == "left": tokenized = [DialogMessage(is_output=m.is_output, tokens=m.tokens[::-1]) for m in tokenized[::-1]] # truncate if necessary lengths = [len(t.tokens) for t in tokenized] cumsum_lengths = [sum(lengths[:i]) for i in range(len(lengths))] truncated = [ DialogMessage(is_output=t.is_output, tokens=t.tokens[: max(max_length - cl, 0)]) for t, cl in zip(tokenized, cumsum_lengths) ] # flip back if was fliped to left truncate if tokenizer.truncation_side == "left": truncated = [DialogMessage(is_output=m.is_output, tokens=m.tokens[::-1]) for m in truncated[::-1]] # remove empty messages out = [t for t in truncated if len(t.tokens) > 0] if out[0].is_output: if sum(map(lambda msg: len(msg.tokens), out)) == max_length: if tokenizer.truncation_side == "left": out[0].tokens = out[0].tokens[1:] else: out[-1].tokens = out[-1].tokens[:-1] out.insert(0, DialogMessage(False, (tokenizer.bos_token_id,))) return out
[docs]class DialogStore(BaseRolloutStore): def __init__(self, dialogs: List[List[DialogMessage]], tokenizer: PreTrainedTokenizer): super().__init__() self.tokenizer = tokenizer attention_masks = [torch.ones(sum(len(m.tokens) for m in d), dtype=torch.bool) for d in dialogs] input_ids = [torch.tensor([t for m in d for t in m.tokens], dtype=torch.long) for d in dialogs] # -100 is the ignore index for CrossEntropyLoss labels = [ torch.tensor([t if m.is_output else -100 for m in d for t in m.tokens], dtype=torch.long) for d in dialogs ] self.history = [ dict(input_ids=i, attention_mask=a, labels=l) for i, a, l in zip(input_ids, attention_masks, labels) ]
[docs] def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: hf_collate_fn = DataCollatorWithPadding(self.tokenizer) def collate_fn(elems: Iterable[dict]): batch = hf_collate_fn( {"input_ids": [e["input_ids"] for e in elems], "attention_mask": [e["attention_mask"] for e in elems]} ) labels = hf_collate_fn([{"input_ids": e["labels"]} for e in elems])["input_ids"] batch["labels"] = labels return batch return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)
[docs]@register_datapipeline class PromptPipeline(BasePipeline): """ Dataloader which is used to supply prompts for either training or evaluation Args: prompts (`List[str]` or `List[Dict[str, Any]]`): list of raw text prompts or a dictionary with a required key `"prompt"` and extra information, that would be passed along the generation for that prompt as a keyword argument to a reward function. max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to tokenizer's truncation setting. tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with. add_special_tokens (`bool`): whether to encode prompts with tokenizer's special tokens (passed directly into `tokenizer.encode`) """ def __init__( self, prompts: Union[List[Dict[str, Any]], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer, add_special_tokens: bool = False, ): super().__init__() if isinstance(prompts[0], dict): metadata = prompts prompts = [x.pop("prompt") for x in metadata] else: metadata = [{}] * len(prompts) model_inputs = tokenizer( prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=add_special_tokens ) prompts_tokens = model_inputs["input_ids"] attention_mask = model_inputs["attention_mask"] self.tokenizer = tokenizer self.prompts = [ {"input_ids": tokens, "attention_mask": mask, **metadata} for tokens, mask, metadata in zip(prompts_tokens, attention_mask, metadata) ] def __getitem__(self, ix: int): return self.prompts[ix] def __len__(self) -> int: return len(self.prompts)
[docs] def create_loader(self, batch_size: int, shuffle=False, sampler=None, drop_last=False) -> DataLoader: def collate_fn(xs): out = self.tokenizer.pad([{"input_ids": x["input_ids"]} for x in xs], return_tensors="pt") for key in xs[0]: if key != "input_ids" and key != "attention_mask": out[key] = [x[key] for x in xs] return out # Since all data is already pre-processed, no need to have # multi-process data loading return DataLoader( self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle, sampler=sampler, num_workers=0, drop_last=drop_last, )
def ilql_collate_fn(elems: Iterable[ILQLElement]): return ILQLBatch( pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), )
[docs]class ILQLRolloutStorage(BaseRolloutStore): """ Rollout storage for training ILQL """ def __init__(self, input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones): super().__init__() self.input_ids = input_ids self.attention_mask = attention_mask self.rewards = rewards self.states_ixs = states_ixs self.actions_ixs = actions_ixs self.dones = dones def __getitem__(self, ix: int) -> ILQLElement: return ILQLElement( self.input_ids[ix], self.attention_mask[ix], self.rewards[ix], self.states_ixs[ix], self.actions_ixs[ix], self.dones[ix], ) def __len__(self) -> int: return len(self.input_ids)
[docs] def create_loader(self, batch_size: int): return DataLoader( self, batch_size=batch_size, shuffle=True, collate_fn=ilql_collate_fn, drop_last=torch.distributed.is_initialized(), )
def ilql_seq2seq_collate_fn(elems: Iterable[ILQLElement]): return ILQLSeq2SeqBatch( pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), pad_sequence([x.decoder_input_ids for x in elems], batch_first=True, padding_value=0), pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), )
[docs]class ILQLSeq2SeqRolloutStorage(BaseRolloutStore): """ Rollout storage for training ILQL with Seq2Seq models """ def __init__(self, input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones): super().__init__() self.input_ids = input_ids self.attention_mask = attention_mask self.decoder_input_ids = decoder_input_ids self.rewards = rewards self.states_ixs = states_ixs self.actions_ixs = actions_ixs self.dones = dones def __getitem__(self, ix: int) -> ILQLElement: return ILQLSeq2SeqElement( self.input_ids[ix], self.attention_mask[ix], self.decoder_input_ids[ix], self.rewards[ix], self.states_ixs[ix], self.actions_ixs[ix], self.dones[ix], ) def __len__(self) -> int: return len(self.input_ids)
[docs] def create_loader(self, batch_size: int): return DataLoader( self, batch_size=batch_size, shuffle=True, collate_fn=ilql_seq2seq_collate_fn, drop_last=torch.distributed.is_initialized(), )