API
trlX uses a single entrypoint for training, which will execute training conditioned on the passed config and the necessary arguments for a specific training routine. For the online training prompts (a list of strings to prompt the training model) and reward_fn (a function which gives reward for model outputs sampled from prompts) are necessary, while for offline training samples (a list of environment/model interactions) and rewards (precomputed scores for each interaction) are required.
Training
- trlx.train(model_path: Optional[str] = None, reward_fn: Optional[Callable[[List[str], List[str], List[str]], List[float]]] = None, dataset: Optional[Iterable[Tuple[str, float]]] = None, samples: Optional[List[str]] = None, rewards: Optional[List[float]] = None, prompts: Optional[List[str]] = None, eval_prompts: Optional[List[str]] = None, metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, config: Optional[trlx.data.configs.TRLConfig] = None, stop_sequences: Optional[List[str]] = [])[source]
Runs online, offline reinforcement training or supervised finetuning depending on provided arguments. reward_fn and prompts are required for online training, samples and rewards are required for offline training.
- Args:
- model_path (Optional[str]):
Path to either huggingface hub checkpoint or a local directory.
- config (Optional[TRLConfig]):
Training configuration object.
- reward_fn (Optional[Callable[[List[str], List[str], List[str]], List[float]]]):
A function to rate batches of generated samples. Its required arguments are (samples, prompts, outputs) and the return is a list of scalar rewards per each sample in batch
- dataset (List[Union[str, List[str]]], List[float]):
Lists of samples and rewards for offline training. (Use samples and rewards instead)
- samples (List[Union[str, List[str]]]):
List of strings or a list of prompts (questions or environment states) and outputs which are meant to be optimized. In the latter case the following form is expected: (prompt_0: str, output_0: str, prompt_1: str, output_1: str …). Giving a single string s for the sample is a shorthand for (tokenizer.bos_token, s)
- rewards (List[float]):
List of scalar rewards per each sample in samples.
- prompts (Union[List[str], List[Dict[str, Any]]]):
Prompts to use for generations during online training. If a dict is passed as prompt, it must have a required key “prompt”, all the extra keys would be passed along the generation for that prompt as a keyword argument to reward function.
- eval_prompts (Union[List[str], List[Dict[str, Any]]]):
Prompts to use for periodical validation of training.
- metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]):
Function to compute statistics on batches of generated samples. Its arguments are the same as in reward_fn (samples, prompts, outputs) but the return is a dictionary of mapping from metric’s name to a list of scalar values per each sample in batch.
- stop_sequences (Optional[List[str]]):
String sequences to trim generations (both for generating of experience and evaluation) up to its encounter in them. Generations will not contain them and also will also be right-stripped.
Distributed
Accelerate
To launch distributed training with Accelerate, first you have to specify the training configuration. You only have to execute this command once per each training node.
$ accelerate config
$ accelerate launch examples/ppo_sentiments.py
You can also use configs provided in trlX repository):
$ accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py
NVIDIA NeMo
For training with NeMo you have to use a model stored in the NeMo format. You can convert an existing llama model with the following script:
$ python examples/llama_nemo/convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b
To start training you have to execute python script per each GPU, or launch the following sbatch script which has -ntasks-per-node=8
$ sbatch examples/llama_nemo/dist_train.sh
Run example: wandb