Configs

Training a model in TRL will require you to set several configs: ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for the specific method being used (i.e. ILQL or PPO)

General

class trlx.data.configs.TRLConfig(model: trlx.data.configs.ModelConfig, train: trlx.data.configs.TrainConfig, method: trlx.data.method_configs.MethodConfig)[source]

Top level config for trlX. Loads configs and can be converted to dictionary.

classmethod load_yaml(yml_fp: str)[source]

Load yaml file as TRLConfig.

Parameters

yml_fp (str) – Path to yaml file

to_dict()[source]

Convert TRLConfig to dictionary.

class trlx.data.configs.ModelConfig(model_path: str, tokenizer_path: str, model_type: str, num_layers_unfrozen: int = - 1)[source]

Config for a model.

Parameters
  • model_path (str) – Path to the model (local or on huggingface hub)

  • tokenizer_path (str) – Path to the tokenizer (local or on huggingface hub)

  • model_type (str) – One of the registered RL models present in trlx.model

class trlx.data.configs.TrainConfig(total_steps: int, seq_length: int, epochs: int, batch_size: int, lr_ramp_steps: int, lr_decay_steps: int, weight_decay: float, learning_rate_init: float, learning_rate_target: float, opt_betas: Tuple[float], checkpoint_interval: int, eval_interval: int, pipeline: str, orchestrator: str, checkpoint_dir: str = 'ckpts', project_name: str = 'trlx', seed: int = 1000)[source]

Config for train job on model.

Parameters
  • total_steps (int) – Total number of training steps

  • seq_length (int) – Number of tokens to use as context (max length for tokenizer)

  • epochs (int) – Total number of passes through data

  • batch_size (int) – Batch size for training

  • lr_ramp_steps (int) – Number of steps before learning rate reaches learning_rate_init

  • lr_decay_steps (int) – Number of after ramp up steps before learning rate decays to learning_rate_target

  • weight_decay (float) – Weight decay for optimizer

  • learning_rate_init (float) – Initial learning rate after ramp up

  • learning_rate_target (float) – Target learning rate after decay

  • checkpoint_interval (int) – Save model every save_interval steps

  • eval_interval (int) – Evaluate model every eval_interval steps

  • pipeline (str) – Pipeline to use for training. One of the registered pipelines present in trlx.pipeline

  • orchestrator (str) – Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator

  • checkpoint_dir (str) – Directory to save checkpoints

  • project_name (str) – Project name for wandb

class trlx.data.method_configs.MethodConfig(name: str)[source]

Config for a certain RL method.

Parameters

name (str) – Name of the method

PPO

class trlx.data.method_configs.PPOConfig(name: str, ppo_epochs: int, num_rollouts: int, chunk_size: int, init_kl_coef: float, target: float, horizon: int, gamma: float, lam: float, cliprange: float, cliprange_value: float, vf_coef: float, gen_kwargs: dict)[source]

Config for PPO method

Parameters
  • ppo_epochs (int) – Number of updates per batch

  • num_rollouts (int) – Number of experiences to observe before learning

  • init_kl_coef (float) – Initial value for KL coefficient

  • target (float) – Target value for KL coefficient

  • horizon (int) – Number of steps for KL coefficient to reach target

  • gamma (float) – Discount factor

  • lam (float) – GAE lambda

  • cliprange (float) – Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)

  • cliprange_value (float) – Clipping range for predicted values (observed values - cliprange_value, observed values + cliprange_value)

  • vf_coef (float) – Value loss scale w.r.t policy loss

  • gen_kwargs (Dict[str, Any]) – Additioanl kwargs for the generation

ILQL

class trlx.data.method_configs.ILQLConfig(name: str, tau: float, gamma: float, cql_scale: float, awac_scale: float, alpha: float, steps_for_target_q_sync: int, betas: List[float], two_qs: bool)[source]

Config for ILQL method

Parameters
  • tau (float) – Control tradeoff in value loss between punishing value network for underestimating the target Q (i.e. Q value corresponding to the action taken) (high tau) and overestimating the target Q (low tau)

  • gamma (float) – Discount factor for future rewards

  • cql_scale (float) – Weight for CQL loss term

  • awac_scale (float) – Weight for AWAC loss term

  • steps_for_target_q_sync (int) – Number of steps to wait before syncing target Q network with Q network

  • two_qs (bool) – Use minimum of two Q-value estimates