API

trlX uses a single entrypoint for training, which will execute training conditioned on the passed config and the necessary arguments for a specific training routine. For the online training prompts (a list of strings to prompt the training model) and reward_fn (a function which gives reward for model outputs sampled from prompts) are necessary, while for offline training samples (a list of environment/model interactions) and rewards (precomputed scores for each interaction) are required.

Training

trlx.train(model_path: Optional[str] = None, reward_fn: Optional[Callable[[List[str], List[str], List[str]], List[float]]] = None, dataset: Optional[Iterable[Tuple[str, float]]] = None, samples: Optional[List[str]] = None, rewards: Optional[List[float]] = None, prompts: Optional[List[str]] = None, eval_prompts: Optional[List[str]] = None, metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, config: Optional[trlx.data.configs.TRLConfig] = None, stop_sequences: Optional[List[str]] = [])[source]

Runs online, offline reinforcement training or supervised finetuning depending on provided arguments. reward_fn and prompts are required for online training, samples and rewards are required for offline training.

Args:
model_path (Optional[str]):

Path to either huggingface hub checkpoint or a local directory.

config (Optional[TRLConfig]):

Training configuration object.

reward_fn (Optional[Callable[[List[str], List[str], List[str]], List[float]]]):

A function to rate batches of generated samples. Its required arguments are (samples, prompts, outputs) and the return is a list of scalar rewards per each sample in batch

dataset (List[Union[str, List[str]]], List[float]):

Lists of samples and rewards for offline training. (Use samples and rewards instead)

samples (List[Union[str, List[str]]]):

List of strings or a list of prompts (questions or environment states) and outputs which are meant to be optimized. In the latter case the following form is expected: (prompt_0: str, output_0: str, prompt_1: str, output_1: str …). Giving a single string s for the sample is a shorthand for (tokenizer.bos_token, s)

rewards (List[float]):

List of scalar rewards per each sample in samples.

prompts (Union[List[str], List[Dict[str, Any]]]):

Prompts to use for generations during online training. If a dict is passed as prompt, it must have a required key “prompt”, all the extra keys would be passed along the generation for that prompt as a keyword argument to reward function.

eval_prompts (Union[List[str], List[Dict[str, Any]]]):

Prompts to use for periodical validation of training.

metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]):

Function to compute statistics on batches of generated samples. Its arguments are the same as in reward_fn (samples, prompts, outputs) but the return is a dictionary of mapping from metric’s name to a list of scalar values per each sample in batch.

stop_sequences (Optional[List[str]]):

String sequences to trim generations (both for generating of experience and evaluation) up to its encounter in them. Generations will not contain them and also will also be right-stripped.

Distributed

Accelerate

To launch distributed training with Accelerate, first you have to specify the training configuration. You only have to execute this command once per each training node.

$ accelerate config
$ accelerate launch examples/ppo_sentiments.py

You can also use configs provided in trlX repository):

$ accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py

NVIDIA NeMo

For training with NeMo you have to use a model stored in the NeMo format. You can convert an existing llama model with the following script:

$ python examples/llama_nemo/convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b

To start training you have to execute python script per each GPU, or launch the following sbatch script which has -ntasks-per-node=8

$ sbatch examples/llama_nemo/dist_train.sh

Run example: wandb