Configs
Training a model in TRL will require you to set several configs: ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for the specific method being used (i.e. ILQL or PPO)
General
- class trlx.data.configs.TRLConfig(model: trlx.data.configs.ModelConfig, train: trlx.data.configs.TrainConfig, method: trlx.data.method_configs.MethodConfig)[source]
Top level config for trlX. Loads configs and can be converted to dictionary.
- class trlx.data.configs.ModelConfig(model_path: str, tokenizer_path: str, model_type: str, num_layers_unfrozen: int = - 1)[source]
Config for a model.
- Parameters
model_path (str) – Path to the model (local or on huggingface hub)
tokenizer_path (str) – Path to the tokenizer (local or on huggingface hub)
model_type (str) – One of the registered RL models present in trlx.model
- class trlx.data.configs.TrainConfig(total_steps: int, seq_length: int, epochs: int, batch_size: int, lr_ramp_steps: int, lr_decay_steps: int, weight_decay: float, learning_rate_init: float, learning_rate_target: float, opt_betas: Tuple[float], checkpoint_interval: int, eval_interval: int, pipeline: str, orchestrator: str, checkpoint_dir: str = 'ckpts', project_name: str = 'trlx', seed: int = 1000)[source]
Config for train job on model.
- Parameters
total_steps (int) – Total number of training steps
seq_length (int) – Number of tokens to use as context (max length for tokenizer)
epochs (int) – Total number of passes through data
batch_size (int) – Batch size for training
lr_ramp_steps (int) – Number of steps before learning rate reaches learning_rate_init
lr_decay_steps (int) – Number of after ramp up steps before learning rate decays to learning_rate_target
weight_decay (float) – Weight decay for optimizer
learning_rate_init (float) – Initial learning rate after ramp up
learning_rate_target (float) – Target learning rate after decay
checkpoint_interval (int) – Save model every checkpoint_interval steps
eval_interval (int) – Evaluate model every eval_interval steps
pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline
orchestrator (str) – Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator
project_name (str) – Project name for wandb
- class trlx.data.method_configs.MethodConfig(name: str)[source]
Config for a certain RL method.
- Parameters
name (str) – Name of the method
PPO
- class trlx.data.method_configs.PPOConfig(name: str, ppo_epochs: int, num_rollouts: int, chunk_size: int, init_kl_coef: float, target: float, horizon: int, gamma: float, lam: float, cliprange: float, cliprange_value: float, vf_coef: float, gen_kwargs: dict)[source]
- chunk_size: int
- cliprange: float
- cliprange_value: float
- gamma: float
- gen_kwargs: dict
- horizon: int
- init_kl_coef: float
- lam: float
- num_rollouts: int
- ppo_epochs: int
- target: float
- vf_coef: float
ILQL
- class trlx.data.method_configs.ILQLConfig(name: str, tau: float, gamma: float, cql_scale: float, awac_scale: float, alpha: float, steps_for_target_q_sync: int, betas: List[float], two_qs: bool)[source]
- alpha: float
- awac_scale: float
- betas: List[float]
- cql_scale: float
- gamma: float
- steps_for_target_q_sync: int
- tau: float
- two_qs: bool