Configs
Training requires configuration to be passed through a set of configs: TrainConfig with training configuration, ModelConfig, TokenizerConfig, OptimizerConfig, SchedulerConfig and a MethodConfig for a specific configuration of a particular algorithm (PPO, ILQL or SFT)
General
- class trlx.data.configs.TRLConfig(method: trlx.data.method_configs.MethodConfig, model: trlx.data.configs.ModelConfig, optimizer: trlx.data.configs.OptimizerConfig, scheduler: trlx.data.configs.SchedulerConfig, tokenizer: trlx.data.configs.TokenizerConfig, train: trlx.data.configs.TrainConfig)[source]
Top level config for trlX. Loads configs and can be converted to dictionary.
- evolve(**kwargs) → trlx.data.configs.TRLConfig[source]
Evolve TRLConfig with new parameters. Can update nested parameters. >>> config = trlx.data.default_configs.default_ilql_config() >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) >>> config.method.gamma 0.99
- class trlx.data.configs.TrainConfig(total_steps: int, seq_length: int, epochs: int, batch_size: int, checkpoint_interval: int, eval_interval: int, pipeline: str, trainer: str, trainer_kwargs: Dict[str, Any] = <factory>, project_name: str = 'trlx', run_name: Optional[str] = None, entity_name: Optional[str] = None, group_name: Optional[str] = None, checkpoint_dir: str = 'ckpts', rollout_logging_dir: Optional[str] = None, save_best: bool = True, save_optimizer: bool = True, resume_from_checkpoint: Optional[str] = None, tracker: Optional[str] = 'wandb', logging_dir: Optional[str] = None, tags: Optional[List[str]] = <factory>, seed: int = 1000, minibatch_size: Optional[int] = None)[source]
Config for train job on model.
- Parameters
total_steps (int) – Total number of training steps
seq_length (int) – Number of tokens to use as context (max length for tokenizer)
epochs (int) – Total number of passes through data
batch_size (int) – Batch size for training
tracker (str) – Tracker to use for logging. Default: “wandb”
checkpoint_interval (int) – Save model every checkpoint_interval steps. Each checkpoint is stored in a sub-directory of the TrainConfig.checkpoint_dir directory in the format checkpoint_dir/checkpoint_{step}.
eval_interval (int) – Evaluate model every eval_interval steps
pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
trainer (Dict[str, Any]) – Trainer to use for training. One of the registered trainers present in trlx.trainer
trainer_kwargs – Extra keyword arguments for the trainer
project_name (str) – Project name for wandb
entity_name (str) – Entity name for wandb
group_name (str) – Group name for wandb (used for grouping runs)
checkpoint_dir (str) – Directory to save checkpoints
rollout_logging_dir (Optional[str]) – Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOTrainer.
save_best (bool) – Save best model based on mean reward
seed (int) – Random seed
minibatch_size (int) – Size of model input during one forward pass. Must divide batch size
- class trlx.data.configs.ModelConfig(model_path: str, model_arch_type: str = 'causal', num_layers_unfrozen: int = -1, peft_config: Optional[Any] = None, model_extra_configs: Dict[str, Any] = <factory>)[source]
Config for a model.
- Parameters
model_path (str) – Path or name of the model (local or on huggingface hub)
model_arch_type (str) – Type of model architecture. Either “causal” or “seq2seq”
num_layers_unfrozen (int) – Number of layers to unfreeze for fine-tuning. -1 means all layers are unfrozen.
peft_config (Union[peft.PeftConfig, Dict[str, Any]]) –
configuration for peft (Parameter Efficient Fine-Tuning library). Peft is designed to reduce the number of parameters to train and the memory footprint, without significant performance loss. It supports multiple techniques such as LORA or prefix tuning (cf. https://github.com/huggingface/peft).
- Here is an example of LORA configuration:
{“peft_type”: “LORA”, “r”: 8, “lora_alpha”: 32, “lora_dropout”: 0.1}
(parameter-efficient fine-tuning was previously done in trlx with OpenDelta, but it is no longer supported)
- class trlx.data.configs.TokenizerConfig(tokenizer_path: str, padding_side: str = 'left', truncation_side: str = 'right', tokenizer_extra_configs: Dict[str, Any] = <factory>)[source]
Config for a model.
- Parameters
tokenizer_path (str) – Path or name of the tokenizer (local or on huggingface hub)
padding_side – Padding side
truncation_side (str) – Truncation side
- class trlx.data.configs.OptimizerConfig(name: str, kwargs: Dict[str, Any] = <factory>)[source]
Config for an optimizer.
- Parameters
name (str) – Name of the optimizer
kwargs (Dict[str, Any]) – Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay)
- class trlx.data.configs.SchedulerConfig(name: str, kwargs: Dict[str, Any] = <factory>)[source]
Config for a learning rate scheduler.
- Parameters
name (str) – Name of the scheduler
kwargs (Dict[str, Any]) – Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max)
- class trlx.data.method_configs.MethodConfig(name: str)[source]
Config for a certain RL method.
- Parameters
name (str) – Name of the method
PPO
- class trlx.models.modeling_ppo.PPOConfig(name: str, ppo_epochs: int, num_rollouts: int, chunk_size: int, init_kl_coef: float, target: float, horizon: int, gamma: float, lam: float, cliprange: float, cliprange_value: float, vf_coef: float, scale_reward: Optional[str], ref_mean: Optional[float], ref_std: Optional[float], cliprange_reward: float, gen_kwargs: dict, gen_experience_kwargs: Optional[dict] = None, num_value_layers_unfrozen: int = 0)[source]
Config for PPO method
- Parameters
ppo_epochs (int) – Number of updates per batch
num_rollouts (int) – Number of experiences to observe before learning
init_kl_coef (float) – Initial value for KL coefficient
target (float) – Target value for KL coefficient
horizon (int) – Number of steps for KL coefficient to reach target
gamma (float) – Discount factor
lam (float) – GAE lambda
cliprange (float) – Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)
cliprange_value (float) – Clipping range for predicted values (observed values - cliprange_value, observed values + cliprange_value)
vf_coef (float) – Value loss scale w.r.t policy loss
gen_kwargs (Dict[str, Any]) – Additional kwargs for the generation
gen_experience_kwargs (Dict[str, Any]) – if this is not None, then the experience is generated using this
- get_advantages_and_returns(values: torch.Tensor[torch.Tensor], rewards: torch.Tensor[torch.Tensor], response_length: int, use_whitening: Optional[bool] = True) → Tuple[torch.Tensor, torch.Tensor][source]
Function that computes advantages and returns from rewards and values. Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347 Note that rewards may include a KL divergence loss term.
Advantages looks like this: Adv1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + …
V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + …
Returns looks like this: Ret1 = R1 + γ * λ * R2 + γ^2 * λ^2 * R3 + …
γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + …
- Args:
values: Tensor of shape (batch_size, response_size) rewards: Tensor of shape (batch_size, response_size) response_length: Length of the response sequence use_whitening: Whether to use whitening (ie. normalize advantages) or not
- loss(logprobs: torch.Tensor[torch.Tensor], values: torch.Tensor[torch.Tensor], old_logprobs: torch.Tensor[torch.Tensor], old_values: torch.Tensor[torch.Tensor], advantages: torch.Tensor[torch.Tensor], returns: torch.Tensor[torch.Tensor], mask: torch.Tensor[torch.Tensor])[source]
PPO objective function. References: - https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html
ILQL
- class trlx.models.modeling_ilql.ILQLConfig(name: str, tau: float, gamma: float, cql_scale: float, awac_scale: float, alpha: float, beta: float, steps_for_target_q_sync: int, two_qs: bool, gen_kwargs: dict)[source]
Configuration for ILQL method.
- Parameters
tau – Parameter for expectile regression for the value function to q
estimates, in (0, 1), where tau=0.5 is equivalent to the mean square error and tau=1 is equivalent to taking a maximum over q estimates :type tau: float
- Parameters
gamma (float) – Discount factor
cql_scale (float) – Scale for the CQL loss (conservative q-learning loss)
awac_scale (float) – Scale for the AWAC loss (weighted cross-entropy loss)
alpha (float) – Parameter for Polyak averaging of the target Q-head sync, in (0, 1)
beta (float) – Parameter for magnitude of weighting effect in the AWAC loss, in (0, 1)
steps_for_target_q_sync (int) – Number of steps between target Q-head syncs
two_qs (bool) – Whether to use two Q-heads and taking minimum of separate estimates or using only one
gen_kwargs (dict) – Keyword arguments for the generation method