Configs
Training a model in TRL will require you to set several configs: ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for the specific method being used (i.e. ILQL or PPO)
General
- class trlx.data.configs.TRLConfig(model: trlx.data.configs.ModelConfig, train: trlx.data.configs.TrainConfig, method: trlx.data.method_configs.MethodConfig)[source]
Top level config for trlX. Loads configs and can be converted to dictionary.
- class trlx.data.configs.ModelConfig(model_path: str, tokenizer_path: str, model_type: str, num_layers_unfrozen: int = - 1)[source]
Config for a model.
- Parameters
model_path (str) – Path to the model (local or on huggingface hub)
tokenizer_path (str) – Path to the tokenizer (local or on huggingface hub)
model_type (str) – One of the registered RL models present in trlx.model
- class trlx.data.configs.TrainConfig(total_steps: int, seq_length: int, epochs: int, batch_size: int, lr_ramp_steps: int, lr_decay_steps: int, weight_decay: float, learning_rate_init: float, learning_rate_target: float, opt_betas: Tuple[float], checkpoint_interval: int, eval_interval: int, pipeline: str, orchestrator: str, checkpoint_dir: str = 'ckpts', project_name: str = 'trlx', seed: int = 1000)[source]
Config for train job on model.
- Parameters
total_steps (int) – Total number of training steps
seq_length (int) – Number of tokens to use as context (max length for tokenizer)
epochs (int) – Total number of passes through data
batch_size (int) – Batch size for training
lr_ramp_steps (int) – Number of steps before learning rate reaches learning_rate_init
lr_decay_steps (int) – Number of after ramp up steps before learning rate decays to learning_rate_target
weight_decay (float) – Weight decay for optimizer
learning_rate_init (float) – Initial learning rate after ramp up
learning_rate_target (float) – Target learning rate after decay
checkpoint_interval (int) – Save model every save_interval steps
eval_interval (int) – Evaluate model every eval_interval steps
pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
orchestrator (str) – Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator
checkpoint_dir (str) – Directory to save checkpoints
project_name (str) – Project name for wandb
- class trlx.data.method_configs.MethodConfig(name: str)[source]
Config for a certain RL method.
- Parameters
name (str) – Name of the method
PPO
- class trlx.data.method_configs.PPOConfig(name: str, ppo_epochs: int, num_rollouts: int, chunk_size: int, init_kl_coef: float, target: float, horizon: int, gamma: float, lam: float, cliprange: float, cliprange_value: float, vf_coef: float, gen_kwargs: dict)[source]
Config for PPO method
- Parameters
ppo_epochs (int) – Number of updates per batch
num_rollouts (int) – Number of experiences to observe before learning
init_kl_coef (float) – Initial value for KL coefficient
target (float) – Target value for KL coefficient
horizon (int) – Number of steps for KL coefficient to reach target
gamma (float) – Discount factor
lam (float) – GAE lambda
cliprange (float) – Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)
cliprange_value (float) – Clipping range for predicted values (observed values - cliprange_value, observed values + cliprange_value)
vf_coef (float) – Value loss scale w.r.t policy loss
gen_kwargs (Dict[str, Any]) – Additioanl kwargs for the generation
ILQL
- class trlx.data.method_configs.ILQLConfig(name: str, tau: float, gamma: float, cql_scale: float, awac_scale: float, alpha: float, steps_for_target_q_sync: int, betas: List[float], two_qs: bool)[source]
Config for ILQL method
- Parameters
tau (float) – Control tradeoff in value loss between punishing value network for underestimating the target Q (i.e. Q value corresponding to the action taken) (high tau) and overestimating the target Q (low tau)
gamma (float) – Discount factor for future rewards
cql_scale (float) – Weight for CQL loss term
awac_scale (float) – Weight for AWAC loss term
steps_for_target_q_sync (int) – Number of steps to wait before syncing target Q network with Q network
two_qs (bool) – Use minimum of two Q-value estimates